from __future__ import absolute_import
import scipy.stats
import autograd.numpy as np
from autograd.numpy.numpy_vjps import unbroadcast_f
from autograd.extend import primitive, defvjp
pdf = primitive(scipy.stats.multivariate_normal.pdf)
logpdf = primitive(scipy.stats.multivariate_normal.logpdf)
entropy = primitive(scipy.stats.multivariate_normal.entropy)
# With thanks to Eric Bresch.
# Some formulas are from
# "An extended collection of matrix derivative results
# for forward and reverse mode algorithmic differentiation"
# by Mike Giles
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
def generalized_outer_product(x):
if np.ndim(x) == 1:
return np.outer(x, x)
return np.matmul(x, np.swapaxes(x, -1, -2))
def covgrad(x, mean, cov, allow_singular=False):
if allow_singular:
raise NotImplementedError("The multivariate normal pdf is not "
"differentiable w.r.t. a singular covariance matix")
J = np.linalg.inv(cov)
solved = np.matmul(J, np.expand_dims(x - mean, -1))
return 1./2 * (generalized_outer_product(solved) - J)
def solve(allow_singular):
if allow_singular:
return lambda A, x: np.dot(np.linalg.pinv(A), x)
else:
return np.linalg.solve
defvjp(logpdf,
lambda ans, x, mean, cov, allow_singular=False:
unbroadcast_f(x, lambda g: -np.expand_dims(np.atleast_1d(g), 1) * solve(allow_singular)(cov, (x - mean).T).T),
lambda ans, x, mean, cov, allow_singular=False:
unbroadcast_f(mean, lambda g: np.expand_dims(np.atleast_1d(g), 1) * solve(allow_singular)(cov, (x - mean).T).T),
lambda ans, x, mean, cov, allow_singular=False:
unbroadcast_f(cov, lambda g: np.reshape(g, np.shape(g) + (1, 1)) * covgrad(x, mean, cov, allow_singular)))
# Same as log pdf, but multiplied by the pdf (ans).
defvjp(pdf,
lambda ans, x, mean, cov, allow_singular=False:
unbroadcast_f(x, lambda g: -np.expand_dims(np.atleast_1d(ans * g), 1) * solve(allow_singular)(cov, (x - mean).T).T),
lambda ans, x, mean, cov, allow_singular=False:
unbroadcast_f(mean, lambda g: np.expand_dims(np.atleast_1d(ans * g), 1) * solve(allow_singular)(cov, (x - mean).T).T),
lambda ans, x, mean, cov, allow_singular=False:
unbroadcast_f(cov, lambda g: np.reshape(ans * g, np.shape(g) + (1, 1)) * covgrad(x, mean, cov, allow_singular)))
defvjp(entropy, None,
lambda ans, mean, cov:
unbroadcast_f(cov, lambda g: 0.5 * g * np.linalg.inv(cov).T))