Use factory method for session_manager
Also add some abstraction over the SQL to allow for different SQL dialects
This commit is contained in:
parent
ea0cbc669b
commit
50cc6a12d9
126
ankisyncd/sessions.py
Normal file
126
ankisyncd/sessions.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from sqlite3 import dbapi2 as sqlite
|
||||||
|
|
||||||
|
logger = logging.getLogger("ankisyncd.sessions")
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleSessionManager:
|
||||||
|
"""A simple session manager that keeps the sessions in memory."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.sessions = {}
|
||||||
|
|
||||||
|
def load(self, hkey, session_factory=None):
|
||||||
|
return self.sessions.get(hkey)
|
||||||
|
|
||||||
|
def load_from_skey(self, skey, session_factory=None):
|
||||||
|
for i in self.sessions:
|
||||||
|
if self.sessions[i].skey == skey:
|
||||||
|
return self.sessions[i]
|
||||||
|
|
||||||
|
def save(self, hkey, session):
|
||||||
|
self.sessions[hkey] = session
|
||||||
|
|
||||||
|
def delete(self, hkey):
|
||||||
|
del self.sessions[hkey]
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteSessionManager(SimpleSessionManager):
|
||||||
|
"""Stores sessions in a SQLite database to prevent the user from being logged out
|
||||||
|
everytime the SyncApp is restarted."""
|
||||||
|
|
||||||
|
def __init__(self, session_db_path):
|
||||||
|
SimpleSessionManager.__init__(self)
|
||||||
|
|
||||||
|
self.session_db_path = os.path.realpath(session_db_path)
|
||||||
|
|
||||||
|
def _conn(self):
|
||||||
|
new = not os.path.exists(self.session_db_path)
|
||||||
|
conn = sqlite.connect(self.session_db_path)
|
||||||
|
if new:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, username VARCHAR, path VARCHAR)")
|
||||||
|
return conn
|
||||||
|
|
||||||
|
# Default to using sqlite3 syntax but overridable for sub-classes using other
|
||||||
|
# DB API 2 driver variants
|
||||||
|
@staticmethod
|
||||||
|
def fs(sql):
|
||||||
|
return sql
|
||||||
|
|
||||||
|
def load(self, hkey, session_factory=None):
|
||||||
|
session = SimpleSessionManager.load(self, hkey)
|
||||||
|
if session is not None:
|
||||||
|
return session
|
||||||
|
|
||||||
|
conn = self._conn()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
cursor.execute(self.fs("SELECT skey, username, path FROM session WHERE hkey=?"), (hkey,))
|
||||||
|
res = cursor.fetchone()
|
||||||
|
|
||||||
|
if res is not None:
|
||||||
|
session = self.sessions[hkey] = session_factory(res[1], res[2])
|
||||||
|
session.skey = res[0]
|
||||||
|
return session
|
||||||
|
|
||||||
|
def load_from_skey(self, skey, session_factory=None):
|
||||||
|
session = SimpleSessionManager.load_from_skey(self, skey)
|
||||||
|
if session is not None:
|
||||||
|
return session
|
||||||
|
|
||||||
|
conn = self._conn()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
cursor.execute(self.fs("SELECT hkey, username, path FROM session WHERE skey=?"), (skey,))
|
||||||
|
res = cursor.fetchone()
|
||||||
|
|
||||||
|
if res is not None:
|
||||||
|
session = self.sessions[res[0]] = session_factory(res[1], res[2])
|
||||||
|
session.skey = skey
|
||||||
|
return session
|
||||||
|
|
||||||
|
def save(self, hkey, session):
|
||||||
|
SimpleSessionManager.save(self, hkey, session)
|
||||||
|
|
||||||
|
conn = self._conn()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
cursor.execute("INSERT OR REPLACE INTO session (hkey, skey, username, path) VALUES (?, ?, ?, ?)",
|
||||||
|
(hkey, session.skey, session.name, session.path))
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def delete(self, hkey):
|
||||||
|
SimpleSessionManager.delete(self, hkey)
|
||||||
|
|
||||||
|
conn = self._conn()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
cursor.execute(self.fs("DELETE FROM session WHERE hkey=?"), (hkey,))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def get_session_manager(config):
|
||||||
|
if "session_db_path" in config and config["session_db_path"]:
|
||||||
|
logger.info("Found session_db_path in config, using SqliteSessionManager for auth")
|
||||||
|
return SqliteSessionManager(config['session_db_path'])
|
||||||
|
elif "session_manager" in config and config["session_manager"]: # load from config
|
||||||
|
logger.info("Found session_manager in config, using {} for persisting sessions".format(
|
||||||
|
config['session_manager'])
|
||||||
|
)
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
module_name, class_name = config['session_manager'].rsplit('.', 1)
|
||||||
|
module = importlib.import_module(module_name.strip())
|
||||||
|
class_ = getattr(module, class_name.strip())
|
||||||
|
|
||||||
|
if not SimpleSessionManager in inspect.getmro(class_):
|
||||||
|
raise TypeError('''"session_manager" found in the conf file but it doesn''t
|
||||||
|
inherit from SimpleSessionManager''')
|
||||||
|
return class_(config)
|
||||||
|
else:
|
||||||
|
logger.warning("Neither session_db_path nor session_manager set, "
|
||||||
|
"ankisyncd will lose sessions on application restart")
|
||||||
|
return SimpleSessionManager()
|
||||||
@ -41,6 +41,7 @@ from anki.consts import SYNC_VER, SYNC_ZIP_SIZE, SYNC_ZIP_COUNT
|
|||||||
from anki.consts import REM_CARD, REM_NOTE
|
from anki.consts import REM_CARD, REM_NOTE
|
||||||
|
|
||||||
from ankisyncd.users import get_user_manager
|
from ankisyncd.users import get_user_manager
|
||||||
|
from ankisyncd.sessions import get_session_manager
|
||||||
|
|
||||||
logger = logging.getLogger("ankisyncd")
|
logger = logging.getLogger("ankisyncd")
|
||||||
|
|
||||||
@ -382,26 +383,6 @@ class SyncUserSession:
|
|||||||
handler.col = col
|
handler.col = col
|
||||||
return handler
|
return handler
|
||||||
|
|
||||||
class SimpleSessionManager:
|
|
||||||
"""A simple session manager that keeps the sessions in memory."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.sessions = {}
|
|
||||||
|
|
||||||
def load(self, hkey, session_factory=None):
|
|
||||||
return self.sessions.get(hkey)
|
|
||||||
|
|
||||||
def load_from_skey(self, skey, session_factory=None):
|
|
||||||
for i in self.sessions:
|
|
||||||
if self.sessions[i].skey == skey:
|
|
||||||
return self.sessions[i]
|
|
||||||
|
|
||||||
def save(self, hkey, session):
|
|
||||||
self.sessions[hkey] = session
|
|
||||||
|
|
||||||
def delete(self, hkey):
|
|
||||||
del self.sessions[hkey]
|
|
||||||
|
|
||||||
class SyncApp:
|
class SyncApp:
|
||||||
valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download']
|
valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download']
|
||||||
|
|
||||||
@ -416,12 +397,8 @@ class SyncApp:
|
|||||||
self.prehooks = {}
|
self.prehooks = {}
|
||||||
self.posthooks = {}
|
self.posthooks = {}
|
||||||
|
|
||||||
if "session_db_path" in config:
|
|
||||||
self.session_manager = SqliteSessionManager(config['session_db_path'])
|
|
||||||
else:
|
|
||||||
self.session_manager = SimpleSessionManager()
|
|
||||||
|
|
||||||
self.user_manager = get_user_manager(config)
|
self.user_manager = get_user_manager(config)
|
||||||
|
self.session_manager = get_session_manager(config)
|
||||||
self.collection_manager = getCollectionManager()
|
self.collection_manager = getCollectionManager()
|
||||||
|
|
||||||
# make sure the base_url has a trailing slash
|
# make sure the base_url has a trailing slash
|
||||||
@ -680,73 +657,6 @@ class SyncApp:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
class SqliteSessionManager(SimpleSessionManager):
|
|
||||||
"""Stores sessions in a SQLite database to prevent the user from being logged out
|
|
||||||
everytime the SyncApp is restarted."""
|
|
||||||
|
|
||||||
def __init__(self, session_db_path):
|
|
||||||
SimpleSessionManager.__init__(self)
|
|
||||||
|
|
||||||
self.session_db_path = os.path.realpath(session_db_path)
|
|
||||||
|
|
||||||
def _conn(self):
|
|
||||||
new = not os.path.exists(self.session_db_path)
|
|
||||||
conn = sqlite.connect(self.session_db_path)
|
|
||||||
if new:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute("CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, user VARCHAR, path VARCHAR)")
|
|
||||||
return conn
|
|
||||||
|
|
||||||
def load(self, hkey, session_factory=None):
|
|
||||||
session = SimpleSessionManager.load(self, hkey)
|
|
||||||
if session is not None:
|
|
||||||
return session
|
|
||||||
|
|
||||||
conn = self._conn()
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
cursor.execute("SELECT skey, user, path FROM session WHERE hkey=?", (hkey,))
|
|
||||||
res = cursor.fetchone()
|
|
||||||
|
|
||||||
if res is not None:
|
|
||||||
session = self.sessions[hkey] = session_factory(res[1], res[2])
|
|
||||||
session.skey = res[0]
|
|
||||||
return session
|
|
||||||
|
|
||||||
def load_from_skey(self, skey, session_factory=None):
|
|
||||||
session = SimpleSessionManager.load_from_skey(self, skey)
|
|
||||||
if session is not None:
|
|
||||||
return session
|
|
||||||
|
|
||||||
conn = self._conn()
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
cursor.execute("SELECT hkey, user, path FROM session WHERE skey=?", (skey,))
|
|
||||||
res = cursor.fetchone()
|
|
||||||
|
|
||||||
if res is not None:
|
|
||||||
session = self.sessions[res[0]] = session_factory(res[1], res[2])
|
|
||||||
session.skey = skey
|
|
||||||
return session
|
|
||||||
|
|
||||||
def save(self, hkey, session):
|
|
||||||
SimpleSessionManager.save(self, hkey, session)
|
|
||||||
|
|
||||||
conn = self._conn()
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
cursor.execute("INSERT OR REPLACE INTO session (hkey, skey, user, path) VALUES (?, ?, ?, ?)",
|
|
||||||
(hkey, session.skey, session.name, session.path))
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def delete(self, hkey):
|
|
||||||
SimpleSessionManager.delete(self, hkey)
|
|
||||||
|
|
||||||
conn = self._conn()
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
cursor.execute("DELETE FROM session WHERE hkey=?", (hkey,))
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def make_app(global_conf, **local_conf):
|
def make_app(global_conf, **local_conf):
|
||||||
return SyncApp(**local_conf)
|
return SyncApp(**local_conf)
|
||||||
|
|||||||
123
tests/test_sessions.py
Normal file
123
tests/test_sessions.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import sqlite3
|
||||||
|
import unittest
|
||||||
|
import configparser
|
||||||
|
|
||||||
|
from ankisyncd.sessions import SimpleSessionManager
|
||||||
|
from ankisyncd.sessions import SqliteSessionManager
|
||||||
|
from ankisyncd.sessions import get_session_manager
|
||||||
|
|
||||||
|
from ankisyncd.sync_app import SyncUserSession
|
||||||
|
|
||||||
|
import helpers.server_utils
|
||||||
|
|
||||||
|
class FakeSessionManager(SimpleSessionManager):
|
||||||
|
def __init__(self, config):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class BadSessionManager:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class SessionManagerFactoryTest(unittest.TestCase):
|
||||||
|
def test_get_session_manager(self):
|
||||||
|
# Get absolute path to development ini file.
|
||||||
|
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
ini_file_path = os.path.join(script_dir,
|
||||||
|
"assets",
|
||||||
|
"test.conf")
|
||||||
|
|
||||||
|
# Create temporary files and dirs the server will use.
|
||||||
|
server_paths = helpers.server_utils.create_server_paths()
|
||||||
|
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read(ini_file_path)
|
||||||
|
|
||||||
|
# Use custom files and dirs in settings. Should be SqliteSessionManager
|
||||||
|
config['sync_app'].update(server_paths)
|
||||||
|
self.assertTrue(type(get_session_manager(config['sync_app']) == SqliteSessionManager))
|
||||||
|
|
||||||
|
# No value defaults to SimpleSessionManager
|
||||||
|
config.remove_option("sync_app", "session_db_path")
|
||||||
|
self.assertTrue(type(get_session_manager(config['sync_app'])) == SimpleSessionManager)
|
||||||
|
|
||||||
|
# A conf-specified SessionManager is loaded
|
||||||
|
config.set("sync_app", "session_manager", 'test_sessions.FakeSessionManager')
|
||||||
|
self.assertTrue(type(get_session_manager(config['sync_app'])) == FakeSessionManager)
|
||||||
|
|
||||||
|
# Should fail at load time if the class doesn't inherit from SimpleSessionManager
|
||||||
|
config.set("sync_app", "session_manager", 'test_sessions.BadSessionManager')
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
sm = get_session_manager(config['sync_app'])
|
||||||
|
|
||||||
|
# Add the session_db_path back, it should take precedence over BadSessionManager
|
||||||
|
config['sync_app'].update(server_paths)
|
||||||
|
self.assertTrue(type(get_session_manager(config['sync_app'])) == SqliteSessionManager)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleSessionManagerTest(unittest.TestCase):
|
||||||
|
test_hkey = '1234567890'
|
||||||
|
sdir = tempfile.mkdtemp(suffix="_session")
|
||||||
|
os.rmdir(sdir)
|
||||||
|
test_session = SyncUserSession('testName', sdir, None, None)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.sessionManager = SimpleSessionManager()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.sessionManager = None
|
||||||
|
|
||||||
|
def test_save(self):
|
||||||
|
self.sessionManager.save(self.test_hkey, self.test_session)
|
||||||
|
self.assertEqual(self.sessionManager.sessions[self.test_hkey].name,
|
||||||
|
self.test_session.name)
|
||||||
|
self.assertEqual(self.sessionManager.sessions[self.test_hkey].path,
|
||||||
|
self.test_session.path)
|
||||||
|
|
||||||
|
def test_delete(self):
|
||||||
|
self.sessionManager.save(self.test_hkey, self.test_session)
|
||||||
|
self.assertTrue(self.test_hkey in self.sessionManager.sessions)
|
||||||
|
|
||||||
|
self.sessionManager.delete(self.test_hkey)
|
||||||
|
|
||||||
|
self.assertTrue(self.test_hkey not in self.sessionManager.sessions)
|
||||||
|
|
||||||
|
def test_load(self):
|
||||||
|
self.sessionManager.save(self.test_hkey, self.test_session)
|
||||||
|
self.assertTrue(self.test_hkey in self.sessionManager.sessions)
|
||||||
|
|
||||||
|
loaded_session = self.sessionManager.load(self.test_hkey)
|
||||||
|
self.assertEqual(loaded_session.name, self.test_session.name)
|
||||||
|
self.assertEqual(loaded_session.path, self.test_session.path)
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteSessionManagerTest(SimpleSessionManagerTest):
|
||||||
|
file_descriptor, _test_sess_db_path = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(file_descriptor)
|
||||||
|
os.unlink(_test_sess_db_path)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.sessionManager = SqliteSessionManager(self._test_sess_db_path)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
if os.path.exists(self._test_sess_db_path):
|
||||||
|
os.remove(self._test_sess_db_path)
|
||||||
|
|
||||||
|
def test_save(self):
|
||||||
|
SimpleSessionManagerTest.test_save(self)
|
||||||
|
self.assertTrue(os.path.exists(self._test_sess_db_path))
|
||||||
|
|
||||||
|
conn = sqlite3.connect(self._test_sess_db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT username, path FROM session WHERE hkey=?",
|
||||||
|
(self.test_hkey,))
|
||||||
|
res = cursor.fetchone()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
self.assertEqual(res[0], self.test_session.name)
|
||||||
|
self.assertEqual(res[1], self.test_session.path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -8,8 +8,6 @@ from anki.consts import SYNC_VER
|
|||||||
|
|
||||||
from ankisyncd.sync_app import SyncCollectionHandler
|
from ankisyncd.sync_app import SyncCollectionHandler
|
||||||
from ankisyncd.sync_app import SyncUserSession
|
from ankisyncd.sync_app import SyncUserSession
|
||||||
from ankisyncd.sync_app import SimpleSessionManager
|
|
||||||
from ankisyncd.sync_app import SqliteSessionManager
|
|
||||||
|
|
||||||
from collection_test_base import CollectionTestBase
|
from collection_test_base import CollectionTestBase
|
||||||
|
|
||||||
@ -67,68 +65,5 @@ class SyncCollectionHandlerTest(CollectionTestBase):
|
|||||||
self.assertEqual(meta['cont'], True)
|
self.assertEqual(meta['cont'], True)
|
||||||
|
|
||||||
|
|
||||||
class SimpleSessionManagerTest(unittest.TestCase):
|
|
||||||
test_hkey = '1234567890'
|
|
||||||
sdir = tempfile.mkdtemp(suffix="_session")
|
|
||||||
os.rmdir(sdir)
|
|
||||||
test_session = SyncUserSession('testName', sdir, None, None)
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.sessionManager = SimpleSessionManager()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
self.sessionManager = None
|
|
||||||
|
|
||||||
def test_save(self):
|
|
||||||
self.sessionManager.save(self.test_hkey, self.test_session)
|
|
||||||
self.assertEqual(self.sessionManager.sessions[self.test_hkey].name,
|
|
||||||
self.test_session.name)
|
|
||||||
self.assertEqual(self.sessionManager.sessions[self.test_hkey].path,
|
|
||||||
self.test_session.path)
|
|
||||||
|
|
||||||
def test_delete(self):
|
|
||||||
self.sessionManager.save(self.test_hkey, self.test_session)
|
|
||||||
self.assertTrue(self.test_hkey in self.sessionManager.sessions)
|
|
||||||
|
|
||||||
self.sessionManager.delete(self.test_hkey)
|
|
||||||
|
|
||||||
self.assertTrue(self.test_hkey not in self.sessionManager.sessions)
|
|
||||||
|
|
||||||
def test_load(self):
|
|
||||||
self.sessionManager.save(self.test_hkey, self.test_session)
|
|
||||||
self.assertTrue(self.test_hkey in self.sessionManager.sessions)
|
|
||||||
|
|
||||||
loaded_session = self.sessionManager.load(self.test_hkey)
|
|
||||||
self.assertEqual(loaded_session.name, self.test_session.name)
|
|
||||||
self.assertEqual(loaded_session.path, self.test_session.path)
|
|
||||||
|
|
||||||
|
|
||||||
class SqliteSessionManagerTest(SimpleSessionManagerTest):
|
|
||||||
file_descriptor, _test_sess_db_path = tempfile.mkstemp(suffix=".db")
|
|
||||||
os.close(file_descriptor)
|
|
||||||
os.unlink(_test_sess_db_path)
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.sessionManager = SqliteSessionManager(self._test_sess_db_path)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
if os.path.exists(self._test_sess_db_path):
|
|
||||||
os.remove(self._test_sess_db_path)
|
|
||||||
|
|
||||||
def test_save(self):
|
|
||||||
SimpleSessionManagerTest.test_save(self)
|
|
||||||
self.assertTrue(os.path.exists(self._test_sess_db_path))
|
|
||||||
|
|
||||||
conn = sqlite3.connect(self._test_sess_db_path)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute("SELECT user, path FROM session WHERE hkey=?",
|
|
||||||
(self.test_hkey,))
|
|
||||||
res = cursor.fetchone()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
self.assertEqual(res[0], self.test_session.name)
|
|
||||||
self.assertEqual(res[1], self.test_session.path)
|
|
||||||
|
|
||||||
|
|
||||||
class SyncAppTest(unittest.TestCase):
|
class SyncAppTest(unittest.TestCase):
|
||||||
pass
|
pass
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user