lint: Auto-format code using black

This commit is contained in:
Vikash Kothary 2022-10-14 20:07:21 +01:00
parent 88d6f7f3f0
commit 684b4a2e13
30 changed files with 910 additions and 608 deletions

View File

@ -10,67 +10,77 @@ config = aqt.mw.addonManager.getConfig(__name__)
# TODO: force the user to log out before changing any of the settings
def addui(self, _):
self = self.form
parent_w = self.tab_2
parent_l = self.vboxlayout
self.useCustomServer = QCheckBox(parent_w)
self.useCustomServer.setText("Use custom sync server")
parent_l.addWidget(self.useCustomServer)
cshl = QHBoxLayout()
parent_l.addLayout(cshl)
self = self.form
parent_w = self.tab_2
parent_l = self.vboxlayout
self.useCustomServer = QCheckBox(parent_w)
self.useCustomServer.setText("Use custom sync server")
parent_l.addWidget(self.useCustomServer)
cshl = QHBoxLayout()
parent_l.addLayout(cshl)
self.serverAddrLabel = QLabel(parent_w)
self.serverAddrLabel.setText("Server address")
cshl.addWidget(self.serverAddrLabel)
self.customServerAddr = QLineEdit(parent_w)
self.customServerAddr.setPlaceholderText(DEFAULT_ADDR)
cshl.addWidget(self.customServerAddr)
self.serverAddrLabel = QLabel(parent_w)
self.serverAddrLabel.setText("Server address")
cshl.addWidget(self.serverAddrLabel)
self.customServerAddr = QLineEdit(parent_w)
self.customServerAddr.setPlaceholderText(DEFAULT_ADDR)
cshl.addWidget(self.customServerAddr)
pconfig = getprofileconfig()
if pconfig["enabled"]:
self.useCustomServer.setCheckState(Qt.Checked)
if pconfig["addr"]:
self.customServerAddr.setText(pconfig["addr"])
pconfig = getprofileconfig()
if pconfig["enabled"]:
self.useCustomServer.setCheckState(Qt.Checked)
if pconfig["addr"]:
self.customServerAddr.setText(pconfig["addr"])
self.customServerAddr.textChanged.connect(lambda text: updateserver(self, text))
def onchecked(state):
pconfig["enabled"] = state == Qt.Checked
updateui(self, state)
updateserver(self, self.customServerAddr.text())
self.useCustomServer.stateChanged.connect(onchecked)
self.customServerAddr.textChanged.connect(lambda text: updateserver(self, text))
def onchecked(state):
pconfig["enabled"] = state == Qt.Checked
updateui(self, state)
updateserver(self, self.customServerAddr.text())
self.useCustomServer.stateChanged.connect(onchecked)
updateui(self, self.useCustomServer.checkState())
updateui(self, self.useCustomServer.checkState())
def updateserver(self, text):
pconfig = getprofileconfig()
if pconfig['enabled']:
addr = text or self.customServerAddr.placeholderText()
pconfig['addr'] = addr
setserver()
aqt.mw.addonManager.writeConfig(__name__, config)
pconfig = getprofileconfig()
if pconfig["enabled"]:
addr = text or self.customServerAddr.placeholderText()
pconfig["addr"] = addr
setserver()
aqt.mw.addonManager.writeConfig(__name__, config)
def updateui(self, state):
self.serverAddrLabel.setEnabled(state == Qt.Checked)
self.customServerAddr.setEnabled(state == Qt.Checked)
self.serverAddrLabel.setEnabled(state == Qt.Checked)
self.customServerAddr.setEnabled(state == Qt.Checked)
def setserver():
pconfig = getprofileconfig()
if pconfig['enabled']:
aqt.mw.pm.profile['hostNum'] = None
anki.sync.SYNC_BASE = "%s" + pconfig['addr']
else:
anki.sync.SYNC_BASE = anki.consts.SYNC_BASE
pconfig = getprofileconfig()
if pconfig["enabled"]:
aqt.mw.pm.profile["hostNum"] = None
anki.sync.SYNC_BASE = "%s" + pconfig["addr"]
else:
anki.sync.SYNC_BASE = anki.consts.SYNC_BASE
def getprofileconfig():
if aqt.mw.pm.name not in config["profiles"]:
# inherit global settings if present (used in earlier versions of the addon)
config["profiles"][aqt.mw.pm.name] = {
"enabled": config.get("enabled", False),
"addr": config.get("addr", DEFAULT_ADDR),
}
aqt.mw.addonManager.writeConfig(__name__, config)
return config["profiles"][aqt.mw.pm.name]
if aqt.mw.pm.name not in config["profiles"]:
# inherit global settings if present (used in earlier versions of the addon)
config["profiles"][aqt.mw.pm.name] = {
"enabled": config.get("enabled", False),
"addr": config.get("addr", DEFAULT_ADDR),
}
aqt.mw.addonManager.writeConfig(__name__, config)
return config["profiles"][aqt.mw.pm.name]
addHook("profileLoaded", setserver)
aqt.preferences.Preferences.__init__ = wrap(aqt.preferences.Preferences.__init__, addui, "after")
aqt.preferences.Preferences.__init__ = wrap(
aqt.preferences.Preferences.__init__, addui, "after"
)

View File

@ -2,6 +2,7 @@ import sys
if __package__ is None and not hasattr(sys, "frozen"):
import os.path
path = os.path.realpath(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(os.path.dirname(path)))

View File

@ -30,7 +30,7 @@ class CollectionWrapper:
self.close()
def execute(self, func, args=[], kw={}, waitForReturn=True):
""" Executes the given function with the underlying anki.storage.Collection
"""Executes the given function with the underlying anki.storage.Collection
object as the first argument and any additional arguments specified by *args
and **kw.
@ -92,6 +92,7 @@ class CollectionWrapper:
"""Returns True if the collection is open, False otherwise."""
return self.__col is not None
class CollectionManager:
"""Manages a set of CollectionWrapper objects."""
@ -109,7 +110,9 @@ class CollectionManager:
try:
col = self.collections[path]
except KeyError:
col = self.collections[path] = self.collection_wrapper(self.config, path, setup_new_collection)
col = self.collections[path] = self.collection_wrapper(
self.config, path, setup_new_collection
)
return col
@ -119,19 +122,25 @@ class CollectionManager:
del self.collections[path]
col.close()
def get_collection_wrapper(config, path, setup_new_collection = None):
def get_collection_wrapper(config, path, setup_new_collection=None):
if "collection_wrapper" in config and config["collection_wrapper"]:
logger.info("Found collection_wrapper in config, using {} for "
"user data persistence".format(config['collection_wrapper']))
logger.info(
"Found collection_wrapper in config, using {} for "
"user data persistence".format(config["collection_wrapper"])
)
import importlib
import inspect
module_name, class_name = config['collection_wrapper'].rsplit('.', 1)
module_name, class_name = config["collection_wrapper"].rsplit(".", 1)
module = importlib.import_module(module_name.strip())
class_ = getattr(module, class_name.strip())
if not CollectionWrapper in inspect.getmro(class_):
raise TypeError('''"collection_wrapper" found in the conf file but it doesn''t
inherit from CollectionWrapper''')
raise TypeError(
""""collection_wrapper" found in the conf file but it doesn''t
inherit from CollectionWrapper"""
)
return class_(config, path, setup_new_collection)
else:
return CollectionWrapper(config, path, setup_new_collection)

View File

@ -7,9 +7,9 @@ logger = logging.getLogger("ankisyncd")
paths = [
"/etc/ankisyncd/ankisyncd.conf",
os.environ.get("XDG_CONFIG_HOME") and
(os.path.join(os.environ['XDG_CONFIG_HOME'], "ankisyncd", "ankisyncd.conf")) or
os.path.join(os.path.expanduser("~"), ".config", "ankisyncd", "ankisyncd.conf"),
os.environ.get("XDG_CONFIG_HOME")
and (os.path.join(os.environ["XDG_CONFIG_HOME"], "ankisyncd", "ankisyncd.conf"))
or os.path.join(os.path.expanduser("~"), ".config", "ankisyncd", "ankisyncd.conf"),
os.path.join(dirname(dirname(realpath(__file__))), "ankisyncd.conf"),
]
@ -19,11 +19,12 @@ paths = [
def load_from_env(conf):
logger.debug("Loading/overriding config values from ENV")
for env in os.environ:
if env.startswith('ANKISYNCD_'):
if env.startswith("ANKISYNCD_"):
config_key = env[10:].lower()
conf[config_key] = os.getenv(env)
logger.info("Setting {} from ENV".format(config_key))
def load(path=None):
choices = paths
parser = configparser.ConfigParser()
@ -33,7 +34,7 @@ def load(path=None):
logger.debug("config.location: trying", path)
try:
parser.read(path)
conf = parser['sync_app']
conf = parser["sync_app"]
logger.info("Loaded config from {}".format(path))
load_from_env(conf)
return conf

View File

@ -13,6 +13,7 @@ from anki.collection import Collection
logger = logging.getLogger("ankisyncd.media")
logger.setLevel(1)
class FullSyncManager:
def test_db(self, db: DB):
"""
@ -34,15 +35,14 @@ class FullSyncManager:
# Verify integrity of the received database file before replacing our
# existing db.
temp_db_path = session.get_collection_path() + ".tmp"
with open(temp_db_path, 'wb') as f:
with open(temp_db_path, "wb") as f:
f.write(data)
try:
with DB(temp_db_path) as test_db:
self.test_db(test_db)
except sqlite.Error as e:
raise HTTPBadRequest("Uploaded collection database file is "
"corrupt.")
raise HTTPBadRequest("Uploaded collection database file is " "corrupt.")
# Overwrite existing db.
col.close()
@ -69,7 +69,7 @@ class FullSyncManager:
col.close(downgrade=True)
db_path = session.get_collection_path()
try:
with open(db_path, 'rb') as tmp:
with open(db_path, "rb") as tmp:
data = tmp.read()
finally:
col.reopen()
@ -80,16 +80,21 @@ class FullSyncManager:
def get_full_sync_manager(config):
if "full_sync_manager" in config and config["full_sync_manager"]: # load from config
if (
"full_sync_manager" in config and config["full_sync_manager"]
): # load from config
import importlib
import inspect
module_name, class_name = config['full_sync_manager'].rsplit('.', 1)
module_name, class_name = config["full_sync_manager"].rsplit(".", 1)
module = importlib.import_module(module_name.strip())
class_ = getattr(module, class_name.strip())
if not FullSyncManager in inspect.getmro(class_):
raise TypeError('''"full_sync_manager" found in the conf file but it doesn''t
inherit from FullSyncManager''')
raise TypeError(
""""full_sync_manager" found in the conf file but it doesn''t
inherit from FullSyncManager"""
)
return class_(config)
else:
return FullSyncManager()

View File

@ -12,6 +12,7 @@ from anki.media import MediaManager
logger = logging.getLogger("ankisyncd.media")
class ServerMediaManager(MediaManager):
def __init__(self, col, server=True):
super().__init__(col, server)
@ -20,14 +21,16 @@ class ServerMediaManager(MediaManager):
def addMedia(self, media_to_add):
self._db.executemany(
"INSERT OR REPLACE INTO media VALUES (?,?,?)",
media_to_add
"INSERT OR REPLACE INTO media VALUES (?,?,?)", media_to_add
)
self._db.commit()
def changes(self, lastUsn):
return self._db.execute("select fname,usn,csum from media order by usn desc limit ?", self.lastUsn() - lastUsn)
return self._db.execute(
"select fname,usn,csum from media order by usn desc limit ?",
self.lastUsn() - lastUsn,
)
def connect(self):
path = self.dir() + ".server.db"
create = not os.path.exists(path)

View File

@ -43,20 +43,26 @@ class SqliteSessionManager(SimpleSessionManager):
conn = self._conn()
cursor = conn.cursor()
cursor.execute("SELECT * FROM sqlite_master "
"WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' "
"AND tbl_name = 'session'")
cursor.execute(
"SELECT * FROM sqlite_master "
"WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' "
"AND tbl_name = 'session'"
)
res = cursor.fetchone()
conn.close()
if res is not None:
raise Exception("Outdated database schema, run utils/migrate_user_tables.py")
raise Exception(
"Outdated database schema, run utils/migrate_user_tables.py"
)
def _conn(self):
new = not os.path.exists(self.session_db_path)
conn = sqlite.connect(self.session_db_path)
if new:
cursor = conn.cursor()
cursor.execute("CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, username VARCHAR, path VARCHAR)")
cursor.execute(
"CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, username VARCHAR, path VARCHAR)"
)
return conn
# Default to using sqlite3 syntax but overridable for sub-classes using other
@ -73,7 +79,9 @@ class SqliteSessionManager(SimpleSessionManager):
conn = self._conn()
cursor = conn.cursor()
cursor.execute(self.fs("SELECT skey, username, path FROM session WHERE hkey=?"), (hkey,))
cursor.execute(
self.fs("SELECT skey, username, path FROM session WHERE hkey=?"), (hkey,)
)
res = cursor.fetchone()
if res is not None:
@ -89,7 +97,9 @@ class SqliteSessionManager(SimpleSessionManager):
conn = self._conn()
cursor = conn.cursor()
cursor.execute(self.fs("SELECT hkey, username, path FROM session WHERE skey=?"), (skey,))
cursor.execute(
self.fs("SELECT hkey, username, path FROM session WHERE skey=?"), (skey,)
)
res = cursor.fetchone()
if res is not None:
@ -103,8 +113,10 @@ class SqliteSessionManager(SimpleSessionManager):
conn = self._conn()
cursor = conn.cursor()
cursor.execute("INSERT OR REPLACE INTO session (hkey, skey, username, path) VALUES (?, ?, ?, ?)",
(hkey, session.skey, session.name, session.path))
cursor.execute(
"INSERT OR REPLACE INTO session (hkey, skey, username, path) VALUES (?, ?, ?, ?)",
(hkey, session.skey, session.name, session.path),
)
conn.commit()
@ -117,25 +129,35 @@ class SqliteSessionManager(SimpleSessionManager):
cursor.execute(self.fs("DELETE FROM session WHERE hkey=?"), (hkey,))
conn.commit()
def get_session_manager(config):
if "session_db_path" in config and config["session_db_path"]:
logger.info("Found session_db_path in config, using SqliteSessionManager for auth")
return SqliteSessionManager(config['session_db_path'])
logger.info(
"Found session_db_path in config, using SqliteSessionManager for auth"
)
return SqliteSessionManager(config["session_db_path"])
elif "session_manager" in config and config["session_manager"]: # load from config
logger.info("Found session_manager in config, using {} for persisting sessions".format(
config['session_manager'])
logger.info(
"Found session_manager in config, using {} for persisting sessions".format(
config["session_manager"]
)
)
import importlib
import inspect
module_name, class_name = config['session_manager'].rsplit('.', 1)
module_name, class_name = config["session_manager"].rsplit(".", 1)
module = importlib.import_module(module_name.strip())
class_ = getattr(module, class_name.strip())
if not SimpleSessionManager in inspect.getmro(class_):
raise TypeError('''"session_manager" found in the conf file but it doesn''t
inherit from SimpleSessionManager''')
raise TypeError(
""""session_manager" found in the conf file but it doesn''t
inherit from SimpleSessionManager"""
)
return class_(config)
else:
logger.warning("Neither session_db_path nor session_manager set, "
"ankisyncd will lose sessions on application restart")
logger.warning(
"Neither session_db_path nor session_manager set, "
"ankisyncd will lose sessions on application restart"
)
return SimpleSessionManager()

View File

@ -10,7 +10,7 @@ import random
import requests
import json
import os
from typing import List,Tuple
from typing import List, Tuple
from anki.db import DB, DBError
from anki.utils import ids2str, intTime, platDesc, checksum, devMode
@ -24,38 +24,43 @@ from anki.lang import ngettext
# https://github.com/ankitects/anki/blob/04b1ca75599f18eb783a8bf0bdeeeb32362f4da0/rslib/src/sync/http_client.rs#L11
SYNC_VER = 10
# https://github.com/ankitects/anki/blob/cca3fcb2418880d0430a5c5c2e6b81ba260065b7/anki/consts.py#L50
SYNC_ZIP_SIZE = int(2.5*1024*1024)
SYNC_ZIP_SIZE = int(2.5 * 1024 * 1024)
# https://github.com/ankitects/anki/blob/cca3fcb2418880d0430a5c5c2e6b81ba260065b7/anki/consts.py#L51
SYNC_ZIP_COUNT = 25
# syncing vars
HTTP_TIMEOUT = 90
HTTP_PROXY = None
HTTP_BUF_SIZE = 64*1024
HTTP_BUF_SIZE = 64 * 1024
# Incremental syncing
##########################################################################
class Syncer(object):
def __init__(self, col, server=None):
self.col = col
self.server = server
# new added functions related to Syncer:
# these are removed from latest anki module
########################################################################
# new added functions related to Syncer:
# these are removed from latest anki module
########################################################################
def scm(self):
"""return schema"""
scm=self.col.db.scalar("select scm from col")
scm = self.col.db.scalar("select scm from col")
return scm
def increment_usn(self):
"""usn+1 in db"""
self.col.db.execute("update col set usn = usn + 1")
def set_modified_time(self,now:int):
def set_modified_time(self, now: int):
self.col.db.execute("update col set mod=?", now)
def set_last_sync(self,now:int):
def set_last_sync(self, now: int):
self.col.db.execute("update col set ls = ?", now)
#########################################################################
#########################################################################
def meta(self):
return dict(
mod=self.col.mod,
@ -64,30 +69,29 @@ class Syncer(object):
ts=intTime(),
musn=0,
msg="",
cont=True
cont=True,
)
def changes(self):
"Bundle up small objects."
d = dict(models=self.getModels(),
decks=self.getDecks(),
tags=self.getTags())
d = dict(models=self.getModels(), decks=self.getDecks(), tags=self.getTags())
if self.lnewer:
d['conf'] = self.col.all_config()
d['crt'] = self.col.crt
d["conf"] = self.col.all_config()
d["crt"] = self.col.crt
return d
def mergeChanges(self, lchg, rchg):
# then the other objects
self.mergeModels(rchg['models'])
self.mergeDecks(rchg['decks'])
if 'conf' in rchg:
self.mergeConf(rchg['conf'])
self.mergeModels(rchg["models"])
self.mergeDecks(rchg["decks"])
if "conf" in rchg:
self.mergeConf(rchg["conf"])
# this was left out of earlier betas
if 'crt' in rchg:
self.col.crt = rchg['crt']
if "crt" in rchg:
self.col.crt = rchg["crt"]
self.prepareToChunk()
# this fn was cloned from anki module(version 2.1.36)
# this fn was cloned from anki module(version 2.1.36)
def basicCheck(self) -> bool:
"Basic integrity check for syncing. True if ok."
# cards without notes
@ -118,9 +122,10 @@ select id from notes where mid = ?) limit 1"""
):
return False
return True
def sanityCheck(self):
tables=["cards",
tables = [
"cards",
"notes",
"revlog",
"graves",
@ -130,17 +135,17 @@ select id from notes where mid = ?) limit 1"""
"notetypes",
]
for tb in tables:
if self.col.db.scalar(f'select null from {tb} where usn=-1'):
return f'table had usn=-1: {tb}'
if self.col.db.scalar(f"select null from {tb} where usn=-1"):
return f"table had usn=-1: {tb}"
self.col.sched.reset()
# return summary of deck
# make sched.counts() equal to default [0,0,0]
# to make sure sync normally if sched.counts()
# are not equal between different clients due to
# different deck selection
# different deck selection
return [
list([0,0,0]),
list([0, 0, 0]),
self.col.db.scalar("select count() from cards"),
self.col.db.scalar("select count() from notes"),
self.col.db.scalar("select count() from revlog"),
@ -155,7 +160,7 @@ select id from notes where mid = ?) limit 1"""
def finish(self, now=None):
if now is not None:
# ensure we save the mod time even if no changes made
# ensure we save the mod time even if no changes made
self.set_modified_time(now)
self.set_last_sync(now)
self.increment_usn()
@ -174,17 +179,29 @@ select id from notes where mid = ?) limit 1"""
def queryTable(self, table):
lim = self.usnLim()
if table == "revlog":
return self.col.db.execute("""
return self.col.db.execute(
"""
select id, cid, ?, ease, ivl, lastIvl, factor, time, type
from revlog where %s""" % lim, self.maxUsn)
from revlog where %s"""
% lim,
self.maxUsn,
)
elif table == "cards":
return self.col.db.execute("""
return self.col.db.execute(
"""
select id, nid, did, ord, mod, ?, type, queue, due, ivl, factor, reps,
lapses, left, odue, odid, flags, data from cards where %s""" % lim, self.maxUsn)
lapses, left, odue, odid, flags, data from cards where %s"""
% lim,
self.maxUsn,
)
else:
return self.col.db.execute("""
return self.col.db.execute(
"""
select id, guid, mid, mod, ?, tags, flds, '', '', flags, data
from notes where %s""" % lim, self.maxUsn)
from notes where %s"""
% lim,
self.maxUsn,
)
def chunk(self):
buf = dict(done=False)
@ -195,95 +212,96 @@ from notes where %s""" % lim, self.maxUsn)
f"update {curTable} set usn=? where usn=-1", self.maxUsn
)
if not self.tablesLeft:
buf['done'] = True
buf["done"] = True
return buf
def applyChunk(self, chunk):
if "revlog" in chunk:
self.mergeRevlog(chunk['revlog'])
self.mergeRevlog(chunk["revlog"])
if "cards" in chunk:
self.mergeCards(chunk['cards'])
self.mergeCards(chunk["cards"])
if "notes" in chunk:
self.mergeNotes(chunk['notes'])
self.mergeNotes(chunk["notes"])
# Deletions
##########################################################################
def add_grave(self, ids: List[int], type: int,usn: int):
items=[(id,type,usn) for id in ids]
def add_grave(self, ids: List[int], type: int, usn: int):
items = [(id, type, usn) for id in ids]
# make sure table graves fields order and schema version match
# query sql1='pragma table_info(graves)' version query schema='select ver from col'
self.col.db.executemany(
"INSERT OR IGNORE INTO graves (oid, type, usn) VALUES (?, ?, ?)" ,
items)
def apply_graves(self, graves,latest_usn: int):
# remove card and the card's orphaned notes
self.col.remove_cards_and_orphaned_notes(graves['cards'])
self.add_grave(graves['cards'], REM_CARD,latest_usn)
"INSERT OR IGNORE INTO graves (oid, type, usn) VALUES (?, ?, ?)", items
)
def apply_graves(self, graves, latest_usn: int):
# remove card and the card's orphaned notes
self.col.remove_cards_and_orphaned_notes(graves["cards"])
self.add_grave(graves["cards"], REM_CARD, latest_usn)
# only notes
self.col.remove_notes(graves['notes'])
self.add_grave(graves['notes'], REM_NOTE,latest_usn)
self.col.remove_notes(graves["notes"])
self.add_grave(graves["notes"], REM_NOTE, latest_usn)
# since level 0 deck ,we only remove deck ,but backend will delete child,it is ok, the delete
# will have once effect
self.col.decks.remove(graves['decks'])
self.add_grave(graves['decks'], REM_DECK,latest_usn)
self.col.decks.remove(graves["decks"])
self.add_grave(graves["decks"], REM_DECK, latest_usn)
# Models
##########################################################################
def getModels(self):
mods = [m for m in self.col.models.all() if m['usn'] == -1]
mods = [m for m in self.col.models.all() if m["usn"] == -1]
for m in mods:
m['usn'] = self.maxUsn
m["usn"] = self.maxUsn
self.col.models.save()
return mods
def mergeModels(self, rchg):
for r in rchg:
l = self.col.models.get(r['id'])
l = self.col.models.get(r["id"])
# if missing locally or server is newer, update
if not l or r['mod'] > l['mod']:
if not l or r["mod"] > l["mod"]:
self.col.models.update(r)
# Decks
##########################################################################
def getDecks(self):
decks = [g for g in self.col.decks.all() if g['usn'] == -1]
decks = [g for g in self.col.decks.all() if g["usn"] == -1]
for g in decks:
g['usn'] = self.maxUsn
dconf = [g for g in self.col.decks.allConf() if g['usn'] == -1]
g["usn"] = self.maxUsn
dconf = [g for g in self.col.decks.allConf() if g["usn"] == -1]
for g in dconf:
g['usn'] = self.maxUsn
g["usn"] = self.maxUsn
self.col.decks.save()
return [decks, dconf]
def mergeDecks(self, rchg):
for r in rchg[0]:
l = self.col.decks.get(r['id'], False)
l = self.col.decks.get(r["id"], False)
# work around mod time being stored as string
if l and not isinstance(l['mod'], int):
l['mod'] = int(l['mod'])
if l and not isinstance(l["mod"], int):
l["mod"] = int(l["mod"])
# if missing locally or server is newer, update
if not l or r['mod'] > l['mod']:
if not l or r["mod"] > l["mod"]:
self.col.decks.update(r)
for r in rchg[1]:
try:
l = self.col.decks.getConf(r['id'])
l = self.col.decks.getConf(r["id"])
except KeyError:
l = None
# if missing locally or server is newer, update
if not l or r['mod'] > l['mod']:
if not l or r["mod"] > l["mod"]:
self.col.decks.updateConf(r)
# Tags
##########################################################################
def allItems(self) -> List[Tuple[str, int]]:
tags=self.col.db.execute("select tag, usn from tags")
return [(tag, int(usn)) for tag,usn in tags]
tags = self.col.db.execute("select tag, usn from tags")
return [(tag, int(usn)) for tag, usn in tags]
def getTags(self):
tags = []
for t, usn in self.allItems():
@ -301,15 +319,16 @@ from notes where %s""" % lim, self.maxUsn)
def mergeRevlog(self, logs):
self.col.db.executemany(
"insert or ignore into revlog values (?,?,?,?,?,?,?,?,?)",
logs)
"insert or ignore into revlog values (?,?,?,?,?,?,?,?,?)", logs
)
def newerRows(self, data, table, modIdx):
ids = (r[0] for r in data)
lmods = {}
for id, mod in self.col.db.execute(
"select id, mod from %s where id in %s and %s" % (
table, ids2str(ids), self.usnLim())):
"select id, mod from %s where id in %s and %s"
% (table, ids2str(ids), self.usnLim())
):
lmods[id] = mod
update = []
for r in data:
@ -323,14 +342,17 @@ from notes where %s""" % lim, self.maxUsn)
self.col.db.executemany(
"insert or replace into cards values "
"(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
self.newerRows(cards, "cards", 4))
self.newerRows(cards, "cards", 4),
)
def mergeNotes(self, notes):
rows = self.newerRows(notes, "notes", 3)
self.col.db.executemany(
"insert or replace into notes values (?,?,?,?,?,?,?,?,?,?,?)",
rows)
self.col.after_note_updates([f[0] for f in rows], mark_modified=False, generate_cards=False)
"insert or replace into notes values (?,?,?,?,?,?,?,?,?,?,?)", rows
)
self.col.after_note_updates(
[f[0] for f in rows], mark_modified=False, generate_cards=False
)
# Col config
##########################################################################
@ -341,11 +363,14 @@ from notes where %s""" % lim, self.maxUsn)
def mergeConf(self, conf):
for key, value in conf.items():
self.col.set_config(key, value)
# self.col.backend.set_all_config(json.dumps(conf).encode())
# Wrapper for requests that tracks upload/download progress
##########################################################################
class AnkiRequestsClient(object):
verify = True
timeout = 60
@ -355,15 +380,23 @@ class AnkiRequestsClient(object):
def post(self, url, data, headers):
data = _MonitoringFile(data)
headers['User-Agent'] = self._agentName()
headers["User-Agent"] = self._agentName()
return self.session.post(
url, data=data, headers=headers, stream=True, timeout=self.timeout, verify=self.verify)
url,
data=data,
headers=headers,
stream=True,
timeout=self.timeout,
verify=self.verify,
)
def get(self, url, headers=None):
if headers is None:
headers = {}
headers['User-Agent'] = self._agentName()
return self.session.get(url, stream=True, headers=headers, timeout=self.timeout, verify=self.verify)
headers["User-Agent"] = self._agentName()
return self.session.get(
url, stream=True, headers=headers, timeout=self.timeout, verify=self.verify
)
def streamContent(self, resp):
resp.raise_for_status()
@ -375,24 +408,30 @@ class AnkiRequestsClient(object):
def _agentName(self):
from anki import version
return "Anki {}".format(version)
# allow user to accept invalid certs in work/school settings
if os.environ.get("ANKI_NOVERIFYSSL"):
AnkiRequestsClient.verify = False
import warnings
warnings.filterwarnings("ignore")
class _MonitoringFile(io.BufferedReader):
def read(self, size=-1):
data = io.BufferedReader.read(self, HTTP_BUF_SIZE)
return data
# HTTP syncing tools
##########################################################################
class HttpSyncer(object):
def __init__(self, hkey=None, client=None, hostNum=None):
self.hkey = hkey
@ -421,24 +460,29 @@ class HttpSyncer(object):
# support file uploading, so this is the more compatible choice.
def _buildPostData(self, fobj, comp):
BOUNDARY=b"Anki-sync-boundary"
bdry = b"--"+BOUNDARY
BOUNDARY = b"Anki-sync-boundary"
bdry = b"--" + BOUNDARY
buf = io.BytesIO()
# post vars
self.postVars['c'] = 1 if comp else 0
self.postVars["c"] = 1 if comp else 0
for (key, value) in list(self.postVars.items()):
buf.write(bdry + b"\r\n")
buf.write(
('Content-Disposition: form-data; name="%s"\r\n\r\n%s\r\n' %
(key, value)).encode("utf8"))
(
'Content-Disposition: form-data; name="%s"\r\n\r\n%s\r\n'
% (key, value)
).encode("utf8")
)
# payload as raw data or json
rawSize = 0
if fobj:
# header
buf.write(bdry + b"\r\n")
buf.write(b"""\
buf.write(
b"""\
Content-Disposition: form-data; name="data"; filename="data"\r\n\
Content-Type: application/octet-stream\r\n\r\n""")
Content-Type: application/octet-stream\r\n\r\n"""
)
# write file into buffer, optionally compressing
if comp:
tgt = gzip.GzipFile(mode="wb", fileobj=buf, compresslevel=comp)
@ -453,16 +497,17 @@ Content-Type: application/octet-stream\r\n\r\n""")
rawSize += len(data)
tgt.write(data)
buf.write(b"\r\n")
buf.write(bdry + b'--\r\n')
buf.write(bdry + b"--\r\n")
size = buf.tell()
# connection headers
headers = {
'Content-Type': 'multipart/form-data; boundary=%s' % BOUNDARY.decode("utf8"),
'Content-Length': str(size),
"Content-Type": "multipart/form-data; boundary=%s"
% BOUNDARY.decode("utf8"),
"Content-Length": str(size),
}
buf.seek(0)
if size >= 100*1024*1024 or rawSize >= 250*1024*1024:
if size >= 100 * 1024 * 1024 or rawSize >= 250 * 1024 * 1024:
raise Exception("Collection too large to upload to AnkiWeb.")
return headers, buf
@ -470,7 +515,7 @@ Content-Type: application/octet-stream\r\n\r\n""")
def req(self, method, fobj=None, comp=6, badAuthRaises=True):
headers, body = self._buildPostData(fobj, comp)
r = self.client.post(self.syncURL()+method, data=body, headers=headers)
r = self.client.post(self.syncURL() + method, data=body, headers=headers)
if not badAuthRaises and r.status_code == 403:
return False
self.assertOk(r)
@ -478,9 +523,11 @@ Content-Type: application/octet-stream\r\n\r\n""")
buf = self.client.streamContent(r)
return buf
# Incremental sync over HTTP
######################################################################
class RemoteServer(HttpSyncer):
def __init__(self, hkey, hostNum):
super().__init__(self, hkey, hostNum=hostNum)
@ -489,12 +536,14 @@ class RemoteServer(HttpSyncer):
"Returns hkey or none if user/pw incorrect."
self.postVars = dict()
ret = self.req(
"hostKey", io.BytesIO(json.dumps(dict(u=user, p=pw)).encode("utf8")),
badAuthRaises=False)
"hostKey",
io.BytesIO(json.dumps(dict(u=user, p=pw)).encode("utf8")),
badAuthRaises=False,
)
if not ret:
# invalid auth
return
self.hkey = json.loads(ret.decode("utf8"))['key']
self.hkey = json.loads(ret.decode("utf8"))["key"]
return self.hkey
def meta(self):
@ -503,9 +552,17 @@ class RemoteServer(HttpSyncer):
s=self.skey,
)
ret = self.req(
"meta", io.BytesIO(json.dumps(dict(
v=SYNC_VER, cv="ankidesktop,%s,%s"%(versionWithBuild(), platDesc()))).encode("utf8")),
badAuthRaises=False)
"meta",
io.BytesIO(
json.dumps(
dict(
v=SYNC_VER,
cv="ankidesktop,%s,%s" % (versionWithBuild(), platDesc()),
)
).encode("utf8")
),
badAuthRaises=False,
)
if not ret:
# invalid auth
return
@ -537,17 +594,20 @@ class RemoteServer(HttpSyncer):
def _run(self, cmd, data):
return json.loads(
self.req(cmd, io.BytesIO(json.dumps(data).encode("utf8"))).decode("utf8"))
self.req(cmd, io.BytesIO(json.dumps(data).encode("utf8"))).decode("utf8")
)
# Full syncing
##########################################################################
class FullSyncer(HttpSyncer):
def __init__(self, col, hkey, client, hostNum):
super().__init__(self, hkey, client, hostNum=hostNum)
self.postVars = dict(
k=self.hkey,
v="ankidesktop,%s,%s"%(anki.version, platDesc()),
v="ankidesktop,%s,%s" % (anki.version, platDesc()),
)
self.col = col
@ -586,9 +646,11 @@ class FullSyncer(HttpSyncer):
return False
return True
# Remote media syncing
##########################################################################
class RemoteMediaServer(HttpSyncer):
def __init__(self, col, hkey, client, hostNum):
self.col = col
@ -597,12 +659,12 @@ class RemoteMediaServer(HttpSyncer):
def begin(self):
self.postVars = dict(
k=self.hkey,
v="ankidesktop,%s,%s"%(anki.version, platDesc())
k=self.hkey, v="ankidesktop,%s,%s" % (anki.version, platDesc())
)
ret = self._dataOnly(self.req(
"begin", io.BytesIO(json.dumps(dict()).encode("utf8"))))
self.skey = ret['sk']
ret = self._dataOnly(
self.req("begin", io.BytesIO(json.dumps(dict()).encode("utf8")))
)
self.skey = ret["sk"]
return ret
# args: lastUsn
@ -611,7 +673,8 @@ class RemoteMediaServer(HttpSyncer):
sk=self.skey,
)
return self._dataOnly(
self.req("mediaChanges", io.BytesIO(json.dumps(kw).encode("utf8"))))
self.req("mediaChanges", io.BytesIO(json.dumps(kw).encode("utf8")))
)
# args: files
def downloadFiles(self, **kw):
@ -619,20 +682,20 @@ class RemoteMediaServer(HttpSyncer):
def uploadChanges(self, zip):
# no compression, as we compress the zip file instead
return self._dataOnly(
self.req("uploadChanges", io.BytesIO(zip), comp=0))
return self._dataOnly(self.req("uploadChanges", io.BytesIO(zip), comp=0))
# args: local
def mediaSanity(self, **kw):
return self._dataOnly(
self.req("mediaSanity", io.BytesIO(json.dumps(kw).encode("utf8"))))
self.req("mediaSanity", io.BytesIO(json.dumps(kw).encode("utf8")))
)
def _dataOnly(self, resp):
resp = json.loads(resp.decode("utf8"))
if resp['err']:
self.col.log("error returned:%s"%resp['err'])
raise Exception("SyncError:%s"%resp['err'])
return resp['data']
if resp["err"]:
self.col.log("error returned:%s" % resp["err"])
raise Exception("SyncError:%s" % resp["err"])
return resp["data"]
# only for unit tests
def mediatest(self, cmd):
@ -640,5 +703,7 @@ class RemoteMediaServer(HttpSyncer):
k=self.hkey,
)
return self._dataOnly(
self.req("newMediaTest", io.BytesIO(
json.dumps(dict(cmd=cmd)).encode("utf8"))))
self.req(
"newMediaTest", io.BytesIO(json.dumps(dict(cmd=cmd)).encode("utf8"))
)
)

View File

@ -43,7 +43,16 @@ logger = logging.getLogger("ankisyncd")
class SyncCollectionHandler(Syncer):
operations = ['meta', 'applyChanges', 'start', 'applyGraves', 'chunk', 'applyChunk', 'sanityCheck2', 'finish']
operations = [
"meta",
"applyChanges",
"start",
"applyGraves",
"chunk",
"applyChunk",
"sanityCheck2",
"finish",
]
def __init__(self, col, session):
# So that 'server' (the 3rd argument) can't get set
@ -56,9 +65,9 @@ class SyncCollectionHandler(Syncer):
return False
note = {"alpha": 0, "beta": 0, "rc": 0}
client, version, platform = cv.split(',')
client, version, platform = cv.split(",")
if 'arch' not in version:
if "arch" not in version:
for name in note.keys():
if name in version:
vs = version.split(name)
@ -66,17 +75,17 @@ class SyncCollectionHandler(Syncer):
note[name] = int(vs[-1])
# convert the version string, ignoring non-numeric suffixes like in beta versions of Anki
version_nosuffix = re.sub(r'[^0-9.].*$', '', version)
version_int = [int(x) for x in version_nosuffix.split('.')]
version_nosuffix = re.sub(r"[^0-9.].*$", "", version)
version_int = [int(x) for x in version_nosuffix.split(".")]
if client == 'ankidesktop':
if client == "ankidesktop":
return version_int < [2, 0, 27]
elif client == 'ankidroid':
elif client == "ankidroid":
if version_int == [2, 3]:
if note["alpha"]:
return note["alpha"] < 4
if note["alpha"]:
return note["alpha"] < 4
else:
return version_int < [2, 2, 3]
return version_int < [2, 2, 3]
else: # unknown client, assume current version
return False
@ -84,23 +93,33 @@ class SyncCollectionHandler(Syncer):
if self._old_client(cv):
return Response(status=501) # client needs upgrade
if v > SYNC_VER:
return {"cont": False, "msg": "Your client is using unsupported sync protocol ({}, supported version: {})".format(v, SYNC_VER)}
return {
"cont": False,
"msg": "Your client is using unsupported sync protocol ({}, supported version: {})".format(
v, SYNC_VER
),
}
if v < 9 and self.col.schedVer() >= 2:
return {"cont": False, "msg": "Your client doesn't support the v{} scheduler.".format(self.col.schedVer())}
return {
"cont": False,
"msg": "Your client doesn't support the v{} scheduler.".format(
self.col.schedVer()
),
}
# Make sure the media database is open!
self.col.media.connect()
return {
'mod': self.col.mod,
'scm': self.scm(),
'usn': self.col.usn(),
'ts': anki.utils.intTime(),
'musn': self.col.media.lastUsn(),
'uname': self.session.name,
'msg': '',
'cont': True,
'hostNum': 0,
"mod": self.col.mod,
"scm": self.scm(),
"usn": self.col.usn(),
"ts": anki.utils.intTime(),
"musn": self.col.media.lastUsn(),
"uname": self.session.name,
"msg": "",
"cont": True,
"hostNum": 0,
}
def usnLim(self):
@ -108,10 +127,16 @@ class SyncCollectionHandler(Syncer):
# ankidesktop >=2.1rc2 sends graves in applyGraves, but still expects
# server-side deletions to be returned by start
def start(self, minUsn, lnewer, graves={"cards": [], "notes": [], "decks": []}, offset=None):
def start(
self,
minUsn,
lnewer,
graves={"cards": [], "notes": [], "decks": []},
offset=None,
):
# The offset para is passed by client V2 scheduler,which is minutes_west.
# Since now have not thorougly test the V2 scheduler, we leave this comments here, and
# just enable the V2 scheduler in the serve code.
# Since now have not thorougly test the V2 scheduler, we leave this comments here, and
# just enable the V2 scheduler in the serve code.
self.maxUsn = self.col.usn()
self.minUsn = minUsn
@ -122,11 +147,11 @@ class SyncCollectionHandler(Syncer):
# Only if Operations like deleting deck are performed on Ankidroid
# can (client) graves is not None
if graves is not None:
self.apply_graves(graves,self.maxUsn)
self.apply_graves(graves, self.maxUsn)
return lgraves
def applyGraves(self, chunk):
self.apply_graves(chunk,self.maxUsn)
self.apply_graves(chunk, self.maxUsn)
def applyChanges(self, changes):
self.rchg = changes
@ -136,16 +161,13 @@ class SyncCollectionHandler(Syncer):
return lchg
def sanityCheck2(self, client):
client[0]=[0,0,0]
client[0] = [0, 0, 0]
server = self.sanityCheck()
if client != server:
logger.info(
f"sanity check failed with server: {server} client: {client}"
)
logger.info(f"sanity check failed with server: {server} client: {client}")
return dict(status="bad", c=client, s=server)
return dict(status="ok")
def finish(self):
return super().finish(anki.utils.intTime(1000))
@ -158,7 +180,8 @@ class SyncCollectionHandler(Syncer):
decks = []
curs = self.col.db.execute(
"select oid, type from graves where usn >= ?", self.minUsn)
"select oid, type from graves where usn >= ?", self.minUsn
)
for oid, type in curs:
if type == REM_CARD:
@ -171,20 +194,26 @@ class SyncCollectionHandler(Syncer):
return dict(cards=cards, notes=notes, decks=decks)
def getModels(self):
return [m for m in self.col.models.all() if m['usn'] >= self.minUsn]
return [m for m in self.col.models.all() if m["usn"] >= self.minUsn]
def getDecks(self):
return [
[g for g in self.col.decks.all() if g['usn'] >= self.minUsn],
[g for g in self.col.decks.all_config() if g['usn'] >= self.minUsn]
[g for g in self.col.decks.all() if g["usn"] >= self.minUsn],
[g for g in self.col.decks.all_config() if g["usn"] >= self.minUsn],
]
def getTags(self):
return [t for t, usn in self.allItems()
if usn >= self.minUsn]
return [t for t, usn in self.allItems() if usn >= self.minUsn]
class SyncMediaHandler:
operations = ['begin', 'mediaChanges', 'mediaSanity', 'uploadChanges', 'downloadFiles']
operations = [
"begin",
"mediaChanges",
"mediaSanity",
"uploadChanges",
"downloadFiles",
]
def __init__(self, col, session):
self.col = col
@ -192,11 +221,11 @@ class SyncMediaHandler:
def begin(self, skey):
return {
'data': {
'sk': skey,
'usn': self.col.media.lastUsn(),
"data": {
"sk": skey,
"usn": self.col.media.lastUsn(),
},
'err': '',
"err": "",
}
def uploadChanges(self, data):
@ -210,24 +239,27 @@ class SyncMediaHandler:
processed_count = self._adopt_media_changes_from_zip(z)
return {
'data': [processed_count, self.col.media.lastUsn()],
'err': '',
"data": [processed_count, self.col.media.lastUsn()],
"err": "",
}
@staticmethod
def _check_zip_data(zip_file):
max_zip_size = 100*1024*1024
max_zip_size = 100 * 1024 * 1024
max_meta_file_size = 100000
meta_file_size = zip_file.getinfo("_meta").file_size
sum_file_sizes = sum(info.file_size for info in zip_file.infolist())
if meta_file_size > max_meta_file_size:
raise ValueError("Zip file's metadata file is larger than %s "
"Bytes." % max_meta_file_size)
raise ValueError(
"Zip file's metadata file is larger than %s "
"Bytes." % max_meta_file_size
)
elif sum_file_sizes > max_zip_size:
raise ValueError("Zip file contents are larger than %s Bytes." %
max_zip_size)
raise ValueError(
"Zip file contents are larger than %s Bytes." % max_zip_size
)
def _adopt_media_changes_from_zip(self, zip_file):
"""
@ -261,7 +293,7 @@ class SyncMediaHandler:
file_path = os.path.join(media_dir, filename)
# Save file to media directory.
with open(file_path, 'wb') as f:
with open(file_path, "wb") as f:
f.write(file_data)
usn += 1
@ -279,7 +311,9 @@ class SyncMediaHandler:
if media_to_add:
self.col.media.addMedia(media_to_add)
assert self.col.media.lastUsn() == oldUsn + processed_count # TODO: move to some unit test
assert (
self.col.media.lastUsn() == oldUsn + processed_count
) # TODO: move to some unit test
return processed_count
@staticmethod
@ -302,13 +336,15 @@ class SyncMediaHandler:
Marks all files in list filenames as deleted and removes them from the
media directory.
"""
logger.debug('Removing %d files from media dir.' % len(filenames))
logger.debug("Removing %d files from media dir." % len(filenames))
for filename in filenames:
try:
self.col.media.syncDelete(filename)
except OSError as err:
logger.error("Error when removing file '%s' from media dir: "
"%s" % (filename, str(err)))
logger.error(
"Error when removing file '%s' from media dir: "
"%s" % (filename, str(err))
)
def downloadFiles(self, files):
flist = {}
@ -334,13 +370,17 @@ class SyncMediaHandler:
server_lastUsn = self.col.media.lastUsn()
if lastUsn < server_lastUsn or lastUsn == 0:
for fname,usn,csum, in self.col.media.changes(lastUsn):
for (
fname,
usn,
csum,
) in self.col.media.changes(lastUsn):
result.append([fname, usn, csum])
# anki assumes server_lastUsn == result[-1][1]
# ref: anki/sync.py:720 (commit cca3fcb2418880d0430a5c5c2e6b81ba260065b7)
result.reverse()
return {'data': result, 'err': ''}
return {"data": result, "err": ""}
def mediaSanity(self, local=None):
if self.col.media.mediaCount() == local:
@ -348,7 +388,8 @@ class SyncMediaHandler:
else:
result = "FAILED"
return {'data': result, 'err': ''}
return {"data": result, "err": ""}
class SyncUserSession:
def __init__(self, name, path, collection_manager, setup_new_collection=None):
@ -371,16 +412,18 @@ class SyncUserSession:
return anki.utils.checksum(str(random.random()))[:8]
def get_collection_path(self):
return os.path.realpath(os.path.join(self.path, 'collection.anki2'))
return os.path.realpath(os.path.join(self.path, "collection.anki2"))
def get_thread(self):
return self.collection_manager.get_collection(self.get_collection_path(), self.setup_new_collection)
return self.collection_manager.get_collection(
self.get_collection_path(), self.setup_new_collection
)
def get_handler_for_operation(self, operation, col):
if operation in SyncCollectionHandler.operations:
attr, handler_class = 'collection_handler', SyncCollectionHandler
attr, handler_class = "collection_handler", SyncCollectionHandler
elif operation in SyncMediaHandler.operations:
attr, handler_class = 'media_handler', SyncMediaHandler
attr, handler_class = "media_handler", SyncMediaHandler
else:
raise Exception("no handler for {}".format(operation))
@ -391,146 +434,178 @@ class SyncUserSession:
# for inactivity and then later re-open it (creating a new Collection object).
handler.col = col
return handler
class Requests(object):
def __init__(self,environ: dict):
self.environ=environ
def __init__(self, environ: dict):
self.environ = environ
@property
def params(self):
return self.request_items_dict
@params.setter
def params(self,value):
def params(self, value):
"""
A dictionary-like object containing both the parameters from
the query string and request body.
"""
self.request_items_dict= value
self.request_items_dict = value
@property
def path(self)-> str:
return self.environ['PATH_INFO']
def path(self) -> str:
return self.environ["PATH_INFO"]
@property
def POST(self):
return self._request_items_dict
@POST.setter
def POST(self,value):
self._request_items_dict=value
def POST(self, value):
self._request_items_dict = value
@property
def parse(self):
'''Return a MultiDict containing all the variables from a form
"""Return a MultiDict containing all the variables from a form
request.\n
include not only post req,but also get'''
include not only post req,but also get"""
env = self.environ
query_string=env['QUERY_STRING']
content_len= env.get('CONTENT_LENGTH', '0')
input = env.get('wsgi.input')
length = 0 if content_len == '' else int(content_len)
body=b''
request_items_dict={}
query_string = env["QUERY_STRING"]
content_len = env.get("CONTENT_LENGTH", "0")
input = env.get("wsgi.input")
length = 0 if content_len == "" else int(content_len)
body = b""
request_items_dict = {}
if length == 0:
if input is None:
return request_items_dict
if env.get('HTTP_TRANSFER_ENCODING','0') == 'chunked':
if env.get("HTTP_TRANSFER_ENCODING", "0") == "chunked":
# readlines and read(no argument) will block
# convert byte str to number base 16
leng=int(input.readline(),16)
c=0
bdry=b''
data=[]
data_other=[]
while leng >0:
c+=1
dt = input.read(leng+2)
if c==1:
bdry=dt
elif c>=3:
leng = int(input.readline(), 16)
c = 0
bdry = b""
data = []
data_other = []
while leng > 0:
c += 1
dt = input.read(leng + 2)
if c == 1:
bdry = dt
elif c >= 3:
# data
data_other.append(dt)
leng = int(input.readline(),16)
data_other=[item for item in data_other if item!=b'\r\n\r\n']
leng = int(input.readline(), 16)
data_other = [item for item in data_other if item != b"\r\n\r\n"]
for item in data_other:
if bdry in item:
break
# only strip \r\n if there are extra \n
# eg b'?V\xc1\x8f>\xf9\xb1\n\r\n'
data.append(item[:-2])
request_items_dict['data']=b''.join(data)
others=data_other[len(data):]
boundary=others[0]
others=b''.join(others).split(boundary.strip())
request_items_dict["data"] = b"".join(data)
others = data_other[len(data) :]
boundary = others[0]
others = b"".join(others).split(boundary.strip())
others.pop()
others.pop(0)
for i in others:
i=i.splitlines()
key=re.findall(b'name="(.*?)"',i[2],flags=re.M)[0].decode('utf-8')
v=i[-1].decode('utf-8')
request_items_dict[key]=v
i = i.splitlines()
key = re.findall(b'name="(.*?)"', i[2], flags=re.M)[0].decode(
"utf-8"
)
v = i[-1].decode("utf-8")
request_items_dict[key] = v
return request_items_dict
if query_string !='':
if query_string != "":
# GET method
body=query_string
request_items_dict=urllib.parse.parse_qs(body)
for k,v in request_items_dict.items():
request_items_dict[k]=''.join(v)
body = query_string
request_items_dict = urllib.parse.parse_qs(body)
for k, v in request_items_dict.items():
request_items_dict[k] = "".join(v)
return request_items_dict
else:
body = env['wsgi.input'].read(length)
if body is None or body ==b'':
body = env["wsgi.input"].read(length)
if body is None or body == b"":
return request_items_dict
# process body to dict
repeat=body.splitlines()[0]
items=re.split(repeat,body)
repeat = body.splitlines()[0]
items = re.split(repeat, body)
# del first ,last item
items.pop()
items.pop(0)
for item in items:
if b'name="data"' in item:
data_field=None
# remove \r\n
if b'application/octet-stream' in item:
data_field = None
# remove \r\n
if b"application/octet-stream" in item:
# Ankidroid case
item=re.sub(b'Content-Disposition: form-data; name="data"; filename="data"',b'',item)
item=re.sub(b'Content-Type: application/octet-stream',b'',item)
data_field=item.strip()
item = re.sub(
b'Content-Disposition: form-data; name="data"; filename="data"',
b"",
item,
)
item = re.sub(b"Content-Type: application/octet-stream", b"", item)
data_field = item.strip()
else:
# PKzip file stream and others
item=re.sub(b'Content-Disposition: form-data; name="data"; filename="data"',b'',item)
data_field=item.strip()
request_items_dict['data']=data_field
item = re.sub(
b'Content-Disposition: form-data; name="data"; filename="data"',
b"",
item,
)
data_field = item.strip()
request_items_dict["data"] = data_field
continue
item=re.sub(b'\r\n',b'',item,flags=re.MULTILINE)
key=re.findall(b'name="(.*?)"',item)[0].decode('utf-8')
v=item[item.rfind(b'"')+1:].decode('utf-8')
request_items_dict[key]=v
item = re.sub(b"\r\n", b"", item, flags=re.MULTILINE)
key = re.findall(b'name="(.*?)"', item)[0].decode("utf-8")
v = item[item.rfind(b'"') + 1 :].decode("utf-8")
request_items_dict[key] = v
return request_items_dict
class chunked(object):
'''decorator'''
"""decorator"""
def __init__(self, func):
wraps(func)(self)
def __call__(self, *args, **kwargs):
clss=args[0]
environ=args[1]
clss = args[0]
environ = args[1]
start_response = args[2]
b=Requests(environ)
args=(clss,b,)
w= self.__wrapped__(*args, **kwargs)
resp=Response(w)
b = Requests(environ)
args = (
clss,
b,
)
w = self.__wrapped__(*args, **kwargs)
resp = Response(w)
return resp(environ, start_response)
def __get__(self, instance, cls):
if instance is None:
return self
else:
return types.MethodType(self, instance)
class SyncApp:
valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download']
valid_urls = (
SyncCollectionHandler.operations
+ SyncMediaHandler.operations
+ ["hostKey", "upload", "download"]
)
def __init__(self, config):
from ankisyncd.thread import get_collection_manager
self.data_root = os.path.abspath(config['data_root'])
self.base_url = config['base_url']
self.base_media_url = config['base_media_url']
self.data_root = os.path.abspath(config["data_root"])
self.base_url = config["base_url"]
self.base_media_url = config["base_media_url"]
self.setup_new_collection = None
self.user_manager = get_user_manager(config)
@ -539,22 +614,31 @@ class SyncApp:
self.collection_manager = get_collection_manager(config)
# make sure the base_url has a trailing slash
if not self.base_url.endswith('/'):
self.base_url += '/'
if not self.base_media_url.endswith('/'):
self.base_media_url += '/'
if not self.base_url.endswith("/"):
self.base_url += "/"
if not self.base_media_url.endswith("/"):
self.base_media_url += "/"
def generateHostKey(self, username):
"""Generates a new host key to be used by the given username to identify their session.
This values is random."""
import hashlib, time, random, string
chars = string.ascii_letters + string.digits
val = ':'.join([username, str(int(time.time())), ''.join(random.choice(chars) for x in range(8))]).encode()
val = ":".join(
[
username,
str(int(time.time())),
"".join(random.choice(chars) for x in range(8)),
]
).encode()
return hashlib.md5(val).hexdigest()
def create_session(self, username, user_path):
return SyncUserSession(username, user_path, self.collection_manager, self.setup_new_collection)
return SyncUserSession(
username, user_path, self.collection_manager, self.setup_new_collection
)
def _decode_data(self, data, compression=0):
if compression:
@ -564,7 +648,7 @@ class SyncApp:
try:
data = json.loads(data.decode())
except (ValueError, UnicodeDecodeError):
data = {'data': data}
data = {"data": data}
return data
@ -581,7 +665,7 @@ class SyncApp:
session = self.create_session(username, user_path)
self.session_manager.save(hkey, session)
return {'key': hkey}
return {"key": hkey}
def operation_upload(self, col, data, session):
# Verify integrity of the received database file before replacing our
@ -599,56 +683,56 @@ class SyncApp:
# cgi file can only be read once,and will be blocked after being read once more
# so i call Requests.parse only once,and bind its return result to properties
# POST and params (set return result as property values)
req.params=req.parse
req.POST=req.params
req.params = req.parse
req.POST = req.params
try:
hkey = req.params['k']
hkey = req.params["k"]
except KeyError:
hkey = None
session = self.session_manager.load(hkey, self.create_session)
if session is None:
try:
skey = req.POST['sk']
skey = req.POST["sk"]
session = self.session_manager.load_from_skey(skey, self.create_session)
except KeyError:
skey = None
try:
compression = int(req.POST['c'])
compression = int(req.POST["c"])
except KeyError:
compression = 0
try:
data = req.POST['data']
data = req.POST["data"]
data = self._decode_data(data, compression)
except KeyError:
data = {}
if req.path.startswith(self.base_url):
url = req.path[len(self.base_url):]
url = req.path[len(self.base_url) :]
if url not in self.valid_urls:
raise HTTPNotFound()
if url == 'hostKey':
if url == "hostKey":
result = self.operation_hostKey(data.get("u"), data.get("p"))
if result:
return json.dumps(result)
else:
# TODO: do I have to pass 'null' for the client to receive None?
raise HTTPForbidden('null')
raise HTTPForbidden("null")
if session is None:
raise HTTPForbidden()
if url in SyncCollectionHandler.operations + SyncMediaHandler.operations:
# 'meta' passes the SYNC_VER but it isn't used in the handler
if url == 'meta':
if session.skey == None and 's' in req.POST:
session.skey = req.POST['s']
if 'v' in data:
session.version = data['v']
if 'cv' in data:
session.client_version = data['cv']
if url == "meta":
if session.skey == None and "s" in req.POST:
session.skey = req.POST["s"]
if "v" in data:
session.version = data["v"]
if "cv" in data:
session.client_version = data["cv"]
self.session_manager.save(hkey, session)
session = self.session_manager.load(hkey, self.create_session)
@ -660,12 +744,12 @@ class SyncApp:
return result
elif url == 'upload':
elif url == "upload":
thread = session.get_thread()
result = thread.execute(self.operation_upload, [data['data'], session])
result = thread.execute(self.operation_upload, [data["data"], session])
return result
elif url == 'download':
elif url == "download":
thread = session.get_thread()
result = thread.execute(self.operation_download, [session])
return result
@ -678,13 +762,13 @@ class SyncApp:
if session is None:
raise HTTPForbidden()
url = req.path[len(self.base_media_url):]
url = req.path[len(self.base_media_url) :]
if url not in self.valid_urls:
raise HTTPNotFound()
if url == "begin":
data['skey'] = session.skey
data["skey"] = session.skey
result = self._execute_handler_method_in_thread(url, data, session)
@ -726,10 +810,16 @@ class SyncApp:
def make_app(global_conf, **local_conf):
return SyncApp(**local_conf)
def main():
logging.basicConfig(level=logging.INFO, format="[%(asctime)s]:%(levelname)s:%(name)s:%(message)s")
logging.basicConfig(
level=logging.INFO, format="[%(asctime)s]:%(levelname)s:%(name)s:%(message)s"
)
import ankisyncd
logger.info("ankisyncd {} ({})".format(ankisyncd._get_version(), ankisyncd._homepage))
logger.info(
"ankisyncd {} ({})".format(ankisyncd._get_version(), ankisyncd._homepage)
)
from wsgiref.simple_server import make_server, WSGIRequestHandler
from ankisyncd.thread import shutdown
import ankisyncd.config
@ -738,10 +828,10 @@ def main():
logger = logging.getLogger("ankisyncd.http")
def log_error(self, format, *args):
self.logger.error("%s %s", self.address_string(), format%args)
self.logger.error("%s %s", self.address_string(), format % args)
def log_message(self, format, *args):
self.logger.info("%s %s", self.address_string(), format%args)
self.logger.info("%s %s", self.address_string(), format % args)
if len(sys.argv) > 1:
# backwards compat
@ -750,7 +840,9 @@ def main():
config = ankisyncd.config.load()
ankiserver = SyncApp(config)
httpd = make_server(config['host'], int(config['port']), ankiserver, handler_class=RequestHandler)
httpd = make_server(
config["host"], int(config["port"]), ankiserver, handler_class=RequestHandler
)
try:
logger.info("Serving HTTP on {} port {}...".format(*httpd.server_address))
@ -760,5 +852,6 @@ def main():
finally:
shutdown()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -5,6 +5,7 @@ from queue import Queue
import time, logging
def short_repr(obj, logger=logging.getLogger(), maxlen=80):
"""Like repr, but shortens strings and bytestrings if logger's logging level
is above DEBUG. Currently shallow and very limited, only implemented for
@ -28,6 +29,7 @@ def short_repr(obj, logger=logging.getLogger(), maxlen=80):
return repr(o)
class ThreadingCollectionWrapper:
"""Provides the same interface as CollectionWrapper, but it creates a new Thread to
interact with the collection."""
@ -56,10 +58,11 @@ class ThreadingCollectionWrapper:
def current(self):
from threading import current_thread
return current_thread() == self._thread
def execute(self, func, args=[], kw={}, waitForReturn=True):
""" Executes a given function on this thread with the *args and **kw.
"""Executes a given function on this thread with the *args and **kw.
If 'waitForReturn' is True, then it will block until the function has
executed and return its return value. If False, it will return None
@ -86,19 +89,30 @@ class ThreadingCollectionWrapper:
while self._running:
func, args, kw, return_queue = self._queue.get(True)
if hasattr(func, '__name__'):
if hasattr(func, "__name__"):
func_name = func.__name__
else:
func_name = func.__class__.__name__
self.logger.info("Running %s(*%s, **%s)", func_name, short_repr(args, self.logger), short_repr(kw, self.logger))
self.logger.info(
"Running %s(*%s, **%s)",
func_name,
short_repr(args, self.logger),
short_repr(kw, self.logger),
)
self.last_timestamp = time.time()
try:
ret = self.wrapper.execute(func, args, kw, return_queue)
except Exception as e:
self.logger.error("Unable to %s(*%s, **%s): %s",
func_name, repr(args), repr(kw), e, exc_info=True)
self.logger.error(
"Unable to %s(*%s, **%s): %s",
func_name,
repr(args),
repr(kw),
e,
exc_info=True,
)
# we return the Exception which will be raise'd on the other end
ret = e
@ -125,10 +139,11 @@ class ThreadingCollectionWrapper:
def stop(self):
def _stop(col):
self._running = False
self.execute(_stop, waitForReturn=False)
def stop_and_wait(self):
""" Tell the thread to stop and wait for it to happen. """
"""Tell the thread to stop and wait for it to happen."""
self.stop()
if self._thread is not None:
self._thread.join()
@ -146,11 +161,13 @@ class ThreadingCollectionWrapper:
def _close(col):
self.wrapper.close()
self.execute(_close, waitForReturn=False)
def opened(self):
return self.wrapper.opened()
class ThreadingCollectionManager(CollectionManager):
"""Manages a set of ThreadingCollectionWrapper objects."""
@ -174,14 +191,21 @@ class ThreadingCollectionManager(CollectionManager):
# TODO: it would be awesome to have a safe way to stop inactive threads completely!
# TODO: we need a way to inform other code that the collection has been closed
def _monitor_run(self):
""" Monitors threads for inactivity and closes the collection on them
"""Monitors threads for inactivity and closes the collection on them
(leaves the thread itself running -- hopefully waiting peacefully with only a
small memory footprint!) """
small memory footprint!)"""
while True:
cur = time.time()
for path, thread in self.collections.items():
if thread.running and thread.wrapper.opened() and thread.qempty() and cur - thread.last_timestamp >= self.monitor_inactivity:
self.logger.info("Monitor is closing collection on inactive %s", thread)
if (
thread.running
and thread.wrapper.opened()
and thread.qempty()
and cur - thread.last_timestamp >= self.monitor_inactivity
):
self.logger.info(
"Monitor is closing collection on inactive %s", thread
)
thread.close()
time.sleep(self.monitor_frequency)
@ -196,12 +220,14 @@ class ThreadingCollectionManager(CollectionManager):
# let the parent do whatever else it might want to do...
super(ThreadingCollectionManager, self).shutdown()
#
# For working with the global ThreadingCollectionManager:
#
collection_manager = None
def get_collection_manager(config):
"""Return the global ThreadingCollectionManager for this process."""
global collection_manager
@ -209,10 +235,10 @@ def get_collection_manager(config):
collection_manager = ThreadingCollectionManager(config)
return collection_manager
def shutdown():
"""If the global ThreadingCollectionManager exists, shut it down."""
global collection_manager
if collection_manager is not None:
collection_manager.shutdown()
collection_manager = None

View File

@ -11,7 +11,7 @@ logger = logging.getLogger("ankisyncd.users")
class SimpleUserManager:
"""A simple user manager that always allows any user."""
def __init__(self, collection_path=''):
def __init__(self, collection_path=""):
self.collection_path = collection_path
def authenticate(self, username, password):
@ -34,8 +34,11 @@ class SimpleUserManager:
def _create_user_dir(self, username):
user_dir_path = os.path.join(self.collection_path, username)
if not os.path.isdir(user_dir_path):
logger.info("Creating collection directory for user '{}' at {}"
.format(username, user_dir_path))
logger.info(
"Creating collection directory for user '{}' at {}".format(
username, user_dir_path
)
)
os.makedirs(user_dir_path)
@ -54,13 +57,17 @@ class SqliteUserManager(SimpleUserManager):
conn = self._conn()
cursor = conn.cursor()
cursor.execute("SELECT * FROM sqlite_master "
"WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' "
"AND tbl_name = 'auth'")
cursor.execute(
"SELECT * FROM sqlite_master "
"WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' "
"AND tbl_name = 'auth'"
)
res = cursor.fetchone()
conn.close()
if res is not None:
raise Exception("Outdated database schema, run utils/migrate_user_tables.py")
raise Exception(
"Outdated database schema, run utils/migrate_user_tables.py"
)
# Default to using sqlite3 but overridable for sub-classes using other
# DB API 2 driver variants
@ -124,8 +131,7 @@ class SqliteUserManager(SimpleUserManager):
conn = self._conn()
cursor = conn.cursor()
logger.info("Adding user '{}' to auth db.".format(username))
cursor.execute(self.fs("INSERT INTO auth VALUES (?, ?)"),
(username, pass_hash))
cursor.execute(self.fs("INSERT INTO auth VALUES (?, ?)"), (username, pass_hash))
conn.commit()
conn.close()
@ -139,7 +145,9 @@ class SqliteUserManager(SimpleUserManager):
conn = self._conn()
cursor = conn.cursor()
cursor.execute(self.fs("UPDATE auth SET hash=? WHERE username=?"), (hash, username))
cursor.execute(
self.fs("UPDATE auth SET hash=? WHERE username=?"), (hash, username)
)
conn.commit()
conn.close()
@ -156,8 +164,9 @@ class SqliteUserManager(SimpleUserManager):
conn.close()
if db_hash is None:
logger.info("Authentication failed for nonexistent user {}."
.format(username))
logger.info(
"Authentication failed for nonexistent user {}.".format(username)
)
return False
expected_value = str(db_hash[0])
@ -181,17 +190,22 @@ class SqliteUserManager(SimpleUserManager):
@staticmethod
def _create_pass_hash(username, password):
salt = binascii.b2a_hex(os.urandom(8))
pass_hash = (hashlib.sha256((username + password).encode() + salt).hexdigest() +
salt.decode())
pass_hash = (
hashlib.sha256((username + password).encode() + salt).hexdigest()
+ salt.decode()
)
return pass_hash
def create_auth_db(self):
conn = self._conn()
cursor = conn.cursor()
logger.info("Creating auth db at {}."
.format(self.auth_db_path))
cursor.execute(self.fs("""CREATE TABLE IF NOT EXISTS auth
(username VARCHAR PRIMARY KEY, hash VARCHAR)"""))
logger.info("Creating auth db at {}.".format(self.auth_db_path))
cursor.execute(
self.fs(
"""CREATE TABLE IF NOT EXISTS auth
(username VARCHAR PRIMARY KEY, hash VARCHAR)"""
)
)
conn.commit()
conn.close()
@ -199,19 +213,28 @@ class SqliteUserManager(SimpleUserManager):
def get_user_manager(config):
if "auth_db_path" in config and config["auth_db_path"]:
logger.info("Found auth_db_path in config, using SqliteUserManager for auth")
return SqliteUserManager(config['auth_db_path'], config['data_root'])
return SqliteUserManager(config["auth_db_path"], config["data_root"])
elif "user_manager" in config and config["user_manager"]: # load from config
logger.info("Found user_manager in config, using {} for auth".format(config['user_manager']))
logger.info(
"Found user_manager in config, using {} for auth".format(
config["user_manager"]
)
)
import importlib
import inspect
module_name, class_name = config['user_manager'].rsplit('.', 1)
module_name, class_name = config["user_manager"].rsplit(".", 1)
module = importlib.import_module(module_name.strip())
class_ = getattr(module, class_name.strip())
if not SimpleUserManager in inspect.getmro(class_):
raise TypeError('''"user_manager" found in the conf file but it doesn''t
inherit from SimpleUserManager''')
raise TypeError(
""""user_manager" found in the conf file but it doesn''t
inherit from SimpleUserManager"""
)
return class_(config)
else:
logger.warning("neither auth_db_path nor user_manager set, ankisyncd will accept any password")
return SimpleUserManager()
logger.warning(
"neither auth_db_path nor user_manager set, ankisyncd will accept any password"
)
return SimpleUserManager()

View File

@ -1,4 +1,4 @@
from ankisyncd_cli import ankisyncctl
if __name__ == '__main__':
ankisyncctl.main()
if __name__ == "__main__":
ankisyncctl.main()

View File

@ -9,6 +9,7 @@ from ankisyncd.users import get_user_manager
config = config.load()
def usage():
print("usage: {} <command> [<args>]".format(sys.argv[0]))
print()
@ -18,12 +19,14 @@ def usage():
print(" lsuser - list users")
print(" passwd <username> - change password of a user")
def adduser(username):
password = getpass.getpass("Enter password for {}: ".format(username))
user_manager = get_user_manager(config)
user_manager.add_user(username, password)
def deluser(username):
user_manager = get_user_manager(config)
try:
@ -31,6 +34,7 @@ def deluser(username):
except ValueError as error:
print("Could not delete user {}: {}".format(username, error), file=sys.stderr)
def lsuser():
user_manager = get_user_manager(config)
try:
@ -40,6 +44,7 @@ def lsuser():
except ValueError as error:
print("Could not list users: {}".format(error), file=sys.stderr)
def passwd(username):
user_manager = get_user_manager(config)
@ -51,7 +56,11 @@ def passwd(username):
try:
user_manager.set_password_for_user(username, password)
except ValueError as error:
print("Could not set password for user {}: {}".format(username, error), file=sys.stderr)
print(
"Could not set password for user {}: {}".format(username, error),
file=sys.stderr,
)
def main():
argc = len(sys.argv)
@ -78,5 +87,6 @@ def main():
usage()
exit(1)
if __name__ == "__main__":
main()

View File

@ -7,11 +7,13 @@ word in many other SQL dialects.
"""
import os
import sys
path = os.path.realpath(os.path.abspath(os.path.join(__file__, '../')))
path = os.path.realpath(os.path.abspath(os.path.join(__file__, "../")))
sys.path.insert(0, os.path.dirname(path))
import sqlite3
import ankisyncd.config
conf = ankisyncd.config.load()
@ -21,15 +23,21 @@ def main():
conn = sqlite3.connect(conf["auth_db_path"])
cursor = conn.cursor()
cursor.execute("SELECT * FROM sqlite_master "
"WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' "
"AND tbl_name = 'auth'")
cursor.execute(
"SELECT * FROM sqlite_master "
"WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' "
"AND tbl_name = 'auth'"
)
res = cursor.fetchone()
if res is not None:
cursor.execute("ALTER TABLE auth RENAME TO auth_old")
cursor.execute("CREATE TABLE auth (username VARCHAR PRIMARY KEY, hash VARCHAR)")
cursor.execute("INSERT INTO auth (username, hash) SELECT user, hash FROM auth_old")
cursor.execute(
"CREATE TABLE auth (username VARCHAR PRIMARY KEY, hash VARCHAR)"
)
cursor.execute(
"INSERT INTO auth (username, hash) SELECT user, hash FROM auth_old"
)
cursor.execute("DROP TABLE auth_old")
conn.commit()
print("Successfully updated table 'auth'")
@ -44,17 +52,23 @@ def main():
conn = sqlite3.connect(conf["session_db_path"])
cursor = conn.cursor()
cursor.execute("SELECT * FROM sqlite_master "
"WHERE sql LIKE '%user VARCHAR%' "
"AND tbl_name = 'session'")
cursor.execute(
"SELECT * FROM sqlite_master "
"WHERE sql LIKE '%user VARCHAR%' "
"AND tbl_name = 'session'"
)
res = cursor.fetchone()
if res is not None:
cursor.execute("ALTER TABLE session RENAME TO session_old")
cursor.execute("CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, "
"username VARCHAR, path VARCHAR)")
cursor.execute("INSERT INTO session (hkey, skey, username, path) "
"SELECT hkey, skey, user, path FROM session_old")
cursor.execute(
"CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, "
"username VARCHAR, path VARCHAR)"
)
cursor.execute(
"INSERT INTO session (hkey, skey, username, path) "
"SELECT hkey, skey, user, path FROM session_old"
)
cursor.execute("DROP TABLE session_old")
conn.commit()
print("Successfully updated table 'session'")

View File

@ -8,5 +8,5 @@ setup(
author="Anki Community",
author_email="kothary.vikash+ankicommunity@gmail.com",
packages=find_packages(),
url='https://ankicommunity.github.io/'
url="https://ankicommunity.github.io/",
)

View File

@ -16,7 +16,7 @@ class CollectionTestBase(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
self.collection_path = os.path.join(self.temp_dir, 'collection.anki2');
self.collection_path = os.path.join(self.temp_dir, "collection.anki2")
cm = CollectionManager({})
collectionWrapper = cm.get_collection(self.collection_path)
self.collection = collectionWrapper._get_collection()
@ -32,26 +32,26 @@ class CollectionTestBase(unittest.TestCase):
def add_note(self, data):
from anki.notes import Note
model = self.collection.models.byName(data['model'])
model = self.collection.models.byName(data["model"])
note = Note(self.collection, model)
for name, value in data['fields'].items():
for name, value in data["fields"].items():
note[name] = value
if 'tags' in data:
note.setTagsFromStr(data['tags'])
if "tags" in data:
note.setTagsFromStr(data["tags"])
self.collection.addNote(note)
# TODO: refactor into a parent class
def add_default_note(self, count=1):
data = {
'model': 'Basic',
'fields': {
'Front': 'The front',
'Back': 'The back',
"model": "Basic",
"fields": {
"Front": "The front",
"Back": "The back",
},
'tags': "Tag1 Tag2",
"tags": "Tag1 Tag2",
}
for idx in range(0, count):
self.add_note(data)

View File

@ -1 +1 @@
from . import db_utils
from . import db_utils

View File

@ -5,6 +5,7 @@ import tempfile
from anki.collection import Collection
class CollectionUtils:
"""
Provides utility methods for creating, inspecting and manipulating anki

View File

@ -34,7 +34,7 @@ def to_sql(database):
else:
connection = database
res = '\n'.join(connection.iterdump())
res = "\n".join(connection.iterdump())
if type(database) == str:
connection.close()
@ -54,18 +54,15 @@ def diff(left_db_path, right_db_path):
command = ["sqldiff", left_db_path, right_db_path]
child_process = subprocess.Popen(command,
shell=False,
stdout=subprocess.PIPE)
child_process = subprocess.Popen(command, shell=False, stdout=subprocess.PIPE)
stdout, stderr = child_process.communicate()
exit_code = child_process.returncode
if exit_code != 0 or stderr is not None:
raise RuntimeError("Command {} encountered an error, exit "
"code: {}, stderr: {}"
.format(" ".join(command),
exit_code,
stderr))
raise RuntimeError(
"Command {} encountered an error, exit "
"code: {}, stderr: {}".format(" ".join(command), exit_code, stderr)
)
# Any output from sqldiff means the databases differ.
# Any output from sqldiff means the databases differ.
return stdout != ""

View File

@ -21,9 +21,8 @@ class MockServerConnection:
r = self.test_app.post(url, params=data.read(), headers=headers, status="*")
return types.SimpleNamespace(status_code=r.status_int, body=r.body)
def streamContent(self, r):
return r.body
return r.body
class MockRemoteServer(RemoteServer):

View File

@ -12,9 +12,7 @@ mediamanager_orig_funcs = {
"_logChanges": None,
}
db_orig_funcs = {
"__init__": None
}
db_orig_funcs = {"__init__": None}
def monkeypatch_mediamanager():
@ -35,6 +33,7 @@ def monkeypatch_mediamanager():
os.chdir(old_cwd)
return res
return wrapper
MediaManager.findChanges = make_cwd_safe(MediaManager.findChanges)
@ -47,6 +46,7 @@ def unpatch_mediamanager():
mediamanager_orig_funcs["findChanges"] = None
def monkeypatch_db():
"""
Monkey patches Anki's DB.__init__ to connect to allow access to the db
@ -58,9 +58,7 @@ def monkeypatch_db():
def patched___init__(self, path, text=None, timeout=0):
# Code taken from Anki's DB.__init__()
# Allow more than one thread to use this connection.
self._db = sqlite.connect(path,
timeout=timeout,
check_same_thread=False)
self._db = sqlite.connect(path, timeout=timeout, check_same_thread=False)
if text:
self._db.text_factory = text
self._path = path

View File

@ -23,50 +23,58 @@ def create_server_paths():
"data_root": os.path.join(dir, "data"),
}
def create_sync_app(server_paths, config_path):
config = configparser.ConfigParser()
config.read(config_path)
# Use custom files and dirs in settings.
config['sync_app'].update(server_paths)
config["sync_app"].update(server_paths)
return SyncApp(config["sync_app"])
return SyncApp(config['sync_app'])
def get_session_for_hkey(server, hkey):
return server.session_manager.load(hkey)
def get_thread_for_hkey(server, hkey):
session = get_session_for_hkey(server, hkey)
thread = session.get_thread()
return thread
def get_col_wrapper_for_hkey(server, hkey):
thread = get_thread_for_hkey(server, hkey)
col_wrapper = thread.wrapper
return col_wrapper
def get_col_for_hkey(server, hkey):
col_wrapper = get_col_wrapper_for_hkey(server, hkey)
col_wrapper.open() # Make sure the col is opened.
return col_wrapper._CollectionWrapper__col
def get_col_db_path_for_hkey(server, hkey):
col = get_col_for_hkey(server, hkey)
return col.db._path
def get_syncer_for_hkey(server, hkey, syncer_type='collection'):
def get_syncer_for_hkey(server, hkey, syncer_type="collection"):
col = get_col_for_hkey(server, hkey)
session = get_session_for_hkey(server, hkey)
syncer_type = syncer_type.lower()
if syncer_type == 'collection':
if syncer_type == "collection":
handler_method = SyncCollectionHandler.operations[0]
elif syncer_type == 'media':
elif syncer_type == "media":
handler_method = SyncMediaHandler.operations[0]
return session.get_handler_for_operation(handler_method, col)
def add_files_to_client_mediadb(media, filepaths, update_db=False):
for filepath in filepaths:
logging.debug("Adding file '{}' to client media DB".format(filepath))
@ -76,16 +84,15 @@ def add_files_to_client_mediadb(media, filepaths, update_db=False):
if update_db:
media.findChanges() # Write changes to db.
def add_files_to_server_mediadb(media, filepaths):
for filepath in filepaths:
logging.debug("Adding file '{}' to server media DB".format(filepath))
fname = os.path.basename(filepath)
with open(filepath, 'rb') as infile:
with open(filepath, "rb") as infile:
data = infile.read()
csum = anki.utils.checksum(data)
with open(os.path.join(media.dir(), fname), 'wb') as f:
with open(os.path.join(media.dir(), fname), "wb") as f:
f.write(data)
media.addMedia(
((fname, media.lastUsn() + 1, csum),)
)
media.addMedia(((fname, media.lastUsn() + 1, csum),))

View File

@ -11,7 +11,6 @@ from helpers.monkey_patches import monkeypatch_db, unpatch_db
class SyncAppFunctionalTestBase(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.colutils = CollectionUtils()
@ -28,28 +27,29 @@ class SyncAppFunctionalTestBase(unittest.TestCase):
self.server_paths = helpers.server_utils.create_server_paths()
# Add a test user to the temp auth db the server will use.
self.user_manager = SqliteUserManager(self.server_paths['auth_db_path'],
self.server_paths['data_root'])
self.user_manager.add_user('testuser', 'testpassword')
self.user_manager = SqliteUserManager(
self.server_paths["auth_db_path"], self.server_paths["data_root"]
)
self.user_manager.add_user("testuser", "testpassword")
# Get absolute path to development ini file.
script_dir = os.path.dirname(os.path.realpath(__file__))
ini_file_path = os.path.join(script_dir,
"assets",
"test.conf")
ini_file_path = os.path.join(script_dir, "assets", "test.conf")
# Create SyncApp instance using the dev ini file and the temporary
# paths.
self.server_app = helpers.server_utils.create_sync_app(self.server_paths,
ini_file_path)
self.server_app = helpers.server_utils.create_sync_app(
self.server_paths, ini_file_path
)
# Wrap the SyncApp object in TestApp instance for testing.
self.server_test_app = TestApp(self.server_app)
# MockRemoteServer instance needed for testing normal collection
# syncing and for retrieving hkey for other tests.
self.mock_remote_server = MockRemoteServer(hkey=None,
server_test_app=self.server_test_app)
self.mock_remote_server = MockRemoteServer(
hkey=None, server_test_app=self.server_test_app
)
def tearDown(self):
self.server_paths = {}

View File

@ -9,39 +9,52 @@ from ankisyncd.collection import get_collection_wrapper
import helpers.server_utils
class FakeCollectionWrapper(CollectionWrapper):
def __init__(self, config, path, setup_new_collection=None):
self. _CollectionWrapper__col = None
self._CollectionWrapper__col = None
pass
class BadCollectionWrapper:
pass
class CollectionWrapperFactoryTest(unittest.TestCase):
def test_get_collection_wrapper(self):
# Get absolute path to development ini file.
script_dir = os.path.dirname(os.path.realpath(__file__))
ini_file_path = os.path.join(script_dir,
"assets",
"test.conf")
ini_file_path = os.path.join(script_dir, "assets", "test.conf")
# Create temporary files and dirs the server will use.
server_paths = helpers.server_utils.create_server_paths()
config = configparser.ConfigParser()
config.read(ini_file_path)
path = os.path.realpath('fake/collection.anki2')
path = os.path.realpath("fake/collection.anki2")
# Use custom files and dirs in settings. Should be CollectionWrapper
config['sync_app'].update(server_paths)
self.assertTrue(type(get_collection_wrapper(config['sync_app'], path) == CollectionWrapper))
config["sync_app"].update(server_paths)
self.assertTrue(
type(get_collection_wrapper(config["sync_app"], path) == CollectionWrapper)
)
# A conf-specified CollectionWrapper is loaded
config.set("sync_app", "collection_wrapper", 'test_collection_wrappers.FakeCollectionWrapper')
self.assertTrue(type(get_collection_wrapper(config['sync_app'], path)) == FakeCollectionWrapper)
config.set(
"sync_app",
"collection_wrapper",
"test_collection_wrappers.FakeCollectionWrapper",
)
self.assertTrue(
type(get_collection_wrapper(config["sync_app"], path))
== FakeCollectionWrapper
)
# Should fail at load time if the class doesn't inherit from CollectionWrapper
config.set("sync_app", "collection_wrapper", 'test_collection_wrappers.BadCollectionWrapper')
config.set(
"sync_app",
"collection_wrapper",
"test_collection_wrappers.BadCollectionWrapper",
)
with self.assertRaises(TypeError):
pm = get_collection_wrapper(config['sync_app'], path)
pm = get_collection_wrapper(config["sync_app"], path)

View File

@ -8,20 +8,21 @@ from ankisyncd.full_sync import FullSyncManager, get_full_sync_manager
import helpers.server_utils
class FakeFullSyncManager(FullSyncManager):
def __init__(self, config):
pass
class BadFullSyncManager:
pass
class FullSyncManagerFactoryTest(unittest.TestCase):
def test_get_full_sync_manager(self):
# Get absolute path to development ini file.
script_dir = os.path.dirname(os.path.realpath(__file__))
ini_file_path = os.path.join(script_dir,
"assets",
"test.conf")
ini_file_path = os.path.join(script_dir, "assets", "test.conf")
# Create temporary files and dirs the server will use.
server_paths = helpers.server_utils.create_server_paths()
@ -30,15 +31,20 @@ class FullSyncManagerFactoryTest(unittest.TestCase):
config.read(ini_file_path)
# Use custom files and dirs in settings. Should be PersistenceManager
config['sync_app'].update(server_paths)
self.assertTrue(type(get_full_sync_manager(config['sync_app']) == FullSyncManager))
config["sync_app"].update(server_paths)
self.assertTrue(
type(get_full_sync_manager(config["sync_app"]) == FullSyncManager)
)
# A conf-specified FullSyncManager is loaded
config.set("sync_app", "full_sync_manager", 'test_full_sync.FakeFullSyncManager')
self.assertTrue(type(get_full_sync_manager(config['sync_app'])) == FakeFullSyncManager)
config.set(
"sync_app", "full_sync_manager", "test_full_sync.FakeFullSyncManager"
)
self.assertTrue(
type(get_full_sync_manager(config["sync_app"])) == FakeFullSyncManager
)
# Should fail at load time if the class doesn't inherit from FullSyncManager
config.set("sync_app", "full_sync_manager", 'test_full_sync.BadFullSyncManager')
config.set("sync_app", "full_sync_manager", "test_full_sync.BadFullSyncManager")
with self.assertRaises(TypeError):
pm = get_full_sync_manager(config['sync_app'])
pm = get_full_sync_manager(config["sync_app"])

View File

@ -45,26 +45,22 @@ class ServerMediaManagerTest(unittest.TestCase):
list(cm.db.execute("SELECT fname, csum FROM media")),
)
self.assertEqual(cm.lastUsn(), sm.lastUsn())
self.assertEqual(
list(sm.db.execute("SELECT usn FROM media")),
[(161,), (161,)]
)
self.assertEqual(list(sm.db.execute("SELECT usn FROM media")), [(161,), (161,)])
def test_mediaChanges_lastUsn_order(self):
col = self.colutils.create_empty_col()
col.media = ankisyncd.media.ServerMediaManager(col)
session = MagicMock()
session.name = 'test'
session.name = "test"
mh = ankisyncd.sync_app.SyncMediaHandler(col, session)
mh.col.media.addMedia(
(
('fileA', 101, '53059abba1a72c7aff34a3eaf7fef10ed65541ce'),
('fileB', 100, 'a5ae546046d09559399c80fa7076fb10f1ce4bcd'),
("fileA", 101, "53059abba1a72c7aff34a3eaf7fef10ed65541ce"),
("fileB", 100, "a5ae546046d09559399c80fa7076fb10f1ce4bcd"),
)
)
# anki assumes mh.col.media.lastUsn() == mh.mediaChanges()['data'][-1][1]
# ref: anki/sync.py:720 (commit cca3fcb2418880d0430a5c5c2e6b81ba260065b7)
self.assertEqual(
mh.mediaChanges(lastUsn=99)['data'][-1][1],
mh.col.media.lastUsn()
mh.mediaChanges(lastUsn=99)["data"][-1][1], mh.col.media.lastUsn()
)

View File

@ -14,20 +14,21 @@ from ankisyncd.sync_app import SyncUserSession
import helpers.server_utils
class FakeSessionManager(SimpleSessionManager):
def __init__(self, config):
pass
class BadSessionManager:
pass
class SessionManagerFactoryTest(unittest.TestCase):
def test_get_session_manager(self):
# Get absolute path to development ini file.
script_dir = os.path.dirname(os.path.realpath(__file__))
ini_file_path = os.path.join(script_dir,
"assets",
"test.conf")
ini_file_path = os.path.join(script_dir, "assets", "test.conf")
# Create temporary files and dirs the server will use.
server_paths = helpers.server_utils.create_server_paths()
@ -36,32 +37,40 @@ class SessionManagerFactoryTest(unittest.TestCase):
config.read(ini_file_path)
# Use custom files and dirs in settings. Should be SqliteSessionManager
config['sync_app'].update(server_paths)
self.assertTrue(type(get_session_manager(config['sync_app']) == SqliteSessionManager))
config["sync_app"].update(server_paths)
self.assertTrue(
type(get_session_manager(config["sync_app"]) == SqliteSessionManager)
)
# No value defaults to SimpleSessionManager
config.remove_option("sync_app", "session_db_path")
self.assertTrue(type(get_session_manager(config['sync_app'])) == SimpleSessionManager)
self.assertTrue(
type(get_session_manager(config["sync_app"])) == SimpleSessionManager
)
# A conf-specified SessionManager is loaded
config.set("sync_app", "session_manager", 'test_sessions.FakeSessionManager')
self.assertTrue(type(get_session_manager(config['sync_app'])) == FakeSessionManager)
config.set("sync_app", "session_manager", "test_sessions.FakeSessionManager")
self.assertTrue(
type(get_session_manager(config["sync_app"])) == FakeSessionManager
)
# Should fail at load time if the class doesn't inherit from SimpleSessionManager
config.set("sync_app", "session_manager", 'test_sessions.BadSessionManager')
config.set("sync_app", "session_manager", "test_sessions.BadSessionManager")
with self.assertRaises(TypeError):
sm = get_session_manager(config['sync_app'])
sm = get_session_manager(config["sync_app"])
# Add the session_db_path back, it should take precedence over BadSessionManager
config['sync_app'].update(server_paths)
self.assertTrue(type(get_session_manager(config['sync_app'])) == SqliteSessionManager)
config["sync_app"].update(server_paths)
self.assertTrue(
type(get_session_manager(config["sync_app"])) == SqliteSessionManager
)
class SimpleSessionManagerTest(unittest.TestCase):
test_hkey = '1234567890'
test_hkey = "1234567890"
sdir = tempfile.mkdtemp(suffix="_session")
os.rmdir(sdir)
test_session = SyncUserSession('testName', sdir, None, None)
test_session = SyncUserSession("testName", sdir, None, None)
def setUp(self):
self.sessionManager = SimpleSessionManager()
@ -71,10 +80,12 @@ class SimpleSessionManagerTest(unittest.TestCase):
def test_save(self):
self.sessionManager.save(self.test_hkey, self.test_session)
self.assertEqual(self.sessionManager.sessions[self.test_hkey].name,
self.test_session.name)
self.assertEqual(self.sessionManager.sessions[self.test_hkey].path,
self.test_session.path)
self.assertEqual(
self.sessionManager.sessions[self.test_hkey].name, self.test_session.name
)
self.assertEqual(
self.sessionManager.sessions[self.test_hkey].path, self.test_session.path
)
def test_delete(self):
self.sessionManager.save(self.test_hkey, self.test_session)
@ -111,13 +122,11 @@ class SqliteSessionManagerTest(SimpleSessionManagerTest):
conn = sqlite3.connect(self._test_sess_db_path)
cursor = conn.cursor()
cursor.execute("SELECT username, path FROM session WHERE hkey=?",
(self.test_hkey,))
cursor.execute(
"SELECT username, path FROM session WHERE hkey=?", (self.test_hkey,)
)
res = cursor.fetchone()
conn.close()
self.assertEqual(res[0], self.test_session.name)
self.assertEqual(res[1], self.test_session.path)

View File

@ -16,10 +16,9 @@ class SyncCollectionHandlerTest(CollectionTestBase):
def setUp(self):
super().setUp()
self.session = MagicMock()
self.session.name = 'test'
self.session.name = "test"
self.syncCollectionHandler = SyncCollectionHandler(
self.collection,
self.session
self.collection, self.session
)
def tearDown(self):
@ -28,48 +27,48 @@ class SyncCollectionHandlerTest(CollectionTestBase):
def test_old_client(self):
old = (
','.join(('ankidesktop', '2.0.12', 'lin::')),
','.join(('ankidesktop', '2.0.26', 'lin::')),
','.join(('ankidroid', '2.1', '')),
','.join(('ankidroid', '2.2', '')),
','.join(('ankidroid', '2.2.2', '')),
','.join(('ankidroid', '2.3alpha3', '')),
",".join(("ankidesktop", "2.0.12", "lin::")),
",".join(("ankidesktop", "2.0.26", "lin::")),
",".join(("ankidroid", "2.1", "")),
",".join(("ankidroid", "2.2", "")),
",".join(("ankidroid", "2.2.2", "")),
",".join(("ankidroid", "2.3alpha3", "")),
)
current = (
None,
','.join(('ankidesktop', '2.0.27', 'lin::')),
','.join(('ankidesktop', '2.0.32', 'lin::')),
','.join(('ankidesktop', '2.1.0', 'lin::')),
','.join(('ankidesktop', '2.1.6-beta2', 'lin::')),
','.join(('ankidesktop', '2.1.9 (dev)', 'lin::')),
','.join(('ankidesktop', '2.1.26 (arch-linux-2.1.26-1)', 'lin:arch:')),
','.join(('ankidroid', '2.2.3', '')),
','.join(('ankidroid', '2.3alpha4', '')),
','.join(('ankidroid', '2.3alpha5', '')),
','.join(('ankidroid', '2.3beta1', '')),
','.join(('ankidroid', '2.3', '')),
','.join(('ankidroid', '2.9', '')),
",".join(("ankidesktop", "2.0.27", "lin::")),
",".join(("ankidesktop", "2.0.32", "lin::")),
",".join(("ankidesktop", "2.1.0", "lin::")),
",".join(("ankidesktop", "2.1.6-beta2", "lin::")),
",".join(("ankidesktop", "2.1.9 (dev)", "lin::")),
",".join(("ankidesktop", "2.1.26 (arch-linux-2.1.26-1)", "lin:arch:")),
",".join(("ankidroid", "2.2.3", "")),
",".join(("ankidroid", "2.3alpha4", "")),
",".join(("ankidroid", "2.3alpha5", "")),
",".join(("ankidroid", "2.3beta1", "")),
",".join(("ankidroid", "2.3", "")),
",".join(("ankidroid", "2.9", "")),
)
for cv in old:
if not SyncCollectionHandler._old_client(cv):
raise AssertionError("old_client(\"%s\") is False" % cv)
raise AssertionError('old_client("%s") is False' % cv)
for cv in current:
if SyncCollectionHandler._old_client(cv):
raise AssertionError("old_client(\"%s\") is True" % cv)
raise AssertionError('old_client("%s") is True' % cv)
def test_meta(self):
meta = self.syncCollectionHandler.meta(v=SYNC_VER)
self.assertEqual(meta['scm'], self.collection.scm)
self.assertTrue((type(meta['ts']) == int) and meta['ts'] > 0)
self.assertEqual(meta['mod'], self.collection.mod)
self.assertEqual(meta['usn'], self.collection._usn)
self.assertEqual(meta['uname'], self.session.name)
self.assertEqual(meta['musn'], self.collection.media.lastUsn())
self.assertEqual(meta['msg'], '')
self.assertEqual(meta['cont'], True)
self.assertEqual(meta["scm"], self.collection.scm)
self.assertTrue((type(meta["ts"]) == int) and meta["ts"] > 0)
self.assertEqual(meta["mod"], self.collection.mod)
self.assertEqual(meta["usn"], self.collection._usn)
self.assertEqual(meta["uname"], self.session.name)
self.assertEqual(meta["musn"], self.collection.media.lastUsn())
self.assertEqual(meta["msg"], "")
self.assertEqual(meta["cont"], True)
class SyncAppTest(unittest.TestCase):

View File

@ -10,20 +10,21 @@ from ankisyncd.users import get_user_manager
import helpers.server_utils
class FakeUserManager(SimpleUserManager):
def __init__(self, config):
pass
class BadUserManager:
pass
class UserManagerFactoryTest(unittest.TestCase):
def test_get_user_manager(self):
# Get absolute path to development ini file.
script_dir = os.path.dirname(os.path.realpath(__file__))
ini_file_path = os.path.join(script_dir,
"assets",
"test.conf")
ini_file_path = os.path.join(script_dir, "assets", "test.conf")
# Create temporary files and dirs the server will use.
server_paths = helpers.server_utils.create_server_paths()
@ -32,25 +33,25 @@ class UserManagerFactoryTest(unittest.TestCase):
config.read(ini_file_path)
# Use custom files and dirs in settings. Should be SqliteUserManager
config['sync_app'].update(server_paths)
self.assertTrue(type(get_user_manager(config['sync_app']) == SqliteUserManager))
config["sync_app"].update(server_paths)
self.assertTrue(type(get_user_manager(config["sync_app"]) == SqliteUserManager))
# No value defaults to SimpleUserManager
config.remove_option("sync_app", "auth_db_path")
self.assertTrue(type(get_user_manager(config['sync_app'])) == SimpleUserManager)
self.assertTrue(type(get_user_manager(config["sync_app"])) == SimpleUserManager)
# A conf-specified UserManager is loaded
config.set("sync_app", "user_manager", 'test_users.FakeUserManager')
self.assertTrue(type(get_user_manager(config['sync_app'])) == FakeUserManager)
config.set("sync_app", "user_manager", "test_users.FakeUserManager")
self.assertTrue(type(get_user_manager(config["sync_app"])) == FakeUserManager)
# Should fail at load time if the class doesn't inherit from SimpleUserManager
config.set("sync_app", "user_manager", 'test_users.BadUserManager')
config.set("sync_app", "user_manager", "test_users.BadUserManager")
with self.assertRaises(TypeError):
um = get_user_manager(config['sync_app'])
um = get_user_manager(config["sync_app"])
# Add the auth_db_path back, it should take precedence over BadUserManager
config['sync_app'].update(server_paths)
self.assertTrue(type(get_user_manager(config['sync_app']) == SqliteUserManager))
config["sync_app"].update(server_paths)
self.assertTrue(type(get_user_manager(config["sync_app"]) == SqliteUserManager))
class SimpleUserManagerTest(unittest.TestCase):
@ -61,22 +62,18 @@ class SimpleUserManagerTest(unittest.TestCase):
self._user_manager = None
def test_authenticate(self):
good_test_un = 'username'
good_test_pw = 'password'
bad_test_un = 'notAUsername'
bad_test_pw = 'notAPassword'
good_test_un = "username"
good_test_pw = "password"
bad_test_un = "notAUsername"
bad_test_pw = "notAPassword"
self.assertTrue(self.user_manager.authenticate(good_test_un,
good_test_pw))
self.assertTrue(self.user_manager.authenticate(bad_test_un,
bad_test_pw))
self.assertTrue(self.user_manager.authenticate(good_test_un,
bad_test_pw))
self.assertTrue(self.user_manager.authenticate(bad_test_un,
good_test_pw))
self.assertTrue(self.user_manager.authenticate(good_test_un, good_test_pw))
self.assertTrue(self.user_manager.authenticate(bad_test_un, bad_test_pw))
self.assertTrue(self.user_manager.authenticate(good_test_un, bad_test_pw))
self.assertTrue(self.user_manager.authenticate(bad_test_un, good_test_pw))
def test_userdir(self):
username = 'my_username'
username = "my_username"
dirname = self.user_manager.userdir(username)
self.assertEqual(dirname, username)
@ -87,8 +84,7 @@ class SqliteUserManagerTest(unittest.TestCase):
self.basedir = basedir
self.auth_db_path = os.path.join(basedir, "auth.db")
self.collection_path = os.path.join(basedir, "collections")
self.user_manager = SqliteUserManager(self.auth_db_path,
self.collection_path)
self.user_manager = SqliteUserManager(self.auth_db_path, self.collection_path)
def tearDown(self):
shutil.rmtree(self.basedir)
@ -151,18 +147,22 @@ class SqliteUserManagerTest(unittest.TestCase):
self.assertTrue(os.path.isdir(expected_dir_path))
def test_add_users(self):
users_data = [("my_first_username", "my_first_password"),
("my_second_username", "my_second_password")]
users_data = [
("my_first_username", "my_first_password"),
("my_second_username", "my_second_password"),
]
self.user_manager.create_auth_db()
self.user_manager.add_users(users_data)
user_list = self.user_manager.user_list()
self.assertIn("my_first_username", user_list)
self.assertIn("my_second_username", user_list)
self.assertTrue(os.path.isdir(os.path.join(self.collection_path,
"my_first_username")))
self.assertTrue(os.path.isdir(os.path.join(self.collection_path,
"my_second_username")))
self.assertTrue(
os.path.isdir(os.path.join(self.collection_path, "my_first_username"))
)
self.assertTrue(
os.path.isdir(os.path.join(self.collection_path, "my_second_username"))
)
def test__add_user_to_auth_db(self):
username = "my_username"
@ -191,8 +191,7 @@ class SqliteUserManagerTest(unittest.TestCase):
self.user_manager.create_auth_db()
self.user_manager.add_user(username, password)
self.assertTrue(self.user_manager.authenticate(username,
password))
self.assertTrue(self.user_manager.authenticate(username, password))
def test_set_password_for_user(self):
username = "my_username"
@ -203,8 +202,5 @@ class SqliteUserManagerTest(unittest.TestCase):
self.user_manager.add_user(username, password)
self.user_manager.set_password_for_user(username, new_password)
self.assertFalse(self.user_manager.authenticate(username,
password))
self.assertTrue(self.user_manager.authenticate(username,
new_password))
self.assertFalse(self.user_manager.authenticate(username, password))
self.assertTrue(self.user_manager.authenticate(username, new_password))

View File

@ -17,4 +17,3 @@ class SyncAppFunctionalHostKeyTest(SyncAppFunctionalTestBase):
self.assertIsNone(self.server.hostKey("testuser", "wrongpassword"))
self.assertIsNone(self.server.hostKey("wronguser", "wrongpassword"))
self.assertIsNone(self.server.hostKey("wronguser", "testpassword"))