Blob Blame History Raw
"""Gradients of the normal distribution."""
from __future__ import absolute_import
import scipy.stats
import autograd.numpy as anp
from autograd.extend import primitive, defvjp
from autograd.numpy.numpy_vjps import unbroadcast_f

pdf = primitive(scipy.stats.norm.pdf)
cdf = primitive(scipy.stats.norm.cdf)
sf = primitive(scipy.stats.norm.sf)
logpdf = primitive(scipy.stats.norm.logpdf)
logcdf = primitive(scipy.stats.norm.logcdf)
logsf = primitive(scipy.stats.norm.logsf)

defvjp(pdf,
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(x, lambda g: -g * ans * (x - loc) / scale**2),
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(loc, lambda g: g * ans * (x - loc) / scale**2),
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(scale, lambda g: g * ans * (((x - loc)/scale)**2 - 1.0)/scale))

defvjp(cdf,
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(x, lambda g: g * pdf(x, loc, scale)) ,
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(loc, lambda g: -g * pdf(x, loc, scale)),
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(scale, lambda g: -g * pdf(x, loc, scale)*(x-loc)/scale))

defvjp(logpdf,
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(x, lambda g: -g * (x - loc) / scale**2),
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(loc, lambda g: g * (x - loc) / scale**2),
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(scale, lambda g: g * (-1.0/scale + (x - loc)**2/scale**3)))

defvjp(logcdf,
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(x, lambda g: g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))),
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(loc, lambda g:-g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))),
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(scale, lambda g:-g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))*(x-loc)/scale))

defvjp(logsf, 
       lambda ans, x, loc=0.0, scale=1.0: 
       unbroadcast_f(x, lambda g: -g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale))),
       lambda ans, x, loc=0.0, scale=1.0: 
       unbroadcast_f(loc, lambda g: g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale))),
       lambda ans, x, loc=0.0, scale=1.0: 
       unbroadcast_f(scale, lambda g: g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale)) * (x - loc) / scale))

defvjp(sf,
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(x, lambda g: -g * pdf(x, loc, scale)) ,
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(loc, lambda g: g * pdf(x, loc, scale)),
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(scale, lambda g: g * pdf(x, loc, scale)*(x-loc)/scale))