Use factory method for user_manager
Also add some abstraction over the SQL to allow for different SQL dialects
This commit is contained in:
parent
bfeaeae2e5
commit
ea0cbc669b
@ -4,11 +4,10 @@ import sys
|
||||
import getpass
|
||||
|
||||
import ankisyncd.config
|
||||
from ankisyncd.users import SqliteUserManager
|
||||
from ankisyncd.users import get_user_manager
|
||||
|
||||
|
||||
config = ankisyncd.config.load()
|
||||
AUTHDBPATH = config['auth_db_path']
|
||||
COLLECTIONPATH = config['data_root']
|
||||
|
||||
def usage():
|
||||
print("usage: {} <command> [<args>]".format(sys.argv[0]))
|
||||
@ -22,18 +21,18 @@ def usage():
|
||||
def adduser(username):
|
||||
password = getpass.getpass("Enter password for {}: ".format(username))
|
||||
|
||||
user_manager = SqliteUserManager(AUTHDBPATH, COLLECTIONPATH)
|
||||
user_manager = get_user_manager(config)
|
||||
user_manager.add_user(username, password)
|
||||
|
||||
def deluser(username):
|
||||
user_manager = SqliteUserManager(AUTHDBPATH, COLLECTIONPATH)
|
||||
user_manager = get_user_manager(config)
|
||||
try:
|
||||
user_manager.del_user(username)
|
||||
except ValueError as error:
|
||||
print("Could not delete user {}: {}".format(username, error), file=sys.stderr)
|
||||
|
||||
def lsuser():
|
||||
user_manager = SqliteUserManager(AUTHDBPATH, COLLECTIONPATH)
|
||||
user_manager = get_user_manager(config)
|
||||
try:
|
||||
users = user_manager.user_list()
|
||||
for username in users:
|
||||
@ -42,7 +41,7 @@ def lsuser():
|
||||
print("Could not list users: {}".format(error), file=sys.stderr)
|
||||
|
||||
def passwd(username):
|
||||
user_manager = SqliteUserManager(AUTHDBPATH, COLLECTIONPATH)
|
||||
user_manager = get_user_manager(config)
|
||||
|
||||
if username not in user_manager.user_list():
|
||||
print("User {} doesn't exist".format(username))
|
||||
|
||||
@ -40,7 +40,7 @@ import anki.utils
|
||||
from anki.consts import SYNC_VER, SYNC_ZIP_SIZE, SYNC_ZIP_COUNT
|
||||
from anki.consts import REM_CARD, REM_NOTE
|
||||
|
||||
from ankisyncd.users import SimpleUserManager, SqliteUserManager
|
||||
from ankisyncd.users import get_user_manager
|
||||
|
||||
logger = logging.getLogger("ankisyncd")
|
||||
|
||||
@ -421,12 +421,7 @@ class SyncApp:
|
||||
else:
|
||||
self.session_manager = SimpleSessionManager()
|
||||
|
||||
if "auth_db_path" in config:
|
||||
self.user_manager = SqliteUserManager(config['auth_db_path'])
|
||||
else:
|
||||
logger.warn("auth_db_path not set, ankisyncd will accept any password")
|
||||
self.user_manager = SimpleUserManager()
|
||||
|
||||
self.user_manager = get_user_manager(config)
|
||||
self.collection_manager = getCollectionManager()
|
||||
|
||||
# make sure the base_url has a trailing slash
|
||||
|
||||
@ -46,16 +46,29 @@ class SqliteUserManager(SimpleUserManager):
|
||||
SimpleUserManager.__init__(self, collection_path)
|
||||
self.auth_db_path = os.path.realpath(auth_db_path)
|
||||
|
||||
# Default to using sqlite3 but overridable for sub-classes using other
|
||||
# DB API 2 driver variants
|
||||
def auth_db_exists(self):
|
||||
return os.path.isfile(self.auth_db_path)
|
||||
|
||||
# Default to using sqlite3 but overridable for sub-classes using other
|
||||
# DB API 2 driver variants
|
||||
def _conn(self):
|
||||
return sqlite.connect(self.auth_db_path)
|
||||
|
||||
# Default to using sqlite3 syntax but overridable for sub-classes using other
|
||||
# DB API 2 driver variants
|
||||
@staticmethod
|
||||
def fs(sql):
|
||||
return sql
|
||||
|
||||
def user_list(self):
|
||||
if not self.auth_db_exists():
|
||||
raise ValueError("Auth DB {} doesn't exist".format(self.auth_db_path))
|
||||
|
||||
conn = sqlite.connect(self.auth_db_path)
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT user FROM auth")
|
||||
cursor.execute(self.fs("SELECT username FROM auth"))
|
||||
rows = cursor.fetchall()
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@ -67,13 +80,14 @@ class SqliteUserManager(SimpleUserManager):
|
||||
return username in users
|
||||
|
||||
def del_user(self, username):
|
||||
# Warning, this doesn't remove the user directory or clean it
|
||||
if not self.auth_db_exists():
|
||||
raise ValueError("Auth DB {} doesn't exist".format(self.auth_db_path))
|
||||
|
||||
conn = sqlite.connect(self.auth_db_path)
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
logger.info("Removing user '{}' from auth db".format(username))
|
||||
cursor.execute("DELETE FROM auth WHERE user=?", (username,))
|
||||
cursor.execute(self.fs("DELETE FROM auth WHERE username=?"), (username,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
@ -91,10 +105,10 @@ class SqliteUserManager(SimpleUserManager):
|
||||
|
||||
pass_hash = self._create_pass_hash(username, password)
|
||||
|
||||
conn = sqlite.connect(self.auth_db_path)
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
logger.info("Adding user '{}' to auth db.".format(username))
|
||||
cursor.execute("INSERT INTO auth VALUES (?, ?)",
|
||||
cursor.execute(self.fs("INSERT INTO auth VALUES (?, ?)"),
|
||||
(username, pass_hash))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@ -107,9 +121,9 @@ class SqliteUserManager(SimpleUserManager):
|
||||
|
||||
hash = self._create_pass_hash(username, new_password)
|
||||
|
||||
conn = sqlite.connect(self.auth_db_path)
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("UPDATE auth SET hash=? WHERE user=?", (hash, username))
|
||||
cursor.execute(self.fs("UPDATE auth SET hash=? WHERE username=?"), (hash, username))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
@ -118,10 +132,10 @@ class SqliteUserManager(SimpleUserManager):
|
||||
def authenticate(self, username, password):
|
||||
"""Returns True if this username is allowed to connect with this password. False otherwise."""
|
||||
|
||||
conn = sqlite.connect(self.auth_db_path)
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
param = (username,)
|
||||
cursor.execute("SELECT hash FROM auth WHERE user=?", param)
|
||||
cursor.execute(self.fs("SELECT hash FROM auth WHERE username=?"), param)
|
||||
db_hash = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
@ -156,11 +170,32 @@ class SqliteUserManager(SimpleUserManager):
|
||||
return pass_hash
|
||||
|
||||
def create_auth_db(self):
|
||||
conn = sqlite.connect(self.auth_db_path)
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
logger.info("Creating auth db at {}."
|
||||
.format(self.auth_db_path))
|
||||
cursor.execute("""CREATE TABLE IF NOT EXISTS auth
|
||||
(user VARCHAR PRIMARY KEY, hash VARCHAR)""")
|
||||
cursor.execute(self.fs("""CREATE TABLE IF NOT EXISTS auth
|
||||
(username VARCHAR PRIMARY KEY, hash VARCHAR)"""))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def get_user_manager(config):
|
||||
if "auth_db_path" in config and config["auth_db_path"]:
|
||||
logger.info("Found auth_db_path in config, using SqliteUserManager for auth")
|
||||
return SqliteUserManager(config['auth_db_path'], config['data_root'])
|
||||
elif "user_manager" in config and config["user_manager"]: # load from config
|
||||
logger.info("Found user_manager in config, using {} for auth".format(config['user_manager']))
|
||||
import importlib
|
||||
import inspect
|
||||
module_name, class_name = config['user_manager'].rsplit('.', 1)
|
||||
module = importlib.import_module(module_name.strip())
|
||||
class_ = getattr(module, class_name.strip())
|
||||
|
||||
if not SimpleUserManager in inspect.getmro(class_):
|
||||
raise TypeError('''"user_manager" found in the conf file but it doesn''t
|
||||
inherit from SimpleUserManager''')
|
||||
return class_(config)
|
||||
else:
|
||||
logger.warning("neither auth_db_path nor user_manager set, ankisyncd will accept any password")
|
||||
return SimpleUserManager()
|
||||
|
||||
@ -3,8 +3,54 @@ import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
import configparser
|
||||
|
||||
from ankisyncd.users import SimpleUserManager, SqliteUserManager
|
||||
from ankisyncd.users import get_user_manager
|
||||
|
||||
import helpers.server_utils
|
||||
|
||||
class FakeUserManager(SimpleUserManager):
|
||||
def __init__(self, config):
|
||||
pass
|
||||
|
||||
class BadUserManager:
|
||||
pass
|
||||
|
||||
class UserManagerFactoryTest(unittest.TestCase):
|
||||
def test_get_user_manager(self):
|
||||
# Get absolute path to development ini file.
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
ini_file_path = os.path.join(script_dir,
|
||||
"assets",
|
||||
"test.conf")
|
||||
|
||||
# Create temporary files and dirs the server will use.
|
||||
server_paths = helpers.server_utils.create_server_paths()
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read(ini_file_path)
|
||||
|
||||
# Use custom files and dirs in settings. Should be SqliteUserManager
|
||||
config['sync_app'].update(server_paths)
|
||||
self.assertTrue(type(get_user_manager(config['sync_app']) == SqliteUserManager))
|
||||
|
||||
# No value defaults to SimpleUserManager
|
||||
config.remove_option("sync_app", "auth_db_path")
|
||||
self.assertTrue(type(get_user_manager(config['sync_app'])) == SimpleUserManager)
|
||||
|
||||
# A conf-specified UserManager is loaded
|
||||
config.set("sync_app", "user_manager", 'test_users.FakeUserManager')
|
||||
self.assertTrue(type(get_user_manager(config['sync_app'])) == FakeUserManager)
|
||||
|
||||
# Should fail at load time if the class doesn't inherit from SimpleUserManager
|
||||
config.set("sync_app", "user_manager", 'test_users.BadUserManager')
|
||||
with self.assertRaises(TypeError):
|
||||
um = get_user_manager(config['sync_app'])
|
||||
|
||||
# Add the auth_db_path back, it should take precedence over BadUserManager
|
||||
config['sync_app'].update(server_paths)
|
||||
self.assertTrue(type(get_user_manager(config['sync_app']) == SqliteUserManager))
|
||||
|
||||
|
||||
class SimpleUserManagerTest(unittest.TestCase):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user