anki-sync-server/ankisyncd/sync_app.py

682 lines
23 KiB
Python
Raw Normal View History

2013-10-19 13:46:55 +08:00
# ankisyncd - A personal Anki sync server
2013-10-14 04:42:05 +08:00
# Copyright (C) 2013 David Snopek
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
2013-08-14 19:22:18 +08:00
from ConfigParser import SafeConfigParser
from webob.dec import wsgify
from webob.exc import *
from webob import Response
2013-10-14 04:42:05 +08:00
import os
2013-08-02 01:07:49 +08:00
import hashlib
import logging
import random
2015-11-29 12:53:58 +08:00
import string
2013-08-02 01:07:49 +08:00
2013-10-19 13:46:55 +08:00
import ankisyncd
import anki
2013-10-14 04:42:05 +08:00
from anki.sync import Syncer, MediaSyncer
from anki.utils import intTime, checksum
from anki.consts import SYNC_ZIP_SIZE, SYNC_ZIP_COUNT
try:
import simplejson as json
except ImportError:
import json
2013-10-14 04:42:05 +08:00
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
try:
from pysqlite2 import dbapi2 as sqlite
except ImportError:
from sqlite3 import dbapi2 as sqlite
2013-10-14 04:42:05 +08:00
class SyncCollectionHandler(Syncer):
operations = ['meta', 'applyChanges', 'start', 'chunk', 'applyChunk', 'sanityCheck2', 'finish']
def __init__(self, col):
2013-10-14 04:42:05 +08:00
# So that 'server' (the 3rd argument) can't get set
Syncer.__init__(self, col)
2013-10-14 04:42:05 +08:00
def meta(self, cv=None):
# Make sure the media database is open!
if self.col.media.db is None:
self.col.media.connect()
2013-10-14 04:42:05 +08:00
if cv is not None:
client, version, platform = cv.split(',')
else:
client = 'ankidesktop'
version = '2.0.12'
platform = 'unknown'
2015-11-29 12:53:58 +08:00
version_int = [ int(str(x).translate(None, string.ascii_letters))
for x in version.split('.') ]
2013-10-14 04:42:05 +08:00
# Some insanity added in Anki 2.0.13
2015-11-29 12:53:58 +08:00
if (client == 'ankidroid' and version_int[0] >=2 and version_int[1] >= 3) \
or (client == 'ankidesktop' and version_int[0] >= 2 and version_int[1] >= 0 and version_int[2] >= 13):
2013-10-14 04:42:05 +08:00
return {
'scm': self.col.scm,
'ts': intTime(),
'mod': self.col.mod,
'usn': self.col._usn,
'musn': self.col.media.lastUsn(),
2013-10-14 04:42:05 +08:00
'msg': '',
'cont': True,
}
else:
return (self.col.mod, self.col.scm, self.col._usn, intTime(), self.col.media.lastUsn())
class SyncMediaHandler(MediaSyncer):
operations = ['begin', 'mediaChanges', 'mediaSanity', 'mediaList', 'uploadChanges', 'downloadFiles']
def __init__(self, col):
MediaSyncer.__init__(self, col)
def begin(self, skey):
return json.dumps({
'data':{
'sk':skey,
'usn':self.col.media.lastUsn()
},
'err':''
})
2013-10-14 04:42:05 +08:00
def uploadChanges(self, data, skey):
2013-10-14 04:42:05 +08:00
"""Adds files based from ZIP file data and returns the usn."""
import zipfile
usn = self.col.media.lastUsn()
2013-10-14 04:42:05 +08:00
# Copied from anki.media.MediaManager.syncAdd(). Modified to not need the
# _usn file and, instead, to increment the server usn with each file added.
f = StringIO(data)
z = zipfile.ZipFile(f, "r")
finished = False
meta = None
media = []
sizecnt = 0
processedCnt = 0
2013-10-14 04:42:05 +08:00
# get meta info first
assert z.getinfo("_meta").file_size < 100000
meta = json.loads(z.read("_meta"))
# then loop through all files
for i in z.infolist():
# check for zip bombs
sizecnt += i.file_size
assert sizecnt < 100*1024*1024
if i.filename == "_meta" or i.filename == "_usn":
# ignore previously-retrieved meta
continue
elif i.filename == "_finished":
# last zip in set
finished = True
else:
data = z.read(i)
csum = checksum(data)
name = [x for x in meta if x[1] == i.filename][0][0]
2013-10-14 04:42:05 +08:00
# can we store the file on this system?
# NOTE: this function changed it's name in Anki 2.0.12 to media.hasIllegal()
if hasattr(self.col.media, 'illegal') and self.col.media.illegal(name):
continue
if hasattr(self.col.media, 'hasIllegal') and self.col.media.hasIllegal(name):
2013-10-14 04:42:05 +08:00
continue
# save file
open(os.path.join(self.col.media.dir(), name), "wb").write(data)
# update db
media.append((name, csum, self.col.media._mtime(os.path.join(self.col.media.dir(), name)), 0))
processedCnt += 1
2013-10-14 04:42:05 +08:00
usn += 1
# update media db and note new starting usn
if media:
self.col.media.db.executemany(
"insert or replace into media values (?,?,?,?)", media)
self.col.media.setLastUsn(usn) # commits
2013-10-14 04:42:05 +08:00
# if we have finished adding, we need to record the new folder mtime
# so that we don't trigger a needless scan
if finished:
self.col.media.syncMod()
2015-11-29 12:53:58 +08:00
return json.dumps({'data':[processedCnt, usn], 'err':''})
2013-10-14 04:42:05 +08:00
def downloadFiles(self, files):
import zipfile
flist = {}
cnt = 0
sz = 0
f = StringIO()
z = zipfile.ZipFile(f, "w", compression=zipfile.ZIP_DEFLATED)
for fname in files:
z.write(os.path.join(self.col.media.dir(), fname), str(cnt))
flist[str(cnt)] = fname
sz += os.path.getsize(os.path.join(self.col.media.dir(), fname))
if sz > SYNC_ZIP_SIZE or cnt > SYNC_ZIP_COUNT:
break
cnt += 1
z.writestr("_meta", json.dumps(flist))
z.close()
return f.getvalue()
def mediaChanges(self, lastUsn, skey):
result = []
usn = self.col.media.lastUsn()
fname = csum = None
if lastUsn < usn or lastUsn == 0:
for fname,mtime,csum, in self.col.media.db.execute("select fname,mtime,csum from media"):
result.append([fname, usn, csum])
2015-11-29 12:53:58 +08:00
return json.dumps({'data':result, 'err':''})
def mediaSanity(self, local=None):
if self.col.media.mediaCount() == local:
result = "OK"
else:
result = "FAILED"
2015-11-29 12:53:58 +08:00
return json.dumps({'data':result, 'err':''})
class SyncUserSession(object):
2013-10-14 04:42:05 +08:00
def __init__(self, name, path, collection_manager, setup_new_collection=None):
import time
self.skey = self._generate_session_key()
self.name = name
self.path = path
self.collection_manager = collection_manager
2013-10-14 04:42:05 +08:00
self.setup_new_collection = setup_new_collection
self.version = 0
2013-10-14 04:42:05 +08:00
self.client_version = ''
self.created = time.time()
# make sure the user path exists
if not os.path.exists(path):
os.mkdir(path)
self.collection_handler = None
self.media_handler = None
def _generate_session_key(self):
return checksum(str(random.random()))[:8]
def get_collection_path(self):
return os.path.realpath(os.path.join(self.path, 'collection.anki2'))
2013-08-02 01:07:49 +08:00
def get_thread(self):
2013-10-14 04:42:05 +08:00
return self.collection_manager.get_collection(self.get_collection_path(), self.setup_new_collection)
def get_handler_for_operation(self, operation, col):
if operation in SyncCollectionHandler.operations:
cache_name, handler_class = 'collection_handler', SyncCollectionHandler
else:
cache_name, handler_class = 'media_handler', SyncMediaHandler
if getattr(self, cache_name) is None:
setattr(self, cache_name, handler_class(col))
2013-10-14 04:42:05 +08:00
handler = getattr(self, cache_name)
# The col object may actually be new now! This happens when we close a collection
# for inactivity and then later re-open it (creating a new Collection object).
handler.col = col
return handler
2013-10-14 04:42:05 +08:00
class SimpleSessionManager(object):
"""A simple session manager that keeps the sessions in memory."""
2013-10-14 04:42:05 +08:00
def __init__(self):
self.sessions = {}
2013-10-14 04:42:05 +08:00
def load(self, hkey, session_factory=None):
return self.sessions.get(hkey)
def load_from_skey(self, skey, session_factory=None):
for i in self.sessions:
if self.sessions[i].skey == skey:
return self.sessions[i]
2013-10-14 04:42:05 +08:00
def save(self, hkey, session):
self.sessions[hkey] = session
def delete(self, hkey):
del self.sessions[hkey]
class SimpleUserManager(object):
"""A simple user manager that always allows any user."""
def authenticate(self, username, password):
"""
Returns True if this username is allowed to connect with this password. False otherwise.
Override this to change how users are authenticated.
"""
2013-10-14 04:42:05 +08:00
return True
def username2dirname(self, username):
"""
Returns the directory name for the given user. By default, this is just the username.
Override this to adjust the mapping between users and their directory.
"""
return username
2013-08-02 01:07:49 +08:00
2013-10-14 04:42:05 +08:00
class SyncApp(object):
valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download', 'getDecks']
def __init__(self, config):
2013-10-19 13:46:55 +08:00
from ankisyncd.thread import getCollectionManager
2013-10-14 04:42:05 +08:00
self.data_root = os.path.abspath(config.get("sync_app", "data_root"))
self.base_url = config.get("sync_app", "base_url")
self.base_media_url = config.get("sync_app", "base_media_url")
2013-10-14 04:42:05 +08:00
self.setup_new_collection = None
self.hook_pre_sync = None
self.hook_post_sync = None
self.hook_download = None
self.hook_upload = None
self.session_manager = SimpleSessionManager()
self.user_manager = SimpleUserManager()
self.collection_manager = getCollectionManager()
# make sure the base_url has a trailing slash
if not self.base_url.endswith('/'):
self.base_url += '/'
if not self.base_media_url.endswith('/'):
self.base_media_url += '/'
2013-10-14 04:42:05 +08:00
def generateHostKey(self, username):
"""Generates a new host key to be used by the given username to identify their session.
This values is random."""
import hashlib, time, random, string
chars = string.ascii_letters + string.digits
val = ':'.join([username, str(int(time.time())), ''.join(random.choice(chars) for x in range(8))])
return hashlib.md5(val).hexdigest()
2013-10-14 04:42:05 +08:00
def create_session(self, username, user_path):
return SyncUserSession(username, user_path, self.collection_manager, self.setup_new_collection)
def _decode_data(self, data, compression=0):
2013-10-14 04:42:05 +08:00
import gzip
if compression:
2013-10-14 04:42:05 +08:00
buf = gzip.GzipFile(mode="rb", fileobj=StringIO(data))
data = buf.read()
buf.close()
# really lame check for JSON
if data[0] == '{' and data[-1] == '}':
data = json.loads(data)
else:
data = {'data': data}
return data
def operation_upload(self, col, data, session):
2013-10-14 04:42:05 +08:00
col.close()
# TODO: we should verify the database integrity before perminantly overwriting
# (ie. use a temporary file) and declaring this a success!
#
# d = DB(path)
# assert d.scalar("pragma integrity_check") == "ok"
# d.close()
#
try:
with open(session.get_collection_path(), 'wb') as fd:
fd.write(data)
finally:
col.reopen()
# run hook_upload if one is defined
if self.hook_upload is not None:
self.hook_upload(col, session)
return True
def operation_download(self, col, session):
# run hook_download if one is defined
if self.hook_download is not None:
self.hook_download(col, session)
col.close()
try:
data = open(session.get_collection_path(), 'rb').read()
finally:
col.reopen()
return data
@wsgify
def __call__(self, req):
# Get and verify the session
try:
hkey = req.POST['k']
except KeyError:
hkey = None
session = self.session_manager.load(hkey, self.create_session)
if session is None:
try:
skey = req.POST['sk']
session = self.session_manager.load_from_skey(skey, self.create_session)
except KeyError:
skey = None
try:
compression = int(req.POST['c'])
except KeyError:
compression = 0
try:
data = req.POST['data'].file.read()
data = self._decode_data(data, compression)
except KeyError:
data = {}
except ValueError:
# Bad JSON
raise HTTPBadRequest()
if req.path.startswith(self.base_url):
url = req.path[len(self.base_url):]
if url not in self.valid_urls:
raise HTTPNotFound()
if url == 'getDecks':
# This is an Anki 1.x client! Tell them to upgrade.
2013-10-14 04:42:05 +08:00
import zlib, logging
u = req.params.getone('u')
if u:
logging.warn("'%s' is attempting to sync with an Anki 1.x client" % u)
return Response(
2013-10-14 04:42:05 +08:00
status='200 OK',
content_type='application/json',
content_encoding='deflate',
body=zlib.compress(json.dumps({'status': 'oldVersion'})))
if url == 'hostKey':
try:
u = data['u']
p = data['p']
except KeyError:
raise HTTPForbidden('Must pass username and password')
2013-10-14 04:42:05 +08:00
if self.user_manager.authenticate(u, p):
dirname = self.user_manager.username2dirname(u)
if dirname is None:
raise HTTPForbidden()
hkey = self.generateHostKey(u)
user_path = os.path.join(self.data_root, dirname)
2013-10-14 04:42:05 +08:00
session = self.create_session(u, user_path)
self.session_manager.save(hkey, session)
result = {'key': hkey}
return Response(
status='200 OK',
content_type='application/json',
body=json.dumps(result))
else:
# TODO: do I have to pass 'null' for the client to receive None?
raise HTTPForbidden('null')
if session is None:
raise HTTPForbidden()
if url in SyncCollectionHandler.operations + SyncMediaHandler.operations:
# 'meta' passes the SYNC_VER but it isn't used in the handler
2013-10-14 04:42:05 +08:00
if url == 'meta':
if session.skey == None and req.POST.has_key('s'):
session.skey = req.POST['s']
2013-10-14 04:42:05 +08:00
if data.has_key('v'):
session.version = data['v']
del data['v']
if data.has_key('cv'):
session.client_version = data['cv']
self.session_manager.save(hkey, session)
session = self.session_manager.load(hkey, self.create_session)
2013-10-14 04:42:05 +08:00
thread = session.get_thread()
# run hook_pre_sync if one is defined
if url == 'start':
if self.hook_pre_sync is not None:
thread.execute(self.hook_pre_sync, [session])
# Create a closure to run this operation inside of the thread allocated to this collection
def runFunc(col):
handler = session.get_handler_for_operation(url, col)
func = getattr(handler, url)
result = func(**data)
2013-10-14 04:42:05 +08:00
col.save()
return result
runFunc.func_name = url
# Send to the thread to execute
result = thread.execute(runFunc)
2013-04-04 05:42:28 +08:00
# If it's a complex data type, we convert it to JSON
if type(result) not in (str, unicode):
result = json.dumps(result)
if url == 'finish':
2013-10-14 04:42:05 +08:00
# TODO: Apparently 'finish' isn't when we're done because 'mediaList' comes
# after it... When can we possibly delete the session?
#self.session_manager.delete(hkey)
# run hook_post_sync if one is defined
if self.hook_post_sync is not None:
thread.execute(self.hook_post_sync, [session])
2013-08-02 01:07:49 +08:00
return Response(
status='200 OK',
content_type='application/json',
2013-04-04 05:42:28 +08:00
body=result)
2013-10-14 04:42:05 +08:00
elif url == 'upload':
thread = session.get_thread()
2013-10-14 04:42:05 +08:00
result = thread.execute(self.operation_upload, [data['data'], session])
return Response(
status='200 OK',
content_type='text/plain',
body='OK' if result else 'Error')
2013-10-14 04:42:05 +08:00
elif url == 'download':
thread = session.get_thread()
result = thread.execute(self.operation_download, [session])
return Response(
status='200 OK',
content_type='text/plain',
2013-10-14 04:42:05 +08:00
body=result)
# This was one of our operations but it didn't get handled... Oops!
raise HTTPInternalServerError()
# media sync
elif req.path.startswith(self.base_media_url):
if session is None:
raise HTTPForbidden()
url = req.path[len(self.base_media_url):]
if url not in self.valid_urls:
raise HTTPNotFound()
if url == 'begin' or url == 'mediaChanges' or url == 'uploadChanges':
data['skey'] = session.skey
return self._execute_handler_method_in_thread(url, data, session)
return Response(status='200 OK', content_type='text/plain', body='Anki Sync Server')
@staticmethod
def _execute_handler_method_in_thread(method_name, keyword_args, session):
"""
Gets and runs the handler method specified by method_name inside the
thread for session. The handler method will access the collection as
self.col.
"""
def run_func(col):
# Retrieve the correct handler method.
handler = session.get_handler_for_operation(method_name, col)
handler_method = getattr(handler, method_name)
res = handler_method(**keyword_args)
col.save()
return res
run_func.func_name = method_name # More useful debugging messages.
# Send the closure to the thread for execution.
thread = session.get_thread()
result = thread.execute(run_func)
return result
2013-10-14 04:42:05 +08:00
class SqliteSessionManager(SimpleSessionManager):
"""Stores sessions in a SQLite database to prevent the user from being logged out
everytime the SyncApp is restarted."""
def __init__(self, session_db_path):
SimpleSessionManager.__init__(self)
self.session_db_path = os.path.abspath(session_db_path)
def _conn(self):
new = not os.path.exists(self.session_db_path)
conn = sqlite.connect(self.session_db_path)
if new:
cursor = conn.cursor()
cursor.execute("CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, user VARCHAR, path VARCHAR)")
2013-10-14 04:42:05 +08:00
return conn
def load(self, hkey, session_factory=None):
session = SimpleSessionManager.load(self, hkey)
if session is not None:
return session
conn = self._conn()
cursor = conn.cursor()
cursor.execute("SELECT skey, user, path FROM session WHERE hkey=?", (hkey,))
res = cursor.fetchone()
if res is not None:
session = self.sessions[hkey] = session_factory(res[1], res[2])
session.skey = res[0]
return session
def load_from_skey(self, skey, session_factory=None):
session = SimpleSessionManager.load_from_skey(self, skey)
if session is not None:
return session
conn = self._conn()
cursor = conn.cursor()
cursor.execute("SELECT hkey, user, path FROM session WHERE skey=?", (skey,))
2013-10-14 04:42:05 +08:00
res = cursor.fetchone()
if res is not None:
session = self.sessions[res[0]] = session_factory(res[1], res[2])
session.skey = skey
2013-10-14 04:42:05 +08:00
return session
2013-08-14 19:22:18 +08:00
2013-10-14 04:42:05 +08:00
def save(self, hkey, session):
SimpleSessionManager.save(self, hkey, session)
conn = self._conn()
cursor = conn.cursor()
cursor.execute("INSERT OR REPLACE INTO session (hkey, skey, user, path) VALUES (?, ?, ?, ?)",
(hkey, session.skey, session.name, session.path))
2013-10-14 04:42:05 +08:00
conn.commit()
def delete(self, hkey):
SimpleSessionManager.delete(self, hkey)
conn = self._conn()
cursor = conn.cursor()
cursor.execute("DELETE FROM session WHERE hkey=?", (hkey,))
conn.commit()
class SqliteUserManager(SimpleUserManager):
"""Authenticates users against a SQLite database."""
def __init__(self, auth_db_path):
self.auth_db_path = os.path.abspath(auth_db_path)
2013-08-14 19:22:18 +08:00
def authenticate(self, username, password):
"""Returns True if this username is allowed to connect with this password. False otherwise."""
2013-10-14 04:42:05 +08:00
conn = sqlite.connect(self.auth_db_path)
cursor = conn.cursor()
param = (username,)
cursor.execute("SELECT hash FROM auth WHERE user=?", param)
db_ret = cursor.fetchone()
if db_ret != None:
db_hash = str(db_ret[0])
salt = db_hash[-16:]
hashobj = hashlib.sha256()
hashobj.update(username+password+salt)
conn.close()
return (db_ret != None and hashobj.hexdigest()+salt == db_hash)
2013-10-14 04:42:05 +08:00
# Our entry point
def make_app(global_conf, **local_conf):
if local_conf.has_key('session_db_path'):
local_conf['session_manager'] = SqliteSessionManager(local_conf['session_db_path'])
if local_conf.has_key('auth_db_path'):
local_conf['user_manager'] = SqliteUserManager(local_conf['auth_db_path'])
return SyncApp(**local_conf)
def main():
from wsgiref.simple_server import make_server
2013-10-19 13:46:55 +08:00
from ankisyncd.thread import shutdown
2013-08-14 19:22:18 +08:00
config = SafeConfigParser()
2014-01-04 21:06:14 +08:00
config.read("ankisyncd.conf")
2013-08-14 19:22:18 +08:00
2013-10-14 04:42:05 +08:00
ankiserver = SyncApp(config)
2013-08-14 19:22:18 +08:00
httpd = make_server('', config.getint("sync_app", "port"), ankiserver)
try:
print "Starting..."
httpd.serve_forever()
except KeyboardInterrupt:
print "Exiting ..."
finally:
shutdown()
if __name__ == '__main__': main()