Blob Blame History Raw
from __future__ import absolute_import, division

import autograd.numpy as np
import scipy.stats
from autograd.extend import primitive, defvjp
from autograd.numpy.numpy_vjps import unbroadcast_f
from autograd.scipy.special import gamma

cdf = primitive(scipy.stats.chi2.cdf)
logpdf = primitive(scipy.stats.chi2.logpdf)
pdf = primitive(scipy.stats.chi2.pdf)

def grad_chi2_logpdf(x, df):
    return np.where(df % 1 == 0, (df - x - 2) / (2 * x), 0)

defvjp(cdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * np.power(2., -df/2) * np.exp(-x/2) * np.power(x, df/2 - 1) / gamma(df/2)), argnums=[0])
defvjp(logpdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * grad_chi2_logpdf(x, df)), argnums=[0])
defvjp(pdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * ans * grad_chi2_logpdf(x, df)), argnums=[0])