anki-sync-server/tests/test_rest_app.py

168 lines
5.5 KiB
Python

import os
import shutil
import tempfile
import unittest
import logging
import mock
from mock import MagicMock
import AnkiServer
from AnkiServer.collection import CollectionManager
from AnkiServer.apps.rest_app import RestApp, CollectionHandlerGroup, DeckHandlerGroup
from webob.exc import *
import anki
import anki.storage
class RestAppTest(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
self.collection_manager = CollectionManager()
self.rest_app = RestApp(self.temp_dir, collection_manager=self.collection_manager)
# disable all but critical errors!
logging.disable(logging.CRITICAL)
def tearDown(self):
self.collection_manager.shutdown()
self.collection_manager = None
self.rest_app = None
shutil.rmtree(self.temp_dir)
def test_parsePath(self):
tests = [
('collection/user', ('collection', 'index', ['user'])),
('collection/user/handler', ('collection', 'handler', ['user'])),
('collection/user/deck/name', ('deck', 'index', ['user', 'name'])),
('collection/user/deck/name/handler', ('deck', 'handler', ['user', 'name'])),
('collection/user/deck/name/note/123', ('note', 'index', ['user', 'name', '123'])),
('collection/user/deck/name/note/123/handler', ('note', 'handler', ['user', 'name', '123'])),
# the leading slash should make no difference!
('/collection/user', ('collection', 'index', ['user'])),
]
for path, result in tests:
self.assertEqual(self.rest_app._parsePath(path), result)
def test_parsePath_not_found(self):
tests = [
'bad',
'bad/oaeu',
'collection',
'collection/user/handler/bad',
'',
'/',
]
for path in tests:
self.assertRaises(HTTPNotFound, self.rest_app._parsePath, path)
def test_getCollectionPath(self):
def fullpath(collection_id):
return os.path.normpath(os.path.join(self.temp_dir, collection_id, 'collection.anki2'))
# This is simple and straight forward!
self.assertEqual(self.rest_app._getCollectionPath('user'), fullpath('user'))
# These are dangerous - the user is trying to hack us!
dangerous = ['../user', '/etc/passwd', '/tmp/aBaBaB', '/root/.ssh/id_rsa']
for collection_id in dangerous:
self.assertRaises(HTTPBadRequest, self.rest_app._getCollectionPath, collection_id)
def test_getHandler(self):
def handlerOne():
pass
def handlerTwo():
pass
handlerTwo.hasReturnValue = False
self.rest_app.add_handler('collection', 'handlerOne', handlerOne)
self.rest_app.add_handler('deck', 'handlerTwo', handlerTwo)
(handler, hasReturnValue) = self.rest_app._getHandler('collection', 'handlerOne')
self.assertEqual(handler, handlerOne)
self.assertEqual(hasReturnValue, True)
(handler, hasReturnValue) = self.rest_app._getHandler('deck', 'handlerTwo')
self.assertEqual(handler, handlerTwo)
self.assertEqual(hasReturnValue, False)
# try some bad handler names and types
self.assertRaises(HTTPNotFound, self.rest_app._getHandler, 'collection', 'nonExistantHandler')
self.assertRaises(HTTPNotFound, self.rest_app._getHandler, 'nonExistantType', 'handlerOne')
def test_parseRequestBody(self):
req = MagicMock()
req.body = '{"key":"value"}'
data = self.rest_app._parseRequestBody(req)
self.assertEqual(data, {'key': 'value'})
self.assertEqual(data.keys(), ['key'])
self.assertEqual(type(data.keys()[0]), str)
# test some bad data
req.body = '{aaaaaaa}'
self.assertRaises(HTTPBadRequest, self.rest_app._parseRequestBody, req)
class CollectionTestBase(unittest.TestCase):
"""Parent class for tests that need a collection set up and torn down."""
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
self.collection_path = os.path.join(self.temp_dir, 'collection.anki2');
self.collection = anki.storage.Collection(self.collection_path)
def tearDown(self):
self.collection.close()
self.collection = None
shutil.rmtree(self.temp_dir)
class CollectionHandlerGroupTest(CollectionTestBase):
def setUp(self):
super(CollectionHandlerGroupTest, self).setUp()
self.handler = CollectionHandlerGroup()
def execute(self, name, data):
ids = ['collection_name']
func = getattr(self.handler, name)
return func(self.collection, data, ids)
def test_list_decks(self):
data = {}
ret = self.execute('list_decks', data)
# It contains only the 'Default' deck
self.assertEqual(len(ret), 1)
self.assertEqual(ret[0]['name'], 'Default')
def test_select_deck(self):
data = {'deck_id': '1'}
ret = self.execute('select_deck', data)
self.assertEqual(ret, None);
class DeckHandlerGroupTest(CollectionTestBase):
def setUp(self):
super(DeckHandlerGroupTest, self).setUp()
self.handler = DeckHandlerGroup()
def execute(self, name, data):
ids = ['collection_name', '1']
func = getattr(self.handler, name)
return func(self.collection, data, ids)
def test_next_card(self):
ret = self.execute('next_card', {})
self.assertEqual(ret, None)
# TODO: add a note programatically
if __name__ == '__main__':
unittest.main()