import autograd.numpy as np
import autograd.numpy.random as npr
from autograd.test_util import scalar_close
from autograd import make_vjp, grad
from autograd.tracer import primitive
from autograd.misc import const_graph, flatten
def test_const_graph():
L = []
def foo(x, y):
L.append(None)
return grad(lambda x: np.sin(x) + x * 2)(x * y)
foo_wrapped = const_graph(foo)
assert len(L) == 0
assert scalar_close(foo(0., 0.),
foo_wrapped(0., 0.))
assert len(L) == 2
assert scalar_close(foo(1., 0.5),
foo_wrapped(1., 0.5))
assert len(L) == 3
assert scalar_close(foo(1., 0.5),
foo_wrapped(1., 0.5))
assert len(L) == 4
def test_const_graph_args():
L = []
@primitive
def process(var, varname):
L.append(varname)
return var
def foo(x, y, z):
x = process(x, 'x')
y = process(y, 'y')
z = process(z, 'z')
return x + 2*y + 3*z
foo_wrapped = const_graph(foo, 1., z=3.)
assert L == []
assert scalar_close(foo(1., 2., 3.),
foo_wrapped(2.))
assert L == ['x', 'y', 'z', 'x', 'y', 'z']
L = []
assert scalar_close(foo(1., 2., 3.),
foo_wrapped(2.))
assert L == ['x', 'y', 'z', 'y']
L = []
assert scalar_close(foo(1., 2., 3.),
foo_wrapped(2.))
assert L == ['x', 'y', 'z', 'y']
def test_flatten():
r = np.random.randn
x = (1.0, r(2,3), [r(1,4), {'x': 2.0, 'y': r(4,2)}])
x_flat, unflatten = flatten(x)
assert x_flat.shape == (20,)
assert x_flat[0] == 1.0
assert np.all(x_flat == flatten(unflatten(x_flat))[0])
y = (1.0, 2.0, [3.0, {'x': 2.0, 'y': 4.0}])
y_flat, unflatten = flatten(y)
assert y_flat.shape == (5,)
assert y == unflatten(y_flat)
def test_flatten_empty():
val = (npr.randn(4), [npr.randn(3,4), 2.5], (), (2.0, [1.0, npr.randn(2)]))
vect, unflatten = flatten(val)
val_recovered = unflatten(vect)
vect_2, _ = flatten(val_recovered)
assert np.all(vect == vect_2)
def test_flatten_dict():
val = {'k': npr.random((4, 4)),
'k2': npr.random((3, 3)),
'k3': 3.0,
'k4': [1.0, 4.0, 7.0, 9.0]}
vect, unflatten = flatten(val)
val_recovered = unflatten(vect)
vect_2, _ = flatten(val_recovered)
assert np.all(vect == vect_2)
def unflatten_tracing():
val = [npr.randn(4), [npr.randn(3,4), 2.5], (), (2.0, [1.0, npr.randn(2)])]
vect, unflatten = flatten(val)
def f(vect): return unflatten(vect)
flatten2, _ = make_vjp(f)(vect)
assert np.all(vect == flatten2(val))
def test_flatten_nodes_in_containers():
# see issue #232
def f(x, y):
xy, _ = flatten([x, y])
return np.sum(xy)
grad(f)(1.0, 2.0)
def test_flatten_complex():
val = 1 + 1j
flat, unflatten = flatten(val)
assert np.all(val == unflatten(flat))