Load the CollectionWrapper from a factory method
This allows a class implementing CollectionWrapper's interface to be added from config
This commit is contained in:
parent
9ee9697582
commit
fa89b0e0a2
@ -2,6 +2,10 @@ import anki
|
||||
import anki.storage
|
||||
|
||||
import os, errno
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("ankisyncd.collection")
|
||||
|
||||
|
||||
class CollectionWrapper:
|
||||
"""A simple wrapper around an anki.storage.Collection object.
|
||||
@ -9,12 +13,12 @@ class CollectionWrapper:
|
||||
This allows us to manage and refer to the collection, whether it's open or not. It
|
||||
also provides a special "continuation passing" interface for executing functions
|
||||
on the collection, which makes it easy to switch to a threading mode.
|
||||
|
||||
|
||||
See ThreadingCollectionWrapper for a version that maintains a seperate thread for
|
||||
interacting with the collection.
|
||||
"""
|
||||
|
||||
def __init__(self, path, setup_new_collection=None):
|
||||
def __init__(self, _config, path, setup_new_collection=None):
|
||||
self.path = os.path.realpath(path)
|
||||
self.username = os.path.basename(os.path.dirname(self.path))
|
||||
self.setup_new_collection = setup_new_collection
|
||||
@ -50,7 +54,7 @@ class CollectionWrapper:
|
||||
# mkdir -p the path, because it might not exist
|
||||
os.makedirs(os.path.dirname(self.path), exist_ok=True)
|
||||
|
||||
col = anki.storage.Collection(self.path)
|
||||
col = self._get_collection()
|
||||
|
||||
# Do any special setup
|
||||
if self.setup_new_collection is not None:
|
||||
@ -58,11 +62,14 @@ class CollectionWrapper:
|
||||
|
||||
return col
|
||||
|
||||
def _get_collection(self):
|
||||
return anki.storage.Collection(self.path)
|
||||
|
||||
def open(self):
|
||||
"""Open the collection, or create it if it doesn't exist."""
|
||||
if self.__col is None:
|
||||
if os.path.exists(self.path):
|
||||
self.__col = anki.storage.Collection(self.path)
|
||||
self.__col = self._get_collection()
|
||||
else:
|
||||
self.__col = self.__create_collection()
|
||||
|
||||
@ -83,8 +90,9 @@ class CollectionManager:
|
||||
|
||||
collection_wrapper = CollectionWrapper
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, config):
|
||||
self.collections = {}
|
||||
self.config = config
|
||||
|
||||
def get_collection(self, path, setup_new_collection=None):
|
||||
"""Gets a CollectionWrapper for the given path."""
|
||||
@ -94,7 +102,7 @@ class CollectionManager:
|
||||
try:
|
||||
col = self.collections[path]
|
||||
except KeyError:
|
||||
col = self.collections[path] = self.collection_wrapper(path, setup_new_collection)
|
||||
col = self.collections[path] = self.collection_wrapper(self.config, path, setup_new_collection)
|
||||
|
||||
return col
|
||||
|
||||
@ -104,3 +112,19 @@ class CollectionManager:
|
||||
del self.collections[path]
|
||||
col.close()
|
||||
|
||||
def get_collection_wrapper(config, path, setup_new_collection = None):
|
||||
if "collection_wrapper" in config and config["collection_wrapper"]:
|
||||
logger.info("Found collection_wrapper in config, using {} for "
|
||||
"user data persistence".format(config['collection_wrapper']))
|
||||
import importlib
|
||||
import inspect
|
||||
module_name, class_name = config['collection_wrapper'].rsplit('.', 1)
|
||||
module = importlib.import_module(module_name.strip())
|
||||
class_ = getattr(module, class_name.strip())
|
||||
|
||||
if not CollectionWrapper in inspect.getmro(class_):
|
||||
raise TypeError('''"collection_wrapper" found in the conf file but it doesn''t
|
||||
inherit from CollectionWrapper''')
|
||||
return class_(config, path, setup_new_collection)
|
||||
else:
|
||||
return CollectionWrapper(config, path, setup_new_collection)
|
||||
|
||||
@ -388,7 +388,7 @@ class SyncApp:
|
||||
valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download']
|
||||
|
||||
def __init__(self, config):
|
||||
from ankisyncd.thread import getCollectionManager
|
||||
from ankisyncd.thread import get_collection_manager
|
||||
|
||||
self.data_root = os.path.abspath(config['data_root'])
|
||||
self.base_url = config['base_url']
|
||||
@ -401,7 +401,7 @@ class SyncApp:
|
||||
self.user_manager = get_user_manager(config)
|
||||
self.session_manager = get_session_manager(config)
|
||||
self.full_sync_manager = get_full_sync_manager(config)
|
||||
self.collection_manager = getCollectionManager()
|
||||
self.collection_manager = get_collection_manager(config)
|
||||
|
||||
# make sure the base_url has a trailing slash
|
||||
if not self.base_url.endswith('/'):
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from ankisyncd.collection import CollectionWrapper, CollectionManager
|
||||
from ankisyncd.collection import CollectionManager, get_collection_wrapper
|
||||
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
@ -29,12 +29,12 @@ def short_repr(obj, logger=logging.getLogger(), maxlen=80):
|
||||
return repr(o)
|
||||
|
||||
class ThreadingCollectionWrapper:
|
||||
"""Provides the same interface as CollectionWrapper, but it creates a new Thread to
|
||||
"""Provides the same interface as CollectionWrapper, but it creates a new Thread to
|
||||
interact with the collection."""
|
||||
|
||||
def __init__(self, path, setup_new_collection=None):
|
||||
def __init__(self, config, path, setup_new_collection=None):
|
||||
self.path = path
|
||||
self.wrapper = CollectionWrapper(path, setup_new_collection)
|
||||
self.wrapper = get_collection_wrapper(config, path, setup_new_collection)
|
||||
self.logger = logging.getLogger("ankisyncd." + str(self))
|
||||
|
||||
self._queue = Queue()
|
||||
@ -156,8 +156,8 @@ class ThreadingCollectionManager(CollectionManager):
|
||||
|
||||
collection_wrapper = ThreadingCollectionWrapper
|
||||
|
||||
def __init__(self):
|
||||
super(ThreadingCollectionManager, self).__init__()
|
||||
def __init__(self, config):
|
||||
super(ThreadingCollectionManager, self).__init__(config)
|
||||
|
||||
self.monitor_frequency = 15
|
||||
self.monitor_inactivity = 90
|
||||
@ -202,11 +202,11 @@ class ThreadingCollectionManager(CollectionManager):
|
||||
|
||||
collection_manager = None
|
||||
|
||||
def getCollectionManager():
|
||||
def get_collection_manager(config):
|
||||
"""Return the global ThreadingCollectionManager for this process."""
|
||||
global collection_manager
|
||||
if collection_manager is None:
|
||||
collection_manager = ThreadingCollectionManager()
|
||||
collection_manager = ThreadingCollectionManager(config)
|
||||
return collection_manager
|
||||
|
||||
def shutdown():
|
||||
|
||||
47
tests/test_collection_wrappers.py
Normal file
47
tests/test_collection_wrappers.py
Normal file
@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import configparser
|
||||
|
||||
from ankisyncd.collection import CollectionWrapper
|
||||
from ankisyncd.collection import get_collection_wrapper
|
||||
|
||||
import helpers.server_utils
|
||||
|
||||
class FakeCollectionWrapper(CollectionWrapper):
|
||||
def __init__(self, config, path, setup_new_collection=None):
|
||||
self. _CollectionWrapper__col = None
|
||||
pass
|
||||
|
||||
class BadCollectionWrapper:
|
||||
pass
|
||||
|
||||
class CollectionWrapperFactoryTest(unittest.TestCase):
|
||||
def test_get_collection_wrapper(self):
|
||||
# Get absolute path to development ini file.
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
ini_file_path = os.path.join(script_dir,
|
||||
"assets",
|
||||
"test.conf")
|
||||
|
||||
# Create temporary files and dirs the server will use.
|
||||
server_paths = helpers.server_utils.create_server_paths()
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read(ini_file_path)
|
||||
path = os.path.realpath('fake/collection.anki2')
|
||||
|
||||
# Use custom files and dirs in settings. Should be CollectionWrapper
|
||||
config['sync_app'].update(server_paths)
|
||||
self.assertTrue(type(get_collection_wrapper(config['sync_app'], path) == CollectionWrapper))
|
||||
|
||||
# A conf-specified CollectionWrapper is loaded
|
||||
config.set("sync_app", "collection_wrapper", 'test_collection_wrappers.FakeCollectionWrapper')
|
||||
self.assertTrue(type(get_collection_wrapper(config['sync_app'], path)) == FakeCollectionWrapper)
|
||||
|
||||
# Should fail at load time if the class doesn't inherit from CollectionWrapper
|
||||
config.set("sync_app", "collection_wrapper", 'test_collection_wrappers.BadCollectionWrapper')
|
||||
with self.assertRaises(TypeError):
|
||||
pm = get_collection_wrapper(config['sync_app'], path)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user