diff --git a/ankisyncd/collection.py b/ankisyncd/collection.py index a6df739..4603f4b 100644 --- a/ankisyncd/collection.py +++ b/ankisyncd/collection.py @@ -2,6 +2,10 @@ import anki import anki.storage import os, errno +import logging + +logger = logging.getLogger("ankisyncd.collection") + class CollectionWrapper: """A simple wrapper around an anki.storage.Collection object. @@ -9,12 +13,12 @@ class CollectionWrapper: This allows us to manage and refer to the collection, whether it's open or not. It also provides a special "continuation passing" interface for executing functions on the collection, which makes it easy to switch to a threading mode. - + See ThreadingCollectionWrapper for a version that maintains a seperate thread for interacting with the collection. """ - def __init__(self, path, setup_new_collection=None): + def __init__(self, _config, path, setup_new_collection=None): self.path = os.path.realpath(path) self.username = os.path.basename(os.path.dirname(self.path)) self.setup_new_collection = setup_new_collection @@ -50,7 +54,7 @@ class CollectionWrapper: # mkdir -p the path, because it might not exist os.makedirs(os.path.dirname(self.path), exist_ok=True) - col = anki.storage.Collection(self.path) + col = self._get_collection() # Do any special setup if self.setup_new_collection is not None: @@ -58,11 +62,14 @@ class CollectionWrapper: return col + def _get_collection(self): + return anki.storage.Collection(self.path) + def open(self): """Open the collection, or create it if it doesn't exist.""" if self.__col is None: if os.path.exists(self.path): - self.__col = anki.storage.Collection(self.path) + self.__col = self._get_collection() else: self.__col = self.__create_collection() @@ -83,8 +90,9 @@ class CollectionManager: collection_wrapper = CollectionWrapper - def __init__(self): + def __init__(self, config): self.collections = {} + self.config = config def get_collection(self, path, setup_new_collection=None): """Gets a CollectionWrapper for the given path.""" @@ -94,7 +102,7 @@ class CollectionManager: try: col = self.collections[path] except KeyError: - col = self.collections[path] = self.collection_wrapper(path, setup_new_collection) + col = self.collections[path] = self.collection_wrapper(self.config, path, setup_new_collection) return col @@ -104,3 +112,19 @@ class CollectionManager: del self.collections[path] col.close() +def get_collection_wrapper(config, path, setup_new_collection = None): + if "collection_wrapper" in config and config["collection_wrapper"]: + logger.info("Found collection_wrapper in config, using {} for " + "user data persistence".format(config['collection_wrapper'])) + import importlib + import inspect + module_name, class_name = config['collection_wrapper'].rsplit('.', 1) + module = importlib.import_module(module_name.strip()) + class_ = getattr(module, class_name.strip()) + + if not CollectionWrapper in inspect.getmro(class_): + raise TypeError('''"collection_wrapper" found in the conf file but it doesn''t + inherit from CollectionWrapper''') + return class_(config, path, setup_new_collection) + else: + return CollectionWrapper(config, path, setup_new_collection) diff --git a/ankisyncd/sync_app.py b/ankisyncd/sync_app.py index c4a8212..30a6dbb 100644 --- a/ankisyncd/sync_app.py +++ b/ankisyncd/sync_app.py @@ -388,7 +388,7 @@ class SyncApp: valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download'] def __init__(self, config): - from ankisyncd.thread import getCollectionManager + from ankisyncd.thread import get_collection_manager self.data_root = os.path.abspath(config['data_root']) self.base_url = config['base_url'] @@ -401,7 +401,7 @@ class SyncApp: self.user_manager = get_user_manager(config) self.session_manager = get_session_manager(config) self.full_sync_manager = get_full_sync_manager(config) - self.collection_manager = getCollectionManager() + self.collection_manager = get_collection_manager(config) # make sure the base_url has a trailing slash if not self.base_url.endswith('/'): diff --git a/ankisyncd/thread.py b/ankisyncd/thread.py index 95b5225..63e77c0 100644 --- a/ankisyncd/thread.py +++ b/ankisyncd/thread.py @@ -1,4 +1,4 @@ -from ankisyncd.collection import CollectionWrapper, CollectionManager +from ankisyncd.collection import CollectionManager, get_collection_wrapper from threading import Thread from queue import Queue @@ -29,12 +29,12 @@ def short_repr(obj, logger=logging.getLogger(), maxlen=80): return repr(o) class ThreadingCollectionWrapper: - """Provides the same interface as CollectionWrapper, but it creates a new Thread to + """Provides the same interface as CollectionWrapper, but it creates a new Thread to interact with the collection.""" - def __init__(self, path, setup_new_collection=None): + def __init__(self, config, path, setup_new_collection=None): self.path = path - self.wrapper = CollectionWrapper(path, setup_new_collection) + self.wrapper = get_collection_wrapper(config, path, setup_new_collection) self.logger = logging.getLogger("ankisyncd." + str(self)) self._queue = Queue() @@ -156,8 +156,8 @@ class ThreadingCollectionManager(CollectionManager): collection_wrapper = ThreadingCollectionWrapper - def __init__(self): - super(ThreadingCollectionManager, self).__init__() + def __init__(self, config): + super(ThreadingCollectionManager, self).__init__(config) self.monitor_frequency = 15 self.monitor_inactivity = 90 @@ -202,11 +202,11 @@ class ThreadingCollectionManager(CollectionManager): collection_manager = None -def getCollectionManager(): +def get_collection_manager(config): """Return the global ThreadingCollectionManager for this process.""" global collection_manager if collection_manager is None: - collection_manager = ThreadingCollectionManager() + collection_manager = ThreadingCollectionManager(config) return collection_manager def shutdown(): diff --git a/tests/test_collection_wrappers.py b/tests/test_collection_wrappers.py new file mode 100644 index 0000000..37812f4 --- /dev/null +++ b/tests/test_collection_wrappers.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- + +import os +import unittest +import configparser + +from ankisyncd.collection import CollectionWrapper +from ankisyncd.collection import get_collection_wrapper + +import helpers.server_utils + +class FakeCollectionWrapper(CollectionWrapper): + def __init__(self, config, path, setup_new_collection=None): + self. _CollectionWrapper__col = None + pass + +class BadCollectionWrapper: + pass + +class CollectionWrapperFactoryTest(unittest.TestCase): + def test_get_collection_wrapper(self): + # Get absolute path to development ini file. + script_dir = os.path.dirname(os.path.realpath(__file__)) + ini_file_path = os.path.join(script_dir, + "assets", + "test.conf") + + # Create temporary files and dirs the server will use. + server_paths = helpers.server_utils.create_server_paths() + + config = configparser.ConfigParser() + config.read(ini_file_path) + path = os.path.realpath('fake/collection.anki2') + + # Use custom files and dirs in settings. Should be CollectionWrapper + config['sync_app'].update(server_paths) + self.assertTrue(type(get_collection_wrapper(config['sync_app'], path) == CollectionWrapper)) + + # A conf-specified CollectionWrapper is loaded + config.set("sync_app", "collection_wrapper", 'test_collection_wrappers.FakeCollectionWrapper') + self.assertTrue(type(get_collection_wrapper(config['sync_app'], path)) == FakeCollectionWrapper) + + # Should fail at load time if the class doesn't inherit from CollectionWrapper + config.set("sync_app", "collection_wrapper", 'test_collection_wrappers.BadCollectionWrapper') + with self.assertRaises(TypeError): + pm = get_collection_wrapper(config['sync_app'], path) +