Blob Blame History Raw
from __future__ import absolute_import
import scipy.stats

import autograd.numpy as np
from autograd.scipy.special import digamma
from autograd.extend import primitive, defvjp

rvs    = primitive(scipy.stats.dirichlet.rvs)
pdf    = primitive(scipy.stats.dirichlet.pdf)
logpdf = primitive(scipy.stats.dirichlet.logpdf)

defvjp(logpdf,lambda ans, x, alpha: lambda g:
              g * (alpha - 1) / x,
              lambda ans, x, alpha: lambda g:
              g * (digamma(np.sum(alpha)) - digamma(alpha) + np.log(x)))

# Same as log pdf, but multiplied by the pdf (ans).
defvjp(pdf,lambda ans, x, alpha: lambda g:
           g * ans * (alpha - 1) / x,
           lambda ans, x, alpha: lambda g:
           g * ans * (digamma(np.sum(alpha)) - digamma(alpha) + np.log(x)))