Blob Blame History Raw
from __future__ import absolute_import
import scipy.special
import autograd.numpy as np
from autograd.extend import primitive, defvjp, defjvp
from autograd.numpy.numpy_vjps import unbroadcast_f, repeat_to_match_shape

### Beta function ###
beta    = primitive(scipy.special.beta)
betainc = primitive(scipy.special.betainc)
betaln  = primitive(scipy.special.betaln)

defvjp(beta,
       lambda ans, a, b: unbroadcast_f(a, lambda g: g * ans * (psi(a) - psi(a + b))),
       lambda ans, a, b: unbroadcast_f(b, lambda g: g * ans * (psi(b) - psi(a + b))))
defvjp(betainc,
       lambda ans, a, b, x: unbroadcast_f(x, lambda g: g * np.power(x, a - 1) * np.power(1 - x, b - 1) / beta(a, b)),
       argnums=[2])
defvjp(betaln,
       lambda ans, a, b: unbroadcast_f(a, lambda g: g * (psi(a) - psi(a + b))),
       lambda ans, a, b: unbroadcast_f(b, lambda g: g * (psi(b) - psi(a + b))))

### Gamma functions ###
polygamma    = primitive(scipy.special.polygamma)
psi          = primitive(scipy.special.psi)        # psi(x) is just polygamma(0, x)
digamma      = primitive(scipy.special.digamma)    # digamma is another name for psi.
gamma        = primitive(scipy.special.gamma)
gammaln      = primitive(scipy.special.gammaln)
gammainc     = primitive(scipy.special.gammainc)
gammaincc    = primitive(scipy.special.gammaincc)
gammasgn     = primitive(scipy.special.gammasgn)
rgamma       = primitive(scipy.special.rgamma)
multigammaln = primitive(scipy.special.multigammaln)

defvjp(gammasgn, None)
defvjp(polygamma, None, lambda ans, n, x: lambda g: g * polygamma(n + 1, x))
defvjp(psi,      lambda ans, x: lambda g: g * polygamma(1, x))
defvjp(digamma,  lambda ans, x: lambda g: g * polygamma(1, x))
defvjp(gamma,    lambda ans, x: lambda g: g * ans * psi(x))
defvjp(gammaln,  lambda ans, x: lambda g: g * psi(x))
defvjp(rgamma,   lambda ans, x: lambda g: g * psi(x) / -gamma(x))
defvjp(multigammaln,lambda ans, a, d: lambda g:
       g * np.sum(digamma(np.expand_dims(a, -1) - np.arange(d)/2.), -1),
       None)

def make_gammainc_vjp_arg1(sign):
    def gammainc_vjp_arg1(ans, a, x):
        coeffs = sign * np.exp(-x) * np.power(x, a - 1) / gamma(a)
        return unbroadcast_f(x, lambda g: g * coeffs)
    return gammainc_vjp_arg1
defvjp(gammainc, make_gammainc_vjp_arg1(1), argnums=[1])
defvjp(gammaincc, make_gammainc_vjp_arg1(-1), argnums=[1])

### Bessel functions ###

j0 = primitive(scipy.special.j0)
y0 = primitive(scipy.special.y0)
j1 = primitive(scipy.special.j1)
y1 = primitive(scipy.special.y1)
jn = primitive(scipy.special.jn)
yn = primitive(scipy.special.yn)

defvjp(j0,lambda ans, x: lambda g: -g * j1(x))
defvjp(y0,lambda ans, x: lambda g: -g * y1(x))
defvjp(j1,lambda ans, x: lambda g: g * (j0(x) - jn(2, x)) / 2.0)
defvjp(y1,lambda ans, x: lambda g: g * (y0(x) - yn(2, x)) / 2.0)
defvjp(jn, None, lambda ans, n, x: lambda g: g * (jn(n - 1, x) - jn(n + 1, x)) / 2.0)
defvjp(yn, None, lambda ans, n, x: lambda g: g * (yn(n - 1, x) - yn(n + 1, x)) / 2.0)


### Faster versions of common Bessel functions ###
i0 = primitive(scipy.special.i0)
i1 = primitive(scipy.special.i1)
iv = primitive(scipy.special.iv)
ive = primitive(scipy.special.ive)

defvjp(i0, lambda ans, x: lambda g: g * i1(x))
defvjp(i1, lambda ans, x: lambda g: g * (i0(x) + iv(2, x)) / 2.0)
defvjp(iv,  None, lambda ans, n, x: lambda g: g * (iv(n - 1, x) + iv(n + 1, x)) / 2.0)
defvjp(ive, None, lambda ans, n, x: lambda g: g * (ans * (n / x - np.sign(x)) + ive(n + 1, x)))

### Error Function ###
inv_root_pi = 0.56418958354775627928
erf = primitive(scipy.special.erf)
erfc = primitive(scipy.special.erfc)

defvjp(erf, lambda ans, x: lambda g:  2.*g*inv_root_pi*np.exp(-x**2))
defvjp(erfc,lambda ans, x: lambda g: -2.*g*inv_root_pi*np.exp(-x**2))


### Inverse error function ###
root_pi = 1.7724538509055159
erfinv = primitive(scipy.special.erfinv)
erfcinv = primitive(scipy.special.erfcinv)

defvjp(erfinv,lambda ans, x: lambda g: g * root_pi / 2 * np.exp(erfinv(x)**2))
defvjp(erfcinv,lambda ans, x: lambda g: -g * root_pi / 2 * np.exp(erfcinv(x)**2))

### Logit and Expit ###
logit = primitive(scipy.special.logit)
expit = primitive(scipy.special.expit)

defvjp(logit,lambda ans, x: lambda g: g / ( x * (1 - x)))
defvjp(expit,lambda ans, x: lambda g: g * ans * (1 - ans))

### logsumexp ###
logsumexp = primitive(scipy.special.logsumexp)

def make_grad_logsumexp(ans, x, axis=None, b=1.0, keepdims=False):
    shape, dtype = np.shape(x), np.result_type(x)
    def vjp(g):
        g_repeated,   _ = repeat_to_match_shape(g,   shape, dtype, axis, keepdims)
        ans_repeated, _ = repeat_to_match_shape(ans, shape, dtype, axis, keepdims)
        return g_repeated * b * np.exp(x - ans_repeated)
    return vjp

defvjp(logsumexp, make_grad_logsumexp)

def fwd_grad_logsumexp(g, ans, x, axis=None, b=1.0, keepdims=False):
    if not keepdims:
        if isinstance(axis, int):
            ans = np.expand_dims(ans, axis)
        elif isinstance(axis, tuple):
            for ax in sorted(axis):
                ans = np.expand_dims(ans, ax)
    return np.sum(g * b * np.exp(x - ans), axis=axis, keepdims=keepdims)

defjvp(logsumexp, fwd_grad_logsumexp)