Blob Blame History Raw
from autograd.tracer import primitive, getval
from autograd.extend import defvjp
from autograd.test_util import check_grads
from pytest import raises

def test_check_vjp_1st_order_fail():
    @primitive
    def foo(x):
        return x * 2.0
    defvjp(foo, lambda ans, x : lambda g: g * 2.001)

    with raises(AssertionError, match="\\(VJP\\) check of foo failed"):
         check_grads(foo, modes=['rev'])(1.0)

def test_check_vjp_2nd_order_fail():
    @primitive
    def foo(x):
        return x * 2.0
    defvjp(foo, lambda ans, x : lambda g: bar(g) * 2)

    @primitive
    def bar(x):
        return x
    defvjp(bar, lambda ans, x : lambda g: g * 1.001)

    with raises(AssertionError, match="\\(VJP\\) check of vjp_foo failed"):
         check_grads(foo, modes=['rev'])(1.0)