diff --git a/AnkiServer/importer.py b/AnkiServer/importer.py index 2c33f4c..ece6dda 100644 --- a/AnkiServer/importer.py +++ b/AnkiServer/importer.py @@ -6,6 +6,8 @@ from anki.importing.supermemo_xml import SupermemoXmlImporter from anki.importing.mnemo import MnemosyneImporter from anki.importing.pauker import PaukerImporter +__all__ = ['get_importer_class', 'import_file'] + importers = { 'text': TextImporter, 'apkg': AnkiPackageImporter, @@ -19,11 +21,80 @@ def get_importer_class(type): global importers return importers.get(type) -def import_file(importer_class, col, path): +def import_file(importer_class, col, path, allow_update = False): importer = importer_class(col, path) + if allow_update: + importer.allowUpdate = True + if importer.needMapper: importer.open() importer.run() +# +# Monkey patch anki.importing.anki2 to support updating existing notes. +# TODO: submit a patch to Anki! +# + +def _importNotes(self): + # build guid -> (id,mod,mid) hash & map of existing note ids + self._notes = {} + existing = {} + for id, guid, mod, mid in self.dst.db.execute( + "select id, guid, mod, mid from notes"): + self._notes[guid] = (id, mod, mid) + existing[id] = True + # we may need to rewrite the guid if the model schemas don't match, + # so we need to keep track of the changes for the card import stage + self._changedGuids = {} + # iterate over source collection + add = [] + dirty = [] + usn = self.dst.usn() + dupes = 0 + for note in self.src.db.execute( + "select * from notes"): + # turn the db result into a mutable list + note = list(note) + shouldAdd = self._uniquifyNote(note) + if shouldAdd: + # ensure id is unique + while note[0] in existing: + note[0] += 999 + existing[note[0]] = True + # bump usn + note[4] = usn + # update media references in case of dupes + note[6] = self._mungeMedia(note[MID], note[6]) + add.append(note) + dirty.append(note[0]) + # note we have the added the guid + self._notes[note[GUID]] = (note[0], note[3], note[MID]) + else: + # update existing note + newer = note[3] > mod + if self.allowUpdate and self._mid(mid) == mid and newer: + localNid = self._notes[guid][0] + note[0] = localNid + note[4] = usn + add.append(note) + dirty.append(note[0]) + else: + dupes += 1 + + if dupes: + self.log.append(_("Already in collection: %s.") % (ngettext( + "%d note", "%d notes", dupes) % dupes)) + # add to col + self.dst.db.executemany( + "insert or replace into notes values (?,?,?,?,?,?,?,?,?,?,?)", + add) + self.dst.updateFieldCache(dirty) + self.dst.tags.registerNotes(dirty) + +from anki.importing.anki2 import Anki2Importer, MID, GUID +from anki.lang import _, ngettext +Anki2Importer._importNotes = _importNotes +Anki2Importer.allowUpdate = False + diff --git a/tests/test_importer.py b/tests/test_importer.py new file mode 100644 index 0000000..448f534 --- /dev/null +++ b/tests/test_importer.py @@ -0,0 +1,113 @@ + +import os +import shutil +import tempfile +import unittest + +import mock +from mock import MagicMock, sentinel + +import AnkiServer +from AnkiServer.importer import get_importer_class, import_file + +import anki.storage + +# TODO: refactor into some kind of utility +def add_note(col, data): + from anki.notes import Note + + model = col.models.byName(data['model']) + + note = Note(col, model) + for name, value in data['fields'].items(): + note[name] = value + + if data.has_key('tags'): + note.setTagsFromStr(data['tags']) + + col.addNote(note) + +class ImporterTest(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.collection_path = os.path.join(self.temp_dir, 'collection.anki2') + self.collection = anki.storage.Collection(self.collection_path) + + def tearDown(self): + self.collection.close() + self.collection = None + shutil.rmtree(self.temp_dir) + + + # TODO: refactor into a parent class + def add_default_note(self, count=1): + data = { + 'model': 'Basic', + 'fields': { + 'Front': 'The front', + 'Back': 'The back', + }, + 'tags': "Tag1 Tag2", + } + for idx in range(0, count): + add_note(self.collection, data) + self.add_note(data) + + def test_resync(self): + from anki.exporting import AnkiPackageExporter + from anki.utils import intTime + + # create a new collection with a single note + src_collection = anki.storage.Collection(os.path.join(self.temp_dir, 'src_collection.anki2')) + add_note(src_collection, { + 'model': 'Basic', + 'fields': { + 'Front': 'The front', + 'Back': 'The back', + }, + 'tags': 'Tag1 Tag2', + }) + note_id = src_collection.findNotes('')[0] + note = src_collection.getNote(note_id) + self.assertEqual(note.id, note_id) + self.assertEqual(note['Front'], 'The front') + self.assertEqual(note['Back'], 'The back') + + # export to an .apkg file + dst1_path = os.path.join(self.temp_dir, 'export1.apkg') + exporter = AnkiPackageExporter(src_collection) + exporter.exportInto(dst1_path) + + # import it into the main collection + import_file(get_importer_class('apkg'), self.collection, dst1_path) + + # make sure the note exists + note = self.collection.getNote(note_id) + self.assertEqual(note.id, note_id) + self.assertEqual(note['Front'], 'The front') + self.assertEqual(note['Back'], 'The back') + + # now we change the source collection and re-export it + note = src_collection.getNote(note_id) + note['Front'] = 'The new front' + note.tags.append('Tag3') + note.flush(intTime()+1) + dst2_path = os.path.join(self.temp_dir, 'export2.apkg') + exporter = AnkiPackageExporter(src_collection) + exporter.exportInto(dst2_path) + + # first, import it without allow_update - no change should happen + import_file(get_importer_class('apkg'), self.collection, dst2_path) + note = self.collection.getNote(note_id) + self.assertEqual(note['Front'], 'The front') + self.assertEqual(note.tags, ['Tag1', 'Tag2']) + + # now, import it with allow_update=True, so the note should change + import_file(get_importer_class('apkg'), self.collection, dst2_path, allow_update=True) + note = self.collection.getNote(note_id) + self.assertEqual(note['Front'], 'The new front') + self.assertEqual(note.tags, ['Tag1', 'Tag2', 'Tag3']) + +if __name__ == '__main__': + unittest.main() + diff --git a/tests/test_rest_app.py b/tests/test_rest_app.py index 25730da..15aa53a 100644 --- a/tests/test_rest_app.py +++ b/tests/test_rest_app.py @@ -129,6 +129,7 @@ class CollectionTestBase(unittest.TestCase): shutil.rmtree(self.temp_dir) self.mock_app.reset_mock() + # TODO: refactor into some kind of utility def add_note(self, data): from anki.notes import Note @@ -143,6 +144,7 @@ class CollectionTestBase(unittest.TestCase): self.collection.addNote(note) + # TODO: refactor into a parent class def add_default_note(self, count=1): data = { 'model': 'Basic',