Merge sync_app.py from upstream

This commit is contained in:
jdoe0 2013-10-13 15:42:05 -05:00
parent 0daf0a7cd9
commit bd58e02319

View File

@ -1,72 +1,218 @@
# AnkiServer - A personal Anki sync server
# Copyright (C) 2013 David Snopek
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from ConfigParser import SafeConfigParser from ConfigParser import SafeConfigParser
from webob.dec import wsgify from webob.dec import wsgify
from webob.exc import * from webob.exc import *
from webob import Response from webob import Response
import sqlite3 import os
import hashlib import hashlib
import AnkiServer import AnkiServer
import anki import anki
from anki.sync import LocalServer, MediaSyncer from anki.sync import Syncer, MediaSyncer
from anki.utils import intTime, checksum
from anki.consts import SYNC_ZIP_SIZE, SYNC_ZIP_COUNT
try: try:
import simplejson as json import simplejson as json
except ImportError: except ImportError:
import json import json
import os try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
class SyncCollectionHandler(LocalServer): try:
from pysqlite2 import dbapi2 as sqlite
except ImportError:
from sqlite3 import dbapi2 as sqlite
class SyncCollectionHandler(Syncer):
operations = ['meta', 'applyChanges', 'start', 'chunk', 'applyChunk', 'sanityCheck2', 'finish'] operations = ['meta', 'applyChanges', 'start', 'chunk', 'applyChunk', 'sanityCheck2', 'finish']
def __init__(self, col): def __init__(self, col):
LocalServer.__init__(self, col) # So that 'server' (the 3rd argument) can't get set
Syncer.__init__(self, col)
def meta(self, cv=None):
# Make sure the media database is open!
if self.col.media.db is None:
self.col.media.connect()
def applyChanges(self, changes): if cv is not None:
#self.lmod, lscm, self.maxUsn, lts, dummy = self.meta() client, version, platform = cv.split(',')
# TODO: how should we set this value? else:
#self.lnewer = 1 client = 'ankidesktop'
version = '2.0.12'
platform = 'unknown'
result = LocalServer.applyChanges(self, changes) version_int = [int(x) for x in version.split('.')]
#self.prepareToChunk() # Some insanity added in Anki 2.0.13
if client == 'ankidesktop' and version_int[0] >= 2 and version_int[1] >= 0 and version_int[2] >= 13:
return result return {
'scm': self.col.scm,
#def chunk(self, ): 'ts': intTime(),
# self.prepareToChunk() 'mod': self.col.mod,
# return LocalServer.chunk() 'usn': self.col._usn,
'musn': self.col.media.usn(),
'msg': '',
'cont': True,
}
else:
return (self.col.mod, self.col.scm, self.col._usn, intTime(), self.col.media.usn())
class SyncMediaHandler(MediaSyncer): class SyncMediaHandler(MediaSyncer):
operations = ['remove', 'files', 'addFiles', 'mediaSanity'] operations = ['remove', 'files', 'addFiles', 'mediaSanity', 'mediaList']
def __init__(self, col): def __init__(self, col):
MediaSyncer.__init__(self, col) MediaSyncer.__init__(self, col)
def files(self, minUsn=0): def remove(self, fnames, minUsn):
import zipfile, StringIO rrem = MediaSyncer.remove(self, fnames, minUsn)
# increment the USN for each file removed
#self.col.media.setUsn(self.col.media.usn() + len(rrem))
return rrem
zipdata, fnames = MediaSyncer.files(self) def files(self, minUsn=0, need=None):
"""Gets files from the media database and returns them as ZIP file data."""
# add a _usn element to the zipdata import zipfile
fd = StringIO.StringIO(zipdata)
zfd = zipfile.ZipFile(fd, "a", compression=zipfile.ZIP_DEFLATED)
zfd.writestr("_usn", str(minUsn + len(fnames)))
zfd.close()
return fd.getvalue() # The client can pass None - I'm not sure what the correct action is in that case,
# for now, we're going to resync everything.
if need is None:
need = self.mediaList()
# Comparing minUsn to need, we attempt to determine which files have already
# been sent, and we remove them from the front of the list.
need = need[len(need) - (self.col.media.usn() - minUsn):]
# Copied and modified from anki.media.MediaManager.zipAdded(). Instead of going
# over the log, we loop over the files needed and increment the USN along the
# way. The zip also has an additional '_usn' member, which the client uses to
# update the usn on their end.
f = StringIO()
z = zipfile.ZipFile(f, "w", compression=zipfile.ZIP_DEFLATED)
sz = 0
cnt = 0
files = {}
while 1:
if len(need) == 0:
# add a flag so the server knows it can clean up
z.writestr("_finished", "")
break
fname = need.pop()
minUsn += 1
z.write(os.path.join(self.col.media.dir(), fname), str(cnt))
files[str(cnt)] = fname
sz += os.path.getsize(os.path.join(self.col.media.dir(), fname))
if sz > SYNC_ZIP_SIZE or cnt > SYNC_ZIP_COUNT:
break
cnt += 1
z.writestr("_meta", json.dumps(files))
z.writestr("_usn", str(minUsn))
z.close()
return f.getvalue()
def addFiles(self, data):
"""Adds files based from ZIP file data and returns the usn."""
import zipfile
# The argument name is 'zip' on MediaSyncer, but we always use 'data' when
# we receive non-JSON data. We have to override to receive the right argument!
#MediaSyncer.addFiles(self, zip=fd.getvalue())
usn = self.col.media.usn()
# Copied from anki.media.MediaManager.syncAdd(). Modified to not need the
# _usn file and, instead, to increment the server usn with each file added.
f = StringIO(data)
z = zipfile.ZipFile(f, "r")
finished = False
meta = None
media = []
sizecnt = 0
# get meta info first
assert z.getinfo("_meta").file_size < 100000
meta = json.loads(z.read("_meta"))
# then loop through all files
for i in z.infolist():
# check for zip bombs
sizecnt += i.file_size
assert sizecnt < 100*1024*1024
if i.filename == "_meta" or i.filename == "_usn":
# ignore previously-retrieved meta
continue
elif i.filename == "_finished":
# last zip in set
finished = True
else:
data = z.read(i)
csum = checksum(data)
name = meta[i.filename]
# can we store the file on this system?
# TODO: this function changed it's name in Anki 2.0.12 to media.hasIllegal()
if self.col.media.illegal(name):
continue
# save file
open(os.path.join(self.col.media.dir(), name), "wb").write(data)
# update db
media.append((name, csum, self.col.media._mtime(os.path.join(self.col.media.dir(), name))))
# remove entries from local log
self.col.media.db.execute("delete from log where fname = ?", name)
usn += 1
# update media db and note new starting usn
if media:
self.col.media.db.executemany(
"insert or replace into media values (?,?,?)", media)
self.col.media.setUsn(usn) # commits
# if we have finished adding, we need to record the new folder mtime
# so that we don't trigger a needless scan
if finished:
self.col.media.syncMod()
return usn
def mediaList(self):
"""Returns a list of all the fnames in this collections media database."""
fnames = []
for fname, in self.col.media.db.execute("select fname from media"):
fnames.append(fname)
fnames.sort()
return fnames
class SyncUserSession(object): class SyncUserSession(object):
def __init__(self, name, path, collection_manager): def __init__(self, name, path, collection_manager, setup_new_collection=None):
import time import time
self.name = name self.name = name
self.path = path self.path = path
self.collection_manager = collection_manager self.collection_manager = collection_manager
self.setup_new_collection = setup_new_collection
self.version = 0 self.version = 0
self.client_version = ''
self.created = time.time() self.created = time.time()
# make sure the user path exists # make sure the user path exists
@ -80,7 +226,7 @@ class SyncUserSession(object):
return os.path.realpath(os.path.join(self.path, 'collection.anki2')) return os.path.realpath(os.path.join(self.path, 'collection.anki2'))
def get_thread(self): def get_thread(self):
return self.collection_manager.get_collection(self.get_collection_path()) return self.collection_manager.get_collection(self.get_collection_path(), self.setup_new_collection)
def get_handler_for_operation(self, operation, col): def get_handler_for_operation(self, operation, col):
if operation in SyncCollectionHandler.operations: if operation in SyncCollectionHandler.operations:
@ -90,25 +236,29 @@ class SyncUserSession(object):
if getattr(self, cache_name) is None: if getattr(self, cache_name) is None:
setattr(self, cache_name, handler_class(col)) setattr(self, cache_name, handler_class(col))
return getattr(self, cache_name) handler = getattr(self, cache_name)
# The col object may actually be new now! This happens when we close a collection
# for inactivity and then later re-open it (creating a new Collection object).
handler.col = col
return handler
class SyncApp(object): class SimpleSessionManager(object):
valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download', 'getDecks'] """A simple session manager that keeps the sessions in memory."""
def __init__(self, config): def __init__(self):
from AnkiServer.thread import getCollectionManager
self.data_root = os.path.abspath(config.get("sync_app", "data_root"))
self.base_url = config.get("sync_app", "base_url")
self.sessions = {} self.sessions = {}
self.collection_manager = getCollectionManager() def load(self, hkey, session_factory=None):
return self.sessions.get(hkey)
# make sure the base_url has a trailing slash def save(self, hkey, session):
if len(self.base_url) == 0: self.sessions[hkey] = session
self.base_url = '/'
elif self.base_url[-1] != '/': def delete(self, hkey):
self.base_url = base_url + '/' del self.sessions[hkey]
class SimpleUserManager(object):
"""A simple user manager that always allows any user."""
def authenticate(self, username, password): def authenticate(self, username, password):
""" """
@ -117,7 +267,7 @@ class SyncApp(object):
Override this to change how users are authenticated. Override this to change how users are authenticated.
""" """
return False return True
def username2dirname(self, username): def username2dirname(self, username):
""" """
@ -128,6 +278,30 @@ class SyncApp(object):
return username return username
class SyncApp(object):
valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download', 'getDecks']
def __init__(self, config):
from AnkiServer.thread import getCollectionManager
self.data_root = os.path.abspath(config.get("sync_app", "data_root"))
self.base_url = config.get("sync_app", "base_url")
self.setup_new_collection = None
self.hook_pre_sync = None
self.hook_post_sync = None
self.hook_download = None
self.hook_upload = None
self.session_manager = SimpleSessionManager()
self.user_manager = SimpleUserManager()
self.collection_manager = getCollectionManager()
# make sure the base_url has a trailing slash
if len(self.base_url) == 0:
self.base_url = '/'
elif self.base_url[-1] != '/':
self.base_url = base_url + '/'
def generateHostKey(self, username): def generateHostKey(self, username):
"""Generates a new host key to be used by the given username to identify their session. """Generates a new host key to be used by the given username to identify their session.
This values is random.""" This values is random."""
@ -137,26 +311,14 @@ class SyncApp(object):
val = ':'.join([username, str(int(time.time())), ''.join(random.choice(chars) for x in range(8))]) val = ':'.join([username, str(int(time.time())), ''.join(random.choice(chars) for x in range(8))])
return hashlib.md5(val).hexdigest() return hashlib.md5(val).hexdigest()
def create_session(self, hkey, username, user_path): def create_session(self, username, user_path):
"""Creates, stores and returns a new session for the given hkey and username.""" return SyncUserSession(username, user_path, self.collection_manager, self.setup_new_collection)
session = self.sessions[hkey] = SyncUserSession(username, user_path, self.collection_manager)
return session
def load_session(self, hkey):
return self.sessions.get(hkey)
def save_session(self, hkey, session):
pass
def delete_session(self, hkey):
del self.sessions[hkey]
def _decode_data(self, data, compression=0): def _decode_data(self, data, compression=0):
import gzip, StringIO import gzip
if compression: if compression:
buf = gzip.GzipFile(mode="rb", fileobj=StringIO.StringIO(data)) buf = gzip.GzipFile(mode="rb", fileobj=StringIO(data))
data = buf.read() data = buf.read()
buf.close() buf.close()
@ -169,18 +331,42 @@ class SyncApp(object):
return data return data
def operation_upload(self, col, data, session): def operation_upload(self, col, data, session):
# TODO: deal with thread pool col.close()
fd = open(session.get_collection_path(), 'wb') # TODO: we should verify the database integrity before perminantly overwriting
fd.write(data) # (ie. use a temporary file) and declaring this a success!
fd.close() #
# d = DB(path)
# assert d.scalar("pragma integrity_check") == "ok"
# d.close()
#
try:
with open(session.get_collection_path(), 'wb') as fd:
fd.write(data)
finally:
col.reopen()
def operation_download(self, col, data, session): # run hook_upload if one is defined
pass if self.hook_upload is not None:
self.hook_upload(col, session)
return True
def operation_download(self, col, session):
# run hook_download if one is defined
if self.hook_download is not None:
self.hook_download(col, session)
col.close()
try:
data = open(session.get_collection_path(), 'rb').read()
finally:
col.reopen()
return data
@wsgify @wsgify
def __call__(self, req): def __call__(self, req):
print req.path #print req.path
if req.path.startswith(self.base_url): if req.path.startswith(self.base_url):
url = req.path[len(self.base_url):] url = req.path[len(self.base_url):]
if url not in self.valid_urls: if url not in self.valid_urls:
@ -188,15 +374,18 @@ class SyncApp(object):
if url == 'getDecks': if url == 'getDecks':
# This is an Anki 1.x client! Tell them to upgrade. # This is an Anki 1.x client! Tell them to upgrade.
import zlib import zlib, logging
u = req.params.getone('u')
if u:
logging.warn("'%s' is attempting to sync with an Anki 1.x client" % u)
return Response( return Response(
status='200 OK', status='200 OK',
content_type='application/json', content_type='application/json',
content_encoding='deflate', content_encoding='deflate',
body=zlib.compress(json.dumps({'status': 'oldVersion'}))) body=zlib.compress(json.dumps({'status': 'oldVersion'})))
try: try:
compression = req.POST['c'] compression = int(req.POST['c'])
except KeyError: except KeyError:
compression = 0 compression = 0
@ -208,7 +397,7 @@ class SyncApp(object):
except ValueError: except ValueError:
# Bad JSON # Bad JSON
raise HTTPBadRequest() raise HTTPBadRequest()
print 'data:', data #print 'data:', data
if url == 'hostKey': if url == 'hostKey':
try: try:
@ -216,14 +405,15 @@ class SyncApp(object):
p = data['p'] p = data['p']
except KeyError: except KeyError:
raise HTTPForbidden('Must pass username and password') raise HTTPForbidden('Must pass username and password')
if self.authenticate(u, p): if self.user_manager.authenticate(u, p):
dirname = self.username2dirname(u) dirname = self.user_manager.username2dirname(u)
if dirname is None: if dirname is None:
raise HTTPForbidden() raise HTTPForbidden()
hkey = self.generateHostKey(u) hkey = self.generateHostKey(u)
user_path = os.path.join(self.data_root, dirname) user_path = os.path.join(self.data_root, dirname)
session = self.create_session(hkey, u, user_path) session = self.create_session(u, user_path)
self.session_manager.save(hkey, session)
result = {'key': hkey} result = {'key': hkey}
return Response( return Response(
@ -239,27 +429,36 @@ class SyncApp(object):
hkey = req.POST['k'] hkey = req.POST['k']
except KeyError: except KeyError:
raise HTTPForbidden() raise HTTPForbidden()
session = self.load_session(hkey) session = self.session_manager.load(hkey, self.create_session)
if session is None: if session is None:
raise HTTPForbidden() raise HTTPForbidden()
if url in SyncCollectionHandler.operations + SyncMediaHandler.operations: if url in SyncCollectionHandler.operations + SyncMediaHandler.operations:
# 'meta' passes the SYNC_VER but it isn't used in the handler # 'meta' passes the SYNC_VER but it isn't used in the handler
if url == 'meta' and data.has_key('v'): if url == 'meta':
session.version = data['v'] if data.has_key('v'):
del data['v'] session.version = data['v']
del data['v']
if data.has_key('cv'):
session.client_version = data['cv']
thread = session.get_thread()
# run hook_pre_sync if one is defined
if url == 'start':
if self.hook_pre_sync is not None:
thread.execute(self.hook_pre_sync, [session])
# Create a closure to run this operation inside of the thread allocated to this collection # Create a closure to run this operation inside of the thread allocated to this collection
def runFunc(col): def runFunc(col):
handler = session.get_handler_for_operation(url, col) handler = session.get_handler_for_operation(url, col)
func = getattr(handler, url) func = getattr(handler, url)
result = func(**data) result = func(**data)
handler.col.save() col.save()
return result return result
runFunc.func_name = url runFunc.func_name = url
# Send to the thread to execute # Send to the thread to execute
thread = session.get_thread()
result = thread.execute(runFunc) result = thread.execute(runFunc)
# If it's a complex data type, we convert it to JSON # If it's a complex data type, we convert it to JSON
@ -267,42 +466,101 @@ class SyncApp(object):
result = json.dumps(result) result = json.dumps(result)
if url == 'finish': if url == 'finish':
self.delete_session(hkey) # TODO: Apparently 'finish' isn't when we're done because 'mediaList' comes
# after it... When can we possibly delete the session?
#self.session_manager.delete(hkey)
# run hook_post_sync if one is defined
if self.hook_post_sync is not None:
thread.execute(self.hook_post_sync, [session])
return Response( return Response(
status='200 OK', status='200 OK',
content_type='application/json', content_type='application/json',
body=result) body=result)
elif url in ('upload', 'download'): elif url == 'upload':
if url == 'upload':
func = self.operation_upload
else:
func = self.operation_download
thread = session.get_thread() thread = session.get_thread()
thread.execute(self.operation_upload, [data['data'], session]) result = thread.execute(self.operation_upload, [data['data'], session])
return Response( return Response(
status='200 OK', status='200 OK',
content_type='text/plain', content_type='text/plain',
body='OK') body='OK' if result else 'Error')
elif url == 'download':
thread = session.get_thread()
result = thread.execute(self.operation_download, [session])
return Response(
status='200 OK',
content_type='text/plain',
body=result)
# This was one of our operations but it didn't get handled... Oops! # This was one of our operations but it didn't get handled... Oops!
raise HTTPInternalServerError() raise HTTPInternalServerError()
return Response(status='200 OK', content_type='text/plain', body='Anki Sync Server') return Response(status='200 OK', content_type='text/plain', body='Anki Sync Server')
class DatabaseAuthSyncApp(SyncApp): class SqliteSessionManager(SimpleSessionManager):
def __init__(self, config): """Stores sessions in a SQLite database to prevent the user from being logged out
SyncApp.__init__(self, config) everytime the SyncApp is restarted."""
self.auth_db_path = os.path.abspath(config.get("sync_app", "auth_db_path")) def __init__(self, session_db_path):
SimpleSessionManager.__init__(self)
self.session_db_path = os.path.abspath(session_db_path)
def _conn(self):
new = not os.path.exists(self.session_db_path)
conn = sqlite.connect(self.session_db_path)
if new:
cursor = conn.cursor()
cursor.execute("CREATE TABLE session (hkey VARCHAR PRIMARY KEY, user VARCHAR, path VARCHAR)")
return conn
def load(self, hkey, session_factory=None):
session = SimpleSessionManager.load(self, hkey)
if session is not None:
return session
conn = self._conn()
cursor = conn.cursor()
cursor.execute("SELECT user, path FROM session WHERE hkey=?", (hkey,))
res = cursor.fetchone()
if res is not None:
session = self.sessions[hkey] = session_factory(res[0], res[1])
return session
def save(self, hkey, session):
SimpleSessionManager.save(self, hkey, session)
conn = self._conn()
cursor = conn.cursor()
cursor.execute("INSERT OR REPLACE INTO session (hkey, user, path) VALUES (?, ?, ?)",
(hkey, session.name, session.path))
conn.commit()
def delete(self, hkey):
SimpleSessionManager.delete(self, hkey)
conn = self._conn()
cursor = conn.cursor()
cursor.execute("DELETE FROM session WHERE hkey=?", (hkey,))
conn.commit()
class SqliteUserManager(SimpleUserManager):
"""Authenticates users against a SQLite database."""
def __init__(self, auth_db_path):
self.auth_db_path = os.path.abspath(auth_db_path)
def authenticate(self, username, password): def authenticate(self, username, password):
"""Returns True if this username is allowed to connect with this password. False otherwise.""" """Returns True if this username is allowed to connect with this password. False otherwise."""
conn = sqlite3.connect(self.auth_db_path) conn = sqlite.connect(self.auth_db_path)
cursor = conn.cursor() cursor = conn.cursor()
param = (username,) param = (username,)
@ -310,8 +568,6 @@ class DatabaseAuthSyncApp(SyncApp):
db_ret = cursor.fetchone() db_ret = cursor.fetchone()
conn.close()
if db_ret != None: if db_ret != None:
db_hash = str(db_ret[0]) db_hash = str(db_ret[0])
salt = db_hash[-16:] salt = db_hash[-16:]
@ -323,6 +579,14 @@ class DatabaseAuthSyncApp(SyncApp):
return (db_ret != None and hashobj.hexdigest()+salt == db_hash) return (db_ret != None and hashobj.hexdigest()+salt == db_hash)
# Our entry point
def make_app(global_conf, **local_conf):
if local_conf.has_key('session_db_path'):
local_conf['session_manager'] = SqliteSessionManager(local_conf['session_db_path'])
if local_conf.has_key('auth_db_path'):
local_conf['user_manager'] = SqliteUserManager(local_conf['auth_db_path'])
return SyncApp(**local_conf)
def main(): def main():
from wsgiref.simple_server import make_server from wsgiref.simple_server import make_server
from AnkiServer.thread import shutdown from AnkiServer.thread import shutdown
@ -330,7 +594,7 @@ def main():
config = SafeConfigParser() config = SafeConfigParser()
config.read("production.ini") config.read("production.ini")
ankiserver = DatabaseAuthSyncApp(config) ankiserver = SyncApp(config)
httpd = make_server('', config.getint("sync_app", "port"), ankiserver) httpd = make_server('', config.getint("sync_app", "port"), ankiserver)
try: try: