diff --git a/AnkiServer/__init__.py b/AnkiServer/__init__.py index 26d92db..5958cdb 100644 --- a/AnkiServer/__init__.py +++ b/AnkiServer/__init__.py @@ -2,14 +2,14 @@ import sys sys.path.insert(0, "/usr/share/anki") -def server_runner(app, global_conf, **kw): - """ Special version of paste.httpserver.server_runner which calls - AnkiServer.threading.shutdown() on server exit.""" - - from paste.httpserver import server_runner as paste_server_runner - from AnkiServer.threading import shutdown - try: - paste_server_runner(app, global_conf, **kw) - finally: - shutdown() - +#def server_runner(app, global_conf, **kw): +# """ Special version of paste.httpserver.server_runner which calls +# AnkiServer.thread.shutdown() on server exit.""" +# +# from paste.httpserver import server_runner as paste_server_runner +# from AnkiServer.thread import shutdown +# try: +# paste_server_runner(app, global_conf, **kw) +# finally: +# shutdown() +# diff --git a/AnkiServer/apps/__init__.py b/AnkiServer/apps/__init__.py deleted file mode 100644 index ed88d78..0000000 --- a/AnkiServer/apps/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# package - diff --git a/AnkiServer/apps/rest_app.py b/AnkiServer/apps/rest_app.py deleted file mode 100644 index c8a8188..0000000 --- a/AnkiServer/apps/rest_app.py +++ /dev/null @@ -1,874 +0,0 @@ - -from webob.dec import wsgify -from webob.exc import * -from webob import Response - -#from pprint import pprint - -try: - import simplejson as json - from simplejson import JSONDecodeError -except ImportError: - import json - JSONDecodeError = ValueError - -import os, logging - -import anki.consts -import anki.lang -from anki.lang import _ as t - -__all__ = ['RestApp', 'RestHandlerBase', 'noReturnValue'] - -def noReturnValue(func): - func.hasReturnValue = False - return func - -class RestHandlerBase(object): - """Parent class for a handler group.""" - hasReturnValue = True - -class _RestHandlerWrapper(RestHandlerBase): - """Wrapper for functions that we can't modify.""" - def __init__(self, func_name, func, hasReturnValue=True): - self.func_name = func_name - self.func = func - self.hasReturnValue = hasReturnValue - def __call__(self, *args, **kw): - return self.func(*args, **kw) - -class RestHandlerRequest(object): - def __init__(self, app, data, ids, session): - self.app = app - self.data = data - self.ids = ids - self.session = session - - def copy(self): - return RestHandlerRequest(self.app, self.data.copy(), self.ids[:], self.session) - - def __eq__(self, other): - return self.app == other.app and self.data == other.data and self.ids == other.ids and self.session == other.session - -class RestApp(object): - """A WSGI app that implements RESTful operations on Collections, Decks and Cards.""" - - # Defines not only the valid handler types, but their position in the URL string - handler_types = ['collection', ['model', 'note', 'deck', 'card']] - - def __init__(self, data_root, allowed_hosts='*', setup_new_collection=None, use_default_handlers=True, collection_manager=None): - from AnkiServer.threading import getCollectionManager - - self.data_root = os.path.abspath(data_root) - self.allowed_hosts = allowed_hosts - self.setup_new_collection = setup_new_collection - - if collection_manager is not None: - self.collection_manager = collection_manager - else: - self.collection_manager = getCollectionManager() - - self.handlers = {} - for type_list in self.handler_types: - if type(type_list) is not list: - type_list = [type_list] - for handler_type in type_list: - self.handlers[handler_type] = {} - - if use_default_handlers: - self.add_handler_group('collection', CollectionHandler()) - self.add_handler_group('note', NoteHandler()) - self.add_handler_group('model', ModelHandler()) - self.add_handler_group('deck', DeckHandler()) - self.add_handler_group('card', CardHandler()) - - # hold per collection session data - self.sessions = {} - - def add_handler(self, type, name, handler): - """Adds a callback handler for a type (collection, deck, card) with a unique name. - - - 'type' is the item that will be worked on, for example: collection, deck, and card. - - - 'name' is a unique name for the handler that gets used in the URL. - - - 'handler' is a callable that takes (collection, data, ids). - """ - - if self.handlers[type].has_key(name): - raise "Handler already for %(type)s/%(name)s exists!" - self.handlers[type][name] = handler - - def add_handler_group(self, type, group): - """Adds several handlers for every public method on an object descended from RestHandlerBase. - - This allows you to create a single class with several methods, so that you can quickly - create a group of related handlers.""" - - import inspect - for name, method in inspect.getmembers(group, predicate=inspect.ismethod): - if not name.startswith('_'): - if hasattr(group, 'hasReturnValue') and not hasattr(method, 'hasReturnValue'): - method = _RestHandlerWrapper(group.__class__.__name__ + '.' + name, method, group.hasReturnValue) - self.add_handler(type, name, method) - - def execute_handler(self, type, name, col, req): - """Executes the handler with the given type and name, passing in the col and req as arguments.""" - - handler, hasReturnValue = self._getHandler(type, name) - ret = handler(col, req) - if hasReturnValue: - return ret - - def list_collections(self): - """Returns an array of valid collection names in our self.data_path.""" - return [x for x in os.listdir(self.data_root) if os.path.exists(os.path.join(self.data_root, x, 'collection.anki2'))] - - def _checkRequest(self, req): - """Raises an exception if the request isn't allowed or valid for some reason.""" - if self.allowed_hosts != '*': - try: - remote_addr = req.headers['X-Forwarded-For'] - except KeyError: - remote_addr = req.remote_addr - if remote_addr != self.allowed_hosts: - raise HTTPForbidden() - - if req.method != 'POST': - raise HTTPMethodNotAllowed(allow=['POST']) - - def _parsePath(self, path): - """Takes a request path and returns a tuple containing the handler type, name - and a list of ids. - - Raises an HTTPNotFound exception if the path is invalid.""" - - if path in ('', '/'): - raise HTTPNotFound() - - # split the URL into a list of parts - if path[0] == '/': - path = path[1:] - parts = path.split('/') - - # pull the type and context from the URL parts - handler_type = None - ids = [] - for type_list in self.handler_types: - if len(parts) == 0: - break - - # some URL positions can have multiple types - if type(type_list) is not list: - type_list = [type_list] - - # get the handler_type - if parts[0] not in type_list: - break - handler_type = parts.pop(0) - - # add the id to the id list - if len(parts) > 0: - ids.append(parts.pop(0)) - # break if we don't have enough parts to make a new type/id pair - if len(parts) < 2: - break - - # sanity check to make sure the URL is valid - if len(parts) > 1 or len(ids) == 0: - raise HTTPNotFound() - - # get the handler name - if len(parts) == 0: - name = 'index' - else: - name = parts[0] - - return (handler_type, name, ids) - - def _getCollectionPath(self, collection_id): - """Returns the path to the collection based on the collection_id from the request. - - Raises HTTPBadRequest if the collection_id is invalid.""" - - path = os.path.normpath(os.path.join(self.data_root, collection_id, 'collection.anki2')) - if path[0:len(self.data_root)] != self.data_root: - # attempting to escape our data jail! - raise HTTPBadRequest('"%s" is not a valid collection' % collection_id) - - return path - - def _getHandler(self, type, name): - """Returns a tuple containing handler function for this type and name, and a boolean flag - if that handler has a return value. - - Raises an HTTPNotFound exception if the handler doesn't exist.""" - - # get the handler function - try: - handler = self.handlers[type][name] - except KeyError: - raise HTTPNotFound() - - # get if we have a return value - hasReturnValue = True - if hasattr(handler, 'hasReturnValue'): - hasReturnValue = handler.hasReturnValue - - return (handler, hasReturnValue) - - def _parseRequestBody(self, req): - """Parses the request body (JSON) into a Python dict and returns it. - - Raises an HTTPBadRequest exception if the request isn't valid JSON.""" - - try: - data = json.loads(req.body) - except JSONDecodeError, e: - logging.error(req.path+': Unable to parse JSON: '+str(e), exc_info=True) - raise HTTPBadRequest() - - # fix for a JSON encoding 'quirk' in PHP - if type(data) == list and len(data) == 0: - data = {} - - # make the keys into non-unicode strings - data = dict([(str(k), v) for k, v in data.items()]) - - return data - - @wsgify - def __call__(self, req): - # make sure the request is valid - self._checkRequest(req) - - if req.path == '/list_collections': - return Response(json.dumps(self.list_collections()), content_type='application/json') - - # parse the path - type, name, ids = self._parsePath(req.path) - - # get the collection path - collection_path = self._getCollectionPath(ids[0]) - print collection_path - - # get the handler function - handler, hasReturnValue = self._getHandler(type, name) - - # parse the request body - data = self._parseRequestBody(req) - - # get the users session - try: - session = self.sessions[ids[0]] - except KeyError: - session = self.sessions[ids[0]] = {} - - # debug - from pprint import pprint - pprint(data) - - # run it! - col = self.collection_manager.get_collection(collection_path, self.setup_new_collection) - handler_request = RestHandlerRequest(self, data, ids, session) - try: - output = col.execute(handler, [handler_request], {}, hasReturnValue) - except Exception, e: - logging.error(e) - return HTTPInternalServerError() - - if output is None: - return Response('', content_type='text/plain') - else: - return Response(json.dumps(output), content_type='application/json') - -class CollectionHandler(RestHandlerBase): - """Default handler group for 'collection' type.""" - - # - # MODELS - Store fields definitions and templates for notes - # - - def list_models(self, col, req): - # This is already a list of dicts, so it doesn't need to be serialized - return col.models.all() - - def find_model_by_name(self, col, req): - # This is already a list of dicts, so it doesn't need to be serialized - return col.models.byName(req.data['model']) - - # - # NOTES - Information (in fields per the model) that can generate a card - # (based on a template from the model). - # - - def find_notes(self, col, req): - query = req.data.get('query', '') - ids = col.findNotes(query) - - if req.data.get('preload', False): - notes = [NoteHandler._serialize(col.getNote(id)) for id in ids] - else: - notes = [{'id': id} for id in ids] - - return notes - - def latest_notes(self, col, req): - # TODO: use SQLAlchemy objects to do this - sql = "SELECT n.id FROM notes AS n"; - args = [] - if req.data.has_key('updated_since'): - sql += ' WHERE n.mod > ?' - args.append(req.data['updated_since']) - sql += ' ORDER BY n.mod DESC' - sql += ' LIMIT ' + str(req.data.get('limit', 10)) - ids = col.db.list(sql, *args) - - if req.data.get('preload', False): - notes = [NoteHandler._serialize(col.getNote(id)) for id in ids] - else: - notes = [{'id': id} for id in ids] - - return notes - - @noReturnValue - def add_note(self, col, req): - from anki.notes import Note - - # TODO: I think this would be better with 'model' for the name - # and 'mid' for the model id. - if type(req.data['model']) in (str, unicode): - model = col.models.byName(req.data['model']) - else: - model = col.models.get(req.data['model']) - - note = Note(col, model) - for name, value in req.data['fields'].items(): - note[name] = value - - if req.data.has_key('tags'): - note.setTagsFromStr(req.data['tags']) - - col.addNote(note) - - def list_tags(self, col, req): - return col.tags.all() - - # - # DECKS - Groups of cards - # - - def list_decks(self, col, req): - # This is already a list of dicts, so it doesn't need to be serialized - return col.decks.all() - - @noReturnValue - def select_deck(self, col, req): - col.decks.select(req.data['deck_id']) - - dyn_modes = { - 'random': anki.consts.DYN_RANDOM, - 'added': anki.consts.DYN_ADDED, - 'due': anki.consts.DYN_DUE, - } - - def create_dynamic_deck(self, col, req): - name = req.data.get('name', t('Custom Study Session')) - deck = col.decks.byName(name) - if deck: - if not deck['dyn']: - raise HTTPBadRequest("There is an existing non-dynamic deck with the name %s" % name) - - # safe to empty it because it's a dynamic deck - # TODO: maybe this should be an option? - col.sched.emptyDyn(deck['id']) - else: - deck = col.decks.get(col.decks.newDyn(name)) - - query = req.data.get('query', '') - count = int(req.data.get('count', 100)) - mode = req.data.get('mode', 'random') - - try: - mode = self.dyn_modes[mode] - except KeyError: - raise HTTPBadRequest("Unknown mode: %s" % mode) - - deck['terms'][0] = [query, count, mode] - - if mode != anki.consts.DYN_RANDOM: - deck['resched'] = True - else: - deck['resched'] = False - - if not col.sched.rebuildDyn(deck['id']): - raise HTTPBadRequest("No cards matched the criteria you provided") - - col.decks.save(deck) - col.sched.reset() - - return deck - - # - # CARD - A specific card in a deck with a history of review (generated from - # a note based on the template). - # - - def find_cards(self, col, req): - query = req.data.get('query', '') - order = req.data.get('order', False) - ids = anki.find.Finder(col).findCards(query, order) - - if req.data.get('preload', False): - cards = [CardHandler._serialize(col.getCard(id), req.data) for id in ids] - else: - cards = [{'id': id} for id in ids] - - return cards - - def latest_cards(self, col, req): - # TODO: use SQLAlchemy objects to do this - sql = "SELECT c.id FROM notes AS n INNER JOIN cards AS c ON c.nid = n.id"; - args = [] - if req.data.has_key('updated_since'): - sql += ' WHERE n.mod > ?' - args.append(req.data['updated_since']) - sql += ' ORDER BY n.mod DESC' - sql += ' LIMIT ' + str(req.data.get('limit', 10)) - ids = col.db.list(sql, *args) - - if req.data.get('preload', False): - cards = [CardHandler._serialize(col.getCard(id), req.data) for id in ids] - else: - cards = [{'id': id} for id in ids] - - return cards - - # - # SCHEDULER - Controls card review, ie. intervals, what cards are due, answering a card, etc. - # - - def reset_scheduler(self, col, req): - if req.data.has_key('deck'): - deck = DeckHandler._get_deck(col, req.data['deck']) - col.decks.select(deck['id']) - - col.sched.reset() - counts = col.sched.counts() - return { - 'new_cards': counts[0], - 'learning_cards': counts[1], - 'review_cards': counts[1], - } - - def extend_scheduler_limits(self, col, req): - new_cards = int(req.data.get('new_cards', 0)) - review_cards = int(req.data.get('review_cards', 0)) - col.sched.extendLimits(new_cards, review_cards) - col.sched.reset() - - button_labels = ['Easy', 'Good', 'Hard'] - - def _get_answer_buttons(self, col, card): - l = [] - - # Put the correct number of buttons - cnt = col.sched.answerButtons(card) - for idx in range(0, cnt - 1): - l.append(self.button_labels[idx]) - l.append('Again') - l.reverse() - - # Loop through and add the ease, estimated time (in seconds) and other info - return [{ - 'ease': ease, - 'label': label, - 'string_label': t(label), - 'interval': col.sched.nextIvl(card, ease), - 'string_interval': col.sched.nextIvlStr(card, ease), - } for ease, label in enumerate(l, 1)] - - def next_card(self, col, req): - if req.data.has_key('deck'): - deck = DeckHandler._get_deck(col, req.data['deck']) - col.decks.select(deck['id']) - - card = col.sched.getCard() - if card is None: - return None - - # put it into the card cache to be removed when we answer it - #if not req.session.has_key('cards'): - # req.session['cards'] = {} - #req.session['cards'][long(card.id)] = card - - card.startTimer() - - result = CardHandler._serialize(card, req.data) - result['answer_buttons'] = self._get_answer_buttons(col, card) - - return result - - @noReturnValue - def answer_card(self, col, req): - import time - - card_id = long(req.data['id']) - ease = int(req.data['ease']) - - card = col.getCard(card_id) - if req.data.has_key('timerStarted'): - card.timerStarted = float(req.data['timerStarted']) - else: - card.timerStarted = time.time() - - col.sched.answerCard(card, ease) - - @noReturnValue - def suspend_cards(self, col, req): - card_ids = req.data['ids'] - col.sched.suspendCards(card_ids) - - @noReturnValue - def unsuspend_cards(self, col, req): - card_ids = req.data['ids'] - col.sched.unsuspendCards(card_ids) - - def cards_recent_ease(self, col, req): - """Returns the most recent ease for each card.""" - - # TODO: Use sqlalchemy to build this query! - sql = "SELECT r.cid, r.ease, r.id FROM revlog AS r INNER JOIN (SELECT cid, MAX(id) AS id FROM revlog GROUP BY cid) AS q ON r.cid = q.cid AND r.id = q.id" - where = [] - if req.data.has_key('ids'): - where.append('ids IN (' + (','.join(["'%s'" % x for x in req.data['ids']])) + ')') - if len(where) > 0: - sql += ' WHERE ' + ' AND '.join(where) - - result = [] - for r in col.db.all(sql): - result.append({'id': r[0], 'ease': r[1], 'timestamp': int(r[2] / 1000)}) - - return result - - def latest_revlog(self, col, req): - """Returns recent entries from the revlog.""" - - # TODO: Use sqlalchemy to build this query! - sql = "SELECT r.id, r.ease, r.cid, r.usn, r.ivl, r.lastIvl, r.factor, r.time, r.type FROM revlog AS r" - args = [] - if req.data.has_key('updated_since'): - sql += ' WHERE r.id > ?' - args.append(long(req.data['updated_since']) * 1000) - sql += ' ORDER BY r.id DESC' - sql += ' LIMIT ' + str(req.data.get('limit', 100)) - - revlog = col.db.all(sql, *args) - return [{ - 'id': r[0], - 'ease': r[1], - 'timestamp': int(r[0] / 1000), - 'card_id': r[2], - 'usn': r[3], - 'interval': r[4], - 'last_interval': r[5], - 'factor': r[6], - 'time': r[7], - 'type': r[8], - } for r in revlog] - - stats_reports = { - 'today': 'todayStats', - 'due': 'dueGraph', - 'reps': 'repsGraph', - 'interval': 'ivlGraph', - 'hourly': 'hourGraph', - 'ease': 'easeGraph', - 'card': 'cardGraph', - 'footer': 'footer', - } - stats_reports_order = ['today', 'due', 'reps', 'interval', 'hourly', 'ease', 'card', 'footer'] - - def stats_report(self, col, req): - import anki.stats - import re - - stats = anki.stats.CollectionStats(col) - stats.width = int(req.data.get('width', 600)) - stats.height = int(req.data.get('height', 200)) - reports = req.data.get('reports', self.stats_reports_order) - include_css = req.data.get('include_css', False) - include_jquery = req.data.get('include_jquery', False) - include_flot = req.data.get('include_flot', False) - - if include_css: - from anki.statsbg import bg - html = stats.css % bg - else: - html = '' - - for name in reports: - if not self.stats_reports.has_key(name): - raise HTTPBadRequest("Unknown report name: %s" % name) - func = getattr(stats, self.stats_reports[name]) - - html += '
' % name - html += func() - html += '
' - - # fix an error in some inline styles - # TODO: submit a patch to Anki! - html = re.sub(r'style="width:([0-9\.]+); height:([0-9\.]+);"', r'style="width:\1px; height: \2px;"', html) - html = re.sub(r'-webkit-transform: ([^;]+);', r'-webkit-transform: \1; -moz-transform: \1; -ms-transform: \1; -o-transform: \1; transform: \1;', html) - - scripts = [] - if include_jquery or include_flot: - import anki.js - if include_jquery: - scripts.append(anki.js.jquery) - if include_flot: - scripts.append(anki.js.plot) - if len(scripts) > 0: - html = "" % ''.join(scripts) + html - - return html - - # - # GLOBAL / MISC - # - - @noReturnValue - def set_language(self, col, req): - anki.lang.setLang(req.data['code']) - -class ImportExportHandler(RestHandlerBase): - """Handler group for the 'collection' type, but it's not added by default.""" - - def _get_filedata(self, data): - import urllib2 - - if data.has_key('data'): - return data['data'] - - fd = None - try: - fd = urllib2.urlopen(data['url']) - filedata = fd.read() - finally: - if fd is not None: - fd.close() - - return filedata - - def _get_importer_class(self, data): - filetype = data['filetype'] - - from AnkiServer.importer import get_importer_class - importer_class = get_importer_class(filetype) - if importer_class is None: - raise HTTPBadRequest("Unknown filetype '%s'" % filetype) - - return importer_class - - def import_file(self, col, req): - import AnkiServer.importer - import tempfile - - # get the importer class - importer_class = self._get_importer_class(req.data) - - # get the file data - filedata = self._get_filedata(req.data) - - # write the file data to a temporary file - try: - path = None - with tempfile.NamedTemporaryFile('wt', delete=False) as fd: - path = fd.name - fd.write(filedata) - - AnkiServer.importer.import_file(importer_class, col, path) - finally: - if path is not None: - os.unlink(path) - -class ModelHandler(RestHandlerBase): - """Default handler group for 'model' type.""" - - def field_names(self, col, req): - model = col.models.get(req.ids[1]) - if model is None: - raise HTTPNotFound() - return col.models.fieldNames(model) - -class NoteHandler(RestHandlerBase): - """Default handler group for 'note' type.""" - - @staticmethod - def _serialize(note): - d = { - 'id': note.id, - 'guid': note.guid, - 'model': note.model(), - 'mid': note.mid, - 'mod': note.mod, - 'scm': note.scm, - 'tags': note.tags, - 'string_tags': ' '.join(note.tags), - 'fields': {}, - 'flags': note.flags, - 'usn': note.usn, - } - - # add all the fields - for name, value in note.items(): - d['fields'][name] = value - - return d - - def index(self, col, req): - note = col.getNote(req.ids[1]) - return self._serialize(note) - - @noReturnValue - def add_tags(self, col, req): - note = col.getNote(req.ids[1]) - for tag in req.data['tags']: - note.addTag(tag) - note.flush() - - @noReturnValue - def remove_tags(self, col, req): - note = col.getNote(req.ids[1]) - for tag in req.data['tags']: - note.delTag(tag) - note.flush() - -class DeckHandler(RestHandlerBase): - """Default handler group for 'deck' type.""" - - @staticmethod - def _get_deck(col, val): - try: - did = long(val) - deck = col.decks.get(did, False) - except ValueError: - deck = col.decks.byName(val) - - if deck is None: - raise HTTPNotFound('No deck with id or name: ' + str(val)) - - return deck - - def index(self, col, req): - return self._get_deck(col, req.ids[1]) - - def next_card(self, col, req): - req_copy = req.copy() - req_copy.data['deck'] = req.ids[1] - del req_copy.ids[1] - - # forward this to the CollectionHandler - return req.app.execute_handler('collection', 'next_card', col, req_copy) - - def get_conf(self, col, req): - # TODO: should probably live in a ConfHandler - return col.decks.confForDid(req.ids[1]) - - @noReturnValue - def set_update_conf(self, col, req): - data = req.data.copy() - del data['id'] - - conf = col.decks.confForDid(req.ids[1]) - conf = conf.copy() - conf.update(data) - - col.decks.updateConf(conf) - -class CardHandler(RestHandlerBase): - """Default handler group for 'card' type.""" - - @staticmethod - def _serialize(card, opts): - d = { - 'id': card.id, - 'isEmpty': card.isEmpty(), - 'css': card.css(), - 'question': card._getQA()['q'], - 'answer': card._getQA()['a'], - 'did': card.did, - 'due': card.due, - 'factor': card.factor, - 'ivl': card.ivl, - 'lapses': card.lapses, - 'left': card.left, - 'mod': card.mod, - 'nid': card.nid, - 'odid': card.odid, - 'odue': card.odue, - 'ord': card.ord, - 'queue': card.queue, - 'reps': card.reps, - 'type': card.type, - 'usn': card.usn, - 'timerStarted': card.timerStarted, - } - - if opts.get('load_note', False): - d['note'] = NoteHandler._serialize(card.col.getNote(card.nid)) - - if opts.get('load_deck', False): - d['deck'] = card.col.decks.get(card.did) - - if opts.get('load_latest_revlog', False): - d['latest_revlog'] = CardHandler._latest_revlog(card.col, card.id) - - return d - - @staticmethod - def _latest_revlog(col, card_id): - r = col.db.first("SELECT r.id, r.ease FROM revlog AS r WHERE r.cid = ? ORDER BY id DESC LIMIT 1", card_id) - if r: - return {'id': r[0], 'ease': r[1], 'timestamp': int(r[0] / 1000)} - - def index(self, col, req): - card = col.getCard(req.ids[1]) - return self._serialize(card, req.data) - - def _forward_to_note(self, col, req, name): - card = col.getCard(req.ids[1]) - - req_copy = req.copy() - req_copy.ids[1] = card.nid - - return req.app.execute_handler('note', name, col, req) - - @noReturnValue - def add_tags(self, col, req): - self._forward_to_note(col, req, 'add_tags') - - @noReturnValue - def remove_tags(self, col, req): - self._forward_to_note(col, req, 'remove_tags') - - def stats_report(self, col, req): - card = col.getCard(req.ids[1]) - return col.cardStats(card) - - def latest_revlog(self, col, req): - return self._latest_revlog(col, req.ids[1]) - -# Our entry point -def make_app(global_conf, **local_conf): - # TODO: we should setup the default language from conf! - - # setup the logger - from AnkiServer.utils import setup_logging - setup_logging(local_conf.get('logging.config_file')) - - return RestApp( - data_root=local_conf.get('data_root', '.'), - allowed_hosts=local_conf.get('allowed_hosts', '*') - ) - diff --git a/AnkiServer/apps/sync_old.py b/AnkiServer/apps/sync_old.py deleted file mode 100644 index d504993..0000000 --- a/AnkiServer/apps/sync_old.py +++ /dev/null @@ -1,377 +0,0 @@ - -from webob.dec import wsgify -from webob.exc import * -from webob import Response - -import anki -from anki.sync import HttpSyncServer, CHUNK_SIZE -from anki.db import sqlite -from anki.utils import checksum - -import AnkiServer.deck - -import MySQLdb - -try: - import simplejson as json -except ImportError: - import json - -import os, zlib, tempfile, time - -def makeArgs(mdict): - d = dict(mdict.items()) - # TODO: use password/username/version for something? - for k in ['p','u','v','d']: - if d.has_key(k): - del d[k] - return d - -class FileIterable(object): - def __init__(self, fn): - self.fn = fn - def __iter__(self): - return FileIterator(self.fn) - -class FileIterator(object): - def __init__(self, fn): - self.fn = fn - self.fo = open(self.fn, 'rb') - self.c = zlib.compressobj() - self.flushed = False - def __iter__(self): - return self - def next(self): - data = self.fo.read(CHUNK_SIZE) - if not data: - if not self.flushed: - self.flushed = True - return self.c.flush() - else: - raise StopIteration - return self.c.compress(data) - -def lock_deck(path): - """ Gets exclusive access to this deck path. If there is a DeckThread running on this - deck, this will wait for its current operations to complete before temporarily stopping - it. """ - - from AnkiServer.deck import thread_pool - - if thread_pool.decks.has_key(path): - thread_pool.decks[path].stop_and_wait() - thread_pool.lock(path) - -def unlock_deck(path): - """ Release exclusive access to this deck path. """ - from AnkiServer.deck import thread_pool - thread_pool.unlock(path) - -class SyncAppHandler(HttpSyncServer): - operations = ['summary','applyPayload','finish','createDeck','getOneWayPayload'] - - def __init__(self): - HttpSyncServer.__init__(self) - - def createDeck(self, name): - # The HttpSyncServer.createDeck doesn't return a valid value! This seems to be - # a bug in libanki.sync ... - return self.stuff({"status": "OK"}) - - def finish(self): - # The HttpSyncServer has no finish() function... I can only assume this is a bug too! - return self.stuff("OK") - -class SyncApp(object): - valid_urls = SyncAppHandler.operations + ['getDecks','fullup','fulldown'] - - def __init__(self, **kw): - self.data_root = os.path.abspath(kw.get('data_root', '.')) - self.base_url = kw.get('base_url', '/') - self.users = {} - - # make sure the base_url has a trailing slash - if len(self.base_url) == 0: - self.base_url = '/' - elif self.base_url[-1] != '/': - self.base_url = base_url + '/' - - # setup mysql connection - mysql_args = {} - for k, v in kw.items(): - if k.startswith('mysql.'): - mysql_args[k[6:]] = v - self.mysql_args = mysql_args - self.conn = None - - # get SQL statements - self.sql_check_password = kw.get('sql_check_password') - self.sql_username2dirname = kw.get('sql_username2dirname') - - default_libanki_version = '.'.join(anki.version.split('.')[:2]) - - def user_libanki_version(self, u): - try: - s = self.users[u]['libanki'] - except KeyError: - return self.default_libanki_version - - parts = s.split('.') - if parts[0] == '1': - if parts[1] == '0': - return '1.0' - elif parts[1] in ('1','2'): - return '1.2' - - return self.default_libanki_version - - # Mimcs from anki.sync.SyncTools.stuff() - def _stuff(self, data): - return zlib.compress(json.dumps(data)) - - def _connect_mysql(self): - if self.conn is None and len(self.mysql_args) > 0: - self.conn = MySQLdb.connect(**self.mysql_args) - - def _execute_sql(self, sql, args=()): - self._connect_mysql() - try: - cur = self.conn.cursor() - cur.execute(sql, args) - except MySQLdb.OperationalError, e: - if e.args[0] == 2006: - # MySQL server has gone away message - self.conn = None - self._connect_mysql() - cur = self.conn.cursor() - cur.execute(sql, args) - return cur - - def check_password(self, username, password): - if len(self.mysql_args) > 0 and self.sql_check_password is not None: - cur = self._execute_sql(self.sql_check_password, (username, password)) - row = cur.fetchone() - return row is not None - - return True - - def username2dirname(self, username): - if len(self.mysql_args) > 0 and self.sql_username2dirname is not None: - cur = self._execute_sql(self.sql_username2dirname, (username,)) - row = cur.fetchone() - if row is None: - return None - return str(row[0]) - - return username - - def _getDecks(self, user_path): - decks = {} - - if os.path.exists(user_path): - # It is a dict of {'deckName':[modified,lastSync]} - for fn in os.listdir(unicode(user_path, 'utf-8')): - if len(fn) > 5 and fn[-5:] == '.anki': - d = os.path.abspath(os.path.join(user_path, fn)) - - # For simplicity, we will always open a thread. But this probably - # isn't necessary! - thread = AnkiServer.deck.thread_pool.start(d) - def lookupModifiedLastSync(wrapper): - deck = wrapper.open() - return [deck.modified, deck.lastSync] - res = thread.execute(lookupModifiedLastSync, [thread.wrapper]) - -# if thread_pool.threads.has_key(d): -# thread = thread_pool.threads[d] -# def lookupModifiedLastSync(wrapper): -# deck = wrapper.open() -# return [deck.modified, deck.lastSync] -# res = thread.execute(lookup, [thread.wrapper]) -# else: -# conn = sqlite.connect(d) -# cur = conn.cursor() -# cur.execute("select modified, lastSync from decks") -# -# res = list(cur.fetchone()) -# -# cur.close() -# conn.close() - - #self.decks[fn[:-5]] = ["%.5f" % x for x in res] - decks[fn[:-5]] = res - - # same as HttpSyncServer.getDecks() - return self._stuff({ - "status": "OK", - "decks": decks, - "timestamp": time.time(), - }) - - def _fullup(self, wrapper, infile, version): - wrapper.close() - path = wrapper.path - - # DRS: most of this function was graciously copied - # from anki.sync.SyncTools.fullSyncFromServer() - (fd, tmpname) = tempfile.mkstemp(dir=os.getcwd(), prefix="fullsync") - outfile = open(tmpname, 'wb') - decomp = zlib.decompressobj() - while 1: - data = infile.read(CHUNK_SIZE) - if not data: - outfile.write(decomp.flush()) - break - outfile.write(decomp.decompress(data)) - infile.close() - outfile.close() - os.close(fd) - # if we were successful, overwrite old deck - if os.path.exists(path): - os.unlink(path) - os.rename(tmpname, path) - # reset the deck name - c = sqlite.connect(path) - lastSync = time.time() - if version == '1': - c.execute("update decks set lastSync = ?", [lastSync]) - elif version == '2': - c.execute("update decks set syncName = ?, lastSync = ?", - [checksum(path.encode("utf-8")), lastSync]) - c.commit() - c.close() - - return lastSync - - def _stuffedResp(self, data): - return Response( - status='200 OK', - content_type='application/json', - content_encoding='deflate', - body=data) - - @wsgify - def __call__(self, req): - if req.path.startswith(self.base_url): - url = req.path[len(self.base_url):] - if url not in self.valid_urls: - raise HTTPNotFound() - - # get and check username and password - try: - u = req.str_params.getone('u') - p = req.str_params.getone('p') - except KeyError: - raise HTTPBadRequest('Must pass username and password') - if not self.check_password(u, p): - #raise HTTPBadRequest('Incorrect username or password') - return self._stuffedResp(self._stuff({'status':'invalidUserPass'})) - dirname = self.username2dirname(u) - if dirname is None: - raise HTTPBadRequest('Incorrect username or password') - user_path = os.path.join(self.data_root, dirname) - - # get and lock the (optional) deck for this request - d = None - try: - d = unicode(req.str_params.getone('d'), 'utf-8') - # AnkiDesktop actually passes us the string value 'None'! - if d == 'None': - d = None - except KeyError: - pass - if d is not None: - # get the full deck path name - d = os.path.abspath(os.path.join(user_path, d)+'.anki') - if d[:len(user_path)] != user_path: - raise HTTPBadRequest('Bad deck name') - thread = AnkiServer.deck.thread_pool.start(d) - else: - thread = None - - if url == 'getDecks': - # force the version up to 1.2.x - v = req.str_params.getone('libanki') - if v.startswith('0.') or v.startswith('1.0'): - return self._stuffedResp(self._stuff({'status':'oldVersion'})) - - # store the data the user passes us keyed with the username. This - # will be used later by SyncAppHandler for version compatibility. - self.users[u] = makeArgs(req.str_params) - return self._stuffedResp(self._getDecks(user_path)) - - elif url in SyncAppHandler.operations: - handler = SyncAppHandler() - func = getattr(handler, url) - args = makeArgs(req.str_params) - - if thread is not None: - # If this is for a specific deck, then it needs to run - # inside of the DeckThread. - def runFunc(wrapper): - handler.deck = wrapper.open() - ret = func(**args) - handler.deck.save() - return ret - runFunc.func_name = url - ret = thread.execute(runFunc, [thread.wrapper]) - else: - # Otherwise, we can simply execute it in this thread. - ret = func(**args) - - # clean-up user data stored in getDecks - if url == 'finish': - del self.users[u] - - return self._stuffedResp(ret) - - elif url == 'fulldown': - # set the syncTime before we send it - def setupForSync(wrapper): - wrapper.close() - c = sqlite.connect(d) - lastSync = time.time() - c.execute("update decks set lastSync = ?", [lastSync]) - c.commit() - c.close() - thread.execute(setupForSync, [thread.wrapper]) - - return Response(status='200 OK', content_type='application/octet-stream', content_encoding='deflate', content_disposition='attachment; filename="'+os.path.basename(d).encode('utf-8')+'"', app_iter=FileIterable(d)) - elif url == 'fullup': - #version = self.user_libanki_version(u) - try: - version = req.str_params.getone('v') - except KeyError: - version = '1' - - infile = req.str_params['deck'].file - lastSync = thread.execute(self._fullup, [thread.wrapper, infile, version]) - - # append the 'lastSync' value for libanki 1.1 and 1.2 - if version == '2': - body = 'OK '+str(lastSync) - else: - body = 'OK' - - return Response(status='200 OK', content_type='application/text', body=body) - - return Response(status='200 OK', content_type='text/plain', body='Anki Server') - -# Our entry point -def make_app(global_conf, **local_conf): - return SyncApp(**local_conf) - -def main(): - from wsgiref.simple_server import make_server - - ankiserver = DeckApp('.', '/sync/') - httpd = make_server('', 8001, ankiserver) - try: - httpd.serve_forever() - except KeyboardInterrupt: - print "Exiting ..." - finally: - AnkiServer.deck.thread_pool.shutdown() - -if __name__ == '__main__': main() - diff --git a/AnkiServer/importer.py b/AnkiServer/importer.py deleted file mode 100644 index 762440c..0000000 --- a/AnkiServer/importer.py +++ /dev/null @@ -1,100 +0,0 @@ - -from anki.importing.csvfile import TextImporter -from anki.importing.apkg import AnkiPackageImporter -from anki.importing.anki1 import Anki1Importer -from anki.importing.supermemo_xml import SupermemoXmlImporter -from anki.importing.mnemo import MnemosyneImporter -from anki.importing.pauker import PaukerImporter - -__all__ = ['get_importer_class', 'import_file'] - -importers = { - 'text': TextImporter, - 'apkg': AnkiPackageImporter, - 'anki1': Anki1Importer, - 'supermemo_xml': SupermemoXmlImporter, - 'mnemosyne': MnemosyneImporter, - 'pauker': PaukerImporter, -} - -def get_importer_class(type): - global importers - return importers.get(type) - -def import_file(importer_class, col, path, allow_update = False): - importer = importer_class(col, path) - - if allow_update: - importer.allowUpdate = True - - if importer.needMapper: - importer.open() - - importer.run() - -# -# Monkey patch anki.importing.anki2 to support updating existing notes. -# TODO: submit a patch to Anki! -# - -def _importNotes(self): - # build guid -> (id,mod,mid) hash & map of existing note ids - self._notes = {} - existing = {} - for id, guid, mod, mid in self.dst.db.execute( - "select id, guid, mod, mid from notes"): - self._notes[guid] = (id, mod, mid) - existing[id] = True - # we may need to rewrite the guid if the model schemas don't match, - # so we need to keep track of the changes for the card import stage - self._changedGuids = {} - # iterate over source collection - add = [] - dirty = [] - usn = self.dst.usn() - dupes = 0 - for note in self.src.db.execute( - "select * from notes"): - # turn the db result into a mutable list - note = list(note) - shouldAdd = self._uniquifyNote(note) - if shouldAdd: - # ensure id is unique - while note[0] in existing: - note[0] += 999 - existing[note[0]] = True - # bump usn - note[4] = usn - # update media references in case of dupes - note[6] = self._mungeMedia(note[MID], note[6]) - add.append(note) - dirty.append(note[0]) - # note we have the added the guid - self._notes[note[GUID]] = (note[0], note[3], note[MID]) - else: - dupes += 1 - - # update existing note - newer = note[3] > mod - if self.allowUpdate and self._mid(mid) == mid and newer: - localNid = self._notes[note[GUID]][0] - note[0] = localNid - note[4] = usn - add.append(note) - dirty.append(note[0]) - - if dupes: - self.log.append(_("Already in collection: %s.") % (ngettext( - "%d note", "%d notes", dupes) % dupes)) - # add to col - self.dst.db.executemany( - "insert or replace into notes values (?,?,?,?,?,?,?,?,?,?,?)", - add) - self.dst.updateFieldCache(dirty) - self.dst.tags.registerNotes(dirty) - -from anki.importing.anki2 import Anki2Importer, MID, GUID -from anki.lang import _, ngettext -Anki2Importer._importNotes = _importNotes -Anki2Importer.allowUpdate = False - diff --git a/AnkiServer/logpatch.py b/AnkiServer/logpatch.py deleted file mode 100644 index eeeda58..0000000 --- a/AnkiServer/logpatch.py +++ /dev/null @@ -1,96 +0,0 @@ - -import logging -import logging.handlers -import types - -# The SMTPHandler taken from python 2.6 -class SMTPHandler(logging.Handler): - """ - A handler class which sends an SMTP email for each logging event. - """ - def __init__(self, mailhost, fromaddr, toaddrs, subject, credentials=None): - """ - Initialize the handler. - - Initialize the instance with the from and to addresses and subject - line of the email. To specify a non-standard SMTP port, use the - (host, port) tuple format for the mailhost argument. To specify - authentication credentials, supply a (username, password) tuple - for the credentials argument. - """ - logging.Handler.__init__(self) - if type(mailhost) == types.TupleType: - self.mailhost, self.mailport = mailhost - else: - self.mailhost, self.mailport = mailhost, None - if type(credentials) == types.TupleType: - self.username, self.password = credentials - else: - self.username = None - self.fromaddr = fromaddr - if type(toaddrs) == types.StringType: - toaddrs = [toaddrs] - self.toaddrs = toaddrs - self.subject = subject - - def getSubject(self, record): - """ - Determine the subject for the email. - - If you want to specify a subject line which is record-dependent, - override this method. - """ - return self.subject - - weekdayname = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] - - monthname = [None, - 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', - 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] - - def date_time(self): - """ - Return the current date and time formatted for a MIME header. - Needed for Python 1.5.2 (no email package available) - """ - year, month, day, hh, mm, ss, wd, y, z = time.gmtime(time.time()) - s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( - self.weekdayname[wd], - day, self.monthname[month], year, - hh, mm, ss) - return s - - def emit(self, record): - """ - Emit a record. - - Format the record and send it to the specified addressees. - """ - try: - import smtplib - try: - from email.utils import formatdate - except ImportError: - formatdate = self.date_time - port = self.mailport - if not port: - port = smtplib.SMTP_PORT - smtp = smtplib.SMTP(self.mailhost, port) - msg = self.format(record) - msg = "From: %s\r\nTo: %s\r\nSubject: %s\r\nDate: %s\r\n\r\n%s" % ( - self.fromaddr, - string.join(self.toaddrs, ","), - self.getSubject(record), - formatdate(), msg) - if self.username: - smtp.login(self.username, self.password) - smtp.sendmail(self.fromaddr, self.toaddrs, msg) - smtp.quit() - except (KeyboardInterrupt, SystemExit): - raise - except: - self.handleError(record) - -# Monkey patch logging.handlers -logging.handlers.SMTPHandler = SMTPHandler - diff --git a/AnkiServer/apps/sync_app.py b/AnkiServer/sync_app.py similarity index 93% rename from AnkiServer/apps/sync_app.py rename to AnkiServer/sync_app.py index d373f26..ada6957 100644 --- a/AnkiServer/apps/sync_app.py +++ b/AnkiServer/sync_app.py @@ -1,3 +1,4 @@ +from ConfigParser import SafeConfigParser from webob.dec import wsgify from webob.exc import * @@ -94,18 +95,14 @@ class SyncUserSession(object): class SyncApp(object): valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download', 'getDecks'] - def __init__(self, **kw): - from AnkiServer.threading import getCollectionManager + def __init__(self, config): + from AnkiServer.thread import getCollectionManager - self.data_root = os.path.abspath(kw.get('data_root', '.')) - self.base_url = kw.get('base_url', '/') - self.auth_db_path = os.path.abspath(kw.get('auth_db_path', '.')) + self.data_root = os.path.abspath(config.get("sync_app", "data_root")) + self.base_url = config.get("sync_app", "base_url") self.sessions = {} - try: - self.collection_manager = kw['collection_manager'] - except KeyError: - self.collection_manager = getCollectionManager() + self.collection_manager = getCollectionManager() # make sure the base_url has a trailing slash if len(self.base_url) == 0: @@ -297,6 +294,11 @@ class SyncApp(object): return Response(status='200 OK', content_type='text/plain', body='Anki Sync Server') class DatabaseAuthSyncApp(SyncApp): + def __init__(self, config): + SyncApp.__init__(self, config) + + self.auth_db_path = os.path.abspath(config.get("sync_app", "auth_db_path")) + def authenticate(self, username, password): """Returns True if this username is allowed to connect with this password. False otherwise.""" @@ -308,6 +310,8 @@ class DatabaseAuthSyncApp(SyncApp): db_ret = cursor.fetchone() + conn.close() + if db_ret != None: db_hash = str(db_ret[0]) salt = db_hash[-16:] @@ -317,16 +321,16 @@ class DatabaseAuthSyncApp(SyncApp): return (db_ret != None and hashobj.hexdigest()+salt == db_hash) -# Our entry point -def make_app(global_conf, **local_conf): - return DatabaseAuthSyncApp(**local_conf) - def main(): from wsgiref.simple_server import make_server - from AnkiServer.threading import shutdown + from AnkiServer.thread import shutdown + + config = SafeConfigParser() + config.read("production.ini") + + ankiserver = DatabaseAuthSyncApp(config) + httpd = make_server('', config.getint("sync_app", "port"), ankiserver) - ankiserver = SyncApp() - httpd = make_server('', 8001, ankiserver) try: print "Starting..." httpd.serve_forever() diff --git a/AnkiServer/threading.py b/AnkiServer/thread.py similarity index 100% rename from AnkiServer/threading.py rename to AnkiServer/thread.py diff --git a/AnkiServer/utils.py b/AnkiServer/utils.py deleted file mode 100644 index 844ef14..0000000 --- a/AnkiServer/utils.py +++ /dev/null @@ -1,18 +0,0 @@ - -def setup_logging(config_file=None): - """Setup logging based on a config_file.""" - - import logging - - if config_file is not None: - # monkey patch the logging.config.SMTPHandler if necessary - import sys - if sys.version_info[0] == 2 and sys.version_info[1] == 5: - import AnkiServer.logpatch - - # load the config file - import logging.config - logging.config.fileConfig(config_file) - else: - logging.getLogger().setLevel(logging.INFO) - diff --git a/example.ini b/example.ini index 7d2da25..b005492 100644 --- a/example.ini +++ b/example.ini @@ -1,34 +1,6 @@ - -[server:main] -#use = egg:Paste#http -use = egg:AnkiServer#server +[sync_app] host = 127.0.0.1 port = 27701 - -[filter-app:main] -use = egg:Paste#translogger -next = real - -[app:real] -use = egg:Paste#urlmap -/collection = rest_app -/sync = sync_app - -[app:rest_app] -use = egg:AnkiServer#rest_app -data_root = ./collections -allowed_hosts = 127.0.0.1 -logging.config_file = logging.conf - -[app:sync_app] -use = egg:AnkiServer#sync_app data_root = ./collections base_url = /sync/ auth_db_path = ./auth.db -mysql.host = 127.0.0.1 -mysql.user = db_user -mysql.passwd = db_password -mysql.db = db -sql_check_password = SELECT uid FROM users WHERE name=%s AND pass=MD5(%s) AND status=1 -sql_username2dirname = SELECT uid AS dirname FROM users WHERE name=%s - diff --git a/logging.conf b/logging.conf deleted file mode 100644 index fb94bf0..0000000 --- a/logging.conf +++ /dev/null @@ -1,41 +0,0 @@ - -[loggers] -keys=root - -[handlers] -keys=screen,file,email - -[formatters] -keys=normal,email - -[logger_root] -level=INFO -handlers=screen -#handlers=file -#handlers=file,email - -[handler_file] -class=FileHandler -formatter=normal -args=('server.log','a') - -[handler_screen] -class=StreamHandler -level=NOTSET -formatter=normal -args=(sys.stdout,) - -[handler_email] -class=handlers.SMTPHandler -level=ERROR -formatter=email -args=('smtp.example.com', 'support@example.com', ['support_guy1@example.com', 'support_guy2@example.com'], 'AnkiServer error', ('smtp_user', 'smtp_password')) - -[formatter_normal] -format=%(asctime)s:%(name)s:%(levelname)s:%(message)s -datefmt= - -[formatter_email] -format=%(asctime)s - %(name)s - %(levelname)s - %(message)s -datefmt= - diff --git a/setup.py b/setup.py deleted file mode 100644 index b4c239e..0000000 --- a/setup.py +++ /dev/null @@ -1,23 +0,0 @@ - -from setuptools import setup, find_packages - -setup( - name="AnkiServer", - version="2.0.0a1", - description="A personal Anki sync server (so you can sync against your own server rather than AnkiWeb)", - author="David Snopek", - author_email="dsnopek@gmail.com", - install_requires=["PasteDeploy>=1.3.2"], - # TODO: should these really be in install_requires? - requires=["webob(>=0.9.7)"], - test_suite='nose.collector', - entry_points=""" - [paste.app_factory] - sync_app = AnkiServer.apps.sync_app:make_app - rest_app = AnkiServer.apps.rest_app:make_app - - [paste.server_runner] - server = AnkiServer:server_runner - """, -) - diff --git a/supervisor-anki-server.conf b/supervisor-anki-server.conf deleted file mode 100644 index 1dc556d..0000000 --- a/supervisor-anki-server.conf +++ /dev/null @@ -1,12 +0,0 @@ -[program:anki-server] -command=/usr/local/bin/paster serve production.ini -directory=/var/anki-server -user=www-data -autostart=true -autorestart=true -redirect_stderr=true - -; Sometimes necessary if Anki is complaining about a UTF-8 locale. Make sure -; that the local you pick is actually installed on your system. -;environment=LANG=en_US.UTF-8,LC_ALL=en_US.UTF-8,LC_LANG=en_US.UTF-8 - diff --git a/tests/test_collection.py b/tests/test_collection.py deleted file mode 100644 index a38ed5f..0000000 --- a/tests/test_collection.py +++ /dev/null @@ -1,140 +0,0 @@ - -import os -import shutil -import tempfile -import unittest - -import mock -from mock import MagicMock, sentinel - -import AnkiServer -from AnkiServer.collection import CollectionWrapper, CollectionManager - -class CollectionWrapperTest(unittest.TestCase): - def setUp(self): - self.temp_dir = tempfile.mkdtemp() - self.collection_path = os.path.join(self.temp_dir, 'collection.anki2'); - - def tearDown(self): - shutil.rmtree(self.temp_dir) - - def test_lifecycle_real(self): - """Testing common life-cycle with existing and non-existant collections. This - test uses the real Anki objects and actually creates a new collection on disk.""" - - wrapper = CollectionWrapper(self.collection_path) - self.assertFalse(os.path.exists(self.collection_path)) - self.assertFalse(wrapper.opened()) - - wrapper.open() - self.assertTrue(os.path.exists(self.collection_path)) - self.assertTrue(wrapper.opened()) - - # calling open twice shouldn't break anything - wrapper.open() - - wrapper.close() - self.assertTrue(os.path.exists(self.collection_path)) - self.assertFalse(wrapper.opened()) - - # open the same collection again (not a creation) - wrapper = CollectionWrapper(self.collection_path) - self.assertFalse(wrapper.opened()) - wrapper.open() - self.assertTrue(wrapper.opened()) - wrapper.close() - self.assertFalse(wrapper.opened()) - self.assertTrue(os.path.exists(self.collection_path)) - - def test_del(self): - with mock.patch('anki.storage.Collection') as anki_storage_Collection: - col = anki_storage_Collection.return_value - wrapper = CollectionWrapper(self.collection_path) - wrapper.open() - wrapper = None - col.close.assert_called_with() - - def test_setup_func(self): - # Run it when the collection doesn't exist - with mock.patch('anki.storage.Collection') as anki_storage_Collection: - col = anki_storage_Collection.return_value - setup_new_collection = MagicMock() - self.assertFalse(os.path.exists(self.collection_path)) - wrapper = CollectionWrapper(self.collection_path, setup_new_collection) - wrapper.open() - anki_storage_Collection.assert_called_with(self.collection_path) - setup_new_collection.assert_called_with(col) - wrapper = None - - # Make sure that no collection was actually created - self.assertFalse(os.path.exists(self.collection_path)) - - # Create a faux collection file - with file(self.collection_path, 'wt') as fd: - fd.write('Collection!') - - # Run it when the collection does exist - with mock.patch('anki.storage.Collection'): - setup_new_collection = lambda col: self.fail("Setup function called when collection already exists!") - self.assertTrue(os.path.exists(self.collection_path)) - wrapper = CollectionWrapper(self.collection_path, setup_new_collection) - wrapper.open() - anki_storage_Collection.assert_called_with(self.collection_path) - wrapper = None - - def test_execute(self): - with mock.patch('anki.storage.Collection') as anki_storage_Collection: - col = anki_storage_Collection.return_value - func = MagicMock() - func.return_value = sentinel.some_object - - # check that execute works and auto-creates the collection - wrapper = CollectionWrapper(self.collection_path) - ret = wrapper.execute(func, [1, 2, 3], {'key': 'aoeu'}) - self.assertEqual(ret, sentinel.some_object) - anki_storage_Collection.assert_called_with(self.collection_path) - func.assert_called_with(col, 1, 2, 3, key='aoeu') - - # check that execute always returns False if waitForReturn=False - func.reset_mock() - ret = wrapper.execute(func, [1, 2, 3], {'key': 'aoeu'}, waitForReturn=False) - self.assertEqual(ret, None) - func.assert_called_with(col, 1, 2, 3, key='aoeu') - -class CollectionManagerTest(unittest.TestCase): - def test_lifecycle(self): - with mock.patch('AnkiServer.collection.CollectionManager.collection_wrapper') as CollectionWrapper: - wrapper = MagicMock() - CollectionWrapper.return_value = wrapper - - manager = CollectionManager() - - # check getting a new collection - ret = manager.get_collection('path1') - CollectionWrapper.assert_called_with(os.path.realpath('path1'), None) - self.assertEqual(ret, wrapper) - - # change the return value, so that it would return a new object - new_wrapper = MagicMock() - CollectionWrapper.return_value = new_wrapper - CollectionWrapper.reset_mock() - - # get the new wrapper - ret = manager.get_collection('path2') - CollectionWrapper.assert_called_with(os.path.realpath('path2'), None) - self.assertEqual(ret, new_wrapper) - - # make sure the wrapper and new_wrapper are different - self.assertNotEqual(wrapper, new_wrapper) - - # assert that calling with the first path again, returns the first wrapper - ret = manager.get_collection('path1') - self.assertEqual(ret, wrapper) - - manager.shutdown() - wrapper.close.assert_called_with() - new_wrapper.close.assert_called_with() - -if __name__ == '__main__': - unittest.main() - diff --git a/tests/test_importer.py b/tests/test_importer.py deleted file mode 100644 index 448f534..0000000 --- a/tests/test_importer.py +++ /dev/null @@ -1,113 +0,0 @@ - -import os -import shutil -import tempfile -import unittest - -import mock -from mock import MagicMock, sentinel - -import AnkiServer -from AnkiServer.importer import get_importer_class, import_file - -import anki.storage - -# TODO: refactor into some kind of utility -def add_note(col, data): - from anki.notes import Note - - model = col.models.byName(data['model']) - - note = Note(col, model) - for name, value in data['fields'].items(): - note[name] = value - - if data.has_key('tags'): - note.setTagsFromStr(data['tags']) - - col.addNote(note) - -class ImporterTest(unittest.TestCase): - def setUp(self): - self.temp_dir = tempfile.mkdtemp() - self.collection_path = os.path.join(self.temp_dir, 'collection.anki2') - self.collection = anki.storage.Collection(self.collection_path) - - def tearDown(self): - self.collection.close() - self.collection = None - shutil.rmtree(self.temp_dir) - - - # TODO: refactor into a parent class - def add_default_note(self, count=1): - data = { - 'model': 'Basic', - 'fields': { - 'Front': 'The front', - 'Back': 'The back', - }, - 'tags': "Tag1 Tag2", - } - for idx in range(0, count): - add_note(self.collection, data) - self.add_note(data) - - def test_resync(self): - from anki.exporting import AnkiPackageExporter - from anki.utils import intTime - - # create a new collection with a single note - src_collection = anki.storage.Collection(os.path.join(self.temp_dir, 'src_collection.anki2')) - add_note(src_collection, { - 'model': 'Basic', - 'fields': { - 'Front': 'The front', - 'Back': 'The back', - }, - 'tags': 'Tag1 Tag2', - }) - note_id = src_collection.findNotes('')[0] - note = src_collection.getNote(note_id) - self.assertEqual(note.id, note_id) - self.assertEqual(note['Front'], 'The front') - self.assertEqual(note['Back'], 'The back') - - # export to an .apkg file - dst1_path = os.path.join(self.temp_dir, 'export1.apkg') - exporter = AnkiPackageExporter(src_collection) - exporter.exportInto(dst1_path) - - # import it into the main collection - import_file(get_importer_class('apkg'), self.collection, dst1_path) - - # make sure the note exists - note = self.collection.getNote(note_id) - self.assertEqual(note.id, note_id) - self.assertEqual(note['Front'], 'The front') - self.assertEqual(note['Back'], 'The back') - - # now we change the source collection and re-export it - note = src_collection.getNote(note_id) - note['Front'] = 'The new front' - note.tags.append('Tag3') - note.flush(intTime()+1) - dst2_path = os.path.join(self.temp_dir, 'export2.apkg') - exporter = AnkiPackageExporter(src_collection) - exporter.exportInto(dst2_path) - - # first, import it without allow_update - no change should happen - import_file(get_importer_class('apkg'), self.collection, dst2_path) - note = self.collection.getNote(note_id) - self.assertEqual(note['Front'], 'The front') - self.assertEqual(note.tags, ['Tag1', 'Tag2']) - - # now, import it with allow_update=True, so the note should change - import_file(get_importer_class('apkg'), self.collection, dst2_path, allow_update=True) - note = self.collection.getNote(note_id) - self.assertEqual(note['Front'], 'The new front') - self.assertEqual(note.tags, ['Tag1', 'Tag2', 'Tag3']) - -if __name__ == '__main__': - unittest.main() - diff --git a/tests/test_rest_app.py b/tests/test_rest_app.py deleted file mode 100644 index 5033512..0000000 --- a/tests/test_rest_app.py +++ /dev/null @@ -1,618 +0,0 @@ -# -*- coding: utf-8 -*- - -import os -import shutil -import tempfile -import unittest -import logging -from pprint import pprint - -import mock -from mock import MagicMock - -import AnkiServer -from AnkiServer.collection import CollectionManager -from AnkiServer.apps.rest_app import RestApp, RestHandlerRequest, CollectionHandler, ImportExportHandler, NoteHandler, ModelHandler, DeckHandler, CardHandler - -from webob.exc import * - -import anki -import anki.storage - -class RestAppTest(unittest.TestCase): - def setUp(self): - self.temp_dir = tempfile.mkdtemp() - self.collection_manager = CollectionManager() - self.rest_app = RestApp(self.temp_dir, collection_manager=self.collection_manager) - - # disable all but critical errors! - logging.disable(logging.CRITICAL) - - def tearDown(self): - self.collection_manager.shutdown() - self.collection_manager = None - self.rest_app = None - shutil.rmtree(self.temp_dir) - - def test_list_collections(self): - os.mkdir(os.path.join(self.temp_dir, 'test1')) - os.mkdir(os.path.join(self.temp_dir, 'test2')) - - with open(os.path.join(self.temp_dir, 'test1', 'collection.anki2'), 'wt') as fd: - fd.write('Testing!') - - self.assertEqual(self.rest_app.list_collections(), ['test1']) - - def test_parsePath(self): - tests = [ - ('collection/user', ('collection', 'index', ['user'])), - ('collection/user/handler', ('collection', 'handler', ['user'])), - ('collection/user/note/123', ('note', 'index', ['user', '123'])), - ('collection/user/note/123/handler', ('note', 'handler', ['user', '123'])), - ('collection/user/deck/name', ('deck', 'index', ['user', 'name'])), - ('collection/user/deck/name/handler', ('deck', 'handler', ['user', 'name'])), - #('collection/user/deck/name/card/123', ('card', 'index', ['user', 'name', '123'])), - #('collection/user/deck/name/card/123/handler', ('card', 'handler', ['user', 'name', '123'])), - ('collection/user/card/123', ('card', 'index', ['user', '123'])), - ('collection/user/card/123/handler', ('card', 'handler', ['user', '123'])), - # the leading slash should make no difference! - ('/collection/user', ('collection', 'index', ['user'])), - ] - - for path, result in tests: - self.assertEqual(self.rest_app._parsePath(path), result) - - def test_parsePath_not_found(self): - tests = [ - 'bad', - 'bad/oaeu', - 'collection', - 'collection/user/handler/bad', - '', - '/', - ] - - for path in tests: - self.assertRaises(HTTPNotFound, self.rest_app._parsePath, path) - - def test_getCollectionPath(self): - def fullpath(collection_id): - return os.path.normpath(os.path.join(self.temp_dir, collection_id, 'collection.anki2')) - - # This is simple and straight forward! - self.assertEqual(self.rest_app._getCollectionPath('user'), fullpath('user')) - - # These are dangerous - the user is trying to hack us! - dangerous = ['../user', '/etc/passwd', '/tmp/aBaBaB', '/root/.ssh/id_rsa'] - for collection_id in dangerous: - self.assertRaises(HTTPBadRequest, self.rest_app._getCollectionPath, collection_id) - - def test_getHandler(self): - def handlerOne(): - pass - - def handlerTwo(): - pass - handlerTwo.hasReturnValue = False - - self.rest_app.add_handler('collection', 'handlerOne', handlerOne) - self.rest_app.add_handler('deck', 'handlerTwo', handlerTwo) - - (handler, hasReturnValue) = self.rest_app._getHandler('collection', 'handlerOne') - self.assertEqual(handler, handlerOne) - self.assertEqual(hasReturnValue, True) - - (handler, hasReturnValue) = self.rest_app._getHandler('deck', 'handlerTwo') - self.assertEqual(handler, handlerTwo) - self.assertEqual(hasReturnValue, False) - - # try some bad handler names and types - self.assertRaises(HTTPNotFound, self.rest_app._getHandler, 'collection', 'nonExistantHandler') - self.assertRaises(HTTPNotFound, self.rest_app._getHandler, 'nonExistantType', 'handlerOne') - - def test_parseRequestBody(self): - req = MagicMock() - req.body = '{"key":"value"}' - - data = self.rest_app._parseRequestBody(req) - self.assertEqual(data, {'key': 'value'}) - self.assertEqual(data.keys(), ['key']) - self.assertEqual(type(data.keys()[0]), str) - - # test some bad data - req.body = '{aaaaaaa}' - self.assertRaises(HTTPBadRequest, self.rest_app._parseRequestBody, req) - -class CollectionTestBase(unittest.TestCase): - """Parent class for tests that need a collection set up and torn down.""" - - def setUp(self): - self.temp_dir = tempfile.mkdtemp() - self.collection_path = os.path.join(self.temp_dir, 'collection.anki2'); - self.collection = anki.storage.Collection(self.collection_path) - self.mock_app = MagicMock() - - def tearDown(self): - self.collection.close() - self.collection = None - shutil.rmtree(self.temp_dir) - self.mock_app.reset_mock() - - # TODO: refactor into some kind of utility - def add_note(self, data): - from anki.notes import Note - - model = self.collection.models.byName(data['model']) - - note = Note(self.collection, model) - for name, value in data['fields'].items(): - note[name] = value - - if data.has_key('tags'): - note.setTagsFromStr(data['tags']) - - self.collection.addNote(note) - - # TODO: refactor into a parent class - def add_default_note(self, count=1): - data = { - 'model': 'Basic', - 'fields': { - 'Front': 'The front', - 'Back': 'The back', - }, - 'tags': "Tag1 Tag2", - } - for idx in range(0, count): - self.add_note(data) - -class CollectionHandlerTest(CollectionTestBase): - def setUp(self): - super(CollectionHandlerTest, self).setUp() - self.handler = CollectionHandler() - - def execute(self, name, data): - ids = ['collection_name'] - func = getattr(self.handler, name) - req = RestHandlerRequest(self.mock_app, data, ids, {}) - return func(self.collection, req) - - def test_list_decks(self): - data = {} - ret = self.execute('list_decks', data) - - # It contains only the 'Default' deck - self.assertEqual(len(ret), 1) - self.assertEqual(ret[0]['name'], 'Default') - - def test_select_deck(self): - data = {'deck_id': '1'} - ret = self.execute('select_deck', data) - self.assertEqual(ret, None); - - def test_create_dynamic_deck_simple(self): - self.add_default_note(5) - - data = { - 'name': 'Dyn deck', - 'mode': 'random', - 'count': 2, - 'query': "deck:\"Default\" (tag:'Tag1' or tag:'Tag2') (-tag:'Tag3')", - } - ret = self.execute('create_dynamic_deck', data) - self.assertEqual(ret['name'], 'Dyn deck') - self.assertEqual(ret['dyn'], True) - - cards = self.collection.findCards('deck:"Dyn deck"') - self.assertEqual(len(cards), 2) - - def test_list_models(self): - data = {} - ret = self.execute('list_models', data) - - # get a sorted name list that we can actually check - names = [model['name'] for model in ret] - names.sort() - - # These are the default models created by Anki in a new collection - default_models = [ - 'Basic', - 'Basic (and reversed card)', - 'Basic (optional reversed card)', - 'Cloze' - ] - - self.assertEqual(names, default_models) - - def test_find_model_by_name(self): - data = {'model': 'Basic'} - ret = self.execute('find_model_by_name', data) - self.assertEqual(ret['name'], 'Basic') - - def test_find_notes(self): - ret = self.execute('find_notes', {}) - self.assertEqual(ret, []) - - # add a note programatically - self.add_default_note() - - # get the id for the one note on this collection - note_id = self.collection.findNotes('')[0] - - ret = self.execute('find_notes', {}) - self.assertEqual(ret, [{'id': note_id}]) - - ret = self.execute('find_notes', {'query': 'tag:Tag1'}) - self.assertEqual(ret, [{'id': note_id}]) - - ret = self.execute('find_notes', {'query': 'tag:TagX'}) - self.assertEqual(ret, []) - - ret = self.execute('find_notes', {'preload': True}) - self.assertEqual(len(ret), 1) - self.assertEqual(ret[0]['id'], note_id) - self.assertEqual(ret[0]['model']['name'], 'Basic') - - def test_add_note(self): - # make sure there are no notes (yet) - self.assertEqual(self.collection.findNotes(''), []) - - # add a note programatically - note = { - 'model': 'Basic', - 'fields': { - 'Front': 'The front', - 'Back': 'The back', - }, - 'tags': "Tag1 Tag2", - } - self.execute('add_note', note) - - notes = self.collection.findNotes('') - self.assertEqual(len(notes), 1) - - note_id = notes[0] - note = self.collection.getNote(note_id) - - self.assertEqual(note.model()['name'], 'Basic') - self.assertEqual(note['Front'], 'The front') - self.assertEqual(note['Back'], 'The back') - self.assertEqual(note.tags, ['Tag1', 'Tag2']) - - def test_list_tags(self): - ret = self.execute('list_tags', {}) - self.assertEqual(ret, []) - - self.add_default_note() - - ret = self.execute('list_tags', {}) - ret.sort() - self.assertEqual(ret, ['Tag1', 'Tag2']) - - def test_set_language(self): - import anki.lang - - self.assertEqual(anki.lang._('Again'), 'Again') - - try: - data = {'code': 'pl'} - self.execute('set_language', data) - self.assertEqual(anki.lang._('Again'), u'Znowu') - finally: - # return everything to normal! - anki.lang.setLang('en') - - def test_reset_scheduler(self): - self.add_default_note(3) - - ret = self.execute('reset_scheduler', {'deck': 'Default'}) - self.assertEqual(ret, { - 'new_cards': 3, - 'learning_cards': 0, - 'review_cards': 0, - }) - - def test_next_card(self): - ret = self.execute('next_card', {}) - self.assertEqual(ret, None) - - # add a note programatically - self.add_default_note() - - # get the id for the one card and note on this collection - note_id = self.collection.findNotes('')[0] - card_id = self.collection.findCards('')[0] - - self.collection.sched.reset() - ret = self.execute('next_card', {}) - self.assertEqual(ret['id'], card_id) - self.assertEqual(ret['nid'], note_id) - self.assertEqual(ret['css'], '') - self.assertEqual(ret['question'], 'The front') - self.assertEqual(ret['answer'], 'The front\n\n
\n\nThe back') - self.assertEqual(ret['answer_buttons'], [ - {'ease': 1, - 'label': 'Again', - 'string_label': 'Again', - 'interval': 60, - 'string_interval': '<1 minute'}, - {'ease': 2, - 'label': 'Good', - 'string_label': 'Good', - 'interval': 600, - 'string_interval': '<10 minutes'}, - {'ease': 3, - 'label': 'Easy', - 'string_label': 'Easy', - 'interval': 345600, - 'string_interval': '4 days'}]) - - def test_next_card_translation(self): - # add a note programatically - self.add_default_note() - - # get the card in Polish so we can test translation too - anki.lang.setLang('pl') - try: - ret = self.execute('next_card', {}) - finally: - anki.lang.setLang('en') - - self.assertEqual(ret['answer_buttons'], [ - {'ease': 1, - 'label': 'Again', - 'string_label': u'Znowu', - 'interval': 60, - 'string_interval': '<1 minuta'}, - {'ease': 2, - 'label': 'Good', - 'string_label': u'Dobra', - 'interval': 600, - 'string_interval': '<10 minut'}, - {'ease': 3, - 'label': 'Easy', - 'string_label': u'Łatwa', - 'interval': 345600, - 'string_interval': '4 dni'}]) - - def test_next_card_five_times(self): - self.add_default_note(5) - for idx in range(0, 5): - ret = self.execute('next_card', {}) - self.assertTrue(ret is not None) - - def test_answer_card(self): - import time - - self.add_default_note() - - # instantiate a deck handler to get the card - card = self.execute('next_card', {}) - self.assertEqual(card['reps'], 0) - - self.execute('answer_card', {'id': card['id'], 'ease': 2, 'timerStarted': time.time()}) - - # reset the scheduler and try to get the next card again - there should be none! - self.collection.sched.reset() - card = self.execute('next_card', {}) - self.assertEqual(card['reps'], 1) - - def test_suspend_cards(self): - # add a note programatically - self.add_default_note() - - # get the id for the one card on this collection - card_id = self.collection.findCards('')[0] - - # suspend it - self.execute('suspend_cards', {'ids': [card_id]}) - - # test that getting the next card will be None - card = self.collection.sched.getCard() - self.assertEqual(card, None) - - # unsuspend it - self.execute('unsuspend_cards', {'ids': [card_id]}) - - # test that now we're getting the next card! - self.collection.sched.reset() - card = self.collection.sched.getCard() - self.assertEqual(card.id, card_id) - - def test_cards_recent_ease(self): - self.add_default_note() - card_id = self.collection.findCards('')[0] - - # answer the card - self.collection.reset() - card = self.collection.sched.getCard() - card.startTimer() - # answer multiple times to see that we only get the latest! - self.collection.sched.answerCard(card, 1) - self.collection.sched.answerCard(card, 3) - self.collection.sched.answerCard(card, 2) - - # pull the latest revision - ret = self.execute('cards_recent_ease', {}) - self.assertEqual(ret[0]['id'], card_id) - self.assertEqual(ret[0]['ease'], 2) - -class ImportExportHandlerTest(CollectionTestBase): - export_rows = [ - ['Card front 1', 'Card back 1', 'Tag1 Tag2'], - ['Card front 2', 'Card back 2', 'Tag1 Tag3'], - ] - - def setUp(self): - super(ImportExportHandlerTest, self).setUp() - self.handler = ImportExportHandler() - - def execute(self, name, data): - ids = ['collection_name'] - func = getattr(self.handler, name) - req = RestHandlerRequest(self.mock_app, data, ids, {}) - return func(self.collection, req) - - def generate_text_export(self): - # Create a simple export file - export_data = '' - for row in self.export_rows: - export_data += '\t'.join(row) + '\n' - export_path = os.path.join(self.temp_dir, 'export.txt') - with file(export_path, 'wt') as fd: - fd.write(export_data) - - return (export_data, export_path) - - def check_import(self): - note_ids = self.collection.findNotes('') - notes = [self.collection.getNote(note_id) for note_id in note_ids] - self.assertEqual(len(notes), len(self.export_rows)) - - for index, test_data in enumerate(self.export_rows): - self.assertEqual(notes[index]['Front'], test_data[0]) - self.assertEqual(notes[index]['Back'], test_data[1]) - self.assertEqual(' '.join(notes[index].tags), test_data[2]) - - def test_import_text_data(self): - (export_data, export_path) = self.generate_text_export() - - data = { - 'filetype': 'text', - 'data': export_data, - } - ret = self.execute('import_file', data) - self.check_import() - - def test_import_text_url(self): - (export_data, export_path) = self.generate_text_export() - - data = { - 'filetype': 'text', - 'url': 'file://' + os.path.realpath(export_path), - } - ret = self.execute('import_file', data) - self.check_import() - -class NoteHandlerTest(CollectionTestBase): - def setUp(self): - super(NoteHandlerTest, self).setUp() - self.handler = NoteHandler() - - def execute(self, name, data, note_id): - ids = ['collection_name', note_id] - func = getattr(self.handler, name) - req = RestHandlerRequest(self.mock_app, data, ids, {}) - return func(self.collection, req) - - def test_index(self): - self.add_default_note() - - note_id = self.collection.findNotes('')[0] - - ret = self.execute('index', {}, note_id) - self.assertEqual(ret['id'], note_id) - self.assertEqual(len(ret['fields']), 2) - self.assertEqual(ret['flags'], 0) - self.assertEqual(ret['model']['name'], 'Basic') - self.assertEqual(ret['tags'], ['Tag1', 'Tag2']) - self.assertEqual(ret['string_tags'], 'Tag1 Tag2') - self.assertEqual(ret['usn'], -1) - - def test_add_tags(self): - self.add_default_note() - note_id = self.collection.findNotes('')[0] - note = self.collection.getNote(note_id) - self.assertFalse('NT1' in note.tags) - self.assertFalse('NT2' in note.tags) - - self.execute('add_tags', {'tags': ['NT1', 'NT2']}, note_id) - note = self.collection.getNote(note_id) - self.assertTrue('NT1' in note.tags) - self.assertTrue('NT2' in note.tags) - - def test_remove_tags(self): - self.add_default_note() - note_id = self.collection.findNotes('')[0] - note = self.collection.getNote(note_id) - self.assertTrue('Tag1' in note.tags) - self.assertTrue('Tag2' in note.tags) - - self.execute('remove_tags', {'tags': ['Tag1', 'Tag2']}, note_id) - note = self.collection.getNote(note_id) - self.assertFalse('Tag1' in note.tags) - self.assertFalse('Tag2' in note.tags) - -class DeckHandlerTest(CollectionTestBase): - def setUp(self): - super(DeckHandlerTest, self).setUp() - self.handler = DeckHandler() - - def execute(self, name, data): - ids = ['collection_name', '1'] - func = getattr(self.handler, name) - req = RestHandlerRequest(self.mock_app, data, ids, {}) - return func(self.collection, req) - - def test_index(self): - ret = self.execute('index', {}) - #pprint(ret) - self.assertEqual(ret['name'], 'Default') - self.assertEqual(ret['id'], 1) - self.assertEqual(ret['dyn'], False) - - def test_next_card(self): - self.mock_app.execute_handler.return_value = None - - ret = self.execute('next_card', {}) - self.assertEqual(ret, None) - self.mock_app.execute_handler.assert_called_with('collection', 'next_card', self.collection, RestHandlerRequest(self.mock_app, {'deck': '1'}, ['collection_name'], {})) - - def test_get_conf(self): - ret = self.execute('get_conf', {}) - #pprint(ret) - self.assertEqual(ret['name'], 'Default') - self.assertEqual(ret['id'], 1) - self.assertEqual(ret['dyn'], False) - -class CardHandlerTest(CollectionTestBase): - def setUp(self): - super(CardHandlerTest, self).setUp() - self.handler = CardHandler() - - def execute(self, name, data, card_id): - ids = ['collection_name', card_id] - func = getattr(self.handler, name) - req = RestHandlerRequest(self.mock_app, data, ids, {}) - return func(self.collection, req) - - def test_index_simple(self): - self.add_default_note() - - note_id = self.collection.findNotes('')[0] - card_id = self.collection.findCards('')[0] - - ret = self.execute('index', {}, card_id) - self.assertEqual(ret['id'], card_id) - self.assertEqual(ret['nid'], note_id) - self.assertEqual(ret['did'], 1) - self.assertFalse(ret.has_key('note')) - self.assertFalse(ret.has_key('deck')) - - def test_index_load(self): - self.add_default_note() - - note_id = self.collection.findNotes('')[0] - card_id = self.collection.findCards('')[0] - - ret = self.execute('index', {'load_note': 1, 'load_deck': 1}, card_id) - self.assertEqual(ret['id'], card_id) - self.assertEqual(ret['nid'], note_id) - self.assertEqual(ret['did'], 1) - self.assertEqual(ret['note']['id'], note_id) - self.assertEqual(ret['note']['model']['name'], 'Basic') - self.assertEqual(ret['deck']['name'], 'Default') - -if __name__ == '__main__': - unittest.main() - diff --git a/tests/test_sync_app.py b/tests/test_sync_app.py deleted file mode 100644 index 9159eac..0000000 --- a/tests/test_sync_app.py +++ /dev/null @@ -1,9 +0,0 @@ - -import unittest - -import AnkiServer -from AnkiServer.apps.sync_app import SyncApp - -class SyncAppTest(unittest.TestCase): - pass -