anki-sync-server/ankisyncd/sessions.py
Anton Melser 50cc6a12d9 Use factory method for session_manager
Also add some abstraction over the SQL to allow for different SQL
dialects
2019-03-09 14:05:15 +08:00

127 lines
4.2 KiB
Python

# -*- coding: utf-8 -*-
import os
import logging
from sqlite3 import dbapi2 as sqlite
logger = logging.getLogger("ankisyncd.sessions")
class SimpleSessionManager:
"""A simple session manager that keeps the sessions in memory."""
def __init__(self):
self.sessions = {}
def load(self, hkey, session_factory=None):
return self.sessions.get(hkey)
def load_from_skey(self, skey, session_factory=None):
for i in self.sessions:
if self.sessions[i].skey == skey:
return self.sessions[i]
def save(self, hkey, session):
self.sessions[hkey] = session
def delete(self, hkey):
del self.sessions[hkey]
class SqliteSessionManager(SimpleSessionManager):
"""Stores sessions in a SQLite database to prevent the user from being logged out
everytime the SyncApp is restarted."""
def __init__(self, session_db_path):
SimpleSessionManager.__init__(self)
self.session_db_path = os.path.realpath(session_db_path)
def _conn(self):
new = not os.path.exists(self.session_db_path)
conn = sqlite.connect(self.session_db_path)
if new:
cursor = conn.cursor()
cursor.execute("CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, username VARCHAR, path VARCHAR)")
return conn
# Default to using sqlite3 syntax but overridable for sub-classes using other
# DB API 2 driver variants
@staticmethod
def fs(sql):
return sql
def load(self, hkey, session_factory=None):
session = SimpleSessionManager.load(self, hkey)
if session is not None:
return session
conn = self._conn()
cursor = conn.cursor()
cursor.execute(self.fs("SELECT skey, username, path FROM session WHERE hkey=?"), (hkey,))
res = cursor.fetchone()
if res is not None:
session = self.sessions[hkey] = session_factory(res[1], res[2])
session.skey = res[0]
return session
def load_from_skey(self, skey, session_factory=None):
session = SimpleSessionManager.load_from_skey(self, skey)
if session is not None:
return session
conn = self._conn()
cursor = conn.cursor()
cursor.execute(self.fs("SELECT hkey, username, path FROM session WHERE skey=?"), (skey,))
res = cursor.fetchone()
if res is not None:
session = self.sessions[res[0]] = session_factory(res[1], res[2])
session.skey = skey
return session
def save(self, hkey, session):
SimpleSessionManager.save(self, hkey, session)
conn = self._conn()
cursor = conn.cursor()
cursor.execute("INSERT OR REPLACE INTO session (hkey, skey, username, path) VALUES (?, ?, ?, ?)",
(hkey, session.skey, session.name, session.path))
conn.commit()
def delete(self, hkey):
SimpleSessionManager.delete(self, hkey)
conn = self._conn()
cursor = conn.cursor()
cursor.execute(self.fs("DELETE FROM session WHERE hkey=?"), (hkey,))
conn.commit()
def get_session_manager(config):
if "session_db_path" in config and config["session_db_path"]:
logger.info("Found session_db_path in config, using SqliteSessionManager for auth")
return SqliteSessionManager(config['session_db_path'])
elif "session_manager" in config and config["session_manager"]: # load from config
logger.info("Found session_manager in config, using {} for persisting sessions".format(
config['session_manager'])
)
import importlib
import inspect
module_name, class_name = config['session_manager'].rsplit('.', 1)
module = importlib.import_module(module_name.strip())
class_ = getattr(module, class_name.strip())
if not SimpleSessionManager in inspect.getmro(class_):
raise TypeError('''"session_manager" found in the conf file but it doesn''t
inherit from SimpleSessionManager''')
return class_(config)
else:
logger.warning("Neither session_db_path nor session_manager set, "
"ankisyncd will lose sessions on application restart")
return SimpleSessionManager()