diff --git a/src/addon/__init__.py b/src/addon/__init__.py index 4a0c68b..0408d0f 100644 --- a/src/addon/__init__.py +++ b/src/addon/__init__.py @@ -10,67 +10,77 @@ config = aqt.mw.addonManager.getConfig(__name__) # TODO: force the user to log out before changing any of the settings + def addui(self, _): - self = self.form - parent_w = self.tab_2 - parent_l = self.vboxlayout - self.useCustomServer = QCheckBox(parent_w) - self.useCustomServer.setText("Use custom sync server") - parent_l.addWidget(self.useCustomServer) - cshl = QHBoxLayout() - parent_l.addLayout(cshl) + self = self.form + parent_w = self.tab_2 + parent_l = self.vboxlayout + self.useCustomServer = QCheckBox(parent_w) + self.useCustomServer.setText("Use custom sync server") + parent_l.addWidget(self.useCustomServer) + cshl = QHBoxLayout() + parent_l.addLayout(cshl) - self.serverAddrLabel = QLabel(parent_w) - self.serverAddrLabel.setText("Server address") - cshl.addWidget(self.serverAddrLabel) - self.customServerAddr = QLineEdit(parent_w) - self.customServerAddr.setPlaceholderText(DEFAULT_ADDR) - cshl.addWidget(self.customServerAddr) + self.serverAddrLabel = QLabel(parent_w) + self.serverAddrLabel.setText("Server address") + cshl.addWidget(self.serverAddrLabel) + self.customServerAddr = QLineEdit(parent_w) + self.customServerAddr.setPlaceholderText(DEFAULT_ADDR) + cshl.addWidget(self.customServerAddr) - pconfig = getprofileconfig() - if pconfig["enabled"]: - self.useCustomServer.setCheckState(Qt.Checked) - if pconfig["addr"]: - self.customServerAddr.setText(pconfig["addr"]) + pconfig = getprofileconfig() + if pconfig["enabled"]: + self.useCustomServer.setCheckState(Qt.Checked) + if pconfig["addr"]: + self.customServerAddr.setText(pconfig["addr"]) - self.customServerAddr.textChanged.connect(lambda text: updateserver(self, text)) - def onchecked(state): - pconfig["enabled"] = state == Qt.Checked - updateui(self, state) - updateserver(self, self.customServerAddr.text()) - self.useCustomServer.stateChanged.connect(onchecked) + self.customServerAddr.textChanged.connect(lambda text: updateserver(self, text)) + + def onchecked(state): + pconfig["enabled"] = state == Qt.Checked + updateui(self, state) + updateserver(self, self.customServerAddr.text()) + + self.useCustomServer.stateChanged.connect(onchecked) + + updateui(self, self.useCustomServer.checkState()) - updateui(self, self.useCustomServer.checkState()) def updateserver(self, text): - pconfig = getprofileconfig() - if pconfig['enabled']: - addr = text or self.customServerAddr.placeholderText() - pconfig['addr'] = addr - setserver() - aqt.mw.addonManager.writeConfig(__name__, config) + pconfig = getprofileconfig() + if pconfig["enabled"]: + addr = text or self.customServerAddr.placeholderText() + pconfig["addr"] = addr + setserver() + aqt.mw.addonManager.writeConfig(__name__, config) + def updateui(self, state): - self.serverAddrLabel.setEnabled(state == Qt.Checked) - self.customServerAddr.setEnabled(state == Qt.Checked) + self.serverAddrLabel.setEnabled(state == Qt.Checked) + self.customServerAddr.setEnabled(state == Qt.Checked) + def setserver(): - pconfig = getprofileconfig() - if pconfig['enabled']: - aqt.mw.pm.profile['hostNum'] = None - anki.sync.SYNC_BASE = "%s" + pconfig['addr'] - else: - anki.sync.SYNC_BASE = anki.consts.SYNC_BASE + pconfig = getprofileconfig() + if pconfig["enabled"]: + aqt.mw.pm.profile["hostNum"] = None + anki.sync.SYNC_BASE = "%s" + pconfig["addr"] + else: + anki.sync.SYNC_BASE = anki.consts.SYNC_BASE + def getprofileconfig(): - if aqt.mw.pm.name not in config["profiles"]: - # inherit global settings if present (used in earlier versions of the addon) - config["profiles"][aqt.mw.pm.name] = { - "enabled": config.get("enabled", False), - "addr": config.get("addr", DEFAULT_ADDR), - } - aqt.mw.addonManager.writeConfig(__name__, config) - return config["profiles"][aqt.mw.pm.name] + if aqt.mw.pm.name not in config["profiles"]: + # inherit global settings if present (used in earlier versions of the addon) + config["profiles"][aqt.mw.pm.name] = { + "enabled": config.get("enabled", False), + "addr": config.get("addr", DEFAULT_ADDR), + } + aqt.mw.addonManager.writeConfig(__name__, config) + return config["profiles"][aqt.mw.pm.name] + addHook("profileLoaded", setserver) -aqt.preferences.Preferences.__init__ = wrap(aqt.preferences.Preferences.__init__, addui, "after") +aqt.preferences.Preferences.__init__ = wrap( + aqt.preferences.Preferences.__init__, addui, "after" +) diff --git a/src/ankisyncd/__main__.py b/src/ankisyncd/__main__.py index c6a4f13..656c54a 100644 --- a/src/ankisyncd/__main__.py +++ b/src/ankisyncd/__main__.py @@ -2,6 +2,7 @@ import sys if __package__ is None and not hasattr(sys, "frozen"): import os.path + path = os.path.realpath(os.path.abspath(__file__)) sys.path.insert(0, os.path.dirname(os.path.dirname(path))) diff --git a/src/ankisyncd/collection.py b/src/ankisyncd/collection.py index d29e97e..f93178a 100644 --- a/src/ankisyncd/collection.py +++ b/src/ankisyncd/collection.py @@ -30,7 +30,7 @@ class CollectionWrapper: self.close() def execute(self, func, args=[], kw={}, waitForReturn=True): - """ Executes the given function with the underlying anki.storage.Collection + """Executes the given function with the underlying anki.storage.Collection object as the first argument and any additional arguments specified by *args and **kw. @@ -92,6 +92,7 @@ class CollectionWrapper: """Returns True if the collection is open, False otherwise.""" return self.__col is not None + class CollectionManager: """Manages a set of CollectionWrapper objects.""" @@ -109,7 +110,9 @@ class CollectionManager: try: col = self.collections[path] except KeyError: - col = self.collections[path] = self.collection_wrapper(self.config, path, setup_new_collection) + col = self.collections[path] = self.collection_wrapper( + self.config, path, setup_new_collection + ) return col @@ -119,19 +122,25 @@ class CollectionManager: del self.collections[path] col.close() -def get_collection_wrapper(config, path, setup_new_collection = None): + +def get_collection_wrapper(config, path, setup_new_collection=None): if "collection_wrapper" in config and config["collection_wrapper"]: - logger.info("Found collection_wrapper in config, using {} for " - "user data persistence".format(config['collection_wrapper'])) + logger.info( + "Found collection_wrapper in config, using {} for " + "user data persistence".format(config["collection_wrapper"]) + ) import importlib import inspect - module_name, class_name = config['collection_wrapper'].rsplit('.', 1) + + module_name, class_name = config["collection_wrapper"].rsplit(".", 1) module = importlib.import_module(module_name.strip()) class_ = getattr(module, class_name.strip()) if not CollectionWrapper in inspect.getmro(class_): - raise TypeError('''"collection_wrapper" found in the conf file but it doesn''t - inherit from CollectionWrapper''') + raise TypeError( + """"collection_wrapper" found in the conf file but it doesn''t + inherit from CollectionWrapper""" + ) return class_(config, path, setup_new_collection) else: return CollectionWrapper(config, path, setup_new_collection) diff --git a/src/ankisyncd/config.py b/src/ankisyncd/config.py index 7eaaa4c..a440b44 100644 --- a/src/ankisyncd/config.py +++ b/src/ankisyncd/config.py @@ -7,9 +7,9 @@ logger = logging.getLogger("ankisyncd") paths = [ "/etc/ankisyncd/ankisyncd.conf", - os.environ.get("XDG_CONFIG_HOME") and - (os.path.join(os.environ['XDG_CONFIG_HOME'], "ankisyncd", "ankisyncd.conf")) or - os.path.join(os.path.expanduser("~"), ".config", "ankisyncd", "ankisyncd.conf"), + os.environ.get("XDG_CONFIG_HOME") + and (os.path.join(os.environ["XDG_CONFIG_HOME"], "ankisyncd", "ankisyncd.conf")) + or os.path.join(os.path.expanduser("~"), ".config", "ankisyncd", "ankisyncd.conf"), os.path.join(dirname(dirname(realpath(__file__))), "ankisyncd.conf"), ] @@ -19,11 +19,12 @@ paths = [ def load_from_env(conf): logger.debug("Loading/overriding config values from ENV") for env in os.environ: - if env.startswith('ANKISYNCD_'): + if env.startswith("ANKISYNCD_"): config_key = env[10:].lower() conf[config_key] = os.getenv(env) logger.info("Setting {} from ENV".format(config_key)) + def load(path=None): choices = paths parser = configparser.ConfigParser() @@ -33,7 +34,7 @@ def load(path=None): logger.debug("config.location: trying", path) try: parser.read(path) - conf = parser['sync_app'] + conf = parser["sync_app"] logger.info("Loaded config from {}".format(path)) load_from_env(conf) return conf diff --git a/src/ankisyncd/full_sync.py b/src/ankisyncd/full_sync.py index a6c9b9d..a458b5b 100644 --- a/src/ankisyncd/full_sync.py +++ b/src/ankisyncd/full_sync.py @@ -13,6 +13,7 @@ from anki.collection import Collection logger = logging.getLogger("ankisyncd.media") logger.setLevel(1) + class FullSyncManager: def test_db(self, db: DB): """ @@ -34,15 +35,14 @@ class FullSyncManager: # Verify integrity of the received database file before replacing our # existing db. temp_db_path = session.get_collection_path() + ".tmp" - with open(temp_db_path, 'wb') as f: + with open(temp_db_path, "wb") as f: f.write(data) try: with DB(temp_db_path) as test_db: self.test_db(test_db) except sqlite.Error as e: - raise HTTPBadRequest("Uploaded collection database file is " - "corrupt.") + raise HTTPBadRequest("Uploaded collection database file is " "corrupt.") # Overwrite existing db. col.close() @@ -69,7 +69,7 @@ class FullSyncManager: col.close(downgrade=True) db_path = session.get_collection_path() try: - with open(db_path, 'rb') as tmp: + with open(db_path, "rb") as tmp: data = tmp.read() finally: col.reopen() @@ -80,16 +80,21 @@ class FullSyncManager: def get_full_sync_manager(config): - if "full_sync_manager" in config and config["full_sync_manager"]: # load from config + if ( + "full_sync_manager" in config and config["full_sync_manager"] + ): # load from config import importlib import inspect - module_name, class_name = config['full_sync_manager'].rsplit('.', 1) + + module_name, class_name = config["full_sync_manager"].rsplit(".", 1) module = importlib.import_module(module_name.strip()) class_ = getattr(module, class_name.strip()) if not FullSyncManager in inspect.getmro(class_): - raise TypeError('''"full_sync_manager" found in the conf file but it doesn''t - inherit from FullSyncManager''') + raise TypeError( + """"full_sync_manager" found in the conf file but it doesn''t + inherit from FullSyncManager""" + ) return class_(config) else: return FullSyncManager() diff --git a/src/ankisyncd/media.py b/src/ankisyncd/media.py index 47341f4..a824f33 100644 --- a/src/ankisyncd/media.py +++ b/src/ankisyncd/media.py @@ -12,6 +12,7 @@ from anki.media import MediaManager logger = logging.getLogger("ankisyncd.media") + class ServerMediaManager(MediaManager): def __init__(self, col, server=True): super().__init__(col, server) @@ -20,14 +21,16 @@ class ServerMediaManager(MediaManager): def addMedia(self, media_to_add): self._db.executemany( - "INSERT OR REPLACE INTO media VALUES (?,?,?)", - media_to_add + "INSERT OR REPLACE INTO media VALUES (?,?,?)", media_to_add ) self._db.commit() def changes(self, lastUsn): - return self._db.execute("select fname,usn,csum from media order by usn desc limit ?", self.lastUsn() - lastUsn) - + return self._db.execute( + "select fname,usn,csum from media order by usn desc limit ?", + self.lastUsn() - lastUsn, + ) + def connect(self): path = self.dir() + ".server.db" create = not os.path.exists(path) diff --git a/src/ankisyncd/sessions.py b/src/ankisyncd/sessions.py index 7c609db..64ab1dc 100644 --- a/src/ankisyncd/sessions.py +++ b/src/ankisyncd/sessions.py @@ -43,20 +43,26 @@ class SqliteSessionManager(SimpleSessionManager): conn = self._conn() cursor = conn.cursor() - cursor.execute("SELECT * FROM sqlite_master " - "WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' " - "AND tbl_name = 'session'") + cursor.execute( + "SELECT * FROM sqlite_master " + "WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' " + "AND tbl_name = 'session'" + ) res = cursor.fetchone() conn.close() if res is not None: - raise Exception("Outdated database schema, run utils/migrate_user_tables.py") + raise Exception( + "Outdated database schema, run utils/migrate_user_tables.py" + ) def _conn(self): new = not os.path.exists(self.session_db_path) conn = sqlite.connect(self.session_db_path) if new: cursor = conn.cursor() - cursor.execute("CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, username VARCHAR, path VARCHAR)") + cursor.execute( + "CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, username VARCHAR, path VARCHAR)" + ) return conn # Default to using sqlite3 syntax but overridable for sub-classes using other @@ -73,7 +79,9 @@ class SqliteSessionManager(SimpleSessionManager): conn = self._conn() cursor = conn.cursor() - cursor.execute(self.fs("SELECT skey, username, path FROM session WHERE hkey=?"), (hkey,)) + cursor.execute( + self.fs("SELECT skey, username, path FROM session WHERE hkey=?"), (hkey,) + ) res = cursor.fetchone() if res is not None: @@ -89,7 +97,9 @@ class SqliteSessionManager(SimpleSessionManager): conn = self._conn() cursor = conn.cursor() - cursor.execute(self.fs("SELECT hkey, username, path FROM session WHERE skey=?"), (skey,)) + cursor.execute( + self.fs("SELECT hkey, username, path FROM session WHERE skey=?"), (skey,) + ) res = cursor.fetchone() if res is not None: @@ -103,8 +113,10 @@ class SqliteSessionManager(SimpleSessionManager): conn = self._conn() cursor = conn.cursor() - cursor.execute("INSERT OR REPLACE INTO session (hkey, skey, username, path) VALUES (?, ?, ?, ?)", - (hkey, session.skey, session.name, session.path)) + cursor.execute( + "INSERT OR REPLACE INTO session (hkey, skey, username, path) VALUES (?, ?, ?, ?)", + (hkey, session.skey, session.name, session.path), + ) conn.commit() @@ -117,25 +129,35 @@ class SqliteSessionManager(SimpleSessionManager): cursor.execute(self.fs("DELETE FROM session WHERE hkey=?"), (hkey,)) conn.commit() + def get_session_manager(config): if "session_db_path" in config and config["session_db_path"]: - logger.info("Found session_db_path in config, using SqliteSessionManager for auth") - return SqliteSessionManager(config['session_db_path']) + logger.info( + "Found session_db_path in config, using SqliteSessionManager for auth" + ) + return SqliteSessionManager(config["session_db_path"]) elif "session_manager" in config and config["session_manager"]: # load from config - logger.info("Found session_manager in config, using {} for persisting sessions".format( - config['session_manager']) + logger.info( + "Found session_manager in config, using {} for persisting sessions".format( + config["session_manager"] + ) ) import importlib import inspect - module_name, class_name = config['session_manager'].rsplit('.', 1) + + module_name, class_name = config["session_manager"].rsplit(".", 1) module = importlib.import_module(module_name.strip()) class_ = getattr(module, class_name.strip()) if not SimpleSessionManager in inspect.getmro(class_): - raise TypeError('''"session_manager" found in the conf file but it doesn''t - inherit from SimpleSessionManager''') + raise TypeError( + """"session_manager" found in the conf file but it doesn''t + inherit from SimpleSessionManager""" + ) return class_(config) else: - logger.warning("Neither session_db_path nor session_manager set, " - "ankisyncd will lose sessions on application restart") + logger.warning( + "Neither session_db_path nor session_manager set, " + "ankisyncd will lose sessions on application restart" + ) return SimpleSessionManager() diff --git a/src/ankisyncd/sync.py b/src/ankisyncd/sync.py index 9c45e40..5796246 100644 --- a/src/ankisyncd/sync.py +++ b/src/ankisyncd/sync.py @@ -10,7 +10,7 @@ import random import requests import json import os -from typing import List,Tuple +from typing import List, Tuple from anki.db import DB, DBError from anki.utils import ids2str, intTime, platDesc, checksum, devMode @@ -24,38 +24,43 @@ from anki.lang import ngettext # https://github.com/ankitects/anki/blob/04b1ca75599f18eb783a8bf0bdeeeb32362f4da0/rslib/src/sync/http_client.rs#L11 SYNC_VER = 10 # https://github.com/ankitects/anki/blob/cca3fcb2418880d0430a5c5c2e6b81ba260065b7/anki/consts.py#L50 -SYNC_ZIP_SIZE = int(2.5*1024*1024) +SYNC_ZIP_SIZE = int(2.5 * 1024 * 1024) # https://github.com/ankitects/anki/blob/cca3fcb2418880d0430a5c5c2e6b81ba260065b7/anki/consts.py#L51 SYNC_ZIP_COUNT = 25 # syncing vars HTTP_TIMEOUT = 90 HTTP_PROXY = None -HTTP_BUF_SIZE = 64*1024 +HTTP_BUF_SIZE = 64 * 1024 # Incremental syncing ########################################################################## + class Syncer(object): def __init__(self, col, server=None): self.col = col self.server = server -# new added functions related to Syncer: -# these are removed from latest anki module -######################################################################## + # new added functions related to Syncer: + # these are removed from latest anki module + ######################################################################## def scm(self): """return schema""" - scm=self.col.db.scalar("select scm from col") + scm = self.col.db.scalar("select scm from col") return scm + def increment_usn(self): """usn+1 in db""" self.col.db.execute("update col set usn = usn + 1") - def set_modified_time(self,now:int): + + def set_modified_time(self, now: int): self.col.db.execute("update col set mod=?", now) - def set_last_sync(self,now:int): + + def set_last_sync(self, now: int): self.col.db.execute("update col set ls = ?", now) -######################################################################### + + ######################################################################### def meta(self): return dict( mod=self.col.mod, @@ -64,30 +69,29 @@ class Syncer(object): ts=intTime(), musn=0, msg="", - cont=True + cont=True, ) def changes(self): "Bundle up small objects." - d = dict(models=self.getModels(), - decks=self.getDecks(), - tags=self.getTags()) + d = dict(models=self.getModels(), decks=self.getDecks(), tags=self.getTags()) if self.lnewer: - d['conf'] = self.col.all_config() - d['crt'] = self.col.crt + d["conf"] = self.col.all_config() + d["crt"] = self.col.crt return d def mergeChanges(self, lchg, rchg): # then the other objects - self.mergeModels(rchg['models']) - self.mergeDecks(rchg['decks']) - if 'conf' in rchg: - self.mergeConf(rchg['conf']) + self.mergeModels(rchg["models"]) + self.mergeDecks(rchg["decks"]) + if "conf" in rchg: + self.mergeConf(rchg["conf"]) # this was left out of earlier betas - if 'crt' in rchg: - self.col.crt = rchg['crt'] + if "crt" in rchg: + self.col.crt = rchg["crt"] self.prepareToChunk() -# this fn was cloned from anki module(version 2.1.36) + + # this fn was cloned from anki module(version 2.1.36) def basicCheck(self) -> bool: "Basic integrity check for syncing. True if ok." # cards without notes @@ -118,9 +122,10 @@ select id from notes where mid = ?) limit 1""" ): return False return True - + def sanityCheck(self): - tables=["cards", + tables = [ + "cards", "notes", "revlog", "graves", @@ -130,17 +135,17 @@ select id from notes where mid = ?) limit 1""" "notetypes", ] for tb in tables: - if self.col.db.scalar(f'select null from {tb} where usn=-1'): - return f'table had usn=-1: {tb}' + if self.col.db.scalar(f"select null from {tb} where usn=-1"): + return f"table had usn=-1: {tb}" self.col.sched.reset() - + # return summary of deck # make sched.counts() equal to default [0,0,0] # to make sure sync normally if sched.counts() # are not equal between different clients due to - # different deck selection + # different deck selection return [ - list([0,0,0]), + list([0, 0, 0]), self.col.db.scalar("select count() from cards"), self.col.db.scalar("select count() from notes"), self.col.db.scalar("select count() from revlog"), @@ -155,7 +160,7 @@ select id from notes where mid = ?) limit 1""" def finish(self, now=None): if now is not None: - # ensure we save the mod time even if no changes made + # ensure we save the mod time even if no changes made self.set_modified_time(now) self.set_last_sync(now) self.increment_usn() @@ -174,17 +179,29 @@ select id from notes where mid = ?) limit 1""" def queryTable(self, table): lim = self.usnLim() if table == "revlog": - return self.col.db.execute(""" + return self.col.db.execute( + """ select id, cid, ?, ease, ivl, lastIvl, factor, time, type -from revlog where %s""" % lim, self.maxUsn) +from revlog where %s""" + % lim, + self.maxUsn, + ) elif table == "cards": - return self.col.db.execute(""" + return self.col.db.execute( + """ select id, nid, did, ord, mod, ?, type, queue, due, ivl, factor, reps, -lapses, left, odue, odid, flags, data from cards where %s""" % lim, self.maxUsn) +lapses, left, odue, odid, flags, data from cards where %s""" + % lim, + self.maxUsn, + ) else: - return self.col.db.execute(""" + return self.col.db.execute( + """ select id, guid, mid, mod, ?, tags, flds, '', '', flags, data -from notes where %s""" % lim, self.maxUsn) +from notes where %s""" + % lim, + self.maxUsn, + ) def chunk(self): buf = dict(done=False) @@ -195,95 +212,96 @@ from notes where %s""" % lim, self.maxUsn) f"update {curTable} set usn=? where usn=-1", self.maxUsn ) if not self.tablesLeft: - buf['done'] = True + buf["done"] = True return buf def applyChunk(self, chunk): if "revlog" in chunk: - self.mergeRevlog(chunk['revlog']) + self.mergeRevlog(chunk["revlog"]) if "cards" in chunk: - self.mergeCards(chunk['cards']) + self.mergeCards(chunk["cards"]) if "notes" in chunk: - self.mergeNotes(chunk['notes']) + self.mergeNotes(chunk["notes"]) # Deletions ########################################################################## - def add_grave(self, ids: List[int], type: int,usn: int): - items=[(id,type,usn) for id in ids] + def add_grave(self, ids: List[int], type: int, usn: int): + items = [(id, type, usn) for id in ids] # make sure table graves fields order and schema version match # query sql1='pragma table_info(graves)' version query schema='select ver from col' self.col.db.executemany( - "INSERT OR IGNORE INTO graves (oid, type, usn) VALUES (?, ?, ?)" , - items) - - def apply_graves(self, graves,latest_usn: int): - # remove card and the card's orphaned notes - self.col.remove_cards_and_orphaned_notes(graves['cards']) - self.add_grave(graves['cards'], REM_CARD,latest_usn) + "INSERT OR IGNORE INTO graves (oid, type, usn) VALUES (?, ?, ?)", items + ) + + def apply_graves(self, graves, latest_usn: int): + # remove card and the card's orphaned notes + self.col.remove_cards_and_orphaned_notes(graves["cards"]) + self.add_grave(graves["cards"], REM_CARD, latest_usn) # only notes - self.col.remove_notes(graves['notes']) - self.add_grave(graves['notes'], REM_NOTE,latest_usn) + self.col.remove_notes(graves["notes"]) + self.add_grave(graves["notes"], REM_NOTE, latest_usn) # since level 0 deck ,we only remove deck ,but backend will delete child,it is ok, the delete # will have once effect - self.col.decks.remove(graves['decks']) - self.add_grave(graves['decks'], REM_DECK,latest_usn) + self.col.decks.remove(graves["decks"]) + self.add_grave(graves["decks"], REM_DECK, latest_usn) # Models ########################################################################## def getModels(self): - mods = [m for m in self.col.models.all() if m['usn'] == -1] + mods = [m for m in self.col.models.all() if m["usn"] == -1] for m in mods: - m['usn'] = self.maxUsn + m["usn"] = self.maxUsn self.col.models.save() return mods def mergeModels(self, rchg): for r in rchg: - l = self.col.models.get(r['id']) + l = self.col.models.get(r["id"]) # if missing locally or server is newer, update - if not l or r['mod'] > l['mod']: + if not l or r["mod"] > l["mod"]: self.col.models.update(r) # Decks ########################################################################## def getDecks(self): - decks = [g for g in self.col.decks.all() if g['usn'] == -1] + decks = [g for g in self.col.decks.all() if g["usn"] == -1] for g in decks: - g['usn'] = self.maxUsn - dconf = [g for g in self.col.decks.allConf() if g['usn'] == -1] + g["usn"] = self.maxUsn + dconf = [g for g in self.col.decks.allConf() if g["usn"] == -1] for g in dconf: - g['usn'] = self.maxUsn + g["usn"] = self.maxUsn self.col.decks.save() return [decks, dconf] def mergeDecks(self, rchg): for r in rchg[0]: - l = self.col.decks.get(r['id'], False) + l = self.col.decks.get(r["id"], False) # work around mod time being stored as string - if l and not isinstance(l['mod'], int): - l['mod'] = int(l['mod']) + if l and not isinstance(l["mod"], int): + l["mod"] = int(l["mod"]) # if missing locally or server is newer, update - if not l or r['mod'] > l['mod']: + if not l or r["mod"] > l["mod"]: self.col.decks.update(r) for r in rchg[1]: try: - l = self.col.decks.getConf(r['id']) + l = self.col.decks.getConf(r["id"]) except KeyError: l = None # if missing locally or server is newer, update - if not l or r['mod'] > l['mod']: + if not l or r["mod"] > l["mod"]: self.col.decks.updateConf(r) # Tags ########################################################################## def allItems(self) -> List[Tuple[str, int]]: - tags=self.col.db.execute("select tag, usn from tags") - return [(tag, int(usn)) for tag,usn in tags] + tags = self.col.db.execute("select tag, usn from tags") + return [(tag, int(usn)) for tag, usn in tags] + def getTags(self): tags = [] for t, usn in self.allItems(): @@ -301,15 +319,16 @@ from notes where %s""" % lim, self.maxUsn) def mergeRevlog(self, logs): self.col.db.executemany( - "insert or ignore into revlog values (?,?,?,?,?,?,?,?,?)", - logs) + "insert or ignore into revlog values (?,?,?,?,?,?,?,?,?)", logs + ) def newerRows(self, data, table, modIdx): ids = (r[0] for r in data) lmods = {} for id, mod in self.col.db.execute( - "select id, mod from %s where id in %s and %s" % ( - table, ids2str(ids), self.usnLim())): + "select id, mod from %s where id in %s and %s" + % (table, ids2str(ids), self.usnLim()) + ): lmods[id] = mod update = [] for r in data: @@ -323,14 +342,17 @@ from notes where %s""" % lim, self.maxUsn) self.col.db.executemany( "insert or replace into cards values " "(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", - self.newerRows(cards, "cards", 4)) + self.newerRows(cards, "cards", 4), + ) def mergeNotes(self, notes): rows = self.newerRows(notes, "notes", 3) self.col.db.executemany( - "insert or replace into notes values (?,?,?,?,?,?,?,?,?,?,?)", - rows) - self.col.after_note_updates([f[0] for f in rows], mark_modified=False, generate_cards=False) + "insert or replace into notes values (?,?,?,?,?,?,?,?,?,?,?)", rows + ) + self.col.after_note_updates( + [f[0] for f in rows], mark_modified=False, generate_cards=False + ) # Col config ########################################################################## @@ -341,11 +363,14 @@ from notes where %s""" % lim, self.maxUsn) def mergeConf(self, conf): for key, value in conf.items(): self.col.set_config(key, value) + + # self.col.backend.set_all_config(json.dumps(conf).encode()) # Wrapper for requests that tracks upload/download progress ########################################################################## + class AnkiRequestsClient(object): verify = True timeout = 60 @@ -355,15 +380,23 @@ class AnkiRequestsClient(object): def post(self, url, data, headers): data = _MonitoringFile(data) - headers['User-Agent'] = self._agentName() + headers["User-Agent"] = self._agentName() return self.session.post( - url, data=data, headers=headers, stream=True, timeout=self.timeout, verify=self.verify) + url, + data=data, + headers=headers, + stream=True, + timeout=self.timeout, + verify=self.verify, + ) def get(self, url, headers=None): if headers is None: headers = {} - headers['User-Agent'] = self._agentName() - return self.session.get(url, stream=True, headers=headers, timeout=self.timeout, verify=self.verify) + headers["User-Agent"] = self._agentName() + return self.session.get( + url, stream=True, headers=headers, timeout=self.timeout, verify=self.verify + ) def streamContent(self, resp): resp.raise_for_status() @@ -375,24 +408,30 @@ class AnkiRequestsClient(object): def _agentName(self): from anki import version + return "Anki {}".format(version) + # allow user to accept invalid certs in work/school settings if os.environ.get("ANKI_NOVERIFYSSL"): AnkiRequestsClient.verify = False import warnings + warnings.filterwarnings("ignore") + class _MonitoringFile(io.BufferedReader): def read(self, size=-1): data = io.BufferedReader.read(self, HTTP_BUF_SIZE) return data + # HTTP syncing tools ########################################################################## + class HttpSyncer(object): def __init__(self, hkey=None, client=None, hostNum=None): self.hkey = hkey @@ -421,24 +460,29 @@ class HttpSyncer(object): # support file uploading, so this is the more compatible choice. def _buildPostData(self, fobj, comp): - BOUNDARY=b"Anki-sync-boundary" - bdry = b"--"+BOUNDARY + BOUNDARY = b"Anki-sync-boundary" + bdry = b"--" + BOUNDARY buf = io.BytesIO() # post vars - self.postVars['c'] = 1 if comp else 0 + self.postVars["c"] = 1 if comp else 0 for (key, value) in list(self.postVars.items()): buf.write(bdry + b"\r\n") buf.write( - ('Content-Disposition: form-data; name="%s"\r\n\r\n%s\r\n' % - (key, value)).encode("utf8")) + ( + 'Content-Disposition: form-data; name="%s"\r\n\r\n%s\r\n' + % (key, value) + ).encode("utf8") + ) # payload as raw data or json rawSize = 0 if fobj: # header buf.write(bdry + b"\r\n") - buf.write(b"""\ + buf.write( + b"""\ Content-Disposition: form-data; name="data"; filename="data"\r\n\ -Content-Type: application/octet-stream\r\n\r\n""") +Content-Type: application/octet-stream\r\n\r\n""" + ) # write file into buffer, optionally compressing if comp: tgt = gzip.GzipFile(mode="wb", fileobj=buf, compresslevel=comp) @@ -453,16 +497,17 @@ Content-Type: application/octet-stream\r\n\r\n""") rawSize += len(data) tgt.write(data) buf.write(b"\r\n") - buf.write(bdry + b'--\r\n') + buf.write(bdry + b"--\r\n") size = buf.tell() # connection headers headers = { - 'Content-Type': 'multipart/form-data; boundary=%s' % BOUNDARY.decode("utf8"), - 'Content-Length': str(size), + "Content-Type": "multipart/form-data; boundary=%s" + % BOUNDARY.decode("utf8"), + "Content-Length": str(size), } buf.seek(0) - if size >= 100*1024*1024 or rawSize >= 250*1024*1024: + if size >= 100 * 1024 * 1024 or rawSize >= 250 * 1024 * 1024: raise Exception("Collection too large to upload to AnkiWeb.") return headers, buf @@ -470,7 +515,7 @@ Content-Type: application/octet-stream\r\n\r\n""") def req(self, method, fobj=None, comp=6, badAuthRaises=True): headers, body = self._buildPostData(fobj, comp) - r = self.client.post(self.syncURL()+method, data=body, headers=headers) + r = self.client.post(self.syncURL() + method, data=body, headers=headers) if not badAuthRaises and r.status_code == 403: return False self.assertOk(r) @@ -478,9 +523,11 @@ Content-Type: application/octet-stream\r\n\r\n""") buf = self.client.streamContent(r) return buf + # Incremental sync over HTTP ###################################################################### + class RemoteServer(HttpSyncer): def __init__(self, hkey, hostNum): super().__init__(self, hkey, hostNum=hostNum) @@ -489,12 +536,14 @@ class RemoteServer(HttpSyncer): "Returns hkey or none if user/pw incorrect." self.postVars = dict() ret = self.req( - "hostKey", io.BytesIO(json.dumps(dict(u=user, p=pw)).encode("utf8")), - badAuthRaises=False) + "hostKey", + io.BytesIO(json.dumps(dict(u=user, p=pw)).encode("utf8")), + badAuthRaises=False, + ) if not ret: # invalid auth return - self.hkey = json.loads(ret.decode("utf8"))['key'] + self.hkey = json.loads(ret.decode("utf8"))["key"] return self.hkey def meta(self): @@ -503,9 +552,17 @@ class RemoteServer(HttpSyncer): s=self.skey, ) ret = self.req( - "meta", io.BytesIO(json.dumps(dict( - v=SYNC_VER, cv="ankidesktop,%s,%s"%(versionWithBuild(), platDesc()))).encode("utf8")), - badAuthRaises=False) + "meta", + io.BytesIO( + json.dumps( + dict( + v=SYNC_VER, + cv="ankidesktop,%s,%s" % (versionWithBuild(), platDesc()), + ) + ).encode("utf8") + ), + badAuthRaises=False, + ) if not ret: # invalid auth return @@ -537,17 +594,20 @@ class RemoteServer(HttpSyncer): def _run(self, cmd, data): return json.loads( - self.req(cmd, io.BytesIO(json.dumps(data).encode("utf8"))).decode("utf8")) + self.req(cmd, io.BytesIO(json.dumps(data).encode("utf8"))).decode("utf8") + ) + # Full syncing ########################################################################## + class FullSyncer(HttpSyncer): def __init__(self, col, hkey, client, hostNum): super().__init__(self, hkey, client, hostNum=hostNum) self.postVars = dict( k=self.hkey, - v="ankidesktop,%s,%s"%(anki.version, platDesc()), + v="ankidesktop,%s,%s" % (anki.version, platDesc()), ) self.col = col @@ -586,9 +646,11 @@ class FullSyncer(HttpSyncer): return False return True + # Remote media syncing ########################################################################## + class RemoteMediaServer(HttpSyncer): def __init__(self, col, hkey, client, hostNum): self.col = col @@ -597,12 +659,12 @@ class RemoteMediaServer(HttpSyncer): def begin(self): self.postVars = dict( - k=self.hkey, - v="ankidesktop,%s,%s"%(anki.version, platDesc()) + k=self.hkey, v="ankidesktop,%s,%s" % (anki.version, platDesc()) ) - ret = self._dataOnly(self.req( - "begin", io.BytesIO(json.dumps(dict()).encode("utf8")))) - self.skey = ret['sk'] + ret = self._dataOnly( + self.req("begin", io.BytesIO(json.dumps(dict()).encode("utf8"))) + ) + self.skey = ret["sk"] return ret # args: lastUsn @@ -611,7 +673,8 @@ class RemoteMediaServer(HttpSyncer): sk=self.skey, ) return self._dataOnly( - self.req("mediaChanges", io.BytesIO(json.dumps(kw).encode("utf8")))) + self.req("mediaChanges", io.BytesIO(json.dumps(kw).encode("utf8"))) + ) # args: files def downloadFiles(self, **kw): @@ -619,20 +682,20 @@ class RemoteMediaServer(HttpSyncer): def uploadChanges(self, zip): # no compression, as we compress the zip file instead - return self._dataOnly( - self.req("uploadChanges", io.BytesIO(zip), comp=0)) + return self._dataOnly(self.req("uploadChanges", io.BytesIO(zip), comp=0)) # args: local def mediaSanity(self, **kw): return self._dataOnly( - self.req("mediaSanity", io.BytesIO(json.dumps(kw).encode("utf8")))) + self.req("mediaSanity", io.BytesIO(json.dumps(kw).encode("utf8"))) + ) def _dataOnly(self, resp): resp = json.loads(resp.decode("utf8")) - if resp['err']: - self.col.log("error returned:%s"%resp['err']) - raise Exception("SyncError:%s"%resp['err']) - return resp['data'] + if resp["err"]: + self.col.log("error returned:%s" % resp["err"]) + raise Exception("SyncError:%s" % resp["err"]) + return resp["data"] # only for unit tests def mediatest(self, cmd): @@ -640,5 +703,7 @@ class RemoteMediaServer(HttpSyncer): k=self.hkey, ) return self._dataOnly( - self.req("newMediaTest", io.BytesIO( - json.dumps(dict(cmd=cmd)).encode("utf8")))) + self.req( + "newMediaTest", io.BytesIO(json.dumps(dict(cmd=cmd)).encode("utf8")) + ) + ) diff --git a/src/ankisyncd/sync_app.py b/src/ankisyncd/sync_app.py index bcf8652..8d01f0e 100644 --- a/src/ankisyncd/sync_app.py +++ b/src/ankisyncd/sync_app.py @@ -43,7 +43,16 @@ logger = logging.getLogger("ankisyncd") class SyncCollectionHandler(Syncer): - operations = ['meta', 'applyChanges', 'start', 'applyGraves', 'chunk', 'applyChunk', 'sanityCheck2', 'finish'] + operations = [ + "meta", + "applyChanges", + "start", + "applyGraves", + "chunk", + "applyChunk", + "sanityCheck2", + "finish", + ] def __init__(self, col, session): # So that 'server' (the 3rd argument) can't get set @@ -56,9 +65,9 @@ class SyncCollectionHandler(Syncer): return False note = {"alpha": 0, "beta": 0, "rc": 0} - client, version, platform = cv.split(',') + client, version, platform = cv.split(",") - if 'arch' not in version: + if "arch" not in version: for name in note.keys(): if name in version: vs = version.split(name) @@ -66,17 +75,17 @@ class SyncCollectionHandler(Syncer): note[name] = int(vs[-1]) # convert the version string, ignoring non-numeric suffixes like in beta versions of Anki - version_nosuffix = re.sub(r'[^0-9.].*$', '', version) - version_int = [int(x) for x in version_nosuffix.split('.')] + version_nosuffix = re.sub(r"[^0-9.].*$", "", version) + version_int = [int(x) for x in version_nosuffix.split(".")] - if client == 'ankidesktop': + if client == "ankidesktop": return version_int < [2, 0, 27] - elif client == 'ankidroid': + elif client == "ankidroid": if version_int == [2, 3]: - if note["alpha"]: - return note["alpha"] < 4 + if note["alpha"]: + return note["alpha"] < 4 else: - return version_int < [2, 2, 3] + return version_int < [2, 2, 3] else: # unknown client, assume current version return False @@ -84,23 +93,33 @@ class SyncCollectionHandler(Syncer): if self._old_client(cv): return Response(status=501) # client needs upgrade if v > SYNC_VER: - return {"cont": False, "msg": "Your client is using unsupported sync protocol ({}, supported version: {})".format(v, SYNC_VER)} + return { + "cont": False, + "msg": "Your client is using unsupported sync protocol ({}, supported version: {})".format( + v, SYNC_VER + ), + } if v < 9 and self.col.schedVer() >= 2: - return {"cont": False, "msg": "Your client doesn't support the v{} scheduler.".format(self.col.schedVer())} + return { + "cont": False, + "msg": "Your client doesn't support the v{} scheduler.".format( + self.col.schedVer() + ), + } # Make sure the media database is open! self.col.media.connect() return { - 'mod': self.col.mod, - 'scm': self.scm(), - 'usn': self.col.usn(), - 'ts': anki.utils.intTime(), - 'musn': self.col.media.lastUsn(), - 'uname': self.session.name, - 'msg': '', - 'cont': True, - 'hostNum': 0, + "mod": self.col.mod, + "scm": self.scm(), + "usn": self.col.usn(), + "ts": anki.utils.intTime(), + "musn": self.col.media.lastUsn(), + "uname": self.session.name, + "msg": "", + "cont": True, + "hostNum": 0, } def usnLim(self): @@ -108,10 +127,16 @@ class SyncCollectionHandler(Syncer): # ankidesktop >=2.1rc2 sends graves in applyGraves, but still expects # server-side deletions to be returned by start - def start(self, minUsn, lnewer, graves={"cards": [], "notes": [], "decks": []}, offset=None): + def start( + self, + minUsn, + lnewer, + graves={"cards": [], "notes": [], "decks": []}, + offset=None, + ): # The offset para is passed by client V2 scheduler,which is minutes_west. - # Since now have not thorougly test the V2 scheduler, we leave this comments here, and - # just enable the V2 scheduler in the serve code. + # Since now have not thorougly test the V2 scheduler, we leave this comments here, and + # just enable the V2 scheduler in the serve code. self.maxUsn = self.col.usn() self.minUsn = minUsn @@ -122,11 +147,11 @@ class SyncCollectionHandler(Syncer): # Only if Operations like deleting deck are performed on Ankidroid # can (client) graves is not None if graves is not None: - self.apply_graves(graves,self.maxUsn) + self.apply_graves(graves, self.maxUsn) return lgraves def applyGraves(self, chunk): - self.apply_graves(chunk,self.maxUsn) + self.apply_graves(chunk, self.maxUsn) def applyChanges(self, changes): self.rchg = changes @@ -136,16 +161,13 @@ class SyncCollectionHandler(Syncer): return lchg def sanityCheck2(self, client): - client[0]=[0,0,0] + client[0] = [0, 0, 0] server = self.sanityCheck() if client != server: - logger.info( - f"sanity check failed with server: {server} client: {client}" - ) + logger.info(f"sanity check failed with server: {server} client: {client}") return dict(status="bad", c=client, s=server) return dict(status="ok") - def finish(self): return super().finish(anki.utils.intTime(1000)) @@ -158,7 +180,8 @@ class SyncCollectionHandler(Syncer): decks = [] curs = self.col.db.execute( - "select oid, type from graves where usn >= ?", self.minUsn) + "select oid, type from graves where usn >= ?", self.minUsn + ) for oid, type in curs: if type == REM_CARD: @@ -171,20 +194,26 @@ class SyncCollectionHandler(Syncer): return dict(cards=cards, notes=notes, decks=decks) def getModels(self): - return [m for m in self.col.models.all() if m['usn'] >= self.minUsn] + return [m for m in self.col.models.all() if m["usn"] >= self.minUsn] def getDecks(self): return [ - [g for g in self.col.decks.all() if g['usn'] >= self.minUsn], - [g for g in self.col.decks.all_config() if g['usn'] >= self.minUsn] + [g for g in self.col.decks.all() if g["usn"] >= self.minUsn], + [g for g in self.col.decks.all_config() if g["usn"] >= self.minUsn], ] def getTags(self): - return [t for t, usn in self.allItems() - if usn >= self.minUsn] + return [t for t, usn in self.allItems() if usn >= self.minUsn] + class SyncMediaHandler: - operations = ['begin', 'mediaChanges', 'mediaSanity', 'uploadChanges', 'downloadFiles'] + operations = [ + "begin", + "mediaChanges", + "mediaSanity", + "uploadChanges", + "downloadFiles", + ] def __init__(self, col, session): self.col = col @@ -192,11 +221,11 @@ class SyncMediaHandler: def begin(self, skey): return { - 'data': { - 'sk': skey, - 'usn': self.col.media.lastUsn(), + "data": { + "sk": skey, + "usn": self.col.media.lastUsn(), }, - 'err': '', + "err": "", } def uploadChanges(self, data): @@ -210,24 +239,27 @@ class SyncMediaHandler: processed_count = self._adopt_media_changes_from_zip(z) return { - 'data': [processed_count, self.col.media.lastUsn()], - 'err': '', + "data": [processed_count, self.col.media.lastUsn()], + "err": "", } @staticmethod def _check_zip_data(zip_file): - max_zip_size = 100*1024*1024 + max_zip_size = 100 * 1024 * 1024 max_meta_file_size = 100000 meta_file_size = zip_file.getinfo("_meta").file_size sum_file_sizes = sum(info.file_size for info in zip_file.infolist()) if meta_file_size > max_meta_file_size: - raise ValueError("Zip file's metadata file is larger than %s " - "Bytes." % max_meta_file_size) + raise ValueError( + "Zip file's metadata file is larger than %s " + "Bytes." % max_meta_file_size + ) elif sum_file_sizes > max_zip_size: - raise ValueError("Zip file contents are larger than %s Bytes." % - max_zip_size) + raise ValueError( + "Zip file contents are larger than %s Bytes." % max_zip_size + ) def _adopt_media_changes_from_zip(self, zip_file): """ @@ -261,7 +293,7 @@ class SyncMediaHandler: file_path = os.path.join(media_dir, filename) # Save file to media directory. - with open(file_path, 'wb') as f: + with open(file_path, "wb") as f: f.write(file_data) usn += 1 @@ -279,7 +311,9 @@ class SyncMediaHandler: if media_to_add: self.col.media.addMedia(media_to_add) - assert self.col.media.lastUsn() == oldUsn + processed_count # TODO: move to some unit test + assert ( + self.col.media.lastUsn() == oldUsn + processed_count + ) # TODO: move to some unit test return processed_count @staticmethod @@ -302,13 +336,15 @@ class SyncMediaHandler: Marks all files in list filenames as deleted and removes them from the media directory. """ - logger.debug('Removing %d files from media dir.' % len(filenames)) + logger.debug("Removing %d files from media dir." % len(filenames)) for filename in filenames: try: self.col.media.syncDelete(filename) except OSError as err: - logger.error("Error when removing file '%s' from media dir: " - "%s" % (filename, str(err))) + logger.error( + "Error when removing file '%s' from media dir: " + "%s" % (filename, str(err)) + ) def downloadFiles(self, files): flist = {} @@ -334,13 +370,17 @@ class SyncMediaHandler: server_lastUsn = self.col.media.lastUsn() if lastUsn < server_lastUsn or lastUsn == 0: - for fname,usn,csum, in self.col.media.changes(lastUsn): + for ( + fname, + usn, + csum, + ) in self.col.media.changes(lastUsn): result.append([fname, usn, csum]) # anki assumes server_lastUsn == result[-1][1] # ref: anki/sync.py:720 (commit cca3fcb2418880d0430a5c5c2e6b81ba260065b7) result.reverse() - return {'data': result, 'err': ''} + return {"data": result, "err": ""} def mediaSanity(self, local=None): if self.col.media.mediaCount() == local: @@ -348,7 +388,8 @@ class SyncMediaHandler: else: result = "FAILED" - return {'data': result, 'err': ''} + return {"data": result, "err": ""} + class SyncUserSession: def __init__(self, name, path, collection_manager, setup_new_collection=None): @@ -371,16 +412,18 @@ class SyncUserSession: return anki.utils.checksum(str(random.random()))[:8] def get_collection_path(self): - return os.path.realpath(os.path.join(self.path, 'collection.anki2')) + return os.path.realpath(os.path.join(self.path, "collection.anki2")) def get_thread(self): - return self.collection_manager.get_collection(self.get_collection_path(), self.setup_new_collection) + return self.collection_manager.get_collection( + self.get_collection_path(), self.setup_new_collection + ) def get_handler_for_operation(self, operation, col): if operation in SyncCollectionHandler.operations: - attr, handler_class = 'collection_handler', SyncCollectionHandler + attr, handler_class = "collection_handler", SyncCollectionHandler elif operation in SyncMediaHandler.operations: - attr, handler_class = 'media_handler', SyncMediaHandler + attr, handler_class = "media_handler", SyncMediaHandler else: raise Exception("no handler for {}".format(operation)) @@ -391,146 +434,178 @@ class SyncUserSession: # for inactivity and then later re-open it (creating a new Collection object). handler.col = col return handler + + class Requests(object): - def __init__(self,environ: dict): - self.environ=environ + def __init__(self, environ: dict): + self.environ = environ + @property def params(self): return self.request_items_dict + @params.setter - def params(self,value): + def params(self, value): """ A dictionary-like object containing both the parameters from the query string and request body. """ - self.request_items_dict= value + self.request_items_dict = value + @property - def path(self)-> str: - return self.environ['PATH_INFO'] + def path(self) -> str: + return self.environ["PATH_INFO"] + @property def POST(self): return self._request_items_dict + @POST.setter - def POST(self,value): - self._request_items_dict=value + def POST(self, value): + self._request_items_dict = value + @property def parse(self): - '''Return a MultiDict containing all the variables from a form + """Return a MultiDict containing all the variables from a form request.\n - include not only post req,but also get''' + include not only post req,but also get""" env = self.environ - query_string=env['QUERY_STRING'] - content_len= env.get('CONTENT_LENGTH', '0') - input = env.get('wsgi.input') - length = 0 if content_len == '' else int(content_len) - body=b'' - request_items_dict={} + query_string = env["QUERY_STRING"] + content_len = env.get("CONTENT_LENGTH", "0") + input = env.get("wsgi.input") + length = 0 if content_len == "" else int(content_len) + body = b"" + request_items_dict = {} if length == 0: if input is None: return request_items_dict - if env.get('HTTP_TRANSFER_ENCODING','0') == 'chunked': + if env.get("HTTP_TRANSFER_ENCODING", "0") == "chunked": # readlines and read(no argument) will block # convert byte str to number base 16 - leng=int(input.readline(),16) - c=0 - bdry=b'' - data=[] - data_other=[] - while leng >0: - c+=1 - dt = input.read(leng+2) - if c==1: - bdry=dt - elif c>=3: + leng = int(input.readline(), 16) + c = 0 + bdry = b"" + data = [] + data_other = [] + while leng > 0: + c += 1 + dt = input.read(leng + 2) + if c == 1: + bdry = dt + elif c >= 3: # data data_other.append(dt) - leng = int(input.readline(),16) - data_other=[item for item in data_other if item!=b'\r\n\r\n'] + leng = int(input.readline(), 16) + data_other = [item for item in data_other if item != b"\r\n\r\n"] for item in data_other: if bdry in item: break # only strip \r\n if there are extra \n # eg b'?V\xc1\x8f>\xf9\xb1\n\r\n' data.append(item[:-2]) - request_items_dict['data']=b''.join(data) - others=data_other[len(data):] - boundary=others[0] - others=b''.join(others).split(boundary.strip()) + request_items_dict["data"] = b"".join(data) + others = data_other[len(data) :] + boundary = others[0] + others = b"".join(others).split(boundary.strip()) others.pop() others.pop(0) for i in others: - i=i.splitlines() - key=re.findall(b'name="(.*?)"',i[2],flags=re.M)[0].decode('utf-8') - v=i[-1].decode('utf-8') - request_items_dict[key]=v + i = i.splitlines() + key = re.findall(b'name="(.*?)"', i[2], flags=re.M)[0].decode( + "utf-8" + ) + v = i[-1].decode("utf-8") + request_items_dict[key] = v return request_items_dict - - if query_string !='': + + if query_string != "": # GET method - body=query_string - request_items_dict=urllib.parse.parse_qs(body) - for k,v in request_items_dict.items(): - request_items_dict[k]=''.join(v) + body = query_string + request_items_dict = urllib.parse.parse_qs(body) + for k, v in request_items_dict.items(): + request_items_dict[k] = "".join(v) return request_items_dict - + else: - body = env['wsgi.input'].read(length) - - if body is None or body ==b'': + body = env["wsgi.input"].read(length) + + if body is None or body == b"": return request_items_dict # process body to dict - repeat=body.splitlines()[0] - items=re.split(repeat,body) + repeat = body.splitlines()[0] + items = re.split(repeat, body) # del first ,last item items.pop() items.pop(0) for item in items: if b'name="data"' in item: - data_field=None - # remove \r\n - if b'application/octet-stream' in item: + data_field = None + # remove \r\n + if b"application/octet-stream" in item: # Ankidroid case - item=re.sub(b'Content-Disposition: form-data; name="data"; filename="data"',b'',item) - item=re.sub(b'Content-Type: application/octet-stream',b'',item) - data_field=item.strip() + item = re.sub( + b'Content-Disposition: form-data; name="data"; filename="data"', + b"", + item, + ) + item = re.sub(b"Content-Type: application/octet-stream", b"", item) + data_field = item.strip() else: # PKzip file stream and others - item=re.sub(b'Content-Disposition: form-data; name="data"; filename="data"',b'',item) - data_field=item.strip() - request_items_dict['data']=data_field + item = re.sub( + b'Content-Disposition: form-data; name="data"; filename="data"', + b"", + item, + ) + data_field = item.strip() + request_items_dict["data"] = data_field continue - item=re.sub(b'\r\n',b'',item,flags=re.MULTILINE) - key=re.findall(b'name="(.*?)"',item)[0].decode('utf-8') - v=item[item.rfind(b'"')+1:].decode('utf-8') - request_items_dict[key]=v + item = re.sub(b"\r\n", b"", item, flags=re.MULTILINE) + key = re.findall(b'name="(.*?)"', item)[0].decode("utf-8") + v = item[item.rfind(b'"') + 1 :].decode("utf-8") + request_items_dict[key] = v return request_items_dict + + class chunked(object): - '''decorator''' + """decorator""" + def __init__(self, func): wraps(func)(self) + def __call__(self, *args, **kwargs): - clss=args[0] - environ=args[1] + clss = args[0] + environ = args[1] start_response = args[2] - b=Requests(environ) - args=(clss,b,) - w= self.__wrapped__(*args, **kwargs) - resp=Response(w) + b = Requests(environ) + args = ( + clss, + b, + ) + w = self.__wrapped__(*args, **kwargs) + resp = Response(w) return resp(environ, start_response) + def __get__(self, instance, cls): if instance is None: return self else: return types.MethodType(self, instance) + + class SyncApp: - valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download'] + valid_urls = ( + SyncCollectionHandler.operations + + SyncMediaHandler.operations + + ["hostKey", "upload", "download"] + ) def __init__(self, config): from ankisyncd.thread import get_collection_manager - self.data_root = os.path.abspath(config['data_root']) - self.base_url = config['base_url'] - self.base_media_url = config['base_media_url'] + self.data_root = os.path.abspath(config["data_root"]) + self.base_url = config["base_url"] + self.base_media_url = config["base_media_url"] self.setup_new_collection = None self.user_manager = get_user_manager(config) @@ -539,22 +614,31 @@ class SyncApp: self.collection_manager = get_collection_manager(config) # make sure the base_url has a trailing slash - if not self.base_url.endswith('/'): - self.base_url += '/' - if not self.base_media_url.endswith('/'): - self.base_media_url += '/' + if not self.base_url.endswith("/"): + self.base_url += "/" + if not self.base_media_url.endswith("/"): + self.base_media_url += "/" def generateHostKey(self, username): """Generates a new host key to be used by the given username to identify their session. This values is random.""" import hashlib, time, random, string + chars = string.ascii_letters + string.digits - val = ':'.join([username, str(int(time.time())), ''.join(random.choice(chars) for x in range(8))]).encode() + val = ":".join( + [ + username, + str(int(time.time())), + "".join(random.choice(chars) for x in range(8)), + ] + ).encode() return hashlib.md5(val).hexdigest() def create_session(self, username, user_path): - return SyncUserSession(username, user_path, self.collection_manager, self.setup_new_collection) + return SyncUserSession( + username, user_path, self.collection_manager, self.setup_new_collection + ) def _decode_data(self, data, compression=0): if compression: @@ -564,7 +648,7 @@ class SyncApp: try: data = json.loads(data.decode()) except (ValueError, UnicodeDecodeError): - data = {'data': data} + data = {"data": data} return data @@ -581,7 +665,7 @@ class SyncApp: session = self.create_session(username, user_path) self.session_manager.save(hkey, session) - return {'key': hkey} + return {"key": hkey} def operation_upload(self, col, data, session): # Verify integrity of the received database file before replacing our @@ -599,56 +683,56 @@ class SyncApp: # cgi file can only be read once,and will be blocked after being read once more # so i call Requests.parse only once,and bind its return result to properties # POST and params (set return result as property values) - req.params=req.parse - req.POST=req.params + req.params = req.parse + req.POST = req.params try: - hkey = req.params['k'] + hkey = req.params["k"] except KeyError: hkey = None session = self.session_manager.load(hkey, self.create_session) if session is None: try: - skey = req.POST['sk'] + skey = req.POST["sk"] session = self.session_manager.load_from_skey(skey, self.create_session) except KeyError: skey = None try: - compression = int(req.POST['c']) + compression = int(req.POST["c"]) except KeyError: compression = 0 try: - data = req.POST['data'] + data = req.POST["data"] data = self._decode_data(data, compression) except KeyError: data = {} if req.path.startswith(self.base_url): - url = req.path[len(self.base_url):] + url = req.path[len(self.base_url) :] if url not in self.valid_urls: raise HTTPNotFound() - if url == 'hostKey': + if url == "hostKey": result = self.operation_hostKey(data.get("u"), data.get("p")) if result: return json.dumps(result) else: # TODO: do I have to pass 'null' for the client to receive None? - raise HTTPForbidden('null') + raise HTTPForbidden("null") if session is None: raise HTTPForbidden() if url in SyncCollectionHandler.operations + SyncMediaHandler.operations: # 'meta' passes the SYNC_VER but it isn't used in the handler - if url == 'meta': - if session.skey == None and 's' in req.POST: - session.skey = req.POST['s'] - if 'v' in data: - session.version = data['v'] - if 'cv' in data: - session.client_version = data['cv'] + if url == "meta": + if session.skey == None and "s" in req.POST: + session.skey = req.POST["s"] + if "v" in data: + session.version = data["v"] + if "cv" in data: + session.client_version = data["cv"] self.session_manager.save(hkey, session) session = self.session_manager.load(hkey, self.create_session) @@ -660,12 +744,12 @@ class SyncApp: return result - elif url == 'upload': + elif url == "upload": thread = session.get_thread() - result = thread.execute(self.operation_upload, [data['data'], session]) + result = thread.execute(self.operation_upload, [data["data"], session]) return result - elif url == 'download': + elif url == "download": thread = session.get_thread() result = thread.execute(self.operation_download, [session]) return result @@ -678,13 +762,13 @@ class SyncApp: if session is None: raise HTTPForbidden() - url = req.path[len(self.base_media_url):] + url = req.path[len(self.base_media_url) :] if url not in self.valid_urls: raise HTTPNotFound() if url == "begin": - data['skey'] = session.skey + data["skey"] = session.skey result = self._execute_handler_method_in_thread(url, data, session) @@ -726,10 +810,16 @@ class SyncApp: def make_app(global_conf, **local_conf): return SyncApp(**local_conf) + def main(): - logging.basicConfig(level=logging.INFO, format="[%(asctime)s]:%(levelname)s:%(name)s:%(message)s") + logging.basicConfig( + level=logging.INFO, format="[%(asctime)s]:%(levelname)s:%(name)s:%(message)s" + ) import ankisyncd - logger.info("ankisyncd {} ({})".format(ankisyncd._get_version(), ankisyncd._homepage)) + + logger.info( + "ankisyncd {} ({})".format(ankisyncd._get_version(), ankisyncd._homepage) + ) from wsgiref.simple_server import make_server, WSGIRequestHandler from ankisyncd.thread import shutdown import ankisyncd.config @@ -738,10 +828,10 @@ def main(): logger = logging.getLogger("ankisyncd.http") def log_error(self, format, *args): - self.logger.error("%s %s", self.address_string(), format%args) + self.logger.error("%s %s", self.address_string(), format % args) def log_message(self, format, *args): - self.logger.info("%s %s", self.address_string(), format%args) + self.logger.info("%s %s", self.address_string(), format % args) if len(sys.argv) > 1: # backwards compat @@ -750,7 +840,9 @@ def main(): config = ankisyncd.config.load() ankiserver = SyncApp(config) - httpd = make_server(config['host'], int(config['port']), ankiserver, handler_class=RequestHandler) + httpd = make_server( + config["host"], int(config["port"]), ankiserver, handler_class=RequestHandler + ) try: logger.info("Serving HTTP on {} port {}...".format(*httpd.server_address)) @@ -760,5 +852,6 @@ def main(): finally: shutdown() -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/src/ankisyncd/thread.py b/src/ankisyncd/thread.py index e0173cf..216fa80 100644 --- a/src/ankisyncd/thread.py +++ b/src/ankisyncd/thread.py @@ -5,6 +5,7 @@ from queue import Queue import time, logging + def short_repr(obj, logger=logging.getLogger(), maxlen=80): """Like repr, but shortens strings and bytestrings if logger's logging level is above DEBUG. Currently shallow and very limited, only implemented for @@ -28,6 +29,7 @@ def short_repr(obj, logger=logging.getLogger(), maxlen=80): return repr(o) + class ThreadingCollectionWrapper: """Provides the same interface as CollectionWrapper, but it creates a new Thread to interact with the collection.""" @@ -56,10 +58,11 @@ class ThreadingCollectionWrapper: def current(self): from threading import current_thread + return current_thread() == self._thread def execute(self, func, args=[], kw={}, waitForReturn=True): - """ Executes a given function on this thread with the *args and **kw. + """Executes a given function on this thread with the *args and **kw. If 'waitForReturn' is True, then it will block until the function has executed and return its return value. If False, it will return None @@ -86,19 +89,30 @@ class ThreadingCollectionWrapper: while self._running: func, args, kw, return_queue = self._queue.get(True) - if hasattr(func, '__name__'): + if hasattr(func, "__name__"): func_name = func.__name__ else: func_name = func.__class__.__name__ - self.logger.info("Running %s(*%s, **%s)", func_name, short_repr(args, self.logger), short_repr(kw, self.logger)) + self.logger.info( + "Running %s(*%s, **%s)", + func_name, + short_repr(args, self.logger), + short_repr(kw, self.logger), + ) self.last_timestamp = time.time() try: ret = self.wrapper.execute(func, args, kw, return_queue) except Exception as e: - self.logger.error("Unable to %s(*%s, **%s): %s", - func_name, repr(args), repr(kw), e, exc_info=True) + self.logger.error( + "Unable to %s(*%s, **%s): %s", + func_name, + repr(args), + repr(kw), + e, + exc_info=True, + ) # we return the Exception which will be raise'd on the other end ret = e @@ -125,10 +139,11 @@ class ThreadingCollectionWrapper: def stop(self): def _stop(col): self._running = False + self.execute(_stop, waitForReturn=False) def stop_and_wait(self): - """ Tell the thread to stop and wait for it to happen. """ + """Tell the thread to stop and wait for it to happen.""" self.stop() if self._thread is not None: self._thread.join() @@ -146,11 +161,13 @@ class ThreadingCollectionWrapper: def _close(col): self.wrapper.close() + self.execute(_close, waitForReturn=False) def opened(self): return self.wrapper.opened() + class ThreadingCollectionManager(CollectionManager): """Manages a set of ThreadingCollectionWrapper objects.""" @@ -174,14 +191,21 @@ class ThreadingCollectionManager(CollectionManager): # TODO: it would be awesome to have a safe way to stop inactive threads completely! # TODO: we need a way to inform other code that the collection has been closed def _monitor_run(self): - """ Monitors threads for inactivity and closes the collection on them + """Monitors threads for inactivity and closes the collection on them (leaves the thread itself running -- hopefully waiting peacefully with only a - small memory footprint!) """ + small memory footprint!)""" while True: cur = time.time() for path, thread in self.collections.items(): - if thread.running and thread.wrapper.opened() and thread.qempty() and cur - thread.last_timestamp >= self.monitor_inactivity: - self.logger.info("Monitor is closing collection on inactive %s", thread) + if ( + thread.running + and thread.wrapper.opened() + and thread.qempty() + and cur - thread.last_timestamp >= self.monitor_inactivity + ): + self.logger.info( + "Monitor is closing collection on inactive %s", thread + ) thread.close() time.sleep(self.monitor_frequency) @@ -196,12 +220,14 @@ class ThreadingCollectionManager(CollectionManager): # let the parent do whatever else it might want to do... super(ThreadingCollectionManager, self).shutdown() + # # For working with the global ThreadingCollectionManager: # collection_manager = None + def get_collection_manager(config): """Return the global ThreadingCollectionManager for this process.""" global collection_manager @@ -209,10 +235,10 @@ def get_collection_manager(config): collection_manager = ThreadingCollectionManager(config) return collection_manager + def shutdown(): """If the global ThreadingCollectionManager exists, shut it down.""" global collection_manager if collection_manager is not None: collection_manager.shutdown() collection_manager = None - diff --git a/src/ankisyncd/users.py b/src/ankisyncd/users.py index 01736ba..ab1fb62 100644 --- a/src/ankisyncd/users.py +++ b/src/ankisyncd/users.py @@ -11,7 +11,7 @@ logger = logging.getLogger("ankisyncd.users") class SimpleUserManager: """A simple user manager that always allows any user.""" - def __init__(self, collection_path=''): + def __init__(self, collection_path=""): self.collection_path = collection_path def authenticate(self, username, password): @@ -34,8 +34,11 @@ class SimpleUserManager: def _create_user_dir(self, username): user_dir_path = os.path.join(self.collection_path, username) if not os.path.isdir(user_dir_path): - logger.info("Creating collection directory for user '{}' at {}" - .format(username, user_dir_path)) + logger.info( + "Creating collection directory for user '{}' at {}".format( + username, user_dir_path + ) + ) os.makedirs(user_dir_path) @@ -54,13 +57,17 @@ class SqliteUserManager(SimpleUserManager): conn = self._conn() cursor = conn.cursor() - cursor.execute("SELECT * FROM sqlite_master " - "WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' " - "AND tbl_name = 'auth'") + cursor.execute( + "SELECT * FROM sqlite_master " + "WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' " + "AND tbl_name = 'auth'" + ) res = cursor.fetchone() conn.close() if res is not None: - raise Exception("Outdated database schema, run utils/migrate_user_tables.py") + raise Exception( + "Outdated database schema, run utils/migrate_user_tables.py" + ) # Default to using sqlite3 but overridable for sub-classes using other # DB API 2 driver variants @@ -124,8 +131,7 @@ class SqliteUserManager(SimpleUserManager): conn = self._conn() cursor = conn.cursor() logger.info("Adding user '{}' to auth db.".format(username)) - cursor.execute(self.fs("INSERT INTO auth VALUES (?, ?)"), - (username, pass_hash)) + cursor.execute(self.fs("INSERT INTO auth VALUES (?, ?)"), (username, pass_hash)) conn.commit() conn.close() @@ -139,7 +145,9 @@ class SqliteUserManager(SimpleUserManager): conn = self._conn() cursor = conn.cursor() - cursor.execute(self.fs("UPDATE auth SET hash=? WHERE username=?"), (hash, username)) + cursor.execute( + self.fs("UPDATE auth SET hash=? WHERE username=?"), (hash, username) + ) conn.commit() conn.close() @@ -156,8 +164,9 @@ class SqliteUserManager(SimpleUserManager): conn.close() if db_hash is None: - logger.info("Authentication failed for nonexistent user {}." - .format(username)) + logger.info( + "Authentication failed for nonexistent user {}.".format(username) + ) return False expected_value = str(db_hash[0]) @@ -181,17 +190,22 @@ class SqliteUserManager(SimpleUserManager): @staticmethod def _create_pass_hash(username, password): salt = binascii.b2a_hex(os.urandom(8)) - pass_hash = (hashlib.sha256((username + password).encode() + salt).hexdigest() + - salt.decode()) + pass_hash = ( + hashlib.sha256((username + password).encode() + salt).hexdigest() + + salt.decode() + ) return pass_hash def create_auth_db(self): conn = self._conn() cursor = conn.cursor() - logger.info("Creating auth db at {}." - .format(self.auth_db_path)) - cursor.execute(self.fs("""CREATE TABLE IF NOT EXISTS auth - (username VARCHAR PRIMARY KEY, hash VARCHAR)""")) + logger.info("Creating auth db at {}.".format(self.auth_db_path)) + cursor.execute( + self.fs( + """CREATE TABLE IF NOT EXISTS auth + (username VARCHAR PRIMARY KEY, hash VARCHAR)""" + ) + ) conn.commit() conn.close() @@ -199,19 +213,28 @@ class SqliteUserManager(SimpleUserManager): def get_user_manager(config): if "auth_db_path" in config and config["auth_db_path"]: logger.info("Found auth_db_path in config, using SqliteUserManager for auth") - return SqliteUserManager(config['auth_db_path'], config['data_root']) + return SqliteUserManager(config["auth_db_path"], config["data_root"]) elif "user_manager" in config and config["user_manager"]: # load from config - logger.info("Found user_manager in config, using {} for auth".format(config['user_manager'])) + logger.info( + "Found user_manager in config, using {} for auth".format( + config["user_manager"] + ) + ) import importlib import inspect - module_name, class_name = config['user_manager'].rsplit('.', 1) + + module_name, class_name = config["user_manager"].rsplit(".", 1) module = importlib.import_module(module_name.strip()) class_ = getattr(module, class_name.strip()) if not SimpleUserManager in inspect.getmro(class_): - raise TypeError('''"user_manager" found in the conf file but it doesn''t - inherit from SimpleUserManager''') + raise TypeError( + """"user_manager" found in the conf file but it doesn''t + inherit from SimpleUserManager""" + ) return class_(config) else: - logger.warning("neither auth_db_path nor user_manager set, ankisyncd will accept any password") - return SimpleUserManager() + logger.warning( + "neither auth_db_path nor user_manager set, ankisyncd will accept any password" + ) + return SimpleUserManager() diff --git a/src/ankisyncd_cli/__main__.py b/src/ankisyncd_cli/__main__.py index 1a160fd..e13ede5 100644 --- a/src/ankisyncd_cli/__main__.py +++ b/src/ankisyncd_cli/__main__.py @@ -1,4 +1,4 @@ from ankisyncd_cli import ankisyncctl -if __name__ == '__main__': - ankisyncctl.main() \ No newline at end of file +if __name__ == "__main__": + ankisyncctl.main() diff --git a/src/ankisyncd_cli/ankisyncctl.py b/src/ankisyncd_cli/ankisyncctl.py index 9ccbf15..9447abd 100644 --- a/src/ankisyncd_cli/ankisyncctl.py +++ b/src/ankisyncd_cli/ankisyncctl.py @@ -9,6 +9,7 @@ from ankisyncd.users import get_user_manager config = config.load() + def usage(): print("usage: {} []".format(sys.argv[0])) print() @@ -18,12 +19,14 @@ def usage(): print(" lsuser - list users") print(" passwd - change password of a user") + def adduser(username): password = getpass.getpass("Enter password for {}: ".format(username)) user_manager = get_user_manager(config) user_manager.add_user(username, password) + def deluser(username): user_manager = get_user_manager(config) try: @@ -31,6 +34,7 @@ def deluser(username): except ValueError as error: print("Could not delete user {}: {}".format(username, error), file=sys.stderr) + def lsuser(): user_manager = get_user_manager(config) try: @@ -40,6 +44,7 @@ def lsuser(): except ValueError as error: print("Could not list users: {}".format(error), file=sys.stderr) + def passwd(username): user_manager = get_user_manager(config) @@ -51,7 +56,11 @@ def passwd(username): try: user_manager.set_password_for_user(username, password) except ValueError as error: - print("Could not set password for user {}: {}".format(username, error), file=sys.stderr) + print( + "Could not set password for user {}: {}".format(username, error), + file=sys.stderr, + ) + def main(): argc = len(sys.argv) @@ -78,5 +87,6 @@ def main(): usage() exit(1) + if __name__ == "__main__": main() diff --git a/src/ankisyncd_cli/migrate_user_tables.py b/src/ankisyncd_cli/migrate_user_tables.py index 923aa0c..745e19a 100755 --- a/src/ankisyncd_cli/migrate_user_tables.py +++ b/src/ankisyncd_cli/migrate_user_tables.py @@ -7,11 +7,13 @@ word in many other SQL dialects. """ import os import sys -path = os.path.realpath(os.path.abspath(os.path.join(__file__, '../'))) + +path = os.path.realpath(os.path.abspath(os.path.join(__file__, "../"))) sys.path.insert(0, os.path.dirname(path)) import sqlite3 import ankisyncd.config + conf = ankisyncd.config.load() @@ -21,15 +23,21 @@ def main(): conn = sqlite3.connect(conf["auth_db_path"]) cursor = conn.cursor() - cursor.execute("SELECT * FROM sqlite_master " - "WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' " - "AND tbl_name = 'auth'") + cursor.execute( + "SELECT * FROM sqlite_master " + "WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' " + "AND tbl_name = 'auth'" + ) res = cursor.fetchone() if res is not None: cursor.execute("ALTER TABLE auth RENAME TO auth_old") - cursor.execute("CREATE TABLE auth (username VARCHAR PRIMARY KEY, hash VARCHAR)") - cursor.execute("INSERT INTO auth (username, hash) SELECT user, hash FROM auth_old") + cursor.execute( + "CREATE TABLE auth (username VARCHAR PRIMARY KEY, hash VARCHAR)" + ) + cursor.execute( + "INSERT INTO auth (username, hash) SELECT user, hash FROM auth_old" + ) cursor.execute("DROP TABLE auth_old") conn.commit() print("Successfully updated table 'auth'") @@ -44,17 +52,23 @@ def main(): conn = sqlite3.connect(conf["session_db_path"]) cursor = conn.cursor() - cursor.execute("SELECT * FROM sqlite_master " - "WHERE sql LIKE '%user VARCHAR%' " - "AND tbl_name = 'session'") + cursor.execute( + "SELECT * FROM sqlite_master " + "WHERE sql LIKE '%user VARCHAR%' " + "AND tbl_name = 'session'" + ) res = cursor.fetchone() if res is not None: cursor.execute("ALTER TABLE session RENAME TO session_old") - cursor.execute("CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, " - "username VARCHAR, path VARCHAR)") - cursor.execute("INSERT INTO session (hkey, skey, username, path) " - "SELECT hkey, skey, user, path FROM session_old") + cursor.execute( + "CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, " + "username VARCHAR, path VARCHAR)" + ) + cursor.execute( + "INSERT INTO session (hkey, skey, username, path) " + "SELECT hkey, skey, user, path FROM session_old" + ) cursor.execute("DROP TABLE session_old") conn.commit() print("Successfully updated table 'session'") diff --git a/src/setup.py b/src/setup.py index cc310a1..1cecf03 100644 --- a/src/setup.py +++ b/src/setup.py @@ -8,5 +8,5 @@ setup( author="Anki Community", author_email="kothary.vikash+ankicommunity@gmail.com", packages=find_packages(), - url='https://ankicommunity.github.io/' + url="https://ankicommunity.github.io/", ) diff --git a/tests/collection_test_base.py b/tests/collection_test_base.py index b1da4fd..6e5acaa 100644 --- a/tests/collection_test_base.py +++ b/tests/collection_test_base.py @@ -16,7 +16,7 @@ class CollectionTestBase(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.mkdtemp() - self.collection_path = os.path.join(self.temp_dir, 'collection.anki2'); + self.collection_path = os.path.join(self.temp_dir, "collection.anki2") cm = CollectionManager({}) collectionWrapper = cm.get_collection(self.collection_path) self.collection = collectionWrapper._get_collection() @@ -32,26 +32,26 @@ class CollectionTestBase(unittest.TestCase): def add_note(self, data): from anki.notes import Note - model = self.collection.models.byName(data['model']) + model = self.collection.models.byName(data["model"]) note = Note(self.collection, model) - for name, value in data['fields'].items(): + for name, value in data["fields"].items(): note[name] = value - if 'tags' in data: - note.setTagsFromStr(data['tags']) + if "tags" in data: + note.setTagsFromStr(data["tags"]) self.collection.addNote(note) # TODO: refactor into a parent class def add_default_note(self, count=1): data = { - 'model': 'Basic', - 'fields': { - 'Front': 'The front', - 'Back': 'The back', + "model": "Basic", + "fields": { + "Front": "The front", + "Back": "The back", }, - 'tags': "Tag1 Tag2", + "tags": "Tag1 Tag2", } for idx in range(0, count): self.add_note(data) diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index 0db0522..b556fb4 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -1 +1 @@ -from . import db_utils \ No newline at end of file +from . import db_utils diff --git a/tests/helpers/collection_utils.py b/tests/helpers/collection_utils.py index a2bfbe3..00625c0 100644 --- a/tests/helpers/collection_utils.py +++ b/tests/helpers/collection_utils.py @@ -5,6 +5,7 @@ import tempfile from anki.collection import Collection + class CollectionUtils: """ Provides utility methods for creating, inspecting and manipulating anki diff --git a/tests/helpers/db_utils.py b/tests/helpers/db_utils.py index b95e612..f5a30e6 100644 --- a/tests/helpers/db_utils.py +++ b/tests/helpers/db_utils.py @@ -34,7 +34,7 @@ def to_sql(database): else: connection = database - res = '\n'.join(connection.iterdump()) + res = "\n".join(connection.iterdump()) if type(database) == str: connection.close() @@ -54,18 +54,15 @@ def diff(left_db_path, right_db_path): command = ["sqldiff", left_db_path, right_db_path] - child_process = subprocess.Popen(command, - shell=False, - stdout=subprocess.PIPE) + child_process = subprocess.Popen(command, shell=False, stdout=subprocess.PIPE) stdout, stderr = child_process.communicate() exit_code = child_process.returncode if exit_code != 0 or stderr is not None: - raise RuntimeError("Command {} encountered an error, exit " - "code: {}, stderr: {}" - .format(" ".join(command), - exit_code, - stderr)) + raise RuntimeError( + "Command {} encountered an error, exit " + "code: {}, stderr: {}".format(" ".join(command), exit_code, stderr) + ) - # Any output from sqldiff means the databases differ. + # Any output from sqldiff means the databases differ. return stdout != "" diff --git a/tests/helpers/mock_servers.py b/tests/helpers/mock_servers.py index fdd3a18..aa3ef92 100644 --- a/tests/helpers/mock_servers.py +++ b/tests/helpers/mock_servers.py @@ -21,9 +21,8 @@ class MockServerConnection: r = self.test_app.post(url, params=data.read(), headers=headers, status="*") return types.SimpleNamespace(status_code=r.status_int, body=r.body) - def streamContent(self, r): - return r.body + return r.body class MockRemoteServer(RemoteServer): diff --git a/tests/helpers/monkey_patches.py b/tests/helpers/monkey_patches.py index e65fe80..de8cc3c 100644 --- a/tests/helpers/monkey_patches.py +++ b/tests/helpers/monkey_patches.py @@ -12,9 +12,7 @@ mediamanager_orig_funcs = { "_logChanges": None, } -db_orig_funcs = { - "__init__": None -} +db_orig_funcs = {"__init__": None} def monkeypatch_mediamanager(): @@ -35,6 +33,7 @@ def monkeypatch_mediamanager(): os.chdir(old_cwd) return res + return wrapper MediaManager.findChanges = make_cwd_safe(MediaManager.findChanges) @@ -47,6 +46,7 @@ def unpatch_mediamanager(): mediamanager_orig_funcs["findChanges"] = None + def monkeypatch_db(): """ Monkey patches Anki's DB.__init__ to connect to allow access to the db @@ -58,9 +58,7 @@ def monkeypatch_db(): def patched___init__(self, path, text=None, timeout=0): # Code taken from Anki's DB.__init__() # Allow more than one thread to use this connection. - self._db = sqlite.connect(path, - timeout=timeout, - check_same_thread=False) + self._db = sqlite.connect(path, timeout=timeout, check_same_thread=False) if text: self._db.text_factory = text self._path = path diff --git a/tests/helpers/server_utils.py b/tests/helpers/server_utils.py index 45e6b76..b74dd9f 100644 --- a/tests/helpers/server_utils.py +++ b/tests/helpers/server_utils.py @@ -23,50 +23,58 @@ def create_server_paths(): "data_root": os.path.join(dir, "data"), } + def create_sync_app(server_paths, config_path): config = configparser.ConfigParser() config.read(config_path) # Use custom files and dirs in settings. - config['sync_app'].update(server_paths) + config["sync_app"].update(server_paths) + + return SyncApp(config["sync_app"]) - return SyncApp(config['sync_app']) def get_session_for_hkey(server, hkey): return server.session_manager.load(hkey) + def get_thread_for_hkey(server, hkey): session = get_session_for_hkey(server, hkey) thread = session.get_thread() return thread + def get_col_wrapper_for_hkey(server, hkey): thread = get_thread_for_hkey(server, hkey) col_wrapper = thread.wrapper return col_wrapper + def get_col_for_hkey(server, hkey): col_wrapper = get_col_wrapper_for_hkey(server, hkey) col_wrapper.open() # Make sure the col is opened. return col_wrapper._CollectionWrapper__col + def get_col_db_path_for_hkey(server, hkey): col = get_col_for_hkey(server, hkey) return col.db._path -def get_syncer_for_hkey(server, hkey, syncer_type='collection'): + +def get_syncer_for_hkey(server, hkey, syncer_type="collection"): col = get_col_for_hkey(server, hkey) session = get_session_for_hkey(server, hkey) syncer_type = syncer_type.lower() - if syncer_type == 'collection': + if syncer_type == "collection": handler_method = SyncCollectionHandler.operations[0] - elif syncer_type == 'media': + elif syncer_type == "media": handler_method = SyncMediaHandler.operations[0] return session.get_handler_for_operation(handler_method, col) + def add_files_to_client_mediadb(media, filepaths, update_db=False): for filepath in filepaths: logging.debug("Adding file '{}' to client media DB".format(filepath)) @@ -76,16 +84,15 @@ def add_files_to_client_mediadb(media, filepaths, update_db=False): if update_db: media.findChanges() # Write changes to db. + def add_files_to_server_mediadb(media, filepaths): for filepath in filepaths: logging.debug("Adding file '{}' to server media DB".format(filepath)) fname = os.path.basename(filepath) - with open(filepath, 'rb') as infile: + with open(filepath, "rb") as infile: data = infile.read() csum = anki.utils.checksum(data) - with open(os.path.join(media.dir(), fname), 'wb') as f: + with open(os.path.join(media.dir(), fname), "wb") as f: f.write(data) - media.addMedia( - ((fname, media.lastUsn() + 1, csum),) - ) + media.addMedia(((fname, media.lastUsn() + 1, csum),)) diff --git a/tests/sync_app_functional_test_base.py b/tests/sync_app_functional_test_base.py index b4c589c..c416726 100644 --- a/tests/sync_app_functional_test_base.py +++ b/tests/sync_app_functional_test_base.py @@ -11,7 +11,6 @@ from helpers.monkey_patches import monkeypatch_db, unpatch_db class SyncAppFunctionalTestBase(unittest.TestCase): - @classmethod def setUpClass(cls): cls.colutils = CollectionUtils() @@ -28,28 +27,29 @@ class SyncAppFunctionalTestBase(unittest.TestCase): self.server_paths = helpers.server_utils.create_server_paths() # Add a test user to the temp auth db the server will use. - self.user_manager = SqliteUserManager(self.server_paths['auth_db_path'], - self.server_paths['data_root']) - self.user_manager.add_user('testuser', 'testpassword') + self.user_manager = SqliteUserManager( + self.server_paths["auth_db_path"], self.server_paths["data_root"] + ) + self.user_manager.add_user("testuser", "testpassword") # Get absolute path to development ini file. script_dir = os.path.dirname(os.path.realpath(__file__)) - ini_file_path = os.path.join(script_dir, - "assets", - "test.conf") + ini_file_path = os.path.join(script_dir, "assets", "test.conf") # Create SyncApp instance using the dev ini file and the temporary # paths. - self.server_app = helpers.server_utils.create_sync_app(self.server_paths, - ini_file_path) + self.server_app = helpers.server_utils.create_sync_app( + self.server_paths, ini_file_path + ) # Wrap the SyncApp object in TestApp instance for testing. self.server_test_app = TestApp(self.server_app) # MockRemoteServer instance needed for testing normal collection # syncing and for retrieving hkey for other tests. - self.mock_remote_server = MockRemoteServer(hkey=None, - server_test_app=self.server_test_app) + self.mock_remote_server = MockRemoteServer( + hkey=None, server_test_app=self.server_test_app + ) def tearDown(self): self.server_paths = {} diff --git a/tests/test_collection_wrappers.py b/tests/test_collection_wrappers.py index 37812f4..cb63e02 100644 --- a/tests/test_collection_wrappers.py +++ b/tests/test_collection_wrappers.py @@ -9,39 +9,52 @@ from ankisyncd.collection import get_collection_wrapper import helpers.server_utils + class FakeCollectionWrapper(CollectionWrapper): def __init__(self, config, path, setup_new_collection=None): - self. _CollectionWrapper__col = None + self._CollectionWrapper__col = None pass + class BadCollectionWrapper: pass + class CollectionWrapperFactoryTest(unittest.TestCase): def test_get_collection_wrapper(self): # Get absolute path to development ini file. script_dir = os.path.dirname(os.path.realpath(__file__)) - ini_file_path = os.path.join(script_dir, - "assets", - "test.conf") + ini_file_path = os.path.join(script_dir, "assets", "test.conf") # Create temporary files and dirs the server will use. server_paths = helpers.server_utils.create_server_paths() config = configparser.ConfigParser() config.read(ini_file_path) - path = os.path.realpath('fake/collection.anki2') + path = os.path.realpath("fake/collection.anki2") # Use custom files and dirs in settings. Should be CollectionWrapper - config['sync_app'].update(server_paths) - self.assertTrue(type(get_collection_wrapper(config['sync_app'], path) == CollectionWrapper)) + config["sync_app"].update(server_paths) + self.assertTrue( + type(get_collection_wrapper(config["sync_app"], path) == CollectionWrapper) + ) # A conf-specified CollectionWrapper is loaded - config.set("sync_app", "collection_wrapper", 'test_collection_wrappers.FakeCollectionWrapper') - self.assertTrue(type(get_collection_wrapper(config['sync_app'], path)) == FakeCollectionWrapper) + config.set( + "sync_app", + "collection_wrapper", + "test_collection_wrappers.FakeCollectionWrapper", + ) + self.assertTrue( + type(get_collection_wrapper(config["sync_app"], path)) + == FakeCollectionWrapper + ) # Should fail at load time if the class doesn't inherit from CollectionWrapper - config.set("sync_app", "collection_wrapper", 'test_collection_wrappers.BadCollectionWrapper') + config.set( + "sync_app", + "collection_wrapper", + "test_collection_wrappers.BadCollectionWrapper", + ) with self.assertRaises(TypeError): - pm = get_collection_wrapper(config['sync_app'], path) - + pm = get_collection_wrapper(config["sync_app"], path) diff --git a/tests/test_full_sync.py b/tests/test_full_sync.py index 64feb4b..1deff2e 100644 --- a/tests/test_full_sync.py +++ b/tests/test_full_sync.py @@ -8,20 +8,21 @@ from ankisyncd.full_sync import FullSyncManager, get_full_sync_manager import helpers.server_utils + class FakeFullSyncManager(FullSyncManager): def __init__(self, config): pass + class BadFullSyncManager: pass + class FullSyncManagerFactoryTest(unittest.TestCase): def test_get_full_sync_manager(self): # Get absolute path to development ini file. script_dir = os.path.dirname(os.path.realpath(__file__)) - ini_file_path = os.path.join(script_dir, - "assets", - "test.conf") + ini_file_path = os.path.join(script_dir, "assets", "test.conf") # Create temporary files and dirs the server will use. server_paths = helpers.server_utils.create_server_paths() @@ -30,15 +31,20 @@ class FullSyncManagerFactoryTest(unittest.TestCase): config.read(ini_file_path) # Use custom files and dirs in settings. Should be PersistenceManager - config['sync_app'].update(server_paths) - self.assertTrue(type(get_full_sync_manager(config['sync_app']) == FullSyncManager)) + config["sync_app"].update(server_paths) + self.assertTrue( + type(get_full_sync_manager(config["sync_app"]) == FullSyncManager) + ) # A conf-specified FullSyncManager is loaded - config.set("sync_app", "full_sync_manager", 'test_full_sync.FakeFullSyncManager') - self.assertTrue(type(get_full_sync_manager(config['sync_app'])) == FakeFullSyncManager) + config.set( + "sync_app", "full_sync_manager", "test_full_sync.FakeFullSyncManager" + ) + self.assertTrue( + type(get_full_sync_manager(config["sync_app"])) == FakeFullSyncManager + ) # Should fail at load time if the class doesn't inherit from FullSyncManager - config.set("sync_app", "full_sync_manager", 'test_full_sync.BadFullSyncManager') + config.set("sync_app", "full_sync_manager", "test_full_sync.BadFullSyncManager") with self.assertRaises(TypeError): - pm = get_full_sync_manager(config['sync_app']) - + pm = get_full_sync_manager(config["sync_app"]) diff --git a/tests/test_media.py b/tests/test_media.py index a7aacd9..1ece809 100644 --- a/tests/test_media.py +++ b/tests/test_media.py @@ -45,26 +45,22 @@ class ServerMediaManagerTest(unittest.TestCase): list(cm.db.execute("SELECT fname, csum FROM media")), ) self.assertEqual(cm.lastUsn(), sm.lastUsn()) - self.assertEqual( - list(sm.db.execute("SELECT usn FROM media")), - [(161,), (161,)] - ) + self.assertEqual(list(sm.db.execute("SELECT usn FROM media")), [(161,), (161,)]) def test_mediaChanges_lastUsn_order(self): col = self.colutils.create_empty_col() col.media = ankisyncd.media.ServerMediaManager(col) session = MagicMock() - session.name = 'test' + session.name = "test" mh = ankisyncd.sync_app.SyncMediaHandler(col, session) mh.col.media.addMedia( ( - ('fileA', 101, '53059abba1a72c7aff34a3eaf7fef10ed65541ce'), - ('fileB', 100, 'a5ae546046d09559399c80fa7076fb10f1ce4bcd'), + ("fileA", 101, "53059abba1a72c7aff34a3eaf7fef10ed65541ce"), + ("fileB", 100, "a5ae546046d09559399c80fa7076fb10f1ce4bcd"), ) ) # anki assumes mh.col.media.lastUsn() == mh.mediaChanges()['data'][-1][1] # ref: anki/sync.py:720 (commit cca3fcb2418880d0430a5c5c2e6b81ba260065b7) self.assertEqual( - mh.mediaChanges(lastUsn=99)['data'][-1][1], - mh.col.media.lastUsn() + mh.mediaChanges(lastUsn=99)["data"][-1][1], mh.col.media.lastUsn() ) diff --git a/tests/test_sessions.py b/tests/test_sessions.py index 78f297c..5d33d56 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -14,20 +14,21 @@ from ankisyncd.sync_app import SyncUserSession import helpers.server_utils + class FakeSessionManager(SimpleSessionManager): def __init__(self, config): pass + class BadSessionManager: pass + class SessionManagerFactoryTest(unittest.TestCase): def test_get_session_manager(self): # Get absolute path to development ini file. script_dir = os.path.dirname(os.path.realpath(__file__)) - ini_file_path = os.path.join(script_dir, - "assets", - "test.conf") + ini_file_path = os.path.join(script_dir, "assets", "test.conf") # Create temporary files and dirs the server will use. server_paths = helpers.server_utils.create_server_paths() @@ -36,32 +37,40 @@ class SessionManagerFactoryTest(unittest.TestCase): config.read(ini_file_path) # Use custom files and dirs in settings. Should be SqliteSessionManager - config['sync_app'].update(server_paths) - self.assertTrue(type(get_session_manager(config['sync_app']) == SqliteSessionManager)) + config["sync_app"].update(server_paths) + self.assertTrue( + type(get_session_manager(config["sync_app"]) == SqliteSessionManager) + ) # No value defaults to SimpleSessionManager config.remove_option("sync_app", "session_db_path") - self.assertTrue(type(get_session_manager(config['sync_app'])) == SimpleSessionManager) + self.assertTrue( + type(get_session_manager(config["sync_app"])) == SimpleSessionManager + ) # A conf-specified SessionManager is loaded - config.set("sync_app", "session_manager", 'test_sessions.FakeSessionManager') - self.assertTrue(type(get_session_manager(config['sync_app'])) == FakeSessionManager) + config.set("sync_app", "session_manager", "test_sessions.FakeSessionManager") + self.assertTrue( + type(get_session_manager(config["sync_app"])) == FakeSessionManager + ) # Should fail at load time if the class doesn't inherit from SimpleSessionManager - config.set("sync_app", "session_manager", 'test_sessions.BadSessionManager') + config.set("sync_app", "session_manager", "test_sessions.BadSessionManager") with self.assertRaises(TypeError): - sm = get_session_manager(config['sync_app']) + sm = get_session_manager(config["sync_app"]) # Add the session_db_path back, it should take precedence over BadSessionManager - config['sync_app'].update(server_paths) - self.assertTrue(type(get_session_manager(config['sync_app'])) == SqliteSessionManager) + config["sync_app"].update(server_paths) + self.assertTrue( + type(get_session_manager(config["sync_app"])) == SqliteSessionManager + ) class SimpleSessionManagerTest(unittest.TestCase): - test_hkey = '1234567890' + test_hkey = "1234567890" sdir = tempfile.mkdtemp(suffix="_session") os.rmdir(sdir) - test_session = SyncUserSession('testName', sdir, None, None) + test_session = SyncUserSession("testName", sdir, None, None) def setUp(self): self.sessionManager = SimpleSessionManager() @@ -71,10 +80,12 @@ class SimpleSessionManagerTest(unittest.TestCase): def test_save(self): self.sessionManager.save(self.test_hkey, self.test_session) - self.assertEqual(self.sessionManager.sessions[self.test_hkey].name, - self.test_session.name) - self.assertEqual(self.sessionManager.sessions[self.test_hkey].path, - self.test_session.path) + self.assertEqual( + self.sessionManager.sessions[self.test_hkey].name, self.test_session.name + ) + self.assertEqual( + self.sessionManager.sessions[self.test_hkey].path, self.test_session.path + ) def test_delete(self): self.sessionManager.save(self.test_hkey, self.test_session) @@ -111,13 +122,11 @@ class SqliteSessionManagerTest(SimpleSessionManagerTest): conn = sqlite3.connect(self._test_sess_db_path) cursor = conn.cursor() - cursor.execute("SELECT username, path FROM session WHERE hkey=?", - (self.test_hkey,)) + cursor.execute( + "SELECT username, path FROM session WHERE hkey=?", (self.test_hkey,) + ) res = cursor.fetchone() conn.close() self.assertEqual(res[0], self.test_session.name) self.assertEqual(res[1], self.test_session.path) - - - diff --git a/tests/test_sync_app.py b/tests/test_sync_app.py index 3147daf..674b756 100644 --- a/tests/test_sync_app.py +++ b/tests/test_sync_app.py @@ -16,10 +16,9 @@ class SyncCollectionHandlerTest(CollectionTestBase): def setUp(self): super().setUp() self.session = MagicMock() - self.session.name = 'test' + self.session.name = "test" self.syncCollectionHandler = SyncCollectionHandler( - self.collection, - self.session + self.collection, self.session ) def tearDown(self): @@ -28,48 +27,48 @@ class SyncCollectionHandlerTest(CollectionTestBase): def test_old_client(self): old = ( - ','.join(('ankidesktop', '2.0.12', 'lin::')), - ','.join(('ankidesktop', '2.0.26', 'lin::')), - ','.join(('ankidroid', '2.1', '')), - ','.join(('ankidroid', '2.2', '')), - ','.join(('ankidroid', '2.2.2', '')), - ','.join(('ankidroid', '2.3alpha3', '')), + ",".join(("ankidesktop", "2.0.12", "lin::")), + ",".join(("ankidesktop", "2.0.26", "lin::")), + ",".join(("ankidroid", "2.1", "")), + ",".join(("ankidroid", "2.2", "")), + ",".join(("ankidroid", "2.2.2", "")), + ",".join(("ankidroid", "2.3alpha3", "")), ) current = ( None, - ','.join(('ankidesktop', '2.0.27', 'lin::')), - ','.join(('ankidesktop', '2.0.32', 'lin::')), - ','.join(('ankidesktop', '2.1.0', 'lin::')), - ','.join(('ankidesktop', '2.1.6-beta2', 'lin::')), - ','.join(('ankidesktop', '2.1.9 (dev)', 'lin::')), - ','.join(('ankidesktop', '2.1.26 (arch-linux-2.1.26-1)', 'lin:arch:')), - ','.join(('ankidroid', '2.2.3', '')), - ','.join(('ankidroid', '2.3alpha4', '')), - ','.join(('ankidroid', '2.3alpha5', '')), - ','.join(('ankidroid', '2.3beta1', '')), - ','.join(('ankidroid', '2.3', '')), - ','.join(('ankidroid', '2.9', '')), + ",".join(("ankidesktop", "2.0.27", "lin::")), + ",".join(("ankidesktop", "2.0.32", "lin::")), + ",".join(("ankidesktop", "2.1.0", "lin::")), + ",".join(("ankidesktop", "2.1.6-beta2", "lin::")), + ",".join(("ankidesktop", "2.1.9 (dev)", "lin::")), + ",".join(("ankidesktop", "2.1.26 (arch-linux-2.1.26-1)", "lin:arch:")), + ",".join(("ankidroid", "2.2.3", "")), + ",".join(("ankidroid", "2.3alpha4", "")), + ",".join(("ankidroid", "2.3alpha5", "")), + ",".join(("ankidroid", "2.3beta1", "")), + ",".join(("ankidroid", "2.3", "")), + ",".join(("ankidroid", "2.9", "")), ) for cv in old: if not SyncCollectionHandler._old_client(cv): - raise AssertionError("old_client(\"%s\") is False" % cv) + raise AssertionError('old_client("%s") is False' % cv) for cv in current: if SyncCollectionHandler._old_client(cv): - raise AssertionError("old_client(\"%s\") is True" % cv) + raise AssertionError('old_client("%s") is True' % cv) def test_meta(self): meta = self.syncCollectionHandler.meta(v=SYNC_VER) - self.assertEqual(meta['scm'], self.collection.scm) - self.assertTrue((type(meta['ts']) == int) and meta['ts'] > 0) - self.assertEqual(meta['mod'], self.collection.mod) - self.assertEqual(meta['usn'], self.collection._usn) - self.assertEqual(meta['uname'], self.session.name) - self.assertEqual(meta['musn'], self.collection.media.lastUsn()) - self.assertEqual(meta['msg'], '') - self.assertEqual(meta['cont'], True) + self.assertEqual(meta["scm"], self.collection.scm) + self.assertTrue((type(meta["ts"]) == int) and meta["ts"] > 0) + self.assertEqual(meta["mod"], self.collection.mod) + self.assertEqual(meta["usn"], self.collection._usn) + self.assertEqual(meta["uname"], self.session.name) + self.assertEqual(meta["musn"], self.collection.media.lastUsn()) + self.assertEqual(meta["msg"], "") + self.assertEqual(meta["cont"], True) class SyncAppTest(unittest.TestCase): diff --git a/tests/test_users.py b/tests/test_users.py index 13cfb31..994095a 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -10,20 +10,21 @@ from ankisyncd.users import get_user_manager import helpers.server_utils + class FakeUserManager(SimpleUserManager): def __init__(self, config): pass + class BadUserManager: pass + class UserManagerFactoryTest(unittest.TestCase): def test_get_user_manager(self): # Get absolute path to development ini file. script_dir = os.path.dirname(os.path.realpath(__file__)) - ini_file_path = os.path.join(script_dir, - "assets", - "test.conf") + ini_file_path = os.path.join(script_dir, "assets", "test.conf") # Create temporary files and dirs the server will use. server_paths = helpers.server_utils.create_server_paths() @@ -32,25 +33,25 @@ class UserManagerFactoryTest(unittest.TestCase): config.read(ini_file_path) # Use custom files and dirs in settings. Should be SqliteUserManager - config['sync_app'].update(server_paths) - self.assertTrue(type(get_user_manager(config['sync_app']) == SqliteUserManager)) + config["sync_app"].update(server_paths) + self.assertTrue(type(get_user_manager(config["sync_app"]) == SqliteUserManager)) # No value defaults to SimpleUserManager config.remove_option("sync_app", "auth_db_path") - self.assertTrue(type(get_user_manager(config['sync_app'])) == SimpleUserManager) + self.assertTrue(type(get_user_manager(config["sync_app"])) == SimpleUserManager) # A conf-specified UserManager is loaded - config.set("sync_app", "user_manager", 'test_users.FakeUserManager') - self.assertTrue(type(get_user_manager(config['sync_app'])) == FakeUserManager) + config.set("sync_app", "user_manager", "test_users.FakeUserManager") + self.assertTrue(type(get_user_manager(config["sync_app"])) == FakeUserManager) # Should fail at load time if the class doesn't inherit from SimpleUserManager - config.set("sync_app", "user_manager", 'test_users.BadUserManager') + config.set("sync_app", "user_manager", "test_users.BadUserManager") with self.assertRaises(TypeError): - um = get_user_manager(config['sync_app']) + um = get_user_manager(config["sync_app"]) # Add the auth_db_path back, it should take precedence over BadUserManager - config['sync_app'].update(server_paths) - self.assertTrue(type(get_user_manager(config['sync_app']) == SqliteUserManager)) + config["sync_app"].update(server_paths) + self.assertTrue(type(get_user_manager(config["sync_app"]) == SqliteUserManager)) class SimpleUserManagerTest(unittest.TestCase): @@ -61,22 +62,18 @@ class SimpleUserManagerTest(unittest.TestCase): self._user_manager = None def test_authenticate(self): - good_test_un = 'username' - good_test_pw = 'password' - bad_test_un = 'notAUsername' - bad_test_pw = 'notAPassword' + good_test_un = "username" + good_test_pw = "password" + bad_test_un = "notAUsername" + bad_test_pw = "notAPassword" - self.assertTrue(self.user_manager.authenticate(good_test_un, - good_test_pw)) - self.assertTrue(self.user_manager.authenticate(bad_test_un, - bad_test_pw)) - self.assertTrue(self.user_manager.authenticate(good_test_un, - bad_test_pw)) - self.assertTrue(self.user_manager.authenticate(bad_test_un, - good_test_pw)) + self.assertTrue(self.user_manager.authenticate(good_test_un, good_test_pw)) + self.assertTrue(self.user_manager.authenticate(bad_test_un, bad_test_pw)) + self.assertTrue(self.user_manager.authenticate(good_test_un, bad_test_pw)) + self.assertTrue(self.user_manager.authenticate(bad_test_un, good_test_pw)) def test_userdir(self): - username = 'my_username' + username = "my_username" dirname = self.user_manager.userdir(username) self.assertEqual(dirname, username) @@ -87,8 +84,7 @@ class SqliteUserManagerTest(unittest.TestCase): self.basedir = basedir self.auth_db_path = os.path.join(basedir, "auth.db") self.collection_path = os.path.join(basedir, "collections") - self.user_manager = SqliteUserManager(self.auth_db_path, - self.collection_path) + self.user_manager = SqliteUserManager(self.auth_db_path, self.collection_path) def tearDown(self): shutil.rmtree(self.basedir) @@ -151,18 +147,22 @@ class SqliteUserManagerTest(unittest.TestCase): self.assertTrue(os.path.isdir(expected_dir_path)) def test_add_users(self): - users_data = [("my_first_username", "my_first_password"), - ("my_second_username", "my_second_password")] + users_data = [ + ("my_first_username", "my_first_password"), + ("my_second_username", "my_second_password"), + ] self.user_manager.create_auth_db() self.user_manager.add_users(users_data) user_list = self.user_manager.user_list() self.assertIn("my_first_username", user_list) self.assertIn("my_second_username", user_list) - self.assertTrue(os.path.isdir(os.path.join(self.collection_path, - "my_first_username"))) - self.assertTrue(os.path.isdir(os.path.join(self.collection_path, - "my_second_username"))) + self.assertTrue( + os.path.isdir(os.path.join(self.collection_path, "my_first_username")) + ) + self.assertTrue( + os.path.isdir(os.path.join(self.collection_path, "my_second_username")) + ) def test__add_user_to_auth_db(self): username = "my_username" @@ -191,8 +191,7 @@ class SqliteUserManagerTest(unittest.TestCase): self.user_manager.create_auth_db() self.user_manager.add_user(username, password) - self.assertTrue(self.user_manager.authenticate(username, - password)) + self.assertTrue(self.user_manager.authenticate(username, password)) def test_set_password_for_user(self): username = "my_username" @@ -203,8 +202,5 @@ class SqliteUserManagerTest(unittest.TestCase): self.user_manager.add_user(username, password) self.user_manager.set_password_for_user(username, new_password) - self.assertFalse(self.user_manager.authenticate(username, - password)) - self.assertTrue(self.user_manager.authenticate(username, - new_password)) - + self.assertFalse(self.user_manager.authenticate(username, password)) + self.assertTrue(self.user_manager.authenticate(username, new_password)) diff --git a/tests/test_web_hostkey.py b/tests/test_web_hostkey.py index b07563e..8ff2cdf 100644 --- a/tests/test_web_hostkey.py +++ b/tests/test_web_hostkey.py @@ -17,4 +17,3 @@ class SyncAppFunctionalHostKeyTest(SyncAppFunctionalTestBase): self.assertIsNone(self.server.hostKey("testuser", "wrongpassword")) self.assertIsNone(self.server.hostKey("wronguser", "wrongpassword")) self.assertIsNone(self.server.hostKey("wronguser", "testpassword")) -