from __future__ import absolute_import
from future.utils import string_types
from functools import partial
import numpy as onp
from ..util import func
from . import numpy_wrapper as anp
from .numpy_boxes import ArrayBox
from autograd.extend import (primitive, vspace, defvjp, defvjp_argnum,
SparseObject, VJPNode, register_notrace)
# ----- Non-differentiable functions -----
nograd_functions = [
anp.floor, anp.ceil, anp.round, anp.rint, anp.around, anp.fix, anp.trunc, anp.all,
anp.any, anp.argmax, anp.argmin, anp.argpartition, anp.argsort, anp.argwhere, anp.nonzero,
anp.flatnonzero, anp.count_nonzero, anp.searchsorted, anp.sign, anp.ndim, anp.shape,
anp.floor_divide, anp.logical_and, anp.logical_or, anp.logical_not, anp.logical_xor,
anp.isfinite, anp.isinf, anp.isnan, anp.isneginf, anp.isposinf, anp.allclose, anp.isclose,
anp.array_equal, anp.array_equiv, anp.greater, anp.greater_equal, anp.less, anp.less_equal,
anp.equal, anp.not_equal, anp.iscomplexobj, anp.iscomplex, anp.size, anp.isscalar,
anp.isreal, anp.zeros_like, anp.ones_like, anp.result_type]
for fun in nograd_functions:
register_notrace(VJPNode, fun)
# ----- Functions that are constant w.r.t. continuous inputs -----
defvjp(anp.nan_to_num, lambda ans, x: lambda g: anp.where(anp.isfinite(x), g, 0.))
# ----- Binary ufuncs -----
defvjp(anp.add, lambda ans, x, y : unbroadcast_f(x, lambda g: g),
lambda ans, x, y : unbroadcast_f(y, lambda g: g))
defvjp(anp.multiply, lambda ans, x, y : unbroadcast_f(x, lambda g: y * g),
lambda ans, x, y : unbroadcast_f(y, lambda g: x * g))
defvjp(anp.subtract, lambda ans, x, y : unbroadcast_f(x, lambda g: g),
lambda ans, x, y : unbroadcast_f(y, lambda g: -g))
defvjp(anp.divide, lambda ans, x, y : unbroadcast_f(x, lambda g: g / y),
lambda ans, x, y : unbroadcast_f(y, lambda g: - g * x / y**2))
defvjp(anp.maximum, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)))
defvjp(anp.minimum, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)))
defvjp(anp.fmax, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)))
defvjp(anp.fmin, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)))
defvjp(anp.logaddexp, lambda ans, x, y : unbroadcast_f(x, lambda g: g * anp.exp(x-ans)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * anp.exp(y-ans)))
defvjp(anp.logaddexp2, lambda ans, x, y : unbroadcast_f(x, lambda g: g * 2**(x-ans)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * 2**(y-ans)))
defvjp(anp.true_divide, lambda ans, x, y : unbroadcast_f(x, lambda g: g / y),
lambda ans, x, y : unbroadcast_f(y, lambda g: - g * x / y**2))
defvjp(anp.mod, lambda ans, x, y : unbroadcast_f(x, lambda g: g),
lambda ans, x, y : unbroadcast_f(y, lambda g: -g * anp.floor(x/y)))
defvjp(anp.remainder, lambda ans, x, y : unbroadcast_f(x, lambda g: g),
lambda ans, x, y : unbroadcast_f(y, lambda g: -g * anp.floor(x/y)))
defvjp(anp.power,
lambda ans, x, y : unbroadcast_f(x, lambda g: g * y * x ** anp.where(y, y - 1, 1.)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * anp.log(replace_zero(x, 1.)) * x ** y))
defvjp(anp.arctan2, lambda ans, x, y : unbroadcast_f(x, lambda g: g * y / (x**2 + y**2)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * -x / (x**2 + y**2)))
defvjp(anp.hypot,
lambda ans, x, y : unbroadcast_f(x, lambda g: g * x / ans),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * y / ans))
# ----- Simple grads -----
defvjp(anp.negative, lambda ans, x: lambda g: -g)
defvjp(anp.abs,
lambda ans, x : lambda g: g * replace_zero(anp.conj(x), 0.) / replace_zero(ans, 1.))
defvjp(anp.fabs, lambda ans, x : lambda g: anp.sign(x) * g) # fabs doesn't take complex numbers.
defvjp(anp.absolute, lambda ans, x : lambda g: g * anp.conj(x) / ans)
defvjp(anp.reciprocal, lambda ans, x : lambda g: - g / x**2)
defvjp(anp.exp, lambda ans, x : lambda g: ans * g)
defvjp(anp.exp2, lambda ans, x : lambda g: ans * anp.log(2) * g)
defvjp(anp.expm1, lambda ans, x : lambda g: (ans + 1) * g)
defvjp(anp.log, lambda ans, x : lambda g: g / x)
defvjp(anp.log2, lambda ans, x : lambda g: g / x / anp.log(2))
defvjp(anp.log10, lambda ans, x : lambda g: g / x / anp.log(10))
defvjp(anp.log1p, lambda ans, x : lambda g: g / (x + 1))
defvjp(anp.sin, lambda ans, x : lambda g: g * anp.cos(x))
defvjp(anp.cos, lambda ans, x : lambda g: - g * anp.sin(x))
defvjp(anp.tan, lambda ans, x : lambda g: g / anp.cos(x) **2)
defvjp(anp.arcsin, lambda ans, x : lambda g: g / anp.sqrt(1 - x**2))
defvjp(anp.arccos, lambda ans, x : lambda g:-g / anp.sqrt(1 - x**2))
defvjp(anp.arctan, lambda ans, x : lambda g: g / (1 + x**2))
defvjp(anp.sinh, lambda ans, x : lambda g: g * anp.cosh(x))
defvjp(anp.cosh, lambda ans, x : lambda g: g * anp.sinh(x))
defvjp(anp.tanh, lambda ans, x : lambda g: g / anp.cosh(x) **2)
defvjp(anp.arcsinh, lambda ans, x : lambda g: g / anp.sqrt(x**2 + 1))
defvjp(anp.arccosh, lambda ans, x : lambda g: g / anp.sqrt(x**2 - 1))
defvjp(anp.arctanh, lambda ans, x : lambda g: g / (1 - x**2))
defvjp(anp.rad2deg, lambda ans, x : lambda g: g / anp.pi * 180.0)
defvjp(anp.degrees, lambda ans, x : lambda g: g / anp.pi * 180.0)
defvjp(anp.deg2rad, lambda ans, x : lambda g: g * anp.pi / 180.0)
defvjp(anp.radians, lambda ans, x : lambda g: g * anp.pi / 180.0)
defvjp(anp.square, lambda ans, x : lambda g: g * 2 * x)
defvjp(anp.sqrt, lambda ans, x : lambda g: g * 0.5 * x**-0.5)
defvjp(anp.sinc, lambda ans, x : lambda g: g * (anp.cos(anp.pi*x)*anp.pi*x - anp.sin(anp.pi*x))/(anp.pi*x**2))
defvjp(anp.reshape, lambda ans, x, shape, order=None : lambda g: anp.reshape(g, anp.shape(x), order=order))
defvjp(anp.roll, lambda ans, x, shift, axis=None : lambda g: anp.roll(g, -shift, axis=axis))
defvjp(anp.array_split, lambda ans, ary, idxs, axis=0 : lambda g: anp.concatenate(g, axis=axis))
defvjp(anp.split, lambda ans, ary, idxs, axis=0 : lambda g: anp.concatenate(g, axis=axis))
defvjp(anp.vsplit, lambda ans, ary, idxs : lambda g: anp.concatenate(g, axis=0))
defvjp(anp.hsplit, lambda ans, ary, idxs : lambda g: anp.concatenate(g, axis=1))
defvjp(anp.dsplit, lambda ans, ary, idxs : lambda g: anp.concatenate(g, axis=2))
defvjp(anp.ravel, lambda ans, x, order=None : lambda g: anp.reshape(g, anp.shape(x), order=order))
defvjp(anp.expand_dims, lambda ans, x, axis : lambda g: anp.reshape(g, anp.shape(x)))
defvjp(anp.squeeze, lambda ans, x, axis=None : lambda g: anp.reshape(g, anp.shape(x)))
defvjp(anp.diag, lambda ans, x, k=0 : lambda g: anp.diag(g, k))
defvjp(anp.flipud, lambda ans, x, : lambda g: anp.flipud(g))
defvjp(anp.fliplr, lambda ans, x, : lambda g: anp.fliplr(g))
defvjp(anp.rot90, lambda ans, x, k=1 : lambda g: anp.rot90(g, -k))
defvjp(anp.trace, lambda ans, x, offset=0 : lambda g:
anp.einsum('ij,...->ij...', anp.eye(x.shape[0], x.shape[1], k=offset), g))
defvjp(anp.full, lambda ans, shape, fill_value, dtype=None : lambda g: anp.sum(g), argnums=(1,))
defvjp(anp.triu, lambda ans, x, k=0 : lambda g: anp.triu(g, k=k))
defvjp(anp.tril, lambda ans, x, k=0 : lambda g: anp.tril(g, k=k))
defvjp(anp.clip, lambda ans, x, a_min, a_max : lambda g: g * anp.logical_and(ans != a_min, ans != a_max))
defvjp(anp.swapaxes, lambda ans, x, axis1, axis2: lambda g: anp.swapaxes(g, axis2, axis1))
defvjp(anp.moveaxis, lambda ans, a, source, destination: lambda g:
anp.moveaxis(g, destination, source))
defvjp(anp.real_if_close, lambda ans, x : lambda g: match_complex(x, g))
defvjp(anp.real, lambda ans, x : lambda g: match_complex(x, g))
defvjp(anp.imag, lambda ans, x : lambda g: match_complex(x, -1j * g))
defvjp(anp.conj, lambda ans, x : lambda g: anp.conj(g))
defvjp(anp.conjugate, lambda ans, x: lambda g: anp.conj(g))
defvjp(anp.angle, lambda ans, x : lambda g: match_complex(x, g * anp.conj(x * 1j) / anp.abs(x)**2))
defvjp(anp.where, None,
lambda ans, c, x=None, y=None : lambda g: anp.where(c, g, anp.zeros(g.shape)),
lambda ans, c, x=None, y=None : lambda g: anp.where(c, anp.zeros(g.shape), g))
defvjp(anp.cross, lambda ans, a, b, axisa=-1, axisb=-1, axisc=-1, axis=None : lambda g:
anp.cross(b, g, axisb, axisc, axisa, axis),
lambda ans, a, b, axisa=-1, axisb=-1, axisc=-1, axis=None : lambda g:
anp.cross(g, a, axisc, axisa, axisb, axis))
defvjp(anp.linspace, lambda ans, start, stop, num : lambda g: anp.dot(anp.linspace(1.0, 0.0, num), g),
lambda ans, start, stop, num : lambda g: anp.dot(anp.linspace(0.0, 1.0, num), g))
defvjp(anp._astype,
lambda ans, A, dtype, order='K', casting='unsafe', subok=True, copy=True:
lambda g: anp._astype(g, A.dtype))
# ----- Trickier grads -----
def grad_rollaxis(ans, a, axis, start=0):
if axis < 0:
raise NotImplementedError("Gradient of rollaxis not implemented for axis < 0. "
"Please use moveaxis instead.")
elif start < 0:
raise NotImplementedError("Gradient of rollaxis not implemented for start < 0. "
"Please use moveaxis instead.")
return lambda g: anp.rollaxis(g, start - 1, axis) if start > axis else anp.rollaxis(g, start, axis + 1)
defvjp(anp.rollaxis, grad_rollaxis)
def grad_diff(ans, a, n=1, axis=-1):
nd = anp.ndim(a)
ans_shape = anp.shape(ans)
sl1 = [slice(None)]*nd
sl1[axis] = slice(None, 1)
sl2 = [slice(None)]*nd
sl2[axis] = slice(-1, None)
def undiff(g):
if g.shape[axis] > 0:
return anp.concatenate((-g[tuple(sl1)], -anp.diff(g, axis=axis), g[tuple(sl2)]), axis=axis)
shape = list(ans_shape)
shape[axis] = 1
return anp.zeros(shape)
def helper(g, n):
if n == 0:
return g
return helper(undiff(g), n-1)
return lambda g: helper(g, n)
defvjp(anp.diff, grad_diff)
def grad_gradient(ans, x, *vargs, **kwargs):
axis = kwargs.pop('axis', None)
if vargs or kwargs:
raise NotImplementedError(
"The only optional argument currently supported for np.gradient "
"is axis.")
if axis is None:
axis = range(x.ndim)
elif type(axis) is int:
axis = [axis]
else:
axis = list(axis)
x_dtype = x.dtype
x_shape = x.shape
nd = x.ndim
def vjp(g):
if anp.ndim(g) == nd:
# add axis if gradient was along one axis only
g = g[anp.newaxis]
# accumulate gradient
out = anp.zeros(x_shape, dtype=x_dtype)
for i, a in enumerate(axis):
# swap gradient axis to the front
g_swap = anp.swapaxes(g[i], 0, a)[:, anp.newaxis]
out_axis = anp.concatenate((
-g_swap[0] - 0.5 * g_swap[1],
g_swap[0] - 0.5 * g_swap[2],
(-1.) * anp.gradient(g_swap, axis=0)[2:-2, 0],
0.5 * g_swap[-3] - g_swap[-1],
0.5 * g_swap[-2] + g_swap[-1],
), axis=0)
out = out + anp.swapaxes(out_axis, 0, a)
return out
return vjp
defvjp(anp.gradient, grad_gradient)
def grad_repeat(ans, x, repeats, axis=None):
shape = anp.shape(x)
def vjp(g):
if axis is None: # If axis is none, np.repeat() repeats the flattened array.
expanded = anp.reshape(g, (anp.prod(shape),) + (repeats,))
return anp.reshape(anp.sum(expanded, axis=1, keepdims=False), shape)
else:
if shape[axis] == 1: # For this common case, the logic is simple.
return anp.sum(g, axis=axis, keepdims=True)
else:
expanded = anp.reshape(g, shape[0:axis+1] + (repeats,) + shape[axis+1:])
return anp.sum(expanded, axis=axis+1, keepdims=False)
return vjp
defvjp(anp.repeat, grad_repeat)
def grad_tile(ans, x, reps):
reps = [reps] if anp.isscalar(reps) else reps
x_shape = anp.shape(x)
def vjp(g):
for axis, rep in enumerate(reps):
g = sum(anp.split(g, rep, axis))
return anp.reshape(g, x_shape)
return vjp
defvjp(anp.tile, grad_tile)
def grad_kron(argnum, ans, orig_A, orig_B):
# kron has different promotion rules than dot. the reshapes are necessary if
# and only if (1) orig_B is 1D or (2) orig_A and/or orig_B are 0D
orig_A_shape = anp.shape(orig_A)
orig_B_shape = anp.shape(orig_B)
def vjp(G):
A, B = anp.atleast_2d(orig_A), anp.atleast_2d(orig_B)
shape = list(A.shape + B.shape)
n = anp.ndim(A)
shape[n-1], shape[n] = shape[n], shape[n-1]
reshaped_G = anp.swapaxes(anp.reshape(G, shape), n-1, n)
if argnum == 0:
return match_complex(orig_A, anp.reshape(anp.tensordot(reshaped_G, B, axes=anp.ndim(B)), orig_A_shape))
else:
return match_complex(orig_B, anp.reshape(anp.tensordot(A, reshaped_G, axes=anp.ndim(A)), orig_B_shape))
return vjp
defvjp(anp.kron, partial(grad_kron, 0), partial(grad_kron, 1))
def grad_transpose(ans, x, axes=None):
if axes is not None:
axes = anp.argsort(axes)
return lambda g: anp.transpose(g, axes)
defvjp(anp.transpose, grad_transpose)
def repeat_to_match_shape(g, shape, dtype, axis, keepdims):
"""Returns the array g repeated along axis to fit vector space vs.
Also returns the number of repetitions of the array."""
if shape == ():
return g, 1
axis = list(axis) if isinstance(axis, tuple) else axis
new_shape = onp.array(shape)
new_shape[axis] = 1
num_reps = onp.prod(onp.array(shape)[axis])
# Can't use broadcast_to because of numpy bug: https://github.com/numpy/numpy/issues/9165
# return anp.broadcast_to(anp.reshape(g, new_shape), shape), num_reps
return anp.reshape(g, new_shape) + onp.zeros(shape, dtype=dtype), num_reps
def grad_broadcast_to(ans, x, new_shape):
old_shape = anp.shape(x)
assert anp.shape(ans) == new_shape
assert len(old_shape) == len(new_shape), "Can't handle extra leading dims"
broadcast_axes = tuple(onp.where(onp.logical_and(
onp.array(old_shape) == 1,
onp.array(new_shape) > 1))[0])
return lambda g: anp.sum(g, axis=broadcast_axes, keepdims=True)
defvjp(anp.broadcast_to, grad_broadcast_to)
def grad_np_sum(ans, x, axis=None, keepdims=False, dtype=None):
shape, dtype = anp.shape(x), anp.result_type(x)
return lambda g: repeat_to_match_shape(g, shape, dtype, axis, keepdims)[0]
defvjp(anp.sum, grad_np_sum)
def grad_np_mean(ans, x, axis=None, keepdims=False):
shape, dtype = anp.shape(x), anp.result_type(x)
def vjp(g):
g_repeated, num_reps = repeat_to_match_shape(g, shape, dtype, axis, keepdims)
return g_repeated / num_reps
return vjp
defvjp(anp.mean, grad_np_mean)
def grad_np_prod(ans, x, axis=None, keepdims=False): # TODO: Support tuples of axes.
shape, dtype = anp.shape(x), anp.result_type(x)
def vjp(g):
g_repeated, _ = repeat_to_match_shape(g * ans, shape, dtype, axis, keepdims)
return g_repeated / x
return vjp
defvjp(anp.prod, grad_np_prod)
def grad_np_var(ans, x, axis=None, ddof=0, keepdims=False):
shape, _, dtype, iscomplex = anp.metadata(x)
def vjp(g):
if iscomplex:
g = g + 0j
g_repeated, num_reps = repeat_to_match_shape(g, shape, dtype, axis, keepdims)
x_minus_mean = anp.conj(x - anp.mean(x, axis=axis, keepdims=True))
return 2.0 * g_repeated * x_minus_mean / (num_reps - ddof)
return vjp
defvjp(anp.var, grad_np_var)
def grad_np_std(ans, x, axis=None, ddof=0, keepdims=False):
shape, _, dtype, iscomplex = anp.metadata(x)
def vjp(g):
if iscomplex:
g = g + 0j
g_repeated, num_reps = repeat_to_match_shape(g, shape, dtype, axis, keepdims) # Avoid division by zero.
if num_reps <= 1:
return g_repeated * 0.0
else:
g_repeated, num_reps = repeat_to_match_shape(g / ans, shape, dtype, axis, keepdims)
x_minus_mean = anp.conj(x - anp.mean(x, axis=axis, keepdims=True))
return g_repeated * x_minus_mean / (num_reps - ddof)
return vjp
defvjp(anp.std, grad_np_std)
def grad_chooser(ans, x, axis=None, keepdims=None):
shape, dtype = anp.shape(x), anp.result_type(x)
def vjp(g):
"""Builds gradient of functions that choose a single item, such as min or max."""
g_repeated, _ = repeat_to_match_shape(g, shape, dtype, axis, keepdims)
argmax_locations = x == repeat_to_match_shape(ans, shape, dtype, axis, keepdims)[0]
return g_repeated * argmax_locations \
/ onp.sum(argmax_locations, axis=axis, keepdims=True)
return vjp
defvjp(anp.max, grad_chooser)
defvjp(anp.min, grad_chooser)
defvjp(anp.amax, grad_chooser)
defvjp(anp.amin, grad_chooser)
def reverse_axis(x, axis):
x = x.swapaxes(axis, 0)
x = x[::-1,...]
return x.swapaxes(0, axis)
def grad_np_cumsum(ans, x, axis=None):
def vjp(g):
if axis:
return reverse_axis(anp.cumsum(reverse_axis(g, axis), axis), axis)
else:
return anp.reshape(anp.cumsum(g[::-1], axis)[::-1], x.shape)
return vjp
defvjp(anp.cumsum, grad_np_cumsum)
def grad_inner(argnum, ans, A, B):
A_ndim, B_ndim = anp.ndim(A), anp.ndim(B)
if A_ndim == 0 or B_ndim == 0:
axes = ([], [])
else:
axes = ([A_ndim - 1], [B_ndim - 1])
if argnum == 0:
return lambda G: tensordot_adjoint_0(B, G, axes, A_ndim, B_ndim)
elif argnum == 1:
return lambda G: tensordot_adjoint_1(A, G, axes, A_ndim, B_ndim)
defvjp(anp.inner, partial(grad_inner, 0), partial(grad_inner, 1))
def matmul_adjoint_0(B, G, A_meta, B_ndim):
if anp.ndim(G) == 0: # A_ndim == B_ndim == 1
return unbroadcast(G * B, A_meta)
_, A_ndim, _, _ = A_meta
if A_ndim == 1:
G = anp.expand_dims(G, anp.ndim(G) - 1)
if B_ndim == 1: # The result we need is an outer product
B = anp.expand_dims(B, 0)
G = anp.expand_dims(G, anp.ndim(G))
else: # We need to swap the last two axes of B
B = anp.swapaxes(B, B_ndim - 2, B_ndim - 1)
result = anp.matmul(G, B)
return unbroadcast(result, A_meta)
def matmul_adjoint_1(A, G, A_ndim, B_meta):
if anp.ndim(G) == 0: # A_ndim == B_ndim == 1
return unbroadcast(G * A, B_meta)
_, B_ndim, _, _ = B_meta
B_is_vec = (B_ndim == 1)
if B_is_vec:
G = anp.expand_dims(G, anp.ndim(G))
if A_ndim == 1: # The result we need is an outer product
A = anp.expand_dims(A, 1)
G = anp.expand_dims(G, anp.ndim(G) - 1)
else: # We need to swap the last two axes of A
A = anp.swapaxes(A, A_ndim - 2, A_ndim - 1)
result = anp.matmul(A, G)
if B_is_vec:
result = anp.squeeze(result, anp.ndim(G) - 1)
return unbroadcast(result, B_meta)
def matmul_vjp_0(ans, A, B):
A_meta = anp.metadata(A)
B_ndim = anp.ndim(B)
return lambda g: matmul_adjoint_0(B, g, A_meta, B_ndim)
def matmul_vjp_1(ans, A, B):
A_ndim = anp.ndim(A)
B_meta = anp.metadata(B)
return lambda g: matmul_adjoint_1(A, g, A_ndim, B_meta)
defvjp(anp.matmul, matmul_vjp_0, matmul_vjp_1)
@primitive
def dot_adjoint_0(B, G, A_meta, B_meta):
_, A_ndim, A_dtype, _ = A_meta
_, B_ndim, _, _ = B_meta
if B_ndim == 0 or B_ndim == 1 or A_ndim == 0:
contract_num = max(0, B_ndim - (A_ndim != 0))
out = onp.tensordot(G, B, contract_num)
else:
out = onp.tensordot(G, onp.swapaxes(B, -1, -2), B_ndim - 1)
return onp.asarray(out, dtype=A_dtype)
@primitive
def dot_adjoint_1(A, G, A_meta, B_meta):
_, A_ndim, _, _ = A_meta
_, B_ndim, B_dtype, _ = B_meta
needs_transpose = B_ndim > 1 and A_ndim != 0
swap = (lambda x: onp.swapaxes(x, -1, -2)) if needs_transpose else (lambda x: x)
if A_ndim == 0 or A_ndim == 1 or B_ndim == 0:
contract_num = max(0, A_ndim - (B_ndim != 0))
out = swap(onp.tensordot(G, A, contract_num))
else:
out = swap(onp.tensordot(
G, A, [range(-A_ndim - B_ndim + 2, -B_ndim + 1), range(A_ndim - 1)]))
return onp.asarray(out, dtype=B_dtype)
def dot_vjp_0(ans, A, B):
A_meta, B_meta = anp.metadata(A), anp.metadata(B)
return lambda g: match_complex(A, dot_adjoint_0(B, g, A_meta, B_meta))
def dot_vjp_1(ans, A, B):
A_meta, B_meta = anp.metadata(A), anp.metadata(B)
return lambda g: match_complex(B, dot_adjoint_1(A, g, A_meta, B_meta))
defvjp(anp.dot, dot_vjp_0, dot_vjp_1)
defvjp(dot_adjoint_0, lambda ans, B, g, An, Bn: lambda A: match_complex(B, dot_adjoint_1(A, g, An, Bn)),
lambda ans, B, g, An, Bn: lambda A: match_complex(g, anp.dot(A, B)))
defvjp(dot_adjoint_1, lambda ans, A, g, An, Bn: lambda B: match_complex(A, dot_adjoint_0(B, g, An, Bn)),
lambda ans, A, g, An, Bn: lambda B: match_complex(g, anp.dot(A, B)))
@primitive
def tensordot_adjoint_0(B, G, axes, A_ndim, B_ndim):
# The adjoint of the operator
# A |--> np.tensordot(A, B, axes)
if B_ndim == 0:
return G * B
G_axes = onp.arange(onp.ndim(G))
if type(axes) is int:
axes = max(axes, 0)
B_axes = onp.arange(B_ndim)
return onp.tensordot(G, B, [G_axes[A_ndim-axes:], B_axes[axes:]])
else:
axes0 = [axes[0]] if type(axes[0]) is int else axes[0]
axes1 = [axes[1]] if type(axes[1]) is int else axes[1]
axes = [axes0, axes1]
A_axes = onp.arange(A_ndim)
B_axes = onp.arange(B_ndim)
summed_axes = [onp.asarray(axes[0], dtype='int64') % A_ndim,
onp.asarray(axes[1], dtype='int64') % B_ndim]
other_axes = [onp.delete(A_axes, summed_axes[0]),
onp.delete(B_axes, summed_axes[1])]
out = onp.tensordot(G, B, [G_axes[len(other_axes[0]):], other_axes[1]])
perm = onp.argsort(onp.concatenate(
(other_axes[0], summed_axes[0][onp.argsort(summed_axes[1])])))
return onp.transpose(out, perm)
@primitive
def tensordot_adjoint_1(A, G, axes, A_ndim, B_ndim):
# The adjoint of the operator
# B |--> np.tensordot(A, B, axes)
if A_ndim == 0:
return G * A
G_axes = onp.arange(onp.ndim(G))
if type(axes) is int:
axes = max(axes, 0)
A_axes = onp.arange(A_ndim)
return onp.tensordot(A, G, [A_axes[:A_ndim-axes], G_axes[:A_ndim-axes]])
else:
axes0 = [axes[0]] if type(axes[0]) is int else axes[0]
axes1 = [axes[1]] if type(axes[1]) is int else axes[1]
axes = [axes0, axes1]
A_axes = onp.arange(A_ndim)
B_axes = onp.arange(B_ndim)
summed_axes = [onp.asarray(axes[0], dtype='int64') % A_ndim,
onp.asarray(axes[1], dtype='int64') % B_ndim]
other_axes = [onp.delete(A_axes, summed_axes[0]),
onp.delete(B_axes, summed_axes[1])]
out = onp.tensordot(A, G, [other_axes[0], G_axes[:len(other_axes[0])]])
perm = onp.argsort(onp.concatenate(
(summed_axes[1][onp.argsort(summed_axes[0])], other_axes[1])))
return onp.transpose(out, perm)
def tensordot_vjp_0(ans, A, B, axes=2):
A_ndim, B_ndim = anp.ndim(A), anp.ndim(B)
return lambda G: match_complex(A, tensordot_adjoint_0(B, G, axes, A_ndim, B_ndim))
def tensordot_vjp_1(ans, A, B, axes=2):
A_ndim, B_ndim = anp.ndim(A), anp.ndim(B)
return lambda G: match_complex(B, tensordot_adjoint_1(A, G, axes, A_ndim, B_ndim))
defvjp(anp.tensordot, tensordot_vjp_0, tensordot_vjp_1)
defvjp(tensordot_adjoint_0, lambda ans, B, G, axes, An, Bn: lambda A: match_complex(B, tensordot_adjoint_1(A, G, axes, An, Bn)),
lambda ans, B, G, axes, An, Bn: lambda A: match_complex(G, anp.tensordot(A, B, axes)))
defvjp(tensordot_adjoint_1, lambda ans, A, G, axes, An, Bn: lambda B: match_complex(A, tensordot_adjoint_0(B, G, axes, An, Bn)),
lambda ans, A, G, axes, An, Bn: lambda B: match_complex(G, anp.tensordot(A, B, axes)))
defvjp(anp.outer, lambda ans, a, b : lambda g: match_complex(a, anp.dot(g, b.T)),
lambda ans, a, b : lambda g: match_complex(b, anp.dot(a.T, g)))
def grad_concatenate_args(argnum, ans, axis_args, kwargs):
axis, args = axis_args[0], axis_args[1:]
sizes = [anp.shape(a)[axis] for a in args[:argnum]]
start = sum(sizes[:-1])
idxs = [slice(None)] * ans.ndim
idxs[axis] = slice(start, start + sizes[-1])
return lambda g: g[tuple(idxs)]
defvjp_argnum(anp.concatenate_args, grad_concatenate_args)
def wrapped_reshape(x, *args, **kwargs):
# The reshape method can be called like A.reshape((5,4)) or A.reshape(5,4).
# The reshape function doesn't support both ways, so we have to wrap it.
if isinstance(args[0], int):
return anp.reshape(x, args, **kwargs)
else:
return anp.reshape(x, *args, **kwargs)
setattr(ArrayBox, 'reshape', wrapped_reshape)
def grad_sort(ans, x, axis=-1, kind='quicksort', order=None):
#TODO: Cast input with np.asanyarray()
if len(x.shape) > 1:
raise NotImplementedError(
"Gradient of sort not implemented for multi-dimensional arrays.")
sort_perm = anp.argsort(x, axis, kind, order)
return lambda g: unpermuter(g, sort_perm)
defvjp(anp.sort, grad_sort)
defvjp(anp.msort, grad_sort) # Until multi-D is allowed, these are the same.
def grad_partition(ans, x, kth, axis=-1, kind='introselect', order=None):
#TODO: Cast input with np.asanyarray()
if len(x.shape) > 1:
raise NotImplementedError(
"Gradient of partition not implemented for multi-dimensional arrays.")
partition_perm = anp.argpartition(x, kth, axis, kind, order)
return lambda g: unpermuter(g, partition_perm)
defvjp(anp.partition, grad_partition)
def unpermuter(g, permutation):
unsort = anp.zeros(len(permutation), dtype=int)
unsort[permutation] = list(range(len(permutation)))
return g[unsort]
def grad_reshape_list(ans, *arys):
if len(arys) > 1:
raise NotImplementedError("Can't handle multiple arguments yet.")
return lambda g: anp.reshape(g, anp.shape(arys[0]))
defvjp(anp.atleast_1d, grad_reshape_list)
defvjp(anp.atleast_2d, grad_reshape_list)
defvjp(anp.atleast_3d, grad_reshape_list)
def grad_einsum(argnum, ans, operands_, kwargs):
result_meta = anp.metadata(operands_[argnum])
def vjp(g):
operands = operands_
if isinstance(operands[0], string_types): # using "ijk" convention.
in_subs, out_subs, _ = anp.parse_einsum_input(*operands)
string, operands = operands[0], operands[1:]
in_subs_list = in_subs.split(',')
op_num = argnum - 1
subs_wrt = in_subs_list[op_num]
rest_of_ops = operands[:op_num] + operands[op_num+1:]
rest_of_subs = in_subs_list[:op_num] + in_subs_list[op_num+1:]
# subscripts that only appear in subs_wrt (and not in other subscript lists
# or in the output) are implicitly being summed out, as if contracted
# against a tensor of ones. we make that tensor of ones explicit to handle
# the necessary vjp broadcasting inside einsum.
other_named_subs = set(''.join([out_subs] + rest_of_subs))
naked_summed = [(i, sub) for i, sub in enumerate(subs_wrt)
if sub not in other_named_subs]
if naked_summed:
naked_summed_dims, ones_subs = zip(*naked_summed)
ones_subs = ''.join(ones_subs)
ones = onp.ones(onp.array(operands[op_num].shape)[list(naked_summed_dims)])
new_input_subs = ','.join([out_subs, ones_subs] + rest_of_subs)
new_operands = (g, ones) + rest_of_ops
else:
new_input_subs = ','.join([out_subs] + rest_of_subs)
new_operands = (g,) + rest_of_ops
new_subscripts = new_input_subs + '->' + subs_wrt
return unbroadcast(anp.einsum(new_subscripts, *new_operands), result_meta)
else: # using (op0, sublist0, op1, sublist1, ..., sublistout) convention
if len(operands) % 2 == 0:
raise NotImplementedError("Need sublistout argument")
operands = list(operands)
rest_of_ops = [operands[-1]] + operands[:argnum] + \
operands[(argnum+2):-1] + [operands[argnum+1]]
return unbroadcast_einsum(anp.einsum(g, *rest_of_ops), result_meta, operands[argnum + 1])
return vjp
defvjp_argnum(anp.einsum, grad_einsum)
defvjp(anp.diagonal,
lambda ans, A, offset=0, axis1=0, axis2=1 :
lambda g: anp.make_diagonal(g, offset, axis1, axis2))
defvjp(anp.make_diagonal,
lambda ans, D, offset=0, axis1=0, axis2=1 :
lambda g: anp.diagonal(g, offset, axis1, axis2))
def match_complex(target, x):
target_iscomplex = anp.iscomplexobj(target)
x_iscomplex = anp.iscomplexobj(x)
if x_iscomplex and not target_iscomplex:
return anp.real(x)
elif not x_iscomplex and target_iscomplex:
return x + 0j
else:
return x
def unbroadcast(x, target_meta, broadcast_idx=0):
target_shape, target_ndim, dtype, target_iscomplex = target_meta
while anp.ndim(x) > target_ndim:
x = anp.sum(x, axis=broadcast_idx)
for axis, size in enumerate(target_shape):
if size == 1:
x = anp.sum(x, axis=axis, keepdims=True)
if anp.iscomplexobj(x) and not target_iscomplex:
x = anp.real(x)
return x
def unbroadcast_f(target, f):
target_meta = anp.metadata(target)
return lambda g: unbroadcast(f(g), target_meta)
def unbroadcast_einsum(x, target_meta, subscript):
if Ellipsis not in subscript:
return x
elif subscript[0] == Ellipsis:
return unbroadcast(x, target_meta, 0)
elif subscript[-1] == Ellipsis:
return unbroadcast(x, target_meta, -1)
else:
return unbroadcast(x, target_meta, subscript.index(Ellipsis))
def balanced_eq(x, z, y):
return (x == z) / (1.0 + (x == y))
def replace_zero(x, val):
return anp.where(x, x, val)
# ----- extra functions used internally -----
def array_from_args_gradmaker(argnum, ans, args, kwargs):
return lambda g: g[argnum-2]
defvjp_argnum(anp.array_from_args, array_from_args_gradmaker)
def array_from_scalar_or_array_gradmaker(ans, array_args, array_kwargs, scarray):
ndmin = array_kwargs.get('ndmin', 0)
scarray_ndim = anp.ndim(scarray)
if ndmin > scarray_ndim:
return lambda g: anp.squeeze(g, axis=tuple(range(ndmin - scarray_ndim)))
else:
return lambda g: g
defvjp(anp._array_from_scalar_or_array, array_from_scalar_or_array_gradmaker, argnums=(2,3))
@primitive
def untake(x, idx, vs):
if isinstance(idx, list) and (len(idx) == 0 or not isinstance(idx[0], slice)):
idx = onp.array(idx, dtype='int64')
def mut_add(A):
onp.add.at(A, idx, x)
return A
return SparseObject(vs, mut_add)
defvjp(func(ArrayBox.__getitem__), lambda ans, A, idx: lambda g: untake(g, idx, vspace(A)))
defvjp(untake, lambda ans, x, idx, _: lambda g: g[idx])
def _unpad(array, width):
if anp.isscalar(width):
width = [[width, width]]
elif anp.shape(width) == (1,):
width = [anp.concatenate((width, width))]
elif anp.shape(width) == (2,):
width = [width]
if anp.shape(width)[0] == 1:
width = anp.repeat(width, anp.ndim(array), 0)
idxs = tuple(slice(l, -u or None) for l, u in width)
return array[idxs]
def pad_vjp(ans, array, pad_width, mode, **kwargs):
assert mode == "constant", "Only constant mode padding is supported."
return lambda g: _unpad(g, pad_width)
defvjp(anp.pad, pad_vjp)