lint: Auto-format code using black
This commit is contained in:
parent
88d6f7f3f0
commit
684b4a2e13
@ -10,67 +10,77 @@ config = aqt.mw.addonManager.getConfig(__name__)
|
||||
|
||||
# TODO: force the user to log out before changing any of the settings
|
||||
|
||||
|
||||
def addui(self, _):
|
||||
self = self.form
|
||||
parent_w = self.tab_2
|
||||
parent_l = self.vboxlayout
|
||||
self.useCustomServer = QCheckBox(parent_w)
|
||||
self.useCustomServer.setText("Use custom sync server")
|
||||
parent_l.addWidget(self.useCustomServer)
|
||||
cshl = QHBoxLayout()
|
||||
parent_l.addLayout(cshl)
|
||||
self = self.form
|
||||
parent_w = self.tab_2
|
||||
parent_l = self.vboxlayout
|
||||
self.useCustomServer = QCheckBox(parent_w)
|
||||
self.useCustomServer.setText("Use custom sync server")
|
||||
parent_l.addWidget(self.useCustomServer)
|
||||
cshl = QHBoxLayout()
|
||||
parent_l.addLayout(cshl)
|
||||
|
||||
self.serverAddrLabel = QLabel(parent_w)
|
||||
self.serverAddrLabel.setText("Server address")
|
||||
cshl.addWidget(self.serverAddrLabel)
|
||||
self.customServerAddr = QLineEdit(parent_w)
|
||||
self.customServerAddr.setPlaceholderText(DEFAULT_ADDR)
|
||||
cshl.addWidget(self.customServerAddr)
|
||||
self.serverAddrLabel = QLabel(parent_w)
|
||||
self.serverAddrLabel.setText("Server address")
|
||||
cshl.addWidget(self.serverAddrLabel)
|
||||
self.customServerAddr = QLineEdit(parent_w)
|
||||
self.customServerAddr.setPlaceholderText(DEFAULT_ADDR)
|
||||
cshl.addWidget(self.customServerAddr)
|
||||
|
||||
pconfig = getprofileconfig()
|
||||
if pconfig["enabled"]:
|
||||
self.useCustomServer.setCheckState(Qt.Checked)
|
||||
if pconfig["addr"]:
|
||||
self.customServerAddr.setText(pconfig["addr"])
|
||||
pconfig = getprofileconfig()
|
||||
if pconfig["enabled"]:
|
||||
self.useCustomServer.setCheckState(Qt.Checked)
|
||||
if pconfig["addr"]:
|
||||
self.customServerAddr.setText(pconfig["addr"])
|
||||
|
||||
self.customServerAddr.textChanged.connect(lambda text: updateserver(self, text))
|
||||
def onchecked(state):
|
||||
pconfig["enabled"] = state == Qt.Checked
|
||||
updateui(self, state)
|
||||
updateserver(self, self.customServerAddr.text())
|
||||
self.useCustomServer.stateChanged.connect(onchecked)
|
||||
self.customServerAddr.textChanged.connect(lambda text: updateserver(self, text))
|
||||
|
||||
def onchecked(state):
|
||||
pconfig["enabled"] = state == Qt.Checked
|
||||
updateui(self, state)
|
||||
updateserver(self, self.customServerAddr.text())
|
||||
|
||||
self.useCustomServer.stateChanged.connect(onchecked)
|
||||
|
||||
updateui(self, self.useCustomServer.checkState())
|
||||
|
||||
updateui(self, self.useCustomServer.checkState())
|
||||
|
||||
def updateserver(self, text):
|
||||
pconfig = getprofileconfig()
|
||||
if pconfig['enabled']:
|
||||
addr = text or self.customServerAddr.placeholderText()
|
||||
pconfig['addr'] = addr
|
||||
setserver()
|
||||
aqt.mw.addonManager.writeConfig(__name__, config)
|
||||
pconfig = getprofileconfig()
|
||||
if pconfig["enabled"]:
|
||||
addr = text or self.customServerAddr.placeholderText()
|
||||
pconfig["addr"] = addr
|
||||
setserver()
|
||||
aqt.mw.addonManager.writeConfig(__name__, config)
|
||||
|
||||
|
||||
def updateui(self, state):
|
||||
self.serverAddrLabel.setEnabled(state == Qt.Checked)
|
||||
self.customServerAddr.setEnabled(state == Qt.Checked)
|
||||
self.serverAddrLabel.setEnabled(state == Qt.Checked)
|
||||
self.customServerAddr.setEnabled(state == Qt.Checked)
|
||||
|
||||
|
||||
def setserver():
|
||||
pconfig = getprofileconfig()
|
||||
if pconfig['enabled']:
|
||||
aqt.mw.pm.profile['hostNum'] = None
|
||||
anki.sync.SYNC_BASE = "%s" + pconfig['addr']
|
||||
else:
|
||||
anki.sync.SYNC_BASE = anki.consts.SYNC_BASE
|
||||
pconfig = getprofileconfig()
|
||||
if pconfig["enabled"]:
|
||||
aqt.mw.pm.profile["hostNum"] = None
|
||||
anki.sync.SYNC_BASE = "%s" + pconfig["addr"]
|
||||
else:
|
||||
anki.sync.SYNC_BASE = anki.consts.SYNC_BASE
|
||||
|
||||
|
||||
def getprofileconfig():
|
||||
if aqt.mw.pm.name not in config["profiles"]:
|
||||
# inherit global settings if present (used in earlier versions of the addon)
|
||||
config["profiles"][aqt.mw.pm.name] = {
|
||||
"enabled": config.get("enabled", False),
|
||||
"addr": config.get("addr", DEFAULT_ADDR),
|
||||
}
|
||||
aqt.mw.addonManager.writeConfig(__name__, config)
|
||||
return config["profiles"][aqt.mw.pm.name]
|
||||
if aqt.mw.pm.name not in config["profiles"]:
|
||||
# inherit global settings if present (used in earlier versions of the addon)
|
||||
config["profiles"][aqt.mw.pm.name] = {
|
||||
"enabled": config.get("enabled", False),
|
||||
"addr": config.get("addr", DEFAULT_ADDR),
|
||||
}
|
||||
aqt.mw.addonManager.writeConfig(__name__, config)
|
||||
return config["profiles"][aqt.mw.pm.name]
|
||||
|
||||
|
||||
addHook("profileLoaded", setserver)
|
||||
aqt.preferences.Preferences.__init__ = wrap(aqt.preferences.Preferences.__init__, addui, "after")
|
||||
aqt.preferences.Preferences.__init__ = wrap(
|
||||
aqt.preferences.Preferences.__init__, addui, "after"
|
||||
)
|
||||
|
||||
@ -2,6 +2,7 @@ import sys
|
||||
|
||||
if __package__ is None and not hasattr(sys, "frozen"):
|
||||
import os.path
|
||||
|
||||
path = os.path.realpath(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(path)))
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ class CollectionWrapper:
|
||||
self.close()
|
||||
|
||||
def execute(self, func, args=[], kw={}, waitForReturn=True):
|
||||
""" Executes the given function with the underlying anki.storage.Collection
|
||||
"""Executes the given function with the underlying anki.storage.Collection
|
||||
object as the first argument and any additional arguments specified by *args
|
||||
and **kw.
|
||||
|
||||
@ -92,6 +92,7 @@ class CollectionWrapper:
|
||||
"""Returns True if the collection is open, False otherwise."""
|
||||
return self.__col is not None
|
||||
|
||||
|
||||
class CollectionManager:
|
||||
"""Manages a set of CollectionWrapper objects."""
|
||||
|
||||
@ -109,7 +110,9 @@ class CollectionManager:
|
||||
try:
|
||||
col = self.collections[path]
|
||||
except KeyError:
|
||||
col = self.collections[path] = self.collection_wrapper(self.config, path, setup_new_collection)
|
||||
col = self.collections[path] = self.collection_wrapper(
|
||||
self.config, path, setup_new_collection
|
||||
)
|
||||
|
||||
return col
|
||||
|
||||
@ -119,19 +122,25 @@ class CollectionManager:
|
||||
del self.collections[path]
|
||||
col.close()
|
||||
|
||||
def get_collection_wrapper(config, path, setup_new_collection = None):
|
||||
|
||||
def get_collection_wrapper(config, path, setup_new_collection=None):
|
||||
if "collection_wrapper" in config and config["collection_wrapper"]:
|
||||
logger.info("Found collection_wrapper in config, using {} for "
|
||||
"user data persistence".format(config['collection_wrapper']))
|
||||
logger.info(
|
||||
"Found collection_wrapper in config, using {} for "
|
||||
"user data persistence".format(config["collection_wrapper"])
|
||||
)
|
||||
import importlib
|
||||
import inspect
|
||||
module_name, class_name = config['collection_wrapper'].rsplit('.', 1)
|
||||
|
||||
module_name, class_name = config["collection_wrapper"].rsplit(".", 1)
|
||||
module = importlib.import_module(module_name.strip())
|
||||
class_ = getattr(module, class_name.strip())
|
||||
|
||||
if not CollectionWrapper in inspect.getmro(class_):
|
||||
raise TypeError('''"collection_wrapper" found in the conf file but it doesn''t
|
||||
inherit from CollectionWrapper''')
|
||||
raise TypeError(
|
||||
""""collection_wrapper" found in the conf file but it doesn''t
|
||||
inherit from CollectionWrapper"""
|
||||
)
|
||||
return class_(config, path, setup_new_collection)
|
||||
else:
|
||||
return CollectionWrapper(config, path, setup_new_collection)
|
||||
|
||||
@ -7,9 +7,9 @@ logger = logging.getLogger("ankisyncd")
|
||||
|
||||
paths = [
|
||||
"/etc/ankisyncd/ankisyncd.conf",
|
||||
os.environ.get("XDG_CONFIG_HOME") and
|
||||
(os.path.join(os.environ['XDG_CONFIG_HOME'], "ankisyncd", "ankisyncd.conf")) or
|
||||
os.path.join(os.path.expanduser("~"), ".config", "ankisyncd", "ankisyncd.conf"),
|
||||
os.environ.get("XDG_CONFIG_HOME")
|
||||
and (os.path.join(os.environ["XDG_CONFIG_HOME"], "ankisyncd", "ankisyncd.conf"))
|
||||
or os.path.join(os.path.expanduser("~"), ".config", "ankisyncd", "ankisyncd.conf"),
|
||||
os.path.join(dirname(dirname(realpath(__file__))), "ankisyncd.conf"),
|
||||
]
|
||||
|
||||
@ -19,11 +19,12 @@ paths = [
|
||||
def load_from_env(conf):
|
||||
logger.debug("Loading/overriding config values from ENV")
|
||||
for env in os.environ:
|
||||
if env.startswith('ANKISYNCD_'):
|
||||
if env.startswith("ANKISYNCD_"):
|
||||
config_key = env[10:].lower()
|
||||
conf[config_key] = os.getenv(env)
|
||||
logger.info("Setting {} from ENV".format(config_key))
|
||||
|
||||
|
||||
def load(path=None):
|
||||
choices = paths
|
||||
parser = configparser.ConfigParser()
|
||||
@ -33,7 +34,7 @@ def load(path=None):
|
||||
logger.debug("config.location: trying", path)
|
||||
try:
|
||||
parser.read(path)
|
||||
conf = parser['sync_app']
|
||||
conf = parser["sync_app"]
|
||||
logger.info("Loaded config from {}".format(path))
|
||||
load_from_env(conf)
|
||||
return conf
|
||||
|
||||
@ -13,6 +13,7 @@ from anki.collection import Collection
|
||||
logger = logging.getLogger("ankisyncd.media")
|
||||
logger.setLevel(1)
|
||||
|
||||
|
||||
class FullSyncManager:
|
||||
def test_db(self, db: DB):
|
||||
"""
|
||||
@ -34,15 +35,14 @@ class FullSyncManager:
|
||||
# Verify integrity of the received database file before replacing our
|
||||
# existing db.
|
||||
temp_db_path = session.get_collection_path() + ".tmp"
|
||||
with open(temp_db_path, 'wb') as f:
|
||||
with open(temp_db_path, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
try:
|
||||
with DB(temp_db_path) as test_db:
|
||||
self.test_db(test_db)
|
||||
except sqlite.Error as e:
|
||||
raise HTTPBadRequest("Uploaded collection database file is "
|
||||
"corrupt.")
|
||||
raise HTTPBadRequest("Uploaded collection database file is " "corrupt.")
|
||||
|
||||
# Overwrite existing db.
|
||||
col.close()
|
||||
@ -69,7 +69,7 @@ class FullSyncManager:
|
||||
col.close(downgrade=True)
|
||||
db_path = session.get_collection_path()
|
||||
try:
|
||||
with open(db_path, 'rb') as tmp:
|
||||
with open(db_path, "rb") as tmp:
|
||||
data = tmp.read()
|
||||
finally:
|
||||
col.reopen()
|
||||
@ -80,16 +80,21 @@ class FullSyncManager:
|
||||
|
||||
|
||||
def get_full_sync_manager(config):
|
||||
if "full_sync_manager" in config and config["full_sync_manager"]: # load from config
|
||||
if (
|
||||
"full_sync_manager" in config and config["full_sync_manager"]
|
||||
): # load from config
|
||||
import importlib
|
||||
import inspect
|
||||
module_name, class_name = config['full_sync_manager'].rsplit('.', 1)
|
||||
|
||||
module_name, class_name = config["full_sync_manager"].rsplit(".", 1)
|
||||
module = importlib.import_module(module_name.strip())
|
||||
class_ = getattr(module, class_name.strip())
|
||||
|
||||
if not FullSyncManager in inspect.getmro(class_):
|
||||
raise TypeError('''"full_sync_manager" found in the conf file but it doesn''t
|
||||
inherit from FullSyncManager''')
|
||||
raise TypeError(
|
||||
""""full_sync_manager" found in the conf file but it doesn''t
|
||||
inherit from FullSyncManager"""
|
||||
)
|
||||
return class_(config)
|
||||
else:
|
||||
return FullSyncManager()
|
||||
|
||||
@ -12,6 +12,7 @@ from anki.media import MediaManager
|
||||
|
||||
logger = logging.getLogger("ankisyncd.media")
|
||||
|
||||
|
||||
class ServerMediaManager(MediaManager):
|
||||
def __init__(self, col, server=True):
|
||||
super().__init__(col, server)
|
||||
@ -20,14 +21,16 @@ class ServerMediaManager(MediaManager):
|
||||
|
||||
def addMedia(self, media_to_add):
|
||||
self._db.executemany(
|
||||
"INSERT OR REPLACE INTO media VALUES (?,?,?)",
|
||||
media_to_add
|
||||
"INSERT OR REPLACE INTO media VALUES (?,?,?)", media_to_add
|
||||
)
|
||||
self._db.commit()
|
||||
|
||||
def changes(self, lastUsn):
|
||||
return self._db.execute("select fname,usn,csum from media order by usn desc limit ?", self.lastUsn() - lastUsn)
|
||||
|
||||
return self._db.execute(
|
||||
"select fname,usn,csum from media order by usn desc limit ?",
|
||||
self.lastUsn() - lastUsn,
|
||||
)
|
||||
|
||||
def connect(self):
|
||||
path = self.dir() + ".server.db"
|
||||
create = not os.path.exists(path)
|
||||
|
||||
@ -43,20 +43,26 @@ class SqliteSessionManager(SimpleSessionManager):
|
||||
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM sqlite_master "
|
||||
"WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' "
|
||||
"AND tbl_name = 'session'")
|
||||
cursor.execute(
|
||||
"SELECT * FROM sqlite_master "
|
||||
"WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' "
|
||||
"AND tbl_name = 'session'"
|
||||
)
|
||||
res = cursor.fetchone()
|
||||
conn.close()
|
||||
if res is not None:
|
||||
raise Exception("Outdated database schema, run utils/migrate_user_tables.py")
|
||||
raise Exception(
|
||||
"Outdated database schema, run utils/migrate_user_tables.py"
|
||||
)
|
||||
|
||||
def _conn(self):
|
||||
new = not os.path.exists(self.session_db_path)
|
||||
conn = sqlite.connect(self.session_db_path)
|
||||
if new:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, username VARCHAR, path VARCHAR)")
|
||||
cursor.execute(
|
||||
"CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, username VARCHAR, path VARCHAR)"
|
||||
)
|
||||
return conn
|
||||
|
||||
# Default to using sqlite3 syntax but overridable for sub-classes using other
|
||||
@ -73,7 +79,9 @@ class SqliteSessionManager(SimpleSessionManager):
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(self.fs("SELECT skey, username, path FROM session WHERE hkey=?"), (hkey,))
|
||||
cursor.execute(
|
||||
self.fs("SELECT skey, username, path FROM session WHERE hkey=?"), (hkey,)
|
||||
)
|
||||
res = cursor.fetchone()
|
||||
|
||||
if res is not None:
|
||||
@ -89,7 +97,9 @@ class SqliteSessionManager(SimpleSessionManager):
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(self.fs("SELECT hkey, username, path FROM session WHERE skey=?"), (skey,))
|
||||
cursor.execute(
|
||||
self.fs("SELECT hkey, username, path FROM session WHERE skey=?"), (skey,)
|
||||
)
|
||||
res = cursor.fetchone()
|
||||
|
||||
if res is not None:
|
||||
@ -103,8 +113,10 @@ class SqliteSessionManager(SimpleSessionManager):
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("INSERT OR REPLACE INTO session (hkey, skey, username, path) VALUES (?, ?, ?, ?)",
|
||||
(hkey, session.skey, session.name, session.path))
|
||||
cursor.execute(
|
||||
"INSERT OR REPLACE INTO session (hkey, skey, username, path) VALUES (?, ?, ?, ?)",
|
||||
(hkey, session.skey, session.name, session.path),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
@ -117,25 +129,35 @@ class SqliteSessionManager(SimpleSessionManager):
|
||||
cursor.execute(self.fs("DELETE FROM session WHERE hkey=?"), (hkey,))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def get_session_manager(config):
|
||||
if "session_db_path" in config and config["session_db_path"]:
|
||||
logger.info("Found session_db_path in config, using SqliteSessionManager for auth")
|
||||
return SqliteSessionManager(config['session_db_path'])
|
||||
logger.info(
|
||||
"Found session_db_path in config, using SqliteSessionManager for auth"
|
||||
)
|
||||
return SqliteSessionManager(config["session_db_path"])
|
||||
elif "session_manager" in config and config["session_manager"]: # load from config
|
||||
logger.info("Found session_manager in config, using {} for persisting sessions".format(
|
||||
config['session_manager'])
|
||||
logger.info(
|
||||
"Found session_manager in config, using {} for persisting sessions".format(
|
||||
config["session_manager"]
|
||||
)
|
||||
)
|
||||
import importlib
|
||||
import inspect
|
||||
module_name, class_name = config['session_manager'].rsplit('.', 1)
|
||||
|
||||
module_name, class_name = config["session_manager"].rsplit(".", 1)
|
||||
module = importlib.import_module(module_name.strip())
|
||||
class_ = getattr(module, class_name.strip())
|
||||
|
||||
if not SimpleSessionManager in inspect.getmro(class_):
|
||||
raise TypeError('''"session_manager" found in the conf file but it doesn''t
|
||||
inherit from SimpleSessionManager''')
|
||||
raise TypeError(
|
||||
""""session_manager" found in the conf file but it doesn''t
|
||||
inherit from SimpleSessionManager"""
|
||||
)
|
||||
return class_(config)
|
||||
else:
|
||||
logger.warning("Neither session_db_path nor session_manager set, "
|
||||
"ankisyncd will lose sessions on application restart")
|
||||
logger.warning(
|
||||
"Neither session_db_path nor session_manager set, "
|
||||
"ankisyncd will lose sessions on application restart"
|
||||
)
|
||||
return SimpleSessionManager()
|
||||
|
||||
@ -10,7 +10,7 @@ import random
|
||||
import requests
|
||||
import json
|
||||
import os
|
||||
from typing import List,Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
from anki.db import DB, DBError
|
||||
from anki.utils import ids2str, intTime, platDesc, checksum, devMode
|
||||
@ -24,38 +24,43 @@ from anki.lang import ngettext
|
||||
# https://github.com/ankitects/anki/blob/04b1ca75599f18eb783a8bf0bdeeeb32362f4da0/rslib/src/sync/http_client.rs#L11
|
||||
SYNC_VER = 10
|
||||
# https://github.com/ankitects/anki/blob/cca3fcb2418880d0430a5c5c2e6b81ba260065b7/anki/consts.py#L50
|
||||
SYNC_ZIP_SIZE = int(2.5*1024*1024)
|
||||
SYNC_ZIP_SIZE = int(2.5 * 1024 * 1024)
|
||||
# https://github.com/ankitects/anki/blob/cca3fcb2418880d0430a5c5c2e6b81ba260065b7/anki/consts.py#L51
|
||||
SYNC_ZIP_COUNT = 25
|
||||
|
||||
# syncing vars
|
||||
HTTP_TIMEOUT = 90
|
||||
HTTP_PROXY = None
|
||||
HTTP_BUF_SIZE = 64*1024
|
||||
HTTP_BUF_SIZE = 64 * 1024
|
||||
|
||||
# Incremental syncing
|
||||
##########################################################################
|
||||
|
||||
|
||||
class Syncer(object):
|
||||
def __init__(self, col, server=None):
|
||||
self.col = col
|
||||
self.server = server
|
||||
|
||||
# new added functions related to Syncer:
|
||||
# these are removed from latest anki module
|
||||
########################################################################
|
||||
# new added functions related to Syncer:
|
||||
# these are removed from latest anki module
|
||||
########################################################################
|
||||
def scm(self):
|
||||
"""return schema"""
|
||||
scm=self.col.db.scalar("select scm from col")
|
||||
scm = self.col.db.scalar("select scm from col")
|
||||
return scm
|
||||
|
||||
def increment_usn(self):
|
||||
"""usn+1 in db"""
|
||||
self.col.db.execute("update col set usn = usn + 1")
|
||||
def set_modified_time(self,now:int):
|
||||
|
||||
def set_modified_time(self, now: int):
|
||||
self.col.db.execute("update col set mod=?", now)
|
||||
def set_last_sync(self,now:int):
|
||||
|
||||
def set_last_sync(self, now: int):
|
||||
self.col.db.execute("update col set ls = ?", now)
|
||||
#########################################################################
|
||||
|
||||
#########################################################################
|
||||
def meta(self):
|
||||
return dict(
|
||||
mod=self.col.mod,
|
||||
@ -64,30 +69,29 @@ class Syncer(object):
|
||||
ts=intTime(),
|
||||
musn=0,
|
||||
msg="",
|
||||
cont=True
|
||||
cont=True,
|
||||
)
|
||||
|
||||
def changes(self):
|
||||
"Bundle up small objects."
|
||||
d = dict(models=self.getModels(),
|
||||
decks=self.getDecks(),
|
||||
tags=self.getTags())
|
||||
d = dict(models=self.getModels(), decks=self.getDecks(), tags=self.getTags())
|
||||
if self.lnewer:
|
||||
d['conf'] = self.col.all_config()
|
||||
d['crt'] = self.col.crt
|
||||
d["conf"] = self.col.all_config()
|
||||
d["crt"] = self.col.crt
|
||||
return d
|
||||
|
||||
def mergeChanges(self, lchg, rchg):
|
||||
# then the other objects
|
||||
self.mergeModels(rchg['models'])
|
||||
self.mergeDecks(rchg['decks'])
|
||||
if 'conf' in rchg:
|
||||
self.mergeConf(rchg['conf'])
|
||||
self.mergeModels(rchg["models"])
|
||||
self.mergeDecks(rchg["decks"])
|
||||
if "conf" in rchg:
|
||||
self.mergeConf(rchg["conf"])
|
||||
# this was left out of earlier betas
|
||||
if 'crt' in rchg:
|
||||
self.col.crt = rchg['crt']
|
||||
if "crt" in rchg:
|
||||
self.col.crt = rchg["crt"]
|
||||
self.prepareToChunk()
|
||||
# this fn was cloned from anki module(version 2.1.36)
|
||||
|
||||
# this fn was cloned from anki module(version 2.1.36)
|
||||
def basicCheck(self) -> bool:
|
||||
"Basic integrity check for syncing. True if ok."
|
||||
# cards without notes
|
||||
@ -118,9 +122,10 @@ select id from notes where mid = ?) limit 1"""
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def sanityCheck(self):
|
||||
tables=["cards",
|
||||
tables = [
|
||||
"cards",
|
||||
"notes",
|
||||
"revlog",
|
||||
"graves",
|
||||
@ -130,17 +135,17 @@ select id from notes where mid = ?) limit 1"""
|
||||
"notetypes",
|
||||
]
|
||||
for tb in tables:
|
||||
if self.col.db.scalar(f'select null from {tb} where usn=-1'):
|
||||
return f'table had usn=-1: {tb}'
|
||||
if self.col.db.scalar(f"select null from {tb} where usn=-1"):
|
||||
return f"table had usn=-1: {tb}"
|
||||
self.col.sched.reset()
|
||||
|
||||
|
||||
# return summary of deck
|
||||
# make sched.counts() equal to default [0,0,0]
|
||||
# to make sure sync normally if sched.counts()
|
||||
# are not equal between different clients due to
|
||||
# different deck selection
|
||||
# different deck selection
|
||||
return [
|
||||
list([0,0,0]),
|
||||
list([0, 0, 0]),
|
||||
self.col.db.scalar("select count() from cards"),
|
||||
self.col.db.scalar("select count() from notes"),
|
||||
self.col.db.scalar("select count() from revlog"),
|
||||
@ -155,7 +160,7 @@ select id from notes where mid = ?) limit 1"""
|
||||
|
||||
def finish(self, now=None):
|
||||
if now is not None:
|
||||
# ensure we save the mod time even if no changes made
|
||||
# ensure we save the mod time even if no changes made
|
||||
self.set_modified_time(now)
|
||||
self.set_last_sync(now)
|
||||
self.increment_usn()
|
||||
@ -174,17 +179,29 @@ select id from notes where mid = ?) limit 1"""
|
||||
def queryTable(self, table):
|
||||
lim = self.usnLim()
|
||||
if table == "revlog":
|
||||
return self.col.db.execute("""
|
||||
return self.col.db.execute(
|
||||
"""
|
||||
select id, cid, ?, ease, ivl, lastIvl, factor, time, type
|
||||
from revlog where %s""" % lim, self.maxUsn)
|
||||
from revlog where %s"""
|
||||
% lim,
|
||||
self.maxUsn,
|
||||
)
|
||||
elif table == "cards":
|
||||
return self.col.db.execute("""
|
||||
return self.col.db.execute(
|
||||
"""
|
||||
select id, nid, did, ord, mod, ?, type, queue, due, ivl, factor, reps,
|
||||
lapses, left, odue, odid, flags, data from cards where %s""" % lim, self.maxUsn)
|
||||
lapses, left, odue, odid, flags, data from cards where %s"""
|
||||
% lim,
|
||||
self.maxUsn,
|
||||
)
|
||||
else:
|
||||
return self.col.db.execute("""
|
||||
return self.col.db.execute(
|
||||
"""
|
||||
select id, guid, mid, mod, ?, tags, flds, '', '', flags, data
|
||||
from notes where %s""" % lim, self.maxUsn)
|
||||
from notes where %s"""
|
||||
% lim,
|
||||
self.maxUsn,
|
||||
)
|
||||
|
||||
def chunk(self):
|
||||
buf = dict(done=False)
|
||||
@ -195,95 +212,96 @@ from notes where %s""" % lim, self.maxUsn)
|
||||
f"update {curTable} set usn=? where usn=-1", self.maxUsn
|
||||
)
|
||||
if not self.tablesLeft:
|
||||
buf['done'] = True
|
||||
buf["done"] = True
|
||||
return buf
|
||||
|
||||
def applyChunk(self, chunk):
|
||||
if "revlog" in chunk:
|
||||
self.mergeRevlog(chunk['revlog'])
|
||||
self.mergeRevlog(chunk["revlog"])
|
||||
if "cards" in chunk:
|
||||
self.mergeCards(chunk['cards'])
|
||||
self.mergeCards(chunk["cards"])
|
||||
if "notes" in chunk:
|
||||
self.mergeNotes(chunk['notes'])
|
||||
self.mergeNotes(chunk["notes"])
|
||||
|
||||
# Deletions
|
||||
##########################################################################
|
||||
|
||||
def add_grave(self, ids: List[int], type: int,usn: int):
|
||||
items=[(id,type,usn) for id in ids]
|
||||
def add_grave(self, ids: List[int], type: int, usn: int):
|
||||
items = [(id, type, usn) for id in ids]
|
||||
# make sure table graves fields order and schema version match
|
||||
# query sql1='pragma table_info(graves)' version query schema='select ver from col'
|
||||
self.col.db.executemany(
|
||||
"INSERT OR IGNORE INTO graves (oid, type, usn) VALUES (?, ?, ?)" ,
|
||||
items)
|
||||
|
||||
def apply_graves(self, graves,latest_usn: int):
|
||||
# remove card and the card's orphaned notes
|
||||
self.col.remove_cards_and_orphaned_notes(graves['cards'])
|
||||
self.add_grave(graves['cards'], REM_CARD,latest_usn)
|
||||
"INSERT OR IGNORE INTO graves (oid, type, usn) VALUES (?, ?, ?)", items
|
||||
)
|
||||
|
||||
def apply_graves(self, graves, latest_usn: int):
|
||||
# remove card and the card's orphaned notes
|
||||
self.col.remove_cards_and_orphaned_notes(graves["cards"])
|
||||
self.add_grave(graves["cards"], REM_CARD, latest_usn)
|
||||
# only notes
|
||||
self.col.remove_notes(graves['notes'])
|
||||
self.add_grave(graves['notes'], REM_NOTE,latest_usn)
|
||||
self.col.remove_notes(graves["notes"])
|
||||
self.add_grave(graves["notes"], REM_NOTE, latest_usn)
|
||||
|
||||
# since level 0 deck ,we only remove deck ,but backend will delete child,it is ok, the delete
|
||||
# will have once effect
|
||||
self.col.decks.remove(graves['decks'])
|
||||
self.add_grave(graves['decks'], REM_DECK,latest_usn)
|
||||
self.col.decks.remove(graves["decks"])
|
||||
self.add_grave(graves["decks"], REM_DECK, latest_usn)
|
||||
|
||||
# Models
|
||||
##########################################################################
|
||||
|
||||
def getModels(self):
|
||||
mods = [m for m in self.col.models.all() if m['usn'] == -1]
|
||||
mods = [m for m in self.col.models.all() if m["usn"] == -1]
|
||||
for m in mods:
|
||||
m['usn'] = self.maxUsn
|
||||
m["usn"] = self.maxUsn
|
||||
self.col.models.save()
|
||||
return mods
|
||||
|
||||
def mergeModels(self, rchg):
|
||||
for r in rchg:
|
||||
l = self.col.models.get(r['id'])
|
||||
l = self.col.models.get(r["id"])
|
||||
# if missing locally or server is newer, update
|
||||
if not l or r['mod'] > l['mod']:
|
||||
if not l or r["mod"] > l["mod"]:
|
||||
self.col.models.update(r)
|
||||
|
||||
# Decks
|
||||
##########################################################################
|
||||
|
||||
def getDecks(self):
|
||||
decks = [g for g in self.col.decks.all() if g['usn'] == -1]
|
||||
decks = [g for g in self.col.decks.all() if g["usn"] == -1]
|
||||
for g in decks:
|
||||
g['usn'] = self.maxUsn
|
||||
dconf = [g for g in self.col.decks.allConf() if g['usn'] == -1]
|
||||
g["usn"] = self.maxUsn
|
||||
dconf = [g for g in self.col.decks.allConf() if g["usn"] == -1]
|
||||
for g in dconf:
|
||||
g['usn'] = self.maxUsn
|
||||
g["usn"] = self.maxUsn
|
||||
self.col.decks.save()
|
||||
return [decks, dconf]
|
||||
|
||||
def mergeDecks(self, rchg):
|
||||
for r in rchg[0]:
|
||||
l = self.col.decks.get(r['id'], False)
|
||||
l = self.col.decks.get(r["id"], False)
|
||||
# work around mod time being stored as string
|
||||
if l and not isinstance(l['mod'], int):
|
||||
l['mod'] = int(l['mod'])
|
||||
if l and not isinstance(l["mod"], int):
|
||||
l["mod"] = int(l["mod"])
|
||||
|
||||
# if missing locally or server is newer, update
|
||||
if not l or r['mod'] > l['mod']:
|
||||
if not l or r["mod"] > l["mod"]:
|
||||
self.col.decks.update(r)
|
||||
for r in rchg[1]:
|
||||
try:
|
||||
l = self.col.decks.getConf(r['id'])
|
||||
l = self.col.decks.getConf(r["id"])
|
||||
except KeyError:
|
||||
l = None
|
||||
# if missing locally or server is newer, update
|
||||
if not l or r['mod'] > l['mod']:
|
||||
if not l or r["mod"] > l["mod"]:
|
||||
self.col.decks.updateConf(r)
|
||||
|
||||
# Tags
|
||||
##########################################################################
|
||||
def allItems(self) -> List[Tuple[str, int]]:
|
||||
tags=self.col.db.execute("select tag, usn from tags")
|
||||
return [(tag, int(usn)) for tag,usn in tags]
|
||||
tags = self.col.db.execute("select tag, usn from tags")
|
||||
return [(tag, int(usn)) for tag, usn in tags]
|
||||
|
||||
def getTags(self):
|
||||
tags = []
|
||||
for t, usn in self.allItems():
|
||||
@ -301,15 +319,16 @@ from notes where %s""" % lim, self.maxUsn)
|
||||
|
||||
def mergeRevlog(self, logs):
|
||||
self.col.db.executemany(
|
||||
"insert or ignore into revlog values (?,?,?,?,?,?,?,?,?)",
|
||||
logs)
|
||||
"insert or ignore into revlog values (?,?,?,?,?,?,?,?,?)", logs
|
||||
)
|
||||
|
||||
def newerRows(self, data, table, modIdx):
|
||||
ids = (r[0] for r in data)
|
||||
lmods = {}
|
||||
for id, mod in self.col.db.execute(
|
||||
"select id, mod from %s where id in %s and %s" % (
|
||||
table, ids2str(ids), self.usnLim())):
|
||||
"select id, mod from %s where id in %s and %s"
|
||||
% (table, ids2str(ids), self.usnLim())
|
||||
):
|
||||
lmods[id] = mod
|
||||
update = []
|
||||
for r in data:
|
||||
@ -323,14 +342,17 @@ from notes where %s""" % lim, self.maxUsn)
|
||||
self.col.db.executemany(
|
||||
"insert or replace into cards values "
|
||||
"(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
|
||||
self.newerRows(cards, "cards", 4))
|
||||
self.newerRows(cards, "cards", 4),
|
||||
)
|
||||
|
||||
def mergeNotes(self, notes):
|
||||
rows = self.newerRows(notes, "notes", 3)
|
||||
self.col.db.executemany(
|
||||
"insert or replace into notes values (?,?,?,?,?,?,?,?,?,?,?)",
|
||||
rows)
|
||||
self.col.after_note_updates([f[0] for f in rows], mark_modified=False, generate_cards=False)
|
||||
"insert or replace into notes values (?,?,?,?,?,?,?,?,?,?,?)", rows
|
||||
)
|
||||
self.col.after_note_updates(
|
||||
[f[0] for f in rows], mark_modified=False, generate_cards=False
|
||||
)
|
||||
|
||||
# Col config
|
||||
##########################################################################
|
||||
@ -341,11 +363,14 @@ from notes where %s""" % lim, self.maxUsn)
|
||||
def mergeConf(self, conf):
|
||||
for key, value in conf.items():
|
||||
self.col.set_config(key, value)
|
||||
|
||||
|
||||
# self.col.backend.set_all_config(json.dumps(conf).encode())
|
||||
|
||||
# Wrapper for requests that tracks upload/download progress
|
||||
##########################################################################
|
||||
|
||||
|
||||
class AnkiRequestsClient(object):
|
||||
verify = True
|
||||
timeout = 60
|
||||
@ -355,15 +380,23 @@ class AnkiRequestsClient(object):
|
||||
|
||||
def post(self, url, data, headers):
|
||||
data = _MonitoringFile(data)
|
||||
headers['User-Agent'] = self._agentName()
|
||||
headers["User-Agent"] = self._agentName()
|
||||
return self.session.post(
|
||||
url, data=data, headers=headers, stream=True, timeout=self.timeout, verify=self.verify)
|
||||
url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
timeout=self.timeout,
|
||||
verify=self.verify,
|
||||
)
|
||||
|
||||
def get(self, url, headers=None):
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers['User-Agent'] = self._agentName()
|
||||
return self.session.get(url, stream=True, headers=headers, timeout=self.timeout, verify=self.verify)
|
||||
headers["User-Agent"] = self._agentName()
|
||||
return self.session.get(
|
||||
url, stream=True, headers=headers, timeout=self.timeout, verify=self.verify
|
||||
)
|
||||
|
||||
def streamContent(self, resp):
|
||||
resp.raise_for_status()
|
||||
@ -375,24 +408,30 @@ class AnkiRequestsClient(object):
|
||||
|
||||
def _agentName(self):
|
||||
from anki import version
|
||||
|
||||
return "Anki {}".format(version)
|
||||
|
||||
|
||||
# allow user to accept invalid certs in work/school settings
|
||||
if os.environ.get("ANKI_NOVERIFYSSL"):
|
||||
AnkiRequestsClient.verify = False
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
class _MonitoringFile(io.BufferedReader):
|
||||
def read(self, size=-1):
|
||||
data = io.BufferedReader.read(self, HTTP_BUF_SIZE)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# HTTP syncing tools
|
||||
##########################################################################
|
||||
|
||||
|
||||
class HttpSyncer(object):
|
||||
def __init__(self, hkey=None, client=None, hostNum=None):
|
||||
self.hkey = hkey
|
||||
@ -421,24 +460,29 @@ class HttpSyncer(object):
|
||||
# support file uploading, so this is the more compatible choice.
|
||||
|
||||
def _buildPostData(self, fobj, comp):
|
||||
BOUNDARY=b"Anki-sync-boundary"
|
||||
bdry = b"--"+BOUNDARY
|
||||
BOUNDARY = b"Anki-sync-boundary"
|
||||
bdry = b"--" + BOUNDARY
|
||||
buf = io.BytesIO()
|
||||
# post vars
|
||||
self.postVars['c'] = 1 if comp else 0
|
||||
self.postVars["c"] = 1 if comp else 0
|
||||
for (key, value) in list(self.postVars.items()):
|
||||
buf.write(bdry + b"\r\n")
|
||||
buf.write(
|
||||
('Content-Disposition: form-data; name="%s"\r\n\r\n%s\r\n' %
|
||||
(key, value)).encode("utf8"))
|
||||
(
|
||||
'Content-Disposition: form-data; name="%s"\r\n\r\n%s\r\n'
|
||||
% (key, value)
|
||||
).encode("utf8")
|
||||
)
|
||||
# payload as raw data or json
|
||||
rawSize = 0
|
||||
if fobj:
|
||||
# header
|
||||
buf.write(bdry + b"\r\n")
|
||||
buf.write(b"""\
|
||||
buf.write(
|
||||
b"""\
|
||||
Content-Disposition: form-data; name="data"; filename="data"\r\n\
|
||||
Content-Type: application/octet-stream\r\n\r\n""")
|
||||
Content-Type: application/octet-stream\r\n\r\n"""
|
||||
)
|
||||
# write file into buffer, optionally compressing
|
||||
if comp:
|
||||
tgt = gzip.GzipFile(mode="wb", fileobj=buf, compresslevel=comp)
|
||||
@ -453,16 +497,17 @@ Content-Type: application/octet-stream\r\n\r\n""")
|
||||
rawSize += len(data)
|
||||
tgt.write(data)
|
||||
buf.write(b"\r\n")
|
||||
buf.write(bdry + b'--\r\n')
|
||||
buf.write(bdry + b"--\r\n")
|
||||
size = buf.tell()
|
||||
# connection headers
|
||||
headers = {
|
||||
'Content-Type': 'multipart/form-data; boundary=%s' % BOUNDARY.decode("utf8"),
|
||||
'Content-Length': str(size),
|
||||
"Content-Type": "multipart/form-data; boundary=%s"
|
||||
% BOUNDARY.decode("utf8"),
|
||||
"Content-Length": str(size),
|
||||
}
|
||||
buf.seek(0)
|
||||
|
||||
if size >= 100*1024*1024 or rawSize >= 250*1024*1024:
|
||||
if size >= 100 * 1024 * 1024 or rawSize >= 250 * 1024 * 1024:
|
||||
raise Exception("Collection too large to upload to AnkiWeb.")
|
||||
|
||||
return headers, buf
|
||||
@ -470,7 +515,7 @@ Content-Type: application/octet-stream\r\n\r\n""")
|
||||
def req(self, method, fobj=None, comp=6, badAuthRaises=True):
|
||||
headers, body = self._buildPostData(fobj, comp)
|
||||
|
||||
r = self.client.post(self.syncURL()+method, data=body, headers=headers)
|
||||
r = self.client.post(self.syncURL() + method, data=body, headers=headers)
|
||||
if not badAuthRaises and r.status_code == 403:
|
||||
return False
|
||||
self.assertOk(r)
|
||||
@ -478,9 +523,11 @@ Content-Type: application/octet-stream\r\n\r\n""")
|
||||
buf = self.client.streamContent(r)
|
||||
return buf
|
||||
|
||||
|
||||
# Incremental sync over HTTP
|
||||
######################################################################
|
||||
|
||||
|
||||
class RemoteServer(HttpSyncer):
|
||||
def __init__(self, hkey, hostNum):
|
||||
super().__init__(self, hkey, hostNum=hostNum)
|
||||
@ -489,12 +536,14 @@ class RemoteServer(HttpSyncer):
|
||||
"Returns hkey or none if user/pw incorrect."
|
||||
self.postVars = dict()
|
||||
ret = self.req(
|
||||
"hostKey", io.BytesIO(json.dumps(dict(u=user, p=pw)).encode("utf8")),
|
||||
badAuthRaises=False)
|
||||
"hostKey",
|
||||
io.BytesIO(json.dumps(dict(u=user, p=pw)).encode("utf8")),
|
||||
badAuthRaises=False,
|
||||
)
|
||||
if not ret:
|
||||
# invalid auth
|
||||
return
|
||||
self.hkey = json.loads(ret.decode("utf8"))['key']
|
||||
self.hkey = json.loads(ret.decode("utf8"))["key"]
|
||||
return self.hkey
|
||||
|
||||
def meta(self):
|
||||
@ -503,9 +552,17 @@ class RemoteServer(HttpSyncer):
|
||||
s=self.skey,
|
||||
)
|
||||
ret = self.req(
|
||||
"meta", io.BytesIO(json.dumps(dict(
|
||||
v=SYNC_VER, cv="ankidesktop,%s,%s"%(versionWithBuild(), platDesc()))).encode("utf8")),
|
||||
badAuthRaises=False)
|
||||
"meta",
|
||||
io.BytesIO(
|
||||
json.dumps(
|
||||
dict(
|
||||
v=SYNC_VER,
|
||||
cv="ankidesktop,%s,%s" % (versionWithBuild(), platDesc()),
|
||||
)
|
||||
).encode("utf8")
|
||||
),
|
||||
badAuthRaises=False,
|
||||
)
|
||||
if not ret:
|
||||
# invalid auth
|
||||
return
|
||||
@ -537,17 +594,20 @@ class RemoteServer(HttpSyncer):
|
||||
|
||||
def _run(self, cmd, data):
|
||||
return json.loads(
|
||||
self.req(cmd, io.BytesIO(json.dumps(data).encode("utf8"))).decode("utf8"))
|
||||
self.req(cmd, io.BytesIO(json.dumps(data).encode("utf8"))).decode("utf8")
|
||||
)
|
||||
|
||||
|
||||
# Full syncing
|
||||
##########################################################################
|
||||
|
||||
|
||||
class FullSyncer(HttpSyncer):
|
||||
def __init__(self, col, hkey, client, hostNum):
|
||||
super().__init__(self, hkey, client, hostNum=hostNum)
|
||||
self.postVars = dict(
|
||||
k=self.hkey,
|
||||
v="ankidesktop,%s,%s"%(anki.version, platDesc()),
|
||||
v="ankidesktop,%s,%s" % (anki.version, platDesc()),
|
||||
)
|
||||
self.col = col
|
||||
|
||||
@ -586,9 +646,11 @@ class FullSyncer(HttpSyncer):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# Remote media syncing
|
||||
##########################################################################
|
||||
|
||||
|
||||
class RemoteMediaServer(HttpSyncer):
|
||||
def __init__(self, col, hkey, client, hostNum):
|
||||
self.col = col
|
||||
@ -597,12 +659,12 @@ class RemoteMediaServer(HttpSyncer):
|
||||
|
||||
def begin(self):
|
||||
self.postVars = dict(
|
||||
k=self.hkey,
|
||||
v="ankidesktop,%s,%s"%(anki.version, platDesc())
|
||||
k=self.hkey, v="ankidesktop,%s,%s" % (anki.version, platDesc())
|
||||
)
|
||||
ret = self._dataOnly(self.req(
|
||||
"begin", io.BytesIO(json.dumps(dict()).encode("utf8"))))
|
||||
self.skey = ret['sk']
|
||||
ret = self._dataOnly(
|
||||
self.req("begin", io.BytesIO(json.dumps(dict()).encode("utf8")))
|
||||
)
|
||||
self.skey = ret["sk"]
|
||||
return ret
|
||||
|
||||
# args: lastUsn
|
||||
@ -611,7 +673,8 @@ class RemoteMediaServer(HttpSyncer):
|
||||
sk=self.skey,
|
||||
)
|
||||
return self._dataOnly(
|
||||
self.req("mediaChanges", io.BytesIO(json.dumps(kw).encode("utf8"))))
|
||||
self.req("mediaChanges", io.BytesIO(json.dumps(kw).encode("utf8")))
|
||||
)
|
||||
|
||||
# args: files
|
||||
def downloadFiles(self, **kw):
|
||||
@ -619,20 +682,20 @@ class RemoteMediaServer(HttpSyncer):
|
||||
|
||||
def uploadChanges(self, zip):
|
||||
# no compression, as we compress the zip file instead
|
||||
return self._dataOnly(
|
||||
self.req("uploadChanges", io.BytesIO(zip), comp=0))
|
||||
return self._dataOnly(self.req("uploadChanges", io.BytesIO(zip), comp=0))
|
||||
|
||||
# args: local
|
||||
def mediaSanity(self, **kw):
|
||||
return self._dataOnly(
|
||||
self.req("mediaSanity", io.BytesIO(json.dumps(kw).encode("utf8"))))
|
||||
self.req("mediaSanity", io.BytesIO(json.dumps(kw).encode("utf8")))
|
||||
)
|
||||
|
||||
def _dataOnly(self, resp):
|
||||
resp = json.loads(resp.decode("utf8"))
|
||||
if resp['err']:
|
||||
self.col.log("error returned:%s"%resp['err'])
|
||||
raise Exception("SyncError:%s"%resp['err'])
|
||||
return resp['data']
|
||||
if resp["err"]:
|
||||
self.col.log("error returned:%s" % resp["err"])
|
||||
raise Exception("SyncError:%s" % resp["err"])
|
||||
return resp["data"]
|
||||
|
||||
# only for unit tests
|
||||
def mediatest(self, cmd):
|
||||
@ -640,5 +703,7 @@ class RemoteMediaServer(HttpSyncer):
|
||||
k=self.hkey,
|
||||
)
|
||||
return self._dataOnly(
|
||||
self.req("newMediaTest", io.BytesIO(
|
||||
json.dumps(dict(cmd=cmd)).encode("utf8"))))
|
||||
self.req(
|
||||
"newMediaTest", io.BytesIO(json.dumps(dict(cmd=cmd)).encode("utf8"))
|
||||
)
|
||||
)
|
||||
|
||||
@ -43,7 +43,16 @@ logger = logging.getLogger("ankisyncd")
|
||||
|
||||
|
||||
class SyncCollectionHandler(Syncer):
|
||||
operations = ['meta', 'applyChanges', 'start', 'applyGraves', 'chunk', 'applyChunk', 'sanityCheck2', 'finish']
|
||||
operations = [
|
||||
"meta",
|
||||
"applyChanges",
|
||||
"start",
|
||||
"applyGraves",
|
||||
"chunk",
|
||||
"applyChunk",
|
||||
"sanityCheck2",
|
||||
"finish",
|
||||
]
|
||||
|
||||
def __init__(self, col, session):
|
||||
# So that 'server' (the 3rd argument) can't get set
|
||||
@ -56,9 +65,9 @@ class SyncCollectionHandler(Syncer):
|
||||
return False
|
||||
|
||||
note = {"alpha": 0, "beta": 0, "rc": 0}
|
||||
client, version, platform = cv.split(',')
|
||||
client, version, platform = cv.split(",")
|
||||
|
||||
if 'arch' not in version:
|
||||
if "arch" not in version:
|
||||
for name in note.keys():
|
||||
if name in version:
|
||||
vs = version.split(name)
|
||||
@ -66,17 +75,17 @@ class SyncCollectionHandler(Syncer):
|
||||
note[name] = int(vs[-1])
|
||||
|
||||
# convert the version string, ignoring non-numeric suffixes like in beta versions of Anki
|
||||
version_nosuffix = re.sub(r'[^0-9.].*$', '', version)
|
||||
version_int = [int(x) for x in version_nosuffix.split('.')]
|
||||
version_nosuffix = re.sub(r"[^0-9.].*$", "", version)
|
||||
version_int = [int(x) for x in version_nosuffix.split(".")]
|
||||
|
||||
if client == 'ankidesktop':
|
||||
if client == "ankidesktop":
|
||||
return version_int < [2, 0, 27]
|
||||
elif client == 'ankidroid':
|
||||
elif client == "ankidroid":
|
||||
if version_int == [2, 3]:
|
||||
if note["alpha"]:
|
||||
return note["alpha"] < 4
|
||||
if note["alpha"]:
|
||||
return note["alpha"] < 4
|
||||
else:
|
||||
return version_int < [2, 2, 3]
|
||||
return version_int < [2, 2, 3]
|
||||
else: # unknown client, assume current version
|
||||
return False
|
||||
|
||||
@ -84,23 +93,33 @@ class SyncCollectionHandler(Syncer):
|
||||
if self._old_client(cv):
|
||||
return Response(status=501) # client needs upgrade
|
||||
if v > SYNC_VER:
|
||||
return {"cont": False, "msg": "Your client is using unsupported sync protocol ({}, supported version: {})".format(v, SYNC_VER)}
|
||||
return {
|
||||
"cont": False,
|
||||
"msg": "Your client is using unsupported sync protocol ({}, supported version: {})".format(
|
||||
v, SYNC_VER
|
||||
),
|
||||
}
|
||||
if v < 9 and self.col.schedVer() >= 2:
|
||||
return {"cont": False, "msg": "Your client doesn't support the v{} scheduler.".format(self.col.schedVer())}
|
||||
return {
|
||||
"cont": False,
|
||||
"msg": "Your client doesn't support the v{} scheduler.".format(
|
||||
self.col.schedVer()
|
||||
),
|
||||
}
|
||||
|
||||
# Make sure the media database is open!
|
||||
self.col.media.connect()
|
||||
|
||||
return {
|
||||
'mod': self.col.mod,
|
||||
'scm': self.scm(),
|
||||
'usn': self.col.usn(),
|
||||
'ts': anki.utils.intTime(),
|
||||
'musn': self.col.media.lastUsn(),
|
||||
'uname': self.session.name,
|
||||
'msg': '',
|
||||
'cont': True,
|
||||
'hostNum': 0,
|
||||
"mod": self.col.mod,
|
||||
"scm": self.scm(),
|
||||
"usn": self.col.usn(),
|
||||
"ts": anki.utils.intTime(),
|
||||
"musn": self.col.media.lastUsn(),
|
||||
"uname": self.session.name,
|
||||
"msg": "",
|
||||
"cont": True,
|
||||
"hostNum": 0,
|
||||
}
|
||||
|
||||
def usnLim(self):
|
||||
@ -108,10 +127,16 @@ class SyncCollectionHandler(Syncer):
|
||||
|
||||
# ankidesktop >=2.1rc2 sends graves in applyGraves, but still expects
|
||||
# server-side deletions to be returned by start
|
||||
def start(self, minUsn, lnewer, graves={"cards": [], "notes": [], "decks": []}, offset=None):
|
||||
def start(
|
||||
self,
|
||||
minUsn,
|
||||
lnewer,
|
||||
graves={"cards": [], "notes": [], "decks": []},
|
||||
offset=None,
|
||||
):
|
||||
# The offset para is passed by client V2 scheduler,which is minutes_west.
|
||||
# Since now have not thorougly test the V2 scheduler, we leave this comments here, and
|
||||
# just enable the V2 scheduler in the serve code.
|
||||
# Since now have not thorougly test the V2 scheduler, we leave this comments here, and
|
||||
# just enable the V2 scheduler in the serve code.
|
||||
|
||||
self.maxUsn = self.col.usn()
|
||||
self.minUsn = minUsn
|
||||
@ -122,11 +147,11 @@ class SyncCollectionHandler(Syncer):
|
||||
# Only if Operations like deleting deck are performed on Ankidroid
|
||||
# can (client) graves is not None
|
||||
if graves is not None:
|
||||
self.apply_graves(graves,self.maxUsn)
|
||||
self.apply_graves(graves, self.maxUsn)
|
||||
return lgraves
|
||||
|
||||
def applyGraves(self, chunk):
|
||||
self.apply_graves(chunk,self.maxUsn)
|
||||
self.apply_graves(chunk, self.maxUsn)
|
||||
|
||||
def applyChanges(self, changes):
|
||||
self.rchg = changes
|
||||
@ -136,16 +161,13 @@ class SyncCollectionHandler(Syncer):
|
||||
return lchg
|
||||
|
||||
def sanityCheck2(self, client):
|
||||
client[0]=[0,0,0]
|
||||
client[0] = [0, 0, 0]
|
||||
server = self.sanityCheck()
|
||||
if client != server:
|
||||
logger.info(
|
||||
f"sanity check failed with server: {server} client: {client}"
|
||||
)
|
||||
logger.info(f"sanity check failed with server: {server} client: {client}")
|
||||
return dict(status="bad", c=client, s=server)
|
||||
return dict(status="ok")
|
||||
|
||||
|
||||
def finish(self):
|
||||
return super().finish(anki.utils.intTime(1000))
|
||||
|
||||
@ -158,7 +180,8 @@ class SyncCollectionHandler(Syncer):
|
||||
decks = []
|
||||
|
||||
curs = self.col.db.execute(
|
||||
"select oid, type from graves where usn >= ?", self.minUsn)
|
||||
"select oid, type from graves where usn >= ?", self.minUsn
|
||||
)
|
||||
|
||||
for oid, type in curs:
|
||||
if type == REM_CARD:
|
||||
@ -171,20 +194,26 @@ class SyncCollectionHandler(Syncer):
|
||||
return dict(cards=cards, notes=notes, decks=decks)
|
||||
|
||||
def getModels(self):
|
||||
return [m for m in self.col.models.all() if m['usn'] >= self.minUsn]
|
||||
return [m for m in self.col.models.all() if m["usn"] >= self.minUsn]
|
||||
|
||||
def getDecks(self):
|
||||
return [
|
||||
[g for g in self.col.decks.all() if g['usn'] >= self.minUsn],
|
||||
[g for g in self.col.decks.all_config() if g['usn'] >= self.minUsn]
|
||||
[g for g in self.col.decks.all() if g["usn"] >= self.minUsn],
|
||||
[g for g in self.col.decks.all_config() if g["usn"] >= self.minUsn],
|
||||
]
|
||||
|
||||
def getTags(self):
|
||||
return [t for t, usn in self.allItems()
|
||||
if usn >= self.minUsn]
|
||||
return [t for t, usn in self.allItems() if usn >= self.minUsn]
|
||||
|
||||
|
||||
class SyncMediaHandler:
|
||||
operations = ['begin', 'mediaChanges', 'mediaSanity', 'uploadChanges', 'downloadFiles']
|
||||
operations = [
|
||||
"begin",
|
||||
"mediaChanges",
|
||||
"mediaSanity",
|
||||
"uploadChanges",
|
||||
"downloadFiles",
|
||||
]
|
||||
|
||||
def __init__(self, col, session):
|
||||
self.col = col
|
||||
@ -192,11 +221,11 @@ class SyncMediaHandler:
|
||||
|
||||
def begin(self, skey):
|
||||
return {
|
||||
'data': {
|
||||
'sk': skey,
|
||||
'usn': self.col.media.lastUsn(),
|
||||
"data": {
|
||||
"sk": skey,
|
||||
"usn": self.col.media.lastUsn(),
|
||||
},
|
||||
'err': '',
|
||||
"err": "",
|
||||
}
|
||||
|
||||
def uploadChanges(self, data):
|
||||
@ -210,24 +239,27 @@ class SyncMediaHandler:
|
||||
processed_count = self._adopt_media_changes_from_zip(z)
|
||||
|
||||
return {
|
||||
'data': [processed_count, self.col.media.lastUsn()],
|
||||
'err': '',
|
||||
"data": [processed_count, self.col.media.lastUsn()],
|
||||
"err": "",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _check_zip_data(zip_file):
|
||||
max_zip_size = 100*1024*1024
|
||||
max_zip_size = 100 * 1024 * 1024
|
||||
max_meta_file_size = 100000
|
||||
|
||||
meta_file_size = zip_file.getinfo("_meta").file_size
|
||||
sum_file_sizes = sum(info.file_size for info in zip_file.infolist())
|
||||
|
||||
if meta_file_size > max_meta_file_size:
|
||||
raise ValueError("Zip file's metadata file is larger than %s "
|
||||
"Bytes." % max_meta_file_size)
|
||||
raise ValueError(
|
||||
"Zip file's metadata file is larger than %s "
|
||||
"Bytes." % max_meta_file_size
|
||||
)
|
||||
elif sum_file_sizes > max_zip_size:
|
||||
raise ValueError("Zip file contents are larger than %s Bytes." %
|
||||
max_zip_size)
|
||||
raise ValueError(
|
||||
"Zip file contents are larger than %s Bytes." % max_zip_size
|
||||
)
|
||||
|
||||
def _adopt_media_changes_from_zip(self, zip_file):
|
||||
"""
|
||||
@ -261,7 +293,7 @@ class SyncMediaHandler:
|
||||
file_path = os.path.join(media_dir, filename)
|
||||
|
||||
# Save file to media directory.
|
||||
with open(file_path, 'wb') as f:
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_data)
|
||||
|
||||
usn += 1
|
||||
@ -279,7 +311,9 @@ class SyncMediaHandler:
|
||||
if media_to_add:
|
||||
self.col.media.addMedia(media_to_add)
|
||||
|
||||
assert self.col.media.lastUsn() == oldUsn + processed_count # TODO: move to some unit test
|
||||
assert (
|
||||
self.col.media.lastUsn() == oldUsn + processed_count
|
||||
) # TODO: move to some unit test
|
||||
return processed_count
|
||||
|
||||
@staticmethod
|
||||
@ -302,13 +336,15 @@ class SyncMediaHandler:
|
||||
Marks all files in list filenames as deleted and removes them from the
|
||||
media directory.
|
||||
"""
|
||||
logger.debug('Removing %d files from media dir.' % len(filenames))
|
||||
logger.debug("Removing %d files from media dir." % len(filenames))
|
||||
for filename in filenames:
|
||||
try:
|
||||
self.col.media.syncDelete(filename)
|
||||
except OSError as err:
|
||||
logger.error("Error when removing file '%s' from media dir: "
|
||||
"%s" % (filename, str(err)))
|
||||
logger.error(
|
||||
"Error when removing file '%s' from media dir: "
|
||||
"%s" % (filename, str(err))
|
||||
)
|
||||
|
||||
def downloadFiles(self, files):
|
||||
flist = {}
|
||||
@ -334,13 +370,17 @@ class SyncMediaHandler:
|
||||
server_lastUsn = self.col.media.lastUsn()
|
||||
|
||||
if lastUsn < server_lastUsn or lastUsn == 0:
|
||||
for fname,usn,csum, in self.col.media.changes(lastUsn):
|
||||
for (
|
||||
fname,
|
||||
usn,
|
||||
csum,
|
||||
) in self.col.media.changes(lastUsn):
|
||||
result.append([fname, usn, csum])
|
||||
# anki assumes server_lastUsn == result[-1][1]
|
||||
# ref: anki/sync.py:720 (commit cca3fcb2418880d0430a5c5c2e6b81ba260065b7)
|
||||
result.reverse()
|
||||
|
||||
return {'data': result, 'err': ''}
|
||||
return {"data": result, "err": ""}
|
||||
|
||||
def mediaSanity(self, local=None):
|
||||
if self.col.media.mediaCount() == local:
|
||||
@ -348,7 +388,8 @@ class SyncMediaHandler:
|
||||
else:
|
||||
result = "FAILED"
|
||||
|
||||
return {'data': result, 'err': ''}
|
||||
return {"data": result, "err": ""}
|
||||
|
||||
|
||||
class SyncUserSession:
|
||||
def __init__(self, name, path, collection_manager, setup_new_collection=None):
|
||||
@ -371,16 +412,18 @@ class SyncUserSession:
|
||||
return anki.utils.checksum(str(random.random()))[:8]
|
||||
|
||||
def get_collection_path(self):
|
||||
return os.path.realpath(os.path.join(self.path, 'collection.anki2'))
|
||||
return os.path.realpath(os.path.join(self.path, "collection.anki2"))
|
||||
|
||||
def get_thread(self):
|
||||
return self.collection_manager.get_collection(self.get_collection_path(), self.setup_new_collection)
|
||||
return self.collection_manager.get_collection(
|
||||
self.get_collection_path(), self.setup_new_collection
|
||||
)
|
||||
|
||||
def get_handler_for_operation(self, operation, col):
|
||||
if operation in SyncCollectionHandler.operations:
|
||||
attr, handler_class = 'collection_handler', SyncCollectionHandler
|
||||
attr, handler_class = "collection_handler", SyncCollectionHandler
|
||||
elif operation in SyncMediaHandler.operations:
|
||||
attr, handler_class = 'media_handler', SyncMediaHandler
|
||||
attr, handler_class = "media_handler", SyncMediaHandler
|
||||
else:
|
||||
raise Exception("no handler for {}".format(operation))
|
||||
|
||||
@ -391,146 +434,178 @@ class SyncUserSession:
|
||||
# for inactivity and then later re-open it (creating a new Collection object).
|
||||
handler.col = col
|
||||
return handler
|
||||
|
||||
|
||||
class Requests(object):
|
||||
def __init__(self,environ: dict):
|
||||
self.environ=environ
|
||||
def __init__(self, environ: dict):
|
||||
self.environ = environ
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
return self.request_items_dict
|
||||
|
||||
@params.setter
|
||||
def params(self,value):
|
||||
def params(self, value):
|
||||
"""
|
||||
A dictionary-like object containing both the parameters from
|
||||
the query string and request body.
|
||||
"""
|
||||
self.request_items_dict= value
|
||||
self.request_items_dict = value
|
||||
|
||||
@property
|
||||
def path(self)-> str:
|
||||
return self.environ['PATH_INFO']
|
||||
def path(self) -> str:
|
||||
return self.environ["PATH_INFO"]
|
||||
|
||||
@property
|
||||
def POST(self):
|
||||
return self._request_items_dict
|
||||
|
||||
@POST.setter
|
||||
def POST(self,value):
|
||||
self._request_items_dict=value
|
||||
def POST(self, value):
|
||||
self._request_items_dict = value
|
||||
|
||||
@property
|
||||
def parse(self):
|
||||
'''Return a MultiDict containing all the variables from a form
|
||||
"""Return a MultiDict containing all the variables from a form
|
||||
request.\n
|
||||
include not only post req,but also get'''
|
||||
include not only post req,but also get"""
|
||||
env = self.environ
|
||||
query_string=env['QUERY_STRING']
|
||||
content_len= env.get('CONTENT_LENGTH', '0')
|
||||
input = env.get('wsgi.input')
|
||||
length = 0 if content_len == '' else int(content_len)
|
||||
body=b''
|
||||
request_items_dict={}
|
||||
query_string = env["QUERY_STRING"]
|
||||
content_len = env.get("CONTENT_LENGTH", "0")
|
||||
input = env.get("wsgi.input")
|
||||
length = 0 if content_len == "" else int(content_len)
|
||||
body = b""
|
||||
request_items_dict = {}
|
||||
if length == 0:
|
||||
if input is None:
|
||||
return request_items_dict
|
||||
if env.get('HTTP_TRANSFER_ENCODING','0') == 'chunked':
|
||||
if env.get("HTTP_TRANSFER_ENCODING", "0") == "chunked":
|
||||
# readlines and read(no argument) will block
|
||||
# convert byte str to number base 16
|
||||
leng=int(input.readline(),16)
|
||||
c=0
|
||||
bdry=b''
|
||||
data=[]
|
||||
data_other=[]
|
||||
while leng >0:
|
||||
c+=1
|
||||
dt = input.read(leng+2)
|
||||
if c==1:
|
||||
bdry=dt
|
||||
elif c>=3:
|
||||
leng = int(input.readline(), 16)
|
||||
c = 0
|
||||
bdry = b""
|
||||
data = []
|
||||
data_other = []
|
||||
while leng > 0:
|
||||
c += 1
|
||||
dt = input.read(leng + 2)
|
||||
if c == 1:
|
||||
bdry = dt
|
||||
elif c >= 3:
|
||||
# data
|
||||
data_other.append(dt)
|
||||
leng = int(input.readline(),16)
|
||||
data_other=[item for item in data_other if item!=b'\r\n\r\n']
|
||||
leng = int(input.readline(), 16)
|
||||
data_other = [item for item in data_other if item != b"\r\n\r\n"]
|
||||
for item in data_other:
|
||||
if bdry in item:
|
||||
break
|
||||
# only strip \r\n if there are extra \n
|
||||
# eg b'?V\xc1\x8f>\xf9\xb1\n\r\n'
|
||||
data.append(item[:-2])
|
||||
request_items_dict['data']=b''.join(data)
|
||||
others=data_other[len(data):]
|
||||
boundary=others[0]
|
||||
others=b''.join(others).split(boundary.strip())
|
||||
request_items_dict["data"] = b"".join(data)
|
||||
others = data_other[len(data) :]
|
||||
boundary = others[0]
|
||||
others = b"".join(others).split(boundary.strip())
|
||||
others.pop()
|
||||
others.pop(0)
|
||||
for i in others:
|
||||
i=i.splitlines()
|
||||
key=re.findall(b'name="(.*?)"',i[2],flags=re.M)[0].decode('utf-8')
|
||||
v=i[-1].decode('utf-8')
|
||||
request_items_dict[key]=v
|
||||
i = i.splitlines()
|
||||
key = re.findall(b'name="(.*?)"', i[2], flags=re.M)[0].decode(
|
||||
"utf-8"
|
||||
)
|
||||
v = i[-1].decode("utf-8")
|
||||
request_items_dict[key] = v
|
||||
return request_items_dict
|
||||
|
||||
if query_string !='':
|
||||
|
||||
if query_string != "":
|
||||
# GET method
|
||||
body=query_string
|
||||
request_items_dict=urllib.parse.parse_qs(body)
|
||||
for k,v in request_items_dict.items():
|
||||
request_items_dict[k]=''.join(v)
|
||||
body = query_string
|
||||
request_items_dict = urllib.parse.parse_qs(body)
|
||||
for k, v in request_items_dict.items():
|
||||
request_items_dict[k] = "".join(v)
|
||||
return request_items_dict
|
||||
|
||||
|
||||
else:
|
||||
body = env['wsgi.input'].read(length)
|
||||
|
||||
if body is None or body ==b'':
|
||||
body = env["wsgi.input"].read(length)
|
||||
|
||||
if body is None or body == b"":
|
||||
return request_items_dict
|
||||
# process body to dict
|
||||
repeat=body.splitlines()[0]
|
||||
items=re.split(repeat,body)
|
||||
repeat = body.splitlines()[0]
|
||||
items = re.split(repeat, body)
|
||||
# del first ,last item
|
||||
items.pop()
|
||||
items.pop(0)
|
||||
for item in items:
|
||||
if b'name="data"' in item:
|
||||
data_field=None
|
||||
# remove \r\n
|
||||
if b'application/octet-stream' in item:
|
||||
data_field = None
|
||||
# remove \r\n
|
||||
if b"application/octet-stream" in item:
|
||||
# Ankidroid case
|
||||
item=re.sub(b'Content-Disposition: form-data; name="data"; filename="data"',b'',item)
|
||||
item=re.sub(b'Content-Type: application/octet-stream',b'',item)
|
||||
data_field=item.strip()
|
||||
item = re.sub(
|
||||
b'Content-Disposition: form-data; name="data"; filename="data"',
|
||||
b"",
|
||||
item,
|
||||
)
|
||||
item = re.sub(b"Content-Type: application/octet-stream", b"", item)
|
||||
data_field = item.strip()
|
||||
else:
|
||||
# PKzip file stream and others
|
||||
item=re.sub(b'Content-Disposition: form-data; name="data"; filename="data"',b'',item)
|
||||
data_field=item.strip()
|
||||
request_items_dict['data']=data_field
|
||||
item = re.sub(
|
||||
b'Content-Disposition: form-data; name="data"; filename="data"',
|
||||
b"",
|
||||
item,
|
||||
)
|
||||
data_field = item.strip()
|
||||
request_items_dict["data"] = data_field
|
||||
continue
|
||||
item=re.sub(b'\r\n',b'',item,flags=re.MULTILINE)
|
||||
key=re.findall(b'name="(.*?)"',item)[0].decode('utf-8')
|
||||
v=item[item.rfind(b'"')+1:].decode('utf-8')
|
||||
request_items_dict[key]=v
|
||||
item = re.sub(b"\r\n", b"", item, flags=re.MULTILINE)
|
||||
key = re.findall(b'name="(.*?)"', item)[0].decode("utf-8")
|
||||
v = item[item.rfind(b'"') + 1 :].decode("utf-8")
|
||||
request_items_dict[key] = v
|
||||
return request_items_dict
|
||||
|
||||
|
||||
class chunked(object):
|
||||
'''decorator'''
|
||||
"""decorator"""
|
||||
|
||||
def __init__(self, func):
|
||||
wraps(func)(self)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
clss=args[0]
|
||||
environ=args[1]
|
||||
clss = args[0]
|
||||
environ = args[1]
|
||||
start_response = args[2]
|
||||
b=Requests(environ)
|
||||
args=(clss,b,)
|
||||
w= self.__wrapped__(*args, **kwargs)
|
||||
resp=Response(w)
|
||||
b = Requests(environ)
|
||||
args = (
|
||||
clss,
|
||||
b,
|
||||
)
|
||||
w = self.__wrapped__(*args, **kwargs)
|
||||
resp = Response(w)
|
||||
return resp(environ, start_response)
|
||||
|
||||
def __get__(self, instance, cls):
|
||||
if instance is None:
|
||||
return self
|
||||
else:
|
||||
return types.MethodType(self, instance)
|
||||
|
||||
|
||||
class SyncApp:
|
||||
valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download']
|
||||
valid_urls = (
|
||||
SyncCollectionHandler.operations
|
||||
+ SyncMediaHandler.operations
|
||||
+ ["hostKey", "upload", "download"]
|
||||
)
|
||||
|
||||
def __init__(self, config):
|
||||
from ankisyncd.thread import get_collection_manager
|
||||
|
||||
self.data_root = os.path.abspath(config['data_root'])
|
||||
self.base_url = config['base_url']
|
||||
self.base_media_url = config['base_media_url']
|
||||
self.data_root = os.path.abspath(config["data_root"])
|
||||
self.base_url = config["base_url"]
|
||||
self.base_media_url = config["base_media_url"]
|
||||
self.setup_new_collection = None
|
||||
|
||||
self.user_manager = get_user_manager(config)
|
||||
@ -539,22 +614,31 @@ class SyncApp:
|
||||
self.collection_manager = get_collection_manager(config)
|
||||
|
||||
# make sure the base_url has a trailing slash
|
||||
if not self.base_url.endswith('/'):
|
||||
self.base_url += '/'
|
||||
if not self.base_media_url.endswith('/'):
|
||||
self.base_media_url += '/'
|
||||
if not self.base_url.endswith("/"):
|
||||
self.base_url += "/"
|
||||
if not self.base_media_url.endswith("/"):
|
||||
self.base_media_url += "/"
|
||||
|
||||
def generateHostKey(self, username):
|
||||
"""Generates a new host key to be used by the given username to identify their session.
|
||||
This values is random."""
|
||||
|
||||
import hashlib, time, random, string
|
||||
|
||||
chars = string.ascii_letters + string.digits
|
||||
val = ':'.join([username, str(int(time.time())), ''.join(random.choice(chars) for x in range(8))]).encode()
|
||||
val = ":".join(
|
||||
[
|
||||
username,
|
||||
str(int(time.time())),
|
||||
"".join(random.choice(chars) for x in range(8)),
|
||||
]
|
||||
).encode()
|
||||
return hashlib.md5(val).hexdigest()
|
||||
|
||||
def create_session(self, username, user_path):
|
||||
return SyncUserSession(username, user_path, self.collection_manager, self.setup_new_collection)
|
||||
return SyncUserSession(
|
||||
username, user_path, self.collection_manager, self.setup_new_collection
|
||||
)
|
||||
|
||||
def _decode_data(self, data, compression=0):
|
||||
if compression:
|
||||
@ -564,7 +648,7 @@ class SyncApp:
|
||||
try:
|
||||
data = json.loads(data.decode())
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
data = {'data': data}
|
||||
data = {"data": data}
|
||||
|
||||
return data
|
||||
|
||||
@ -581,7 +665,7 @@ class SyncApp:
|
||||
session = self.create_session(username, user_path)
|
||||
self.session_manager.save(hkey, session)
|
||||
|
||||
return {'key': hkey}
|
||||
return {"key": hkey}
|
||||
|
||||
def operation_upload(self, col, data, session):
|
||||
# Verify integrity of the received database file before replacing our
|
||||
@ -599,56 +683,56 @@ class SyncApp:
|
||||
# cgi file can only be read once,and will be blocked after being read once more
|
||||
# so i call Requests.parse only once,and bind its return result to properties
|
||||
# POST and params (set return result as property values)
|
||||
req.params=req.parse
|
||||
req.POST=req.params
|
||||
req.params = req.parse
|
||||
req.POST = req.params
|
||||
try:
|
||||
hkey = req.params['k']
|
||||
hkey = req.params["k"]
|
||||
except KeyError:
|
||||
hkey = None
|
||||
session = self.session_manager.load(hkey, self.create_session)
|
||||
if session is None:
|
||||
try:
|
||||
skey = req.POST['sk']
|
||||
skey = req.POST["sk"]
|
||||
session = self.session_manager.load_from_skey(skey, self.create_session)
|
||||
except KeyError:
|
||||
skey = None
|
||||
|
||||
try:
|
||||
compression = int(req.POST['c'])
|
||||
compression = int(req.POST["c"])
|
||||
except KeyError:
|
||||
compression = 0
|
||||
|
||||
try:
|
||||
data = req.POST['data']
|
||||
data = req.POST["data"]
|
||||
data = self._decode_data(data, compression)
|
||||
except KeyError:
|
||||
data = {}
|
||||
|
||||
if req.path.startswith(self.base_url):
|
||||
url = req.path[len(self.base_url):]
|
||||
url = req.path[len(self.base_url) :]
|
||||
if url not in self.valid_urls:
|
||||
raise HTTPNotFound()
|
||||
|
||||
if url == 'hostKey':
|
||||
if url == "hostKey":
|
||||
result = self.operation_hostKey(data.get("u"), data.get("p"))
|
||||
if result:
|
||||
return json.dumps(result)
|
||||
else:
|
||||
# TODO: do I have to pass 'null' for the client to receive None?
|
||||
raise HTTPForbidden('null')
|
||||
raise HTTPForbidden("null")
|
||||
|
||||
if session is None:
|
||||
raise HTTPForbidden()
|
||||
|
||||
if url in SyncCollectionHandler.operations + SyncMediaHandler.operations:
|
||||
# 'meta' passes the SYNC_VER but it isn't used in the handler
|
||||
if url == 'meta':
|
||||
if session.skey == None and 's' in req.POST:
|
||||
session.skey = req.POST['s']
|
||||
if 'v' in data:
|
||||
session.version = data['v']
|
||||
if 'cv' in data:
|
||||
session.client_version = data['cv']
|
||||
if url == "meta":
|
||||
if session.skey == None and "s" in req.POST:
|
||||
session.skey = req.POST["s"]
|
||||
if "v" in data:
|
||||
session.version = data["v"]
|
||||
if "cv" in data:
|
||||
session.client_version = data["cv"]
|
||||
|
||||
self.session_manager.save(hkey, session)
|
||||
session = self.session_manager.load(hkey, self.create_session)
|
||||
@ -660,12 +744,12 @@ class SyncApp:
|
||||
|
||||
return result
|
||||
|
||||
elif url == 'upload':
|
||||
elif url == "upload":
|
||||
thread = session.get_thread()
|
||||
result = thread.execute(self.operation_upload, [data['data'], session])
|
||||
result = thread.execute(self.operation_upload, [data["data"], session])
|
||||
return result
|
||||
|
||||
elif url == 'download':
|
||||
elif url == "download":
|
||||
thread = session.get_thread()
|
||||
result = thread.execute(self.operation_download, [session])
|
||||
return result
|
||||
@ -678,13 +762,13 @@ class SyncApp:
|
||||
if session is None:
|
||||
raise HTTPForbidden()
|
||||
|
||||
url = req.path[len(self.base_media_url):]
|
||||
url = req.path[len(self.base_media_url) :]
|
||||
|
||||
if url not in self.valid_urls:
|
||||
raise HTTPNotFound()
|
||||
|
||||
if url == "begin":
|
||||
data['skey'] = session.skey
|
||||
data["skey"] = session.skey
|
||||
|
||||
result = self._execute_handler_method_in_thread(url, data, session)
|
||||
|
||||
@ -726,10 +810,16 @@ class SyncApp:
|
||||
def make_app(global_conf, **local_conf):
|
||||
return SyncApp(**local_conf)
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO, format="[%(asctime)s]:%(levelname)s:%(name)s:%(message)s")
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="[%(asctime)s]:%(levelname)s:%(name)s:%(message)s"
|
||||
)
|
||||
import ankisyncd
|
||||
logger.info("ankisyncd {} ({})".format(ankisyncd._get_version(), ankisyncd._homepage))
|
||||
|
||||
logger.info(
|
||||
"ankisyncd {} ({})".format(ankisyncd._get_version(), ankisyncd._homepage)
|
||||
)
|
||||
from wsgiref.simple_server import make_server, WSGIRequestHandler
|
||||
from ankisyncd.thread import shutdown
|
||||
import ankisyncd.config
|
||||
@ -738,10 +828,10 @@ def main():
|
||||
logger = logging.getLogger("ankisyncd.http")
|
||||
|
||||
def log_error(self, format, *args):
|
||||
self.logger.error("%s %s", self.address_string(), format%args)
|
||||
self.logger.error("%s %s", self.address_string(), format % args)
|
||||
|
||||
def log_message(self, format, *args):
|
||||
self.logger.info("%s %s", self.address_string(), format%args)
|
||||
self.logger.info("%s %s", self.address_string(), format % args)
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
# backwards compat
|
||||
@ -750,7 +840,9 @@ def main():
|
||||
config = ankisyncd.config.load()
|
||||
|
||||
ankiserver = SyncApp(config)
|
||||
httpd = make_server(config['host'], int(config['port']), ankiserver, handler_class=RequestHandler)
|
||||
httpd = make_server(
|
||||
config["host"], int(config["port"]), ankiserver, handler_class=RequestHandler
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info("Serving HTTP on {} port {}...".format(*httpd.server_address))
|
||||
@ -760,5 +852,6 @@ def main():
|
||||
finally:
|
||||
shutdown()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -5,6 +5,7 @@ from queue import Queue
|
||||
|
||||
import time, logging
|
||||
|
||||
|
||||
def short_repr(obj, logger=logging.getLogger(), maxlen=80):
|
||||
"""Like repr, but shortens strings and bytestrings if logger's logging level
|
||||
is above DEBUG. Currently shallow and very limited, only implemented for
|
||||
@ -28,6 +29,7 @@ def short_repr(obj, logger=logging.getLogger(), maxlen=80):
|
||||
|
||||
return repr(o)
|
||||
|
||||
|
||||
class ThreadingCollectionWrapper:
|
||||
"""Provides the same interface as CollectionWrapper, but it creates a new Thread to
|
||||
interact with the collection."""
|
||||
@ -56,10 +58,11 @@ class ThreadingCollectionWrapper:
|
||||
|
||||
def current(self):
|
||||
from threading import current_thread
|
||||
|
||||
return current_thread() == self._thread
|
||||
|
||||
def execute(self, func, args=[], kw={}, waitForReturn=True):
|
||||
""" Executes a given function on this thread with the *args and **kw.
|
||||
"""Executes a given function on this thread with the *args and **kw.
|
||||
|
||||
If 'waitForReturn' is True, then it will block until the function has
|
||||
executed and return its return value. If False, it will return None
|
||||
@ -86,19 +89,30 @@ class ThreadingCollectionWrapper:
|
||||
while self._running:
|
||||
func, args, kw, return_queue = self._queue.get(True)
|
||||
|
||||
if hasattr(func, '__name__'):
|
||||
if hasattr(func, "__name__"):
|
||||
func_name = func.__name__
|
||||
else:
|
||||
func_name = func.__class__.__name__
|
||||
|
||||
self.logger.info("Running %s(*%s, **%s)", func_name, short_repr(args, self.logger), short_repr(kw, self.logger))
|
||||
self.logger.info(
|
||||
"Running %s(*%s, **%s)",
|
||||
func_name,
|
||||
short_repr(args, self.logger),
|
||||
short_repr(kw, self.logger),
|
||||
)
|
||||
self.last_timestamp = time.time()
|
||||
|
||||
try:
|
||||
ret = self.wrapper.execute(func, args, kw, return_queue)
|
||||
except Exception as e:
|
||||
self.logger.error("Unable to %s(*%s, **%s): %s",
|
||||
func_name, repr(args), repr(kw), e, exc_info=True)
|
||||
self.logger.error(
|
||||
"Unable to %s(*%s, **%s): %s",
|
||||
func_name,
|
||||
repr(args),
|
||||
repr(kw),
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
# we return the Exception which will be raise'd on the other end
|
||||
ret = e
|
||||
|
||||
@ -125,10 +139,11 @@ class ThreadingCollectionWrapper:
|
||||
def stop(self):
|
||||
def _stop(col):
|
||||
self._running = False
|
||||
|
||||
self.execute(_stop, waitForReturn=False)
|
||||
|
||||
def stop_and_wait(self):
|
||||
""" Tell the thread to stop and wait for it to happen. """
|
||||
"""Tell the thread to stop and wait for it to happen."""
|
||||
self.stop()
|
||||
if self._thread is not None:
|
||||
self._thread.join()
|
||||
@ -146,11 +161,13 @@ class ThreadingCollectionWrapper:
|
||||
|
||||
def _close(col):
|
||||
self.wrapper.close()
|
||||
|
||||
self.execute(_close, waitForReturn=False)
|
||||
|
||||
def opened(self):
|
||||
return self.wrapper.opened()
|
||||
|
||||
|
||||
class ThreadingCollectionManager(CollectionManager):
|
||||
"""Manages a set of ThreadingCollectionWrapper objects."""
|
||||
|
||||
@ -174,14 +191,21 @@ class ThreadingCollectionManager(CollectionManager):
|
||||
# TODO: it would be awesome to have a safe way to stop inactive threads completely!
|
||||
# TODO: we need a way to inform other code that the collection has been closed
|
||||
def _monitor_run(self):
|
||||
""" Monitors threads for inactivity and closes the collection on them
|
||||
"""Monitors threads for inactivity and closes the collection on them
|
||||
(leaves the thread itself running -- hopefully waiting peacefully with only a
|
||||
small memory footprint!) """
|
||||
small memory footprint!)"""
|
||||
while True:
|
||||
cur = time.time()
|
||||
for path, thread in self.collections.items():
|
||||
if thread.running and thread.wrapper.opened() and thread.qempty() and cur - thread.last_timestamp >= self.monitor_inactivity:
|
||||
self.logger.info("Monitor is closing collection on inactive %s", thread)
|
||||
if (
|
||||
thread.running
|
||||
and thread.wrapper.opened()
|
||||
and thread.qempty()
|
||||
and cur - thread.last_timestamp >= self.monitor_inactivity
|
||||
):
|
||||
self.logger.info(
|
||||
"Monitor is closing collection on inactive %s", thread
|
||||
)
|
||||
thread.close()
|
||||
time.sleep(self.monitor_frequency)
|
||||
|
||||
@ -196,12 +220,14 @@ class ThreadingCollectionManager(CollectionManager):
|
||||
# let the parent do whatever else it might want to do...
|
||||
super(ThreadingCollectionManager, self).shutdown()
|
||||
|
||||
|
||||
#
|
||||
# For working with the global ThreadingCollectionManager:
|
||||
#
|
||||
|
||||
collection_manager = None
|
||||
|
||||
|
||||
def get_collection_manager(config):
|
||||
"""Return the global ThreadingCollectionManager for this process."""
|
||||
global collection_manager
|
||||
@ -209,10 +235,10 @@ def get_collection_manager(config):
|
||||
collection_manager = ThreadingCollectionManager(config)
|
||||
return collection_manager
|
||||
|
||||
|
||||
def shutdown():
|
||||
"""If the global ThreadingCollectionManager exists, shut it down."""
|
||||
global collection_manager
|
||||
if collection_manager is not None:
|
||||
collection_manager.shutdown()
|
||||
collection_manager = None
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ logger = logging.getLogger("ankisyncd.users")
|
||||
class SimpleUserManager:
|
||||
"""A simple user manager that always allows any user."""
|
||||
|
||||
def __init__(self, collection_path=''):
|
||||
def __init__(self, collection_path=""):
|
||||
self.collection_path = collection_path
|
||||
|
||||
def authenticate(self, username, password):
|
||||
@ -34,8 +34,11 @@ class SimpleUserManager:
|
||||
def _create_user_dir(self, username):
|
||||
user_dir_path = os.path.join(self.collection_path, username)
|
||||
if not os.path.isdir(user_dir_path):
|
||||
logger.info("Creating collection directory for user '{}' at {}"
|
||||
.format(username, user_dir_path))
|
||||
logger.info(
|
||||
"Creating collection directory for user '{}' at {}".format(
|
||||
username, user_dir_path
|
||||
)
|
||||
)
|
||||
os.makedirs(user_dir_path)
|
||||
|
||||
|
||||
@ -54,13 +57,17 @@ class SqliteUserManager(SimpleUserManager):
|
||||
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM sqlite_master "
|
||||
"WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' "
|
||||
"AND tbl_name = 'auth'")
|
||||
cursor.execute(
|
||||
"SELECT * FROM sqlite_master "
|
||||
"WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' "
|
||||
"AND tbl_name = 'auth'"
|
||||
)
|
||||
res = cursor.fetchone()
|
||||
conn.close()
|
||||
if res is not None:
|
||||
raise Exception("Outdated database schema, run utils/migrate_user_tables.py")
|
||||
raise Exception(
|
||||
"Outdated database schema, run utils/migrate_user_tables.py"
|
||||
)
|
||||
|
||||
# Default to using sqlite3 but overridable for sub-classes using other
|
||||
# DB API 2 driver variants
|
||||
@ -124,8 +131,7 @@ class SqliteUserManager(SimpleUserManager):
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
logger.info("Adding user '{}' to auth db.".format(username))
|
||||
cursor.execute(self.fs("INSERT INTO auth VALUES (?, ?)"),
|
||||
(username, pass_hash))
|
||||
cursor.execute(self.fs("INSERT INTO auth VALUES (?, ?)"), (username, pass_hash))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
@ -139,7 +145,9 @@ class SqliteUserManager(SimpleUserManager):
|
||||
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(self.fs("UPDATE auth SET hash=? WHERE username=?"), (hash, username))
|
||||
cursor.execute(
|
||||
self.fs("UPDATE auth SET hash=? WHERE username=?"), (hash, username)
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
@ -156,8 +164,9 @@ class SqliteUserManager(SimpleUserManager):
|
||||
conn.close()
|
||||
|
||||
if db_hash is None:
|
||||
logger.info("Authentication failed for nonexistent user {}."
|
||||
.format(username))
|
||||
logger.info(
|
||||
"Authentication failed for nonexistent user {}.".format(username)
|
||||
)
|
||||
return False
|
||||
|
||||
expected_value = str(db_hash[0])
|
||||
@ -181,17 +190,22 @@ class SqliteUserManager(SimpleUserManager):
|
||||
@staticmethod
|
||||
def _create_pass_hash(username, password):
|
||||
salt = binascii.b2a_hex(os.urandom(8))
|
||||
pass_hash = (hashlib.sha256((username + password).encode() + salt).hexdigest() +
|
||||
salt.decode())
|
||||
pass_hash = (
|
||||
hashlib.sha256((username + password).encode() + salt).hexdigest()
|
||||
+ salt.decode()
|
||||
)
|
||||
return pass_hash
|
||||
|
||||
def create_auth_db(self):
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
logger.info("Creating auth db at {}."
|
||||
.format(self.auth_db_path))
|
||||
cursor.execute(self.fs("""CREATE TABLE IF NOT EXISTS auth
|
||||
(username VARCHAR PRIMARY KEY, hash VARCHAR)"""))
|
||||
logger.info("Creating auth db at {}.".format(self.auth_db_path))
|
||||
cursor.execute(
|
||||
self.fs(
|
||||
"""CREATE TABLE IF NOT EXISTS auth
|
||||
(username VARCHAR PRIMARY KEY, hash VARCHAR)"""
|
||||
)
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
@ -199,19 +213,28 @@ class SqliteUserManager(SimpleUserManager):
|
||||
def get_user_manager(config):
|
||||
if "auth_db_path" in config and config["auth_db_path"]:
|
||||
logger.info("Found auth_db_path in config, using SqliteUserManager for auth")
|
||||
return SqliteUserManager(config['auth_db_path'], config['data_root'])
|
||||
return SqliteUserManager(config["auth_db_path"], config["data_root"])
|
||||
elif "user_manager" in config and config["user_manager"]: # load from config
|
||||
logger.info("Found user_manager in config, using {} for auth".format(config['user_manager']))
|
||||
logger.info(
|
||||
"Found user_manager in config, using {} for auth".format(
|
||||
config["user_manager"]
|
||||
)
|
||||
)
|
||||
import importlib
|
||||
import inspect
|
||||
module_name, class_name = config['user_manager'].rsplit('.', 1)
|
||||
|
||||
module_name, class_name = config["user_manager"].rsplit(".", 1)
|
||||
module = importlib.import_module(module_name.strip())
|
||||
class_ = getattr(module, class_name.strip())
|
||||
|
||||
if not SimpleUserManager in inspect.getmro(class_):
|
||||
raise TypeError('''"user_manager" found in the conf file but it doesn''t
|
||||
inherit from SimpleUserManager''')
|
||||
raise TypeError(
|
||||
""""user_manager" found in the conf file but it doesn''t
|
||||
inherit from SimpleUserManager"""
|
||||
)
|
||||
return class_(config)
|
||||
else:
|
||||
logger.warning("neither auth_db_path nor user_manager set, ankisyncd will accept any password")
|
||||
return SimpleUserManager()
|
||||
logger.warning(
|
||||
"neither auth_db_path nor user_manager set, ankisyncd will accept any password"
|
||||
)
|
||||
return SimpleUserManager()
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from ankisyncd_cli import ankisyncctl
|
||||
|
||||
if __name__ == '__main__':
|
||||
ankisyncctl.main()
|
||||
if __name__ == "__main__":
|
||||
ankisyncctl.main()
|
||||
|
||||
@ -9,6 +9,7 @@ from ankisyncd.users import get_user_manager
|
||||
|
||||
config = config.load()
|
||||
|
||||
|
||||
def usage():
|
||||
print("usage: {} <command> [<args>]".format(sys.argv[0]))
|
||||
print()
|
||||
@ -18,12 +19,14 @@ def usage():
|
||||
print(" lsuser - list users")
|
||||
print(" passwd <username> - change password of a user")
|
||||
|
||||
|
||||
def adduser(username):
|
||||
password = getpass.getpass("Enter password for {}: ".format(username))
|
||||
|
||||
user_manager = get_user_manager(config)
|
||||
user_manager.add_user(username, password)
|
||||
|
||||
|
||||
def deluser(username):
|
||||
user_manager = get_user_manager(config)
|
||||
try:
|
||||
@ -31,6 +34,7 @@ def deluser(username):
|
||||
except ValueError as error:
|
||||
print("Could not delete user {}: {}".format(username, error), file=sys.stderr)
|
||||
|
||||
|
||||
def lsuser():
|
||||
user_manager = get_user_manager(config)
|
||||
try:
|
||||
@ -40,6 +44,7 @@ def lsuser():
|
||||
except ValueError as error:
|
||||
print("Could not list users: {}".format(error), file=sys.stderr)
|
||||
|
||||
|
||||
def passwd(username):
|
||||
user_manager = get_user_manager(config)
|
||||
|
||||
@ -51,7 +56,11 @@ def passwd(username):
|
||||
try:
|
||||
user_manager.set_password_for_user(username, password)
|
||||
except ValueError as error:
|
||||
print("Could not set password for user {}: {}".format(username, error), file=sys.stderr)
|
||||
print(
|
||||
"Could not set password for user {}: {}".format(username, error),
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
argc = len(sys.argv)
|
||||
@ -78,5 +87,6 @@ def main():
|
||||
usage()
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -7,11 +7,13 @@ word in many other SQL dialects.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
path = os.path.realpath(os.path.abspath(os.path.join(__file__, '../')))
|
||||
|
||||
path = os.path.realpath(os.path.abspath(os.path.join(__file__, "../")))
|
||||
sys.path.insert(0, os.path.dirname(path))
|
||||
|
||||
import sqlite3
|
||||
import ankisyncd.config
|
||||
|
||||
conf = ankisyncd.config.load()
|
||||
|
||||
|
||||
@ -21,15 +23,21 @@ def main():
|
||||
conn = sqlite3.connect(conf["auth_db_path"])
|
||||
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM sqlite_master "
|
||||
"WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' "
|
||||
"AND tbl_name = 'auth'")
|
||||
cursor.execute(
|
||||
"SELECT * FROM sqlite_master "
|
||||
"WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' "
|
||||
"AND tbl_name = 'auth'"
|
||||
)
|
||||
res = cursor.fetchone()
|
||||
|
||||
if res is not None:
|
||||
cursor.execute("ALTER TABLE auth RENAME TO auth_old")
|
||||
cursor.execute("CREATE TABLE auth (username VARCHAR PRIMARY KEY, hash VARCHAR)")
|
||||
cursor.execute("INSERT INTO auth (username, hash) SELECT user, hash FROM auth_old")
|
||||
cursor.execute(
|
||||
"CREATE TABLE auth (username VARCHAR PRIMARY KEY, hash VARCHAR)"
|
||||
)
|
||||
cursor.execute(
|
||||
"INSERT INTO auth (username, hash) SELECT user, hash FROM auth_old"
|
||||
)
|
||||
cursor.execute("DROP TABLE auth_old")
|
||||
conn.commit()
|
||||
print("Successfully updated table 'auth'")
|
||||
@ -44,17 +52,23 @@ def main():
|
||||
conn = sqlite3.connect(conf["session_db_path"])
|
||||
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM sqlite_master "
|
||||
"WHERE sql LIKE '%user VARCHAR%' "
|
||||
"AND tbl_name = 'session'")
|
||||
cursor.execute(
|
||||
"SELECT * FROM sqlite_master "
|
||||
"WHERE sql LIKE '%user VARCHAR%' "
|
||||
"AND tbl_name = 'session'"
|
||||
)
|
||||
res = cursor.fetchone()
|
||||
|
||||
if res is not None:
|
||||
cursor.execute("ALTER TABLE session RENAME TO session_old")
|
||||
cursor.execute("CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, "
|
||||
"username VARCHAR, path VARCHAR)")
|
||||
cursor.execute("INSERT INTO session (hkey, skey, username, path) "
|
||||
"SELECT hkey, skey, user, path FROM session_old")
|
||||
cursor.execute(
|
||||
"CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, "
|
||||
"username VARCHAR, path VARCHAR)"
|
||||
)
|
||||
cursor.execute(
|
||||
"INSERT INTO session (hkey, skey, username, path) "
|
||||
"SELECT hkey, skey, user, path FROM session_old"
|
||||
)
|
||||
cursor.execute("DROP TABLE session_old")
|
||||
conn.commit()
|
||||
print("Successfully updated table 'session'")
|
||||
|
||||
@ -8,5 +8,5 @@ setup(
|
||||
author="Anki Community",
|
||||
author_email="kothary.vikash+ankicommunity@gmail.com",
|
||||
packages=find_packages(),
|
||||
url='https://ankicommunity.github.io/'
|
||||
url="https://ankicommunity.github.io/",
|
||||
)
|
||||
|
||||
@ -16,7 +16,7 @@ class CollectionTestBase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.collection_path = os.path.join(self.temp_dir, 'collection.anki2');
|
||||
self.collection_path = os.path.join(self.temp_dir, "collection.anki2")
|
||||
cm = CollectionManager({})
|
||||
collectionWrapper = cm.get_collection(self.collection_path)
|
||||
self.collection = collectionWrapper._get_collection()
|
||||
@ -32,26 +32,26 @@ class CollectionTestBase(unittest.TestCase):
|
||||
def add_note(self, data):
|
||||
from anki.notes import Note
|
||||
|
||||
model = self.collection.models.byName(data['model'])
|
||||
model = self.collection.models.byName(data["model"])
|
||||
|
||||
note = Note(self.collection, model)
|
||||
for name, value in data['fields'].items():
|
||||
for name, value in data["fields"].items():
|
||||
note[name] = value
|
||||
|
||||
if 'tags' in data:
|
||||
note.setTagsFromStr(data['tags'])
|
||||
if "tags" in data:
|
||||
note.setTagsFromStr(data["tags"])
|
||||
|
||||
self.collection.addNote(note)
|
||||
|
||||
# TODO: refactor into a parent class
|
||||
def add_default_note(self, count=1):
|
||||
data = {
|
||||
'model': 'Basic',
|
||||
'fields': {
|
||||
'Front': 'The front',
|
||||
'Back': 'The back',
|
||||
"model": "Basic",
|
||||
"fields": {
|
||||
"Front": "The front",
|
||||
"Back": "The back",
|
||||
},
|
||||
'tags': "Tag1 Tag2",
|
||||
"tags": "Tag1 Tag2",
|
||||
}
|
||||
for idx in range(0, count):
|
||||
self.add_note(data)
|
||||
|
||||
@ -1 +1 @@
|
||||
from . import db_utils
|
||||
from . import db_utils
|
||||
|
||||
@ -5,6 +5,7 @@ import tempfile
|
||||
|
||||
from anki.collection import Collection
|
||||
|
||||
|
||||
class CollectionUtils:
|
||||
"""
|
||||
Provides utility methods for creating, inspecting and manipulating anki
|
||||
|
||||
@ -34,7 +34,7 @@ def to_sql(database):
|
||||
else:
|
||||
connection = database
|
||||
|
||||
res = '\n'.join(connection.iterdump())
|
||||
res = "\n".join(connection.iterdump())
|
||||
|
||||
if type(database) == str:
|
||||
connection.close()
|
||||
@ -54,18 +54,15 @@ def diff(left_db_path, right_db_path):
|
||||
|
||||
command = ["sqldiff", left_db_path, right_db_path]
|
||||
|
||||
child_process = subprocess.Popen(command,
|
||||
shell=False,
|
||||
stdout=subprocess.PIPE)
|
||||
child_process = subprocess.Popen(command, shell=False, stdout=subprocess.PIPE)
|
||||
stdout, stderr = child_process.communicate()
|
||||
exit_code = child_process.returncode
|
||||
|
||||
if exit_code != 0 or stderr is not None:
|
||||
raise RuntimeError("Command {} encountered an error, exit "
|
||||
"code: {}, stderr: {}"
|
||||
.format(" ".join(command),
|
||||
exit_code,
|
||||
stderr))
|
||||
raise RuntimeError(
|
||||
"Command {} encountered an error, exit "
|
||||
"code: {}, stderr: {}".format(" ".join(command), exit_code, stderr)
|
||||
)
|
||||
|
||||
# Any output from sqldiff means the databases differ.
|
||||
# Any output from sqldiff means the databases differ.
|
||||
return stdout != ""
|
||||
|
||||
@ -21,9 +21,8 @@ class MockServerConnection:
|
||||
r = self.test_app.post(url, params=data.read(), headers=headers, status="*")
|
||||
return types.SimpleNamespace(status_code=r.status_int, body=r.body)
|
||||
|
||||
|
||||
def streamContent(self, r):
|
||||
return r.body
|
||||
return r.body
|
||||
|
||||
|
||||
class MockRemoteServer(RemoteServer):
|
||||
|
||||
@ -12,9 +12,7 @@ mediamanager_orig_funcs = {
|
||||
"_logChanges": None,
|
||||
}
|
||||
|
||||
db_orig_funcs = {
|
||||
"__init__": None
|
||||
}
|
||||
db_orig_funcs = {"__init__": None}
|
||||
|
||||
|
||||
def monkeypatch_mediamanager():
|
||||
@ -35,6 +33,7 @@ def monkeypatch_mediamanager():
|
||||
|
||||
os.chdir(old_cwd)
|
||||
return res
|
||||
|
||||
return wrapper
|
||||
|
||||
MediaManager.findChanges = make_cwd_safe(MediaManager.findChanges)
|
||||
@ -47,6 +46,7 @@ def unpatch_mediamanager():
|
||||
|
||||
mediamanager_orig_funcs["findChanges"] = None
|
||||
|
||||
|
||||
def monkeypatch_db():
|
||||
"""
|
||||
Monkey patches Anki's DB.__init__ to connect to allow access to the db
|
||||
@ -58,9 +58,7 @@ def monkeypatch_db():
|
||||
def patched___init__(self, path, text=None, timeout=0):
|
||||
# Code taken from Anki's DB.__init__()
|
||||
# Allow more than one thread to use this connection.
|
||||
self._db = sqlite.connect(path,
|
||||
timeout=timeout,
|
||||
check_same_thread=False)
|
||||
self._db = sqlite.connect(path, timeout=timeout, check_same_thread=False)
|
||||
if text:
|
||||
self._db.text_factory = text
|
||||
self._path = path
|
||||
|
||||
@ -23,50 +23,58 @@ def create_server_paths():
|
||||
"data_root": os.path.join(dir, "data"),
|
||||
}
|
||||
|
||||
|
||||
def create_sync_app(server_paths, config_path):
|
||||
config = configparser.ConfigParser()
|
||||
config.read(config_path)
|
||||
|
||||
# Use custom files and dirs in settings.
|
||||
config['sync_app'].update(server_paths)
|
||||
config["sync_app"].update(server_paths)
|
||||
|
||||
return SyncApp(config["sync_app"])
|
||||
|
||||
return SyncApp(config['sync_app'])
|
||||
|
||||
def get_session_for_hkey(server, hkey):
|
||||
return server.session_manager.load(hkey)
|
||||
|
||||
|
||||
def get_thread_for_hkey(server, hkey):
|
||||
session = get_session_for_hkey(server, hkey)
|
||||
thread = session.get_thread()
|
||||
return thread
|
||||
|
||||
|
||||
def get_col_wrapper_for_hkey(server, hkey):
|
||||
thread = get_thread_for_hkey(server, hkey)
|
||||
col_wrapper = thread.wrapper
|
||||
return col_wrapper
|
||||
|
||||
|
||||
def get_col_for_hkey(server, hkey):
|
||||
col_wrapper = get_col_wrapper_for_hkey(server, hkey)
|
||||
col_wrapper.open() # Make sure the col is opened.
|
||||
return col_wrapper._CollectionWrapper__col
|
||||
|
||||
|
||||
def get_col_db_path_for_hkey(server, hkey):
|
||||
col = get_col_for_hkey(server, hkey)
|
||||
return col.db._path
|
||||
|
||||
def get_syncer_for_hkey(server, hkey, syncer_type='collection'):
|
||||
|
||||
def get_syncer_for_hkey(server, hkey, syncer_type="collection"):
|
||||
col = get_col_for_hkey(server, hkey)
|
||||
|
||||
session = get_session_for_hkey(server, hkey)
|
||||
|
||||
syncer_type = syncer_type.lower()
|
||||
if syncer_type == 'collection':
|
||||
if syncer_type == "collection":
|
||||
handler_method = SyncCollectionHandler.operations[0]
|
||||
elif syncer_type == 'media':
|
||||
elif syncer_type == "media":
|
||||
handler_method = SyncMediaHandler.operations[0]
|
||||
|
||||
return session.get_handler_for_operation(handler_method, col)
|
||||
|
||||
|
||||
def add_files_to_client_mediadb(media, filepaths, update_db=False):
|
||||
for filepath in filepaths:
|
||||
logging.debug("Adding file '{}' to client media DB".format(filepath))
|
||||
@ -76,16 +84,15 @@ def add_files_to_client_mediadb(media, filepaths, update_db=False):
|
||||
if update_db:
|
||||
media.findChanges() # Write changes to db.
|
||||
|
||||
|
||||
def add_files_to_server_mediadb(media, filepaths):
|
||||
for filepath in filepaths:
|
||||
logging.debug("Adding file '{}' to server media DB".format(filepath))
|
||||
fname = os.path.basename(filepath)
|
||||
with open(filepath, 'rb') as infile:
|
||||
with open(filepath, "rb") as infile:
|
||||
data = infile.read()
|
||||
csum = anki.utils.checksum(data)
|
||||
|
||||
with open(os.path.join(media.dir(), fname), 'wb') as f:
|
||||
with open(os.path.join(media.dir(), fname), "wb") as f:
|
||||
f.write(data)
|
||||
media.addMedia(
|
||||
((fname, media.lastUsn() + 1, csum),)
|
||||
)
|
||||
media.addMedia(((fname, media.lastUsn() + 1, csum),))
|
||||
|
||||
@ -11,7 +11,6 @@ from helpers.monkey_patches import monkeypatch_db, unpatch_db
|
||||
|
||||
|
||||
class SyncAppFunctionalTestBase(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.colutils = CollectionUtils()
|
||||
@ -28,28 +27,29 @@ class SyncAppFunctionalTestBase(unittest.TestCase):
|
||||
self.server_paths = helpers.server_utils.create_server_paths()
|
||||
|
||||
# Add a test user to the temp auth db the server will use.
|
||||
self.user_manager = SqliteUserManager(self.server_paths['auth_db_path'],
|
||||
self.server_paths['data_root'])
|
||||
self.user_manager.add_user('testuser', 'testpassword')
|
||||
self.user_manager = SqliteUserManager(
|
||||
self.server_paths["auth_db_path"], self.server_paths["data_root"]
|
||||
)
|
||||
self.user_manager.add_user("testuser", "testpassword")
|
||||
|
||||
# Get absolute path to development ini file.
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
ini_file_path = os.path.join(script_dir,
|
||||
"assets",
|
||||
"test.conf")
|
||||
ini_file_path = os.path.join(script_dir, "assets", "test.conf")
|
||||
|
||||
# Create SyncApp instance using the dev ini file and the temporary
|
||||
# paths.
|
||||
self.server_app = helpers.server_utils.create_sync_app(self.server_paths,
|
||||
ini_file_path)
|
||||
self.server_app = helpers.server_utils.create_sync_app(
|
||||
self.server_paths, ini_file_path
|
||||
)
|
||||
|
||||
# Wrap the SyncApp object in TestApp instance for testing.
|
||||
self.server_test_app = TestApp(self.server_app)
|
||||
|
||||
# MockRemoteServer instance needed for testing normal collection
|
||||
# syncing and for retrieving hkey for other tests.
|
||||
self.mock_remote_server = MockRemoteServer(hkey=None,
|
||||
server_test_app=self.server_test_app)
|
||||
self.mock_remote_server = MockRemoteServer(
|
||||
hkey=None, server_test_app=self.server_test_app
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
self.server_paths = {}
|
||||
|
||||
@ -9,39 +9,52 @@ from ankisyncd.collection import get_collection_wrapper
|
||||
|
||||
import helpers.server_utils
|
||||
|
||||
|
||||
class FakeCollectionWrapper(CollectionWrapper):
|
||||
def __init__(self, config, path, setup_new_collection=None):
|
||||
self. _CollectionWrapper__col = None
|
||||
self._CollectionWrapper__col = None
|
||||
pass
|
||||
|
||||
|
||||
class BadCollectionWrapper:
|
||||
pass
|
||||
|
||||
|
||||
class CollectionWrapperFactoryTest(unittest.TestCase):
|
||||
def test_get_collection_wrapper(self):
|
||||
# Get absolute path to development ini file.
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
ini_file_path = os.path.join(script_dir,
|
||||
"assets",
|
||||
"test.conf")
|
||||
ini_file_path = os.path.join(script_dir, "assets", "test.conf")
|
||||
|
||||
# Create temporary files and dirs the server will use.
|
||||
server_paths = helpers.server_utils.create_server_paths()
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read(ini_file_path)
|
||||
path = os.path.realpath('fake/collection.anki2')
|
||||
path = os.path.realpath("fake/collection.anki2")
|
||||
|
||||
# Use custom files and dirs in settings. Should be CollectionWrapper
|
||||
config['sync_app'].update(server_paths)
|
||||
self.assertTrue(type(get_collection_wrapper(config['sync_app'], path) == CollectionWrapper))
|
||||
config["sync_app"].update(server_paths)
|
||||
self.assertTrue(
|
||||
type(get_collection_wrapper(config["sync_app"], path) == CollectionWrapper)
|
||||
)
|
||||
|
||||
# A conf-specified CollectionWrapper is loaded
|
||||
config.set("sync_app", "collection_wrapper", 'test_collection_wrappers.FakeCollectionWrapper')
|
||||
self.assertTrue(type(get_collection_wrapper(config['sync_app'], path)) == FakeCollectionWrapper)
|
||||
config.set(
|
||||
"sync_app",
|
||||
"collection_wrapper",
|
||||
"test_collection_wrappers.FakeCollectionWrapper",
|
||||
)
|
||||
self.assertTrue(
|
||||
type(get_collection_wrapper(config["sync_app"], path))
|
||||
== FakeCollectionWrapper
|
||||
)
|
||||
|
||||
# Should fail at load time if the class doesn't inherit from CollectionWrapper
|
||||
config.set("sync_app", "collection_wrapper", 'test_collection_wrappers.BadCollectionWrapper')
|
||||
config.set(
|
||||
"sync_app",
|
||||
"collection_wrapper",
|
||||
"test_collection_wrappers.BadCollectionWrapper",
|
||||
)
|
||||
with self.assertRaises(TypeError):
|
||||
pm = get_collection_wrapper(config['sync_app'], path)
|
||||
|
||||
pm = get_collection_wrapper(config["sync_app"], path)
|
||||
|
||||
@ -8,20 +8,21 @@ from ankisyncd.full_sync import FullSyncManager, get_full_sync_manager
|
||||
|
||||
import helpers.server_utils
|
||||
|
||||
|
||||
class FakeFullSyncManager(FullSyncManager):
|
||||
def __init__(self, config):
|
||||
pass
|
||||
|
||||
|
||||
class BadFullSyncManager:
|
||||
pass
|
||||
|
||||
|
||||
class FullSyncManagerFactoryTest(unittest.TestCase):
|
||||
def test_get_full_sync_manager(self):
|
||||
# Get absolute path to development ini file.
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
ini_file_path = os.path.join(script_dir,
|
||||
"assets",
|
||||
"test.conf")
|
||||
ini_file_path = os.path.join(script_dir, "assets", "test.conf")
|
||||
|
||||
# Create temporary files and dirs the server will use.
|
||||
server_paths = helpers.server_utils.create_server_paths()
|
||||
@ -30,15 +31,20 @@ class FullSyncManagerFactoryTest(unittest.TestCase):
|
||||
config.read(ini_file_path)
|
||||
|
||||
# Use custom files and dirs in settings. Should be PersistenceManager
|
||||
config['sync_app'].update(server_paths)
|
||||
self.assertTrue(type(get_full_sync_manager(config['sync_app']) == FullSyncManager))
|
||||
config["sync_app"].update(server_paths)
|
||||
self.assertTrue(
|
||||
type(get_full_sync_manager(config["sync_app"]) == FullSyncManager)
|
||||
)
|
||||
|
||||
# A conf-specified FullSyncManager is loaded
|
||||
config.set("sync_app", "full_sync_manager", 'test_full_sync.FakeFullSyncManager')
|
||||
self.assertTrue(type(get_full_sync_manager(config['sync_app'])) == FakeFullSyncManager)
|
||||
config.set(
|
||||
"sync_app", "full_sync_manager", "test_full_sync.FakeFullSyncManager"
|
||||
)
|
||||
self.assertTrue(
|
||||
type(get_full_sync_manager(config["sync_app"])) == FakeFullSyncManager
|
||||
)
|
||||
|
||||
# Should fail at load time if the class doesn't inherit from FullSyncManager
|
||||
config.set("sync_app", "full_sync_manager", 'test_full_sync.BadFullSyncManager')
|
||||
config.set("sync_app", "full_sync_manager", "test_full_sync.BadFullSyncManager")
|
||||
with self.assertRaises(TypeError):
|
||||
pm = get_full_sync_manager(config['sync_app'])
|
||||
|
||||
pm = get_full_sync_manager(config["sync_app"])
|
||||
|
||||
@ -45,26 +45,22 @@ class ServerMediaManagerTest(unittest.TestCase):
|
||||
list(cm.db.execute("SELECT fname, csum FROM media")),
|
||||
)
|
||||
self.assertEqual(cm.lastUsn(), sm.lastUsn())
|
||||
self.assertEqual(
|
||||
list(sm.db.execute("SELECT usn FROM media")),
|
||||
[(161,), (161,)]
|
||||
)
|
||||
self.assertEqual(list(sm.db.execute("SELECT usn FROM media")), [(161,), (161,)])
|
||||
|
||||
def test_mediaChanges_lastUsn_order(self):
|
||||
col = self.colutils.create_empty_col()
|
||||
col.media = ankisyncd.media.ServerMediaManager(col)
|
||||
session = MagicMock()
|
||||
session.name = 'test'
|
||||
session.name = "test"
|
||||
mh = ankisyncd.sync_app.SyncMediaHandler(col, session)
|
||||
mh.col.media.addMedia(
|
||||
(
|
||||
('fileA', 101, '53059abba1a72c7aff34a3eaf7fef10ed65541ce'),
|
||||
('fileB', 100, 'a5ae546046d09559399c80fa7076fb10f1ce4bcd'),
|
||||
("fileA", 101, "53059abba1a72c7aff34a3eaf7fef10ed65541ce"),
|
||||
("fileB", 100, "a5ae546046d09559399c80fa7076fb10f1ce4bcd"),
|
||||
)
|
||||
)
|
||||
# anki assumes mh.col.media.lastUsn() == mh.mediaChanges()['data'][-1][1]
|
||||
# ref: anki/sync.py:720 (commit cca3fcb2418880d0430a5c5c2e6b81ba260065b7)
|
||||
self.assertEqual(
|
||||
mh.mediaChanges(lastUsn=99)['data'][-1][1],
|
||||
mh.col.media.lastUsn()
|
||||
mh.mediaChanges(lastUsn=99)["data"][-1][1], mh.col.media.lastUsn()
|
||||
)
|
||||
|
||||
@ -14,20 +14,21 @@ from ankisyncd.sync_app import SyncUserSession
|
||||
|
||||
import helpers.server_utils
|
||||
|
||||
|
||||
class FakeSessionManager(SimpleSessionManager):
|
||||
def __init__(self, config):
|
||||
pass
|
||||
|
||||
|
||||
class BadSessionManager:
|
||||
pass
|
||||
|
||||
|
||||
class SessionManagerFactoryTest(unittest.TestCase):
|
||||
def test_get_session_manager(self):
|
||||
# Get absolute path to development ini file.
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
ini_file_path = os.path.join(script_dir,
|
||||
"assets",
|
||||
"test.conf")
|
||||
ini_file_path = os.path.join(script_dir, "assets", "test.conf")
|
||||
|
||||
# Create temporary files and dirs the server will use.
|
||||
server_paths = helpers.server_utils.create_server_paths()
|
||||
@ -36,32 +37,40 @@ class SessionManagerFactoryTest(unittest.TestCase):
|
||||
config.read(ini_file_path)
|
||||
|
||||
# Use custom files and dirs in settings. Should be SqliteSessionManager
|
||||
config['sync_app'].update(server_paths)
|
||||
self.assertTrue(type(get_session_manager(config['sync_app']) == SqliteSessionManager))
|
||||
config["sync_app"].update(server_paths)
|
||||
self.assertTrue(
|
||||
type(get_session_manager(config["sync_app"]) == SqliteSessionManager)
|
||||
)
|
||||
|
||||
# No value defaults to SimpleSessionManager
|
||||
config.remove_option("sync_app", "session_db_path")
|
||||
self.assertTrue(type(get_session_manager(config['sync_app'])) == SimpleSessionManager)
|
||||
self.assertTrue(
|
||||
type(get_session_manager(config["sync_app"])) == SimpleSessionManager
|
||||
)
|
||||
|
||||
# A conf-specified SessionManager is loaded
|
||||
config.set("sync_app", "session_manager", 'test_sessions.FakeSessionManager')
|
||||
self.assertTrue(type(get_session_manager(config['sync_app'])) == FakeSessionManager)
|
||||
config.set("sync_app", "session_manager", "test_sessions.FakeSessionManager")
|
||||
self.assertTrue(
|
||||
type(get_session_manager(config["sync_app"])) == FakeSessionManager
|
||||
)
|
||||
|
||||
# Should fail at load time if the class doesn't inherit from SimpleSessionManager
|
||||
config.set("sync_app", "session_manager", 'test_sessions.BadSessionManager')
|
||||
config.set("sync_app", "session_manager", "test_sessions.BadSessionManager")
|
||||
with self.assertRaises(TypeError):
|
||||
sm = get_session_manager(config['sync_app'])
|
||||
sm = get_session_manager(config["sync_app"])
|
||||
|
||||
# Add the session_db_path back, it should take precedence over BadSessionManager
|
||||
config['sync_app'].update(server_paths)
|
||||
self.assertTrue(type(get_session_manager(config['sync_app'])) == SqliteSessionManager)
|
||||
config["sync_app"].update(server_paths)
|
||||
self.assertTrue(
|
||||
type(get_session_manager(config["sync_app"])) == SqliteSessionManager
|
||||
)
|
||||
|
||||
|
||||
class SimpleSessionManagerTest(unittest.TestCase):
|
||||
test_hkey = '1234567890'
|
||||
test_hkey = "1234567890"
|
||||
sdir = tempfile.mkdtemp(suffix="_session")
|
||||
os.rmdir(sdir)
|
||||
test_session = SyncUserSession('testName', sdir, None, None)
|
||||
test_session = SyncUserSession("testName", sdir, None, None)
|
||||
|
||||
def setUp(self):
|
||||
self.sessionManager = SimpleSessionManager()
|
||||
@ -71,10 +80,12 @@ class SimpleSessionManagerTest(unittest.TestCase):
|
||||
|
||||
def test_save(self):
|
||||
self.sessionManager.save(self.test_hkey, self.test_session)
|
||||
self.assertEqual(self.sessionManager.sessions[self.test_hkey].name,
|
||||
self.test_session.name)
|
||||
self.assertEqual(self.sessionManager.sessions[self.test_hkey].path,
|
||||
self.test_session.path)
|
||||
self.assertEqual(
|
||||
self.sessionManager.sessions[self.test_hkey].name, self.test_session.name
|
||||
)
|
||||
self.assertEqual(
|
||||
self.sessionManager.sessions[self.test_hkey].path, self.test_session.path
|
||||
)
|
||||
|
||||
def test_delete(self):
|
||||
self.sessionManager.save(self.test_hkey, self.test_session)
|
||||
@ -111,13 +122,11 @@ class SqliteSessionManagerTest(SimpleSessionManagerTest):
|
||||
|
||||
conn = sqlite3.connect(self._test_sess_db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT username, path FROM session WHERE hkey=?",
|
||||
(self.test_hkey,))
|
||||
cursor.execute(
|
||||
"SELECT username, path FROM session WHERE hkey=?", (self.test_hkey,)
|
||||
)
|
||||
res = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
self.assertEqual(res[0], self.test_session.name)
|
||||
self.assertEqual(res[1], self.test_session.path)
|
||||
|
||||
|
||||
|
||||
|
||||
@ -16,10 +16,9 @@ class SyncCollectionHandlerTest(CollectionTestBase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.session = MagicMock()
|
||||
self.session.name = 'test'
|
||||
self.session.name = "test"
|
||||
self.syncCollectionHandler = SyncCollectionHandler(
|
||||
self.collection,
|
||||
self.session
|
||||
self.collection, self.session
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
@ -28,48 +27,48 @@ class SyncCollectionHandlerTest(CollectionTestBase):
|
||||
|
||||
def test_old_client(self):
|
||||
old = (
|
||||
','.join(('ankidesktop', '2.0.12', 'lin::')),
|
||||
','.join(('ankidesktop', '2.0.26', 'lin::')),
|
||||
','.join(('ankidroid', '2.1', '')),
|
||||
','.join(('ankidroid', '2.2', '')),
|
||||
','.join(('ankidroid', '2.2.2', '')),
|
||||
','.join(('ankidroid', '2.3alpha3', '')),
|
||||
",".join(("ankidesktop", "2.0.12", "lin::")),
|
||||
",".join(("ankidesktop", "2.0.26", "lin::")),
|
||||
",".join(("ankidroid", "2.1", "")),
|
||||
",".join(("ankidroid", "2.2", "")),
|
||||
",".join(("ankidroid", "2.2.2", "")),
|
||||
",".join(("ankidroid", "2.3alpha3", "")),
|
||||
)
|
||||
|
||||
current = (
|
||||
None,
|
||||
','.join(('ankidesktop', '2.0.27', 'lin::')),
|
||||
','.join(('ankidesktop', '2.0.32', 'lin::')),
|
||||
','.join(('ankidesktop', '2.1.0', 'lin::')),
|
||||
','.join(('ankidesktop', '2.1.6-beta2', 'lin::')),
|
||||
','.join(('ankidesktop', '2.1.9 (dev)', 'lin::')),
|
||||
','.join(('ankidesktop', '2.1.26 (arch-linux-2.1.26-1)', 'lin:arch:')),
|
||||
','.join(('ankidroid', '2.2.3', '')),
|
||||
','.join(('ankidroid', '2.3alpha4', '')),
|
||||
','.join(('ankidroid', '2.3alpha5', '')),
|
||||
','.join(('ankidroid', '2.3beta1', '')),
|
||||
','.join(('ankidroid', '2.3', '')),
|
||||
','.join(('ankidroid', '2.9', '')),
|
||||
",".join(("ankidesktop", "2.0.27", "lin::")),
|
||||
",".join(("ankidesktop", "2.0.32", "lin::")),
|
||||
",".join(("ankidesktop", "2.1.0", "lin::")),
|
||||
",".join(("ankidesktop", "2.1.6-beta2", "lin::")),
|
||||
",".join(("ankidesktop", "2.1.9 (dev)", "lin::")),
|
||||
",".join(("ankidesktop", "2.1.26 (arch-linux-2.1.26-1)", "lin:arch:")),
|
||||
",".join(("ankidroid", "2.2.3", "")),
|
||||
",".join(("ankidroid", "2.3alpha4", "")),
|
||||
",".join(("ankidroid", "2.3alpha5", "")),
|
||||
",".join(("ankidroid", "2.3beta1", "")),
|
||||
",".join(("ankidroid", "2.3", "")),
|
||||
",".join(("ankidroid", "2.9", "")),
|
||||
)
|
||||
|
||||
for cv in old:
|
||||
if not SyncCollectionHandler._old_client(cv):
|
||||
raise AssertionError("old_client(\"%s\") is False" % cv)
|
||||
raise AssertionError('old_client("%s") is False' % cv)
|
||||
|
||||
for cv in current:
|
||||
if SyncCollectionHandler._old_client(cv):
|
||||
raise AssertionError("old_client(\"%s\") is True" % cv)
|
||||
raise AssertionError('old_client("%s") is True' % cv)
|
||||
|
||||
def test_meta(self):
|
||||
meta = self.syncCollectionHandler.meta(v=SYNC_VER)
|
||||
self.assertEqual(meta['scm'], self.collection.scm)
|
||||
self.assertTrue((type(meta['ts']) == int) and meta['ts'] > 0)
|
||||
self.assertEqual(meta['mod'], self.collection.mod)
|
||||
self.assertEqual(meta['usn'], self.collection._usn)
|
||||
self.assertEqual(meta['uname'], self.session.name)
|
||||
self.assertEqual(meta['musn'], self.collection.media.lastUsn())
|
||||
self.assertEqual(meta['msg'], '')
|
||||
self.assertEqual(meta['cont'], True)
|
||||
self.assertEqual(meta["scm"], self.collection.scm)
|
||||
self.assertTrue((type(meta["ts"]) == int) and meta["ts"] > 0)
|
||||
self.assertEqual(meta["mod"], self.collection.mod)
|
||||
self.assertEqual(meta["usn"], self.collection._usn)
|
||||
self.assertEqual(meta["uname"], self.session.name)
|
||||
self.assertEqual(meta["musn"], self.collection.media.lastUsn())
|
||||
self.assertEqual(meta["msg"], "")
|
||||
self.assertEqual(meta["cont"], True)
|
||||
|
||||
|
||||
class SyncAppTest(unittest.TestCase):
|
||||
|
||||
@ -10,20 +10,21 @@ from ankisyncd.users import get_user_manager
|
||||
|
||||
import helpers.server_utils
|
||||
|
||||
|
||||
class FakeUserManager(SimpleUserManager):
|
||||
def __init__(self, config):
|
||||
pass
|
||||
|
||||
|
||||
class BadUserManager:
|
||||
pass
|
||||
|
||||
|
||||
class UserManagerFactoryTest(unittest.TestCase):
|
||||
def test_get_user_manager(self):
|
||||
# Get absolute path to development ini file.
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
ini_file_path = os.path.join(script_dir,
|
||||
"assets",
|
||||
"test.conf")
|
||||
ini_file_path = os.path.join(script_dir, "assets", "test.conf")
|
||||
|
||||
# Create temporary files and dirs the server will use.
|
||||
server_paths = helpers.server_utils.create_server_paths()
|
||||
@ -32,25 +33,25 @@ class UserManagerFactoryTest(unittest.TestCase):
|
||||
config.read(ini_file_path)
|
||||
|
||||
# Use custom files and dirs in settings. Should be SqliteUserManager
|
||||
config['sync_app'].update(server_paths)
|
||||
self.assertTrue(type(get_user_manager(config['sync_app']) == SqliteUserManager))
|
||||
config["sync_app"].update(server_paths)
|
||||
self.assertTrue(type(get_user_manager(config["sync_app"]) == SqliteUserManager))
|
||||
|
||||
# No value defaults to SimpleUserManager
|
||||
config.remove_option("sync_app", "auth_db_path")
|
||||
self.assertTrue(type(get_user_manager(config['sync_app'])) == SimpleUserManager)
|
||||
self.assertTrue(type(get_user_manager(config["sync_app"])) == SimpleUserManager)
|
||||
|
||||
# A conf-specified UserManager is loaded
|
||||
config.set("sync_app", "user_manager", 'test_users.FakeUserManager')
|
||||
self.assertTrue(type(get_user_manager(config['sync_app'])) == FakeUserManager)
|
||||
config.set("sync_app", "user_manager", "test_users.FakeUserManager")
|
||||
self.assertTrue(type(get_user_manager(config["sync_app"])) == FakeUserManager)
|
||||
|
||||
# Should fail at load time if the class doesn't inherit from SimpleUserManager
|
||||
config.set("sync_app", "user_manager", 'test_users.BadUserManager')
|
||||
config.set("sync_app", "user_manager", "test_users.BadUserManager")
|
||||
with self.assertRaises(TypeError):
|
||||
um = get_user_manager(config['sync_app'])
|
||||
um = get_user_manager(config["sync_app"])
|
||||
|
||||
# Add the auth_db_path back, it should take precedence over BadUserManager
|
||||
config['sync_app'].update(server_paths)
|
||||
self.assertTrue(type(get_user_manager(config['sync_app']) == SqliteUserManager))
|
||||
config["sync_app"].update(server_paths)
|
||||
self.assertTrue(type(get_user_manager(config["sync_app"]) == SqliteUserManager))
|
||||
|
||||
|
||||
class SimpleUserManagerTest(unittest.TestCase):
|
||||
@ -61,22 +62,18 @@ class SimpleUserManagerTest(unittest.TestCase):
|
||||
self._user_manager = None
|
||||
|
||||
def test_authenticate(self):
|
||||
good_test_un = 'username'
|
||||
good_test_pw = 'password'
|
||||
bad_test_un = 'notAUsername'
|
||||
bad_test_pw = 'notAPassword'
|
||||
good_test_un = "username"
|
||||
good_test_pw = "password"
|
||||
bad_test_un = "notAUsername"
|
||||
bad_test_pw = "notAPassword"
|
||||
|
||||
self.assertTrue(self.user_manager.authenticate(good_test_un,
|
||||
good_test_pw))
|
||||
self.assertTrue(self.user_manager.authenticate(bad_test_un,
|
||||
bad_test_pw))
|
||||
self.assertTrue(self.user_manager.authenticate(good_test_un,
|
||||
bad_test_pw))
|
||||
self.assertTrue(self.user_manager.authenticate(bad_test_un,
|
||||
good_test_pw))
|
||||
self.assertTrue(self.user_manager.authenticate(good_test_un, good_test_pw))
|
||||
self.assertTrue(self.user_manager.authenticate(bad_test_un, bad_test_pw))
|
||||
self.assertTrue(self.user_manager.authenticate(good_test_un, bad_test_pw))
|
||||
self.assertTrue(self.user_manager.authenticate(bad_test_un, good_test_pw))
|
||||
|
||||
def test_userdir(self):
|
||||
username = 'my_username'
|
||||
username = "my_username"
|
||||
dirname = self.user_manager.userdir(username)
|
||||
self.assertEqual(dirname, username)
|
||||
|
||||
@ -87,8 +84,7 @@ class SqliteUserManagerTest(unittest.TestCase):
|
||||
self.basedir = basedir
|
||||
self.auth_db_path = os.path.join(basedir, "auth.db")
|
||||
self.collection_path = os.path.join(basedir, "collections")
|
||||
self.user_manager = SqliteUserManager(self.auth_db_path,
|
||||
self.collection_path)
|
||||
self.user_manager = SqliteUserManager(self.auth_db_path, self.collection_path)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.basedir)
|
||||
@ -151,18 +147,22 @@ class SqliteUserManagerTest(unittest.TestCase):
|
||||
self.assertTrue(os.path.isdir(expected_dir_path))
|
||||
|
||||
def test_add_users(self):
|
||||
users_data = [("my_first_username", "my_first_password"),
|
||||
("my_second_username", "my_second_password")]
|
||||
users_data = [
|
||||
("my_first_username", "my_first_password"),
|
||||
("my_second_username", "my_second_password"),
|
||||
]
|
||||
self.user_manager.create_auth_db()
|
||||
self.user_manager.add_users(users_data)
|
||||
|
||||
user_list = self.user_manager.user_list()
|
||||
self.assertIn("my_first_username", user_list)
|
||||
self.assertIn("my_second_username", user_list)
|
||||
self.assertTrue(os.path.isdir(os.path.join(self.collection_path,
|
||||
"my_first_username")))
|
||||
self.assertTrue(os.path.isdir(os.path.join(self.collection_path,
|
||||
"my_second_username")))
|
||||
self.assertTrue(
|
||||
os.path.isdir(os.path.join(self.collection_path, "my_first_username"))
|
||||
)
|
||||
self.assertTrue(
|
||||
os.path.isdir(os.path.join(self.collection_path, "my_second_username"))
|
||||
)
|
||||
|
||||
def test__add_user_to_auth_db(self):
|
||||
username = "my_username"
|
||||
@ -191,8 +191,7 @@ class SqliteUserManagerTest(unittest.TestCase):
|
||||
self.user_manager.create_auth_db()
|
||||
self.user_manager.add_user(username, password)
|
||||
|
||||
self.assertTrue(self.user_manager.authenticate(username,
|
||||
password))
|
||||
self.assertTrue(self.user_manager.authenticate(username, password))
|
||||
|
||||
def test_set_password_for_user(self):
|
||||
username = "my_username"
|
||||
@ -203,8 +202,5 @@ class SqliteUserManagerTest(unittest.TestCase):
|
||||
self.user_manager.add_user(username, password)
|
||||
|
||||
self.user_manager.set_password_for_user(username, new_password)
|
||||
self.assertFalse(self.user_manager.authenticate(username,
|
||||
password))
|
||||
self.assertTrue(self.user_manager.authenticate(username,
|
||||
new_password))
|
||||
|
||||
self.assertFalse(self.user_manager.authenticate(username, password))
|
||||
self.assertTrue(self.user_manager.authenticate(username, new_password))
|
||||
|
||||
@ -17,4 +17,3 @@ class SyncAppFunctionalHostKeyTest(SyncAppFunctionalTestBase):
|
||||
self.assertIsNone(self.server.hostKey("testuser", "wrongpassword"))
|
||||
self.assertIsNone(self.server.hostKey("wronguser", "wrongpassword"))
|
||||
self.assertIsNone(self.server.hostKey("wronguser", "testpassword"))
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user