Remove things not needed by sync_app

This commit is contained in:
jdoe0 2013-08-14 18:22:18 +07:00
parent fd58fcd9ec
commit d573bf6b42
17 changed files with 32 additions and 2479 deletions

View File

@ -2,14 +2,14 @@
import sys
sys.path.insert(0, "/usr/share/anki")
def server_runner(app, global_conf, **kw):
""" Special version of paste.httpserver.server_runner which calls
AnkiServer.threading.shutdown() on server exit."""
from paste.httpserver import server_runner as paste_server_runner
from AnkiServer.threading import shutdown
try:
paste_server_runner(app, global_conf, **kw)
finally:
shutdown()
#def server_runner(app, global_conf, **kw):
# """ Special version of paste.httpserver.server_runner which calls
# AnkiServer.thread.shutdown() on server exit."""
#
# from paste.httpserver import server_runner as paste_server_runner
# from AnkiServer.thread import shutdown
# try:
# paste_server_runner(app, global_conf, **kw)
# finally:
# shutdown()
#

View File

@ -1,2 +0,0 @@
# package

View File

@ -1,874 +0,0 @@
from webob.dec import wsgify
from webob.exc import *
from webob import Response
#from pprint import pprint
try:
import simplejson as json
from simplejson import JSONDecodeError
except ImportError:
import json
JSONDecodeError = ValueError
import os, logging
import anki.consts
import anki.lang
from anki.lang import _ as t
__all__ = ['RestApp', 'RestHandlerBase', 'noReturnValue']
def noReturnValue(func):
func.hasReturnValue = False
return func
class RestHandlerBase(object):
"""Parent class for a handler group."""
hasReturnValue = True
class _RestHandlerWrapper(RestHandlerBase):
"""Wrapper for functions that we can't modify."""
def __init__(self, func_name, func, hasReturnValue=True):
self.func_name = func_name
self.func = func
self.hasReturnValue = hasReturnValue
def __call__(self, *args, **kw):
return self.func(*args, **kw)
class RestHandlerRequest(object):
def __init__(self, app, data, ids, session):
self.app = app
self.data = data
self.ids = ids
self.session = session
def copy(self):
return RestHandlerRequest(self.app, self.data.copy(), self.ids[:], self.session)
def __eq__(self, other):
return self.app == other.app and self.data == other.data and self.ids == other.ids and self.session == other.session
class RestApp(object):
"""A WSGI app that implements RESTful operations on Collections, Decks and Cards."""
# Defines not only the valid handler types, but their position in the URL string
handler_types = ['collection', ['model', 'note', 'deck', 'card']]
def __init__(self, data_root, allowed_hosts='*', setup_new_collection=None, use_default_handlers=True, collection_manager=None):
from AnkiServer.threading import getCollectionManager
self.data_root = os.path.abspath(data_root)
self.allowed_hosts = allowed_hosts
self.setup_new_collection = setup_new_collection
if collection_manager is not None:
self.collection_manager = collection_manager
else:
self.collection_manager = getCollectionManager()
self.handlers = {}
for type_list in self.handler_types:
if type(type_list) is not list:
type_list = [type_list]
for handler_type in type_list:
self.handlers[handler_type] = {}
if use_default_handlers:
self.add_handler_group('collection', CollectionHandler())
self.add_handler_group('note', NoteHandler())
self.add_handler_group('model', ModelHandler())
self.add_handler_group('deck', DeckHandler())
self.add_handler_group('card', CardHandler())
# hold per collection session data
self.sessions = {}
def add_handler(self, type, name, handler):
"""Adds a callback handler for a type (collection, deck, card) with a unique name.
- 'type' is the item that will be worked on, for example: collection, deck, and card.
- 'name' is a unique name for the handler that gets used in the URL.
- 'handler' is a callable that takes (collection, data, ids).
"""
if self.handlers[type].has_key(name):
raise "Handler already for %(type)s/%(name)s exists!"
self.handlers[type][name] = handler
def add_handler_group(self, type, group):
"""Adds several handlers for every public method on an object descended from RestHandlerBase.
This allows you to create a single class with several methods, so that you can quickly
create a group of related handlers."""
import inspect
for name, method in inspect.getmembers(group, predicate=inspect.ismethod):
if not name.startswith('_'):
if hasattr(group, 'hasReturnValue') and not hasattr(method, 'hasReturnValue'):
method = _RestHandlerWrapper(group.__class__.__name__ + '.' + name, method, group.hasReturnValue)
self.add_handler(type, name, method)
def execute_handler(self, type, name, col, req):
"""Executes the handler with the given type and name, passing in the col and req as arguments."""
handler, hasReturnValue = self._getHandler(type, name)
ret = handler(col, req)
if hasReturnValue:
return ret
def list_collections(self):
"""Returns an array of valid collection names in our self.data_path."""
return [x for x in os.listdir(self.data_root) if os.path.exists(os.path.join(self.data_root, x, 'collection.anki2'))]
def _checkRequest(self, req):
"""Raises an exception if the request isn't allowed or valid for some reason."""
if self.allowed_hosts != '*':
try:
remote_addr = req.headers['X-Forwarded-For']
except KeyError:
remote_addr = req.remote_addr
if remote_addr != self.allowed_hosts:
raise HTTPForbidden()
if req.method != 'POST':
raise HTTPMethodNotAllowed(allow=['POST'])
def _parsePath(self, path):
"""Takes a request path and returns a tuple containing the handler type, name
and a list of ids.
Raises an HTTPNotFound exception if the path is invalid."""
if path in ('', '/'):
raise HTTPNotFound()
# split the URL into a list of parts
if path[0] == '/':
path = path[1:]
parts = path.split('/')
# pull the type and context from the URL parts
handler_type = None
ids = []
for type_list in self.handler_types:
if len(parts) == 0:
break
# some URL positions can have multiple types
if type(type_list) is not list:
type_list = [type_list]
# get the handler_type
if parts[0] not in type_list:
break
handler_type = parts.pop(0)
# add the id to the id list
if len(parts) > 0:
ids.append(parts.pop(0))
# break if we don't have enough parts to make a new type/id pair
if len(parts) < 2:
break
# sanity check to make sure the URL is valid
if len(parts) > 1 or len(ids) == 0:
raise HTTPNotFound()
# get the handler name
if len(parts) == 0:
name = 'index'
else:
name = parts[0]
return (handler_type, name, ids)
def _getCollectionPath(self, collection_id):
"""Returns the path to the collection based on the collection_id from the request.
Raises HTTPBadRequest if the collection_id is invalid."""
path = os.path.normpath(os.path.join(self.data_root, collection_id, 'collection.anki2'))
if path[0:len(self.data_root)] != self.data_root:
# attempting to escape our data jail!
raise HTTPBadRequest('"%s" is not a valid collection' % collection_id)
return path
def _getHandler(self, type, name):
"""Returns a tuple containing handler function for this type and name, and a boolean flag
if that handler has a return value.
Raises an HTTPNotFound exception if the handler doesn't exist."""
# get the handler function
try:
handler = self.handlers[type][name]
except KeyError:
raise HTTPNotFound()
# get if we have a return value
hasReturnValue = True
if hasattr(handler, 'hasReturnValue'):
hasReturnValue = handler.hasReturnValue
return (handler, hasReturnValue)
def _parseRequestBody(self, req):
"""Parses the request body (JSON) into a Python dict and returns it.
Raises an HTTPBadRequest exception if the request isn't valid JSON."""
try:
data = json.loads(req.body)
except JSONDecodeError, e:
logging.error(req.path+': Unable to parse JSON: '+str(e), exc_info=True)
raise HTTPBadRequest()
# fix for a JSON encoding 'quirk' in PHP
if type(data) == list and len(data) == 0:
data = {}
# make the keys into non-unicode strings
data = dict([(str(k), v) for k, v in data.items()])
return data
@wsgify
def __call__(self, req):
# make sure the request is valid
self._checkRequest(req)
if req.path == '/list_collections':
return Response(json.dumps(self.list_collections()), content_type='application/json')
# parse the path
type, name, ids = self._parsePath(req.path)
# get the collection path
collection_path = self._getCollectionPath(ids[0])
print collection_path
# get the handler function
handler, hasReturnValue = self._getHandler(type, name)
# parse the request body
data = self._parseRequestBody(req)
# get the users session
try:
session = self.sessions[ids[0]]
except KeyError:
session = self.sessions[ids[0]] = {}
# debug
from pprint import pprint
pprint(data)
# run it!
col = self.collection_manager.get_collection(collection_path, self.setup_new_collection)
handler_request = RestHandlerRequest(self, data, ids, session)
try:
output = col.execute(handler, [handler_request], {}, hasReturnValue)
except Exception, e:
logging.error(e)
return HTTPInternalServerError()
if output is None:
return Response('', content_type='text/plain')
else:
return Response(json.dumps(output), content_type='application/json')
class CollectionHandler(RestHandlerBase):
"""Default handler group for 'collection' type."""
#
# MODELS - Store fields definitions and templates for notes
#
def list_models(self, col, req):
# This is already a list of dicts, so it doesn't need to be serialized
return col.models.all()
def find_model_by_name(self, col, req):
# This is already a list of dicts, so it doesn't need to be serialized
return col.models.byName(req.data['model'])
#
# NOTES - Information (in fields per the model) that can generate a card
# (based on a template from the model).
#
def find_notes(self, col, req):
query = req.data.get('query', '')
ids = col.findNotes(query)
if req.data.get('preload', False):
notes = [NoteHandler._serialize(col.getNote(id)) for id in ids]
else:
notes = [{'id': id} for id in ids]
return notes
def latest_notes(self, col, req):
# TODO: use SQLAlchemy objects to do this
sql = "SELECT n.id FROM notes AS n";
args = []
if req.data.has_key('updated_since'):
sql += ' WHERE n.mod > ?'
args.append(req.data['updated_since'])
sql += ' ORDER BY n.mod DESC'
sql += ' LIMIT ' + str(req.data.get('limit', 10))
ids = col.db.list(sql, *args)
if req.data.get('preload', False):
notes = [NoteHandler._serialize(col.getNote(id)) for id in ids]
else:
notes = [{'id': id} for id in ids]
return notes
@noReturnValue
def add_note(self, col, req):
from anki.notes import Note
# TODO: I think this would be better with 'model' for the name
# and 'mid' for the model id.
if type(req.data['model']) in (str, unicode):
model = col.models.byName(req.data['model'])
else:
model = col.models.get(req.data['model'])
note = Note(col, model)
for name, value in req.data['fields'].items():
note[name] = value
if req.data.has_key('tags'):
note.setTagsFromStr(req.data['tags'])
col.addNote(note)
def list_tags(self, col, req):
return col.tags.all()
#
# DECKS - Groups of cards
#
def list_decks(self, col, req):
# This is already a list of dicts, so it doesn't need to be serialized
return col.decks.all()
@noReturnValue
def select_deck(self, col, req):
col.decks.select(req.data['deck_id'])
dyn_modes = {
'random': anki.consts.DYN_RANDOM,
'added': anki.consts.DYN_ADDED,
'due': anki.consts.DYN_DUE,
}
def create_dynamic_deck(self, col, req):
name = req.data.get('name', t('Custom Study Session'))
deck = col.decks.byName(name)
if deck:
if not deck['dyn']:
raise HTTPBadRequest("There is an existing non-dynamic deck with the name %s" % name)
# safe to empty it because it's a dynamic deck
# TODO: maybe this should be an option?
col.sched.emptyDyn(deck['id'])
else:
deck = col.decks.get(col.decks.newDyn(name))
query = req.data.get('query', '')
count = int(req.data.get('count', 100))
mode = req.data.get('mode', 'random')
try:
mode = self.dyn_modes[mode]
except KeyError:
raise HTTPBadRequest("Unknown mode: %s" % mode)
deck['terms'][0] = [query, count, mode]
if mode != anki.consts.DYN_RANDOM:
deck['resched'] = True
else:
deck['resched'] = False
if not col.sched.rebuildDyn(deck['id']):
raise HTTPBadRequest("No cards matched the criteria you provided")
col.decks.save(deck)
col.sched.reset()
return deck
#
# CARD - A specific card in a deck with a history of review (generated from
# a note based on the template).
#
def find_cards(self, col, req):
query = req.data.get('query', '')
order = req.data.get('order', False)
ids = anki.find.Finder(col).findCards(query, order)
if req.data.get('preload', False):
cards = [CardHandler._serialize(col.getCard(id), req.data) for id in ids]
else:
cards = [{'id': id} for id in ids]
return cards
def latest_cards(self, col, req):
# TODO: use SQLAlchemy objects to do this
sql = "SELECT c.id FROM notes AS n INNER JOIN cards AS c ON c.nid = n.id";
args = []
if req.data.has_key('updated_since'):
sql += ' WHERE n.mod > ?'
args.append(req.data['updated_since'])
sql += ' ORDER BY n.mod DESC'
sql += ' LIMIT ' + str(req.data.get('limit', 10))
ids = col.db.list(sql, *args)
if req.data.get('preload', False):
cards = [CardHandler._serialize(col.getCard(id), req.data) for id in ids]
else:
cards = [{'id': id} for id in ids]
return cards
#
# SCHEDULER - Controls card review, ie. intervals, what cards are due, answering a card, etc.
#
def reset_scheduler(self, col, req):
if req.data.has_key('deck'):
deck = DeckHandler._get_deck(col, req.data['deck'])
col.decks.select(deck['id'])
col.sched.reset()
counts = col.sched.counts()
return {
'new_cards': counts[0],
'learning_cards': counts[1],
'review_cards': counts[1],
}
def extend_scheduler_limits(self, col, req):
new_cards = int(req.data.get('new_cards', 0))
review_cards = int(req.data.get('review_cards', 0))
col.sched.extendLimits(new_cards, review_cards)
col.sched.reset()
button_labels = ['Easy', 'Good', 'Hard']
def _get_answer_buttons(self, col, card):
l = []
# Put the correct number of buttons
cnt = col.sched.answerButtons(card)
for idx in range(0, cnt - 1):
l.append(self.button_labels[idx])
l.append('Again')
l.reverse()
# Loop through and add the ease, estimated time (in seconds) and other info
return [{
'ease': ease,
'label': label,
'string_label': t(label),
'interval': col.sched.nextIvl(card, ease),
'string_interval': col.sched.nextIvlStr(card, ease),
} for ease, label in enumerate(l, 1)]
def next_card(self, col, req):
if req.data.has_key('deck'):
deck = DeckHandler._get_deck(col, req.data['deck'])
col.decks.select(deck['id'])
card = col.sched.getCard()
if card is None:
return None
# put it into the card cache to be removed when we answer it
#if not req.session.has_key('cards'):
# req.session['cards'] = {}
#req.session['cards'][long(card.id)] = card
card.startTimer()
result = CardHandler._serialize(card, req.data)
result['answer_buttons'] = self._get_answer_buttons(col, card)
return result
@noReturnValue
def answer_card(self, col, req):
import time
card_id = long(req.data['id'])
ease = int(req.data['ease'])
card = col.getCard(card_id)
if req.data.has_key('timerStarted'):
card.timerStarted = float(req.data['timerStarted'])
else:
card.timerStarted = time.time()
col.sched.answerCard(card, ease)
@noReturnValue
def suspend_cards(self, col, req):
card_ids = req.data['ids']
col.sched.suspendCards(card_ids)
@noReturnValue
def unsuspend_cards(self, col, req):
card_ids = req.data['ids']
col.sched.unsuspendCards(card_ids)
def cards_recent_ease(self, col, req):
"""Returns the most recent ease for each card."""
# TODO: Use sqlalchemy to build this query!
sql = "SELECT r.cid, r.ease, r.id FROM revlog AS r INNER JOIN (SELECT cid, MAX(id) AS id FROM revlog GROUP BY cid) AS q ON r.cid = q.cid AND r.id = q.id"
where = []
if req.data.has_key('ids'):
where.append('ids IN (' + (','.join(["'%s'" % x for x in req.data['ids']])) + ')')
if len(where) > 0:
sql += ' WHERE ' + ' AND '.join(where)
result = []
for r in col.db.all(sql):
result.append({'id': r[0], 'ease': r[1], 'timestamp': int(r[2] / 1000)})
return result
def latest_revlog(self, col, req):
"""Returns recent entries from the revlog."""
# TODO: Use sqlalchemy to build this query!
sql = "SELECT r.id, r.ease, r.cid, r.usn, r.ivl, r.lastIvl, r.factor, r.time, r.type FROM revlog AS r"
args = []
if req.data.has_key('updated_since'):
sql += ' WHERE r.id > ?'
args.append(long(req.data['updated_since']) * 1000)
sql += ' ORDER BY r.id DESC'
sql += ' LIMIT ' + str(req.data.get('limit', 100))
revlog = col.db.all(sql, *args)
return [{
'id': r[0],
'ease': r[1],
'timestamp': int(r[0] / 1000),
'card_id': r[2],
'usn': r[3],
'interval': r[4],
'last_interval': r[5],
'factor': r[6],
'time': r[7],
'type': r[8],
} for r in revlog]
stats_reports = {
'today': 'todayStats',
'due': 'dueGraph',
'reps': 'repsGraph',
'interval': 'ivlGraph',
'hourly': 'hourGraph',
'ease': 'easeGraph',
'card': 'cardGraph',
'footer': 'footer',
}
stats_reports_order = ['today', 'due', 'reps', 'interval', 'hourly', 'ease', 'card', 'footer']
def stats_report(self, col, req):
import anki.stats
import re
stats = anki.stats.CollectionStats(col)
stats.width = int(req.data.get('width', 600))
stats.height = int(req.data.get('height', 200))
reports = req.data.get('reports', self.stats_reports_order)
include_css = req.data.get('include_css', False)
include_jquery = req.data.get('include_jquery', False)
include_flot = req.data.get('include_flot', False)
if include_css:
from anki.statsbg import bg
html = stats.css % bg
else:
html = ''
for name in reports:
if not self.stats_reports.has_key(name):
raise HTTPBadRequest("Unknown report name: %s" % name)
func = getattr(stats, self.stats_reports[name])
html += '<div class="anki-graph anki-graph-%s">' % name
html += func()
html += '</div>'
# fix an error in some inline styles
# TODO: submit a patch to Anki!
html = re.sub(r'style="width:([0-9\.]+); height:([0-9\.]+);"', r'style="width:\1px; height: \2px;"', html)
html = re.sub(r'-webkit-transform: ([^;]+);', r'-webkit-transform: \1; -moz-transform: \1; -ms-transform: \1; -o-transform: \1; transform: \1;', html)
scripts = []
if include_jquery or include_flot:
import anki.js
if include_jquery:
scripts.append(anki.js.jquery)
if include_flot:
scripts.append(anki.js.plot)
if len(scripts) > 0:
html = "<script>%s\n</script>" % ''.join(scripts) + html
return html
#
# GLOBAL / MISC
#
@noReturnValue
def set_language(self, col, req):
anki.lang.setLang(req.data['code'])
class ImportExportHandler(RestHandlerBase):
"""Handler group for the 'collection' type, but it's not added by default."""
def _get_filedata(self, data):
import urllib2
if data.has_key('data'):
return data['data']
fd = None
try:
fd = urllib2.urlopen(data['url'])
filedata = fd.read()
finally:
if fd is not None:
fd.close()
return filedata
def _get_importer_class(self, data):
filetype = data['filetype']
from AnkiServer.importer import get_importer_class
importer_class = get_importer_class(filetype)
if importer_class is None:
raise HTTPBadRequest("Unknown filetype '%s'" % filetype)
return importer_class
def import_file(self, col, req):
import AnkiServer.importer
import tempfile
# get the importer class
importer_class = self._get_importer_class(req.data)
# get the file data
filedata = self._get_filedata(req.data)
# write the file data to a temporary file
try:
path = None
with tempfile.NamedTemporaryFile('wt', delete=False) as fd:
path = fd.name
fd.write(filedata)
AnkiServer.importer.import_file(importer_class, col, path)
finally:
if path is not None:
os.unlink(path)
class ModelHandler(RestHandlerBase):
"""Default handler group for 'model' type."""
def field_names(self, col, req):
model = col.models.get(req.ids[1])
if model is None:
raise HTTPNotFound()
return col.models.fieldNames(model)
class NoteHandler(RestHandlerBase):
"""Default handler group for 'note' type."""
@staticmethod
def _serialize(note):
d = {
'id': note.id,
'guid': note.guid,
'model': note.model(),
'mid': note.mid,
'mod': note.mod,
'scm': note.scm,
'tags': note.tags,
'string_tags': ' '.join(note.tags),
'fields': {},
'flags': note.flags,
'usn': note.usn,
}
# add all the fields
for name, value in note.items():
d['fields'][name] = value
return d
def index(self, col, req):
note = col.getNote(req.ids[1])
return self._serialize(note)
@noReturnValue
def add_tags(self, col, req):
note = col.getNote(req.ids[1])
for tag in req.data['tags']:
note.addTag(tag)
note.flush()
@noReturnValue
def remove_tags(self, col, req):
note = col.getNote(req.ids[1])
for tag in req.data['tags']:
note.delTag(tag)
note.flush()
class DeckHandler(RestHandlerBase):
"""Default handler group for 'deck' type."""
@staticmethod
def _get_deck(col, val):
try:
did = long(val)
deck = col.decks.get(did, False)
except ValueError:
deck = col.decks.byName(val)
if deck is None:
raise HTTPNotFound('No deck with id or name: ' + str(val))
return deck
def index(self, col, req):
return self._get_deck(col, req.ids[1])
def next_card(self, col, req):
req_copy = req.copy()
req_copy.data['deck'] = req.ids[1]
del req_copy.ids[1]
# forward this to the CollectionHandler
return req.app.execute_handler('collection', 'next_card', col, req_copy)
def get_conf(self, col, req):
# TODO: should probably live in a ConfHandler
return col.decks.confForDid(req.ids[1])
@noReturnValue
def set_update_conf(self, col, req):
data = req.data.copy()
del data['id']
conf = col.decks.confForDid(req.ids[1])
conf = conf.copy()
conf.update(data)
col.decks.updateConf(conf)
class CardHandler(RestHandlerBase):
"""Default handler group for 'card' type."""
@staticmethod
def _serialize(card, opts):
d = {
'id': card.id,
'isEmpty': card.isEmpty(),
'css': card.css(),
'question': card._getQA()['q'],
'answer': card._getQA()['a'],
'did': card.did,
'due': card.due,
'factor': card.factor,
'ivl': card.ivl,
'lapses': card.lapses,
'left': card.left,
'mod': card.mod,
'nid': card.nid,
'odid': card.odid,
'odue': card.odue,
'ord': card.ord,
'queue': card.queue,
'reps': card.reps,
'type': card.type,
'usn': card.usn,
'timerStarted': card.timerStarted,
}
if opts.get('load_note', False):
d['note'] = NoteHandler._serialize(card.col.getNote(card.nid))
if opts.get('load_deck', False):
d['deck'] = card.col.decks.get(card.did)
if opts.get('load_latest_revlog', False):
d['latest_revlog'] = CardHandler._latest_revlog(card.col, card.id)
return d
@staticmethod
def _latest_revlog(col, card_id):
r = col.db.first("SELECT r.id, r.ease FROM revlog AS r WHERE r.cid = ? ORDER BY id DESC LIMIT 1", card_id)
if r:
return {'id': r[0], 'ease': r[1], 'timestamp': int(r[0] / 1000)}
def index(self, col, req):
card = col.getCard(req.ids[1])
return self._serialize(card, req.data)
def _forward_to_note(self, col, req, name):
card = col.getCard(req.ids[1])
req_copy = req.copy()
req_copy.ids[1] = card.nid
return req.app.execute_handler('note', name, col, req)
@noReturnValue
def add_tags(self, col, req):
self._forward_to_note(col, req, 'add_tags')
@noReturnValue
def remove_tags(self, col, req):
self._forward_to_note(col, req, 'remove_tags')
def stats_report(self, col, req):
card = col.getCard(req.ids[1])
return col.cardStats(card)
def latest_revlog(self, col, req):
return self._latest_revlog(col, req.ids[1])
# Our entry point
def make_app(global_conf, **local_conf):
# TODO: we should setup the default language from conf!
# setup the logger
from AnkiServer.utils import setup_logging
setup_logging(local_conf.get('logging.config_file'))
return RestApp(
data_root=local_conf.get('data_root', '.'),
allowed_hosts=local_conf.get('allowed_hosts', '*')
)

View File

@ -1,377 +0,0 @@
from webob.dec import wsgify
from webob.exc import *
from webob import Response
import anki
from anki.sync import HttpSyncServer, CHUNK_SIZE
from anki.db import sqlite
from anki.utils import checksum
import AnkiServer.deck
import MySQLdb
try:
import simplejson as json
except ImportError:
import json
import os, zlib, tempfile, time
def makeArgs(mdict):
d = dict(mdict.items())
# TODO: use password/username/version for something?
for k in ['p','u','v','d']:
if d.has_key(k):
del d[k]
return d
class FileIterable(object):
def __init__(self, fn):
self.fn = fn
def __iter__(self):
return FileIterator(self.fn)
class FileIterator(object):
def __init__(self, fn):
self.fn = fn
self.fo = open(self.fn, 'rb')
self.c = zlib.compressobj()
self.flushed = False
def __iter__(self):
return self
def next(self):
data = self.fo.read(CHUNK_SIZE)
if not data:
if not self.flushed:
self.flushed = True
return self.c.flush()
else:
raise StopIteration
return self.c.compress(data)
def lock_deck(path):
""" Gets exclusive access to this deck path. If there is a DeckThread running on this
deck, this will wait for its current operations to complete before temporarily stopping
it. """
from AnkiServer.deck import thread_pool
if thread_pool.decks.has_key(path):
thread_pool.decks[path].stop_and_wait()
thread_pool.lock(path)
def unlock_deck(path):
""" Release exclusive access to this deck path. """
from AnkiServer.deck import thread_pool
thread_pool.unlock(path)
class SyncAppHandler(HttpSyncServer):
operations = ['summary','applyPayload','finish','createDeck','getOneWayPayload']
def __init__(self):
HttpSyncServer.__init__(self)
def createDeck(self, name):
# The HttpSyncServer.createDeck doesn't return a valid value! This seems to be
# a bug in libanki.sync ...
return self.stuff({"status": "OK"})
def finish(self):
# The HttpSyncServer has no finish() function... I can only assume this is a bug too!
return self.stuff("OK")
class SyncApp(object):
valid_urls = SyncAppHandler.operations + ['getDecks','fullup','fulldown']
def __init__(self, **kw):
self.data_root = os.path.abspath(kw.get('data_root', '.'))
self.base_url = kw.get('base_url', '/')
self.users = {}
# make sure the base_url has a trailing slash
if len(self.base_url) == 0:
self.base_url = '/'
elif self.base_url[-1] != '/':
self.base_url = base_url + '/'
# setup mysql connection
mysql_args = {}
for k, v in kw.items():
if k.startswith('mysql.'):
mysql_args[k[6:]] = v
self.mysql_args = mysql_args
self.conn = None
# get SQL statements
self.sql_check_password = kw.get('sql_check_password')
self.sql_username2dirname = kw.get('sql_username2dirname')
default_libanki_version = '.'.join(anki.version.split('.')[:2])
def user_libanki_version(self, u):
try:
s = self.users[u]['libanki']
except KeyError:
return self.default_libanki_version
parts = s.split('.')
if parts[0] == '1':
if parts[1] == '0':
return '1.0'
elif parts[1] in ('1','2'):
return '1.2'
return self.default_libanki_version
# Mimcs from anki.sync.SyncTools.stuff()
def _stuff(self, data):
return zlib.compress(json.dumps(data))
def _connect_mysql(self):
if self.conn is None and len(self.mysql_args) > 0:
self.conn = MySQLdb.connect(**self.mysql_args)
def _execute_sql(self, sql, args=()):
self._connect_mysql()
try:
cur = self.conn.cursor()
cur.execute(sql, args)
except MySQLdb.OperationalError, e:
if e.args[0] == 2006:
# MySQL server has gone away message
self.conn = None
self._connect_mysql()
cur = self.conn.cursor()
cur.execute(sql, args)
return cur
def check_password(self, username, password):
if len(self.mysql_args) > 0 and self.sql_check_password is not None:
cur = self._execute_sql(self.sql_check_password, (username, password))
row = cur.fetchone()
return row is not None
return True
def username2dirname(self, username):
if len(self.mysql_args) > 0 and self.sql_username2dirname is not None:
cur = self._execute_sql(self.sql_username2dirname, (username,))
row = cur.fetchone()
if row is None:
return None
return str(row[0])
return username
def _getDecks(self, user_path):
decks = {}
if os.path.exists(user_path):
# It is a dict of {'deckName':[modified,lastSync]}
for fn in os.listdir(unicode(user_path, 'utf-8')):
if len(fn) > 5 and fn[-5:] == '.anki':
d = os.path.abspath(os.path.join(user_path, fn))
# For simplicity, we will always open a thread. But this probably
# isn't necessary!
thread = AnkiServer.deck.thread_pool.start(d)
def lookupModifiedLastSync(wrapper):
deck = wrapper.open()
return [deck.modified, deck.lastSync]
res = thread.execute(lookupModifiedLastSync, [thread.wrapper])
# if thread_pool.threads.has_key(d):
# thread = thread_pool.threads[d]
# def lookupModifiedLastSync(wrapper):
# deck = wrapper.open()
# return [deck.modified, deck.lastSync]
# res = thread.execute(lookup, [thread.wrapper])
# else:
# conn = sqlite.connect(d)
# cur = conn.cursor()
# cur.execute("select modified, lastSync from decks")
#
# res = list(cur.fetchone())
#
# cur.close()
# conn.close()
#self.decks[fn[:-5]] = ["%.5f" % x for x in res]
decks[fn[:-5]] = res
# same as HttpSyncServer.getDecks()
return self._stuff({
"status": "OK",
"decks": decks,
"timestamp": time.time(),
})
def _fullup(self, wrapper, infile, version):
wrapper.close()
path = wrapper.path
# DRS: most of this function was graciously copied
# from anki.sync.SyncTools.fullSyncFromServer()
(fd, tmpname) = tempfile.mkstemp(dir=os.getcwd(), prefix="fullsync")
outfile = open(tmpname, 'wb')
decomp = zlib.decompressobj()
while 1:
data = infile.read(CHUNK_SIZE)
if not data:
outfile.write(decomp.flush())
break
outfile.write(decomp.decompress(data))
infile.close()
outfile.close()
os.close(fd)
# if we were successful, overwrite old deck
if os.path.exists(path):
os.unlink(path)
os.rename(tmpname, path)
# reset the deck name
c = sqlite.connect(path)
lastSync = time.time()
if version == '1':
c.execute("update decks set lastSync = ?", [lastSync])
elif version == '2':
c.execute("update decks set syncName = ?, lastSync = ?",
[checksum(path.encode("utf-8")), lastSync])
c.commit()
c.close()
return lastSync
def _stuffedResp(self, data):
return Response(
status='200 OK',
content_type='application/json',
content_encoding='deflate',
body=data)
@wsgify
def __call__(self, req):
if req.path.startswith(self.base_url):
url = req.path[len(self.base_url):]
if url not in self.valid_urls:
raise HTTPNotFound()
# get and check username and password
try:
u = req.str_params.getone('u')
p = req.str_params.getone('p')
except KeyError:
raise HTTPBadRequest('Must pass username and password')
if not self.check_password(u, p):
#raise HTTPBadRequest('Incorrect username or password')
return self._stuffedResp(self._stuff({'status':'invalidUserPass'}))
dirname = self.username2dirname(u)
if dirname is None:
raise HTTPBadRequest('Incorrect username or password')
user_path = os.path.join(self.data_root, dirname)
# get and lock the (optional) deck for this request
d = None
try:
d = unicode(req.str_params.getone('d'), 'utf-8')
# AnkiDesktop actually passes us the string value 'None'!
if d == 'None':
d = None
except KeyError:
pass
if d is not None:
# get the full deck path name
d = os.path.abspath(os.path.join(user_path, d)+'.anki')
if d[:len(user_path)] != user_path:
raise HTTPBadRequest('Bad deck name')
thread = AnkiServer.deck.thread_pool.start(d)
else:
thread = None
if url == 'getDecks':
# force the version up to 1.2.x
v = req.str_params.getone('libanki')
if v.startswith('0.') or v.startswith('1.0'):
return self._stuffedResp(self._stuff({'status':'oldVersion'}))
# store the data the user passes us keyed with the username. This
# will be used later by SyncAppHandler for version compatibility.
self.users[u] = makeArgs(req.str_params)
return self._stuffedResp(self._getDecks(user_path))
elif url in SyncAppHandler.operations:
handler = SyncAppHandler()
func = getattr(handler, url)
args = makeArgs(req.str_params)
if thread is not None:
# If this is for a specific deck, then it needs to run
# inside of the DeckThread.
def runFunc(wrapper):
handler.deck = wrapper.open()
ret = func(**args)
handler.deck.save()
return ret
runFunc.func_name = url
ret = thread.execute(runFunc, [thread.wrapper])
else:
# Otherwise, we can simply execute it in this thread.
ret = func(**args)
# clean-up user data stored in getDecks
if url == 'finish':
del self.users[u]
return self._stuffedResp(ret)
elif url == 'fulldown':
# set the syncTime before we send it
def setupForSync(wrapper):
wrapper.close()
c = sqlite.connect(d)
lastSync = time.time()
c.execute("update decks set lastSync = ?", [lastSync])
c.commit()
c.close()
thread.execute(setupForSync, [thread.wrapper])
return Response(status='200 OK', content_type='application/octet-stream', content_encoding='deflate', content_disposition='attachment; filename="'+os.path.basename(d).encode('utf-8')+'"', app_iter=FileIterable(d))
elif url == 'fullup':
#version = self.user_libanki_version(u)
try:
version = req.str_params.getone('v')
except KeyError:
version = '1'
infile = req.str_params['deck'].file
lastSync = thread.execute(self._fullup, [thread.wrapper, infile, version])
# append the 'lastSync' value for libanki 1.1 and 1.2
if version == '2':
body = 'OK '+str(lastSync)
else:
body = 'OK'
return Response(status='200 OK', content_type='application/text', body=body)
return Response(status='200 OK', content_type='text/plain', body='Anki Server')
# Our entry point
def make_app(global_conf, **local_conf):
return SyncApp(**local_conf)
def main():
from wsgiref.simple_server import make_server
ankiserver = DeckApp('.', '/sync/')
httpd = make_server('', 8001, ankiserver)
try:
httpd.serve_forever()
except KeyboardInterrupt:
print "Exiting ..."
finally:
AnkiServer.deck.thread_pool.shutdown()
if __name__ == '__main__': main()

View File

@ -1,100 +0,0 @@
from anki.importing.csvfile import TextImporter
from anki.importing.apkg import AnkiPackageImporter
from anki.importing.anki1 import Anki1Importer
from anki.importing.supermemo_xml import SupermemoXmlImporter
from anki.importing.mnemo import MnemosyneImporter
from anki.importing.pauker import PaukerImporter
__all__ = ['get_importer_class', 'import_file']
importers = {
'text': TextImporter,
'apkg': AnkiPackageImporter,
'anki1': Anki1Importer,
'supermemo_xml': SupermemoXmlImporter,
'mnemosyne': MnemosyneImporter,
'pauker': PaukerImporter,
}
def get_importer_class(type):
global importers
return importers.get(type)
def import_file(importer_class, col, path, allow_update = False):
importer = importer_class(col, path)
if allow_update:
importer.allowUpdate = True
if importer.needMapper:
importer.open()
importer.run()
#
# Monkey patch anki.importing.anki2 to support updating existing notes.
# TODO: submit a patch to Anki!
#
def _importNotes(self):
# build guid -> (id,mod,mid) hash & map of existing note ids
self._notes = {}
existing = {}
for id, guid, mod, mid in self.dst.db.execute(
"select id, guid, mod, mid from notes"):
self._notes[guid] = (id, mod, mid)
existing[id] = True
# we may need to rewrite the guid if the model schemas don't match,
# so we need to keep track of the changes for the card import stage
self._changedGuids = {}
# iterate over source collection
add = []
dirty = []
usn = self.dst.usn()
dupes = 0
for note in self.src.db.execute(
"select * from notes"):
# turn the db result into a mutable list
note = list(note)
shouldAdd = self._uniquifyNote(note)
if shouldAdd:
# ensure id is unique
while note[0] in existing:
note[0] += 999
existing[note[0]] = True
# bump usn
note[4] = usn
# update media references in case of dupes
note[6] = self._mungeMedia(note[MID], note[6])
add.append(note)
dirty.append(note[0])
# note we have the added the guid
self._notes[note[GUID]] = (note[0], note[3], note[MID])
else:
dupes += 1
# update existing note
newer = note[3] > mod
if self.allowUpdate and self._mid(mid) == mid and newer:
localNid = self._notes[note[GUID]][0]
note[0] = localNid
note[4] = usn
add.append(note)
dirty.append(note[0])
if dupes:
self.log.append(_("Already in collection: %s.") % (ngettext(
"%d note", "%d notes", dupes) % dupes))
# add to col
self.dst.db.executemany(
"insert or replace into notes values (?,?,?,?,?,?,?,?,?,?,?)",
add)
self.dst.updateFieldCache(dirty)
self.dst.tags.registerNotes(dirty)
from anki.importing.anki2 import Anki2Importer, MID, GUID
from anki.lang import _, ngettext
Anki2Importer._importNotes = _importNotes
Anki2Importer.allowUpdate = False

View File

@ -1,96 +0,0 @@
import logging
import logging.handlers
import types
# The SMTPHandler taken from python 2.6
class SMTPHandler(logging.Handler):
"""
A handler class which sends an SMTP email for each logging event.
"""
def __init__(self, mailhost, fromaddr, toaddrs, subject, credentials=None):
"""
Initialize the handler.
Initialize the instance with the from and to addresses and subject
line of the email. To specify a non-standard SMTP port, use the
(host, port) tuple format for the mailhost argument. To specify
authentication credentials, supply a (username, password) tuple
for the credentials argument.
"""
logging.Handler.__init__(self)
if type(mailhost) == types.TupleType:
self.mailhost, self.mailport = mailhost
else:
self.mailhost, self.mailport = mailhost, None
if type(credentials) == types.TupleType:
self.username, self.password = credentials
else:
self.username = None
self.fromaddr = fromaddr
if type(toaddrs) == types.StringType:
toaddrs = [toaddrs]
self.toaddrs = toaddrs
self.subject = subject
def getSubject(self, record):
"""
Determine the subject for the email.
If you want to specify a subject line which is record-dependent,
override this method.
"""
return self.subject
weekdayname = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
monthname = [None,
'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
def date_time(self):
"""
Return the current date and time formatted for a MIME header.
Needed for Python 1.5.2 (no email package available)
"""
year, month, day, hh, mm, ss, wd, y, z = time.gmtime(time.time())
s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
self.weekdayname[wd],
day, self.monthname[month], year,
hh, mm, ss)
return s
def emit(self, record):
"""
Emit a record.
Format the record and send it to the specified addressees.
"""
try:
import smtplib
try:
from email.utils import formatdate
except ImportError:
formatdate = self.date_time
port = self.mailport
if not port:
port = smtplib.SMTP_PORT
smtp = smtplib.SMTP(self.mailhost, port)
msg = self.format(record)
msg = "From: %s\r\nTo: %s\r\nSubject: %s\r\nDate: %s\r\n\r\n%s" % (
self.fromaddr,
string.join(self.toaddrs, ","),
self.getSubject(record),
formatdate(), msg)
if self.username:
smtp.login(self.username, self.password)
smtp.sendmail(self.fromaddr, self.toaddrs, msg)
smtp.quit()
except (KeyboardInterrupt, SystemExit):
raise
except:
self.handleError(record)
# Monkey patch logging.handlers
logging.handlers.SMTPHandler = SMTPHandler

View File

@ -1,3 +1,4 @@
from ConfigParser import SafeConfigParser
from webob.dec import wsgify
from webob.exc import *
@ -94,18 +95,14 @@ class SyncUserSession(object):
class SyncApp(object):
valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download', 'getDecks']
def __init__(self, **kw):
from AnkiServer.threading import getCollectionManager
def __init__(self, config):
from AnkiServer.thread import getCollectionManager
self.data_root = os.path.abspath(kw.get('data_root', '.'))
self.base_url = kw.get('base_url', '/')
self.auth_db_path = os.path.abspath(kw.get('auth_db_path', '.'))
self.data_root = os.path.abspath(config.get("sync_app", "data_root"))
self.base_url = config.get("sync_app", "base_url")
self.sessions = {}
try:
self.collection_manager = kw['collection_manager']
except KeyError:
self.collection_manager = getCollectionManager()
self.collection_manager = getCollectionManager()
# make sure the base_url has a trailing slash
if len(self.base_url) == 0:
@ -297,6 +294,11 @@ class SyncApp(object):
return Response(status='200 OK', content_type='text/plain', body='Anki Sync Server')
class DatabaseAuthSyncApp(SyncApp):
def __init__(self, config):
SyncApp.__init__(self, config)
self.auth_db_path = os.path.abspath(config.get("sync_app", "auth_db_path"))
def authenticate(self, username, password):
"""Returns True if this username is allowed to connect with this password. False otherwise."""
@ -308,6 +310,8 @@ class DatabaseAuthSyncApp(SyncApp):
db_ret = cursor.fetchone()
conn.close()
if db_ret != None:
db_hash = str(db_ret[0])
salt = db_hash[-16:]
@ -317,16 +321,16 @@ class DatabaseAuthSyncApp(SyncApp):
return (db_ret != None and hashobj.hexdigest()+salt == db_hash)
# Our entry point
def make_app(global_conf, **local_conf):
return DatabaseAuthSyncApp(**local_conf)
def main():
from wsgiref.simple_server import make_server
from AnkiServer.threading import shutdown
from AnkiServer.thread import shutdown
config = SafeConfigParser()
config.read("production.ini")
ankiserver = DatabaseAuthSyncApp(config)
httpd = make_server('', config.getint("sync_app", "port"), ankiserver)
ankiserver = SyncApp()
httpd = make_server('', 8001, ankiserver)
try:
print "Starting..."
httpd.serve_forever()

View File

@ -1,18 +0,0 @@
def setup_logging(config_file=None):
"""Setup logging based on a config_file."""
import logging
if config_file is not None:
# monkey patch the logging.config.SMTPHandler if necessary
import sys
if sys.version_info[0] == 2 and sys.version_info[1] == 5:
import AnkiServer.logpatch
# load the config file
import logging.config
logging.config.fileConfig(config_file)
else:
logging.getLogger().setLevel(logging.INFO)

View File

@ -1,34 +1,6 @@
[server:main]
#use = egg:Paste#http
use = egg:AnkiServer#server
[sync_app]
host = 127.0.0.1
port = 27701
[filter-app:main]
use = egg:Paste#translogger
next = real
[app:real]
use = egg:Paste#urlmap
/collection = rest_app
/sync = sync_app
[app:rest_app]
use = egg:AnkiServer#rest_app
data_root = ./collections
allowed_hosts = 127.0.0.1
logging.config_file = logging.conf
[app:sync_app]
use = egg:AnkiServer#sync_app
data_root = ./collections
base_url = /sync/
auth_db_path = ./auth.db
mysql.host = 127.0.0.1
mysql.user = db_user
mysql.passwd = db_password
mysql.db = db
sql_check_password = SELECT uid FROM users WHERE name=%s AND pass=MD5(%s) AND status=1
sql_username2dirname = SELECT uid AS dirname FROM users WHERE name=%s

View File

@ -1,41 +0,0 @@
[loggers]
keys=root
[handlers]
keys=screen,file,email
[formatters]
keys=normal,email
[logger_root]
level=INFO
handlers=screen
#handlers=file
#handlers=file,email
[handler_file]
class=FileHandler
formatter=normal
args=('server.log','a')
[handler_screen]
class=StreamHandler
level=NOTSET
formatter=normal
args=(sys.stdout,)
[handler_email]
class=handlers.SMTPHandler
level=ERROR
formatter=email
args=('smtp.example.com', 'support@example.com', ['support_guy1@example.com', 'support_guy2@example.com'], 'AnkiServer error', ('smtp_user', 'smtp_password'))
[formatter_normal]
format=%(asctime)s:%(name)s:%(levelname)s:%(message)s
datefmt=
[formatter_email]
format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
datefmt=

View File

@ -1,23 +0,0 @@
from setuptools import setup, find_packages
setup(
name="AnkiServer",
version="2.0.0a1",
description="A personal Anki sync server (so you can sync against your own server rather than AnkiWeb)",
author="David Snopek",
author_email="dsnopek@gmail.com",
install_requires=["PasteDeploy>=1.3.2"],
# TODO: should these really be in install_requires?
requires=["webob(>=0.9.7)"],
test_suite='nose.collector',
entry_points="""
[paste.app_factory]
sync_app = AnkiServer.apps.sync_app:make_app
rest_app = AnkiServer.apps.rest_app:make_app
[paste.server_runner]
server = AnkiServer:server_runner
""",
)

View File

@ -1,12 +0,0 @@
[program:anki-server]
command=/usr/local/bin/paster serve production.ini
directory=/var/anki-server
user=www-data
autostart=true
autorestart=true
redirect_stderr=true
; Sometimes necessary if Anki is complaining about a UTF-8 locale. Make sure
; that the local you pick is actually installed on your system.
;environment=LANG=en_US.UTF-8,LC_ALL=en_US.UTF-8,LC_LANG=en_US.UTF-8

View File

@ -1,140 +0,0 @@
import os
import shutil
import tempfile
import unittest
import mock
from mock import MagicMock, sentinel
import AnkiServer
from AnkiServer.collection import CollectionWrapper, CollectionManager
class CollectionWrapperTest(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
self.collection_path = os.path.join(self.temp_dir, 'collection.anki2');
def tearDown(self):
shutil.rmtree(self.temp_dir)
def test_lifecycle_real(self):
"""Testing common life-cycle with existing and non-existant collections. This
test uses the real Anki objects and actually creates a new collection on disk."""
wrapper = CollectionWrapper(self.collection_path)
self.assertFalse(os.path.exists(self.collection_path))
self.assertFalse(wrapper.opened())
wrapper.open()
self.assertTrue(os.path.exists(self.collection_path))
self.assertTrue(wrapper.opened())
# calling open twice shouldn't break anything
wrapper.open()
wrapper.close()
self.assertTrue(os.path.exists(self.collection_path))
self.assertFalse(wrapper.opened())
# open the same collection again (not a creation)
wrapper = CollectionWrapper(self.collection_path)
self.assertFalse(wrapper.opened())
wrapper.open()
self.assertTrue(wrapper.opened())
wrapper.close()
self.assertFalse(wrapper.opened())
self.assertTrue(os.path.exists(self.collection_path))
def test_del(self):
with mock.patch('anki.storage.Collection') as anki_storage_Collection:
col = anki_storage_Collection.return_value
wrapper = CollectionWrapper(self.collection_path)
wrapper.open()
wrapper = None
col.close.assert_called_with()
def test_setup_func(self):
# Run it when the collection doesn't exist
with mock.patch('anki.storage.Collection') as anki_storage_Collection:
col = anki_storage_Collection.return_value
setup_new_collection = MagicMock()
self.assertFalse(os.path.exists(self.collection_path))
wrapper = CollectionWrapper(self.collection_path, setup_new_collection)
wrapper.open()
anki_storage_Collection.assert_called_with(self.collection_path)
setup_new_collection.assert_called_with(col)
wrapper = None
# Make sure that no collection was actually created
self.assertFalse(os.path.exists(self.collection_path))
# Create a faux collection file
with file(self.collection_path, 'wt') as fd:
fd.write('Collection!')
# Run it when the collection does exist
with mock.patch('anki.storage.Collection'):
setup_new_collection = lambda col: self.fail("Setup function called when collection already exists!")
self.assertTrue(os.path.exists(self.collection_path))
wrapper = CollectionWrapper(self.collection_path, setup_new_collection)
wrapper.open()
anki_storage_Collection.assert_called_with(self.collection_path)
wrapper = None
def test_execute(self):
with mock.patch('anki.storage.Collection') as anki_storage_Collection:
col = anki_storage_Collection.return_value
func = MagicMock()
func.return_value = sentinel.some_object
# check that execute works and auto-creates the collection
wrapper = CollectionWrapper(self.collection_path)
ret = wrapper.execute(func, [1, 2, 3], {'key': 'aoeu'})
self.assertEqual(ret, sentinel.some_object)
anki_storage_Collection.assert_called_with(self.collection_path)
func.assert_called_with(col, 1, 2, 3, key='aoeu')
# check that execute always returns False if waitForReturn=False
func.reset_mock()
ret = wrapper.execute(func, [1, 2, 3], {'key': 'aoeu'}, waitForReturn=False)
self.assertEqual(ret, None)
func.assert_called_with(col, 1, 2, 3, key='aoeu')
class CollectionManagerTest(unittest.TestCase):
def test_lifecycle(self):
with mock.patch('AnkiServer.collection.CollectionManager.collection_wrapper') as CollectionWrapper:
wrapper = MagicMock()
CollectionWrapper.return_value = wrapper
manager = CollectionManager()
# check getting a new collection
ret = manager.get_collection('path1')
CollectionWrapper.assert_called_with(os.path.realpath('path1'), None)
self.assertEqual(ret, wrapper)
# change the return value, so that it would return a new object
new_wrapper = MagicMock()
CollectionWrapper.return_value = new_wrapper
CollectionWrapper.reset_mock()
# get the new wrapper
ret = manager.get_collection('path2')
CollectionWrapper.assert_called_with(os.path.realpath('path2'), None)
self.assertEqual(ret, new_wrapper)
# make sure the wrapper and new_wrapper are different
self.assertNotEqual(wrapper, new_wrapper)
# assert that calling with the first path again, returns the first wrapper
ret = manager.get_collection('path1')
self.assertEqual(ret, wrapper)
manager.shutdown()
wrapper.close.assert_called_with()
new_wrapper.close.assert_called_with()
if __name__ == '__main__':
unittest.main()

View File

@ -1,113 +0,0 @@
import os
import shutil
import tempfile
import unittest
import mock
from mock import MagicMock, sentinel
import AnkiServer
from AnkiServer.importer import get_importer_class, import_file
import anki.storage
# TODO: refactor into some kind of utility
def add_note(col, data):
from anki.notes import Note
model = col.models.byName(data['model'])
note = Note(col, model)
for name, value in data['fields'].items():
note[name] = value
if data.has_key('tags'):
note.setTagsFromStr(data['tags'])
col.addNote(note)
class ImporterTest(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
self.collection_path = os.path.join(self.temp_dir, 'collection.anki2')
self.collection = anki.storage.Collection(self.collection_path)
def tearDown(self):
self.collection.close()
self.collection = None
shutil.rmtree(self.temp_dir)
# TODO: refactor into a parent class
def add_default_note(self, count=1):
data = {
'model': 'Basic',
'fields': {
'Front': 'The front',
'Back': 'The back',
},
'tags': "Tag1 Tag2",
}
for idx in range(0, count):
add_note(self.collection, data)
self.add_note(data)
def test_resync(self):
from anki.exporting import AnkiPackageExporter
from anki.utils import intTime
# create a new collection with a single note
src_collection = anki.storage.Collection(os.path.join(self.temp_dir, 'src_collection.anki2'))
add_note(src_collection, {
'model': 'Basic',
'fields': {
'Front': 'The front',
'Back': 'The back',
},
'tags': 'Tag1 Tag2',
})
note_id = src_collection.findNotes('')[0]
note = src_collection.getNote(note_id)
self.assertEqual(note.id, note_id)
self.assertEqual(note['Front'], 'The front')
self.assertEqual(note['Back'], 'The back')
# export to an .apkg file
dst1_path = os.path.join(self.temp_dir, 'export1.apkg')
exporter = AnkiPackageExporter(src_collection)
exporter.exportInto(dst1_path)
# import it into the main collection
import_file(get_importer_class('apkg'), self.collection, dst1_path)
# make sure the note exists
note = self.collection.getNote(note_id)
self.assertEqual(note.id, note_id)
self.assertEqual(note['Front'], 'The front')
self.assertEqual(note['Back'], 'The back')
# now we change the source collection and re-export it
note = src_collection.getNote(note_id)
note['Front'] = 'The new front'
note.tags.append('Tag3')
note.flush(intTime()+1)
dst2_path = os.path.join(self.temp_dir, 'export2.apkg')
exporter = AnkiPackageExporter(src_collection)
exporter.exportInto(dst2_path)
# first, import it without allow_update - no change should happen
import_file(get_importer_class('apkg'), self.collection, dst2_path)
note = self.collection.getNote(note_id)
self.assertEqual(note['Front'], 'The front')
self.assertEqual(note.tags, ['Tag1', 'Tag2'])
# now, import it with allow_update=True, so the note should change
import_file(get_importer_class('apkg'), self.collection, dst2_path, allow_update=True)
note = self.collection.getNote(note_id)
self.assertEqual(note['Front'], 'The new front')
self.assertEqual(note.tags, ['Tag1', 'Tag2', 'Tag3'])
if __name__ == '__main__':
unittest.main()

View File

@ -1,618 +0,0 @@
# -*- coding: utf-8 -*-
import os
import shutil
import tempfile
import unittest
import logging
from pprint import pprint
import mock
from mock import MagicMock
import AnkiServer
from AnkiServer.collection import CollectionManager
from AnkiServer.apps.rest_app import RestApp, RestHandlerRequest, CollectionHandler, ImportExportHandler, NoteHandler, ModelHandler, DeckHandler, CardHandler
from webob.exc import *
import anki
import anki.storage
class RestAppTest(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
self.collection_manager = CollectionManager()
self.rest_app = RestApp(self.temp_dir, collection_manager=self.collection_manager)
# disable all but critical errors!
logging.disable(logging.CRITICAL)
def tearDown(self):
self.collection_manager.shutdown()
self.collection_manager = None
self.rest_app = None
shutil.rmtree(self.temp_dir)
def test_list_collections(self):
os.mkdir(os.path.join(self.temp_dir, 'test1'))
os.mkdir(os.path.join(self.temp_dir, 'test2'))
with open(os.path.join(self.temp_dir, 'test1', 'collection.anki2'), 'wt') as fd:
fd.write('Testing!')
self.assertEqual(self.rest_app.list_collections(), ['test1'])
def test_parsePath(self):
tests = [
('collection/user', ('collection', 'index', ['user'])),
('collection/user/handler', ('collection', 'handler', ['user'])),
('collection/user/note/123', ('note', 'index', ['user', '123'])),
('collection/user/note/123/handler', ('note', 'handler', ['user', '123'])),
('collection/user/deck/name', ('deck', 'index', ['user', 'name'])),
('collection/user/deck/name/handler', ('deck', 'handler', ['user', 'name'])),
#('collection/user/deck/name/card/123', ('card', 'index', ['user', 'name', '123'])),
#('collection/user/deck/name/card/123/handler', ('card', 'handler', ['user', 'name', '123'])),
('collection/user/card/123', ('card', 'index', ['user', '123'])),
('collection/user/card/123/handler', ('card', 'handler', ['user', '123'])),
# the leading slash should make no difference!
('/collection/user', ('collection', 'index', ['user'])),
]
for path, result in tests:
self.assertEqual(self.rest_app._parsePath(path), result)
def test_parsePath_not_found(self):
tests = [
'bad',
'bad/oaeu',
'collection',
'collection/user/handler/bad',
'',
'/',
]
for path in tests:
self.assertRaises(HTTPNotFound, self.rest_app._parsePath, path)
def test_getCollectionPath(self):
def fullpath(collection_id):
return os.path.normpath(os.path.join(self.temp_dir, collection_id, 'collection.anki2'))
# This is simple and straight forward!
self.assertEqual(self.rest_app._getCollectionPath('user'), fullpath('user'))
# These are dangerous - the user is trying to hack us!
dangerous = ['../user', '/etc/passwd', '/tmp/aBaBaB', '/root/.ssh/id_rsa']
for collection_id in dangerous:
self.assertRaises(HTTPBadRequest, self.rest_app._getCollectionPath, collection_id)
def test_getHandler(self):
def handlerOne():
pass
def handlerTwo():
pass
handlerTwo.hasReturnValue = False
self.rest_app.add_handler('collection', 'handlerOne', handlerOne)
self.rest_app.add_handler('deck', 'handlerTwo', handlerTwo)
(handler, hasReturnValue) = self.rest_app._getHandler('collection', 'handlerOne')
self.assertEqual(handler, handlerOne)
self.assertEqual(hasReturnValue, True)
(handler, hasReturnValue) = self.rest_app._getHandler('deck', 'handlerTwo')
self.assertEqual(handler, handlerTwo)
self.assertEqual(hasReturnValue, False)
# try some bad handler names and types
self.assertRaises(HTTPNotFound, self.rest_app._getHandler, 'collection', 'nonExistantHandler')
self.assertRaises(HTTPNotFound, self.rest_app._getHandler, 'nonExistantType', 'handlerOne')
def test_parseRequestBody(self):
req = MagicMock()
req.body = '{"key":"value"}'
data = self.rest_app._parseRequestBody(req)
self.assertEqual(data, {'key': 'value'})
self.assertEqual(data.keys(), ['key'])
self.assertEqual(type(data.keys()[0]), str)
# test some bad data
req.body = '{aaaaaaa}'
self.assertRaises(HTTPBadRequest, self.rest_app._parseRequestBody, req)
class CollectionTestBase(unittest.TestCase):
"""Parent class for tests that need a collection set up and torn down."""
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
self.collection_path = os.path.join(self.temp_dir, 'collection.anki2');
self.collection = anki.storage.Collection(self.collection_path)
self.mock_app = MagicMock()
def tearDown(self):
self.collection.close()
self.collection = None
shutil.rmtree(self.temp_dir)
self.mock_app.reset_mock()
# TODO: refactor into some kind of utility
def add_note(self, data):
from anki.notes import Note
model = self.collection.models.byName(data['model'])
note = Note(self.collection, model)
for name, value in data['fields'].items():
note[name] = value
if data.has_key('tags'):
note.setTagsFromStr(data['tags'])
self.collection.addNote(note)
# TODO: refactor into a parent class
def add_default_note(self, count=1):
data = {
'model': 'Basic',
'fields': {
'Front': 'The front',
'Back': 'The back',
},
'tags': "Tag1 Tag2",
}
for idx in range(0, count):
self.add_note(data)
class CollectionHandlerTest(CollectionTestBase):
def setUp(self):
super(CollectionHandlerTest, self).setUp()
self.handler = CollectionHandler()
def execute(self, name, data):
ids = ['collection_name']
func = getattr(self.handler, name)
req = RestHandlerRequest(self.mock_app, data, ids, {})
return func(self.collection, req)
def test_list_decks(self):
data = {}
ret = self.execute('list_decks', data)
# It contains only the 'Default' deck
self.assertEqual(len(ret), 1)
self.assertEqual(ret[0]['name'], 'Default')
def test_select_deck(self):
data = {'deck_id': '1'}
ret = self.execute('select_deck', data)
self.assertEqual(ret, None);
def test_create_dynamic_deck_simple(self):
self.add_default_note(5)
data = {
'name': 'Dyn deck',
'mode': 'random',
'count': 2,
'query': "deck:\"Default\" (tag:'Tag1' or tag:'Tag2') (-tag:'Tag3')",
}
ret = self.execute('create_dynamic_deck', data)
self.assertEqual(ret['name'], 'Dyn deck')
self.assertEqual(ret['dyn'], True)
cards = self.collection.findCards('deck:"Dyn deck"')
self.assertEqual(len(cards), 2)
def test_list_models(self):
data = {}
ret = self.execute('list_models', data)
# get a sorted name list that we can actually check
names = [model['name'] for model in ret]
names.sort()
# These are the default models created by Anki in a new collection
default_models = [
'Basic',
'Basic (and reversed card)',
'Basic (optional reversed card)',
'Cloze'
]
self.assertEqual(names, default_models)
def test_find_model_by_name(self):
data = {'model': 'Basic'}
ret = self.execute('find_model_by_name', data)
self.assertEqual(ret['name'], 'Basic')
def test_find_notes(self):
ret = self.execute('find_notes', {})
self.assertEqual(ret, [])
# add a note programatically
self.add_default_note()
# get the id for the one note on this collection
note_id = self.collection.findNotes('')[0]
ret = self.execute('find_notes', {})
self.assertEqual(ret, [{'id': note_id}])
ret = self.execute('find_notes', {'query': 'tag:Tag1'})
self.assertEqual(ret, [{'id': note_id}])
ret = self.execute('find_notes', {'query': 'tag:TagX'})
self.assertEqual(ret, [])
ret = self.execute('find_notes', {'preload': True})
self.assertEqual(len(ret), 1)
self.assertEqual(ret[0]['id'], note_id)
self.assertEqual(ret[0]['model']['name'], 'Basic')
def test_add_note(self):
# make sure there are no notes (yet)
self.assertEqual(self.collection.findNotes(''), [])
# add a note programatically
note = {
'model': 'Basic',
'fields': {
'Front': 'The front',
'Back': 'The back',
},
'tags': "Tag1 Tag2",
}
self.execute('add_note', note)
notes = self.collection.findNotes('')
self.assertEqual(len(notes), 1)
note_id = notes[0]
note = self.collection.getNote(note_id)
self.assertEqual(note.model()['name'], 'Basic')
self.assertEqual(note['Front'], 'The front')
self.assertEqual(note['Back'], 'The back')
self.assertEqual(note.tags, ['Tag1', 'Tag2'])
def test_list_tags(self):
ret = self.execute('list_tags', {})
self.assertEqual(ret, [])
self.add_default_note()
ret = self.execute('list_tags', {})
ret.sort()
self.assertEqual(ret, ['Tag1', 'Tag2'])
def test_set_language(self):
import anki.lang
self.assertEqual(anki.lang._('Again'), 'Again')
try:
data = {'code': 'pl'}
self.execute('set_language', data)
self.assertEqual(anki.lang._('Again'), u'Znowu')
finally:
# return everything to normal!
anki.lang.setLang('en')
def test_reset_scheduler(self):
self.add_default_note(3)
ret = self.execute('reset_scheduler', {'deck': 'Default'})
self.assertEqual(ret, {
'new_cards': 3,
'learning_cards': 0,
'review_cards': 0,
})
def test_next_card(self):
ret = self.execute('next_card', {})
self.assertEqual(ret, None)
# add a note programatically
self.add_default_note()
# get the id for the one card and note on this collection
note_id = self.collection.findNotes('')[0]
card_id = self.collection.findCards('')[0]
self.collection.sched.reset()
ret = self.execute('next_card', {})
self.assertEqual(ret['id'], card_id)
self.assertEqual(ret['nid'], note_id)
self.assertEqual(ret['css'], '<style>.card {\n font-family: arial;\n font-size: 20px;\n text-align: center;\n color: black;\n background-color: white;\n}\n</style>')
self.assertEqual(ret['question'], 'The front')
self.assertEqual(ret['answer'], 'The front\n\n<hr id=answer>\n\nThe back')
self.assertEqual(ret['answer_buttons'], [
{'ease': 1,
'label': 'Again',
'string_label': 'Again',
'interval': 60,
'string_interval': '<1 minute'},
{'ease': 2,
'label': 'Good',
'string_label': 'Good',
'interval': 600,
'string_interval': '<10 minutes'},
{'ease': 3,
'label': 'Easy',
'string_label': 'Easy',
'interval': 345600,
'string_interval': '4 days'}])
def test_next_card_translation(self):
# add a note programatically
self.add_default_note()
# get the card in Polish so we can test translation too
anki.lang.setLang('pl')
try:
ret = self.execute('next_card', {})
finally:
anki.lang.setLang('en')
self.assertEqual(ret['answer_buttons'], [
{'ease': 1,
'label': 'Again',
'string_label': u'Znowu',
'interval': 60,
'string_interval': '<1 minuta'},
{'ease': 2,
'label': 'Good',
'string_label': u'Dobra',
'interval': 600,
'string_interval': '<10 minut'},
{'ease': 3,
'label': 'Easy',
'string_label': u'Łatwa',
'interval': 345600,
'string_interval': '4 dni'}])
def test_next_card_five_times(self):
self.add_default_note(5)
for idx in range(0, 5):
ret = self.execute('next_card', {})
self.assertTrue(ret is not None)
def test_answer_card(self):
import time
self.add_default_note()
# instantiate a deck handler to get the card
card = self.execute('next_card', {})
self.assertEqual(card['reps'], 0)
self.execute('answer_card', {'id': card['id'], 'ease': 2, 'timerStarted': time.time()})
# reset the scheduler and try to get the next card again - there should be none!
self.collection.sched.reset()
card = self.execute('next_card', {})
self.assertEqual(card['reps'], 1)
def test_suspend_cards(self):
# add a note programatically
self.add_default_note()
# get the id for the one card on this collection
card_id = self.collection.findCards('')[0]
# suspend it
self.execute('suspend_cards', {'ids': [card_id]})
# test that getting the next card will be None
card = self.collection.sched.getCard()
self.assertEqual(card, None)
# unsuspend it
self.execute('unsuspend_cards', {'ids': [card_id]})
# test that now we're getting the next card!
self.collection.sched.reset()
card = self.collection.sched.getCard()
self.assertEqual(card.id, card_id)
def test_cards_recent_ease(self):
self.add_default_note()
card_id = self.collection.findCards('')[0]
# answer the card
self.collection.reset()
card = self.collection.sched.getCard()
card.startTimer()
# answer multiple times to see that we only get the latest!
self.collection.sched.answerCard(card, 1)
self.collection.sched.answerCard(card, 3)
self.collection.sched.answerCard(card, 2)
# pull the latest revision
ret = self.execute('cards_recent_ease', {})
self.assertEqual(ret[0]['id'], card_id)
self.assertEqual(ret[0]['ease'], 2)
class ImportExportHandlerTest(CollectionTestBase):
export_rows = [
['Card front 1', 'Card back 1', 'Tag1 Tag2'],
['Card front 2', 'Card back 2', 'Tag1 Tag3'],
]
def setUp(self):
super(ImportExportHandlerTest, self).setUp()
self.handler = ImportExportHandler()
def execute(self, name, data):
ids = ['collection_name']
func = getattr(self.handler, name)
req = RestHandlerRequest(self.mock_app, data, ids, {})
return func(self.collection, req)
def generate_text_export(self):
# Create a simple export file
export_data = ''
for row in self.export_rows:
export_data += '\t'.join(row) + '\n'
export_path = os.path.join(self.temp_dir, 'export.txt')
with file(export_path, 'wt') as fd:
fd.write(export_data)
return (export_data, export_path)
def check_import(self):
note_ids = self.collection.findNotes('')
notes = [self.collection.getNote(note_id) for note_id in note_ids]
self.assertEqual(len(notes), len(self.export_rows))
for index, test_data in enumerate(self.export_rows):
self.assertEqual(notes[index]['Front'], test_data[0])
self.assertEqual(notes[index]['Back'], test_data[1])
self.assertEqual(' '.join(notes[index].tags), test_data[2])
def test_import_text_data(self):
(export_data, export_path) = self.generate_text_export()
data = {
'filetype': 'text',
'data': export_data,
}
ret = self.execute('import_file', data)
self.check_import()
def test_import_text_url(self):
(export_data, export_path) = self.generate_text_export()
data = {
'filetype': 'text',
'url': 'file://' + os.path.realpath(export_path),
}
ret = self.execute('import_file', data)
self.check_import()
class NoteHandlerTest(CollectionTestBase):
def setUp(self):
super(NoteHandlerTest, self).setUp()
self.handler = NoteHandler()
def execute(self, name, data, note_id):
ids = ['collection_name', note_id]
func = getattr(self.handler, name)
req = RestHandlerRequest(self.mock_app, data, ids, {})
return func(self.collection, req)
def test_index(self):
self.add_default_note()
note_id = self.collection.findNotes('')[0]
ret = self.execute('index', {}, note_id)
self.assertEqual(ret['id'], note_id)
self.assertEqual(len(ret['fields']), 2)
self.assertEqual(ret['flags'], 0)
self.assertEqual(ret['model']['name'], 'Basic')
self.assertEqual(ret['tags'], ['Tag1', 'Tag2'])
self.assertEqual(ret['string_tags'], 'Tag1 Tag2')
self.assertEqual(ret['usn'], -1)
def test_add_tags(self):
self.add_default_note()
note_id = self.collection.findNotes('')[0]
note = self.collection.getNote(note_id)
self.assertFalse('NT1' in note.tags)
self.assertFalse('NT2' in note.tags)
self.execute('add_tags', {'tags': ['NT1', 'NT2']}, note_id)
note = self.collection.getNote(note_id)
self.assertTrue('NT1' in note.tags)
self.assertTrue('NT2' in note.tags)
def test_remove_tags(self):
self.add_default_note()
note_id = self.collection.findNotes('')[0]
note = self.collection.getNote(note_id)
self.assertTrue('Tag1' in note.tags)
self.assertTrue('Tag2' in note.tags)
self.execute('remove_tags', {'tags': ['Tag1', 'Tag2']}, note_id)
note = self.collection.getNote(note_id)
self.assertFalse('Tag1' in note.tags)
self.assertFalse('Tag2' in note.tags)
class DeckHandlerTest(CollectionTestBase):
def setUp(self):
super(DeckHandlerTest, self).setUp()
self.handler = DeckHandler()
def execute(self, name, data):
ids = ['collection_name', '1']
func = getattr(self.handler, name)
req = RestHandlerRequest(self.mock_app, data, ids, {})
return func(self.collection, req)
def test_index(self):
ret = self.execute('index', {})
#pprint(ret)
self.assertEqual(ret['name'], 'Default')
self.assertEqual(ret['id'], 1)
self.assertEqual(ret['dyn'], False)
def test_next_card(self):
self.mock_app.execute_handler.return_value = None
ret = self.execute('next_card', {})
self.assertEqual(ret, None)
self.mock_app.execute_handler.assert_called_with('collection', 'next_card', self.collection, RestHandlerRequest(self.mock_app, {'deck': '1'}, ['collection_name'], {}))
def test_get_conf(self):
ret = self.execute('get_conf', {})
#pprint(ret)
self.assertEqual(ret['name'], 'Default')
self.assertEqual(ret['id'], 1)
self.assertEqual(ret['dyn'], False)
class CardHandlerTest(CollectionTestBase):
def setUp(self):
super(CardHandlerTest, self).setUp()
self.handler = CardHandler()
def execute(self, name, data, card_id):
ids = ['collection_name', card_id]
func = getattr(self.handler, name)
req = RestHandlerRequest(self.mock_app, data, ids, {})
return func(self.collection, req)
def test_index_simple(self):
self.add_default_note()
note_id = self.collection.findNotes('')[0]
card_id = self.collection.findCards('')[0]
ret = self.execute('index', {}, card_id)
self.assertEqual(ret['id'], card_id)
self.assertEqual(ret['nid'], note_id)
self.assertEqual(ret['did'], 1)
self.assertFalse(ret.has_key('note'))
self.assertFalse(ret.has_key('deck'))
def test_index_load(self):
self.add_default_note()
note_id = self.collection.findNotes('')[0]
card_id = self.collection.findCards('')[0]
ret = self.execute('index', {'load_note': 1, 'load_deck': 1}, card_id)
self.assertEqual(ret['id'], card_id)
self.assertEqual(ret['nid'], note_id)
self.assertEqual(ret['did'], 1)
self.assertEqual(ret['note']['id'], note_id)
self.assertEqual(ret['note']['model']['name'], 'Basic')
self.assertEqual(ret['deck']['name'], 'Default')
if __name__ == '__main__':
unittest.main()

View File

@ -1,9 +0,0 @@
import unittest
import AnkiServer
from AnkiServer.apps.sync_app import SyncApp
class SyncAppTest(unittest.TestCase):
pass