from __future__ import absolute_import
import scipy.stats
import autograd.numpy as np
from autograd.scipy.special import digamma
from autograd.extend import primitive, defvjp
rvs = primitive(scipy.stats.dirichlet.rvs)
pdf = primitive(scipy.stats.dirichlet.pdf)
logpdf = primitive(scipy.stats.dirichlet.logpdf)
defvjp(logpdf,lambda ans, x, alpha: lambda g:
g * (alpha - 1) / x,
lambda ans, x, alpha: lambda g:
g * (digamma(np.sum(alpha)) - digamma(alpha) + np.log(x)))
# Same as log pdf, but multiplied by the pdf (ans).
defvjp(pdf,lambda ans, x, alpha: lambda g:
g * ans * (alpha - 1) / x,
lambda ans, x, alpha: lambda g:
g * ans * (digamma(np.sum(alpha)) - digamma(alpha) + np.log(x)))