From 50cc6a12d9c4b171f7b68ef3e4fed190b1fbfdb2 Mon Sep 17 00:00:00 2001 From: Anton Melser Date: Mon, 28 Jan 2019 21:23:07 +0800 Subject: [PATCH] Use factory method for session_manager Also add some abstraction over the SQL to allow for different SQL dialects --- ankisyncd/sessions.py | 126 +++++++++++++++++++++++++++++++++++++++++ ankisyncd/sync_app.py | 94 +----------------------------- tests/test_sessions.py | 123 ++++++++++++++++++++++++++++++++++++++++ tests/test_sync_app.py | 65 --------------------- 4 files changed, 251 insertions(+), 157 deletions(-) create mode 100644 ankisyncd/sessions.py create mode 100644 tests/test_sessions.py diff --git a/ankisyncd/sessions.py b/ankisyncd/sessions.py new file mode 100644 index 0000000..7a7b2d4 --- /dev/null +++ b/ankisyncd/sessions.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +import os +import logging +from sqlite3 import dbapi2 as sqlite + +logger = logging.getLogger("ankisyncd.sessions") + + +class SimpleSessionManager: + """A simple session manager that keeps the sessions in memory.""" + + def __init__(self): + self.sessions = {} + + def load(self, hkey, session_factory=None): + return self.sessions.get(hkey) + + def load_from_skey(self, skey, session_factory=None): + for i in self.sessions: + if self.sessions[i].skey == skey: + return self.sessions[i] + + def save(self, hkey, session): + self.sessions[hkey] = session + + def delete(self, hkey): + del self.sessions[hkey] + + +class SqliteSessionManager(SimpleSessionManager): + """Stores sessions in a SQLite database to prevent the user from being logged out + everytime the SyncApp is restarted.""" + + def __init__(self, session_db_path): + SimpleSessionManager.__init__(self) + + self.session_db_path = os.path.realpath(session_db_path) + + def _conn(self): + new = not os.path.exists(self.session_db_path) + conn = sqlite.connect(self.session_db_path) + if new: + cursor = conn.cursor() + cursor.execute("CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, username VARCHAR, path VARCHAR)") + return conn + + # Default to using sqlite3 syntax but overridable for sub-classes using other + # DB API 2 driver variants + @staticmethod + def fs(sql): + return sql + + def load(self, hkey, session_factory=None): + session = SimpleSessionManager.load(self, hkey) + if session is not None: + return session + + conn = self._conn() + cursor = conn.cursor() + + cursor.execute(self.fs("SELECT skey, username, path FROM session WHERE hkey=?"), (hkey,)) + res = cursor.fetchone() + + if res is not None: + session = self.sessions[hkey] = session_factory(res[1], res[2]) + session.skey = res[0] + return session + + def load_from_skey(self, skey, session_factory=None): + session = SimpleSessionManager.load_from_skey(self, skey) + if session is not None: + return session + + conn = self._conn() + cursor = conn.cursor() + + cursor.execute(self.fs("SELECT hkey, username, path FROM session WHERE skey=?"), (skey,)) + res = cursor.fetchone() + + if res is not None: + session = self.sessions[res[0]] = session_factory(res[1], res[2]) + session.skey = skey + return session + + def save(self, hkey, session): + SimpleSessionManager.save(self, hkey, session) + + conn = self._conn() + cursor = conn.cursor() + + cursor.execute("INSERT OR REPLACE INTO session (hkey, skey, username, path) VALUES (?, ?, ?, ?)", + (hkey, session.skey, session.name, session.path)) + + conn.commit() + + def delete(self, hkey): + SimpleSessionManager.delete(self, hkey) + + conn = self._conn() + cursor = conn.cursor() + + cursor.execute(self.fs("DELETE FROM session WHERE hkey=?"), (hkey,)) + conn.commit() + +def get_session_manager(config): + if "session_db_path" in config and config["session_db_path"]: + logger.info("Found session_db_path in config, using SqliteSessionManager for auth") + return SqliteSessionManager(config['session_db_path']) + elif "session_manager" in config and config["session_manager"]: # load from config + logger.info("Found session_manager in config, using {} for persisting sessions".format( + config['session_manager']) + ) + import importlib + import inspect + module_name, class_name = config['session_manager'].rsplit('.', 1) + module = importlib.import_module(module_name.strip()) + class_ = getattr(module, class_name.strip()) + + if not SimpleSessionManager in inspect.getmro(class_): + raise TypeError('''"session_manager" found in the conf file but it doesn''t + inherit from SimpleSessionManager''') + return class_(config) + else: + logger.warning("Neither session_db_path nor session_manager set, " + "ankisyncd will lose sessions on application restart") + return SimpleSessionManager() diff --git a/ankisyncd/sync_app.py b/ankisyncd/sync_app.py index 02f4d75..fa7766a 100644 --- a/ankisyncd/sync_app.py +++ b/ankisyncd/sync_app.py @@ -41,6 +41,7 @@ from anki.consts import SYNC_VER, SYNC_ZIP_SIZE, SYNC_ZIP_COUNT from anki.consts import REM_CARD, REM_NOTE from ankisyncd.users import get_user_manager +from ankisyncd.sessions import get_session_manager logger = logging.getLogger("ankisyncd") @@ -382,26 +383,6 @@ class SyncUserSession: handler.col = col return handler -class SimpleSessionManager: - """A simple session manager that keeps the sessions in memory.""" - - def __init__(self): - self.sessions = {} - - def load(self, hkey, session_factory=None): - return self.sessions.get(hkey) - - def load_from_skey(self, skey, session_factory=None): - for i in self.sessions: - if self.sessions[i].skey == skey: - return self.sessions[i] - - def save(self, hkey, session): - self.sessions[hkey] = session - - def delete(self, hkey): - del self.sessions[hkey] - class SyncApp: valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download'] @@ -416,12 +397,8 @@ class SyncApp: self.prehooks = {} self.posthooks = {} - if "session_db_path" in config: - self.session_manager = SqliteSessionManager(config['session_db_path']) - else: - self.session_manager = SimpleSessionManager() - self.user_manager = get_user_manager(config) + self.session_manager = get_session_manager(config) self.collection_manager = getCollectionManager() # make sure the base_url has a trailing slash @@ -680,73 +657,6 @@ class SyncApp: return result -class SqliteSessionManager(SimpleSessionManager): - """Stores sessions in a SQLite database to prevent the user from being logged out - everytime the SyncApp is restarted.""" - - def __init__(self, session_db_path): - SimpleSessionManager.__init__(self) - - self.session_db_path = os.path.realpath(session_db_path) - - def _conn(self): - new = not os.path.exists(self.session_db_path) - conn = sqlite.connect(self.session_db_path) - if new: - cursor = conn.cursor() - cursor.execute("CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, user VARCHAR, path VARCHAR)") - return conn - - def load(self, hkey, session_factory=None): - session = SimpleSessionManager.load(self, hkey) - if session is not None: - return session - - conn = self._conn() - cursor = conn.cursor() - - cursor.execute("SELECT skey, user, path FROM session WHERE hkey=?", (hkey,)) - res = cursor.fetchone() - - if res is not None: - session = self.sessions[hkey] = session_factory(res[1], res[2]) - session.skey = res[0] - return session - - def load_from_skey(self, skey, session_factory=None): - session = SimpleSessionManager.load_from_skey(self, skey) - if session is not None: - return session - - conn = self._conn() - cursor = conn.cursor() - - cursor.execute("SELECT hkey, user, path FROM session WHERE skey=?", (skey,)) - res = cursor.fetchone() - - if res is not None: - session = self.sessions[res[0]] = session_factory(res[1], res[2]) - session.skey = skey - return session - - def save(self, hkey, session): - SimpleSessionManager.save(self, hkey, session) - - conn = self._conn() - cursor = conn.cursor() - - cursor.execute("INSERT OR REPLACE INTO session (hkey, skey, user, path) VALUES (?, ?, ?, ?)", - (hkey, session.skey, session.name, session.path)) - conn.commit() - - def delete(self, hkey): - SimpleSessionManager.delete(self, hkey) - - conn = self._conn() - cursor = conn.cursor() - - cursor.execute("DELETE FROM session WHERE hkey=?", (hkey,)) - conn.commit() def make_app(global_conf, **local_conf): return SyncApp(**local_conf) diff --git a/tests/test_sessions.py b/tests/test_sessions.py new file mode 100644 index 0000000..78f297c --- /dev/null +++ b/tests/test_sessions.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- + +import os +import tempfile +import sqlite3 +import unittest +import configparser + +from ankisyncd.sessions import SimpleSessionManager +from ankisyncd.sessions import SqliteSessionManager +from ankisyncd.sessions import get_session_manager + +from ankisyncd.sync_app import SyncUserSession + +import helpers.server_utils + +class FakeSessionManager(SimpleSessionManager): + def __init__(self, config): + pass + +class BadSessionManager: + pass + +class SessionManagerFactoryTest(unittest.TestCase): + def test_get_session_manager(self): + # Get absolute path to development ini file. + script_dir = os.path.dirname(os.path.realpath(__file__)) + ini_file_path = os.path.join(script_dir, + "assets", + "test.conf") + + # Create temporary files and dirs the server will use. + server_paths = helpers.server_utils.create_server_paths() + + config = configparser.ConfigParser() + config.read(ini_file_path) + + # Use custom files and dirs in settings. Should be SqliteSessionManager + config['sync_app'].update(server_paths) + self.assertTrue(type(get_session_manager(config['sync_app']) == SqliteSessionManager)) + + # No value defaults to SimpleSessionManager + config.remove_option("sync_app", "session_db_path") + self.assertTrue(type(get_session_manager(config['sync_app'])) == SimpleSessionManager) + + # A conf-specified SessionManager is loaded + config.set("sync_app", "session_manager", 'test_sessions.FakeSessionManager') + self.assertTrue(type(get_session_manager(config['sync_app'])) == FakeSessionManager) + + # Should fail at load time if the class doesn't inherit from SimpleSessionManager + config.set("sync_app", "session_manager", 'test_sessions.BadSessionManager') + with self.assertRaises(TypeError): + sm = get_session_manager(config['sync_app']) + + # Add the session_db_path back, it should take precedence over BadSessionManager + config['sync_app'].update(server_paths) + self.assertTrue(type(get_session_manager(config['sync_app'])) == SqliteSessionManager) + + +class SimpleSessionManagerTest(unittest.TestCase): + test_hkey = '1234567890' + sdir = tempfile.mkdtemp(suffix="_session") + os.rmdir(sdir) + test_session = SyncUserSession('testName', sdir, None, None) + + def setUp(self): + self.sessionManager = SimpleSessionManager() + + def tearDown(self): + self.sessionManager = None + + def test_save(self): + self.sessionManager.save(self.test_hkey, self.test_session) + self.assertEqual(self.sessionManager.sessions[self.test_hkey].name, + self.test_session.name) + self.assertEqual(self.sessionManager.sessions[self.test_hkey].path, + self.test_session.path) + + def test_delete(self): + self.sessionManager.save(self.test_hkey, self.test_session) + self.assertTrue(self.test_hkey in self.sessionManager.sessions) + + self.sessionManager.delete(self.test_hkey) + + self.assertTrue(self.test_hkey not in self.sessionManager.sessions) + + def test_load(self): + self.sessionManager.save(self.test_hkey, self.test_session) + self.assertTrue(self.test_hkey in self.sessionManager.sessions) + + loaded_session = self.sessionManager.load(self.test_hkey) + self.assertEqual(loaded_session.name, self.test_session.name) + self.assertEqual(loaded_session.path, self.test_session.path) + + +class SqliteSessionManagerTest(SimpleSessionManagerTest): + file_descriptor, _test_sess_db_path = tempfile.mkstemp(suffix=".db") + os.close(file_descriptor) + os.unlink(_test_sess_db_path) + + def setUp(self): + self.sessionManager = SqliteSessionManager(self._test_sess_db_path) + + def tearDown(self): + if os.path.exists(self._test_sess_db_path): + os.remove(self._test_sess_db_path) + + def test_save(self): + SimpleSessionManagerTest.test_save(self) + self.assertTrue(os.path.exists(self._test_sess_db_path)) + + conn = sqlite3.connect(self._test_sess_db_path) + cursor = conn.cursor() + cursor.execute("SELECT username, path FROM session WHERE hkey=?", + (self.test_hkey,)) + res = cursor.fetchone() + conn.close() + + self.assertEqual(res[0], self.test_session.name) + self.assertEqual(res[1], self.test_session.path) + + + diff --git a/tests/test_sync_app.py b/tests/test_sync_app.py index 1877863..8e3ff89 100644 --- a/tests/test_sync_app.py +++ b/tests/test_sync_app.py @@ -8,8 +8,6 @@ from anki.consts import SYNC_VER from ankisyncd.sync_app import SyncCollectionHandler from ankisyncd.sync_app import SyncUserSession -from ankisyncd.sync_app import SimpleSessionManager -from ankisyncd.sync_app import SqliteSessionManager from collection_test_base import CollectionTestBase @@ -67,68 +65,5 @@ class SyncCollectionHandlerTest(CollectionTestBase): self.assertEqual(meta['cont'], True) -class SimpleSessionManagerTest(unittest.TestCase): - test_hkey = '1234567890' - sdir = tempfile.mkdtemp(suffix="_session") - os.rmdir(sdir) - test_session = SyncUserSession('testName', sdir, None, None) - - def setUp(self): - self.sessionManager = SimpleSessionManager() - - def tearDown(self): - self.sessionManager = None - - def test_save(self): - self.sessionManager.save(self.test_hkey, self.test_session) - self.assertEqual(self.sessionManager.sessions[self.test_hkey].name, - self.test_session.name) - self.assertEqual(self.sessionManager.sessions[self.test_hkey].path, - self.test_session.path) - - def test_delete(self): - self.sessionManager.save(self.test_hkey, self.test_session) - self.assertTrue(self.test_hkey in self.sessionManager.sessions) - - self.sessionManager.delete(self.test_hkey) - - self.assertTrue(self.test_hkey not in self.sessionManager.sessions) - - def test_load(self): - self.sessionManager.save(self.test_hkey, self.test_session) - self.assertTrue(self.test_hkey in self.sessionManager.sessions) - - loaded_session = self.sessionManager.load(self.test_hkey) - self.assertEqual(loaded_session.name, self.test_session.name) - self.assertEqual(loaded_session.path, self.test_session.path) - - -class SqliteSessionManagerTest(SimpleSessionManagerTest): - file_descriptor, _test_sess_db_path = tempfile.mkstemp(suffix=".db") - os.close(file_descriptor) - os.unlink(_test_sess_db_path) - - def setUp(self): - self.sessionManager = SqliteSessionManager(self._test_sess_db_path) - - def tearDown(self): - if os.path.exists(self._test_sess_db_path): - os.remove(self._test_sess_db_path) - - def test_save(self): - SimpleSessionManagerTest.test_save(self) - self.assertTrue(os.path.exists(self._test_sess_db_path)) - - conn = sqlite3.connect(self._test_sess_db_path) - cursor = conn.cursor() - cursor.execute("SELECT user, path FROM session WHERE hkey=?", - (self.test_hkey,)) - res = cursor.fetchone() - conn.close() - - self.assertEqual(res[0], self.test_session.name) - self.assertEqual(res[1], self.test_session.path) - - class SyncAppTest(unittest.TestCase): pass