Remove things not needed by sync_app
This commit is contained in:
parent
fd58fcd9ec
commit
d573bf6b42
@ -2,14 +2,14 @@
|
||||
import sys
|
||||
sys.path.insert(0, "/usr/share/anki")
|
||||
|
||||
def server_runner(app, global_conf, **kw):
|
||||
""" Special version of paste.httpserver.server_runner which calls
|
||||
AnkiServer.threading.shutdown() on server exit."""
|
||||
|
||||
from paste.httpserver import server_runner as paste_server_runner
|
||||
from AnkiServer.threading import shutdown
|
||||
try:
|
||||
paste_server_runner(app, global_conf, **kw)
|
||||
finally:
|
||||
shutdown()
|
||||
|
||||
#def server_runner(app, global_conf, **kw):
|
||||
# """ Special version of paste.httpserver.server_runner which calls
|
||||
# AnkiServer.thread.shutdown() on server exit."""
|
||||
#
|
||||
# from paste.httpserver import server_runner as paste_server_runner
|
||||
# from AnkiServer.thread import shutdown
|
||||
# try:
|
||||
# paste_server_runner(app, global_conf, **kw)
|
||||
# finally:
|
||||
# shutdown()
|
||||
#
|
||||
|
||||
@ -1,2 +0,0 @@
|
||||
# package
|
||||
|
||||
@ -1,874 +0,0 @@
|
||||
|
||||
from webob.dec import wsgify
|
||||
from webob.exc import *
|
||||
from webob import Response
|
||||
|
||||
#from pprint import pprint
|
||||
|
||||
try:
|
||||
import simplejson as json
|
||||
from simplejson import JSONDecodeError
|
||||
except ImportError:
|
||||
import json
|
||||
JSONDecodeError = ValueError
|
||||
|
||||
import os, logging
|
||||
|
||||
import anki.consts
|
||||
import anki.lang
|
||||
from anki.lang import _ as t
|
||||
|
||||
__all__ = ['RestApp', 'RestHandlerBase', 'noReturnValue']
|
||||
|
||||
def noReturnValue(func):
|
||||
func.hasReturnValue = False
|
||||
return func
|
||||
|
||||
class RestHandlerBase(object):
|
||||
"""Parent class for a handler group."""
|
||||
hasReturnValue = True
|
||||
|
||||
class _RestHandlerWrapper(RestHandlerBase):
|
||||
"""Wrapper for functions that we can't modify."""
|
||||
def __init__(self, func_name, func, hasReturnValue=True):
|
||||
self.func_name = func_name
|
||||
self.func = func
|
||||
self.hasReturnValue = hasReturnValue
|
||||
def __call__(self, *args, **kw):
|
||||
return self.func(*args, **kw)
|
||||
|
||||
class RestHandlerRequest(object):
|
||||
def __init__(self, app, data, ids, session):
|
||||
self.app = app
|
||||
self.data = data
|
||||
self.ids = ids
|
||||
self.session = session
|
||||
|
||||
def copy(self):
|
||||
return RestHandlerRequest(self.app, self.data.copy(), self.ids[:], self.session)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.app == other.app and self.data == other.data and self.ids == other.ids and self.session == other.session
|
||||
|
||||
class RestApp(object):
|
||||
"""A WSGI app that implements RESTful operations on Collections, Decks and Cards."""
|
||||
|
||||
# Defines not only the valid handler types, but their position in the URL string
|
||||
handler_types = ['collection', ['model', 'note', 'deck', 'card']]
|
||||
|
||||
def __init__(self, data_root, allowed_hosts='*', setup_new_collection=None, use_default_handlers=True, collection_manager=None):
|
||||
from AnkiServer.threading import getCollectionManager
|
||||
|
||||
self.data_root = os.path.abspath(data_root)
|
||||
self.allowed_hosts = allowed_hosts
|
||||
self.setup_new_collection = setup_new_collection
|
||||
|
||||
if collection_manager is not None:
|
||||
self.collection_manager = collection_manager
|
||||
else:
|
||||
self.collection_manager = getCollectionManager()
|
||||
|
||||
self.handlers = {}
|
||||
for type_list in self.handler_types:
|
||||
if type(type_list) is not list:
|
||||
type_list = [type_list]
|
||||
for handler_type in type_list:
|
||||
self.handlers[handler_type] = {}
|
||||
|
||||
if use_default_handlers:
|
||||
self.add_handler_group('collection', CollectionHandler())
|
||||
self.add_handler_group('note', NoteHandler())
|
||||
self.add_handler_group('model', ModelHandler())
|
||||
self.add_handler_group('deck', DeckHandler())
|
||||
self.add_handler_group('card', CardHandler())
|
||||
|
||||
# hold per collection session data
|
||||
self.sessions = {}
|
||||
|
||||
def add_handler(self, type, name, handler):
|
||||
"""Adds a callback handler for a type (collection, deck, card) with a unique name.
|
||||
|
||||
- 'type' is the item that will be worked on, for example: collection, deck, and card.
|
||||
|
||||
- 'name' is a unique name for the handler that gets used in the URL.
|
||||
|
||||
- 'handler' is a callable that takes (collection, data, ids).
|
||||
"""
|
||||
|
||||
if self.handlers[type].has_key(name):
|
||||
raise "Handler already for %(type)s/%(name)s exists!"
|
||||
self.handlers[type][name] = handler
|
||||
|
||||
def add_handler_group(self, type, group):
|
||||
"""Adds several handlers for every public method on an object descended from RestHandlerBase.
|
||||
|
||||
This allows you to create a single class with several methods, so that you can quickly
|
||||
create a group of related handlers."""
|
||||
|
||||
import inspect
|
||||
for name, method in inspect.getmembers(group, predicate=inspect.ismethod):
|
||||
if not name.startswith('_'):
|
||||
if hasattr(group, 'hasReturnValue') and not hasattr(method, 'hasReturnValue'):
|
||||
method = _RestHandlerWrapper(group.__class__.__name__ + '.' + name, method, group.hasReturnValue)
|
||||
self.add_handler(type, name, method)
|
||||
|
||||
def execute_handler(self, type, name, col, req):
|
||||
"""Executes the handler with the given type and name, passing in the col and req as arguments."""
|
||||
|
||||
handler, hasReturnValue = self._getHandler(type, name)
|
||||
ret = handler(col, req)
|
||||
if hasReturnValue:
|
||||
return ret
|
||||
|
||||
def list_collections(self):
|
||||
"""Returns an array of valid collection names in our self.data_path."""
|
||||
return [x for x in os.listdir(self.data_root) if os.path.exists(os.path.join(self.data_root, x, 'collection.anki2'))]
|
||||
|
||||
def _checkRequest(self, req):
|
||||
"""Raises an exception if the request isn't allowed or valid for some reason."""
|
||||
if self.allowed_hosts != '*':
|
||||
try:
|
||||
remote_addr = req.headers['X-Forwarded-For']
|
||||
except KeyError:
|
||||
remote_addr = req.remote_addr
|
||||
if remote_addr != self.allowed_hosts:
|
||||
raise HTTPForbidden()
|
||||
|
||||
if req.method != 'POST':
|
||||
raise HTTPMethodNotAllowed(allow=['POST'])
|
||||
|
||||
def _parsePath(self, path):
|
||||
"""Takes a request path and returns a tuple containing the handler type, name
|
||||
and a list of ids.
|
||||
|
||||
Raises an HTTPNotFound exception if the path is invalid."""
|
||||
|
||||
if path in ('', '/'):
|
||||
raise HTTPNotFound()
|
||||
|
||||
# split the URL into a list of parts
|
||||
if path[0] == '/':
|
||||
path = path[1:]
|
||||
parts = path.split('/')
|
||||
|
||||
# pull the type and context from the URL parts
|
||||
handler_type = None
|
||||
ids = []
|
||||
for type_list in self.handler_types:
|
||||
if len(parts) == 0:
|
||||
break
|
||||
|
||||
# some URL positions can have multiple types
|
||||
if type(type_list) is not list:
|
||||
type_list = [type_list]
|
||||
|
||||
# get the handler_type
|
||||
if parts[0] not in type_list:
|
||||
break
|
||||
handler_type = parts.pop(0)
|
||||
|
||||
# add the id to the id list
|
||||
if len(parts) > 0:
|
||||
ids.append(parts.pop(0))
|
||||
# break if we don't have enough parts to make a new type/id pair
|
||||
if len(parts) < 2:
|
||||
break
|
||||
|
||||
# sanity check to make sure the URL is valid
|
||||
if len(parts) > 1 or len(ids) == 0:
|
||||
raise HTTPNotFound()
|
||||
|
||||
# get the handler name
|
||||
if len(parts) == 0:
|
||||
name = 'index'
|
||||
else:
|
||||
name = parts[0]
|
||||
|
||||
return (handler_type, name, ids)
|
||||
|
||||
def _getCollectionPath(self, collection_id):
|
||||
"""Returns the path to the collection based on the collection_id from the request.
|
||||
|
||||
Raises HTTPBadRequest if the collection_id is invalid."""
|
||||
|
||||
path = os.path.normpath(os.path.join(self.data_root, collection_id, 'collection.anki2'))
|
||||
if path[0:len(self.data_root)] != self.data_root:
|
||||
# attempting to escape our data jail!
|
||||
raise HTTPBadRequest('"%s" is not a valid collection' % collection_id)
|
||||
|
||||
return path
|
||||
|
||||
def _getHandler(self, type, name):
|
||||
"""Returns a tuple containing handler function for this type and name, and a boolean flag
|
||||
if that handler has a return value.
|
||||
|
||||
Raises an HTTPNotFound exception if the handler doesn't exist."""
|
||||
|
||||
# get the handler function
|
||||
try:
|
||||
handler = self.handlers[type][name]
|
||||
except KeyError:
|
||||
raise HTTPNotFound()
|
||||
|
||||
# get if we have a return value
|
||||
hasReturnValue = True
|
||||
if hasattr(handler, 'hasReturnValue'):
|
||||
hasReturnValue = handler.hasReturnValue
|
||||
|
||||
return (handler, hasReturnValue)
|
||||
|
||||
def _parseRequestBody(self, req):
|
||||
"""Parses the request body (JSON) into a Python dict and returns it.
|
||||
|
||||
Raises an HTTPBadRequest exception if the request isn't valid JSON."""
|
||||
|
||||
try:
|
||||
data = json.loads(req.body)
|
||||
except JSONDecodeError, e:
|
||||
logging.error(req.path+': Unable to parse JSON: '+str(e), exc_info=True)
|
||||
raise HTTPBadRequest()
|
||||
|
||||
# fix for a JSON encoding 'quirk' in PHP
|
||||
if type(data) == list and len(data) == 0:
|
||||
data = {}
|
||||
|
||||
# make the keys into non-unicode strings
|
||||
data = dict([(str(k), v) for k, v in data.items()])
|
||||
|
||||
return data
|
||||
|
||||
@wsgify
|
||||
def __call__(self, req):
|
||||
# make sure the request is valid
|
||||
self._checkRequest(req)
|
||||
|
||||
if req.path == '/list_collections':
|
||||
return Response(json.dumps(self.list_collections()), content_type='application/json')
|
||||
|
||||
# parse the path
|
||||
type, name, ids = self._parsePath(req.path)
|
||||
|
||||
# get the collection path
|
||||
collection_path = self._getCollectionPath(ids[0])
|
||||
print collection_path
|
||||
|
||||
# get the handler function
|
||||
handler, hasReturnValue = self._getHandler(type, name)
|
||||
|
||||
# parse the request body
|
||||
data = self._parseRequestBody(req)
|
||||
|
||||
# get the users session
|
||||
try:
|
||||
session = self.sessions[ids[0]]
|
||||
except KeyError:
|
||||
session = self.sessions[ids[0]] = {}
|
||||
|
||||
# debug
|
||||
from pprint import pprint
|
||||
pprint(data)
|
||||
|
||||
# run it!
|
||||
col = self.collection_manager.get_collection(collection_path, self.setup_new_collection)
|
||||
handler_request = RestHandlerRequest(self, data, ids, session)
|
||||
try:
|
||||
output = col.execute(handler, [handler_request], {}, hasReturnValue)
|
||||
except Exception, e:
|
||||
logging.error(e)
|
||||
return HTTPInternalServerError()
|
||||
|
||||
if output is None:
|
||||
return Response('', content_type='text/plain')
|
||||
else:
|
||||
return Response(json.dumps(output), content_type='application/json')
|
||||
|
||||
class CollectionHandler(RestHandlerBase):
|
||||
"""Default handler group for 'collection' type."""
|
||||
|
||||
#
|
||||
# MODELS - Store fields definitions and templates for notes
|
||||
#
|
||||
|
||||
def list_models(self, col, req):
|
||||
# This is already a list of dicts, so it doesn't need to be serialized
|
||||
return col.models.all()
|
||||
|
||||
def find_model_by_name(self, col, req):
|
||||
# This is already a list of dicts, so it doesn't need to be serialized
|
||||
return col.models.byName(req.data['model'])
|
||||
|
||||
#
|
||||
# NOTES - Information (in fields per the model) that can generate a card
|
||||
# (based on a template from the model).
|
||||
#
|
||||
|
||||
def find_notes(self, col, req):
|
||||
query = req.data.get('query', '')
|
||||
ids = col.findNotes(query)
|
||||
|
||||
if req.data.get('preload', False):
|
||||
notes = [NoteHandler._serialize(col.getNote(id)) for id in ids]
|
||||
else:
|
||||
notes = [{'id': id} for id in ids]
|
||||
|
||||
return notes
|
||||
|
||||
def latest_notes(self, col, req):
|
||||
# TODO: use SQLAlchemy objects to do this
|
||||
sql = "SELECT n.id FROM notes AS n";
|
||||
args = []
|
||||
if req.data.has_key('updated_since'):
|
||||
sql += ' WHERE n.mod > ?'
|
||||
args.append(req.data['updated_since'])
|
||||
sql += ' ORDER BY n.mod DESC'
|
||||
sql += ' LIMIT ' + str(req.data.get('limit', 10))
|
||||
ids = col.db.list(sql, *args)
|
||||
|
||||
if req.data.get('preload', False):
|
||||
notes = [NoteHandler._serialize(col.getNote(id)) for id in ids]
|
||||
else:
|
||||
notes = [{'id': id} for id in ids]
|
||||
|
||||
return notes
|
||||
|
||||
@noReturnValue
|
||||
def add_note(self, col, req):
|
||||
from anki.notes import Note
|
||||
|
||||
# TODO: I think this would be better with 'model' for the name
|
||||
# and 'mid' for the model id.
|
||||
if type(req.data['model']) in (str, unicode):
|
||||
model = col.models.byName(req.data['model'])
|
||||
else:
|
||||
model = col.models.get(req.data['model'])
|
||||
|
||||
note = Note(col, model)
|
||||
for name, value in req.data['fields'].items():
|
||||
note[name] = value
|
||||
|
||||
if req.data.has_key('tags'):
|
||||
note.setTagsFromStr(req.data['tags'])
|
||||
|
||||
col.addNote(note)
|
||||
|
||||
def list_tags(self, col, req):
|
||||
return col.tags.all()
|
||||
|
||||
#
|
||||
# DECKS - Groups of cards
|
||||
#
|
||||
|
||||
def list_decks(self, col, req):
|
||||
# This is already a list of dicts, so it doesn't need to be serialized
|
||||
return col.decks.all()
|
||||
|
||||
@noReturnValue
|
||||
def select_deck(self, col, req):
|
||||
col.decks.select(req.data['deck_id'])
|
||||
|
||||
dyn_modes = {
|
||||
'random': anki.consts.DYN_RANDOM,
|
||||
'added': anki.consts.DYN_ADDED,
|
||||
'due': anki.consts.DYN_DUE,
|
||||
}
|
||||
|
||||
def create_dynamic_deck(self, col, req):
|
||||
name = req.data.get('name', t('Custom Study Session'))
|
||||
deck = col.decks.byName(name)
|
||||
if deck:
|
||||
if not deck['dyn']:
|
||||
raise HTTPBadRequest("There is an existing non-dynamic deck with the name %s" % name)
|
||||
|
||||
# safe to empty it because it's a dynamic deck
|
||||
# TODO: maybe this should be an option?
|
||||
col.sched.emptyDyn(deck['id'])
|
||||
else:
|
||||
deck = col.decks.get(col.decks.newDyn(name))
|
||||
|
||||
query = req.data.get('query', '')
|
||||
count = int(req.data.get('count', 100))
|
||||
mode = req.data.get('mode', 'random')
|
||||
|
||||
try:
|
||||
mode = self.dyn_modes[mode]
|
||||
except KeyError:
|
||||
raise HTTPBadRequest("Unknown mode: %s" % mode)
|
||||
|
||||
deck['terms'][0] = [query, count, mode]
|
||||
|
||||
if mode != anki.consts.DYN_RANDOM:
|
||||
deck['resched'] = True
|
||||
else:
|
||||
deck['resched'] = False
|
||||
|
||||
if not col.sched.rebuildDyn(deck['id']):
|
||||
raise HTTPBadRequest("No cards matched the criteria you provided")
|
||||
|
||||
col.decks.save(deck)
|
||||
col.sched.reset()
|
||||
|
||||
return deck
|
||||
|
||||
#
|
||||
# CARD - A specific card in a deck with a history of review (generated from
|
||||
# a note based on the template).
|
||||
#
|
||||
|
||||
def find_cards(self, col, req):
|
||||
query = req.data.get('query', '')
|
||||
order = req.data.get('order', False)
|
||||
ids = anki.find.Finder(col).findCards(query, order)
|
||||
|
||||
if req.data.get('preload', False):
|
||||
cards = [CardHandler._serialize(col.getCard(id), req.data) for id in ids]
|
||||
else:
|
||||
cards = [{'id': id} for id in ids]
|
||||
|
||||
return cards
|
||||
|
||||
def latest_cards(self, col, req):
|
||||
# TODO: use SQLAlchemy objects to do this
|
||||
sql = "SELECT c.id FROM notes AS n INNER JOIN cards AS c ON c.nid = n.id";
|
||||
args = []
|
||||
if req.data.has_key('updated_since'):
|
||||
sql += ' WHERE n.mod > ?'
|
||||
args.append(req.data['updated_since'])
|
||||
sql += ' ORDER BY n.mod DESC'
|
||||
sql += ' LIMIT ' + str(req.data.get('limit', 10))
|
||||
ids = col.db.list(sql, *args)
|
||||
|
||||
if req.data.get('preload', False):
|
||||
cards = [CardHandler._serialize(col.getCard(id), req.data) for id in ids]
|
||||
else:
|
||||
cards = [{'id': id} for id in ids]
|
||||
|
||||
return cards
|
||||
|
||||
#
|
||||
# SCHEDULER - Controls card review, ie. intervals, what cards are due, answering a card, etc.
|
||||
#
|
||||
|
||||
def reset_scheduler(self, col, req):
|
||||
if req.data.has_key('deck'):
|
||||
deck = DeckHandler._get_deck(col, req.data['deck'])
|
||||
col.decks.select(deck['id'])
|
||||
|
||||
col.sched.reset()
|
||||
counts = col.sched.counts()
|
||||
return {
|
||||
'new_cards': counts[0],
|
||||
'learning_cards': counts[1],
|
||||
'review_cards': counts[1],
|
||||
}
|
||||
|
||||
def extend_scheduler_limits(self, col, req):
|
||||
new_cards = int(req.data.get('new_cards', 0))
|
||||
review_cards = int(req.data.get('review_cards', 0))
|
||||
col.sched.extendLimits(new_cards, review_cards)
|
||||
col.sched.reset()
|
||||
|
||||
button_labels = ['Easy', 'Good', 'Hard']
|
||||
|
||||
def _get_answer_buttons(self, col, card):
|
||||
l = []
|
||||
|
||||
# Put the correct number of buttons
|
||||
cnt = col.sched.answerButtons(card)
|
||||
for idx in range(0, cnt - 1):
|
||||
l.append(self.button_labels[idx])
|
||||
l.append('Again')
|
||||
l.reverse()
|
||||
|
||||
# Loop through and add the ease, estimated time (in seconds) and other info
|
||||
return [{
|
||||
'ease': ease,
|
||||
'label': label,
|
||||
'string_label': t(label),
|
||||
'interval': col.sched.nextIvl(card, ease),
|
||||
'string_interval': col.sched.nextIvlStr(card, ease),
|
||||
} for ease, label in enumerate(l, 1)]
|
||||
|
||||
def next_card(self, col, req):
|
||||
if req.data.has_key('deck'):
|
||||
deck = DeckHandler._get_deck(col, req.data['deck'])
|
||||
col.decks.select(deck['id'])
|
||||
|
||||
card = col.sched.getCard()
|
||||
if card is None:
|
||||
return None
|
||||
|
||||
# put it into the card cache to be removed when we answer it
|
||||
#if not req.session.has_key('cards'):
|
||||
# req.session['cards'] = {}
|
||||
#req.session['cards'][long(card.id)] = card
|
||||
|
||||
card.startTimer()
|
||||
|
||||
result = CardHandler._serialize(card, req.data)
|
||||
result['answer_buttons'] = self._get_answer_buttons(col, card)
|
||||
|
||||
return result
|
||||
|
||||
@noReturnValue
|
||||
def answer_card(self, col, req):
|
||||
import time
|
||||
|
||||
card_id = long(req.data['id'])
|
||||
ease = int(req.data['ease'])
|
||||
|
||||
card = col.getCard(card_id)
|
||||
if req.data.has_key('timerStarted'):
|
||||
card.timerStarted = float(req.data['timerStarted'])
|
||||
else:
|
||||
card.timerStarted = time.time()
|
||||
|
||||
col.sched.answerCard(card, ease)
|
||||
|
||||
@noReturnValue
|
||||
def suspend_cards(self, col, req):
|
||||
card_ids = req.data['ids']
|
||||
col.sched.suspendCards(card_ids)
|
||||
|
||||
@noReturnValue
|
||||
def unsuspend_cards(self, col, req):
|
||||
card_ids = req.data['ids']
|
||||
col.sched.unsuspendCards(card_ids)
|
||||
|
||||
def cards_recent_ease(self, col, req):
|
||||
"""Returns the most recent ease for each card."""
|
||||
|
||||
# TODO: Use sqlalchemy to build this query!
|
||||
sql = "SELECT r.cid, r.ease, r.id FROM revlog AS r INNER JOIN (SELECT cid, MAX(id) AS id FROM revlog GROUP BY cid) AS q ON r.cid = q.cid AND r.id = q.id"
|
||||
where = []
|
||||
if req.data.has_key('ids'):
|
||||
where.append('ids IN (' + (','.join(["'%s'" % x for x in req.data['ids']])) + ')')
|
||||
if len(where) > 0:
|
||||
sql += ' WHERE ' + ' AND '.join(where)
|
||||
|
||||
result = []
|
||||
for r in col.db.all(sql):
|
||||
result.append({'id': r[0], 'ease': r[1], 'timestamp': int(r[2] / 1000)})
|
||||
|
||||
return result
|
||||
|
||||
def latest_revlog(self, col, req):
|
||||
"""Returns recent entries from the revlog."""
|
||||
|
||||
# TODO: Use sqlalchemy to build this query!
|
||||
sql = "SELECT r.id, r.ease, r.cid, r.usn, r.ivl, r.lastIvl, r.factor, r.time, r.type FROM revlog AS r"
|
||||
args = []
|
||||
if req.data.has_key('updated_since'):
|
||||
sql += ' WHERE r.id > ?'
|
||||
args.append(long(req.data['updated_since']) * 1000)
|
||||
sql += ' ORDER BY r.id DESC'
|
||||
sql += ' LIMIT ' + str(req.data.get('limit', 100))
|
||||
|
||||
revlog = col.db.all(sql, *args)
|
||||
return [{
|
||||
'id': r[0],
|
||||
'ease': r[1],
|
||||
'timestamp': int(r[0] / 1000),
|
||||
'card_id': r[2],
|
||||
'usn': r[3],
|
||||
'interval': r[4],
|
||||
'last_interval': r[5],
|
||||
'factor': r[6],
|
||||
'time': r[7],
|
||||
'type': r[8],
|
||||
} for r in revlog]
|
||||
|
||||
stats_reports = {
|
||||
'today': 'todayStats',
|
||||
'due': 'dueGraph',
|
||||
'reps': 'repsGraph',
|
||||
'interval': 'ivlGraph',
|
||||
'hourly': 'hourGraph',
|
||||
'ease': 'easeGraph',
|
||||
'card': 'cardGraph',
|
||||
'footer': 'footer',
|
||||
}
|
||||
stats_reports_order = ['today', 'due', 'reps', 'interval', 'hourly', 'ease', 'card', 'footer']
|
||||
|
||||
def stats_report(self, col, req):
|
||||
import anki.stats
|
||||
import re
|
||||
|
||||
stats = anki.stats.CollectionStats(col)
|
||||
stats.width = int(req.data.get('width', 600))
|
||||
stats.height = int(req.data.get('height', 200))
|
||||
reports = req.data.get('reports', self.stats_reports_order)
|
||||
include_css = req.data.get('include_css', False)
|
||||
include_jquery = req.data.get('include_jquery', False)
|
||||
include_flot = req.data.get('include_flot', False)
|
||||
|
||||
if include_css:
|
||||
from anki.statsbg import bg
|
||||
html = stats.css % bg
|
||||
else:
|
||||
html = ''
|
||||
|
||||
for name in reports:
|
||||
if not self.stats_reports.has_key(name):
|
||||
raise HTTPBadRequest("Unknown report name: %s" % name)
|
||||
func = getattr(stats, self.stats_reports[name])
|
||||
|
||||
html += '<div class="anki-graph anki-graph-%s">' % name
|
||||
html += func()
|
||||
html += '</div>'
|
||||
|
||||
# fix an error in some inline styles
|
||||
# TODO: submit a patch to Anki!
|
||||
html = re.sub(r'style="width:([0-9\.]+); height:([0-9\.]+);"', r'style="width:\1px; height: \2px;"', html)
|
||||
html = re.sub(r'-webkit-transform: ([^;]+);', r'-webkit-transform: \1; -moz-transform: \1; -ms-transform: \1; -o-transform: \1; transform: \1;', html)
|
||||
|
||||
scripts = []
|
||||
if include_jquery or include_flot:
|
||||
import anki.js
|
||||
if include_jquery:
|
||||
scripts.append(anki.js.jquery)
|
||||
if include_flot:
|
||||
scripts.append(anki.js.plot)
|
||||
if len(scripts) > 0:
|
||||
html = "<script>%s\n</script>" % ''.join(scripts) + html
|
||||
|
||||
return html
|
||||
|
||||
#
|
||||
# GLOBAL / MISC
|
||||
#
|
||||
|
||||
@noReturnValue
|
||||
def set_language(self, col, req):
|
||||
anki.lang.setLang(req.data['code'])
|
||||
|
||||
class ImportExportHandler(RestHandlerBase):
|
||||
"""Handler group for the 'collection' type, but it's not added by default."""
|
||||
|
||||
def _get_filedata(self, data):
|
||||
import urllib2
|
||||
|
||||
if data.has_key('data'):
|
||||
return data['data']
|
||||
|
||||
fd = None
|
||||
try:
|
||||
fd = urllib2.urlopen(data['url'])
|
||||
filedata = fd.read()
|
||||
finally:
|
||||
if fd is not None:
|
||||
fd.close()
|
||||
|
||||
return filedata
|
||||
|
||||
def _get_importer_class(self, data):
|
||||
filetype = data['filetype']
|
||||
|
||||
from AnkiServer.importer import get_importer_class
|
||||
importer_class = get_importer_class(filetype)
|
||||
if importer_class is None:
|
||||
raise HTTPBadRequest("Unknown filetype '%s'" % filetype)
|
||||
|
||||
return importer_class
|
||||
|
||||
def import_file(self, col, req):
|
||||
import AnkiServer.importer
|
||||
import tempfile
|
||||
|
||||
# get the importer class
|
||||
importer_class = self._get_importer_class(req.data)
|
||||
|
||||
# get the file data
|
||||
filedata = self._get_filedata(req.data)
|
||||
|
||||
# write the file data to a temporary file
|
||||
try:
|
||||
path = None
|
||||
with tempfile.NamedTemporaryFile('wt', delete=False) as fd:
|
||||
path = fd.name
|
||||
fd.write(filedata)
|
||||
|
||||
AnkiServer.importer.import_file(importer_class, col, path)
|
||||
finally:
|
||||
if path is not None:
|
||||
os.unlink(path)
|
||||
|
||||
class ModelHandler(RestHandlerBase):
|
||||
"""Default handler group for 'model' type."""
|
||||
|
||||
def field_names(self, col, req):
|
||||
model = col.models.get(req.ids[1])
|
||||
if model is None:
|
||||
raise HTTPNotFound()
|
||||
return col.models.fieldNames(model)
|
||||
|
||||
class NoteHandler(RestHandlerBase):
|
||||
"""Default handler group for 'note' type."""
|
||||
|
||||
@staticmethod
|
||||
def _serialize(note):
|
||||
d = {
|
||||
'id': note.id,
|
||||
'guid': note.guid,
|
||||
'model': note.model(),
|
||||
'mid': note.mid,
|
||||
'mod': note.mod,
|
||||
'scm': note.scm,
|
||||
'tags': note.tags,
|
||||
'string_tags': ' '.join(note.tags),
|
||||
'fields': {},
|
||||
'flags': note.flags,
|
||||
'usn': note.usn,
|
||||
}
|
||||
|
||||
# add all the fields
|
||||
for name, value in note.items():
|
||||
d['fields'][name] = value
|
||||
|
||||
return d
|
||||
|
||||
def index(self, col, req):
|
||||
note = col.getNote(req.ids[1])
|
||||
return self._serialize(note)
|
||||
|
||||
@noReturnValue
|
||||
def add_tags(self, col, req):
|
||||
note = col.getNote(req.ids[1])
|
||||
for tag in req.data['tags']:
|
||||
note.addTag(tag)
|
||||
note.flush()
|
||||
|
||||
@noReturnValue
|
||||
def remove_tags(self, col, req):
|
||||
note = col.getNote(req.ids[1])
|
||||
for tag in req.data['tags']:
|
||||
note.delTag(tag)
|
||||
note.flush()
|
||||
|
||||
class DeckHandler(RestHandlerBase):
|
||||
"""Default handler group for 'deck' type."""
|
||||
|
||||
@staticmethod
|
||||
def _get_deck(col, val):
|
||||
try:
|
||||
did = long(val)
|
||||
deck = col.decks.get(did, False)
|
||||
except ValueError:
|
||||
deck = col.decks.byName(val)
|
||||
|
||||
if deck is None:
|
||||
raise HTTPNotFound('No deck with id or name: ' + str(val))
|
||||
|
||||
return deck
|
||||
|
||||
def index(self, col, req):
|
||||
return self._get_deck(col, req.ids[1])
|
||||
|
||||
def next_card(self, col, req):
|
||||
req_copy = req.copy()
|
||||
req_copy.data['deck'] = req.ids[1]
|
||||
del req_copy.ids[1]
|
||||
|
||||
# forward this to the CollectionHandler
|
||||
return req.app.execute_handler('collection', 'next_card', col, req_copy)
|
||||
|
||||
def get_conf(self, col, req):
|
||||
# TODO: should probably live in a ConfHandler
|
||||
return col.decks.confForDid(req.ids[1])
|
||||
|
||||
@noReturnValue
|
||||
def set_update_conf(self, col, req):
|
||||
data = req.data.copy()
|
||||
del data['id']
|
||||
|
||||
conf = col.decks.confForDid(req.ids[1])
|
||||
conf = conf.copy()
|
||||
conf.update(data)
|
||||
|
||||
col.decks.updateConf(conf)
|
||||
|
||||
class CardHandler(RestHandlerBase):
|
||||
"""Default handler group for 'card' type."""
|
||||
|
||||
@staticmethod
|
||||
def _serialize(card, opts):
|
||||
d = {
|
||||
'id': card.id,
|
||||
'isEmpty': card.isEmpty(),
|
||||
'css': card.css(),
|
||||
'question': card._getQA()['q'],
|
||||
'answer': card._getQA()['a'],
|
||||
'did': card.did,
|
||||
'due': card.due,
|
||||
'factor': card.factor,
|
||||
'ivl': card.ivl,
|
||||
'lapses': card.lapses,
|
||||
'left': card.left,
|
||||
'mod': card.mod,
|
||||
'nid': card.nid,
|
||||
'odid': card.odid,
|
||||
'odue': card.odue,
|
||||
'ord': card.ord,
|
||||
'queue': card.queue,
|
||||
'reps': card.reps,
|
||||
'type': card.type,
|
||||
'usn': card.usn,
|
||||
'timerStarted': card.timerStarted,
|
||||
}
|
||||
|
||||
if opts.get('load_note', False):
|
||||
d['note'] = NoteHandler._serialize(card.col.getNote(card.nid))
|
||||
|
||||
if opts.get('load_deck', False):
|
||||
d['deck'] = card.col.decks.get(card.did)
|
||||
|
||||
if opts.get('load_latest_revlog', False):
|
||||
d['latest_revlog'] = CardHandler._latest_revlog(card.col, card.id)
|
||||
|
||||
return d
|
||||
|
||||
@staticmethod
|
||||
def _latest_revlog(col, card_id):
|
||||
r = col.db.first("SELECT r.id, r.ease FROM revlog AS r WHERE r.cid = ? ORDER BY id DESC LIMIT 1", card_id)
|
||||
if r:
|
||||
return {'id': r[0], 'ease': r[1], 'timestamp': int(r[0] / 1000)}
|
||||
|
||||
def index(self, col, req):
|
||||
card = col.getCard(req.ids[1])
|
||||
return self._serialize(card, req.data)
|
||||
|
||||
def _forward_to_note(self, col, req, name):
|
||||
card = col.getCard(req.ids[1])
|
||||
|
||||
req_copy = req.copy()
|
||||
req_copy.ids[1] = card.nid
|
||||
|
||||
return req.app.execute_handler('note', name, col, req)
|
||||
|
||||
@noReturnValue
|
||||
def add_tags(self, col, req):
|
||||
self._forward_to_note(col, req, 'add_tags')
|
||||
|
||||
@noReturnValue
|
||||
def remove_tags(self, col, req):
|
||||
self._forward_to_note(col, req, 'remove_tags')
|
||||
|
||||
def stats_report(self, col, req):
|
||||
card = col.getCard(req.ids[1])
|
||||
return col.cardStats(card)
|
||||
|
||||
def latest_revlog(self, col, req):
|
||||
return self._latest_revlog(col, req.ids[1])
|
||||
|
||||
# Our entry point
|
||||
def make_app(global_conf, **local_conf):
|
||||
# TODO: we should setup the default language from conf!
|
||||
|
||||
# setup the logger
|
||||
from AnkiServer.utils import setup_logging
|
||||
setup_logging(local_conf.get('logging.config_file'))
|
||||
|
||||
return RestApp(
|
||||
data_root=local_conf.get('data_root', '.'),
|
||||
allowed_hosts=local_conf.get('allowed_hosts', '*')
|
||||
)
|
||||
|
||||
@ -1,377 +0,0 @@
|
||||
|
||||
from webob.dec import wsgify
|
||||
from webob.exc import *
|
||||
from webob import Response
|
||||
|
||||
import anki
|
||||
from anki.sync import HttpSyncServer, CHUNK_SIZE
|
||||
from anki.db import sqlite
|
||||
from anki.utils import checksum
|
||||
|
||||
import AnkiServer.deck
|
||||
|
||||
import MySQLdb
|
||||
|
||||
try:
|
||||
import simplejson as json
|
||||
except ImportError:
|
||||
import json
|
||||
|
||||
import os, zlib, tempfile, time
|
||||
|
||||
def makeArgs(mdict):
|
||||
d = dict(mdict.items())
|
||||
# TODO: use password/username/version for something?
|
||||
for k in ['p','u','v','d']:
|
||||
if d.has_key(k):
|
||||
del d[k]
|
||||
return d
|
||||
|
||||
class FileIterable(object):
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
def __iter__(self):
|
||||
return FileIterator(self.fn)
|
||||
|
||||
class FileIterator(object):
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
self.fo = open(self.fn, 'rb')
|
||||
self.c = zlib.compressobj()
|
||||
self.flushed = False
|
||||
def __iter__(self):
|
||||
return self
|
||||
def next(self):
|
||||
data = self.fo.read(CHUNK_SIZE)
|
||||
if not data:
|
||||
if not self.flushed:
|
||||
self.flushed = True
|
||||
return self.c.flush()
|
||||
else:
|
||||
raise StopIteration
|
||||
return self.c.compress(data)
|
||||
|
||||
def lock_deck(path):
|
||||
""" Gets exclusive access to this deck path. If there is a DeckThread running on this
|
||||
deck, this will wait for its current operations to complete before temporarily stopping
|
||||
it. """
|
||||
|
||||
from AnkiServer.deck import thread_pool
|
||||
|
||||
if thread_pool.decks.has_key(path):
|
||||
thread_pool.decks[path].stop_and_wait()
|
||||
thread_pool.lock(path)
|
||||
|
||||
def unlock_deck(path):
|
||||
""" Release exclusive access to this deck path. """
|
||||
from AnkiServer.deck import thread_pool
|
||||
thread_pool.unlock(path)
|
||||
|
||||
class SyncAppHandler(HttpSyncServer):
|
||||
operations = ['summary','applyPayload','finish','createDeck','getOneWayPayload']
|
||||
|
||||
def __init__(self):
|
||||
HttpSyncServer.__init__(self)
|
||||
|
||||
def createDeck(self, name):
|
||||
# The HttpSyncServer.createDeck doesn't return a valid value! This seems to be
|
||||
# a bug in libanki.sync ...
|
||||
return self.stuff({"status": "OK"})
|
||||
|
||||
def finish(self):
|
||||
# The HttpSyncServer has no finish() function... I can only assume this is a bug too!
|
||||
return self.stuff("OK")
|
||||
|
||||
class SyncApp(object):
|
||||
valid_urls = SyncAppHandler.operations + ['getDecks','fullup','fulldown']
|
||||
|
||||
def __init__(self, **kw):
|
||||
self.data_root = os.path.abspath(kw.get('data_root', '.'))
|
||||
self.base_url = kw.get('base_url', '/')
|
||||
self.users = {}
|
||||
|
||||
# make sure the base_url has a trailing slash
|
||||
if len(self.base_url) == 0:
|
||||
self.base_url = '/'
|
||||
elif self.base_url[-1] != '/':
|
||||
self.base_url = base_url + '/'
|
||||
|
||||
# setup mysql connection
|
||||
mysql_args = {}
|
||||
for k, v in kw.items():
|
||||
if k.startswith('mysql.'):
|
||||
mysql_args[k[6:]] = v
|
||||
self.mysql_args = mysql_args
|
||||
self.conn = None
|
||||
|
||||
# get SQL statements
|
||||
self.sql_check_password = kw.get('sql_check_password')
|
||||
self.sql_username2dirname = kw.get('sql_username2dirname')
|
||||
|
||||
default_libanki_version = '.'.join(anki.version.split('.')[:2])
|
||||
|
||||
def user_libanki_version(self, u):
|
||||
try:
|
||||
s = self.users[u]['libanki']
|
||||
except KeyError:
|
||||
return self.default_libanki_version
|
||||
|
||||
parts = s.split('.')
|
||||
if parts[0] == '1':
|
||||
if parts[1] == '0':
|
||||
return '1.0'
|
||||
elif parts[1] in ('1','2'):
|
||||
return '1.2'
|
||||
|
||||
return self.default_libanki_version
|
||||
|
||||
# Mimcs from anki.sync.SyncTools.stuff()
|
||||
def _stuff(self, data):
|
||||
return zlib.compress(json.dumps(data))
|
||||
|
||||
def _connect_mysql(self):
|
||||
if self.conn is None and len(self.mysql_args) > 0:
|
||||
self.conn = MySQLdb.connect(**self.mysql_args)
|
||||
|
||||
def _execute_sql(self, sql, args=()):
|
||||
self._connect_mysql()
|
||||
try:
|
||||
cur = self.conn.cursor()
|
||||
cur.execute(sql, args)
|
||||
except MySQLdb.OperationalError, e:
|
||||
if e.args[0] == 2006:
|
||||
# MySQL server has gone away message
|
||||
self.conn = None
|
||||
self._connect_mysql()
|
||||
cur = self.conn.cursor()
|
||||
cur.execute(sql, args)
|
||||
return cur
|
||||
|
||||
def check_password(self, username, password):
|
||||
if len(self.mysql_args) > 0 and self.sql_check_password is not None:
|
||||
cur = self._execute_sql(self.sql_check_password, (username, password))
|
||||
row = cur.fetchone()
|
||||
return row is not None
|
||||
|
||||
return True
|
||||
|
||||
def username2dirname(self, username):
|
||||
if len(self.mysql_args) > 0 and self.sql_username2dirname is not None:
|
||||
cur = self._execute_sql(self.sql_username2dirname, (username,))
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return str(row[0])
|
||||
|
||||
return username
|
||||
|
||||
def _getDecks(self, user_path):
|
||||
decks = {}
|
||||
|
||||
if os.path.exists(user_path):
|
||||
# It is a dict of {'deckName':[modified,lastSync]}
|
||||
for fn in os.listdir(unicode(user_path, 'utf-8')):
|
||||
if len(fn) > 5 and fn[-5:] == '.anki':
|
||||
d = os.path.abspath(os.path.join(user_path, fn))
|
||||
|
||||
# For simplicity, we will always open a thread. But this probably
|
||||
# isn't necessary!
|
||||
thread = AnkiServer.deck.thread_pool.start(d)
|
||||
def lookupModifiedLastSync(wrapper):
|
||||
deck = wrapper.open()
|
||||
return [deck.modified, deck.lastSync]
|
||||
res = thread.execute(lookupModifiedLastSync, [thread.wrapper])
|
||||
|
||||
# if thread_pool.threads.has_key(d):
|
||||
# thread = thread_pool.threads[d]
|
||||
# def lookupModifiedLastSync(wrapper):
|
||||
# deck = wrapper.open()
|
||||
# return [deck.modified, deck.lastSync]
|
||||
# res = thread.execute(lookup, [thread.wrapper])
|
||||
# else:
|
||||
# conn = sqlite.connect(d)
|
||||
# cur = conn.cursor()
|
||||
# cur.execute("select modified, lastSync from decks")
|
||||
#
|
||||
# res = list(cur.fetchone())
|
||||
#
|
||||
# cur.close()
|
||||
# conn.close()
|
||||
|
||||
#self.decks[fn[:-5]] = ["%.5f" % x for x in res]
|
||||
decks[fn[:-5]] = res
|
||||
|
||||
# same as HttpSyncServer.getDecks()
|
||||
return self._stuff({
|
||||
"status": "OK",
|
||||
"decks": decks,
|
||||
"timestamp": time.time(),
|
||||
})
|
||||
|
||||
def _fullup(self, wrapper, infile, version):
|
||||
wrapper.close()
|
||||
path = wrapper.path
|
||||
|
||||
# DRS: most of this function was graciously copied
|
||||
# from anki.sync.SyncTools.fullSyncFromServer()
|
||||
(fd, tmpname) = tempfile.mkstemp(dir=os.getcwd(), prefix="fullsync")
|
||||
outfile = open(tmpname, 'wb')
|
||||
decomp = zlib.decompressobj()
|
||||
while 1:
|
||||
data = infile.read(CHUNK_SIZE)
|
||||
if not data:
|
||||
outfile.write(decomp.flush())
|
||||
break
|
||||
outfile.write(decomp.decompress(data))
|
||||
infile.close()
|
||||
outfile.close()
|
||||
os.close(fd)
|
||||
# if we were successful, overwrite old deck
|
||||
if os.path.exists(path):
|
||||
os.unlink(path)
|
||||
os.rename(tmpname, path)
|
||||
# reset the deck name
|
||||
c = sqlite.connect(path)
|
||||
lastSync = time.time()
|
||||
if version == '1':
|
||||
c.execute("update decks set lastSync = ?", [lastSync])
|
||||
elif version == '2':
|
||||
c.execute("update decks set syncName = ?, lastSync = ?",
|
||||
[checksum(path.encode("utf-8")), lastSync])
|
||||
c.commit()
|
||||
c.close()
|
||||
|
||||
return lastSync
|
||||
|
||||
def _stuffedResp(self, data):
|
||||
return Response(
|
||||
status='200 OK',
|
||||
content_type='application/json',
|
||||
content_encoding='deflate',
|
||||
body=data)
|
||||
|
||||
@wsgify
|
||||
def __call__(self, req):
|
||||
if req.path.startswith(self.base_url):
|
||||
url = req.path[len(self.base_url):]
|
||||
if url not in self.valid_urls:
|
||||
raise HTTPNotFound()
|
||||
|
||||
# get and check username and password
|
||||
try:
|
||||
u = req.str_params.getone('u')
|
||||
p = req.str_params.getone('p')
|
||||
except KeyError:
|
||||
raise HTTPBadRequest('Must pass username and password')
|
||||
if not self.check_password(u, p):
|
||||
#raise HTTPBadRequest('Incorrect username or password')
|
||||
return self._stuffedResp(self._stuff({'status':'invalidUserPass'}))
|
||||
dirname = self.username2dirname(u)
|
||||
if dirname is None:
|
||||
raise HTTPBadRequest('Incorrect username or password')
|
||||
user_path = os.path.join(self.data_root, dirname)
|
||||
|
||||
# get and lock the (optional) deck for this request
|
||||
d = None
|
||||
try:
|
||||
d = unicode(req.str_params.getone('d'), 'utf-8')
|
||||
# AnkiDesktop actually passes us the string value 'None'!
|
||||
if d == 'None':
|
||||
d = None
|
||||
except KeyError:
|
||||
pass
|
||||
if d is not None:
|
||||
# get the full deck path name
|
||||
d = os.path.abspath(os.path.join(user_path, d)+'.anki')
|
||||
if d[:len(user_path)] != user_path:
|
||||
raise HTTPBadRequest('Bad deck name')
|
||||
thread = AnkiServer.deck.thread_pool.start(d)
|
||||
else:
|
||||
thread = None
|
||||
|
||||
if url == 'getDecks':
|
||||
# force the version up to 1.2.x
|
||||
v = req.str_params.getone('libanki')
|
||||
if v.startswith('0.') or v.startswith('1.0'):
|
||||
return self._stuffedResp(self._stuff({'status':'oldVersion'}))
|
||||
|
||||
# store the data the user passes us keyed with the username. This
|
||||
# will be used later by SyncAppHandler for version compatibility.
|
||||
self.users[u] = makeArgs(req.str_params)
|
||||
return self._stuffedResp(self._getDecks(user_path))
|
||||
|
||||
elif url in SyncAppHandler.operations:
|
||||
handler = SyncAppHandler()
|
||||
func = getattr(handler, url)
|
||||
args = makeArgs(req.str_params)
|
||||
|
||||
if thread is not None:
|
||||
# If this is for a specific deck, then it needs to run
|
||||
# inside of the DeckThread.
|
||||
def runFunc(wrapper):
|
||||
handler.deck = wrapper.open()
|
||||
ret = func(**args)
|
||||
handler.deck.save()
|
||||
return ret
|
||||
runFunc.func_name = url
|
||||
ret = thread.execute(runFunc, [thread.wrapper])
|
||||
else:
|
||||
# Otherwise, we can simply execute it in this thread.
|
||||
ret = func(**args)
|
||||
|
||||
# clean-up user data stored in getDecks
|
||||
if url == 'finish':
|
||||
del self.users[u]
|
||||
|
||||
return self._stuffedResp(ret)
|
||||
|
||||
elif url == 'fulldown':
|
||||
# set the syncTime before we send it
|
||||
def setupForSync(wrapper):
|
||||
wrapper.close()
|
||||
c = sqlite.connect(d)
|
||||
lastSync = time.time()
|
||||
c.execute("update decks set lastSync = ?", [lastSync])
|
||||
c.commit()
|
||||
c.close()
|
||||
thread.execute(setupForSync, [thread.wrapper])
|
||||
|
||||
return Response(status='200 OK', content_type='application/octet-stream', content_encoding='deflate', content_disposition='attachment; filename="'+os.path.basename(d).encode('utf-8')+'"', app_iter=FileIterable(d))
|
||||
elif url == 'fullup':
|
||||
#version = self.user_libanki_version(u)
|
||||
try:
|
||||
version = req.str_params.getone('v')
|
||||
except KeyError:
|
||||
version = '1'
|
||||
|
||||
infile = req.str_params['deck'].file
|
||||
lastSync = thread.execute(self._fullup, [thread.wrapper, infile, version])
|
||||
|
||||
# append the 'lastSync' value for libanki 1.1 and 1.2
|
||||
if version == '2':
|
||||
body = 'OK '+str(lastSync)
|
||||
else:
|
||||
body = 'OK'
|
||||
|
||||
return Response(status='200 OK', content_type='application/text', body=body)
|
||||
|
||||
return Response(status='200 OK', content_type='text/plain', body='Anki Server')
|
||||
|
||||
# Our entry point
|
||||
def make_app(global_conf, **local_conf):
|
||||
return SyncApp(**local_conf)
|
||||
|
||||
def main():
|
||||
from wsgiref.simple_server import make_server
|
||||
|
||||
ankiserver = DeckApp('.', '/sync/')
|
||||
httpd = make_server('', 8001, ankiserver)
|
||||
try:
|
||||
httpd.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
print "Exiting ..."
|
||||
finally:
|
||||
AnkiServer.deck.thread_pool.shutdown()
|
||||
|
||||
if __name__ == '__main__': main()
|
||||
|
||||
@ -1,100 +0,0 @@
|
||||
|
||||
from anki.importing.csvfile import TextImporter
|
||||
from anki.importing.apkg import AnkiPackageImporter
|
||||
from anki.importing.anki1 import Anki1Importer
|
||||
from anki.importing.supermemo_xml import SupermemoXmlImporter
|
||||
from anki.importing.mnemo import MnemosyneImporter
|
||||
from anki.importing.pauker import PaukerImporter
|
||||
|
||||
__all__ = ['get_importer_class', 'import_file']
|
||||
|
||||
importers = {
|
||||
'text': TextImporter,
|
||||
'apkg': AnkiPackageImporter,
|
||||
'anki1': Anki1Importer,
|
||||
'supermemo_xml': SupermemoXmlImporter,
|
||||
'mnemosyne': MnemosyneImporter,
|
||||
'pauker': PaukerImporter,
|
||||
}
|
||||
|
||||
def get_importer_class(type):
|
||||
global importers
|
||||
return importers.get(type)
|
||||
|
||||
def import_file(importer_class, col, path, allow_update = False):
|
||||
importer = importer_class(col, path)
|
||||
|
||||
if allow_update:
|
||||
importer.allowUpdate = True
|
||||
|
||||
if importer.needMapper:
|
||||
importer.open()
|
||||
|
||||
importer.run()
|
||||
|
||||
#
|
||||
# Monkey patch anki.importing.anki2 to support updating existing notes.
|
||||
# TODO: submit a patch to Anki!
|
||||
#
|
||||
|
||||
def _importNotes(self):
|
||||
# build guid -> (id,mod,mid) hash & map of existing note ids
|
||||
self._notes = {}
|
||||
existing = {}
|
||||
for id, guid, mod, mid in self.dst.db.execute(
|
||||
"select id, guid, mod, mid from notes"):
|
||||
self._notes[guid] = (id, mod, mid)
|
||||
existing[id] = True
|
||||
# we may need to rewrite the guid if the model schemas don't match,
|
||||
# so we need to keep track of the changes for the card import stage
|
||||
self._changedGuids = {}
|
||||
# iterate over source collection
|
||||
add = []
|
||||
dirty = []
|
||||
usn = self.dst.usn()
|
||||
dupes = 0
|
||||
for note in self.src.db.execute(
|
||||
"select * from notes"):
|
||||
# turn the db result into a mutable list
|
||||
note = list(note)
|
||||
shouldAdd = self._uniquifyNote(note)
|
||||
if shouldAdd:
|
||||
# ensure id is unique
|
||||
while note[0] in existing:
|
||||
note[0] += 999
|
||||
existing[note[0]] = True
|
||||
# bump usn
|
||||
note[4] = usn
|
||||
# update media references in case of dupes
|
||||
note[6] = self._mungeMedia(note[MID], note[6])
|
||||
add.append(note)
|
||||
dirty.append(note[0])
|
||||
# note we have the added the guid
|
||||
self._notes[note[GUID]] = (note[0], note[3], note[MID])
|
||||
else:
|
||||
dupes += 1
|
||||
|
||||
# update existing note
|
||||
newer = note[3] > mod
|
||||
if self.allowUpdate and self._mid(mid) == mid and newer:
|
||||
localNid = self._notes[note[GUID]][0]
|
||||
note[0] = localNid
|
||||
note[4] = usn
|
||||
add.append(note)
|
||||
dirty.append(note[0])
|
||||
|
||||
if dupes:
|
||||
self.log.append(_("Already in collection: %s.") % (ngettext(
|
||||
"%d note", "%d notes", dupes) % dupes))
|
||||
# add to col
|
||||
self.dst.db.executemany(
|
||||
"insert or replace into notes values (?,?,?,?,?,?,?,?,?,?,?)",
|
||||
add)
|
||||
self.dst.updateFieldCache(dirty)
|
||||
self.dst.tags.registerNotes(dirty)
|
||||
|
||||
from anki.importing.anki2 import Anki2Importer, MID, GUID
|
||||
from anki.lang import _, ngettext
|
||||
Anki2Importer._importNotes = _importNotes
|
||||
Anki2Importer.allowUpdate = False
|
||||
|
||||
@ -1,96 +0,0 @@
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import types
|
||||
|
||||
# The SMTPHandler taken from python 2.6
|
||||
class SMTPHandler(logging.Handler):
|
||||
"""
|
||||
A handler class which sends an SMTP email for each logging event.
|
||||
"""
|
||||
def __init__(self, mailhost, fromaddr, toaddrs, subject, credentials=None):
|
||||
"""
|
||||
Initialize the handler.
|
||||
|
||||
Initialize the instance with the from and to addresses and subject
|
||||
line of the email. To specify a non-standard SMTP port, use the
|
||||
(host, port) tuple format for the mailhost argument. To specify
|
||||
authentication credentials, supply a (username, password) tuple
|
||||
for the credentials argument.
|
||||
"""
|
||||
logging.Handler.__init__(self)
|
||||
if type(mailhost) == types.TupleType:
|
||||
self.mailhost, self.mailport = mailhost
|
||||
else:
|
||||
self.mailhost, self.mailport = mailhost, None
|
||||
if type(credentials) == types.TupleType:
|
||||
self.username, self.password = credentials
|
||||
else:
|
||||
self.username = None
|
||||
self.fromaddr = fromaddr
|
||||
if type(toaddrs) == types.StringType:
|
||||
toaddrs = [toaddrs]
|
||||
self.toaddrs = toaddrs
|
||||
self.subject = subject
|
||||
|
||||
def getSubject(self, record):
|
||||
"""
|
||||
Determine the subject for the email.
|
||||
|
||||
If you want to specify a subject line which is record-dependent,
|
||||
override this method.
|
||||
"""
|
||||
return self.subject
|
||||
|
||||
weekdayname = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
|
||||
|
||||
monthname = [None,
|
||||
'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
|
||||
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
|
||||
|
||||
def date_time(self):
|
||||
"""
|
||||
Return the current date and time formatted for a MIME header.
|
||||
Needed for Python 1.5.2 (no email package available)
|
||||
"""
|
||||
year, month, day, hh, mm, ss, wd, y, z = time.gmtime(time.time())
|
||||
s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
|
||||
self.weekdayname[wd],
|
||||
day, self.monthname[month], year,
|
||||
hh, mm, ss)
|
||||
return s
|
||||
|
||||
def emit(self, record):
|
||||
"""
|
||||
Emit a record.
|
||||
|
||||
Format the record and send it to the specified addressees.
|
||||
"""
|
||||
try:
|
||||
import smtplib
|
||||
try:
|
||||
from email.utils import formatdate
|
||||
except ImportError:
|
||||
formatdate = self.date_time
|
||||
port = self.mailport
|
||||
if not port:
|
||||
port = smtplib.SMTP_PORT
|
||||
smtp = smtplib.SMTP(self.mailhost, port)
|
||||
msg = self.format(record)
|
||||
msg = "From: %s\r\nTo: %s\r\nSubject: %s\r\nDate: %s\r\n\r\n%s" % (
|
||||
self.fromaddr,
|
||||
string.join(self.toaddrs, ","),
|
||||
self.getSubject(record),
|
||||
formatdate(), msg)
|
||||
if self.username:
|
||||
smtp.login(self.username, self.password)
|
||||
smtp.sendmail(self.fromaddr, self.toaddrs, msg)
|
||||
smtp.quit()
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
raise
|
||||
except:
|
||||
self.handleError(record)
|
||||
|
||||
# Monkey patch logging.handlers
|
||||
logging.handlers.SMTPHandler = SMTPHandler
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from ConfigParser import SafeConfigParser
|
||||
|
||||
from webob.dec import wsgify
|
||||
from webob.exc import *
|
||||
@ -94,18 +95,14 @@ class SyncUserSession(object):
|
||||
class SyncApp(object):
|
||||
valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download', 'getDecks']
|
||||
|
||||
def __init__(self, **kw):
|
||||
from AnkiServer.threading import getCollectionManager
|
||||
def __init__(self, config):
|
||||
from AnkiServer.thread import getCollectionManager
|
||||
|
||||
self.data_root = os.path.abspath(kw.get('data_root', '.'))
|
||||
self.base_url = kw.get('base_url', '/')
|
||||
self.auth_db_path = os.path.abspath(kw.get('auth_db_path', '.'))
|
||||
self.data_root = os.path.abspath(config.get("sync_app", "data_root"))
|
||||
self.base_url = config.get("sync_app", "base_url")
|
||||
self.sessions = {}
|
||||
|
||||
try:
|
||||
self.collection_manager = kw['collection_manager']
|
||||
except KeyError:
|
||||
self.collection_manager = getCollectionManager()
|
||||
self.collection_manager = getCollectionManager()
|
||||
|
||||
# make sure the base_url has a trailing slash
|
||||
if len(self.base_url) == 0:
|
||||
@ -297,6 +294,11 @@ class SyncApp(object):
|
||||
return Response(status='200 OK', content_type='text/plain', body='Anki Sync Server')
|
||||
|
||||
class DatabaseAuthSyncApp(SyncApp):
|
||||
def __init__(self, config):
|
||||
SyncApp.__init__(self, config)
|
||||
|
||||
self.auth_db_path = os.path.abspath(config.get("sync_app", "auth_db_path"))
|
||||
|
||||
def authenticate(self, username, password):
|
||||
"""Returns True if this username is allowed to connect with this password. False otherwise."""
|
||||
|
||||
@ -308,6 +310,8 @@ class DatabaseAuthSyncApp(SyncApp):
|
||||
|
||||
db_ret = cursor.fetchone()
|
||||
|
||||
conn.close()
|
||||
|
||||
if db_ret != None:
|
||||
db_hash = str(db_ret[0])
|
||||
salt = db_hash[-16:]
|
||||
@ -317,16 +321,16 @@ class DatabaseAuthSyncApp(SyncApp):
|
||||
|
||||
return (db_ret != None and hashobj.hexdigest()+salt == db_hash)
|
||||
|
||||
# Our entry point
|
||||
def make_app(global_conf, **local_conf):
|
||||
return DatabaseAuthSyncApp(**local_conf)
|
||||
|
||||
def main():
|
||||
from wsgiref.simple_server import make_server
|
||||
from AnkiServer.threading import shutdown
|
||||
from AnkiServer.thread import shutdown
|
||||
|
||||
config = SafeConfigParser()
|
||||
config.read("production.ini")
|
||||
|
||||
ankiserver = DatabaseAuthSyncApp(config)
|
||||
httpd = make_server('', config.getint("sync_app", "port"), ankiserver)
|
||||
|
||||
ankiserver = SyncApp()
|
||||
httpd = make_server('', 8001, ankiserver)
|
||||
try:
|
||||
print "Starting..."
|
||||
httpd.serve_forever()
|
||||
@ -1,18 +0,0 @@
|
||||
|
||||
def setup_logging(config_file=None):
|
||||
"""Setup logging based on a config_file."""
|
||||
|
||||
import logging
|
||||
|
||||
if config_file is not None:
|
||||
# monkey patch the logging.config.SMTPHandler if necessary
|
||||
import sys
|
||||
if sys.version_info[0] == 2 and sys.version_info[1] == 5:
|
||||
import AnkiServer.logpatch
|
||||
|
||||
# load the config file
|
||||
import logging.config
|
||||
logging.config.fileConfig(config_file)
|
||||
else:
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
30
example.ini
30
example.ini
@ -1,34 +1,6 @@
|
||||
|
||||
[server:main]
|
||||
#use = egg:Paste#http
|
||||
use = egg:AnkiServer#server
|
||||
[sync_app]
|
||||
host = 127.0.0.1
|
||||
port = 27701
|
||||
|
||||
[filter-app:main]
|
||||
use = egg:Paste#translogger
|
||||
next = real
|
||||
|
||||
[app:real]
|
||||
use = egg:Paste#urlmap
|
||||
/collection = rest_app
|
||||
/sync = sync_app
|
||||
|
||||
[app:rest_app]
|
||||
use = egg:AnkiServer#rest_app
|
||||
data_root = ./collections
|
||||
allowed_hosts = 127.0.0.1
|
||||
logging.config_file = logging.conf
|
||||
|
||||
[app:sync_app]
|
||||
use = egg:AnkiServer#sync_app
|
||||
data_root = ./collections
|
||||
base_url = /sync/
|
||||
auth_db_path = ./auth.db
|
||||
mysql.host = 127.0.0.1
|
||||
mysql.user = db_user
|
||||
mysql.passwd = db_password
|
||||
mysql.db = db
|
||||
sql_check_password = SELECT uid FROM users WHERE name=%s AND pass=MD5(%s) AND status=1
|
||||
sql_username2dirname = SELECT uid AS dirname FROM users WHERE name=%s
|
||||
|
||||
|
||||
41
logging.conf
41
logging.conf
@ -1,41 +0,0 @@
|
||||
|
||||
[loggers]
|
||||
keys=root
|
||||
|
||||
[handlers]
|
||||
keys=screen,file,email
|
||||
|
||||
[formatters]
|
||||
keys=normal,email
|
||||
|
||||
[logger_root]
|
||||
level=INFO
|
||||
handlers=screen
|
||||
#handlers=file
|
||||
#handlers=file,email
|
||||
|
||||
[handler_file]
|
||||
class=FileHandler
|
||||
formatter=normal
|
||||
args=('server.log','a')
|
||||
|
||||
[handler_screen]
|
||||
class=StreamHandler
|
||||
level=NOTSET
|
||||
formatter=normal
|
||||
args=(sys.stdout,)
|
||||
|
||||
[handler_email]
|
||||
class=handlers.SMTPHandler
|
||||
level=ERROR
|
||||
formatter=email
|
||||
args=('smtp.example.com', 'support@example.com', ['support_guy1@example.com', 'support_guy2@example.com'], 'AnkiServer error', ('smtp_user', 'smtp_password'))
|
||||
|
||||
[formatter_normal]
|
||||
format=%(asctime)s:%(name)s:%(levelname)s:%(message)s
|
||||
datefmt=
|
||||
|
||||
[formatter_email]
|
||||
format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
|
||||
datefmt=
|
||||
|
||||
23
setup.py
23
setup.py
@ -1,23 +0,0 @@
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="AnkiServer",
|
||||
version="2.0.0a1",
|
||||
description="A personal Anki sync server (so you can sync against your own server rather than AnkiWeb)",
|
||||
author="David Snopek",
|
||||
author_email="dsnopek@gmail.com",
|
||||
install_requires=["PasteDeploy>=1.3.2"],
|
||||
# TODO: should these really be in install_requires?
|
||||
requires=["webob(>=0.9.7)"],
|
||||
test_suite='nose.collector',
|
||||
entry_points="""
|
||||
[paste.app_factory]
|
||||
sync_app = AnkiServer.apps.sync_app:make_app
|
||||
rest_app = AnkiServer.apps.rest_app:make_app
|
||||
|
||||
[paste.server_runner]
|
||||
server = AnkiServer:server_runner
|
||||
""",
|
||||
)
|
||||
|
||||
@ -1,12 +0,0 @@
|
||||
[program:anki-server]
|
||||
command=/usr/local/bin/paster serve production.ini
|
||||
directory=/var/anki-server
|
||||
user=www-data
|
||||
autostart=true
|
||||
autorestart=true
|
||||
redirect_stderr=true
|
||||
|
||||
; Sometimes necessary if Anki is complaining about a UTF-8 locale. Make sure
|
||||
; that the local you pick is actually installed on your system.
|
||||
;environment=LANG=en_US.UTF-8,LC_ALL=en_US.UTF-8,LC_LANG=en_US.UTF-8
|
||||
|
||||
@ -1,140 +0,0 @@
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import mock
|
||||
from mock import MagicMock, sentinel
|
||||
|
||||
import AnkiServer
|
||||
from AnkiServer.collection import CollectionWrapper, CollectionManager
|
||||
|
||||
class CollectionWrapperTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.collection_path = os.path.join(self.temp_dir, 'collection.anki2');
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
def test_lifecycle_real(self):
|
||||
"""Testing common life-cycle with existing and non-existant collections. This
|
||||
test uses the real Anki objects and actually creates a new collection on disk."""
|
||||
|
||||
wrapper = CollectionWrapper(self.collection_path)
|
||||
self.assertFalse(os.path.exists(self.collection_path))
|
||||
self.assertFalse(wrapper.opened())
|
||||
|
||||
wrapper.open()
|
||||
self.assertTrue(os.path.exists(self.collection_path))
|
||||
self.assertTrue(wrapper.opened())
|
||||
|
||||
# calling open twice shouldn't break anything
|
||||
wrapper.open()
|
||||
|
||||
wrapper.close()
|
||||
self.assertTrue(os.path.exists(self.collection_path))
|
||||
self.assertFalse(wrapper.opened())
|
||||
|
||||
# open the same collection again (not a creation)
|
||||
wrapper = CollectionWrapper(self.collection_path)
|
||||
self.assertFalse(wrapper.opened())
|
||||
wrapper.open()
|
||||
self.assertTrue(wrapper.opened())
|
||||
wrapper.close()
|
||||
self.assertFalse(wrapper.opened())
|
||||
self.assertTrue(os.path.exists(self.collection_path))
|
||||
|
||||
def test_del(self):
|
||||
with mock.patch('anki.storage.Collection') as anki_storage_Collection:
|
||||
col = anki_storage_Collection.return_value
|
||||
wrapper = CollectionWrapper(self.collection_path)
|
||||
wrapper.open()
|
||||
wrapper = None
|
||||
col.close.assert_called_with()
|
||||
|
||||
def test_setup_func(self):
|
||||
# Run it when the collection doesn't exist
|
||||
with mock.patch('anki.storage.Collection') as anki_storage_Collection:
|
||||
col = anki_storage_Collection.return_value
|
||||
setup_new_collection = MagicMock()
|
||||
self.assertFalse(os.path.exists(self.collection_path))
|
||||
wrapper = CollectionWrapper(self.collection_path, setup_new_collection)
|
||||
wrapper.open()
|
||||
anki_storage_Collection.assert_called_with(self.collection_path)
|
||||
setup_new_collection.assert_called_with(col)
|
||||
wrapper = None
|
||||
|
||||
# Make sure that no collection was actually created
|
||||
self.assertFalse(os.path.exists(self.collection_path))
|
||||
|
||||
# Create a faux collection file
|
||||
with file(self.collection_path, 'wt') as fd:
|
||||
fd.write('Collection!')
|
||||
|
||||
# Run it when the collection does exist
|
||||
with mock.patch('anki.storage.Collection'):
|
||||
setup_new_collection = lambda col: self.fail("Setup function called when collection already exists!")
|
||||
self.assertTrue(os.path.exists(self.collection_path))
|
||||
wrapper = CollectionWrapper(self.collection_path, setup_new_collection)
|
||||
wrapper.open()
|
||||
anki_storage_Collection.assert_called_with(self.collection_path)
|
||||
wrapper = None
|
||||
|
||||
def test_execute(self):
|
||||
with mock.patch('anki.storage.Collection') as anki_storage_Collection:
|
||||
col = anki_storage_Collection.return_value
|
||||
func = MagicMock()
|
||||
func.return_value = sentinel.some_object
|
||||
|
||||
# check that execute works and auto-creates the collection
|
||||
wrapper = CollectionWrapper(self.collection_path)
|
||||
ret = wrapper.execute(func, [1, 2, 3], {'key': 'aoeu'})
|
||||
self.assertEqual(ret, sentinel.some_object)
|
||||
anki_storage_Collection.assert_called_with(self.collection_path)
|
||||
func.assert_called_with(col, 1, 2, 3, key='aoeu')
|
||||
|
||||
# check that execute always returns False if waitForReturn=False
|
||||
func.reset_mock()
|
||||
ret = wrapper.execute(func, [1, 2, 3], {'key': 'aoeu'}, waitForReturn=False)
|
||||
self.assertEqual(ret, None)
|
||||
func.assert_called_with(col, 1, 2, 3, key='aoeu')
|
||||
|
||||
class CollectionManagerTest(unittest.TestCase):
|
||||
def test_lifecycle(self):
|
||||
with mock.patch('AnkiServer.collection.CollectionManager.collection_wrapper') as CollectionWrapper:
|
||||
wrapper = MagicMock()
|
||||
CollectionWrapper.return_value = wrapper
|
||||
|
||||
manager = CollectionManager()
|
||||
|
||||
# check getting a new collection
|
||||
ret = manager.get_collection('path1')
|
||||
CollectionWrapper.assert_called_with(os.path.realpath('path1'), None)
|
||||
self.assertEqual(ret, wrapper)
|
||||
|
||||
# change the return value, so that it would return a new object
|
||||
new_wrapper = MagicMock()
|
||||
CollectionWrapper.return_value = new_wrapper
|
||||
CollectionWrapper.reset_mock()
|
||||
|
||||
# get the new wrapper
|
||||
ret = manager.get_collection('path2')
|
||||
CollectionWrapper.assert_called_with(os.path.realpath('path2'), None)
|
||||
self.assertEqual(ret, new_wrapper)
|
||||
|
||||
# make sure the wrapper and new_wrapper are different
|
||||
self.assertNotEqual(wrapper, new_wrapper)
|
||||
|
||||
# assert that calling with the first path again, returns the first wrapper
|
||||
ret = manager.get_collection('path1')
|
||||
self.assertEqual(ret, wrapper)
|
||||
|
||||
manager.shutdown()
|
||||
wrapper.close.assert_called_with()
|
||||
new_wrapper.close.assert_called_with()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@ -1,113 +0,0 @@
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import mock
|
||||
from mock import MagicMock, sentinel
|
||||
|
||||
import AnkiServer
|
||||
from AnkiServer.importer import get_importer_class, import_file
|
||||
|
||||
import anki.storage
|
||||
|
||||
# TODO: refactor into some kind of utility
|
||||
def add_note(col, data):
|
||||
from anki.notes import Note
|
||||
|
||||
model = col.models.byName(data['model'])
|
||||
|
||||
note = Note(col, model)
|
||||
for name, value in data['fields'].items():
|
||||
note[name] = value
|
||||
|
||||
if data.has_key('tags'):
|
||||
note.setTagsFromStr(data['tags'])
|
||||
|
||||
col.addNote(note)
|
||||
|
||||
class ImporterTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.collection_path = os.path.join(self.temp_dir, 'collection.anki2')
|
||||
self.collection = anki.storage.Collection(self.collection_path)
|
||||
|
||||
def tearDown(self):
|
||||
self.collection.close()
|
||||
self.collection = None
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
|
||||
# TODO: refactor into a parent class
|
||||
def add_default_note(self, count=1):
|
||||
data = {
|
||||
'model': 'Basic',
|
||||
'fields': {
|
||||
'Front': 'The front',
|
||||
'Back': 'The back',
|
||||
},
|
||||
'tags': "Tag1 Tag2",
|
||||
}
|
||||
for idx in range(0, count):
|
||||
add_note(self.collection, data)
|
||||
self.add_note(data)
|
||||
|
||||
def test_resync(self):
|
||||
from anki.exporting import AnkiPackageExporter
|
||||
from anki.utils import intTime
|
||||
|
||||
# create a new collection with a single note
|
||||
src_collection = anki.storage.Collection(os.path.join(self.temp_dir, 'src_collection.anki2'))
|
||||
add_note(src_collection, {
|
||||
'model': 'Basic',
|
||||
'fields': {
|
||||
'Front': 'The front',
|
||||
'Back': 'The back',
|
||||
},
|
||||
'tags': 'Tag1 Tag2',
|
||||
})
|
||||
note_id = src_collection.findNotes('')[0]
|
||||
note = src_collection.getNote(note_id)
|
||||
self.assertEqual(note.id, note_id)
|
||||
self.assertEqual(note['Front'], 'The front')
|
||||
self.assertEqual(note['Back'], 'The back')
|
||||
|
||||
# export to an .apkg file
|
||||
dst1_path = os.path.join(self.temp_dir, 'export1.apkg')
|
||||
exporter = AnkiPackageExporter(src_collection)
|
||||
exporter.exportInto(dst1_path)
|
||||
|
||||
# import it into the main collection
|
||||
import_file(get_importer_class('apkg'), self.collection, dst1_path)
|
||||
|
||||
# make sure the note exists
|
||||
note = self.collection.getNote(note_id)
|
||||
self.assertEqual(note.id, note_id)
|
||||
self.assertEqual(note['Front'], 'The front')
|
||||
self.assertEqual(note['Back'], 'The back')
|
||||
|
||||
# now we change the source collection and re-export it
|
||||
note = src_collection.getNote(note_id)
|
||||
note['Front'] = 'The new front'
|
||||
note.tags.append('Tag3')
|
||||
note.flush(intTime()+1)
|
||||
dst2_path = os.path.join(self.temp_dir, 'export2.apkg')
|
||||
exporter = AnkiPackageExporter(src_collection)
|
||||
exporter.exportInto(dst2_path)
|
||||
|
||||
# first, import it without allow_update - no change should happen
|
||||
import_file(get_importer_class('apkg'), self.collection, dst2_path)
|
||||
note = self.collection.getNote(note_id)
|
||||
self.assertEqual(note['Front'], 'The front')
|
||||
self.assertEqual(note.tags, ['Tag1', 'Tag2'])
|
||||
|
||||
# now, import it with allow_update=True, so the note should change
|
||||
import_file(get_importer_class('apkg'), self.collection, dst2_path, allow_update=True)
|
||||
note = self.collection.getNote(note_id)
|
||||
self.assertEqual(note['Front'], 'The new front')
|
||||
self.assertEqual(note.tags, ['Tag1', 'Tag2', 'Tag3'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@ -1,618 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
import logging
|
||||
from pprint import pprint
|
||||
|
||||
import mock
|
||||
from mock import MagicMock
|
||||
|
||||
import AnkiServer
|
||||
from AnkiServer.collection import CollectionManager
|
||||
from AnkiServer.apps.rest_app import RestApp, RestHandlerRequest, CollectionHandler, ImportExportHandler, NoteHandler, ModelHandler, DeckHandler, CardHandler
|
||||
|
||||
from webob.exc import *
|
||||
|
||||
import anki
|
||||
import anki.storage
|
||||
|
||||
class RestAppTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.collection_manager = CollectionManager()
|
||||
self.rest_app = RestApp(self.temp_dir, collection_manager=self.collection_manager)
|
||||
|
||||
# disable all but critical errors!
|
||||
logging.disable(logging.CRITICAL)
|
||||
|
||||
def tearDown(self):
|
||||
self.collection_manager.shutdown()
|
||||
self.collection_manager = None
|
||||
self.rest_app = None
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
def test_list_collections(self):
|
||||
os.mkdir(os.path.join(self.temp_dir, 'test1'))
|
||||
os.mkdir(os.path.join(self.temp_dir, 'test2'))
|
||||
|
||||
with open(os.path.join(self.temp_dir, 'test1', 'collection.anki2'), 'wt') as fd:
|
||||
fd.write('Testing!')
|
||||
|
||||
self.assertEqual(self.rest_app.list_collections(), ['test1'])
|
||||
|
||||
def test_parsePath(self):
|
||||
tests = [
|
||||
('collection/user', ('collection', 'index', ['user'])),
|
||||
('collection/user/handler', ('collection', 'handler', ['user'])),
|
||||
('collection/user/note/123', ('note', 'index', ['user', '123'])),
|
||||
('collection/user/note/123/handler', ('note', 'handler', ['user', '123'])),
|
||||
('collection/user/deck/name', ('deck', 'index', ['user', 'name'])),
|
||||
('collection/user/deck/name/handler', ('deck', 'handler', ['user', 'name'])),
|
||||
#('collection/user/deck/name/card/123', ('card', 'index', ['user', 'name', '123'])),
|
||||
#('collection/user/deck/name/card/123/handler', ('card', 'handler', ['user', 'name', '123'])),
|
||||
('collection/user/card/123', ('card', 'index', ['user', '123'])),
|
||||
('collection/user/card/123/handler', ('card', 'handler', ['user', '123'])),
|
||||
# the leading slash should make no difference!
|
||||
('/collection/user', ('collection', 'index', ['user'])),
|
||||
]
|
||||
|
||||
for path, result in tests:
|
||||
self.assertEqual(self.rest_app._parsePath(path), result)
|
||||
|
||||
def test_parsePath_not_found(self):
|
||||
tests = [
|
||||
'bad',
|
||||
'bad/oaeu',
|
||||
'collection',
|
||||
'collection/user/handler/bad',
|
||||
'',
|
||||
'/',
|
||||
]
|
||||
|
||||
for path in tests:
|
||||
self.assertRaises(HTTPNotFound, self.rest_app._parsePath, path)
|
||||
|
||||
def test_getCollectionPath(self):
|
||||
def fullpath(collection_id):
|
||||
return os.path.normpath(os.path.join(self.temp_dir, collection_id, 'collection.anki2'))
|
||||
|
||||
# This is simple and straight forward!
|
||||
self.assertEqual(self.rest_app._getCollectionPath('user'), fullpath('user'))
|
||||
|
||||
# These are dangerous - the user is trying to hack us!
|
||||
dangerous = ['../user', '/etc/passwd', '/tmp/aBaBaB', '/root/.ssh/id_rsa']
|
||||
for collection_id in dangerous:
|
||||
self.assertRaises(HTTPBadRequest, self.rest_app._getCollectionPath, collection_id)
|
||||
|
||||
def test_getHandler(self):
|
||||
def handlerOne():
|
||||
pass
|
||||
|
||||
def handlerTwo():
|
||||
pass
|
||||
handlerTwo.hasReturnValue = False
|
||||
|
||||
self.rest_app.add_handler('collection', 'handlerOne', handlerOne)
|
||||
self.rest_app.add_handler('deck', 'handlerTwo', handlerTwo)
|
||||
|
||||
(handler, hasReturnValue) = self.rest_app._getHandler('collection', 'handlerOne')
|
||||
self.assertEqual(handler, handlerOne)
|
||||
self.assertEqual(hasReturnValue, True)
|
||||
|
||||
(handler, hasReturnValue) = self.rest_app._getHandler('deck', 'handlerTwo')
|
||||
self.assertEqual(handler, handlerTwo)
|
||||
self.assertEqual(hasReturnValue, False)
|
||||
|
||||
# try some bad handler names and types
|
||||
self.assertRaises(HTTPNotFound, self.rest_app._getHandler, 'collection', 'nonExistantHandler')
|
||||
self.assertRaises(HTTPNotFound, self.rest_app._getHandler, 'nonExistantType', 'handlerOne')
|
||||
|
||||
def test_parseRequestBody(self):
|
||||
req = MagicMock()
|
||||
req.body = '{"key":"value"}'
|
||||
|
||||
data = self.rest_app._parseRequestBody(req)
|
||||
self.assertEqual(data, {'key': 'value'})
|
||||
self.assertEqual(data.keys(), ['key'])
|
||||
self.assertEqual(type(data.keys()[0]), str)
|
||||
|
||||
# test some bad data
|
||||
req.body = '{aaaaaaa}'
|
||||
self.assertRaises(HTTPBadRequest, self.rest_app._parseRequestBody, req)
|
||||
|
||||
class CollectionTestBase(unittest.TestCase):
|
||||
"""Parent class for tests that need a collection set up and torn down."""
|
||||
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.collection_path = os.path.join(self.temp_dir, 'collection.anki2');
|
||||
self.collection = anki.storage.Collection(self.collection_path)
|
||||
self.mock_app = MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
self.collection.close()
|
||||
self.collection = None
|
||||
shutil.rmtree(self.temp_dir)
|
||||
self.mock_app.reset_mock()
|
||||
|
||||
# TODO: refactor into some kind of utility
|
||||
def add_note(self, data):
|
||||
from anki.notes import Note
|
||||
|
||||
model = self.collection.models.byName(data['model'])
|
||||
|
||||
note = Note(self.collection, model)
|
||||
for name, value in data['fields'].items():
|
||||
note[name] = value
|
||||
|
||||
if data.has_key('tags'):
|
||||
note.setTagsFromStr(data['tags'])
|
||||
|
||||
self.collection.addNote(note)
|
||||
|
||||
# TODO: refactor into a parent class
|
||||
def add_default_note(self, count=1):
|
||||
data = {
|
||||
'model': 'Basic',
|
||||
'fields': {
|
||||
'Front': 'The front',
|
||||
'Back': 'The back',
|
||||
},
|
||||
'tags': "Tag1 Tag2",
|
||||
}
|
||||
for idx in range(0, count):
|
||||
self.add_note(data)
|
||||
|
||||
class CollectionHandlerTest(CollectionTestBase):
|
||||
def setUp(self):
|
||||
super(CollectionHandlerTest, self).setUp()
|
||||
self.handler = CollectionHandler()
|
||||
|
||||
def execute(self, name, data):
|
||||
ids = ['collection_name']
|
||||
func = getattr(self.handler, name)
|
||||
req = RestHandlerRequest(self.mock_app, data, ids, {})
|
||||
return func(self.collection, req)
|
||||
|
||||
def test_list_decks(self):
|
||||
data = {}
|
||||
ret = self.execute('list_decks', data)
|
||||
|
||||
# It contains only the 'Default' deck
|
||||
self.assertEqual(len(ret), 1)
|
||||
self.assertEqual(ret[0]['name'], 'Default')
|
||||
|
||||
def test_select_deck(self):
|
||||
data = {'deck_id': '1'}
|
||||
ret = self.execute('select_deck', data)
|
||||
self.assertEqual(ret, None);
|
||||
|
||||
def test_create_dynamic_deck_simple(self):
|
||||
self.add_default_note(5)
|
||||
|
||||
data = {
|
||||
'name': 'Dyn deck',
|
||||
'mode': 'random',
|
||||
'count': 2,
|
||||
'query': "deck:\"Default\" (tag:'Tag1' or tag:'Tag2') (-tag:'Tag3')",
|
||||
}
|
||||
ret = self.execute('create_dynamic_deck', data)
|
||||
self.assertEqual(ret['name'], 'Dyn deck')
|
||||
self.assertEqual(ret['dyn'], True)
|
||||
|
||||
cards = self.collection.findCards('deck:"Dyn deck"')
|
||||
self.assertEqual(len(cards), 2)
|
||||
|
||||
def test_list_models(self):
|
||||
data = {}
|
||||
ret = self.execute('list_models', data)
|
||||
|
||||
# get a sorted name list that we can actually check
|
||||
names = [model['name'] for model in ret]
|
||||
names.sort()
|
||||
|
||||
# These are the default models created by Anki in a new collection
|
||||
default_models = [
|
||||
'Basic',
|
||||
'Basic (and reversed card)',
|
||||
'Basic (optional reversed card)',
|
||||
'Cloze'
|
||||
]
|
||||
|
||||
self.assertEqual(names, default_models)
|
||||
|
||||
def test_find_model_by_name(self):
|
||||
data = {'model': 'Basic'}
|
||||
ret = self.execute('find_model_by_name', data)
|
||||
self.assertEqual(ret['name'], 'Basic')
|
||||
|
||||
def test_find_notes(self):
|
||||
ret = self.execute('find_notes', {})
|
||||
self.assertEqual(ret, [])
|
||||
|
||||
# add a note programatically
|
||||
self.add_default_note()
|
||||
|
||||
# get the id for the one note on this collection
|
||||
note_id = self.collection.findNotes('')[0]
|
||||
|
||||
ret = self.execute('find_notes', {})
|
||||
self.assertEqual(ret, [{'id': note_id}])
|
||||
|
||||
ret = self.execute('find_notes', {'query': 'tag:Tag1'})
|
||||
self.assertEqual(ret, [{'id': note_id}])
|
||||
|
||||
ret = self.execute('find_notes', {'query': 'tag:TagX'})
|
||||
self.assertEqual(ret, [])
|
||||
|
||||
ret = self.execute('find_notes', {'preload': True})
|
||||
self.assertEqual(len(ret), 1)
|
||||
self.assertEqual(ret[0]['id'], note_id)
|
||||
self.assertEqual(ret[0]['model']['name'], 'Basic')
|
||||
|
||||
def test_add_note(self):
|
||||
# make sure there are no notes (yet)
|
||||
self.assertEqual(self.collection.findNotes(''), [])
|
||||
|
||||
# add a note programatically
|
||||
note = {
|
||||
'model': 'Basic',
|
||||
'fields': {
|
||||
'Front': 'The front',
|
||||
'Back': 'The back',
|
||||
},
|
||||
'tags': "Tag1 Tag2",
|
||||
}
|
||||
self.execute('add_note', note)
|
||||
|
||||
notes = self.collection.findNotes('')
|
||||
self.assertEqual(len(notes), 1)
|
||||
|
||||
note_id = notes[0]
|
||||
note = self.collection.getNote(note_id)
|
||||
|
||||
self.assertEqual(note.model()['name'], 'Basic')
|
||||
self.assertEqual(note['Front'], 'The front')
|
||||
self.assertEqual(note['Back'], 'The back')
|
||||
self.assertEqual(note.tags, ['Tag1', 'Tag2'])
|
||||
|
||||
def test_list_tags(self):
|
||||
ret = self.execute('list_tags', {})
|
||||
self.assertEqual(ret, [])
|
||||
|
||||
self.add_default_note()
|
||||
|
||||
ret = self.execute('list_tags', {})
|
||||
ret.sort()
|
||||
self.assertEqual(ret, ['Tag1', 'Tag2'])
|
||||
|
||||
def test_set_language(self):
|
||||
import anki.lang
|
||||
|
||||
self.assertEqual(anki.lang._('Again'), 'Again')
|
||||
|
||||
try:
|
||||
data = {'code': 'pl'}
|
||||
self.execute('set_language', data)
|
||||
self.assertEqual(anki.lang._('Again'), u'Znowu')
|
||||
finally:
|
||||
# return everything to normal!
|
||||
anki.lang.setLang('en')
|
||||
|
||||
def test_reset_scheduler(self):
|
||||
self.add_default_note(3)
|
||||
|
||||
ret = self.execute('reset_scheduler', {'deck': 'Default'})
|
||||
self.assertEqual(ret, {
|
||||
'new_cards': 3,
|
||||
'learning_cards': 0,
|
||||
'review_cards': 0,
|
||||
})
|
||||
|
||||
def test_next_card(self):
|
||||
ret = self.execute('next_card', {})
|
||||
self.assertEqual(ret, None)
|
||||
|
||||
# add a note programatically
|
||||
self.add_default_note()
|
||||
|
||||
# get the id for the one card and note on this collection
|
||||
note_id = self.collection.findNotes('')[0]
|
||||
card_id = self.collection.findCards('')[0]
|
||||
|
||||
self.collection.sched.reset()
|
||||
ret = self.execute('next_card', {})
|
||||
self.assertEqual(ret['id'], card_id)
|
||||
self.assertEqual(ret['nid'], note_id)
|
||||
self.assertEqual(ret['css'], '<style>.card {\n font-family: arial;\n font-size: 20px;\n text-align: center;\n color: black;\n background-color: white;\n}\n</style>')
|
||||
self.assertEqual(ret['question'], 'The front')
|
||||
self.assertEqual(ret['answer'], 'The front\n\n<hr id=answer>\n\nThe back')
|
||||
self.assertEqual(ret['answer_buttons'], [
|
||||
{'ease': 1,
|
||||
'label': 'Again',
|
||||
'string_label': 'Again',
|
||||
'interval': 60,
|
||||
'string_interval': '<1 minute'},
|
||||
{'ease': 2,
|
||||
'label': 'Good',
|
||||
'string_label': 'Good',
|
||||
'interval': 600,
|
||||
'string_interval': '<10 minutes'},
|
||||
{'ease': 3,
|
||||
'label': 'Easy',
|
||||
'string_label': 'Easy',
|
||||
'interval': 345600,
|
||||
'string_interval': '4 days'}])
|
||||
|
||||
def test_next_card_translation(self):
|
||||
# add a note programatically
|
||||
self.add_default_note()
|
||||
|
||||
# get the card in Polish so we can test translation too
|
||||
anki.lang.setLang('pl')
|
||||
try:
|
||||
ret = self.execute('next_card', {})
|
||||
finally:
|
||||
anki.lang.setLang('en')
|
||||
|
||||
self.assertEqual(ret['answer_buttons'], [
|
||||
{'ease': 1,
|
||||
'label': 'Again',
|
||||
'string_label': u'Znowu',
|
||||
'interval': 60,
|
||||
'string_interval': '<1 minuta'},
|
||||
{'ease': 2,
|
||||
'label': 'Good',
|
||||
'string_label': u'Dobra',
|
||||
'interval': 600,
|
||||
'string_interval': '<10 minut'},
|
||||
{'ease': 3,
|
||||
'label': 'Easy',
|
||||
'string_label': u'Łatwa',
|
||||
'interval': 345600,
|
||||
'string_interval': '4 dni'}])
|
||||
|
||||
def test_next_card_five_times(self):
|
||||
self.add_default_note(5)
|
||||
for idx in range(0, 5):
|
||||
ret = self.execute('next_card', {})
|
||||
self.assertTrue(ret is not None)
|
||||
|
||||
def test_answer_card(self):
|
||||
import time
|
||||
|
||||
self.add_default_note()
|
||||
|
||||
# instantiate a deck handler to get the card
|
||||
card = self.execute('next_card', {})
|
||||
self.assertEqual(card['reps'], 0)
|
||||
|
||||
self.execute('answer_card', {'id': card['id'], 'ease': 2, 'timerStarted': time.time()})
|
||||
|
||||
# reset the scheduler and try to get the next card again - there should be none!
|
||||
self.collection.sched.reset()
|
||||
card = self.execute('next_card', {})
|
||||
self.assertEqual(card['reps'], 1)
|
||||
|
||||
def test_suspend_cards(self):
|
||||
# add a note programatically
|
||||
self.add_default_note()
|
||||
|
||||
# get the id for the one card on this collection
|
||||
card_id = self.collection.findCards('')[0]
|
||||
|
||||
# suspend it
|
||||
self.execute('suspend_cards', {'ids': [card_id]})
|
||||
|
||||
# test that getting the next card will be None
|
||||
card = self.collection.sched.getCard()
|
||||
self.assertEqual(card, None)
|
||||
|
||||
# unsuspend it
|
||||
self.execute('unsuspend_cards', {'ids': [card_id]})
|
||||
|
||||
# test that now we're getting the next card!
|
||||
self.collection.sched.reset()
|
||||
card = self.collection.sched.getCard()
|
||||
self.assertEqual(card.id, card_id)
|
||||
|
||||
def test_cards_recent_ease(self):
|
||||
self.add_default_note()
|
||||
card_id = self.collection.findCards('')[0]
|
||||
|
||||
# answer the card
|
||||
self.collection.reset()
|
||||
card = self.collection.sched.getCard()
|
||||
card.startTimer()
|
||||
# answer multiple times to see that we only get the latest!
|
||||
self.collection.sched.answerCard(card, 1)
|
||||
self.collection.sched.answerCard(card, 3)
|
||||
self.collection.sched.answerCard(card, 2)
|
||||
|
||||
# pull the latest revision
|
||||
ret = self.execute('cards_recent_ease', {})
|
||||
self.assertEqual(ret[0]['id'], card_id)
|
||||
self.assertEqual(ret[0]['ease'], 2)
|
||||
|
||||
class ImportExportHandlerTest(CollectionTestBase):
|
||||
export_rows = [
|
||||
['Card front 1', 'Card back 1', 'Tag1 Tag2'],
|
||||
['Card front 2', 'Card back 2', 'Tag1 Tag3'],
|
||||
]
|
||||
|
||||
def setUp(self):
|
||||
super(ImportExportHandlerTest, self).setUp()
|
||||
self.handler = ImportExportHandler()
|
||||
|
||||
def execute(self, name, data):
|
||||
ids = ['collection_name']
|
||||
func = getattr(self.handler, name)
|
||||
req = RestHandlerRequest(self.mock_app, data, ids, {})
|
||||
return func(self.collection, req)
|
||||
|
||||
def generate_text_export(self):
|
||||
# Create a simple export file
|
||||
export_data = ''
|
||||
for row in self.export_rows:
|
||||
export_data += '\t'.join(row) + '\n'
|
||||
export_path = os.path.join(self.temp_dir, 'export.txt')
|
||||
with file(export_path, 'wt') as fd:
|
||||
fd.write(export_data)
|
||||
|
||||
return (export_data, export_path)
|
||||
|
||||
def check_import(self):
|
||||
note_ids = self.collection.findNotes('')
|
||||
notes = [self.collection.getNote(note_id) for note_id in note_ids]
|
||||
self.assertEqual(len(notes), len(self.export_rows))
|
||||
|
||||
for index, test_data in enumerate(self.export_rows):
|
||||
self.assertEqual(notes[index]['Front'], test_data[0])
|
||||
self.assertEqual(notes[index]['Back'], test_data[1])
|
||||
self.assertEqual(' '.join(notes[index].tags), test_data[2])
|
||||
|
||||
def test_import_text_data(self):
|
||||
(export_data, export_path) = self.generate_text_export()
|
||||
|
||||
data = {
|
||||
'filetype': 'text',
|
||||
'data': export_data,
|
||||
}
|
||||
ret = self.execute('import_file', data)
|
||||
self.check_import()
|
||||
|
||||
def test_import_text_url(self):
|
||||
(export_data, export_path) = self.generate_text_export()
|
||||
|
||||
data = {
|
||||
'filetype': 'text',
|
||||
'url': 'file://' + os.path.realpath(export_path),
|
||||
}
|
||||
ret = self.execute('import_file', data)
|
||||
self.check_import()
|
||||
|
||||
class NoteHandlerTest(CollectionTestBase):
|
||||
def setUp(self):
|
||||
super(NoteHandlerTest, self).setUp()
|
||||
self.handler = NoteHandler()
|
||||
|
||||
def execute(self, name, data, note_id):
|
||||
ids = ['collection_name', note_id]
|
||||
func = getattr(self.handler, name)
|
||||
req = RestHandlerRequest(self.mock_app, data, ids, {})
|
||||
return func(self.collection, req)
|
||||
|
||||
def test_index(self):
|
||||
self.add_default_note()
|
||||
|
||||
note_id = self.collection.findNotes('')[0]
|
||||
|
||||
ret = self.execute('index', {}, note_id)
|
||||
self.assertEqual(ret['id'], note_id)
|
||||
self.assertEqual(len(ret['fields']), 2)
|
||||
self.assertEqual(ret['flags'], 0)
|
||||
self.assertEqual(ret['model']['name'], 'Basic')
|
||||
self.assertEqual(ret['tags'], ['Tag1', 'Tag2'])
|
||||
self.assertEqual(ret['string_tags'], 'Tag1 Tag2')
|
||||
self.assertEqual(ret['usn'], -1)
|
||||
|
||||
def test_add_tags(self):
|
||||
self.add_default_note()
|
||||
note_id = self.collection.findNotes('')[0]
|
||||
note = self.collection.getNote(note_id)
|
||||
self.assertFalse('NT1' in note.tags)
|
||||
self.assertFalse('NT2' in note.tags)
|
||||
|
||||
self.execute('add_tags', {'tags': ['NT1', 'NT2']}, note_id)
|
||||
note = self.collection.getNote(note_id)
|
||||
self.assertTrue('NT1' in note.tags)
|
||||
self.assertTrue('NT2' in note.tags)
|
||||
|
||||
def test_remove_tags(self):
|
||||
self.add_default_note()
|
||||
note_id = self.collection.findNotes('')[0]
|
||||
note = self.collection.getNote(note_id)
|
||||
self.assertTrue('Tag1' in note.tags)
|
||||
self.assertTrue('Tag2' in note.tags)
|
||||
|
||||
self.execute('remove_tags', {'tags': ['Tag1', 'Tag2']}, note_id)
|
||||
note = self.collection.getNote(note_id)
|
||||
self.assertFalse('Tag1' in note.tags)
|
||||
self.assertFalse('Tag2' in note.tags)
|
||||
|
||||
class DeckHandlerTest(CollectionTestBase):
|
||||
def setUp(self):
|
||||
super(DeckHandlerTest, self).setUp()
|
||||
self.handler = DeckHandler()
|
||||
|
||||
def execute(self, name, data):
|
||||
ids = ['collection_name', '1']
|
||||
func = getattr(self.handler, name)
|
||||
req = RestHandlerRequest(self.mock_app, data, ids, {})
|
||||
return func(self.collection, req)
|
||||
|
||||
def test_index(self):
|
||||
ret = self.execute('index', {})
|
||||
#pprint(ret)
|
||||
self.assertEqual(ret['name'], 'Default')
|
||||
self.assertEqual(ret['id'], 1)
|
||||
self.assertEqual(ret['dyn'], False)
|
||||
|
||||
def test_next_card(self):
|
||||
self.mock_app.execute_handler.return_value = None
|
||||
|
||||
ret = self.execute('next_card', {})
|
||||
self.assertEqual(ret, None)
|
||||
self.mock_app.execute_handler.assert_called_with('collection', 'next_card', self.collection, RestHandlerRequest(self.mock_app, {'deck': '1'}, ['collection_name'], {}))
|
||||
|
||||
def test_get_conf(self):
|
||||
ret = self.execute('get_conf', {})
|
||||
#pprint(ret)
|
||||
self.assertEqual(ret['name'], 'Default')
|
||||
self.assertEqual(ret['id'], 1)
|
||||
self.assertEqual(ret['dyn'], False)
|
||||
|
||||
class CardHandlerTest(CollectionTestBase):
|
||||
def setUp(self):
|
||||
super(CardHandlerTest, self).setUp()
|
||||
self.handler = CardHandler()
|
||||
|
||||
def execute(self, name, data, card_id):
|
||||
ids = ['collection_name', card_id]
|
||||
func = getattr(self.handler, name)
|
||||
req = RestHandlerRequest(self.mock_app, data, ids, {})
|
||||
return func(self.collection, req)
|
||||
|
||||
def test_index_simple(self):
|
||||
self.add_default_note()
|
||||
|
||||
note_id = self.collection.findNotes('')[0]
|
||||
card_id = self.collection.findCards('')[0]
|
||||
|
||||
ret = self.execute('index', {}, card_id)
|
||||
self.assertEqual(ret['id'], card_id)
|
||||
self.assertEqual(ret['nid'], note_id)
|
||||
self.assertEqual(ret['did'], 1)
|
||||
self.assertFalse(ret.has_key('note'))
|
||||
self.assertFalse(ret.has_key('deck'))
|
||||
|
||||
def test_index_load(self):
|
||||
self.add_default_note()
|
||||
|
||||
note_id = self.collection.findNotes('')[0]
|
||||
card_id = self.collection.findCards('')[0]
|
||||
|
||||
ret = self.execute('index', {'load_note': 1, 'load_deck': 1}, card_id)
|
||||
self.assertEqual(ret['id'], card_id)
|
||||
self.assertEqual(ret['nid'], note_id)
|
||||
self.assertEqual(ret['did'], 1)
|
||||
self.assertEqual(ret['note']['id'], note_id)
|
||||
self.assertEqual(ret['note']['model']['name'], 'Basic')
|
||||
self.assertEqual(ret['deck']['name'], 'Default')
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@ -1,9 +0,0 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import AnkiServer
|
||||
from AnkiServer.apps.sync_app import SyncApp
|
||||
|
||||
class SyncAppTest(unittest.TestCase):
|
||||
pass
|
||||
|
||||
Loading…
Reference in New Issue
Block a user