anki-sync-server/tests/helpers/monkey_patches.py
2017-11-04 02:06:42 +01:00

93 lines
3.0 KiB
Python

# -*- coding: utf-8 -*-
import os
import sqlite3 as sqlite
from anki.media import MediaManager
from anki.storage import DB
mediamanager_orig_funcs = {
"findChanges": None,
"mediaChangesZip": None,
"addFilesFromZip": None,
"syncDelete": None
}
db_orig_funcs = {
"__init__": None
}
def monkeypatch_mediamanager():
"""
Monkey patches anki.media.MediaManager's methods so they chdir to
self.dir() before acting on its media directory and chdir back to the
original cwd after finishing.
"""
def make_cwd_safe(original_func):
mediamanager_orig_funcs["findChanges"] = MediaManager.findChanges
mediamanager_orig_funcs["mediaChangesZip"] = MediaManager.mediaChangesZip
mediamanager_orig_funcs["addFilesFromZip"] = MediaManager.addFilesFromZip
mediamanager_orig_funcs["syncDelete"] = MediaManager.syncDelete
def wrapper(instance, *args):
old_cwd = os.getcwd()
os.chdir(instance.dir())
res = original_func(instance, *args)
os.chdir(old_cwd)
return res
return wrapper
MediaManager.findChanges = make_cwd_safe(MediaManager.findChanges)
MediaManager.mediaChangesZip = make_cwd_safe(MediaManager.mediaChangesZip)
MediaManager.addFilesFromZip = make_cwd_safe(MediaManager.addFilesFromZip)
MediaManager.syncDelete = make_cwd_safe(MediaManager.syncDelete)
def unpatch_mediamanager():
"""Undoes monkey patches to Anki's MediaManager."""
MediaManager.findChanges = mediamanager_orig_funcs["findChanges"]
MediaManager.mediaChangesZip = mediamanager_orig_funcs["mediaChangesZip"]
MediaManager.addFilesFromZip = mediamanager_orig_funcs["addFilesFromZip"]
MediaManager.syncDelete = mediamanager_orig_funcs["syncDelete"]
mediamanager_orig_funcs["findChanges"] = None
mediamanager_orig_funcs["mediaChangesZip"] = None
mediamanager_orig_funcs["mediaChangesZip"] = None
mediamanager_orig_funcs["mediaChangesZip"] = None
def monkeypatch_db():
"""
Monkey patches Anki's DB.__init__ to connect to allow access to the db
connection from more than one thread, so that we can inspect and modify
the db created in the app in our test code.
"""
db_orig_funcs["__init__"] = DB.__init__
def patched___init__(self, path, text=None, timeout=0):
# Code taken from Anki's DB.__init__()
encpath = path
if isinstance(encpath, str):
encpath = path.encode("utf-8")
# Allow more than one thread to use this connection.
self._db = sqlite.connect(encpath,
timeout=timeout,
check_same_thread=False)
if text:
self._db.text_factory = text
self._path = path
self.echo = os.environ.get("DBECHO") # echo db modifications
self.mod = False # flag that db has been modified?
DB.__init__ = patched___init__
def unpatch_db():
"""Undoes monkey patches to Anki's DB."""
DB.__init__ = db_orig_funcs["__init__"]
db_orig_funcs["__init__"] = None