from __future__ import absolute_import
from builtins import range, zip
from functools import partial
import autograd.numpy as np
import numpy as npo # original numpy
from autograd.extend import primitive, defvjp
from numpy.lib.stride_tricks import as_strided
from future.utils import iteritems
@primitive
def convolve(A, B, axes=None, dot_axes=[(),()], mode='full'):
assert mode in ['valid', 'full'], "Mode {0} not yet implemented".format(mode)
if axes is None:
axes = [list(range(A.ndim)), list(range(A.ndim))]
wrong_order = any([B.shape[ax_B] < A.shape[ax_A] for ax_A, ax_B in zip(*axes)])
if wrong_order:
if mode=='valid' and not all([B.shape[ax_B] <= A.shape[ax_A] for ax_A, ax_B in zip(*axes)]):
raise Exception("One array must be larger than the other along all convolved dimensions")
elif mode != 'full' or B.size <= A.size: # Tie breaker
i1 = B.ndim - len(dot_axes[1]) - len(axes[1]) # B ignore
i2 = i1 + A.ndim - len(dot_axes[0]) - len(axes[0]) # A ignore
i3 = i2 + len(axes[0])
ignore_B = list(range(i1))
ignore_A = list(range(i1, i2))
conv = list(range(i2, i3))
return convolve(B, A, axes=axes[::-1], dot_axes=dot_axes[::-1], mode=mode).transpose(ignore_A + ignore_B + conv)
if mode == 'full':
B = pad_to_full(B, A, axes[::-1])
B_view_shape = list(B.shape)
B_view_strides = list(B.strides)
flipped_idxs = [slice(None)] * A.ndim
for ax_A, ax_B in zip(*axes):
B_view_shape.append(abs(B.shape[ax_B] - A.shape[ax_A]) + 1)
B_view_strides.append(B.strides[ax_B])
B_view_shape[ax_B] = A.shape[ax_A]
flipped_idxs[ax_A] = slice(None, None, -1)
B_view = as_strided(B, B_view_shape, B_view_strides)
A_view = A[tuple(flipped_idxs)]
all_axes = [list(axes[i]) + list(dot_axes[i]) for i in [0, 1]]
return einsum_tensordot(A_view, B_view, all_axes)
def einsum_tensordot(A, B, axes, reverse=False):
# Does tensor dot product using einsum, which shouldn't require a copy.
A_axnums = list(range(A.ndim))
B_axnums = list(range(A.ndim, A.ndim + B.ndim))
sum_axnum = A.ndim + B.ndim
for i_sum, (i_A, i_B) in enumerate(zip(*axes)):
A_axnums[i_A] = sum_axnum + i_sum
B_axnums[i_B] = sum_axnum + i_sum
return npo.einsum(A, A_axnums, B, B_axnums)
def pad_to_full(A, B, axes):
A_pad = [(0, 0)] * A.ndim
for ax_A, ax_B in zip(*axes):
A_pad[ax_A] = (B.shape[ax_B] - 1,) * 2
return npo.pad(A, A_pad, mode='constant')
def parse_axes(A_shape, B_shape, conv_axes, dot_axes, mode):
A_ndim, B_ndim = len(A_shape), len(B_shape)
if conv_axes is None:
conv_axes = (tuple(range(A_ndim)), tuple(range(A_ndim)),)
axes = {'A' : {'conv' : tuple(conv_axes[0]),
'dot' : tuple(dot_axes[0]),
'ignore' : tuple(i for i in range(A_ndim)
if i not in conv_axes[0] and i not in dot_axes[0])},
'B' : {'conv' : tuple(conv_axes[1]),
'dot' : tuple(dot_axes[1]),
'ignore' : tuple(i for i in range(B_ndim)
if i not in conv_axes[1] and i not in dot_axes[1])}}
assert len(axes['A']['dot']) == len(axes['B']['dot'])
assert len(axes['A']['conv']) == len(axes['B']['conv'])
i1 = len(axes['A']['ignore'])
i2 = i1 + len(axes['B']['ignore'])
i3 = i2 + len(axes['A']['conv'])
axes['out'] = {'ignore_A' : tuple(range(i1)),
'ignore_B' : tuple(range(i1, i2)),
'conv' : tuple(range(i2, i3))}
conv_shape = (compute_conv_size(A_shape[i], B_shape[j], mode)
for i, j in zip(axes['A']['conv'], axes['B']['conv']))
shapes = {'A' : {s : (A_shape[i] for i in ax) for s, ax in iteritems(axes['A'])},
'B' : {s : (B_shape[i] for i in ax) for s, ax in iteritems(axes['B'])}}
shapes['out'] = {'ignore_A' : shapes['A']['ignore'],
'ignore_B' : shapes['B']['ignore'],
'conv' : conv_shape}
return axes, shapes
def compute_conv_size(A_size, B_size, mode):
if mode == 'full':
return A_size + B_size - 1
elif mode == 'same':
return A_size
elif mode == 'valid':
return abs(A_size - B_size) + 1
else:
raise Exception("Mode {0} not recognized".format(mode))
def flipped_idxs(ndim, axes):
new_idxs = [slice(None)] * ndim
for ax in axes:
new_idxs[ax] = slice(None, None, -1)
return tuple(new_idxs)
def grad_convolve(argnum, ans, A, B, axes=None, dot_axes=[(),()], mode='full'):
assert mode in ['valid', 'full'], "Grad for mode {0} not yet implemented".format(mode)
axes, shapes = parse_axes(A.shape, B.shape, axes, dot_axes, mode)
if argnum == 0:
X, Y = A, B
_X_, _Y_ = 'A', 'B'
ignore_Y = 'ignore_B'
elif argnum == 1:
X, Y = B, A
_X_, _Y_ = 'B', 'A'
ignore_Y = 'ignore_A'
else:
raise NotImplementedError("Can't take grad of convolve w.r.t. arg {0}".format(argnum))
if mode == 'full':
new_mode = 'valid'
else:
if any([x_size > y_size for x_size, y_size in zip(shapes[_X_]['conv'], shapes[_Y_]['conv'])]):
new_mode = 'full'
else:
new_mode = 'valid'
def vjp(g):
result = convolve(g, Y[flipped_idxs(Y.ndim, axes[_Y_]['conv'])],
axes = [axes['out']['conv'], axes[_Y_]['conv']],
dot_axes = [axes['out'][ignore_Y], axes[_Y_]['ignore']],
mode = new_mode)
new_order = npo.argsort(axes[_X_]['ignore'] + axes[_X_]['dot'] + axes[_X_]['conv'])
return np.transpose(result, new_order)
return vjp
defvjp(convolve, partial(grad_convolve, 0), partial(grad_convolve, 1))