from itertools import repeat
from autograd.wrap_util import wraps
from autograd.util import subvals, toposort
from autograd.tracer import trace, Node
from functools import partial
class ConstGraphNode(Node):
__slots__ = ['parents', 'partial_fun']
def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
args = subvals(args, zip(parent_argnums, repeat(None)))
def partial_fun(partial_args):
return fun(*subvals(args, zip(parent_argnums, partial_args)), **kwargs)
self.parents = parents
self.partial_fun = partial_fun
def initialize_root(self):
self.parents = []
def const_graph_unary(fun):
graph = []
_fun = [fun] # Allow fun to be freed, since it may have bound args
def maybe_cached_fun(x):
if graph:
_graph = graph[0]
vals = {_graph[0] : x}
for node in _graph[1:]:
vals[node] = node.partial_fun([vals[p] for p in node.parents])
return vals[node]
else:
start_node = ConstGraphNode.new_root()
end_value, end_node = trace(start_node, _fun.pop(), x)
if end_node is None:
raise Exception("Output is independent of input")
graph.append(list(toposort(end_node))[::-1])
return end_value
return maybe_cached_fun
def const_graph(fun, *args, **kwargs):
partial_fun = partial(fun, *args, **kwargs)
unary_fun = lambda args: partial_fun(*args)
maybe_cached_unary_fun = const_graph_unary(unary_fun)
@wraps(fun)
def _fun(*args): return maybe_cached_unary_fun(args)
return _fun
class FullGraphNode(Node):
__slots__ = ['value', 'recipe']
def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
self.value = value
self.recipe = (fun, args, kwargs, zip(parent_argnums, parents))
def initialize_root(self):
self.value = None
self.recipe = (lambda x: x, (), {}, [])
def full_graph(fun, *args, **kwargs):
unary_fun = lambda args: fun(*args, **kwargs)
start_node = FullGraphNode.new_root()
end_value, end_node = trace(start_node, unary_fun, args)
return end_node