Refactored AnkiServer/apps/rest_app.py so that it is actually testable.
This commit is contained in:
		
							parent
							
								
									3a31cb5889
								
							
						
					
					
						commit
						57d3ba5445
					
				| @ -63,13 +63,6 @@ class RestApp(object): | ||||
|             self.add_handler_group('deck', DeckHandlerGroup()) | ||||
|             self.add_handler_group('note', NoteHandlerGroup()) | ||||
| 
 | ||||
|     def _get_path(self, path): | ||||
|         npath = os.path.normpath(os.path.join(self.data_root, path, 'collection.anki2')) | ||||
|         if npath[0:len(self.data_root)] != self.data_root: | ||||
|             # attempting to escape our data jail! | ||||
|             raise HTTPBadRequest('"%s" is not a valid path/id' % path) | ||||
|         return npath | ||||
| 
 | ||||
|     def add_handler(self, type, name, handler): | ||||
|         """Adds a callback handler for a type (collection, deck, card) with a unique name. | ||||
|          | ||||
| @ -97,8 +90,8 @@ class RestApp(object): | ||||
|                     method = _RestHandlerWrapper(group.__class__.__name__ + '.' + name, method, group.hasReturnValue) | ||||
|                 self.add_handler(type, name, method) | ||||
| 
 | ||||
|     @wsgify | ||||
|     def __call__(self, req): | ||||
|     def _checkRequest(self, req): | ||||
|         """Raises an exception if the request isn't allowed or valid for some reason.""" | ||||
|         if self.allowed_hosts != '*': | ||||
|             try: | ||||
|                 remote_addr = req.headers['X-Forwarded-For'] | ||||
| @ -106,12 +99,17 @@ class RestApp(object): | ||||
|                 remote_addr = req.remote_addr | ||||
|             if remote_addr != self.allowed_hosts: | ||||
|                 raise HTTPForbidden() | ||||
| 
 | ||||
|          | ||||
|         if req.method != 'POST': | ||||
|             raise HTTPMethodNotAllowed(allow=['POST']) | ||||
| 
 | ||||
|     def _parsePath(self, path): | ||||
|         """Takes a request path and returns a tuple containing the handler type, name | ||||
|         and a list of ids. | ||||
| 
 | ||||
|         Raises an HTTPNotFound exception if the path is invalid.""" | ||||
| 
 | ||||
|         # split the URL into a list of parts | ||||
|         path = req.path | ||||
|         if path[0] == '/': | ||||
|             path = path[1:] | ||||
|         parts = path.split('/') | ||||
| @ -136,9 +134,25 @@ class RestApp(object): | ||||
|         else: | ||||
|             name = parts[0] | ||||
| 
 | ||||
|         # get the collection path | ||||
|         collection_path = self._get_path(ids[0]) | ||||
|         print collection_path | ||||
|         return (type, name, ids) | ||||
| 
 | ||||
|     def _getCollectionPath(self, collection_id): | ||||
|         """Returns the path to the collection based on the collection_id from the request. | ||||
|          | ||||
|         Raises HTTPBadRequest if the collection_id is invalid.""" | ||||
| 
 | ||||
|         path = os.path.normpath(os.path.join(self.data_root, collection_id, 'collection.anki2')) | ||||
|         if path[0:len(self.data_root)] != self.data_root: | ||||
|             # attempting to escape our data jail! | ||||
|             raise HTTPBadRequest('"%s" is not a valid collection' % collection_id) | ||||
| 
 | ||||
|         return path | ||||
| 
 | ||||
|     def _getHandler(self, type, name): | ||||
|         """Returns a tuple containing handler function for this type and name, and a boolean flag | ||||
|         if that handler has a return value. | ||||
| 
 | ||||
|         Raises an HTTPNotFound exception if the handler doesn't exist.""" | ||||
| 
 | ||||
|         # get the handler function | ||||
|         try: | ||||
| @ -151,6 +165,13 @@ class RestApp(object): | ||||
|         if hasattr(handler, 'hasReturnValue'): | ||||
|             hasReturnValue = handler.hasReturnValue | ||||
| 
 | ||||
|         return (handler, hasReturnValue) | ||||
| 
 | ||||
|     def _parseRequestBody(self, req): | ||||
|         """Parses the request body (JSON) into a Python dict and returns it. | ||||
| 
 | ||||
|         Raises an HTTPBadRequest exception if the request isn't valid JSON.""" | ||||
|          | ||||
|         try: | ||||
|             data = json.loads(req.body) | ||||
|         except ValueError, e: | ||||
| @ -159,6 +180,26 @@ class RestApp(object): | ||||
|         # make the keys into non-unicode strings | ||||
|         data = dict([(str(k), v) for k, v in data.items()]) | ||||
| 
 | ||||
|         return data | ||||
| 
 | ||||
|     @wsgify | ||||
|     def __call__(self, req): | ||||
|         # make sure the request is valid | ||||
|         self._checkRequest(req) | ||||
| 
 | ||||
|         # parse the path | ||||
|         type, name, ids = self._parsePath(req.path) | ||||
| 
 | ||||
|         # get the collection path | ||||
|         collection_path = self._getCollectionPath(ids[0]) | ||||
|         print collection_path | ||||
| 
 | ||||
|         # get the handler function | ||||
|         handler, hasReturnValue = self._getHandler(type, name) | ||||
| 
 | ||||
|         # parse the request body | ||||
|         data = self._parseRequestBody(req) | ||||
| 
 | ||||
|         # debug | ||||
|         from pprint import pprint | ||||
|         pprint(data) | ||||
|  | ||||
| @ -5,11 +5,32 @@ import tempfile | ||||
| import unittest | ||||
| 
 | ||||
| import AnkiServer | ||||
| from AnkiServer.apps.rest_app import CollectionHandlerGroup, DeckHandlerGroup | ||||
| from AnkiServer.collection import CollectionManager | ||||
| from AnkiServer.apps.rest_app import RestApp, CollectionHandlerGroup, DeckHandlerGroup | ||||
| 
 | ||||
| import anki | ||||
| import anki.storage | ||||
| 
 | ||||
| class RestAppTest(unittest.TestCase): | ||||
|     def setUp(self): | ||||
|         self.temp_dir = tempfile.mkdtemp() | ||||
|         self.collection_manager = CollectionManager() | ||||
|         self.rest_app = RestApp(self.temp_dir, collection_manager=self.collection_manager) | ||||
| 
 | ||||
|     def tearDown(self): | ||||
|         self.collection_manager.shutdown() | ||||
|         self.collection_manager = None | ||||
|         self.rest_app = None | ||||
|         shutil.rmtree(self.temp_dir) | ||||
| 
 | ||||
|     def test_parsePath(self): | ||||
|         tests = [ | ||||
|             ('collection/aoeu', ('collection', 'index', ['aoeu'])), | ||||
|         ] | ||||
| 
 | ||||
|         for path, result in tests: | ||||
|             self.assertEqual(self.rest_app._parsePath(path), result) | ||||
| 
 | ||||
| class CollectionTestBase(unittest.TestCase): | ||||
|     """Parent class for tests that need a collection set up and torn down.""" | ||||
| 
 | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user