diff --git a/AnkiServer/apps/rest_app.py b/AnkiServer/apps/rest_app.py index 4d39a5c..3ffa1d4 100644 --- a/AnkiServer/apps/rest_app.py +++ b/AnkiServer/apps/rest_app.py @@ -14,10 +14,6 @@ import os, logging __all__ = ['RestApp', 'RestHandlerBase', 'hasReturnValue', 'noReturnValue'] -def hasReturnValue(func): - func.hasReturnValue = True - return func - def noReturnValue(func): func.hasReturnValue = False return func @@ -270,9 +266,12 @@ class CollectionHandler(RestHandlerBase): return nodes + @noReturnValue def add_note(self, col, data, ids): from anki.notes import Note + # TODO: I think this would be better with 'model' for the name + # and 'mid' for the model id. if type(data['model']) in (str, unicode): model = col.models.byName(data['model']) else: @@ -323,6 +322,75 @@ class CollectionHandler(RestHandlerBase): def sched_reset(self, col, data, ids): col.sched.reset() + +class ImportExportHandler(RestHandlerBase): + """Handler group for the 'collection' type, but it's not added by default.""" + + def _get_filedata(self, data): + import urllib2 + + if data.has_key('data'): + return data['data'] + + fd = None + try: + fd = urllib2.urlopen(data['url']) + filedata = fd.read() + finally: + if fd is not None: + fd.close() + + return filedata + + def _get_importer_class(self, data): + filetype = data['filetype'] + + # We do this as an if/elif/else guy, because I don't want to even import + # the modules until someone actually attempts to import the type + if filetype == 'text': + from anki.importing.csvfile import TextImporter + return TextImporter + elif filetype == 'apkg': + from anki.importing.apkg import AnkiPackageImporter + return AnkiPackageImporter + elif filetype == 'anki1': + from anki.importing.anki1 import Anki1Importer + return Anki1Importer + elif filetype == 'supermemo_xml': + from anki.importing.supermemo_xml import SupermemoXmlImporter + return SupermemoXmlImporter + elif filetype == 'mnemosyne': + from anki.importing.mnemo import MnemosyneImporter + return MnemosyneImporter + elif filetype == 'pauker': + from anki.importing.pauker import PaukerImporter + return PaukerImporter + else: + raise HTTPBadRequest("Unknown filetype '%s'" % filetype) + + def import_file(self, col, data, ids): + import tempfile + + # get the importer class + importer_class = self._get_importer_class(data) + + # get the file data + filedata = self._get_filedata(data) + + # write the file data to a temporary file + try: + path = None + with tempfile.NamedTemporaryFile('wt', delete=False) as fd: + path = fd.name + fd.write(filedata) + + importer = importer_class(col, path) + importer.open() + importer.run() + finally: + if path is not None: + os.unlink(path) + class ModelHandler(RestHandlerBase): """Default handler group for 'model' type.""" diff --git a/tests/test_rest_app.py b/tests/test_rest_app.py index b6360f5..db0bb30 100644 --- a/tests/test_rest_app.py +++ b/tests/test_rest_app.py @@ -11,7 +11,7 @@ from mock import MagicMock import AnkiServer from AnkiServer.collection import CollectionManager -from AnkiServer.apps.rest_app import RestApp, CollectionHandler, NoteHandler, ModelHandler, DeckHandler, CardHandler +from AnkiServer.apps.rest_app import RestApp, CollectionHandler, ImportExportHandler, NoteHandler, ModelHandler, DeckHandler, CardHandler from webob.exc import * @@ -242,6 +242,62 @@ class CollectionHandlerTest(CollectionTestBase): self.assertEqual(note['Back'], 'The back') self.assertEqual(note.tags, ['Tag1', 'Tag2']) +class ImportExportHandlerTest(CollectionTestBase): + export_rows = [ + ['Card front 1', 'Card back 1', 'Tag1 Tag2'], + ['Card front 2', 'Card back 2', 'Tag1 Tag3'], + ] + + def setUp(self): + super(ImportExportHandlerTest, self).setUp() + self.handler = ImportExportHandler() + + def execute(self, name, data): + ids = ['collection_name'] + func = getattr(self.handler, name) + return func(self.collection, data, ids) + + def generate_text_export(self): + # Create a simple export file + export_data = '' + for row in self.export_rows: + export_data += '\t'.join(row) + '\n' + export_path = os.path.join(self.temp_dir, 'export.txt') + with file(export_path, 'wt') as fd: + fd.write(export_data) + + return (export_data, export_path) + + def check_import(self): + note_ids = self.collection.findNotes('') + notes = [self.collection.getNote(note_id) for note_id in note_ids] + self.assertEqual(len(notes), len(self.export_rows)) + + for index, test_data in enumerate(self.export_rows): + self.assertEqual(notes[index]['Front'], test_data[0]) + self.assertEqual(notes[index]['Back'], test_data[1]) + self.assertEqual(' '.join(notes[index].tags), test_data[2]) + + def test_import_text_data(self): + (export_data, export_path) = self.generate_text_export() + + data = { + 'filetype': 'text', + 'data': export_data, + } + ret = self.execute('import_file', data) + self.check_import() + + def test_import_text_url(self): + (export_data, export_path) = self.generate_text_export() + + data = { + 'filetype': 'text', + 'url': 'file://' + os.path.realpath(export_path), + } + ret = self.execute('import_file', data) + self.check_import() + class DeckHandlerTest(CollectionTestBase): def setUp(self): super(DeckHandlerTest, self).setUp()