Blob Blame History Raw
from itertools import repeat
from autograd.wrap_util import wraps
from autograd.util import subvals, toposort
from autograd.tracer import trace, Node
from functools import partial

class ConstGraphNode(Node):
    __slots__ = ['parents', 'partial_fun']
    def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
        args = subvals(args, zip(parent_argnums, repeat(None)))
        def partial_fun(partial_args):
            return fun(*subvals(args, zip(parent_argnums, partial_args)), **kwargs)

        self.parents = parents
        self.partial_fun = partial_fun

    def initialize_root(self):
        self.parents = []

def const_graph_unary(fun):
    graph = []
    _fun = [fun]  # Allow fun to be freed, since it may have bound args
    def maybe_cached_fun(x):
        if graph:
            _graph = graph[0]
            vals = {_graph[0] : x}
            for node in _graph[1:]:
                vals[node] = node.partial_fun([vals[p] for p in node.parents])
            return vals[node]
        else:
            start_node = ConstGraphNode.new_root()
            end_value, end_node = trace(start_node, _fun.pop(), x)
            if end_node is None:
                raise Exception("Output is independent of input")
            graph.append(list(toposort(end_node))[::-1])
            return end_value
    return maybe_cached_fun

def const_graph(fun, *args, **kwargs):
    partial_fun = partial(fun, *args, **kwargs)
    unary_fun = lambda args: partial_fun(*args)
    maybe_cached_unary_fun = const_graph_unary(unary_fun)
    @wraps(fun)
    def _fun(*args): return maybe_cached_unary_fun(args)
    return _fun

class FullGraphNode(Node):
    __slots__ = ['value', 'recipe']
    def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
        self.value = value
        self.recipe = (fun, args, kwargs, zip(parent_argnums, parents))

    def initialize_root(self):
        self.value = None
        self.recipe = (lambda x: x, (), {}, [])

def full_graph(fun, *args, **kwargs):
    unary_fun = lambda args: fun(*args, **kwargs)
    start_node = FullGraphNode.new_root()
    end_value, end_node = trace(start_node, unary_fun, args)
    return end_node