diff --git a/AnkiServer/apps/rest_app.py b/AnkiServer/apps/rest_app.py index fd5cf2c..d0f7edb 100644 --- a/AnkiServer/apps/rest_app.py +++ b/AnkiServer/apps/rest_app.py @@ -63,13 +63,6 @@ class RestApp(object): self.add_handler_group('deck', DeckHandlerGroup()) self.add_handler_group('note', NoteHandlerGroup()) - def _get_path(self, path): - npath = os.path.normpath(os.path.join(self.data_root, path, 'collection.anki2')) - if npath[0:len(self.data_root)] != self.data_root: - # attempting to escape our data jail! - raise HTTPBadRequest('"%s" is not a valid path/id' % path) - return npath - def add_handler(self, type, name, handler): """Adds a callback handler for a type (collection, deck, card) with a unique name. @@ -97,8 +90,8 @@ class RestApp(object): method = _RestHandlerWrapper(group.__class__.__name__ + '.' + name, method, group.hasReturnValue) self.add_handler(type, name, method) - @wsgify - def __call__(self, req): + def _checkRequest(self, req): + """Raises an exception if the request isn't allowed or valid for some reason.""" if self.allowed_hosts != '*': try: remote_addr = req.headers['X-Forwarded-For'] @@ -106,12 +99,17 @@ class RestApp(object): remote_addr = req.remote_addr if remote_addr != self.allowed_hosts: raise HTTPForbidden() - + if req.method != 'POST': raise HTTPMethodNotAllowed(allow=['POST']) + def _parsePath(self, path): + """Takes a request path and returns a tuple containing the handler type, name + and a list of ids. + + Raises an HTTPNotFound exception if the path is invalid.""" + # split the URL into a list of parts - path = req.path if path[0] == '/': path = path[1:] parts = path.split('/') @@ -136,9 +134,25 @@ class RestApp(object): else: name = parts[0] - # get the collection path - collection_path = self._get_path(ids[0]) - print collection_path + return (type, name, ids) + + def _getCollectionPath(self, collection_id): + """Returns the path to the collection based on the collection_id from the request. + + Raises HTTPBadRequest if the collection_id is invalid.""" + + path = os.path.normpath(os.path.join(self.data_root, collection_id, 'collection.anki2')) + if path[0:len(self.data_root)] != self.data_root: + # attempting to escape our data jail! + raise HTTPBadRequest('"%s" is not a valid collection' % collection_id) + + return path + + def _getHandler(self, type, name): + """Returns a tuple containing handler function for this type and name, and a boolean flag + if that handler has a return value. + + Raises an HTTPNotFound exception if the handler doesn't exist.""" # get the handler function try: @@ -151,6 +165,13 @@ class RestApp(object): if hasattr(handler, 'hasReturnValue'): hasReturnValue = handler.hasReturnValue + return (handler, hasReturnValue) + + def _parseRequestBody(self, req): + """Parses the request body (JSON) into a Python dict and returns it. + + Raises an HTTPBadRequest exception if the request isn't valid JSON.""" + try: data = json.loads(req.body) except ValueError, e: @@ -159,6 +180,26 @@ class RestApp(object): # make the keys into non-unicode strings data = dict([(str(k), v) for k, v in data.items()]) + return data + + @wsgify + def __call__(self, req): + # make sure the request is valid + self._checkRequest(req) + + # parse the path + type, name, ids = self._parsePath(req.path) + + # get the collection path + collection_path = self._getCollectionPath(ids[0]) + print collection_path + + # get the handler function + handler, hasReturnValue = self._getHandler(type, name) + + # parse the request body + data = self._parseRequestBody(req) + # debug from pprint import pprint pprint(data) diff --git a/tests/test_rest_app.py b/tests/test_rest_app.py index 6278f44..35a7a6e 100644 --- a/tests/test_rest_app.py +++ b/tests/test_rest_app.py @@ -5,11 +5,32 @@ import tempfile import unittest import AnkiServer -from AnkiServer.apps.rest_app import CollectionHandlerGroup, DeckHandlerGroup +from AnkiServer.collection import CollectionManager +from AnkiServer.apps.rest_app import RestApp, CollectionHandlerGroup, DeckHandlerGroup import anki import anki.storage +class RestAppTest(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.collection_manager = CollectionManager() + self.rest_app = RestApp(self.temp_dir, collection_manager=self.collection_manager) + + def tearDown(self): + self.collection_manager.shutdown() + self.collection_manager = None + self.rest_app = None + shutil.rmtree(self.temp_dir) + + def test_parsePath(self): + tests = [ + ('collection/aoeu', ('collection', 'index', ['aoeu'])), + ] + + for path, result in tests: + self.assertEqual(self.rest_app._parsePath(path), result) + class CollectionTestBase(unittest.TestCase): """Parent class for tests that need a collection set up and torn down."""