Blob Blame History Raw
from __future__ import absolute_import
from builtins import range, zip
from functools import partial
import autograd.numpy as np
import numpy as npo # original numpy
from autograd.extend import primitive, defvjp

from numpy.lib.stride_tricks import as_strided
from future.utils import iteritems

@primitive
def convolve(A, B, axes=None, dot_axes=[(),()], mode='full'):
    assert mode in ['valid', 'full'], "Mode {0} not yet implemented".format(mode)
    if axes is None:
        axes = [list(range(A.ndim)), list(range(A.ndim))]
    wrong_order = any([B.shape[ax_B] < A.shape[ax_A] for ax_A, ax_B in zip(*axes)])
    if wrong_order:
        if mode=='valid' and not all([B.shape[ax_B] <= A.shape[ax_A] for ax_A, ax_B in zip(*axes)]):
                raise Exception("One array must be larger than the other along all convolved dimensions")
        elif mode != 'full' or B.size <= A.size: # Tie breaker
            i1 =      B.ndim - len(dot_axes[1]) - len(axes[1]) # B ignore
            i2 = i1 + A.ndim - len(dot_axes[0]) - len(axes[0]) # A ignore
            i3 = i2 + len(axes[0])
            ignore_B = list(range(i1))
            ignore_A = list(range(i1, i2))
            conv     = list(range(i2, i3))
            return convolve(B, A, axes=axes[::-1], dot_axes=dot_axes[::-1], mode=mode).transpose(ignore_A + ignore_B + conv)

    if mode == 'full':
        B = pad_to_full(B, A, axes[::-1])
    B_view_shape = list(B.shape)
    B_view_strides = list(B.strides)
    flipped_idxs = [slice(None)] * A.ndim
    for ax_A, ax_B in zip(*axes):
        B_view_shape.append(abs(B.shape[ax_B] - A.shape[ax_A]) + 1)
        B_view_strides.append(B.strides[ax_B])
        B_view_shape[ax_B] = A.shape[ax_A]
        flipped_idxs[ax_A] = slice(None, None, -1)
    B_view = as_strided(B, B_view_shape, B_view_strides)
    A_view = A[tuple(flipped_idxs)]
    all_axes = [list(axes[i]) + list(dot_axes[i]) for i in [0, 1]]
    return einsum_tensordot(A_view, B_view, all_axes)

def einsum_tensordot(A, B, axes, reverse=False):
    # Does tensor dot product using einsum, which shouldn't require a copy.
    A_axnums = list(range(A.ndim))
    B_axnums = list(range(A.ndim, A.ndim + B.ndim))
    sum_axnum = A.ndim + B.ndim
    for i_sum, (i_A, i_B) in enumerate(zip(*axes)):
        A_axnums[i_A] = sum_axnum + i_sum
        B_axnums[i_B] = sum_axnum + i_sum
    return npo.einsum(A, A_axnums, B, B_axnums)

def pad_to_full(A, B, axes):
    A_pad = [(0, 0)] * A.ndim
    for ax_A, ax_B in zip(*axes):
        A_pad[ax_A] = (B.shape[ax_B] - 1,) * 2
    return npo.pad(A, A_pad, mode='constant')

def parse_axes(A_shape, B_shape, conv_axes, dot_axes, mode):
    A_ndim, B_ndim = len(A_shape), len(B_shape)
    if conv_axes is None:
        conv_axes = (tuple(range(A_ndim)), tuple(range(A_ndim)),)
    axes = {'A' : {'conv' : tuple(conv_axes[0]),
                   'dot'  : tuple(dot_axes[0]),
                   'ignore' : tuple(i for i in range(A_ndim)
                             if i not in conv_axes[0] and i not in dot_axes[0])},
            'B' : {'conv' : tuple(conv_axes[1]),
                   'dot'  : tuple(dot_axes[1]),
                   'ignore' : tuple(i for i in range(B_ndim)
                               if i not in conv_axes[1] and i not in dot_axes[1])}}
    assert len(axes['A']['dot'])  == len(axes['B']['dot'])
    assert len(axes['A']['conv']) == len(axes['B']['conv'])
    i1 =      len(axes['A']['ignore'])
    i2 = i1 + len(axes['B']['ignore'])
    i3 = i2 + len(axes['A']['conv'])
    axes['out'] = {'ignore_A' : tuple(range(i1)),
                   'ignore_B' : tuple(range(i1, i2)),
                   'conv'     : tuple(range(i2, i3))}
    conv_shape = (compute_conv_size(A_shape[i], B_shape[j], mode)
                  for i, j in zip(axes['A']['conv'], axes['B']['conv']))
    shapes = {'A'   : {s : (A_shape[i] for i in ax) for s, ax in iteritems(axes['A'])},
              'B'   : {s : (B_shape[i] for i in ax) for s, ax in iteritems(axes['B'])}}
    shapes['out'] = {'ignore_A' : shapes['A']['ignore'],
                     'ignore_B' : shapes['B']['ignore'],
                     'conv'     : conv_shape}
    return axes, shapes

def compute_conv_size(A_size, B_size, mode):
    if mode == 'full':
        return A_size + B_size - 1
    elif mode == 'same':
        return A_size
    elif mode == 'valid':
        return abs(A_size - B_size) + 1
    else:
        raise Exception("Mode {0} not recognized".format(mode))

def flipped_idxs(ndim, axes):
    new_idxs = [slice(None)] * ndim
    for ax in axes:
        new_idxs[ax] = slice(None, None, -1)
    return tuple(new_idxs)

def grad_convolve(argnum, ans, A, B, axes=None, dot_axes=[(),()], mode='full'):
    assert mode in ['valid', 'full'], "Grad for mode {0} not yet implemented".format(mode)
    axes, shapes = parse_axes(A.shape, B.shape, axes, dot_axes, mode)
    if argnum == 0:
        X, Y = A, B
        _X_, _Y_ = 'A', 'B'
        ignore_Y = 'ignore_B'
    elif argnum == 1:
        X, Y = B, A
        _X_, _Y_ = 'B', 'A'
        ignore_Y = 'ignore_A'
    else:
        raise NotImplementedError("Can't take grad of convolve w.r.t. arg {0}".format(argnum))

    if mode == 'full':
        new_mode = 'valid'
    else:
        if any([x_size > y_size for x_size, y_size in zip(shapes[_X_]['conv'], shapes[_Y_]['conv'])]):
            new_mode = 'full'
        else:
            new_mode = 'valid'

    def vjp(g):
        result = convolve(g, Y[flipped_idxs(Y.ndim, axes[_Y_]['conv'])],
                          axes     = [axes['out']['conv'],   axes[_Y_]['conv']],
                          dot_axes = [axes['out'][ignore_Y], axes[_Y_]['ignore']],
                          mode     = new_mode)
        new_order = npo.argsort(axes[_X_]['ignore'] + axes[_X_]['dot'] + axes[_X_]['conv'])
        return np.transpose(result, new_order)
    return vjp

defvjp(convolve, partial(grad_convolve, 0), partial(grad_convolve, 1))