diff --git a/ankisyncd/users.py b/ankisyncd/users.py new file mode 100644 index 0000000..ce83686 --- /dev/null +++ b/ankisyncd/users.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- + + +import binascii +from contextlib import closing +import hashlib +import logging +import os +import sqlite3 as sqlite + + +class UserManager: + def __init__(self, auth_db_path, collection_path): + self.auth_db_path = auth_db_path + self.collection_path = collection_path + + def auth_db_exists(self): + return os.path.isfile(self.auth_db_path) + + def user_list(self): + if not self.auth_db_exists(): + self.create_auth_db() + return [] + else: + conn = sqlite.connect(self.auth_db_path) + cursor = conn.cursor() + cursor.execute("SELECT user FROM auth") + rows = cursor.fetchall() + conn.commit() + conn.close() + + return [row[0] for row in rows] + + def user_exists(self, username): + users = self.user_list() + return username in users + + def del_user(self, username): + if not self.auth_db_exists(): + self.create_auth_db() + + conn = sqlite.connect(self.auth_db_path) + cursor = conn.cursor() + logging.info("Removing user '{}' from auth db." + .format(username)) + cursor.execute("DELETE FROM auth WHERE user=?", (username,)) + conn.commit() + conn.close() + + def add_user(self, username, password): + self._add_user_to_auth_db(username, password) + self._create_user_dir(username) + + def add_users(self, users_data): + for username, password in users_data: + self.add_user(username, password) + + def _add_user_to_auth_db(self, username, password): + if not self.auth_db_exists(): + self.create_auth_db() + + pass_hash = self._create_pass_hash(username, password) + + conn = sqlite.connect(self.auth_db_path) + cursor = conn.cursor() + logging.info("Adding user '{}' to auth db.".format(username)) + cursor.execute("INSERT INTO auth VALUES (?, ?)", + (username, pass_hash)) + conn.commit() + conn.close() + + @staticmethod + def _create_pass_hash(username, password): + salt = binascii.b2a_hex(os.urandom(8)) + pass_hash = (hashlib.sha256(username + password + salt).hexdigest() + + salt) + return pass_hash + + def create_auth_db(self): + conn = sqlite.connect(self.auth_db_path) + cursor = conn.cursor() + logging.info("Creating auth db at {}." + .format(self.auth_db_path)) + cursor.execute("""CREATE TABLE IF NOT EXISTS auth + (user VARCHAR PRIMARY KEY, hash VARCHAR)""") + conn.commit() + conn.close() + + def _create_user_dir(self, username): + user_dir_path = os.path.join(self.collection_path, username) + if not os.path.isdir(user_dir_path): + logging.info("Creating collection directory for user '{}' at {}" + .format(username, user_dir_path)) + os.makedirs(user_dir_path) diff --git a/tests/helpers/collection_utils.py b/tests/helpers/collection_utils.py new file mode 100644 index 0000000..3eb3483 --- /dev/null +++ b/tests/helpers/collection_utils.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- + + +import os +import shutil +import tempfile + + +from anki import Collection +from helpers.file_utils import FileUtils + + +class CollectionUtils(object): + """ + Provides utility methods for creating, inspecting and manipulating anki + collections. + """ + + def __init__(self): + self.collections_to_close = [] + self.fileutils = FileUtils() + self.master_db_path = None + + def __create_master_col(self): + """ + Creates an empty master anki db that will be copied on each request + for a new db. This is more efficient than initializing a new db each + time. + """ + + file_descriptor, file_path = tempfile.mkstemp(suffix=".anki2") + os.close(file_descriptor) + os.unlink(file_path) # We only need the file path. + master_col = Collection(file_path) + self.__mark_col_paths_for_deletion(master_col) + master_col.db.close() + self.master_db_path = file_path + + self.fileutils.mark_for_deletion(self.master_db_path) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.clean_up() + + def __mark_collection_for_closing(self, collection): + self.collections_to_close.append(collection) + + def __mark_col_paths_for_deletion(self, collection): + """ + Marks the paths of all the database files and directories managed by + the collection for later deletion. + """ + self.fileutils.mark_for_deletion(collection.path) + self.fileutils.mark_for_deletion(collection.media.dir()) + self.fileutils.mark_for_deletion(collection.media.col.path) + + def clean_up(self): + """ + Removes all files created by the Collection objects we issued and the + master db file. + """ + + # Close collections. + for col in self.collections_to_close: + col.close() # This also closes the media col. + self.collections_to_close = [] + + # Remove the files created by the collections. + self.fileutils.clean_up() + + self.master_db_path = None + + def create_empty_col(self): + """ + Returns a Collection object using a copy of our master db file. + """ + + if self.master_db_path is None: + self.__create_master_col() + + file_descriptor, file_path = tempfile.mkstemp(suffix=".anki2") + + # Overwrite temp file with a copy of our master db. + shutil.copy(self.master_db_path, file_path) + collection = Collection(file_path) + + self.__mark_collection_for_closing(collection) + self.__mark_col_paths_for_deletion(collection) + return collection + + @staticmethod + def create_col_from_existing_db(db_file_path): + """ + Returns a Collection object created from an existing anki db file. + """ + + return Collection(db_file_path) diff --git a/tests/helpers/db_utils.py b/tests/helpers/db_utils.py new file mode 100644 index 0000000..24fd650 --- /dev/null +++ b/tests/helpers/db_utils.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- + + +import os +import sqlite3 +import subprocess + + +from helpers.file_utils import FileUtils + + +class DBUtils(object): + """Provides methods for creating and comparing sqlite databases.""" + + def __init__(self): + self.fileutils = FileUtils() + + def clean_up(self): + self.fileutils.clean_up() + + def create_sqlite_db_with_sql(self, sql_string): + """ + Creates an SQLite db and executes the passed sql statements on it. + + :param sql_string: the sql statements to execute on the newly created + db + :return: the path to the created db file + """ + + db_path = self.fileutils.create_file_path(suffix=".anki2") + connection = sqlite3.connect(db_path) + cursor = connection.cursor() + cursor.executescript(sql_string) + connection.commit() + connection.close() + + return db_path + + @staticmethod + def sqlite_db_to_sql_string(database): + """ + Returns a string containing the sql export of the database. Used for + debugging. + + :param database: either the path to the SQLite db file or an open + connection to it + :return: a string representing the sql export of the database + """ + + if type(database) == str: + connection = sqlite3.connect(database) + else: + connection = database + + res = '\n'.join(connection.iterdump()) + + if type(database) == str: + connection.close() + + return res + + def media_dbs_differ(self, left_db_path, right_db_path, compare_timestamps=False): + """ + Compares two media sqlite database files for equality. mtime and dirMod + timestamps are not considered when comparing. + + :param left_db_path: path to the left db file + :param right_db_path: path to the right db file + :param compare_timestamps: flag determining if timestamp values + (media.mtime and meta.dirMod) are included + in the comparison + :return: True if the specified databases differ, False else + """ + + if not os.path.isfile(left_db_path): + raise IOError("file '" + left_db_path + "' does not exist") + elif not os.path.isfile(right_db_path): + raise IOError("file '" + right_db_path + "' does not exist") + + # Create temporary copies of the files to act on. + left_db_path = self.fileutils.create_file_copy(left_db_path) + right_db_path = self.fileutils.create_file_copy(right_db_path) + + if not compare_timestamps: + # Set all timestamps that are not NULL to 0. + for dbPath in [left_db_path, right_db_path]: + connection = sqlite3.connect(dbPath) + + connection.execute("""UPDATE media SET mtime=0 + WHERE mtime IS NOT NULL""") + + connection.execute("""UPDATE meta SET dirMod=0 + WHERE rowid=1""") + connection.commit() + connection.close() + + return self.__sqlite_dbs_differ(left_db_path, right_db_path) + + def __sqlite_dbs_differ(self, left_db_path, right_db_path): + """ + Uses the sqldiff cli tool to compare two sqlite files for equality. + Returns True if the databases differ, False if they don't. + + :param left_db_path: path to the left db file + :param right_db_path: path to the right db file + :return: True if the specified databases differ, False else + """ + + command = ["/bin/sqldiff", left_db_path, right_db_path] + + try: + child_process = subprocess.Popen(command, + shell=False, + stdout=subprocess.PIPE) + stdout, stderr = child_process.communicate() + exit_code = child_process.returncode + + if exit_code != 0 or stderr is not None: + raise RuntimeError("Command {} encountered an error, exit " + "code: {}, stderr: {}" + .format(" ".join(command), + exit_code, + stderr)) + + # Any output from sqldiff means the databases differ. + return stdout != "" + except OSError as err: + raise err diff --git a/tests/helpers/file_utils.py b/tests/helpers/file_utils.py new file mode 100644 index 0000000..35764cd --- /dev/null +++ b/tests/helpers/file_utils.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- + + +from cStringIO import StringIO +import json +import logging +import logging.config +import os +import random +import shutil +import tempfile +import unicodedata +import zipfile + + +from anki.consts import SYNC_ZIP_SIZE +from anki.utils import checksum + + +class FileUtils(object): + """ + Provides utility methods for creating temporary files and directories. All + created files and dirs are recursively removed when clean_up() is called. + Supports the with statement. + """ + + def __init__(self): + self.paths_to_delete = [] + + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, traceback): + self.clean_up() + + def clean_up(self): + """ + Recursively removes all files and directories created by this instance. + """ + + # Change cwd to a dir we're not about to delete so later calls to + # os.getcwd() and similar functions don't raise Exceptions. + os.chdir("/tmp") + + # Error callback for shutil.rmtree(). + def on_error(func, path, excinfo): + logging.error("Error removing file: func={}, path={}, excinfo={}" + .format(func, path, excinfo)) + + for path in self.paths_to_delete: + if os.path.isfile(path): + logging.debug("Removing temporary file '{}'.".format(path)) + os.remove(path) + elif os.path.isdir(path): + logging.debug(("Removing temporary dir tree '{}' with " + + "files {}").format(path, os.listdir(path))) + shutil.rmtree(path, onerror=on_error) + + self.paths_to_delete = [] + + def mark_for_deletion(self, path): + self.paths_to_delete.append(path) + + def create_file(self, suffix='', prefix='tmp'): + file_descriptor, file_path = tempfile.mkstemp(suffix=suffix, + prefix=prefix) + self.mark_for_deletion(file_path) + return file_path + + def create_dir(self, suffix='', prefix='tmp'): + dir_path = tempfile.mkdtemp(suffix=suffix, + prefix=prefix) + self.mark_for_deletion(dir_path) + return dir_path + + def create_file_path(self, suffix='', prefix='tmp'): + """Generates a file path.""" + + file_path = self.create_file(suffix, prefix) + os.unlink(file_path) + return file_path + + def create_dir_path(self, suffix='', prefix='tmp'): + dir_path = self.create_dir(suffix, prefix) + os.rmdir(dir_path) + return dir_path + + def create_named_file(self, filename, file_contents=None): + """ + Creates a temporary file with a custom name within a new temporary + directory and marks that parent dir for recursive deletion method. + """ + + # We need to create a parent directory for the file so we can freely + # choose the file name . + temp_file_parent_dir = tempfile.mkdtemp(prefix="anki") + self.mark_for_deletion(temp_file_parent_dir) + + file_path = os.path.join(temp_file_parent_dir, filename) + + if file_contents is not None: + open(file_path, 'w').write(file_contents) + + return file_path + + def create_named_file_path(self, filename): + file_path = self.create_named_file(filename) + return file_path + + def create_file_copy(self, path): + basename = os.path.basename(path) + temp_file_path = self.create_named_file_path(basename) + shutil.copyfile(path, temp_file_path) + return temp_file_path + + def create_named_files(self, filenames_and_data): + """ + Creates temporary files within the same new temporary parent directory + and marks that parent for recursive deletion. + + :param filenames_and_data: list of tuples (filename, file contents) + :return: list of paths to the created files + """ + + temp_files_parent_dir = tempfile.mkdtemp(prefix="anki") + self.mark_for_deletion(temp_files_parent_dir) + + file_paths = [] + for filename, file_contents in filenames_and_data: + path = os.path.join(temp_files_parent_dir, filename) + file_paths.append(path) + if file_contents is not None: + open(path, 'w').write(file_contents) + + return file_paths + + @staticmethod + def create_zip_with_existing_files(file_paths): + """ + The method zips existing files and returns the zip data. Logic is + adapted from Anki Desktop's MediaManager.mediaChangesZip(). + + :param file_paths: the paths of the files to include in the zip + :type file_paths: list + :return: the data of the created zip file + """ + + file_buffer = StringIO() + zip_file = zipfile.ZipFile(file_buffer, + 'w', + compression=zipfile.ZIP_DEFLATED) + + meta = [] + sz = 0 + + for count, filePath in enumerate(file_paths): + zip_file.write(filePath, str(count)) + normname = unicodedata.normalize( + "NFC", + os.path.basename(filePath) + ) + meta.append((normname, str(count))) + + sz += os.path.getsize(filePath) + if sz >= SYNC_ZIP_SIZE: + break + + zip_file.writestr("_meta", json.dumps(meta)) + zip_file.close() + + return file_buffer.getvalue() diff --git a/tests/test_users.py b/tests/test_users.py new file mode 100644 index 0000000..76614ab --- /dev/null +++ b/tests/test_users.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- + + +import os +import unittest + + +from ankisyncd.users import UserManager +from helpers.file_utils import FileUtils + + +class SimpleUserManagerTest(unittest.TestCase): + _good_test_un = 'username' + _good_test_pw = 'password' + + _bad_test_un = 'notAUsername' + _bad_test_pw = 'notAPassword' + + @classmethod + def setUpClass(cls): + cls.fileutils = FileUtils() + + @classmethod + def tearDownClass(cls): + cls.fileutils.clean_up() + cls.fileutils = None + + def setUp(self): + self.auth_db_path = self.fileutils.create_file_path(suffix='auth.db') + self.collection_path = self.fileutils.create_dir_path() + self.user_manager = UserManager(self.auth_db_path, + self.collection_path) + + def tearDown(self): + self.user_manager = None + + def test_auth_db_exists(self): + self.assertFalse(self.user_manager.auth_db_exists()) + + self.user_manager.create_auth_db() + self.assertTrue(self.user_manager.auth_db_exists()) + + os.unlink(self.auth_db_path) + self.assertFalse(self.user_manager.auth_db_exists()) + + def test_user_list(self): + username = "my_username" + password = "my_password" + self.user_manager.create_auth_db() + + self.assertEqual(self.user_manager.user_list(), []) + + self.user_manager.add_user(username, password) + self.assertEqual(self.user_manager.user_list(), [username]) + + def test_user_exists(self): + username = "my_username" + password = "my_password" + self.user_manager.create_auth_db() + self.user_manager.add_user(username, password) + self.assertTrue(self.user_manager.user_exists(username)) + + self.user_manager.del_user(username) + self.assertFalse(self.user_manager.user_exists(username)) + + def test_del_user(self): + username = "my_username" + password = "my_password" + collection_dir_path = os.path.join(self.collection_path, username) + self.user_manager.create_auth_db() + self.user_manager.add_user(username, password) + self.user_manager.del_user(username) + + # User should be gone. + self.assertFalse(self.user_manager.user_exists(username)) + # User's collection dir should still be there. + self.assertTrue(os.path.isdir(collection_dir_path)) + + def test_add_user(self): + username = "my_username" + password = "my_password" + expected_dir_path = os.path.join(self.collection_path, username) + self.user_manager.create_auth_db() + + self.assertFalse(os.path.exists(expected_dir_path)) + + self.user_manager.add_user(username, password) + + # User db entry and collection dir should be present. + self.assertTrue(self.user_manager.user_exists(username)) + self.assertTrue(os.path.isdir(expected_dir_path)) + + def test_add_users(self): + users_data = [("my_first_username", "my_first_password"), + ("my_second_username", "my_second_password")] + self.user_manager.create_auth_db() + self.user_manager.add_users(users_data) + + user_list = self.user_manager.user_list() + self.assertIn("my_first_username", user_list) + self.assertIn("my_second_username", user_list) + self.assertTrue(os.path.isdir(os.path.join(self.collection_path, + "my_first_username"))) + self.assertTrue(os.path.isdir(os.path.join(self.collection_path, + "my_second_username"))) + + def test__add_user_to_auth_db(self): + username = "my_username" + password = "my_password" + self.user_manager.create_auth_db() + self.user_manager.add_user(username, password) + + self.assertTrue(self.user_manager.user_exists(username)) + + def test_create_auth_db(self): + self.assertFalse(os.path.exists(self.auth_db_path)) + self.user_manager.create_auth_db() + self.assertTrue(os.path.isfile(self.auth_db_path)) + + def test__create_user_dir(self): + username = "my_username" + expected_dir_path = os.path.join(self.collection_path, username) + self.assertFalse(os.path.exists(expected_dir_path)) + self.user_manager._create_user_dir(username) + self.assertTrue(os.path.isdir(expected_dir_path)) + + +if __name__ == '__main__': + unittest.main()