anki-sync-server/src/ankisyncd/sync_app.py

679 lines
23 KiB
Python
Raw Normal View History

2013-10-19 13:46:55 +08:00
# ankisyncd - A personal Anki sync server
2013-10-14 04:42:05 +08:00
# Copyright (C) 2013 David Snopek
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
2017-11-07 00:21:17 +08:00
import gzip
2013-08-02 01:07:49 +08:00
import hashlib
2017-11-07 00:21:17 +08:00
import io
2017-10-29 02:43:09 +08:00
import json
import logging
2017-11-07 00:21:17 +08:00
import os
import random
import re
2015-11-29 12:53:58 +08:00
import string
2017-11-07 03:00:24 +08:00
import sys
2017-11-07 00:21:17 +08:00
import time
import unicodedata
import zipfile
2017-11-07 00:21:17 +08:00
from configparser import ConfigParser
2017-10-29 02:43:09 +08:00
from sqlite3 import dbapi2 as sqlite
2013-08-02 01:07:49 +08:00
2017-11-07 00:21:17 +08:00
from webob import Response
from webob.dec import wsgify
from webob.exc import *
2017-11-03 07:20:24 +08:00
import anki.db
import anki.utils
from anki.consts import REM_CARD, REM_NOTE
from ankisyncd.users import get_user_manager
from ankisyncd.sessions import get_session_manager
from ankisyncd.full_sync import get_full_sync_manager
from .sync import Syncer, SYNC_VER, SYNC_ZIP_SIZE, SYNC_ZIP_COUNT
logger = logging.getLogger("ankisyncd")
2013-10-14 04:42:05 +08:00
class SyncCollectionHandler(Syncer):
2018-08-21 00:02:09 +08:00
operations = ['meta', 'applyChanges', 'start', 'applyGraves', 'chunk', 'applyChunk', 'sanityCheck2', 'finish']
2020-08-26 22:51:06 +08:00
def __init__(self, col, session):
2013-10-14 04:42:05 +08:00
# So that 'server' (the 3rd argument) can't get set
super().__init__(self, col)
2020-08-26 22:51:06 +08:00
self.session = session
@staticmethod
def _old_client(cv):
if not cv:
return False
note = {"alpha": 0, "beta": 0, "rc": 0}
client, version, platform = cv.split(',')
for name in note.keys():
if name in version:
vs = version.split(name)
version = vs[0]
note[name] = int(vs[-1])
# convert the version string, ignoring non-numeric suffixes like in beta versions of Anki
version_nosuffix = re.sub(r'[^0-9.].*$', '', version)
version_int = [int(x) for x in version_nosuffix.split('.')]
if client == 'ankidesktop':
return version_int < [2, 0, 27]
elif client == 'ankidroid':
if version_int == [2, 3]:
if note["alpha"]:
return note["alpha"] < 4
else:
return version_int < [2, 2, 3]
else: # unknown client, assume current version
return False
def meta(self, v=None, cv=None):
if self._old_client(cv):
return Response(status=501) # client needs upgrade
if v > SYNC_VER:
return {"cont": False, "msg": "Your client is using unsupported sync protocol ({}, supported version: {})".format(v, SYNC_VER)}
if v < 9 and self.col.schedVer() >= 2:
return {"cont": False, "msg": "Your client doesn't support the v{} scheduler.".format(self.col.schedVer())}
2013-10-14 04:42:05 +08:00
# Make sure the media database is open!
2020-08-27 03:06:57 +08:00
self.col.media.connect()
2017-03-05 07:39:06 +08:00
return {
'mod': self.col.mod,
2020-08-26 22:51:06 +08:00
'scm': self.col.scm,
2017-03-05 07:39:06 +08:00
'usn': self.col._usn,
2020-08-26 22:51:06 +08:00
'ts': anki.utils.intTime(),
2017-03-05 07:39:06 +08:00
'musn': self.col.media.lastUsn(),
2020-08-26 22:51:06 +08:00
'uname': self.session.name,
2017-03-05 07:39:06 +08:00
'msg': '',
'cont': True,
2020-08-26 22:51:06 +08:00
'hostNum': 0,
2017-03-05 07:39:06 +08:00
}
def usnLim(self):
return "usn >= %d" % self.minUsn
2018-08-21 00:02:09 +08:00
# ankidesktop >=2.1rc2 sends graves in applyGraves, but still expects
# server-side deletions to be returned by start
2020-01-17 16:30:43 +08:00
def start(self, minUsn, lnewer, graves={"cards": [], "notes": [], "decks": []}, offset=None):
if offset is not None:
raise NotImplementedError('You are using the experimental V2 scheduler, which is not supported by the server.')
self.maxUsn = self.col._usn
self.minUsn = minUsn
self.lnewer = not lnewer
lgraves = self.removed()
self.remove(graves)
return lgraves
2018-08-21 00:02:09 +08:00
def applyGraves(self, chunk):
self.remove(chunk)
def applyChanges(self, changes):
self.rchg = changes
lchg = self.changes()
# merge our side before returning
self.mergeChanges(lchg, self.rchg)
return lchg
2020-08-28 06:03:11 +08:00
def sanityCheck2(self, client, full=None):
server = self.sanityCheck()
if client != server:
2020-08-28 06:03:11 +08:00
logger.info(
f"sanity check failed with server: {server} client: {client}"
)
return dict(status="bad", c=client, s=server)
return dict(status="ok")
def finish(self, mod=None):
2020-08-28 06:03:11 +08:00
return super().finish(anki.utils.intTime(1000))
# This function had to be put here in its entirety because Syncer.removed()
# doesn't use self.usnLim() (which we override in this class) in queries.
# "usn=-1" has been replaced with "usn >= ?", self.minUsn by hand.
def removed(self):
cards = []
notes = []
decks = []
curs = self.col.db.execute(
"select oid, type from graves where usn >= ?", self.minUsn)
for oid, type in curs:
if type == REM_CARD:
cards.append(oid)
elif type == REM_NOTE:
notes.append(oid)
else:
decks.append(oid)
return dict(cards=cards, notes=notes, decks=decks)
def getModels(self):
return [m for m in self.col.models.all() if m['usn'] >= self.minUsn]
def getDecks(self):
return [
[g for g in self.col.decks.all() if g['usn'] >= self.minUsn],
[g for g in self.col.decks.allConf() if g['usn'] >= self.minUsn]
]
def getTags(self):
return [t for t, usn in self.col.tags.allItems()
if usn >= self.minUsn]
class SyncMediaHandler:
2017-11-03 08:40:39 +08:00
operations = ['begin', 'mediaChanges', 'mediaSanity', 'uploadChanges', 'downloadFiles']
2020-08-26 22:51:06 +08:00
def __init__(self, col, session):
self.col = col
2020-08-26 22:51:06 +08:00
self.session = session
def begin(self, skey):
2017-11-03 08:49:18 +08:00
return {
'data': {
'sk': skey,
'usn': self.col.media.lastUsn(),
},
2017-11-03 08:49:18 +08:00
'err': '',
}
2013-10-14 04:42:05 +08:00
2017-11-03 08:40:39 +08:00
def uploadChanges(self, data):
"""
The zip file contains files the client hasn't synced with the server
yet ('dirty'), and info on files it has deleted from its own media dir.
"""
2013-10-14 04:42:05 +08:00
2017-11-07 00:21:17 +08:00
with zipfile.ZipFile(io.BytesIO(data), "r") as z:
self._check_zip_data(z)
processed_count = self._adopt_media_changes_from_zip(z)
2017-11-03 08:49:18 +08:00
return {
'data': [processed_count, self.col.media.lastUsn()],
'err': '',
}
@staticmethod
def _check_zip_data(zip_file):
max_zip_size = 100*1024*1024
max_meta_file_size = 100000
meta_file_size = zip_file.getinfo("_meta").file_size
sum_file_sizes = sum(info.file_size for info in zip_file.infolist())
2013-10-14 04:42:05 +08:00
if meta_file_size > max_meta_file_size:
raise ValueError("Zip file's metadata file is larger than %s "
"Bytes." % max_meta_file_size)
elif sum_file_sizes > max_zip_size:
raise ValueError("Zip file contents are larger than %s Bytes." %
max_zip_size)
def _adopt_media_changes_from_zip(self, zip_file):
"""
Adds and removes files to/from the database and media directory
according to the data in zip file zipData.
"""
# Get meta info first.
meta = json.loads(zip_file.read("_meta").decode())
# Remove media files that were removed on the client.
media_to_remove = []
for normname, ordinal in meta:
if ordinal == '':
media_to_remove.append(self._normalize_filename(normname))
# Add media files that were added on the client.
media_to_add = []
usn = self.col.media.lastUsn()
oldUsn = usn
for i in zip_file.infolist():
if i.filename == "_meta": # Ignore previously retrieved metadata.
2013-10-14 04:42:05 +08:00
continue
2017-11-08 20:23:48 +08:00
file_data = zip_file.read(i)
csum = anki.utils.checksum(file_data)
filename = self._normalize_filename(meta[int(i.filename)][0])
file_path = os.path.join(self.col.media.dir(), filename)
# Save file to media directory.
with open(file_path, 'wb') as f:
f.write(file_data)
usn += 1
media_to_add.append((filename, usn, csum))
# We count all files we are to remove, even if we don't have them in
# our media directory and our db doesn't know about them.
processed_count = len(media_to_remove) + len(media_to_add)
assert len(meta) == processed_count # sanity check
if media_to_remove:
self._remove_media_files(media_to_remove)
if media_to_add:
2020-08-27 03:06:57 +08:00
self.col.media.addMedia(media_to_add)
assert self.col.media.lastUsn() == oldUsn + processed_count # TODO: move to some unit test
return processed_count
@staticmethod
def _normalize_filename(filename):
"""
Performs unicode normalization for file names. Logic taken from Anki's
MediaManager.addFilesFromZip().
"""
# Normalize name for platform.
2017-11-03 07:20:24 +08:00
if anki.utils.isMac: # global
filename = unicodedata.normalize("NFD", filename)
else:
filename = unicodedata.normalize("NFC", filename)
return filename
def _remove_media_files(self, filenames):
"""
Marks all files in list filenames as deleted and removes them from the
media directory.
"""
logger.debug('Removing %d files from media dir.' % len(filenames))
for filename in filenames:
try:
self.col.media.syncDelete(filename)
except OSError as err:
logger.error("Error when removing file '%s' from media dir: "
"%s" % (filename, str(err)))
2013-10-14 04:42:05 +08:00
def downloadFiles(self, files):
flist = {}
cnt = 0
sz = 0
2017-11-07 00:21:17 +08:00
f = io.BytesIO()
2017-11-05 01:07:50 +08:00
with zipfile.ZipFile(f, "w", compression=zipfile.ZIP_DEFLATED) as z:
for fname in files:
z.write(os.path.join(self.col.media.dir(), fname), str(cnt))
flist[str(cnt)] = fname
sz += os.path.getsize(os.path.join(self.col.media.dir(), fname))
if sz > SYNC_ZIP_SIZE or cnt > SYNC_ZIP_COUNT:
break
cnt += 1
2017-11-05 01:07:50 +08:00
z.writestr("_meta", json.dumps(flist))
return f.getvalue()
2017-11-03 08:40:39 +08:00
def mediaChanges(self, lastUsn):
result = []
server_lastUsn = self.col.media.lastUsn()
if lastUsn < server_lastUsn or lastUsn == 0:
2020-08-27 03:06:57 +08:00
for fname,usn,csum, in self.col.media.changes(lastUsn):
result.append([fname, usn, csum])
# anki assumes server_lastUsn == result[-1][1]
# ref: anki/sync.py:720 (commit cca3fcb2418880d0430a5c5c2e6b81ba260065b7)
result.reverse()
2017-11-03 08:49:18 +08:00
return {'data': result, 'err': ''}
def mediaSanity(self, local=None):
if self.col.media.mediaCount() == local:
result = "OK"
else:
result = "FAILED"
2017-11-06 23:56:13 +08:00
return {'data': result, 'err': ''}
2017-11-04 09:38:17 +08:00
class SyncUserSession:
2013-10-14 04:42:05 +08:00
def __init__(self, name, path, collection_manager, setup_new_collection=None):
self.skey = self._generate_session_key()
self.name = name
self.path = path
self.collection_manager = collection_manager
2013-10-14 04:42:05 +08:00
self.setup_new_collection = setup_new_collection
self.version = None
self.client_version = None
self.created = time.time()
self.collection_handler = None
self.media_handler = None
# make sure the user path exists
if not os.path.exists(path):
os.mkdir(path)
def _generate_session_key(self):
2017-11-03 07:20:24 +08:00
return anki.utils.checksum(str(random.random()))[:8]
def get_collection_path(self):
return os.path.realpath(os.path.join(self.path, 'collection.anki2'))
2013-08-02 01:07:49 +08:00
def get_thread(self):
2013-10-14 04:42:05 +08:00
return self.collection_manager.get_collection(self.get_collection_path(), self.setup_new_collection)
def get_handler_for_operation(self, operation, col):
if operation in SyncCollectionHandler.operations:
attr, handler_class = 'collection_handler', SyncCollectionHandler
elif operation in SyncMediaHandler.operations:
attr, handler_class = 'media_handler', SyncMediaHandler
else:
raise Exception("no handler for {}".format(operation))
if getattr(self, attr) is None:
2020-08-26 22:51:06 +08:00
setattr(self, attr, handler_class(col, self))
handler = getattr(self, attr)
2013-10-14 04:42:05 +08:00
# The col object may actually be new now! This happens when we close a collection
# for inactivity and then later re-open it (creating a new Collection object).
handler.col = col
return handler
2017-11-04 09:38:17 +08:00
class SyncApp:
2017-03-05 07:39:06 +08:00
valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download']
2013-10-14 04:42:05 +08:00
def __init__(self, config):
from ankisyncd.thread import get_collection_manager
2013-10-14 04:42:05 +08:00
self.data_root = os.path.abspath(config['data_root'])
self.base_url = config['base_url']
self.base_media_url = config['base_media_url']
2013-10-14 04:42:05 +08:00
self.setup_new_collection = None
self.prehooks = {}
self.posthooks = {}
2013-10-14 04:42:05 +08:00
self.user_manager = get_user_manager(config)
self.session_manager = get_session_manager(config)
self.full_sync_manager = get_full_sync_manager(config)
self.collection_manager = get_collection_manager(config)
2013-10-14 04:42:05 +08:00
# make sure the base_url has a trailing slash
if not self.base_url.endswith('/'):
self.base_url += '/'
if not self.base_media_url.endswith('/'):
self.base_media_url += '/'
2013-10-14 04:42:05 +08:00
# backwards compat
@property
def hook_pre_sync(self):
return self.prehooks.get("start")
@hook_pre_sync.setter
def hook_pre_sync(self, value):
self.prehooks['start'] = value
@property
def hook_post_sync(self):
return self.posthooks.get("finish")
@hook_post_sync.setter
def hook_post_sync(self, value):
self.posthooks['finish'] = value
@property
def hook_upload(self):
return self.prehooks.get("upload")
@hook_upload.setter
def hook_upload(self, value):
self.prehooks['upload'] = value
@property
def hook_download(self):
return self.posthooks.get("download")
@hook_download.setter
def hook_download(self, value):
self.posthooks['download'] = value
def generateHostKey(self, username):
"""Generates a new host key to be used by the given username to identify their session.
This values is random."""
import hashlib, time, random, string
chars = string.ascii_letters + string.digits
val = ':'.join([username, str(int(time.time())), ''.join(random.choice(chars) for x in range(8))]).encode()
return hashlib.md5(val).hexdigest()
2013-10-14 04:42:05 +08:00
def create_session(self, username, user_path):
return SyncUserSession(username, user_path, self.collection_manager, self.setup_new_collection)
def _decode_data(self, data, compression=0):
if compression:
2017-11-07 00:21:17 +08:00
with gzip.GzipFile(mode="rb", fileobj=io.BytesIO(data)) as gz:
2017-11-05 01:07:50 +08:00
data = gz.read()
2017-11-03 06:56:48 +08:00
try:
data = json.loads(data.decode())
except (ValueError, UnicodeDecodeError):
data = {'data': data}
return data
def operation_hostKey(self, username, password):
if not self.user_manager.authenticate(username, password):
return
2017-11-03 02:32:10 +08:00
dirname = self.user_manager.userdir(username)
if dirname is None:
return
hkey = self.generateHostKey(username)
user_path = os.path.join(self.data_root, dirname)
session = self.create_session(username, user_path)
self.session_manager.save(hkey, session)
return {'key': hkey}
def operation_upload(self, col, data, session):
# Verify integrity of the received database file before replacing our
# existing db.
2013-10-14 04:42:05 +08:00
return self.full_sync_manager.upload(col, data, session)
2013-10-14 04:42:05 +08:00
def operation_download(self, col, session):
# returns user data (not media) as a sqlite3 database for replacing their
# local copy in Anki
return self.full_sync_manager.download(col, session)
@wsgify
def __call__(self, req):
# Get and verify the session
try:
hkey = req.params['k']
except KeyError:
hkey = None
session = self.session_manager.load(hkey, self.create_session)
if session is None:
try:
skey = req.POST['sk']
session = self.session_manager.load_from_skey(skey, self.create_session)
except KeyError:
skey = None
try:
compression = int(req.POST['c'])
except KeyError:
compression = 0
try:
data = req.POST['data'].file.read()
data = self._decode_data(data, compression)
except KeyError:
data = {}
if req.path.startswith(self.base_url):
url = req.path[len(self.base_url):]
if url not in self.valid_urls:
raise HTTPNotFound()
if url == 'hostKey':
result = self.operation_hostKey(data.get("u"), data.get("p"))
if result:
2017-11-03 09:18:28 +08:00
return json.dumps(result)
else:
# TODO: do I have to pass 'null' for the client to receive None?
raise HTTPForbidden('null')
if session is None:
raise HTTPForbidden()
if url in SyncCollectionHandler.operations + SyncMediaHandler.operations:
# 'meta' passes the SYNC_VER but it isn't used in the handler
2013-10-14 04:42:05 +08:00
if url == 'meta':
2017-11-04 09:06:42 +08:00
if session.skey == None and 's' in req.POST:
session.skey = req.POST['s']
2017-11-04 09:06:42 +08:00
if 'v' in data:
2013-10-14 04:42:05 +08:00
session.version = data['v']
2017-11-04 09:06:42 +08:00
if 'cv' in data:
2013-10-14 04:42:05 +08:00
session.client_version = data['cv']
self.session_manager.save(hkey, session)
session = self.session_manager.load(hkey, self.create_session)
2013-10-14 04:42:05 +08:00
thread = session.get_thread()
if url in self.prehooks:
thread.execute(self.prehooks[url], [session])
result = self._execute_handler_method_in_thread(url, data, session)
2013-04-04 05:42:28 +08:00
# If it's a complex data type, we convert it to JSON
if type(result) not in (str, bytes, Response):
2013-04-04 05:42:28 +08:00
result = json.dumps(result)
if url in self.posthooks:
thread.execute(self.posthooks[url], [session])
2013-08-02 01:07:49 +08:00
return result
2013-10-14 04:42:05 +08:00
elif url == 'upload':
thread = session.get_thread()
if url in self.prehooks:
thread.execute(self.prehooks[url], [session])
2013-10-14 04:42:05 +08:00
result = thread.execute(self.operation_upload, [data['data'], session])
if url in self.posthooks:
thread.execute(self.posthooks[url], [session])
2017-11-03 09:18:28 +08:00
return result
2013-10-14 04:42:05 +08:00
elif url == 'download':
thread = session.get_thread()
if url in self.prehooks:
thread.execute(self.prehooks[url], [session])
2013-10-14 04:42:05 +08:00
result = thread.execute(self.operation_download, [session])
if url in self.posthooks:
thread.execute(self.posthooks[url], [session])
2017-11-03 09:18:28 +08:00
return result
# This was one of our operations but it didn't get handled... Oops!
raise HTTPInternalServerError()
# media sync
elif req.path.startswith(self.base_media_url):
if session is None:
raise HTTPForbidden()
url = req.path[len(self.base_media_url):]
if url not in self.valid_urls:
raise HTTPNotFound()
2017-11-03 08:40:39 +08:00
if url == "begin":
data['skey'] = session.skey
result = self._execute_handler_method_in_thread(url, data, session)
# If it's a complex data type, we convert it to JSON
2017-11-04 09:15:40 +08:00
if type(result) not in (str, bytes):
result = json.dumps(result)
return result
return "Anki Sync Server"
@staticmethod
def _execute_handler_method_in_thread(method_name, keyword_args, session):
"""
Gets and runs the handler method specified by method_name inside the
thread for session. The handler method will access the collection as
self.col.
"""
def run_func(col, **keyword_args):
# Retrieve the correct handler method.
handler = session.get_handler_for_operation(method_name, col)
handler_method = getattr(handler, method_name)
res = handler_method(**keyword_args)
col.save()
return res
2017-11-04 09:06:42 +08:00
run_func.__name__ = method_name # More useful debugging messages.
# Send the closure to the thread for execution.
thread = session.get_thread()
result = thread.execute(run_func, kw=keyword_args)
return result
2013-10-14 04:42:05 +08:00
def make_app(global_conf, **local_conf):
return SyncApp(**local_conf)
def main():
2019-03-08 19:42:19 +08:00
logging.basicConfig(level=logging.INFO, format="[%(asctime)s]:%(levelname)s:%(name)s:%(message)s")
import ankisyncd
logger.info("ankisyncd {} ({})".format(ankisyncd._get_version(), ankisyncd._homepage))
from wsgiref.simple_server import make_server, WSGIRequestHandler
2013-10-19 13:46:55 +08:00
from ankisyncd.thread import shutdown
2018-08-28 23:15:40 +08:00
import ankisyncd.config
2013-08-14 19:22:18 +08:00
class RequestHandler(WSGIRequestHandler):
logger = logging.getLogger("ankisyncd.http")
def log_error(self, format, *args):
self.logger.error("%s %s", self.address_string(), format%args)
def log_message(self, format, *args):
self.logger.info("%s %s", self.address_string(), format%args)
2018-08-28 23:15:40 +08:00
if len(sys.argv) > 1:
# backwards compat
config = ankisyncd.config.load(sys.argv[1])
else:
config = ankisyncd.config.load()
2013-08-14 19:22:18 +08:00
2013-10-14 04:42:05 +08:00
ankiserver = SyncApp(config)
httpd = make_server(config['host'], int(config['port']), ankiserver, handler_class=RequestHandler)
try:
logger.info("Serving HTTP on {} port {}...".format(*httpd.server_address))
httpd.serve_forever()
except KeyboardInterrupt:
logger.info("Exiting...")
finally:
shutdown()
if __name__ == '__main__': main()