Load the CollectionWrapper from a factory method

This allows a class implementing CollectionWrapper's interface to be
added from config
This commit is contained in:
Anton Melser 2019-01-28 21:39:57 +08:00
parent 9ee9697582
commit fa89b0e0a2
4 changed files with 87 additions and 16 deletions

View File

@ -2,6 +2,10 @@ import anki
import anki.storage
import os, errno
import logging
logger = logging.getLogger("ankisyncd.collection")
class CollectionWrapper:
"""A simple wrapper around an anki.storage.Collection object.
@ -9,12 +13,12 @@ class CollectionWrapper:
This allows us to manage and refer to the collection, whether it's open or not. It
also provides a special "continuation passing" interface for executing functions
on the collection, which makes it easy to switch to a threading mode.
See ThreadingCollectionWrapper for a version that maintains a seperate thread for
interacting with the collection.
"""
def __init__(self, path, setup_new_collection=None):
def __init__(self, _config, path, setup_new_collection=None):
self.path = os.path.realpath(path)
self.username = os.path.basename(os.path.dirname(self.path))
self.setup_new_collection = setup_new_collection
@ -50,7 +54,7 @@ class CollectionWrapper:
# mkdir -p the path, because it might not exist
os.makedirs(os.path.dirname(self.path), exist_ok=True)
col = anki.storage.Collection(self.path)
col = self._get_collection()
# Do any special setup
if self.setup_new_collection is not None:
@ -58,11 +62,14 @@ class CollectionWrapper:
return col
def _get_collection(self):
return anki.storage.Collection(self.path)
def open(self):
"""Open the collection, or create it if it doesn't exist."""
if self.__col is None:
if os.path.exists(self.path):
self.__col = anki.storage.Collection(self.path)
self.__col = self._get_collection()
else:
self.__col = self.__create_collection()
@ -83,8 +90,9 @@ class CollectionManager:
collection_wrapper = CollectionWrapper
def __init__(self):
def __init__(self, config):
self.collections = {}
self.config = config
def get_collection(self, path, setup_new_collection=None):
"""Gets a CollectionWrapper for the given path."""
@ -94,7 +102,7 @@ class CollectionManager:
try:
col = self.collections[path]
except KeyError:
col = self.collections[path] = self.collection_wrapper(path, setup_new_collection)
col = self.collections[path] = self.collection_wrapper(self.config, path, setup_new_collection)
return col
@ -104,3 +112,19 @@ class CollectionManager:
del self.collections[path]
col.close()
def get_collection_wrapper(config, path, setup_new_collection = None):
if "collection_wrapper" in config and config["collection_wrapper"]:
logger.info("Found collection_wrapper in config, using {} for "
"user data persistence".format(config['collection_wrapper']))
import importlib
import inspect
module_name, class_name = config['collection_wrapper'].rsplit('.', 1)
module = importlib.import_module(module_name.strip())
class_ = getattr(module, class_name.strip())
if not CollectionWrapper in inspect.getmro(class_):
raise TypeError('''"collection_wrapper" found in the conf file but it doesn''t
inherit from CollectionWrapper''')
return class_(config, path, setup_new_collection)
else:
return CollectionWrapper(config, path, setup_new_collection)

View File

@ -388,7 +388,7 @@ class SyncApp:
valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download']
def __init__(self, config):
from ankisyncd.thread import getCollectionManager
from ankisyncd.thread import get_collection_manager
self.data_root = os.path.abspath(config['data_root'])
self.base_url = config['base_url']
@ -401,7 +401,7 @@ class SyncApp:
self.user_manager = get_user_manager(config)
self.session_manager = get_session_manager(config)
self.full_sync_manager = get_full_sync_manager(config)
self.collection_manager = getCollectionManager()
self.collection_manager = get_collection_manager(config)
# make sure the base_url has a trailing slash
if not self.base_url.endswith('/'):

View File

@ -1,4 +1,4 @@
from ankisyncd.collection import CollectionWrapper, CollectionManager
from ankisyncd.collection import CollectionManager, get_collection_wrapper
from threading import Thread
from queue import Queue
@ -29,12 +29,12 @@ def short_repr(obj, logger=logging.getLogger(), maxlen=80):
return repr(o)
class ThreadingCollectionWrapper:
"""Provides the same interface as CollectionWrapper, but it creates a new Thread to
"""Provides the same interface as CollectionWrapper, but it creates a new Thread to
interact with the collection."""
def __init__(self, path, setup_new_collection=None):
def __init__(self, config, path, setup_new_collection=None):
self.path = path
self.wrapper = CollectionWrapper(path, setup_new_collection)
self.wrapper = get_collection_wrapper(config, path, setup_new_collection)
self.logger = logging.getLogger("ankisyncd." + str(self))
self._queue = Queue()
@ -156,8 +156,8 @@ class ThreadingCollectionManager(CollectionManager):
collection_wrapper = ThreadingCollectionWrapper
def __init__(self):
super(ThreadingCollectionManager, self).__init__()
def __init__(self, config):
super(ThreadingCollectionManager, self).__init__(config)
self.monitor_frequency = 15
self.monitor_inactivity = 90
@ -202,11 +202,11 @@ class ThreadingCollectionManager(CollectionManager):
collection_manager = None
def getCollectionManager():
def get_collection_manager(config):
"""Return the global ThreadingCollectionManager for this process."""
global collection_manager
if collection_manager is None:
collection_manager = ThreadingCollectionManager()
collection_manager = ThreadingCollectionManager(config)
return collection_manager
def shutdown():

View File

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
import os
import unittest
import configparser
from ankisyncd.collection import CollectionWrapper
from ankisyncd.collection import get_collection_wrapper
import helpers.server_utils
class FakeCollectionWrapper(CollectionWrapper):
def __init__(self, config, path, setup_new_collection=None):
self. _CollectionWrapper__col = None
pass
class BadCollectionWrapper:
pass
class CollectionWrapperFactoryTest(unittest.TestCase):
def test_get_collection_wrapper(self):
# Get absolute path to development ini file.
script_dir = os.path.dirname(os.path.realpath(__file__))
ini_file_path = os.path.join(script_dir,
"assets",
"test.conf")
# Create temporary files and dirs the server will use.
server_paths = helpers.server_utils.create_server_paths()
config = configparser.ConfigParser()
config.read(ini_file_path)
path = os.path.realpath('fake/collection.anki2')
# Use custom files and dirs in settings. Should be CollectionWrapper
config['sync_app'].update(server_paths)
self.assertTrue(type(get_collection_wrapper(config['sync_app'], path) == CollectionWrapper))
# A conf-specified CollectionWrapper is loaded
config.set("sync_app", "collection_wrapper", 'test_collection_wrappers.FakeCollectionWrapper')
self.assertTrue(type(get_collection_wrapper(config['sync_app'], path)) == FakeCollectionWrapper)
# Should fail at load time if the class doesn't inherit from CollectionWrapper
config.set("sync_app", "collection_wrapper", 'test_collection_wrappers.BadCollectionWrapper')
with self.assertRaises(TypeError):
pm = get_collection_wrapper(config['sync_app'], path)