From 93094ebb485c8021ef740775bf2193e6352127f6 Mon Sep 17 00:00:00 2001 From: David Snopek Date: Tue, 16 Jul 2013 14:20:31 +0100 Subject: [PATCH] * Added new handler type 'card' and moved the position of 'note' * Got us actually adding notes and cards to the Anki collection! --- AnkiServer/apps/rest_app.py | 71 ++++++++++++++++++++++++++++--------- tests/test_rest_app.py | 50 ++++++++++++++++++++++++-- 2 files changed, 102 insertions(+), 19 deletions(-) diff --git a/AnkiServer/apps/rest_app.py b/AnkiServer/apps/rest_app.py index 0c401fb..9481d25 100644 --- a/AnkiServer/apps/rest_app.py +++ b/AnkiServer/apps/rest_app.py @@ -43,7 +43,8 @@ class _RestHandlerWrapper(RestHandlerBase): class RestApp(object): """A WSGI app that implements RESTful operations on Collections, Decks and Cards.""" - handler_types = ['collection', 'deck', 'note'] + # Defines not only the valid handler types, but their position in the URL string + handler_types = ['collection', ['deck', 'note'], 'card'] def __init__(self, data_root, allowed_hosts='*', use_default_handlers=True, collection_manager=None): from AnkiServer.threading import getCollectionManager @@ -57,13 +58,17 @@ class RestApp(object): self.collection_manager = getCollectionManager() self.handlers = {} - for type in self.handler_types: - self.handlers[type] = {} + for type_list in self.handler_types: + if type(type_list) is not list: + type_list = [type_list] + for handler_type in type_list: + self.handlers[handler_type] = {} if use_default_handlers: self.add_handler_group('collection', CollectionHandlerGroup()) - self.add_handler_group('deck', DeckHandlerGroup()) self.add_handler_group('note', NoteHandlerGroup()) + self.add_handler_group('deck', DeckHandlerGroup()) + self.add_handler_group('card', CardHandlerGroup()) def add_handler(self, type, name, handler): """Adds a callback handler for a type (collection, deck, card) with a unique name. @@ -120,17 +125,25 @@ class RestApp(object): parts = path.split('/') # pull the type and context from the URL parts - type = None + handler_type = None ids = [] - for type in self.handler_types: + for type_list in self.handler_types: if len(parts) == 0: break - if parts[0] != type: - break - parts.pop(0) + # some URL positions can have multiple types + if type(type_list) is not list: + type_list = [type_list] + + # get the handler_type + if parts[0] not in type_list: + break + handler_type = parts.pop(0) + + # add the id to the id list if len(parts) > 0: ids.append(parts.pop(0)) + # break if we don't have enough parts to make a new type/id pair if len(parts) < 2: break @@ -144,7 +157,7 @@ class RestApp(object): else: name = parts[0] - return (type, name, ids) + return (handler_type, name, ids) def _getCollectionPath(self, collection_id): """Returns the path to the collection based on the collection_id from the request. @@ -238,6 +251,26 @@ class CollectionHandlerGroup(RestHandlerGroupBase): def select_deck(self, col, data, ids): col.decks.select(data['deck_id']) + @noReturnValue + def sched_reset(self, col, data, ids): + col.sched.reset() + +class NoteHandlerGroup(RestHandlerGroupBase): + """Default handler group for 'note' type.""" + + @staticmethod + def _serialize_note(note): + d = { + 'id': note.id, + 'model': note.model()['name'], + } + # TODO: do more stuff! + return d + + def index(self, col, data, ids): + note = col.getNote(ids[1]) + return self._serialize_note(note) + class DeckHandlerGroup(RestHandlerGroupBase): """Default handler group for 'deck' type.""" @@ -246,15 +279,21 @@ class DeckHandlerGroup(RestHandlerGroupBase): col.decks.select(deck_id) card = col.sched.getCard() + if card is None: + return None - return card + return CardHandlerGroup._serialize_card(card) -class NoteHandlerGroup(RestHandlerGroupBase): - """Default handler group for 'note' type.""" +class CardHandlerGroup(RestHandlerGroupBase): + """Default handler group for 'card' type.""" - def add_new(self, col, data, ids): - # col.addNote(...) - pass + @staticmethod + def _serialize_card(card): + d = { + 'id': card.id + } + # TODO: do more stuff! + return d # Our entry point def make_app(global_conf, **local_conf): diff --git a/tests/test_rest_app.py b/tests/test_rest_app.py index 84e402a..77b4320 100644 --- a/tests/test_rest_app.py +++ b/tests/test_rest_app.py @@ -4,6 +4,7 @@ import shutil import tempfile import unittest import logging +from pprint import pprint import mock from mock import MagicMock @@ -36,10 +37,12 @@ class RestAppTest(unittest.TestCase): tests = [ ('collection/user', ('collection', 'index', ['user'])), ('collection/user/handler', ('collection', 'handler', ['user'])), + ('collection/user/note/123', ('note', 'index', ['user', '123'])), + ('collection/user/note/123/handler', ('note', 'handler', ['user', '123'])), ('collection/user/deck/name', ('deck', 'index', ['user', 'name'])), ('collection/user/deck/name/handler', ('deck', 'handler', ['user', 'name'])), - ('collection/user/deck/name/note/123', ('note', 'index', ['user', 'name', '123'])), - ('collection/user/deck/name/note/123/handler', ('note', 'handler', ['user', 'name', '123'])), + ('collection/user/deck/name/card/123', ('card', 'index', ['user', 'name', '123'])), + ('collection/user/deck/name/card/123/handler', ('card', 'handler', ['user', 'name', '123'])), # the leading slash should make no difference! ('/collection/user', ('collection', 'index', ['user'])), ] @@ -121,6 +124,32 @@ class CollectionTestBase(unittest.TestCase): self.collection = None shutil.rmtree(self.temp_dir) + def add_note(self, data): + from anki.notes import Note + + # TODO: we need to check the input for the correct keys.. Can we automate + # this somehow? Maybe using KeyError or wrapper or something? + + #pprint(self.collection.models.all()) + #pprint(self.collection.models.current()) + + model = self.collection.models.byName(data['model']) + #pprint (self.collection.models.fieldNames(model)) + + note = Note(self.collection, model) + for name, value in data['fields'].items(): + note[name] = value + + if data.has_key('tags'): + note.setTagsFromStr(data['tags']) + + ret = self.collection.addNote(note) + + def find_notes(self, data): + query = data.get('query', '') + ids = self.collection.getNotes(query) + + class CollectionHandlerGroupTest(CollectionTestBase): def setUp(self): super(CollectionHandlerGroupTest, self).setUp() @@ -131,6 +160,7 @@ class CollectionHandlerGroupTest(CollectionTestBase): func = getattr(self.handler, name) return func(self.collection, data, ids) + def test_list_decks(self): data = {} ret = self.execute('list_decks', data) @@ -158,9 +188,23 @@ class DeckHandlerGroupTest(CollectionTestBase): ret = self.execute('next_card', {}) self.assertEqual(ret, None) - # TODO: add a note programatically + # add a note programatically + note = { + 'model': 'Basic', + 'fields': { + 'Front': 'The front', + 'Back': 'The back', + }, + 'tags': "Tag1 Tag2", + } + self.add_note(note) + # get the id for the one card on this collection + card_id = self.collection.findCards('')[0] + self.collection.sched.reset() + ret = self.execute('next_card', {}) + self.assertEqual(ret['id'], card_id) if __name__ == '__main__': unittest.main()