diff --git a/ankisyncctl.py b/ankisyncctl.py index bb98918..94e504d 100755 --- a/ankisyncctl.py +++ b/ankisyncctl.py @@ -4,11 +4,10 @@ import sys import getpass import ankisyncd.config -from ankisyncd.users import SqliteUserManager +from ankisyncd.users import get_user_manager + config = ankisyncd.config.load() -AUTHDBPATH = config['auth_db_path'] -COLLECTIONPATH = config['data_root'] def usage(): print("usage: {} []".format(sys.argv[0])) @@ -22,18 +21,18 @@ def usage(): def adduser(username): password = getpass.getpass("Enter password for {}: ".format(username)) - user_manager = SqliteUserManager(AUTHDBPATH, COLLECTIONPATH) + user_manager = get_user_manager(config) user_manager.add_user(username, password) def deluser(username): - user_manager = SqliteUserManager(AUTHDBPATH, COLLECTIONPATH) + user_manager = get_user_manager(config) try: user_manager.del_user(username) except ValueError as error: print("Could not delete user {}: {}".format(username, error), file=sys.stderr) def lsuser(): - user_manager = SqliteUserManager(AUTHDBPATH, COLLECTIONPATH) + user_manager = get_user_manager(config) try: users = user_manager.user_list() for username in users: @@ -42,7 +41,7 @@ def lsuser(): print("Could not list users: {}".format(error), file=sys.stderr) def passwd(username): - user_manager = SqliteUserManager(AUTHDBPATH, COLLECTIONPATH) + user_manager = get_user_manager(config) if username not in user_manager.user_list(): print("User {} doesn't exist".format(username)) diff --git a/ankisyncd/sync_app.py b/ankisyncd/sync_app.py index f75a7a3..02f4d75 100644 --- a/ankisyncd/sync_app.py +++ b/ankisyncd/sync_app.py @@ -40,7 +40,7 @@ import anki.utils from anki.consts import SYNC_VER, SYNC_ZIP_SIZE, SYNC_ZIP_COUNT from anki.consts import REM_CARD, REM_NOTE -from ankisyncd.users import SimpleUserManager, SqliteUserManager +from ankisyncd.users import get_user_manager logger = logging.getLogger("ankisyncd") @@ -421,12 +421,7 @@ class SyncApp: else: self.session_manager = SimpleSessionManager() - if "auth_db_path" in config: - self.user_manager = SqliteUserManager(config['auth_db_path']) - else: - logger.warn("auth_db_path not set, ankisyncd will accept any password") - self.user_manager = SimpleUserManager() - + self.user_manager = get_user_manager(config) self.collection_manager = getCollectionManager() # make sure the base_url has a trailing slash diff --git a/ankisyncd/users.py b/ankisyncd/users.py index 5418725..da2af36 100644 --- a/ankisyncd/users.py +++ b/ankisyncd/users.py @@ -46,16 +46,29 @@ class SqliteUserManager(SimpleUserManager): SimpleUserManager.__init__(self, collection_path) self.auth_db_path = os.path.realpath(auth_db_path) + # Default to using sqlite3 but overridable for sub-classes using other + # DB API 2 driver variants def auth_db_exists(self): return os.path.isfile(self.auth_db_path) + # Default to using sqlite3 but overridable for sub-classes using other + # DB API 2 driver variants + def _conn(self): + return sqlite.connect(self.auth_db_path) + + # Default to using sqlite3 syntax but overridable for sub-classes using other + # DB API 2 driver variants + @staticmethod + def fs(sql): + return sql + def user_list(self): if not self.auth_db_exists(): raise ValueError("Auth DB {} doesn't exist".format(self.auth_db_path)) - conn = sqlite.connect(self.auth_db_path) + conn = self._conn() cursor = conn.cursor() - cursor.execute("SELECT user FROM auth") + cursor.execute(self.fs("SELECT username FROM auth")) rows = cursor.fetchall() conn.commit() conn.close() @@ -67,13 +80,14 @@ class SqliteUserManager(SimpleUserManager): return username in users def del_user(self, username): + # Warning, this doesn't remove the user directory or clean it if not self.auth_db_exists(): raise ValueError("Auth DB {} doesn't exist".format(self.auth_db_path)) - conn = sqlite.connect(self.auth_db_path) + conn = self._conn() cursor = conn.cursor() logger.info("Removing user '{}' from auth db".format(username)) - cursor.execute("DELETE FROM auth WHERE user=?", (username,)) + cursor.execute(self.fs("DELETE FROM auth WHERE username=?"), (username,)) conn.commit() conn.close() @@ -91,10 +105,10 @@ class SqliteUserManager(SimpleUserManager): pass_hash = self._create_pass_hash(username, password) - conn = sqlite.connect(self.auth_db_path) + conn = self._conn() cursor = conn.cursor() logger.info("Adding user '{}' to auth db.".format(username)) - cursor.execute("INSERT INTO auth VALUES (?, ?)", + cursor.execute(self.fs("INSERT INTO auth VALUES (?, ?)"), (username, pass_hash)) conn.commit() conn.close() @@ -107,9 +121,9 @@ class SqliteUserManager(SimpleUserManager): hash = self._create_pass_hash(username, new_password) - conn = sqlite.connect(self.auth_db_path) + conn = self._conn() cursor = conn.cursor() - cursor.execute("UPDATE auth SET hash=? WHERE user=?", (hash, username)) + cursor.execute(self.fs("UPDATE auth SET hash=? WHERE username=?"), (hash, username)) conn.commit() conn.close() @@ -118,10 +132,10 @@ class SqliteUserManager(SimpleUserManager): def authenticate(self, username, password): """Returns True if this username is allowed to connect with this password. False otherwise.""" - conn = sqlite.connect(self.auth_db_path) + conn = self._conn() cursor = conn.cursor() param = (username,) - cursor.execute("SELECT hash FROM auth WHERE user=?", param) + cursor.execute(self.fs("SELECT hash FROM auth WHERE username=?"), param) db_hash = cursor.fetchone() conn.close() @@ -156,11 +170,32 @@ class SqliteUserManager(SimpleUserManager): return pass_hash def create_auth_db(self): - conn = sqlite.connect(self.auth_db_path) + conn = self._conn() cursor = conn.cursor() logger.info("Creating auth db at {}." .format(self.auth_db_path)) - cursor.execute("""CREATE TABLE IF NOT EXISTS auth - (user VARCHAR PRIMARY KEY, hash VARCHAR)""") + cursor.execute(self.fs("""CREATE TABLE IF NOT EXISTS auth + (username VARCHAR PRIMARY KEY, hash VARCHAR)""")) conn.commit() conn.close() + + +def get_user_manager(config): + if "auth_db_path" in config and config["auth_db_path"]: + logger.info("Found auth_db_path in config, using SqliteUserManager for auth") + return SqliteUserManager(config['auth_db_path'], config['data_root']) + elif "user_manager" in config and config["user_manager"]: # load from config + logger.info("Found user_manager in config, using {} for auth".format(config['user_manager'])) + import importlib + import inspect + module_name, class_name = config['user_manager'].rsplit('.', 1) + module = importlib.import_module(module_name.strip()) + class_ = getattr(module, class_name.strip()) + + if not SimpleUserManager in inspect.getmro(class_): + raise TypeError('''"user_manager" found in the conf file but it doesn''t + inherit from SimpleUserManager''') + return class_(config) + else: + logger.warning("neither auth_db_path nor user_manager set, ankisyncd will accept any password") + return SimpleUserManager() diff --git a/tests/test_users.py b/tests/test_users.py index a90b75e..13cfb31 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -3,8 +3,54 @@ import os import shutil import tempfile import unittest +import configparser from ankisyncd.users import SimpleUserManager, SqliteUserManager +from ankisyncd.users import get_user_manager + +import helpers.server_utils + +class FakeUserManager(SimpleUserManager): + def __init__(self, config): + pass + +class BadUserManager: + pass + +class UserManagerFactoryTest(unittest.TestCase): + def test_get_user_manager(self): + # Get absolute path to development ini file. + script_dir = os.path.dirname(os.path.realpath(__file__)) + ini_file_path = os.path.join(script_dir, + "assets", + "test.conf") + + # Create temporary files and dirs the server will use. + server_paths = helpers.server_utils.create_server_paths() + + config = configparser.ConfigParser() + config.read(ini_file_path) + + # Use custom files and dirs in settings. Should be SqliteUserManager + config['sync_app'].update(server_paths) + self.assertTrue(type(get_user_manager(config['sync_app']) == SqliteUserManager)) + + # No value defaults to SimpleUserManager + config.remove_option("sync_app", "auth_db_path") + self.assertTrue(type(get_user_manager(config['sync_app'])) == SimpleUserManager) + + # A conf-specified UserManager is loaded + config.set("sync_app", "user_manager", 'test_users.FakeUserManager') + self.assertTrue(type(get_user_manager(config['sync_app'])) == FakeUserManager) + + # Should fail at load time if the class doesn't inherit from SimpleUserManager + config.set("sync_app", "user_manager", 'test_users.BadUserManager') + with self.assertRaises(TypeError): + um = get_user_manager(config['sync_app']) + + # Add the auth_db_path back, it should take precedence over BadUserManager + config['sync_app'].update(server_paths) + self.assertTrue(type(get_user_manager(config['sync_app']) == SqliteUserManager)) class SimpleUserManagerTest(unittest.TestCase):