from __future__ import absolute_import
from builtins import zip
import numpy.fft as ffto
from .numpy_wrapper import wrap_namespace
from .numpy_vjps import match_complex
from . import numpy_wrapper as anp
from autograd.extend import primitive, defvjp, vspace
wrap_namespace(ffto.__dict__, globals())
# TODO: make fft gradient work for a repeated axis,
# e.g. by replacing fftn with repeated calls to 1d fft along each axis
def fft_grad(get_args, fft_fun, ans, x, *args, **kwargs):
axes, s, norm = get_args(x, *args, **kwargs)
check_no_repeated_axes(axes)
vs = vspace(x)
return lambda g: match_complex(x, truncate_pad(fft_fun(g, *args, **kwargs), vs.shape))
defvjp(fft, lambda *args, **kwargs:
fft_grad(get_fft_args, fft, *args, **kwargs))
defvjp(ifft, lambda *args, **kwargs:
fft_grad(get_fft_args, ifft, *args, **kwargs))
defvjp(fft2, lambda *args, **kwargs:
fft_grad(get_fft_args, fft2, *args, **kwargs))
defvjp(ifft2, lambda *args, **kwargs:
fft_grad(get_fft_args, ifft2, *args, **kwargs))
defvjp(fftn, lambda *args, **kwargs:
fft_grad(get_fft_args, fftn, *args, **kwargs))
defvjp(ifftn, lambda *args, **kwargs:
fft_grad(get_fft_args, ifftn, *args, **kwargs))
def rfft_grad(get_args, irfft_fun, ans, x, *args, **kwargs):
axes, s, norm = get_args(x, *args, **kwargs)
vs = vspace(x)
gvs = vspace(ans)
check_no_repeated_axes(axes)
if s is None: s = [vs.shape[i] for i in axes]
check_even_shape(s)
# s is the full fft shape
# gs is the compressed shape
gs = list(s)
gs[-1] = gs[-1] // 2 + 1
fac = make_rfft_factors(axes, gvs.shape, gs, s, norm)
def vjp(g):
g = anp.conj(g / fac)
r = match_complex(x, truncate_pad((irfft_fun(g, *args, **kwargs)), vs.shape))
return r
return vjp
def irfft_grad(get_args, rfft_fun, ans, x, *args, **kwargs):
axes, gs, norm = get_args(x, *args, **kwargs)
vs = vspace(x)
gvs = vspace(ans)
check_no_repeated_axes(axes)
if gs is None: gs = [gvs.shape[i] for i in axes]
check_even_shape(gs)
# gs is the full fft shape
# s is the compressed shape
s = list(gs)
s[-1] = s[-1] // 2 + 1
def vjp(g):
r = match_complex(x, truncate_pad((rfft_fun(g, *args, **kwargs)), vs.shape))
fac = make_rfft_factors(axes, vs.shape, s, gs, norm)
r = anp.conj(r) * fac
return r
return vjp
defvjp(rfft, lambda *args, **kwargs:
rfft_grad(get_fft_args, irfft, *args, **kwargs))
defvjp(irfft, lambda *args, **kwargs:
irfft_grad(get_fft_args, rfft, *args, **kwargs))
defvjp(rfft2, lambda *args, **kwargs:
rfft_grad(get_fft2_args, irfft2, *args, **kwargs))
defvjp(irfft2, lambda *args, **kwargs:
irfft_grad(get_fft2_args, rfft2, *args, **kwargs))
defvjp(rfftn, lambda *args, **kwargs:
rfft_grad(get_fftn_args, irfftn, *args, **kwargs))
defvjp(irfftn, lambda *args, **kwargs:
irfft_grad(get_fftn_args, rfftn, *args, **kwargs))
defvjp(fftshift, lambda ans, x, axes=None : lambda g:
match_complex(x, anp.conj(ifftshift(anp.conj(g), axes))))
defvjp(ifftshift, lambda ans, x, axes=None : lambda g:
match_complex(x, anp.conj(fftshift(anp.conj(g), axes))))
@primitive
def truncate_pad(x, shape):
# truncate/pad x to have the appropriate shape
slices = [slice(n) for n in shape]
pads = tuple(zip(anp.zeros(len(shape), dtype=int),
anp.maximum(0, anp.array(shape) - anp.array(x.shape))))
return anp.pad(x, pads, 'constant')[tuple(slices)]
defvjp(truncate_pad, lambda ans, x, shape: lambda g:
match_complex(x, truncate_pad(g, vspace(x).shape)))
## TODO: could be made less stringent, to fail only when repeated axis has different values of s
def check_no_repeated_axes(axes):
axes_set = set(axes)
if len(axes) != len(axes_set):
raise NotImplementedError("FFT gradient for repeated axes not implemented.")
def check_even_shape(shape):
if shape[-1] % 2 != 0:
raise NotImplementedError("Real FFT gradient for odd lengthed last axes is not implemented.")
def get_fft_args(a, d=None, axis=-1, norm=None, *args, **kwargs):
axes = [axis]
if d is not None: d = [d]
return axes, d, norm
def get_fft2_args(a, s=None, axes=(-2, -1), norm=None, *args, **kwargs):
return axes, s, norm
def get_fftn_args(a, s=None, axes=None, norm=None, *args, **kwargs):
if axes is None:
axes = list(range(a.ndim))
return axes, s, norm
def make_rfft_factors(axes, resshape, facshape, normshape, norm):
""" make the compression factors and compute the normalization
for irfft and rfft.
"""
N = 1.0
for n in normshape: N = N * n
# inplace modification is fine because we produce a constant
# which doesn't go into autograd.
# For same reason could have used numpy rather than anp.
# but we already imported anp, so use it instead.
fac = anp.zeros(resshape)
fac[...] = 2
index = [slice(None)] * len(resshape)
if facshape[-1] <= resshape[axes[-1]]:
index[axes[-1]] = (0, facshape[-1] - 1)
else:
index[axes[-1]] = (0,)
fac[tuple(index)] = 1
if norm is None:
fac /= N
return fac