Refactored AnkiServer/apps/rest_app.py so that it is actually testable.
This commit is contained in:
parent
3a31cb5889
commit
57d3ba5445
@ -63,13 +63,6 @@ class RestApp(object):
|
||||
self.add_handler_group('deck', DeckHandlerGroup())
|
||||
self.add_handler_group('note', NoteHandlerGroup())
|
||||
|
||||
def _get_path(self, path):
|
||||
npath = os.path.normpath(os.path.join(self.data_root, path, 'collection.anki2'))
|
||||
if npath[0:len(self.data_root)] != self.data_root:
|
||||
# attempting to escape our data jail!
|
||||
raise HTTPBadRequest('"%s" is not a valid path/id' % path)
|
||||
return npath
|
||||
|
||||
def add_handler(self, type, name, handler):
|
||||
"""Adds a callback handler for a type (collection, deck, card) with a unique name.
|
||||
|
||||
@ -97,8 +90,8 @@ class RestApp(object):
|
||||
method = _RestHandlerWrapper(group.__class__.__name__ + '.' + name, method, group.hasReturnValue)
|
||||
self.add_handler(type, name, method)
|
||||
|
||||
@wsgify
|
||||
def __call__(self, req):
|
||||
def _checkRequest(self, req):
|
||||
"""Raises an exception if the request isn't allowed or valid for some reason."""
|
||||
if self.allowed_hosts != '*':
|
||||
try:
|
||||
remote_addr = req.headers['X-Forwarded-For']
|
||||
@ -106,12 +99,17 @@ class RestApp(object):
|
||||
remote_addr = req.remote_addr
|
||||
if remote_addr != self.allowed_hosts:
|
||||
raise HTTPForbidden()
|
||||
|
||||
|
||||
if req.method != 'POST':
|
||||
raise HTTPMethodNotAllowed(allow=['POST'])
|
||||
|
||||
def _parsePath(self, path):
|
||||
"""Takes a request path and returns a tuple containing the handler type, name
|
||||
and a list of ids.
|
||||
|
||||
Raises an HTTPNotFound exception if the path is invalid."""
|
||||
|
||||
# split the URL into a list of parts
|
||||
path = req.path
|
||||
if path[0] == '/':
|
||||
path = path[1:]
|
||||
parts = path.split('/')
|
||||
@ -136,9 +134,25 @@ class RestApp(object):
|
||||
else:
|
||||
name = parts[0]
|
||||
|
||||
# get the collection path
|
||||
collection_path = self._get_path(ids[0])
|
||||
print collection_path
|
||||
return (type, name, ids)
|
||||
|
||||
def _getCollectionPath(self, collection_id):
|
||||
"""Returns the path to the collection based on the collection_id from the request.
|
||||
|
||||
Raises HTTPBadRequest if the collection_id is invalid."""
|
||||
|
||||
path = os.path.normpath(os.path.join(self.data_root, collection_id, 'collection.anki2'))
|
||||
if path[0:len(self.data_root)] != self.data_root:
|
||||
# attempting to escape our data jail!
|
||||
raise HTTPBadRequest('"%s" is not a valid collection' % collection_id)
|
||||
|
||||
return path
|
||||
|
||||
def _getHandler(self, type, name):
|
||||
"""Returns a tuple containing handler function for this type and name, and a boolean flag
|
||||
if that handler has a return value.
|
||||
|
||||
Raises an HTTPNotFound exception if the handler doesn't exist."""
|
||||
|
||||
# get the handler function
|
||||
try:
|
||||
@ -151,6 +165,13 @@ class RestApp(object):
|
||||
if hasattr(handler, 'hasReturnValue'):
|
||||
hasReturnValue = handler.hasReturnValue
|
||||
|
||||
return (handler, hasReturnValue)
|
||||
|
||||
def _parseRequestBody(self, req):
|
||||
"""Parses the request body (JSON) into a Python dict and returns it.
|
||||
|
||||
Raises an HTTPBadRequest exception if the request isn't valid JSON."""
|
||||
|
||||
try:
|
||||
data = json.loads(req.body)
|
||||
except ValueError, e:
|
||||
@ -159,6 +180,26 @@ class RestApp(object):
|
||||
# make the keys into non-unicode strings
|
||||
data = dict([(str(k), v) for k, v in data.items()])
|
||||
|
||||
return data
|
||||
|
||||
@wsgify
|
||||
def __call__(self, req):
|
||||
# make sure the request is valid
|
||||
self._checkRequest(req)
|
||||
|
||||
# parse the path
|
||||
type, name, ids = self._parsePath(req.path)
|
||||
|
||||
# get the collection path
|
||||
collection_path = self._getCollectionPath(ids[0])
|
||||
print collection_path
|
||||
|
||||
# get the handler function
|
||||
handler, hasReturnValue = self._getHandler(type, name)
|
||||
|
||||
# parse the request body
|
||||
data = self._parseRequestBody(req)
|
||||
|
||||
# debug
|
||||
from pprint import pprint
|
||||
pprint(data)
|
||||
|
||||
@ -5,11 +5,32 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
import AnkiServer
|
||||
from AnkiServer.apps.rest_app import CollectionHandlerGroup, DeckHandlerGroup
|
||||
from AnkiServer.collection import CollectionManager
|
||||
from AnkiServer.apps.rest_app import RestApp, CollectionHandlerGroup, DeckHandlerGroup
|
||||
|
||||
import anki
|
||||
import anki.storage
|
||||
|
||||
class RestAppTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.collection_manager = CollectionManager()
|
||||
self.rest_app = RestApp(self.temp_dir, collection_manager=self.collection_manager)
|
||||
|
||||
def tearDown(self):
|
||||
self.collection_manager.shutdown()
|
||||
self.collection_manager = None
|
||||
self.rest_app = None
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
def test_parsePath(self):
|
||||
tests = [
|
||||
('collection/aoeu', ('collection', 'index', ['aoeu'])),
|
||||
]
|
||||
|
||||
for path, result in tests:
|
||||
self.assertEqual(self.rest_app._parsePath(path), result)
|
||||
|
||||
class CollectionTestBase(unittest.TestCase):
|
||||
"""Parent class for tests that need a collection set up and torn down."""
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user