Blob Blame History Raw
import itertools
from future.utils import with_metaclass
from .util import subvals
from .extend import (Box, primitive, notrace_primitive, VSpace, vspace,
                     SparseObject, defvjp, defvjp_argnum, defjvp, defjvp_argnum)

isinstance_ = isinstance
isinstance = notrace_primitive(isinstance)

type_ = type
type = notrace_primitive(type)

tuple_, list_, dict_ = tuple, list, dict

@primitive
def container_take(A, idx):
    return A[idx]
def grad_container_take(ans, A, idx):
    return lambda g: container_untake(g, idx, vspace(A))
defvjp(container_take, grad_container_take)
defjvp(container_take, 'same')

class SequenceBox(Box):
    __slots__ = []
    __getitem__ = container_take
    def __len__(self): return len(self._value)
    def __add__(self, other): return sequence_extend_right(self, *other)
    def __radd__(self, other): return sequence_extend_left(self, *other)
    def __contains__(self, elt): return elt in self._value
    def index(self, elt): return self._value.index(elt)
SequenceBox.register(tuple_)
SequenceBox.register(list_)

class DictBox(Box):
    __slots__ = []
    __getitem__= container_take
    def __len__(self): return len(self._value)
    def __iter__(self): return self._value.__iter__()
    def __contains__(self, elt): return elt in self._value
    def items(self): return list(self.iteritems())
    def keys(self): return list(self.iterkeys())
    def values(self): return list(self.itervalues())
    def iteritems(self): return ((k, self[k]) for k in self)
    def iterkeys(self): return iter(self)
    def itervalues(self): return (self[k] for k in self)
    def get(self, k, d=None): return self[k] if k in self else d
DictBox.register(dict_)

@primitive
def container_untake(x, idx, vs):
    if isinstance(idx, slice):
        accum = lambda result: [elt_vs._mut_add(a, b)
                                for elt_vs, a, b in zip(vs.shape[idx], result, x)]
    else:
        accum = lambda result: vs.shape[idx]._mut_add(result, x)
    def mut_add(A):
        return vs._subval(A, idx, accum(A[idx]))
    return SparseObject(vs, mut_add)
defvjp(container_untake, lambda ans, x, idx, _:
       lambda g: container_take(g, idx))
defjvp(container_untake, 'same')

@primitive
def sequence_extend_right(seq, *elts):
    return seq + type(seq)(elts)
def grad_sequence_extend_right(argnum, ans, args, kwargs):
    seq, elts = args[0], args[1:]
    return lambda g: g[:len(seq)] if argnum == 0 else g[len(seq) + argnum - 1]
defvjp_argnum(sequence_extend_right, grad_sequence_extend_right)

@primitive
def sequence_extend_left(seq, *elts):
    return type(seq)(elts) + seq
def grad_sequence_extend_left(argnum, ans, args, kwargs):
    seq, elts = args[0], args[1:]
    return lambda g: g[len(elts):] if argnum == 0 else g[argnum - 1]
defvjp_argnum(sequence_extend_left, grad_sequence_extend_left)

@primitive
def make_sequence(seq_type, *args):
    return seq_type(args)
defvjp_argnum(make_sequence, lambda argnum, *args: lambda g: g[argnum - 1])

def fwd_grad_make_sequence(argnum, g, ans, seq_type, *args, **kwargs):
    return container_untake(g, argnum-1, vspace(ans))

defjvp_argnum(make_sequence, fwd_grad_make_sequence)


class TupleMeta(type_):
    def __instancecheck__(self, instance):
        return isinstance(instance, tuple_)
class tuple(with_metaclass(TupleMeta, tuple_)):
    def __new__(cls, xs):
        return make_sequence(tuple_, *xs)

class ListMeta(type_):
    def __instancecheck__(self, instance):
        return isinstance(instance, list_)
class list(with_metaclass(ListMeta, list_)):
    def __new__(cls, xs):
        return make_sequence(list_,  *xs)

class DictMeta(type_):
    def __instancecheck__(self, instance):
        return isinstance(instance, dict_)
class dict(with_metaclass(DictMeta, dict_)):
    def __new__(cls, *args, **kwargs):
        result = dict_(*args, **kwargs)
        if result:
            return _make_dict(result.keys(), list(result.values()))
        return result

@primitive
def _make_dict(keys, vals):
    return dict_(zip(keys, vals))
defvjp(_make_dict,
       lambda ans, keys, vals: lambda g:
       list(g[key] for key in keys), argnums=(1,))

class ContainerVSpace(VSpace):
    def __init__(self, value):
        self.shape = value
        self.shape = self._map(vspace)

    @property
    def size(self): return sum(self._values(self._map(lambda vs: vs.size)))
    def zeros(self): return self._map(lambda vs: vs.zeros())
    def ones(self):  return self._map(lambda vs: vs.ones())
    def randn(self): return self._map(lambda vs: vs.randn())
    def standard_basis(self):
        zero = self.zeros()
        for i, vs in self._kv_pairs(self.shape):
            for x in vs.standard_basis():
                yield self._subval(zero, i, x)
    def _add(self, xs, ys):
        return self._map(lambda vs, x, y: vs._add(x, y), xs, ys)
    def _mut_add(self, xs, ys):
        return self._map(lambda vs, x, y: vs._mut_add(x, y), xs, ys)
    def _scalar_mul(self, xs, a):
        return self._map(lambda vs, x: vs._scalar_mul(x, a), xs)
    def _inner_prod(self, xs, ys):
        return sum(self._values(self._map(lambda vs, x, y: vs._inner_prod(x, y), xs, ys)))
    def _covector(self, xs):
        return self._map(lambda vs, x: vs._covector(x), xs)

class SequenceVSpace(ContainerVSpace):
    def _values(self, x): return x
    def _kv_pairs(self, x): return enumerate(x)
    def _map(self, f, *args):
        return self.seq_type(map(f, self.shape, *args))
    def _subval(self, xs, idx, x):
        return self.seq_type(subvals(xs, [(idx, x)]))

class ListVSpace(SequenceVSpace):  seq_type = list_
class TupleVSpace(SequenceVSpace): seq_type = tuple_
class DictVSpace(ContainerVSpace):
    def _values(self, x):   return x.values()
    def _kv_pairs(self, x): return x.items()
    def _map(self, f, *args):return {k: f(vs, *[x[k] for x in args])
                                     for k, vs in self.shape.items()}
    def _subval(self, xs, idx, x):
        d = dict(xs.items())
        d[idx] = x
        return d

ListVSpace.register(list_)
TupleVSpace.register(tuple_)
DictVSpace.register(dict_)