Load the CollectionWrapper from a factory method
This allows a class implementing CollectionWrapper's interface to be added from config
This commit is contained in:
parent
9ee9697582
commit
fa89b0e0a2
@ -2,6 +2,10 @@ import anki
|
|||||||
import anki.storage
|
import anki.storage
|
||||||
|
|
||||||
import os, errno
|
import os, errno
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger("ankisyncd.collection")
|
||||||
|
|
||||||
|
|
||||||
class CollectionWrapper:
|
class CollectionWrapper:
|
||||||
"""A simple wrapper around an anki.storage.Collection object.
|
"""A simple wrapper around an anki.storage.Collection object.
|
||||||
@ -14,7 +18,7 @@ class CollectionWrapper:
|
|||||||
interacting with the collection.
|
interacting with the collection.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, path, setup_new_collection=None):
|
def __init__(self, _config, path, setup_new_collection=None):
|
||||||
self.path = os.path.realpath(path)
|
self.path = os.path.realpath(path)
|
||||||
self.username = os.path.basename(os.path.dirname(self.path))
|
self.username = os.path.basename(os.path.dirname(self.path))
|
||||||
self.setup_new_collection = setup_new_collection
|
self.setup_new_collection = setup_new_collection
|
||||||
@ -50,7 +54,7 @@ class CollectionWrapper:
|
|||||||
# mkdir -p the path, because it might not exist
|
# mkdir -p the path, because it might not exist
|
||||||
os.makedirs(os.path.dirname(self.path), exist_ok=True)
|
os.makedirs(os.path.dirname(self.path), exist_ok=True)
|
||||||
|
|
||||||
col = anki.storage.Collection(self.path)
|
col = self._get_collection()
|
||||||
|
|
||||||
# Do any special setup
|
# Do any special setup
|
||||||
if self.setup_new_collection is not None:
|
if self.setup_new_collection is not None:
|
||||||
@ -58,11 +62,14 @@ class CollectionWrapper:
|
|||||||
|
|
||||||
return col
|
return col
|
||||||
|
|
||||||
|
def _get_collection(self):
|
||||||
|
return anki.storage.Collection(self.path)
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
"""Open the collection, or create it if it doesn't exist."""
|
"""Open the collection, or create it if it doesn't exist."""
|
||||||
if self.__col is None:
|
if self.__col is None:
|
||||||
if os.path.exists(self.path):
|
if os.path.exists(self.path):
|
||||||
self.__col = anki.storage.Collection(self.path)
|
self.__col = self._get_collection()
|
||||||
else:
|
else:
|
||||||
self.__col = self.__create_collection()
|
self.__col = self.__create_collection()
|
||||||
|
|
||||||
@ -83,8 +90,9 @@ class CollectionManager:
|
|||||||
|
|
||||||
collection_wrapper = CollectionWrapper
|
collection_wrapper = CollectionWrapper
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, config):
|
||||||
self.collections = {}
|
self.collections = {}
|
||||||
|
self.config = config
|
||||||
|
|
||||||
def get_collection(self, path, setup_new_collection=None):
|
def get_collection(self, path, setup_new_collection=None):
|
||||||
"""Gets a CollectionWrapper for the given path."""
|
"""Gets a CollectionWrapper for the given path."""
|
||||||
@ -94,7 +102,7 @@ class CollectionManager:
|
|||||||
try:
|
try:
|
||||||
col = self.collections[path]
|
col = self.collections[path]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
col = self.collections[path] = self.collection_wrapper(path, setup_new_collection)
|
col = self.collections[path] = self.collection_wrapper(self.config, path, setup_new_collection)
|
||||||
|
|
||||||
return col
|
return col
|
||||||
|
|
||||||
@ -104,3 +112,19 @@ class CollectionManager:
|
|||||||
del self.collections[path]
|
del self.collections[path]
|
||||||
col.close()
|
col.close()
|
||||||
|
|
||||||
|
def get_collection_wrapper(config, path, setup_new_collection = None):
|
||||||
|
if "collection_wrapper" in config and config["collection_wrapper"]:
|
||||||
|
logger.info("Found collection_wrapper in config, using {} for "
|
||||||
|
"user data persistence".format(config['collection_wrapper']))
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
module_name, class_name = config['collection_wrapper'].rsplit('.', 1)
|
||||||
|
module = importlib.import_module(module_name.strip())
|
||||||
|
class_ = getattr(module, class_name.strip())
|
||||||
|
|
||||||
|
if not CollectionWrapper in inspect.getmro(class_):
|
||||||
|
raise TypeError('''"collection_wrapper" found in the conf file but it doesn''t
|
||||||
|
inherit from CollectionWrapper''')
|
||||||
|
return class_(config, path, setup_new_collection)
|
||||||
|
else:
|
||||||
|
return CollectionWrapper(config, path, setup_new_collection)
|
||||||
|
|||||||
@ -388,7 +388,7 @@ class SyncApp:
|
|||||||
valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download']
|
valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download']
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
from ankisyncd.thread import getCollectionManager
|
from ankisyncd.thread import get_collection_manager
|
||||||
|
|
||||||
self.data_root = os.path.abspath(config['data_root'])
|
self.data_root = os.path.abspath(config['data_root'])
|
||||||
self.base_url = config['base_url']
|
self.base_url = config['base_url']
|
||||||
@ -401,7 +401,7 @@ class SyncApp:
|
|||||||
self.user_manager = get_user_manager(config)
|
self.user_manager = get_user_manager(config)
|
||||||
self.session_manager = get_session_manager(config)
|
self.session_manager = get_session_manager(config)
|
||||||
self.full_sync_manager = get_full_sync_manager(config)
|
self.full_sync_manager = get_full_sync_manager(config)
|
||||||
self.collection_manager = getCollectionManager()
|
self.collection_manager = get_collection_manager(config)
|
||||||
|
|
||||||
# make sure the base_url has a trailing slash
|
# make sure the base_url has a trailing slash
|
||||||
if not self.base_url.endswith('/'):
|
if not self.base_url.endswith('/'):
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from ankisyncd.collection import CollectionWrapper, CollectionManager
|
from ankisyncd.collection import CollectionManager, get_collection_wrapper
|
||||||
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
@ -32,9 +32,9 @@ class ThreadingCollectionWrapper:
|
|||||||
"""Provides the same interface as CollectionWrapper, but it creates a new Thread to
|
"""Provides the same interface as CollectionWrapper, but it creates a new Thread to
|
||||||
interact with the collection."""
|
interact with the collection."""
|
||||||
|
|
||||||
def __init__(self, path, setup_new_collection=None):
|
def __init__(self, config, path, setup_new_collection=None):
|
||||||
self.path = path
|
self.path = path
|
||||||
self.wrapper = CollectionWrapper(path, setup_new_collection)
|
self.wrapper = get_collection_wrapper(config, path, setup_new_collection)
|
||||||
self.logger = logging.getLogger("ankisyncd." + str(self))
|
self.logger = logging.getLogger("ankisyncd." + str(self))
|
||||||
|
|
||||||
self._queue = Queue()
|
self._queue = Queue()
|
||||||
@ -156,8 +156,8 @@ class ThreadingCollectionManager(CollectionManager):
|
|||||||
|
|
||||||
collection_wrapper = ThreadingCollectionWrapper
|
collection_wrapper = ThreadingCollectionWrapper
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, config):
|
||||||
super(ThreadingCollectionManager, self).__init__()
|
super(ThreadingCollectionManager, self).__init__(config)
|
||||||
|
|
||||||
self.monitor_frequency = 15
|
self.monitor_frequency = 15
|
||||||
self.monitor_inactivity = 90
|
self.monitor_inactivity = 90
|
||||||
@ -202,11 +202,11 @@ class ThreadingCollectionManager(CollectionManager):
|
|||||||
|
|
||||||
collection_manager = None
|
collection_manager = None
|
||||||
|
|
||||||
def getCollectionManager():
|
def get_collection_manager(config):
|
||||||
"""Return the global ThreadingCollectionManager for this process."""
|
"""Return the global ThreadingCollectionManager for this process."""
|
||||||
global collection_manager
|
global collection_manager
|
||||||
if collection_manager is None:
|
if collection_manager is None:
|
||||||
collection_manager = ThreadingCollectionManager()
|
collection_manager = ThreadingCollectionManager(config)
|
||||||
return collection_manager
|
return collection_manager
|
||||||
|
|
||||||
def shutdown():
|
def shutdown():
|
||||||
|
|||||||
47
tests/test_collection_wrappers.py
Normal file
47
tests/test_collection_wrappers.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
import configparser
|
||||||
|
|
||||||
|
from ankisyncd.collection import CollectionWrapper
|
||||||
|
from ankisyncd.collection import get_collection_wrapper
|
||||||
|
|
||||||
|
import helpers.server_utils
|
||||||
|
|
||||||
|
class FakeCollectionWrapper(CollectionWrapper):
|
||||||
|
def __init__(self, config, path, setup_new_collection=None):
|
||||||
|
self. _CollectionWrapper__col = None
|
||||||
|
pass
|
||||||
|
|
||||||
|
class BadCollectionWrapper:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class CollectionWrapperFactoryTest(unittest.TestCase):
|
||||||
|
def test_get_collection_wrapper(self):
|
||||||
|
# Get absolute path to development ini file.
|
||||||
|
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
ini_file_path = os.path.join(script_dir,
|
||||||
|
"assets",
|
||||||
|
"test.conf")
|
||||||
|
|
||||||
|
# Create temporary files and dirs the server will use.
|
||||||
|
server_paths = helpers.server_utils.create_server_paths()
|
||||||
|
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read(ini_file_path)
|
||||||
|
path = os.path.realpath('fake/collection.anki2')
|
||||||
|
|
||||||
|
# Use custom files and dirs in settings. Should be CollectionWrapper
|
||||||
|
config['sync_app'].update(server_paths)
|
||||||
|
self.assertTrue(type(get_collection_wrapper(config['sync_app'], path) == CollectionWrapper))
|
||||||
|
|
||||||
|
# A conf-specified CollectionWrapper is loaded
|
||||||
|
config.set("sync_app", "collection_wrapper", 'test_collection_wrappers.FakeCollectionWrapper')
|
||||||
|
self.assertTrue(type(get_collection_wrapper(config['sync_app'], path)) == FakeCollectionWrapper)
|
||||||
|
|
||||||
|
# Should fail at load time if the class doesn't inherit from CollectionWrapper
|
||||||
|
config.set("sync_app", "collection_wrapper", 'test_collection_wrappers.BadCollectionWrapper')
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
pm = get_collection_wrapper(config['sync_app'], path)
|
||||||
|
|
||||||
Loading…
Reference in New Issue
Block a user