Move the upload/download sqlite3 file logic to a manager
Also add a factory method so the manager can be controlled via config
This commit is contained in:
parent
50cc6a12d9
commit
9ee9697582
59
ankisyncd/full_sync.py
Normal file
59
ankisyncd/full_sync.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import os
|
||||||
|
from sqlite3 import dbapi2 as sqlite
|
||||||
|
|
||||||
|
import anki.db
|
||||||
|
|
||||||
|
class FullSyncManager:
|
||||||
|
def upload(self, col, data, session):
|
||||||
|
# Verify integrity of the received database file before replacing our
|
||||||
|
# existing db.
|
||||||
|
temp_db_path = session.get_collection_path() + ".tmp"
|
||||||
|
with open(temp_db_path, 'wb') as f:
|
||||||
|
f.write(data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with anki.db.DB(temp_db_path) as test_db:
|
||||||
|
if test_db.scalar("pragma integrity_check") != "ok":
|
||||||
|
raise HTTPBadRequest("Integrity check failed for uploaded "
|
||||||
|
"collection database file.")
|
||||||
|
except sqlite.Error as e:
|
||||||
|
raise HTTPBadRequest("Uploaded collection database file is "
|
||||||
|
"corrupt.")
|
||||||
|
|
||||||
|
# Overwrite existing db.
|
||||||
|
col.close()
|
||||||
|
try:
|
||||||
|
os.rename(temp_db_path, session.get_collection_path())
|
||||||
|
finally:
|
||||||
|
col.reopen()
|
||||||
|
col.load()
|
||||||
|
|
||||||
|
return "OK"
|
||||||
|
|
||||||
|
|
||||||
|
def download(self, col, session):
|
||||||
|
col.close()
|
||||||
|
try:
|
||||||
|
data = open(session.get_collection_path(), 'rb').read()
|
||||||
|
finally:
|
||||||
|
col.reopen()
|
||||||
|
col.load()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def get_full_sync_manager(config):
|
||||||
|
if "full_sync_manager" in config and config["full_sync_manager"]: # load from config
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
module_name, class_name = config['full_sync_manager'].rsplit('.', 1)
|
||||||
|
module = importlib.import_module(module_name.strip())
|
||||||
|
class_ = getattr(module, class_name.strip())
|
||||||
|
|
||||||
|
if not FullSyncManager in inspect.getmro(class_):
|
||||||
|
raise TypeError('''"full_sync_manager" found in the conf file but it doesn''t
|
||||||
|
inherit from FullSyncManager''')
|
||||||
|
return class_(config)
|
||||||
|
else:
|
||||||
|
return FullSyncManager()
|
||||||
@ -42,6 +42,7 @@ from anki.consts import REM_CARD, REM_NOTE
|
|||||||
|
|
||||||
from ankisyncd.users import get_user_manager
|
from ankisyncd.users import get_user_manager
|
||||||
from ankisyncd.sessions import get_session_manager
|
from ankisyncd.sessions import get_session_manager
|
||||||
|
from ankisyncd.full_sync import get_full_sync_manager
|
||||||
|
|
||||||
logger = logging.getLogger("ankisyncd")
|
logger = logging.getLogger("ankisyncd")
|
||||||
|
|
||||||
@ -399,6 +400,7 @@ class SyncApp:
|
|||||||
|
|
||||||
self.user_manager = get_user_manager(config)
|
self.user_manager = get_user_manager(config)
|
||||||
self.session_manager = get_session_manager(config)
|
self.session_manager = get_session_manager(config)
|
||||||
|
self.full_sync_manager = get_full_sync_manager(config)
|
||||||
self.collection_manager = getCollectionManager()
|
self.collection_manager = getCollectionManager()
|
||||||
|
|
||||||
# make sure the base_url has a trailing slash
|
# make sure the base_url has a trailing slash
|
||||||
@ -482,37 +484,13 @@ class SyncApp:
|
|||||||
def operation_upload(self, col, data, session):
|
def operation_upload(self, col, data, session):
|
||||||
# Verify integrity of the received database file before replacing our
|
# Verify integrity of the received database file before replacing our
|
||||||
# existing db.
|
# existing db.
|
||||||
temp_db_path = session.get_collection_path() + ".tmp"
|
|
||||||
with open(temp_db_path, 'wb') as f:
|
|
||||||
f.write(data)
|
|
||||||
|
|
||||||
try:
|
return self.full_sync_manager.upload(col, data, session)
|
||||||
with anki.db.DB(temp_db_path) as test_db:
|
|
||||||
if test_db.scalar("pragma integrity_check") != "ok":
|
|
||||||
raise HTTPBadRequest("Integrity check failed for uploaded "
|
|
||||||
"collection database file.")
|
|
||||||
except sqlite.Error as e:
|
|
||||||
raise HTTPBadRequest("Uploaded collection database file is "
|
|
||||||
"corrupt.")
|
|
||||||
|
|
||||||
# Overwrite existing db.
|
|
||||||
col.close()
|
|
||||||
try:
|
|
||||||
os.rename(temp_db_path, session.get_collection_path())
|
|
||||||
finally:
|
|
||||||
col.reopen()
|
|
||||||
col.load()
|
|
||||||
|
|
||||||
return "OK"
|
|
||||||
|
|
||||||
def operation_download(self, col, session):
|
def operation_download(self, col, session):
|
||||||
col.close()
|
# returns user data (not media) as a sqlite3 database for replacing their
|
||||||
try:
|
# local copy in Anki
|
||||||
data = open(session.get_collection_path(), 'rb').read()
|
return self.full_sync_manager.download(col, session)
|
||||||
finally:
|
|
||||||
col.reopen()
|
|
||||||
col.load()
|
|
||||||
return data
|
|
||||||
|
|
||||||
@wsgify
|
@wsgify
|
||||||
def __call__(self, req):
|
def __call__(self, req):
|
||||||
|
|||||||
44
tests/test_full_sync.py
Normal file
44
tests/test_full_sync.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
import configparser
|
||||||
|
|
||||||
|
from ankisyncd.full_sync import FullSyncManager, get_full_sync_manager
|
||||||
|
|
||||||
|
import helpers.server_utils
|
||||||
|
|
||||||
|
class FakeFullSyncManager(FullSyncManager):
|
||||||
|
def __init__(self, config):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class BadFullSyncManager:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class FullSyncManagerFactoryTest(unittest.TestCase):
|
||||||
|
def test_get_full_sync_manager(self):
|
||||||
|
# Get absolute path to development ini file.
|
||||||
|
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
ini_file_path = os.path.join(script_dir,
|
||||||
|
"assets",
|
||||||
|
"test.conf")
|
||||||
|
|
||||||
|
# Create temporary files and dirs the server will use.
|
||||||
|
server_paths = helpers.server_utils.create_server_paths()
|
||||||
|
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read(ini_file_path)
|
||||||
|
|
||||||
|
# Use custom files and dirs in settings. Should be PersistenceManager
|
||||||
|
config['sync_app'].update(server_paths)
|
||||||
|
self.assertTrue(type(get_full_sync_manager(config['sync_app']) == FullSyncManager))
|
||||||
|
|
||||||
|
# A conf-specified FullSyncManager is loaded
|
||||||
|
config.set("sync_app", "full_sync_manager", 'test_full_sync.FakeFullSyncManager')
|
||||||
|
self.assertTrue(type(get_full_sync_manager(config['sync_app'])) == FakeFullSyncManager)
|
||||||
|
|
||||||
|
# Should fail at load time if the class doesn't inherit from FullSyncManager
|
||||||
|
config.set("sync_app", "full_sync_manager", 'test_full_sync.BadFullSyncManager')
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
pm = get_full_sync_manager(config['sync_app'])
|
||||||
|
|
||||||
Loading…
Reference in New Issue
Block a user