Blob Blame History Raw
From b420613372ac54d5d6c4576736d9042f97ccdd47 Mon Sep 17 00:00:00 2001
From: Mark Borgerding <mark@borgerding.net>
Date: Sun, 15 Mar 2020 14:53:58 -0400
Subject: [PATCH] I had to fix some python3 incompatibilities and realized how
 embarrassing the code was. I refactored to make it look a little more like it
 was written by someone who knows Python.

---
 test/testkiss.py | 195 +++++++++++++++++++++--------------------------
 1 file changed, 86 insertions(+), 109 deletions(-)

diff --git a/test/testkiss.py b/test/testkiss.py
index 64809bf..81c8e2a 100755
--- a/test/testkiss.py
+++ b/test/testkiss.py
@@ -4,158 +4,135 @@
 #
 # SPDX-License-Identifier: BSD-3-Clause
 #  See COPYING file for more information.
-from __future__ import division,print_function
+from __future__ import absolute_import, division, print_function
 import math
 import sys
 import os
 import random
 import struct
 import getopt
-import numpy
+import numpy as np
 
-pi=math.pi
-e=math.e
-
-doreal=0
-
-datatype = os.environ.get('DATATYPE','float')
+po = math.pi
+e = math.e
+do_real = False
+datatype = os.environ.get('DATATYPE', 'float')
 
 util = '../tools/fft_' + datatype
-minsnr=90
+minsnr = 90
 if datatype == 'double':
-    fmt='d'
-elif datatype=='int16_t':
-    fmt='h'
-    minsnr=10
-elif datatype=='int32_t':
-    fmt='i'
-elif datatype=='simd':
-    fmt='4f'
+    dtype = np.float64
+elif datatype == 'float':
+    dtype = np.float32
+elif datatype == 'int16_t':
+    dtype = np.int16
+    minsnr = 10
+elif datatype == 'int32_t':
+    dtype = np.int32
+elif datatype == 'simd':
     sys.stderr.write('testkiss.py does not yet test simd')
     sys.exit(0)
-elif datatype=='float':
-    fmt='f'
 else:
-    sys.stderr.write('unrecognized datatype %s\n' % datatype)
+    sys.stderr.write('unrecognized datatype {0}\n'.format(datatype))
     sys.exit(1)
- 
 
-def dopack(x,cpx=1):
-    x = numpy.reshape( x, ( numpy.size(x),) )
-    
-    if cpx:
-        s = ''.join( [ struct.pack(fmt*2,c.real,c.imag) for c in x ] )
+def dopack(x):
+    if np.iscomplexobj(x):
+        x = x.astype(np.complex128).view(np.float64)
     else:
-        s = ''.join( [ struct.pack(fmt,c.real) for c in x ] )
-    return s
+        x = x.astype(np.float64)
+    return x.astype(dtype).tobytes()
 
-def dounpack(x,cpx):
-    uf = fmt * ( len(x) // struct.calcsize(fmt) )
-    s = struct.unpack(uf,x)
+def dounpack(x, cpx):
+    x = np.frombuffer(x, dtype).astype(np.float64)
     if cpx:
-        return numpy.array(s[::2]) + numpy.array( s[1::2] )*1j
+        x = x[::2] + 1j * x[1::2]
+    return x
+
+def make_random(shape):
+    'create random uniform (-1,1) data of the given shape'
+    if do_real:
+        return np.random.uniform(-1, 1, shape)
     else:
-        return numpy.array(s )
-
-def make_random(dims=[1]):
-    res = []
-    for i in range(dims[0]):
-        if len(dims)==1:
-            r=random.uniform(-1,1)
-            if doreal:
-                res.append( r )
-            else:
-                i=random.uniform(-1,1)
-                res.append( complex(r,i) )
-        else:
-            res.append( make_random( dims[1:] ) )
-    return numpy.array(res)
-
-def flatten(x):
-    ntotal = numpy.size(x)
-    return numpy.reshape(x,(ntotal,))
-
-def randmat( ndims ):
-    dims=[]
-    for i in range( ndims ):
-        curdim = int( random.uniform(2,5) )
-        if doreal and i==(ndims-1):
-            curdim = int(curdim/2)*2 # force even last dimension if real
-        dims.append( curdim )
-    return make_random(dims )
-
-def test_fft(ndims):
-    x=randmat( ndims )
-
-    if doreal:
-        xver = numpy.fft.rfftn(x)
+        return (np.random.uniform(-1, 1, shape) + 1j * np.random.uniform(-1, 1, shape))
+
+def randmat(ndim):
+    'create a random multidimensional array in range (-1,1)'
+    dims = np.random.randint(2, 5, ndim)
+    if do_real:
+        dims[-1] = (dims[-1] // 2) * 2  # force even last dimension if real
+    return make_random(dims)
+
+def test_fft(ndim):
+    x = randmat(ndim)
+
+    if do_real:
+        xver = np.fft.rfftn(x)
     else:
-        xver = numpy.fft.fftn(x)
-    
-    x2=dofft(x,doreal)
+        xver = np.fft.fftn(x)
+
+    x2 = dofft(x, do_real)
     err = xver - x2
-    errf = flatten(err)
-    xverf = flatten(xver)
-    errpow = numpy.vdot(errf,errf)+1e-10
-    sigpow = numpy.vdot(xverf,xverf)+1e-10
-    snr = 10*math.log10(abs(sigpow/errpow) )
-    print( 'SNR (compared to NumPy) : {0:.1f}dB'.format( float(snr) ) )
-
-    if snr<minsnr:
-        print( 'xver=',xver )
-        print( 'x2=',x2)
-        print( 'err',err)
+    errf = err.ravel()
+    xverf = xver.ravel()
+    errpow = np.vdot(errf, errf) + 1e-10
+    sigpow = np.vdot(xverf, xverf) + 1e-10
+    snr = 10 * math.log10(abs(sigpow / errpow))
+    print('SNR (compared to NumPy) : {0:.1f}dB'.format(float(snr)))
+
+    if snr < minsnr:
+        print('xver=', xver)
+        print('x2=', x2)
+        print('err', err)
         sys.exit(1)
- 
-def dofft(x,isreal):
-    dims=list( numpy.shape(x) )
-    x = flatten(x)
 
-    scale=1
-    if datatype=='int16_t':
+def dofft(x, isreal):
+    dims = list(np.shape(x))
+    x = x.ravel()
+
+    scale = 1
+    if datatype == 'int16_t':
         x = 32767 * x
         scale = len(x) / 32767.0
-    elif datatype=='int32_t':
+    elif datatype == 'int32_t':
         x = 2147483647.0 * x
         scale = len(x) / 2147483647.0
 
-    cmd='%s -n ' % util
+    cmd = util + ' -n '
     cmd += ','.join([str(d) for d in dims])
-    if doreal:
+    if do_real:
         cmd += ' -R '
 
-    print( cmd)
+    print(cmd)
 
-    from subprocess import Popen,PIPE
-    p = Popen(cmd,shell=True,stdin=PIPE,stdout=PIPE )
+    from subprocess import Popen, PIPE
+    p = Popen(cmd, shell=True, stdin=PIPE, stdout=PIPE)
 
-    p.stdin.write( dopack( x , isreal==False ) )
+    p.stdin.write(dopack(x))
     p.stdin.close()
 
-    res = dounpack( p.stdout.read() , 1 )
-    if doreal:
-        dims[-1] = int( dims[-1]/2 ) + 1
+    res = dounpack(p.stdout.read(), 1)
+    if do_real:
+        dims[-1] = (dims[-1] // 2) + 1
 
     res = scale * res
 
     p.wait()
-    return numpy.reshape(res,dims)
+    return np.reshape(res, dims)
 
 def main():
-    opts,args = getopt.getopt(sys.argv[1:],'r')
-    opts=dict(opts)
-
-    global doreal
-    doreal = opts.has_key('-r')
-
-    if doreal:
-        print( 'Testing multi-dimensional real FFTs')
+    opts, args = getopt.getopt(sys.argv[1:], 'r')
+    opts = dict(opts)
+    global do_real
+    do_real = '-r' in opts
+    if do_real:
+        print('Testing multi-dimensional real FFTs')
     else:
-        print( 'Testing multi-dimensional FFTs')
+        print('Testing multi-dimensional FFTs')
+
+    for dim in range(1, 4):
+        test_fft(dim)
 
-    for dim in range(1,4):
-        test_fft( dim )
 
 if __name__ == "__main__":
     main()
-
-- 
2.25.4