Refactored AnkiServer/apps/rest_app.py so that it is actually testable.

This commit is contained in:
David Snopek 2013-07-15 16:13:48 +01:00
parent 3a31cb5889
commit 57d3ba5445
2 changed files with 77 additions and 15 deletions

View File

@ -63,13 +63,6 @@ class RestApp(object):
self.add_handler_group('deck', DeckHandlerGroup())
self.add_handler_group('note', NoteHandlerGroup())
def _get_path(self, path):
npath = os.path.normpath(os.path.join(self.data_root, path, 'collection.anki2'))
if npath[0:len(self.data_root)] != self.data_root:
# attempting to escape our data jail!
raise HTTPBadRequest('"%s" is not a valid path/id' % path)
return npath
def add_handler(self, type, name, handler):
"""Adds a callback handler for a type (collection, deck, card) with a unique name.
@ -97,8 +90,8 @@ class RestApp(object):
method = _RestHandlerWrapper(group.__class__.__name__ + '.' + name, method, group.hasReturnValue)
self.add_handler(type, name, method)
@wsgify
def __call__(self, req):
def _checkRequest(self, req):
"""Raises an exception if the request isn't allowed or valid for some reason."""
if self.allowed_hosts != '*':
try:
remote_addr = req.headers['X-Forwarded-For']
@ -106,12 +99,17 @@ class RestApp(object):
remote_addr = req.remote_addr
if remote_addr != self.allowed_hosts:
raise HTTPForbidden()
if req.method != 'POST':
raise HTTPMethodNotAllowed(allow=['POST'])
def _parsePath(self, path):
"""Takes a request path and returns a tuple containing the handler type, name
and a list of ids.
Raises an HTTPNotFound exception if the path is invalid."""
# split the URL into a list of parts
path = req.path
if path[0] == '/':
path = path[1:]
parts = path.split('/')
@ -136,9 +134,25 @@ class RestApp(object):
else:
name = parts[0]
# get the collection path
collection_path = self._get_path(ids[0])
print collection_path
return (type, name, ids)
def _getCollectionPath(self, collection_id):
"""Returns the path to the collection based on the collection_id from the request.
Raises HTTPBadRequest if the collection_id is invalid."""
path = os.path.normpath(os.path.join(self.data_root, collection_id, 'collection.anki2'))
if path[0:len(self.data_root)] != self.data_root:
# attempting to escape our data jail!
raise HTTPBadRequest('"%s" is not a valid collection' % collection_id)
return path
def _getHandler(self, type, name):
"""Returns a tuple containing handler function for this type and name, and a boolean flag
if that handler has a return value.
Raises an HTTPNotFound exception if the handler doesn't exist."""
# get the handler function
try:
@ -151,6 +165,13 @@ class RestApp(object):
if hasattr(handler, 'hasReturnValue'):
hasReturnValue = handler.hasReturnValue
return (handler, hasReturnValue)
def _parseRequestBody(self, req):
"""Parses the request body (JSON) into a Python dict and returns it.
Raises an HTTPBadRequest exception if the request isn't valid JSON."""
try:
data = json.loads(req.body)
except ValueError, e:
@ -159,6 +180,26 @@ class RestApp(object):
# make the keys into non-unicode strings
data = dict([(str(k), v) for k, v in data.items()])
return data
@wsgify
def __call__(self, req):
# make sure the request is valid
self._checkRequest(req)
# parse the path
type, name, ids = self._parsePath(req.path)
# get the collection path
collection_path = self._getCollectionPath(ids[0])
print collection_path
# get the handler function
handler, hasReturnValue = self._getHandler(type, name)
# parse the request body
data = self._parseRequestBody(req)
# debug
from pprint import pprint
pprint(data)

View File

@ -5,11 +5,32 @@ import tempfile
import unittest
import AnkiServer
from AnkiServer.apps.rest_app import CollectionHandlerGroup, DeckHandlerGroup
from AnkiServer.collection import CollectionManager
from AnkiServer.apps.rest_app import RestApp, CollectionHandlerGroup, DeckHandlerGroup
import anki
import anki.storage
class RestAppTest(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
self.collection_manager = CollectionManager()
self.rest_app = RestApp(self.temp_dir, collection_manager=self.collection_manager)
def tearDown(self):
self.collection_manager.shutdown()
self.collection_manager = None
self.rest_app = None
shutil.rmtree(self.temp_dir)
def test_parsePath(self):
tests = [
('collection/aoeu', ('collection', 'index', ['aoeu'])),
]
for path, result in tests:
self.assertEqual(self.rest_app._parsePath(path), result)
class CollectionTestBase(unittest.TestCase):
"""Parent class for tests that need a collection set up and torn down."""