import warnings
from contextlib import contextmanager
from collections import defaultdict
from .util import subvals, toposort
from .wrap_util import wraps
def trace(start_node, fun, x):
with trace_stack.new_trace() as t:
start_box = new_box(x, t, start_node)
end_box = fun(start_box)
if isbox(end_box) and end_box._trace == start_box._trace:
return end_box._value, end_box._node
else:
warnings.warn("Output seems independent of input.")
return end_box, None
class Node(object):
__slots__ = []
def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
assert False
def initialize_root(self, *args, **kwargs):
assert False
@classmethod
def new_root(cls, *args, **kwargs):
root = cls.__new__(cls)
root.initialize_root(*args, **kwargs)
return root
def primitive(f_raw):
"""
Wraps a function so that its gradient can be specified and its invocation
can be recorded. For examples, see the docs."""
@wraps(f_raw)
def f_wrapped(*args, **kwargs):
boxed_args, trace, node_constructor = find_top_boxed_args(args)
if boxed_args:
argvals = subvals(args, [(argnum, box._value) for argnum, box in boxed_args])
if f_wrapped in notrace_primitives[node_constructor]:
return f_wrapped(*argvals, **kwargs)
parents = tuple(box._node for _ , box in boxed_args)
argnums = tuple(argnum for argnum, _ in boxed_args)
ans = f_wrapped(*argvals, **kwargs)
node = node_constructor(ans, f_wrapped, argvals, kwargs, argnums, parents)
return new_box(ans, trace, node)
else:
return f_raw(*args, **kwargs)
f_wrapped.fun = f_raw
f_wrapped._is_autograd_primitive = True
return f_wrapped
notrace_primitives = defaultdict(set)
def register_notrace(trace_type, primitive_fun):
notrace_primitives[trace_type].add(primitive_fun)
def notrace_primitive(f_raw):
@wraps(f_raw)
def f_wrapped(*args, **kwargs):
argvals = map(getval, args)
return f_raw(*argvals, **kwargs)
f_wrapped._is_primitive = True
return f_wrapped
def find_top_boxed_args(args):
top_trace = -1
top_boxes = []
top_node_type = None
for argnum, arg in enumerate(args):
if isbox(arg):
trace = arg._trace
if trace > top_trace:
top_boxes = [(argnum, arg)]
top_trace = trace
top_node_type = type(arg._node)
elif trace == top_trace:
top_boxes.append((argnum, arg))
return top_boxes, top_trace, top_node_type
class TraceStack(object):
def __init__(self):
self.top = -1
@contextmanager
def new_trace(self):
self.top += 1
yield self.top
self.top -= 1
trace_stack = TraceStack()
class Box(object):
type_mappings = {}
types = set()
__slots__ = ['_value', '_trace', '_node']
def __init__(self, value, trace, node):
self._value = value
self._node = node
self._trace = trace
def __bool__(self):
return bool(self._value)
__nonzero__ = __bool__
def __str__(self):
return "Autograd {0} with value {1}".format(
type(self).__name__, str(self._value))
@classmethod
def register(cls, value_type):
Box.types.add(cls)
Box.type_mappings[value_type] = cls
Box.type_mappings[cls] = cls
box_type_mappings = Box.type_mappings
def new_box(value, trace, node):
try:
return box_type_mappings[type(value)](value, trace, node)
except KeyError:
raise TypeError("Can't differentiate w.r.t. type {}".format(type(value)))
box_types = Box.types
isbox = lambda x: type(x) in box_types # almost 3X faster than isinstance(x, Box)
getval = lambda x: getval(x._value) if isbox(x) else x