anki-sync-server/AnkiServer/apps/rest_app.py
2013-07-24 17:15:59 +01:00

662 lines
20 KiB
Python

from webob.dec import wsgify
from webob.exc import *
from webob import Response
#from pprint import pprint
try:
import simplejson as json
from simplejson import JSONDecodeError
except ImportError:
import json
JSONDecodeError = ValueError
import os, logging
import anki.lang
from anki.lang import _ as t
__all__ = ['RestApp', 'RestHandlerBase', 'noReturnValue']
def noReturnValue(func):
func.hasReturnValue = False
return func
class RestHandlerBase(object):
"""Parent class for a handler group."""
hasReturnValue = True
class _RestHandlerWrapper(RestHandlerBase):
"""Wrapper for functions that we can't modify."""
def __init__(self, func_name, func, hasReturnValue=True):
self.func_name = func_name
self.func = func
self.hasReturnValue = hasReturnValue
def __call__(self, *args, **kw):
return self.func(*args, **kw)
class RestHandlerRequest(object):
def __init__(self, app, data, ids, session):
self.app = app
self.data = data
self.ids = ids
self.session = session
def copy(self):
return RestHandlerRequest(self.app, self.data.copy(), self.ids[:], self.session)
def __eq__(self, other):
return self.app == other.app and self.data == other.data and self.ids == other.ids and self.session == other.session
class RestApp(object):
"""A WSGI app that implements RESTful operations on Collections, Decks and Cards."""
# Defines not only the valid handler types, but their position in the URL string
handler_types = ['collection', ['model', 'note', 'deck', 'card']]
def __init__(self, data_root, allowed_hosts='*', setup_new_collection=None, use_default_handlers=True, collection_manager=None):
from AnkiServer.threading import getCollectionManager
self.data_root = os.path.abspath(data_root)
self.allowed_hosts = allowed_hosts
self.setup_new_collection = setup_new_collection
if collection_manager is not None:
self.collection_manager = collection_manager
else:
self.collection_manager = getCollectionManager()
self.handlers = {}
for type_list in self.handler_types:
if type(type_list) is not list:
type_list = [type_list]
for handler_type in type_list:
self.handlers[handler_type] = {}
if use_default_handlers:
self.add_handler_group('collection', CollectionHandler())
self.add_handler_group('note', NoteHandler())
self.add_handler_group('model', ModelHandler())
self.add_handler_group('deck', DeckHandler())
self.add_handler_group('card', CardHandler())
# hold per collection session data
self.sessions = {}
def add_handler(self, type, name, handler):
"""Adds a callback handler for a type (collection, deck, card) with a unique name.
- 'type' is the item that will be worked on, for example: collection, deck, and card.
- 'name' is a unique name for the handler that gets used in the URL.
- 'handler' is a callable that takes (collection, data, ids).
"""
if self.handlers[type].has_key(name):
raise "Handler already for %(type)s/%(name)s exists!"
self.handlers[type][name] = handler
def add_handler_group(self, type, group):
"""Adds several handlers for every public method on an object descended from RestHandlerBase.
This allows you to create a single class with several methods, so that you can quickly
create a group of related handlers."""
import inspect
for name, method in inspect.getmembers(group, predicate=inspect.ismethod):
if not name.startswith('_'):
if hasattr(group, 'hasReturnValue') and not hasattr(method, 'hasReturnValue'):
method = _RestHandlerWrapper(group.__class__.__name__ + '.' + name, method, group.hasReturnValue)
self.add_handler(type, name, method)
def execute_handler(self, type, name, col, req):
"""Executes the handler with the given type and name, passing in the col and req as arguments."""
handler, hasReturnValue = self._getHandler(type, name)
ret = handler(col, req)
if hasReturnValue:
return ret
def _checkRequest(self, req):
"""Raises an exception if the request isn't allowed or valid for some reason."""
if self.allowed_hosts != '*':
try:
remote_addr = req.headers['X-Forwarded-For']
except KeyError:
remote_addr = req.remote_addr
if remote_addr != self.allowed_hosts:
raise HTTPForbidden()
if req.method != 'POST':
raise HTTPMethodNotAllowed(allow=['POST'])
def _parsePath(self, path):
"""Takes a request path and returns a tuple containing the handler type, name
and a list of ids.
Raises an HTTPNotFound exception if the path is invalid."""
if path in ('', '/'):
raise HTTPNotFound()
# split the URL into a list of parts
if path[0] == '/':
path = path[1:]
parts = path.split('/')
# pull the type and context from the URL parts
handler_type = None
ids = []
for type_list in self.handler_types:
if len(parts) == 0:
break
# some URL positions can have multiple types
if type(type_list) is not list:
type_list = [type_list]
# get the handler_type
if parts[0] not in type_list:
break
handler_type = parts.pop(0)
# add the id to the id list
if len(parts) > 0:
ids.append(parts.pop(0))
# break if we don't have enough parts to make a new type/id pair
if len(parts) < 2:
break
# sanity check to make sure the URL is valid
if len(parts) > 1 or len(ids) == 0:
raise HTTPNotFound()
# get the handler name
if len(parts) == 0:
name = 'index'
else:
name = parts[0]
return (handler_type, name, ids)
def _getCollectionPath(self, collection_id):
"""Returns the path to the collection based on the collection_id from the request.
Raises HTTPBadRequest if the collection_id is invalid."""
path = os.path.normpath(os.path.join(self.data_root, collection_id, 'collection.anki2'))
if path[0:len(self.data_root)] != self.data_root:
# attempting to escape our data jail!
raise HTTPBadRequest('"%s" is not a valid collection' % collection_id)
return path
def _getHandler(self, type, name):
"""Returns a tuple containing handler function for this type and name, and a boolean flag
if that handler has a return value.
Raises an HTTPNotFound exception if the handler doesn't exist."""
# get the handler function
try:
handler = self.handlers[type][name]
except KeyError:
raise HTTPNotFound()
# get if we have a return value
hasReturnValue = True
if hasattr(handler, 'hasReturnValue'):
hasReturnValue = handler.hasReturnValue
return (handler, hasReturnValue)
def _parseRequestBody(self, req):
"""Parses the request body (JSON) into a Python dict and returns it.
Raises an HTTPBadRequest exception if the request isn't valid JSON."""
try:
data = json.loads(req.body)
except JSONDecodeError, e:
logging.error(req.path+': Unable to parse JSON: '+str(e), exc_info=True)
raise HTTPBadRequest()
# make the keys into non-unicode strings
data = dict([(str(k), v) for k, v in data.items()])
return data
@wsgify
def __call__(self, req):
# make sure the request is valid
self._checkRequest(req)
# parse the path
type, name, ids = self._parsePath(req.path)
# get the collection path
collection_path = self._getCollectionPath(ids[0])
print collection_path
# get the handler function
handler, hasReturnValue = self._getHandler(type, name)
# parse the request body
data = self._parseRequestBody(req)
# get the users session
try:
session = self.sessions[ids[0]]
except KeyError:
session = self.sessions[ids[0]] = {}
# debug
from pprint import pprint
pprint(data)
# run it!
col = self.collection_manager.get_collection(collection_path, self.setup_new_collection)
handler_request = RestHandlerRequest(self, data, ids, session)
try:
output = col.execute(handler, [handler_request], {}, hasReturnValue)
except Exception, e:
logging.error(e)
return HTTPInternalServerError()
if output is None:
return Response('', content_type='text/plain')
else:
return Response(json.dumps(output), content_type='application/json')
class CollectionHandler(RestHandlerBase):
"""Default handler group for 'collection' type."""
#
# MODELS - Store fields definitions and templates for notes
#
def list_models(self, col, req):
# This is already a list of dicts, so it doesn't need to be serialized
return col.models.all()
def find_model_by_name(self, col, req):
# This is already a list of dicts, so it doesn't need to be serialized
return col.models.byName(req.data['model'])
#
# NOTES - Information (in fields per the model) that can generate a card
# (based on a template from the model).
#
def find_notes(self, col, req):
query = req.data.get('query', '')
ids = col.findNotes(query)
if req.data.get('preload', False):
nodes = [NoteHandler._serialize(col.getNote(id)) for id in ids]
else:
nodes = [{'id': id} for id in ids]
return nodes
@noReturnValue
def add_note(self, col, req):
from anki.notes import Note
# TODO: I think this would be better with 'model' for the name
# and 'mid' for the model id.
if type(req.data['model']) in (str, unicode):
model = col.models.byName(req.data['model'])
else:
model = col.models.get(req.data['model'])
note = Note(col, model)
for name, value in req.data['fields'].items():
note[name] = value
if req.data.has_key('tags'):
note.setTagsFromStr(req.data['tags'])
col.addNote(note)
#
# DECKS - Groups of cards
#
def list_decks(self, col, req):
# This is already a list of dicts, so it doesn't need to be serialized
return col.decks.all()
@noReturnValue
def select_deck(self, col, req):
col.decks.select(req.data['deck_id'])
#
# CARD - A specific card in a deck with a history of review (generated from
# a note based on the template).
#
def find_cards(self, col, req):
query = req.data.get('query', '')
ids = col.findCards(query)
if req.data.get('preload', False):
cards = [CardHandler._serialize(col.getCard(id), req.data) for id in req.ids]
else:
cards = [{'id': id} for id in req.ids]
return cards
#
# SCHEDULER - Controls card review, ie. intervals, what cards are due, answering a card, etc.
#
def reset_scheduler(self, col, req):
if req.data.has_key('deck'):
deck = DeckHandler._get_deck(col, req.data['deck'])
col.decks.select(deck['id'])
col.sched.reset()
counts = col.sched.counts()
return {
'new_cards': counts[0],
'learning_cards': counts[1],
'review_cards': counts[1],
}
def extend_scheduler_limits(self, col, req):
new_cards = int(req.data.get('new_cards', 0))
review_cards = int(req.data.get('review_cards', 0))
col.sched.extendLimits(new_cards, review_cards)
col.sched.reset()
button_labels = ['Easy', 'Good', 'Hard']
def _get_answer_buttons(self, col, card):
l = []
# Put the correct number of buttons
cnt = col.sched.answerButtons(card)
for idx in range(0, cnt - 1):
l.append(self.button_labels[idx])
l.append('Again')
l.reverse()
# Loop through and add the ease, estimated time (in seconds) and other info
return [{
'ease': ease,
'label': label,
'string_label': t(label),
'interval': col.sched.nextIvl(card, ease),
'string_interval': col.sched.nextIvlStr(card, ease),
} for ease, label in enumerate(l, 1)]
def next_card(self, col, req):
if req.data.has_key('deck'):
deck = DeckHandler._get_deck(col, req.data['deck'])
col.decks.select(deck['id'])
card = col.sched.getCard()
if card is None:
return None
# put it into the card cache to be removed when we answer it
#if not req.session.has_key('cards'):
# req.session['cards'] = {}
#req.session['cards'][long(card.id)] = card
card.startTimer()
result = CardHandler._serialize(card, req.data)
result['answer_buttons'] = self._get_answer_buttons(col, card)
return result
@noReturnValue
def answer_card(self, col, req):
import time
card_id = long(req.data['id'])
ease = int(req.data['ease'])
card = col.getCard(card_id)
if card.timerStarted is None:
card.timerStarted = float(req.data.get('timeStarted', time.time()))
col.sched.answerCard(card, ease)
@noReturnValue
def suspend_cards(self, col, req):
card_ids = req.data['ids']
col.sched.suspendCards(card_ids)
@noReturnValue
def unsuspend_cards(self, col, req):
card_ids = req.data['ids']
col.sched.unsuspendCards(card_ids)
#
# GLOBAL / MISC
#
@noReturnValue
def set_language(self, col, req):
anki.lang.setLang(req.data['code'])
class ImportExportHandler(RestHandlerBase):
"""Handler group for the 'collection' type, but it's not added by default."""
def _get_filedata(self, data):
import urllib2
if data.has_key('data'):
return data['data']
fd = None
try:
fd = urllib2.urlopen(data['url'])
filedata = fd.read()
finally:
if fd is not None:
fd.close()
return filedata
def _get_importer_class(self, data):
filetype = data['filetype']
from AnkiServer.importer import get_importer_class
importer_class = get_importer_class(filetype)
if importer_class is None:
raise HTTPBadRequest("Unknown filetype '%s'" % filetype)
return importer_class
def import_file(self, col, req):
import AnkiServer.importer
import tempfile
# get the importer class
importer_class = self._get_importer_class(req.data)
# get the file data
filedata = self._get_filedata(req.data)
# write the file data to a temporary file
try:
path = None
with tempfile.NamedTemporaryFile('wt', delete=False) as fd:
path = fd.name
fd.write(filedata)
AnkiServer.importer.import_file(importer_class, col, path)
finally:
if path is not None:
os.unlink(path)
class ModelHandler(RestHandlerBase):
"""Default handler group for 'model' type."""
def field_names(self, col, req):
model = col.models.get(req.ids[1])
if model is None:
raise HTTPNotFound()
return col.models.fieldNames(model)
class NoteHandler(RestHandlerBase):
"""Default handler group for 'note' type."""
@staticmethod
def _serialize(note):
d = {
'id': note.id,
'guid': note.guid,
'model': note.model(),
'mid': note.mid,
'mod': note.mod,
'scm': note.scm,
'tags': note.tags,
'string_tags': ' '.join(note.tags),
'fields': {},
'flags': note.flags,
'usn': note.usn,
}
# add all the fields
for name, value in note.items():
d['fields'][name] = value
return d
def index(self, col, req):
note = col.getNote(req.ids[1])
return self._serialize(note)
@noReturnValue
def add_tags(self, col, req):
note = col.getNote(req.ids[1])
for tag in req.data['tags']:
note.addTag(tag)
note.flush()
@noReturnValue
def remove_tags(self, col, req):
note = col.getNote(req.ids[1])
for tag in req.data['tags']:
note.delTag(tag)
note.flush()
class DeckHandler(RestHandlerBase):
"""Default handler group for 'deck' type."""
@staticmethod
def _get_deck(col, val):
try:
did = long(val)
deck = col.decks.get(did, False)
except ValueError:
deck = col.decks.byName(val)
if deck is None:
raise HTTPNotFound('No deck with id or name: ' + str(val))
return deck
def index(self, col, req):
return self._get_deck(col, req.ids[1])
def next_card(self, col, req):
req_copy = req.copy()
req_copy.data['deck'] = req.ids[1]
del req_copy.ids[1]
# forward this to the CollectionHandler
return req.app.execute_handler('collection', 'next_card', col, req_copy)
def get_conf(self, col, req):
# TODO: should probably live in a ConfHandler
return col.decks.confForDid(req.ids[1])
@noReturnValue
def set_update_conf(self, col, req):
data = req.data.copy()
del data['id']
conf = col.decks.confForDid(req.ids[1])
conf = conf.copy()
conf.update(data)
col.decks.updateConf(conf)
class CardHandler(RestHandlerBase):
"""Default handler group for 'card' type."""
@staticmethod
def _serialize(card, opts):
d = {
'id': card.id,
'isEmpty': card.isEmpty(),
'question': card.q(),
'answer': card.a(),
'did': card.did,
'due': card.due,
'factor': card.factor,
'ivl': card.ivl,
'lapses': card.lapses,
'left': card.left,
'mod': card.mod,
'nid': card.nid,
'odid': card.odid,
'odue': card.odue,
'ord': card.ord,
'queue': card.queue,
'reps': card.reps,
'type': card.type,
'usn': card.usn,
'timerStarted': card.timerStarted,
}
if opts.get('load_note', False):
d['note'] = NoteHandler._serialize(card.col.getNote(card.nid))
if opts.get('load_deck', False):
d['deck'] = card.col.decks.get(card.did)
return d
def index(self, col, req):
card = col.getCard(req.ids[1])
return self._serialize(card, req.data)
def _forward_to_note(self, card_id, name):
card = col.getCard(card_id)
req_copy = req.copy()
req_copy.ids[1] = card.nid
return self.app.execute_handler('note', name, col, req)
@noReturnValue
def add_tags(self, col, req):
self._forward_to_note(req.ids[1], 'add_tags')
@noReturnValue
def remove_tags(self, col, req):
self._forward_to_note(req.ids[1], 'remove_tags')
# Our entry point
def make_app(global_conf, **local_conf):
# TODO: we should setup the default language from conf!
# setup the logger
from AnkiServer.utils import setup_logging
setup_logging(local_conf.get('logging.config_file'))
return RestApp(
data_root=local_conf.get('data_root', '.'),
allowed_hosts=local_conf.get('allowed_hosts', '*')
)