from __future__ import division
from contextlib import contextmanager
import pytest
import warnings
import autograd.numpy as np
from autograd import grad, deriv
from autograd.extend import primitive
from autograd.test_util import check_grads
from autograd.core import primitive_vjps
def test_assert():
# from https://github.com/HIPS/autograd/issues/43
def fun(x):
assert np.allclose(x, (x*3.0)/3.0)
return np.sum(x)
check_grads(fun)(np.array([1.0, 2.0, 3.0]))
def test_nograd():
# we want this to raise non-differentiability error
fun = lambda x: np.allclose(x, (x*3.0)/3.0)
with pytest.raises(TypeError):
with warnings.catch_warnings(record=True) as w:
grad(fun)(np.array([1., 2., 3.]))
def test_no_vjp_def():
fun = primitive(lambda x: 2. * x)
with pytest.raises(NotImplementedError):
grad(fun)(1.)
def test_no_jvp_def():
fun = primitive(lambda x: 2. * x)
with pytest.raises(NotImplementedError):
deriv(fun)(1.)
def test_falseyness():
fun = lambda x: np.real(x**2 if np.iscomplex(x) else np.sum(x))
check_grads(fun)(2.)
check_grads(fun)(2. + 1j)
def test_unimplemented_falseyness():
@contextmanager
def remove_grad_definitions(fun):
vjpmaker = primitive_vjps.pop(fun, None)
yield
if vjpmaker:
primitive_vjps[fun] = vjpmaker
with remove_grad_definitions(np.iscomplex):
fun = lambda x: np.real(x**2 if np.iscomplex(x) else np.sum(x))
check_grads(fun)(5.)
check_grads(fun)(2. + 1j)