Load the CollectionWrapper from a factory method

This allows a class implementing CollectionWrapper's interface to be
added from config
This commit is contained in:
Anton Melser 2019-01-28 21:39:57 +08:00
parent 9ee9697582
commit fa89b0e0a2
4 changed files with 87 additions and 16 deletions

View File

@ -2,6 +2,10 @@ import anki
import anki.storage import anki.storage
import os, errno import os, errno
import logging
logger = logging.getLogger("ankisyncd.collection")
class CollectionWrapper: class CollectionWrapper:
"""A simple wrapper around an anki.storage.Collection object. """A simple wrapper around an anki.storage.Collection object.
@ -9,12 +13,12 @@ class CollectionWrapper:
This allows us to manage and refer to the collection, whether it's open or not. It This allows us to manage and refer to the collection, whether it's open or not. It
also provides a special "continuation passing" interface for executing functions also provides a special "continuation passing" interface for executing functions
on the collection, which makes it easy to switch to a threading mode. on the collection, which makes it easy to switch to a threading mode.
See ThreadingCollectionWrapper for a version that maintains a seperate thread for See ThreadingCollectionWrapper for a version that maintains a seperate thread for
interacting with the collection. interacting with the collection.
""" """
def __init__(self, path, setup_new_collection=None): def __init__(self, _config, path, setup_new_collection=None):
self.path = os.path.realpath(path) self.path = os.path.realpath(path)
self.username = os.path.basename(os.path.dirname(self.path)) self.username = os.path.basename(os.path.dirname(self.path))
self.setup_new_collection = setup_new_collection self.setup_new_collection = setup_new_collection
@ -50,7 +54,7 @@ class CollectionWrapper:
# mkdir -p the path, because it might not exist # mkdir -p the path, because it might not exist
os.makedirs(os.path.dirname(self.path), exist_ok=True) os.makedirs(os.path.dirname(self.path), exist_ok=True)
col = anki.storage.Collection(self.path) col = self._get_collection()
# Do any special setup # Do any special setup
if self.setup_new_collection is not None: if self.setup_new_collection is not None:
@ -58,11 +62,14 @@ class CollectionWrapper:
return col return col
def _get_collection(self):
return anki.storage.Collection(self.path)
def open(self): def open(self):
"""Open the collection, or create it if it doesn't exist.""" """Open the collection, or create it if it doesn't exist."""
if self.__col is None: if self.__col is None:
if os.path.exists(self.path): if os.path.exists(self.path):
self.__col = anki.storage.Collection(self.path) self.__col = self._get_collection()
else: else:
self.__col = self.__create_collection() self.__col = self.__create_collection()
@ -83,8 +90,9 @@ class CollectionManager:
collection_wrapper = CollectionWrapper collection_wrapper = CollectionWrapper
def __init__(self): def __init__(self, config):
self.collections = {} self.collections = {}
self.config = config
def get_collection(self, path, setup_new_collection=None): def get_collection(self, path, setup_new_collection=None):
"""Gets a CollectionWrapper for the given path.""" """Gets a CollectionWrapper for the given path."""
@ -94,7 +102,7 @@ class CollectionManager:
try: try:
col = self.collections[path] col = self.collections[path]
except KeyError: except KeyError:
col = self.collections[path] = self.collection_wrapper(path, setup_new_collection) col = self.collections[path] = self.collection_wrapper(self.config, path, setup_new_collection)
return col return col
@ -104,3 +112,19 @@ class CollectionManager:
del self.collections[path] del self.collections[path]
col.close() col.close()
def get_collection_wrapper(config, path, setup_new_collection = None):
if "collection_wrapper" in config and config["collection_wrapper"]:
logger.info("Found collection_wrapper in config, using {} for "
"user data persistence".format(config['collection_wrapper']))
import importlib
import inspect
module_name, class_name = config['collection_wrapper'].rsplit('.', 1)
module = importlib.import_module(module_name.strip())
class_ = getattr(module, class_name.strip())
if not CollectionWrapper in inspect.getmro(class_):
raise TypeError('''"collection_wrapper" found in the conf file but it doesn''t
inherit from CollectionWrapper''')
return class_(config, path, setup_new_collection)
else:
return CollectionWrapper(config, path, setup_new_collection)

View File

@ -388,7 +388,7 @@ class SyncApp:
valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download'] valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download']
def __init__(self, config): def __init__(self, config):
from ankisyncd.thread import getCollectionManager from ankisyncd.thread import get_collection_manager
self.data_root = os.path.abspath(config['data_root']) self.data_root = os.path.abspath(config['data_root'])
self.base_url = config['base_url'] self.base_url = config['base_url']
@ -401,7 +401,7 @@ class SyncApp:
self.user_manager = get_user_manager(config) self.user_manager = get_user_manager(config)
self.session_manager = get_session_manager(config) self.session_manager = get_session_manager(config)
self.full_sync_manager = get_full_sync_manager(config) self.full_sync_manager = get_full_sync_manager(config)
self.collection_manager = getCollectionManager() self.collection_manager = get_collection_manager(config)
# make sure the base_url has a trailing slash # make sure the base_url has a trailing slash
if not self.base_url.endswith('/'): if not self.base_url.endswith('/'):

View File

@ -1,4 +1,4 @@
from ankisyncd.collection import CollectionWrapper, CollectionManager from ankisyncd.collection import CollectionManager, get_collection_wrapper
from threading import Thread from threading import Thread
from queue import Queue from queue import Queue
@ -29,12 +29,12 @@ def short_repr(obj, logger=logging.getLogger(), maxlen=80):
return repr(o) return repr(o)
class ThreadingCollectionWrapper: class ThreadingCollectionWrapper:
"""Provides the same interface as CollectionWrapper, but it creates a new Thread to """Provides the same interface as CollectionWrapper, but it creates a new Thread to
interact with the collection.""" interact with the collection."""
def __init__(self, path, setup_new_collection=None): def __init__(self, config, path, setup_new_collection=None):
self.path = path self.path = path
self.wrapper = CollectionWrapper(path, setup_new_collection) self.wrapper = get_collection_wrapper(config, path, setup_new_collection)
self.logger = logging.getLogger("ankisyncd." + str(self)) self.logger = logging.getLogger("ankisyncd." + str(self))
self._queue = Queue() self._queue = Queue()
@ -156,8 +156,8 @@ class ThreadingCollectionManager(CollectionManager):
collection_wrapper = ThreadingCollectionWrapper collection_wrapper = ThreadingCollectionWrapper
def __init__(self): def __init__(self, config):
super(ThreadingCollectionManager, self).__init__() super(ThreadingCollectionManager, self).__init__(config)
self.monitor_frequency = 15 self.monitor_frequency = 15
self.monitor_inactivity = 90 self.monitor_inactivity = 90
@ -202,11 +202,11 @@ class ThreadingCollectionManager(CollectionManager):
collection_manager = None collection_manager = None
def getCollectionManager(): def get_collection_manager(config):
"""Return the global ThreadingCollectionManager for this process.""" """Return the global ThreadingCollectionManager for this process."""
global collection_manager global collection_manager
if collection_manager is None: if collection_manager is None:
collection_manager = ThreadingCollectionManager() collection_manager = ThreadingCollectionManager(config)
return collection_manager return collection_manager
def shutdown(): def shutdown():

View File

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
import os
import unittest
import configparser
from ankisyncd.collection import CollectionWrapper
from ankisyncd.collection import get_collection_wrapper
import helpers.server_utils
class FakeCollectionWrapper(CollectionWrapper):
def __init__(self, config, path, setup_new_collection=None):
self. _CollectionWrapper__col = None
pass
class BadCollectionWrapper:
pass
class CollectionWrapperFactoryTest(unittest.TestCase):
def test_get_collection_wrapper(self):
# Get absolute path to development ini file.
script_dir = os.path.dirname(os.path.realpath(__file__))
ini_file_path = os.path.join(script_dir,
"assets",
"test.conf")
# Create temporary files and dirs the server will use.
server_paths = helpers.server_utils.create_server_paths()
config = configparser.ConfigParser()
config.read(ini_file_path)
path = os.path.realpath('fake/collection.anki2')
# Use custom files and dirs in settings. Should be CollectionWrapper
config['sync_app'].update(server_paths)
self.assertTrue(type(get_collection_wrapper(config['sync_app'], path) == CollectionWrapper))
# A conf-specified CollectionWrapper is loaded
config.set("sync_app", "collection_wrapper", 'test_collection_wrappers.FakeCollectionWrapper')
self.assertTrue(type(get_collection_wrapper(config['sync_app'], path)) == FakeCollectionWrapper)
# Should fail at load time if the class doesn't inherit from CollectionWrapper
config.set("sync_app", "collection_wrapper", 'test_collection_wrappers.BadCollectionWrapper')
with self.assertRaises(TypeError):
pm = get_collection_wrapper(config['sync_app'], path)