Blob Blame History Raw
from __future__ import absolute_import
from builtins import zip
import numpy.fft as ffto
from .numpy_wrapper import wrap_namespace
from .numpy_vjps import match_complex
from . import numpy_wrapper as anp
from autograd.extend import primitive, defvjp, vspace

wrap_namespace(ffto.__dict__, globals())

# TODO: make fft gradient work for a repeated axis,
# e.g. by replacing fftn with repeated calls to 1d fft along each axis
def fft_grad(get_args, fft_fun, ans, x, *args, **kwargs):
    axes, s, norm = get_args(x, *args, **kwargs)
    check_no_repeated_axes(axes)
    vs = vspace(x)
    return lambda g: match_complex(x, truncate_pad(fft_fun(g, *args, **kwargs), vs.shape))

defvjp(fft, lambda *args, **kwargs:
        fft_grad(get_fft_args, fft, *args, **kwargs))
defvjp(ifft, lambda *args, **kwargs:
        fft_grad(get_fft_args, ifft, *args, **kwargs))

defvjp(fft2, lambda *args, **kwargs:
        fft_grad(get_fft_args, fft2, *args, **kwargs))
defvjp(ifft2, lambda *args, **kwargs:
        fft_grad(get_fft_args, ifft2, *args, **kwargs))

defvjp(fftn, lambda *args, **kwargs:
        fft_grad(get_fft_args, fftn, *args, **kwargs))
defvjp(ifftn, lambda *args, **kwargs:
        fft_grad(get_fft_args, ifftn, *args, **kwargs))

def rfft_grad(get_args, irfft_fun, ans, x, *args, **kwargs):
    axes, s, norm = get_args(x, *args, **kwargs)
    vs = vspace(x)
    gvs = vspace(ans)
    check_no_repeated_axes(axes)
    if s is None: s = [vs.shape[i] for i in axes]
    check_even_shape(s)

    # s is the full fft shape
    # gs is the compressed shape
    gs = list(s)
    gs[-1] = gs[-1] // 2  + 1
    fac = make_rfft_factors(axes, gvs.shape, gs, s, norm)
    def vjp(g):
        g = anp.conj(g / fac)
        r = match_complex(x, truncate_pad((irfft_fun(g, *args, **kwargs)), vs.shape))
        return r
    return vjp

def irfft_grad(get_args, rfft_fun, ans, x, *args, **kwargs):
    axes, gs, norm = get_args(x, *args, **kwargs)
    vs = vspace(x)
    gvs = vspace(ans)
    check_no_repeated_axes(axes)
    if gs is None: gs = [gvs.shape[i] for i in axes]
    check_even_shape(gs)

    # gs is the full fft shape
    # s is the compressed shape
    s = list(gs)
    s[-1] = s[-1] // 2 + 1
    def vjp(g):
        r = match_complex(x, truncate_pad((rfft_fun(g,  *args, **kwargs)), vs.shape))
        fac = make_rfft_factors(axes, vs.shape, s, gs, norm)
        r = anp.conj(r) * fac
        return r
    return vjp

defvjp(rfft, lambda *args, **kwargs:
        rfft_grad(get_fft_args, irfft, *args, **kwargs))

defvjp(irfft, lambda *args, **kwargs:
        irfft_grad(get_fft_args, rfft, *args, **kwargs))

defvjp(rfft2, lambda *args, **kwargs:
        rfft_grad(get_fft2_args, irfft2, *args, **kwargs))

defvjp(irfft2, lambda *args, **kwargs:
        irfft_grad(get_fft2_args, rfft2, *args, **kwargs))

defvjp(rfftn, lambda *args, **kwargs:
        rfft_grad(get_fftn_args, irfftn, *args, **kwargs))

defvjp(irfftn, lambda *args, **kwargs:
        irfft_grad(get_fftn_args, rfftn, *args, **kwargs))

defvjp(fftshift,  lambda ans, x, axes=None : lambda g:
                 match_complex(x, anp.conj(ifftshift(anp.conj(g), axes))))
defvjp(ifftshift, lambda ans, x, axes=None : lambda g:
                 match_complex(x, anp.conj(fftshift(anp.conj(g), axes))))

@primitive
def truncate_pad(x, shape):
    # truncate/pad x to have the appropriate shape
    slices = [slice(n) for n in shape]
    pads = tuple(zip(anp.zeros(len(shape), dtype=int),
               anp.maximum(0, anp.array(shape) - anp.array(x.shape))))
    return anp.pad(x, pads, 'constant')[tuple(slices)]
defvjp(truncate_pad, lambda ans, x, shape: lambda g:
       match_complex(x, truncate_pad(g, vspace(x).shape)))

## TODO: could be made less stringent, to fail only when repeated axis has different values of s
def check_no_repeated_axes(axes):
    axes_set = set(axes)
    if len(axes) != len(axes_set):
        raise NotImplementedError("FFT gradient for repeated axes not implemented.")

def check_even_shape(shape):
    if shape[-1] % 2 != 0:
        raise NotImplementedError("Real FFT gradient for odd lengthed last axes is not implemented.")

def get_fft_args(a, d=None, axis=-1, norm=None, *args, **kwargs):
    axes = [axis]
    if d is not None: d = [d]
    return axes, d, norm

def get_fft2_args(a, s=None, axes=(-2, -1), norm=None, *args, **kwargs):
    return axes, s, norm

def get_fftn_args(a, s=None, axes=None, norm=None, *args, **kwargs):
    if axes is None:
        axes = list(range(a.ndim))
    return axes, s, norm

def make_rfft_factors(axes, resshape, facshape, normshape, norm):
    """ make the compression factors and compute the normalization
        for irfft and rfft.
    """
    N = 1.0
    for n in normshape: N = N * n

    # inplace modification is fine because we produce a constant
    # which doesn't go into autograd.
    # For same reason could have used numpy rather than anp.
    # but we already imported anp, so use it instead.
    fac = anp.zeros(resshape)
    fac[...] = 2
    index = [slice(None)] * len(resshape)
    if facshape[-1] <= resshape[axes[-1]]:
        index[axes[-1]] = (0, facshape[-1] - 1)
    else:
        index[axes[-1]] = (0,)
    fac[tuple(index)] = 1
    if norm is None:
        fac /= N
    return fac