Blob Blame History Raw
from __future__ import absolute_import
from functools import partial
import autograd.numpy as np
import autograd.numpy.random as npr
from autograd.test_util import check_grads
from autograd import grad
import pytest
npr.seed(1)

### fwd mode not yet implemented
check_grads = partial(check_grads, modes=['rev'])

def test_fft():
    def fun(x): return np.fft.fft(x)
    D = 5
    mat = npr.randn(D, D)
    check_grads(fun)(mat)

def test_fft_ortho():
    def fun(x): return np.fft.fft(x, norm='ortho')
    D = 5
    mat = npr.randn(D, D)
    check_grads(fun)(mat)

def test_fft_axis():
    def fun(x): return np.fft.fft(x, axis=0)
    D = 5
    mat = npr.randn(D, D)
    check_grads(fun)(mat)

def match_complex(fft_fun, mat):
    # ensure hermitian by doing a fft
    if fft_fun.__name__.startswith('ir'):
        return getattr(np.fft, fft_fun.__name__[1:])(mat)
    else:
        return mat

def check_fft_n(fft_fun, D, n):
    def fun(x): return fft_fun(x, D + n)
    mat = npr.randn(D, D)
    mat = match_complex(fft_fun, mat)
    check_grads(fun)(mat)

def test_fft_n_smaller(): check_fft_n(np.fft.fft, 5, -2)
def test_fft_n_bigger(): check_fft_n(np.fft.fft, 5, 2)
def test_ifft_n_smaller(): check_fft_n(np.fft.ifft, 5, -2)
def test_ifft_n_bigger(): check_fft_n(np.fft.ifft, 5, 2)

def test_rfft_n_smaller(): check_fft_n(np.fft.rfft, 4, -2)
def test_rfft_n_bigger(): check_fft_n(np.fft.rfft, 4, 2)
def test_irfft_n_smaller(): check_fft_n(np.fft.irfft, 4, -2)
def test_irfft_n_bigger(): check_fft_n(np.fft.irfft, 4, 2)

def check_fft_s(fft_fun, D):
   def fun(x): return fft_fun(x, s=s, axes=axes)
   mat = npr.randn(D,D,D) / 10.0
   mat = match_complex(fft_fun, mat)
   s = [D + 2, D - 2]
   axes = [0,2]
   check_grads(fun)(mat)

def test_fft2_s():  check_fft_s(np.fft.fft2, 5)
def test_ifft2_s(): check_fft_s(np.fft.ifft2, 5)
def test_fftn_s():  check_fft_s(np.fft.fftn, 5)
def test_ifftn_s(): check_fft_s(np.fft.ifftn, 5)

def test_rfft2_s():  check_fft_s(np.fft.rfft2, 4)
def test_irfft2_s(): check_fft_s(np.fft.irfft2, 4)
def test_rfftn_s():  check_fft_s(np.fft.rfftn, 4)
def test_irfftn_s(): check_fft_s(np.fft.irfftn, 4)

## TODO: fft gradient not implemented for repeated axes
# def test_fft_repeated_axis():
#     D = 5
#     for fft_fun in (np.fft.fft2,np.fft.ifft2,np.fft.fftn, np.fft.ifftn):
#        def fun(x): return fft_fun(x, s=s, axes=axes)

#        mat = npr.randn(D,D,D) / 10.0
#        s = [D + 2, D - 2]
#        axes = [0,0]

#   check_grads(rad)(fun)

def test_ifft():
    def fun(x): return np.fft.ifft(x)
    D = 5
    mat = npr.randn(D, D)
    check_grads(fun)(mat)

def test_fft2():
    def fun(x): return np.fft.fft2(x)
    D = 5
    mat = npr.randn(D, D) / 10.0
    check_grads(fun)(mat)

def test_ifft2():
    def fun(x): return np.fft.ifft2(x)
    D = 5
    mat = npr.randn(D, D)
    check_grads(fun)(mat)

def test_fftn():
    def fun(x): return np.fft.fftn(x)
    D = 5
    mat = npr.randn(D, D) / 10.0
    check_grads(fun)(mat)

def test_ifftn():
    def fun(x): return np.fft.ifftn(x)
    D = 5
    mat = npr.randn(D, D)
    check_grads(fun)(mat)

def test_rfft():
    def fun(x): return np.fft.rfft(x)
    D = 4
    mat = npr.randn(D, D) / 10.0
    check_grads(fun)(mat)

def test_rfft_ortho():
    def fun(x): return np.fft.rfft(x, norm='ortho')
    D = 4
    mat = npr.randn(D, D) / 10.0
    check_grads(fun)(mat)

def test_rfft_axes():
    def fun(x): return np.fft.rfft(x, axis=0)
    D = 4
    mat = npr.randn(D, D) / 10.0
    check_grads(fun)(mat)

def test_irfft():
    def fun(x): return np.fft.irfft(x)
    D = 4
    mat = npr.randn(D, D) / 10.0
    # ensure hermitian by doing a fft
    mat = np.fft.rfft(mat)
    check_grads(fun)(mat)

def test_irfft_ortho():
    def fun(x): return np.fft.irfft(x, norm='ortho')
    D = 4
    mat = npr.randn(D, D) / 10.0
    # ensure hermitian by doing a fft
    mat = np.fft.rfft(mat)
    check_grads(fun)(mat)

def test_rfft2():
    def fun(x): return np.fft.rfft2(x)
    D = 4
    mat = npr.randn(D, D) / 10.0
    check_grads(fun)(mat)

def test_irfft2():
    def fun(x): return np.fft.irfft2(x)
    D = 4
    mat = npr.randn(D, D) / 10.0
    # ensure hermitian by doing a fft
    mat = np.fft.rfft2(mat)
    check_grads(fun)(mat)

def test_rfftn():
    def fun(x): return np.fft.rfftn(x)
    D = 4
    mat = npr.randn(D, D, D) / 10.0
    check_grads(fun)(mat)

def test_rfftn_odd_not_implemented():
    def fun(x): return np.fft.rfftn(x)
    D = 5
    mat = npr.randn(D, D, D) / 10.0
    with pytest.raises(NotImplementedError):
        check_grads(fun)(mat)

def test_rfftn_subset():
    def fun(x): return np.fft.rfftn(x)[(0, 1, 0), (3, 3, 2)]
    D = 4
    mat = npr.randn(D, D, D) / 10.0
    check_grads(fun)(mat)

def test_rfftn_axes():
    def fun(x): return np.fft.rfftn(x, axes=(0, 2))
    D = 4
    mat = npr.randn(D, D, D) / 10.0
    check_grads(fun)(mat)

def test_irfftn():
    def fun(x): return np.fft.irfftn(x)
    D = 4
    mat = npr.randn(D, D, D) / 10.0
    # ensure hermitian by doing a fft
    mat = np.fft.rfftn(mat)
    check_grads(fun)(mat)

def test_irfftn_subset():
    def fun(x): return np.fft.irfftn(x)[(0, 1, 0), (3, 3, 2)]
    D = 4
    mat = npr.randn(D, D, D) / 10.0
    # ensure hermitian by doing a fft
    mat = np.fft.rfftn(mat)
    check_grads(fun)(mat)

def test_fftshift():
    def fun(x): return np.fft.fftshift(x)
    D = 5
    mat = npr.randn(D, D) / 10.0
    check_grads(fun)(mat)

def test_fftshift_even():
    def fun(x): return np.fft.fftshift(x)
    D = 4
    mat = npr.randn(D, D) / 10.0
    check_grads(fun)(mat)

def test_fftshift_axes():
    def fun(x): return np.fft.fftshift(x, axes=1)
    D = 5
    mat = npr.randn(D, D) / 10.0
    check_grads(fun)(mat)

def test_ifftshift():
    def fun(x): return np.fft.ifftshift(x)
    D = 5
    mat = npr.randn(D, D)
    check_grads(fun)(mat)

def test_ifftshift_even():
    def fun(x): return np.fft.ifftshift(x)
    D = 4
    mat = npr.randn(D, D)
    check_grads(fun)(mat)

def test_ifftshift_axes():
    def fun(x): return np.fft.ifftshift(x, axes=1)
    D = 5
    mat = npr.randn(D, D)
    check_grads(fun)(mat)