from __future__ import absolute_import
import scipy.special
import autograd.numpy as np
from autograd.extend import primitive, defvjp, defjvp
from autograd.numpy.numpy_vjps import unbroadcast_f, repeat_to_match_shape
### Beta function ###
beta = primitive(scipy.special.beta)
betainc = primitive(scipy.special.betainc)
betaln = primitive(scipy.special.betaln)
defvjp(beta,
lambda ans, a, b: unbroadcast_f(a, lambda g: g * ans * (psi(a) - psi(a + b))),
lambda ans, a, b: unbroadcast_f(b, lambda g: g * ans * (psi(b) - psi(a + b))))
defvjp(betainc,
lambda ans, a, b, x: unbroadcast_f(x, lambda g: g * np.power(x, a - 1) * np.power(1 - x, b - 1) / beta(a, b)),
argnums=[2])
defvjp(betaln,
lambda ans, a, b: unbroadcast_f(a, lambda g: g * (psi(a) - psi(a + b))),
lambda ans, a, b: unbroadcast_f(b, lambda g: g * (psi(b) - psi(a + b))))
### Gamma functions ###
polygamma = primitive(scipy.special.polygamma)
psi = primitive(scipy.special.psi) # psi(x) is just polygamma(0, x)
digamma = primitive(scipy.special.digamma) # digamma is another name for psi.
gamma = primitive(scipy.special.gamma)
gammaln = primitive(scipy.special.gammaln)
gammainc = primitive(scipy.special.gammainc)
gammaincc = primitive(scipy.special.gammaincc)
gammasgn = primitive(scipy.special.gammasgn)
rgamma = primitive(scipy.special.rgamma)
multigammaln = primitive(scipy.special.multigammaln)
defvjp(gammasgn, None)
defvjp(polygamma, None, lambda ans, n, x: lambda g: g * polygamma(n + 1, x))
defvjp(psi, lambda ans, x: lambda g: g * polygamma(1, x))
defvjp(digamma, lambda ans, x: lambda g: g * polygamma(1, x))
defvjp(gamma, lambda ans, x: lambda g: g * ans * psi(x))
defvjp(gammaln, lambda ans, x: lambda g: g * psi(x))
defvjp(rgamma, lambda ans, x: lambda g: g * psi(x) / -gamma(x))
defvjp(multigammaln,lambda ans, a, d: lambda g:
g * np.sum(digamma(np.expand_dims(a, -1) - np.arange(d)/2.), -1),
None)
def make_gammainc_vjp_arg1(sign):
def gammainc_vjp_arg1(ans, a, x):
coeffs = sign * np.exp(-x) * np.power(x, a - 1) / gamma(a)
return unbroadcast_f(x, lambda g: g * coeffs)
return gammainc_vjp_arg1
defvjp(gammainc, make_gammainc_vjp_arg1(1), argnums=[1])
defvjp(gammaincc, make_gammainc_vjp_arg1(-1), argnums=[1])
### Bessel functions ###
j0 = primitive(scipy.special.j0)
y0 = primitive(scipy.special.y0)
j1 = primitive(scipy.special.j1)
y1 = primitive(scipy.special.y1)
jn = primitive(scipy.special.jn)
yn = primitive(scipy.special.yn)
defvjp(j0,lambda ans, x: lambda g: -g * j1(x))
defvjp(y0,lambda ans, x: lambda g: -g * y1(x))
defvjp(j1,lambda ans, x: lambda g: g * (j0(x) - jn(2, x)) / 2.0)
defvjp(y1,lambda ans, x: lambda g: g * (y0(x) - yn(2, x)) / 2.0)
defvjp(jn, None, lambda ans, n, x: lambda g: g * (jn(n - 1, x) - jn(n + 1, x)) / 2.0)
defvjp(yn, None, lambda ans, n, x: lambda g: g * (yn(n - 1, x) - yn(n + 1, x)) / 2.0)
### Faster versions of common Bessel functions ###
i0 = primitive(scipy.special.i0)
i1 = primitive(scipy.special.i1)
iv = primitive(scipy.special.iv)
ive = primitive(scipy.special.ive)
defvjp(i0, lambda ans, x: lambda g: g * i1(x))
defvjp(i1, lambda ans, x: lambda g: g * (i0(x) + iv(2, x)) / 2.0)
defvjp(iv, None, lambda ans, n, x: lambda g: g * (iv(n - 1, x) + iv(n + 1, x)) / 2.0)
defvjp(ive, None, lambda ans, n, x: lambda g: g * (ans * (n / x - np.sign(x)) + ive(n + 1, x)))
### Error Function ###
inv_root_pi = 0.56418958354775627928
erf = primitive(scipy.special.erf)
erfc = primitive(scipy.special.erfc)
defvjp(erf, lambda ans, x: lambda g: 2.*g*inv_root_pi*np.exp(-x**2))
defvjp(erfc,lambda ans, x: lambda g: -2.*g*inv_root_pi*np.exp(-x**2))
### Inverse error function ###
root_pi = 1.7724538509055159
erfinv = primitive(scipy.special.erfinv)
erfcinv = primitive(scipy.special.erfcinv)
defvjp(erfinv,lambda ans, x: lambda g: g * root_pi / 2 * np.exp(erfinv(x)**2))
defvjp(erfcinv,lambda ans, x: lambda g: -g * root_pi / 2 * np.exp(erfcinv(x)**2))
### Logit and Expit ###
logit = primitive(scipy.special.logit)
expit = primitive(scipy.special.expit)
defvjp(logit,lambda ans, x: lambda g: g / ( x * (1 - x)))
defvjp(expit,lambda ans, x: lambda g: g * ans * (1 - ans))
### logsumexp ###
logsumexp = primitive(scipy.special.logsumexp)
def make_grad_logsumexp(ans, x, axis=None, b=1.0, keepdims=False):
shape, dtype = np.shape(x), np.result_type(x)
def vjp(g):
g_repeated, _ = repeat_to_match_shape(g, shape, dtype, axis, keepdims)
ans_repeated, _ = repeat_to_match_shape(ans, shape, dtype, axis, keepdims)
return g_repeated * b * np.exp(x - ans_repeated)
return vjp
defvjp(logsumexp, make_grad_logsumexp)
def fwd_grad_logsumexp(g, ans, x, axis=None, b=1.0, keepdims=False):
if not keepdims:
if isinstance(axis, int):
ans = np.expand_dims(ans, axis)
elif isinstance(axis, tuple):
for ax in sorted(axis):
ans = np.expand_dims(ans, ax)
return np.sum(g * b * np.exp(x - ans), axis=axis, keepdims=keepdims)
defjvp(logsumexp, fwd_grad_logsumexp)