from __future__ import absolute_import
import types
import warnings
from autograd.extend import primitive, notrace_primitive
import numpy as _np
import autograd.builtins as builtins
from numpy.core.einsumfunc import _parse_einsum_input
notrace_functions = [
_np.ndim, _np.shape, _np.iscomplexobj, _np.result_type
]
def wrap_intdtype(cls):
class IntdtypeSubclass(cls):
__new__ = notrace_primitive(cls.__new__)
return IntdtypeSubclass
def wrap_namespace(old, new):
unchanged_types = {float, int, type(None), type}
int_types = {_np.int, _np.int8, _np.int16, _np.int32, _np.int64, _np.integer}
for name, obj in old.items():
if obj in notrace_functions:
new[name] = notrace_primitive(obj)
elif callable(obj) and type(obj) is not type:
new[name] = primitive(obj)
elif type(obj) is type and obj in int_types:
new[name] = wrap_intdtype(obj)
elif type(obj) in unchanged_types:
new[name] = obj
wrap_namespace(_np.__dict__, globals())
# ----- Special treatment of list-input functions -----
@primitive
def concatenate_args(axis, *args):
return _np.concatenate(args, axis).view(ndarray)
concatenate = lambda arr_list, axis=0 : concatenate_args(axis, *arr_list)
vstack = row_stack = lambda tup: concatenate([atleast_2d(_m) for _m in tup], axis=0)
def hstack(tup):
arrs = [atleast_1d(_m) for _m in tup]
if arrs[0].ndim == 1:
return concatenate(arrs, 0)
return concatenate(arrs, 1)
def column_stack(tup):
arrays = []
for v in tup:
arr = array(v)
if arr.ndim < 2:
arr = array(arr, ndmin=2).T
arrays.append(arr)
return concatenate(arrays, 1)
def array(A, *args, **kwargs):
t = builtins.type(A)
if t in (list, tuple):
return array_from_args(args, kwargs, *map(array, A))
else:
return _array_from_scalar_or_array(args, kwargs, A)
def wrap_if_boxes_inside(raw_array, slow_op_name=None):
if raw_array.dtype is _np.dtype('O'):
if slow_op_name:
warnings.warn("{0} is slow for array inputs. "
"np.concatenate() is faster.".format(slow_op_name))
return array_from_args((), {}, *raw_array.ravel()).reshape(raw_array.shape)
else:
return raw_array
@primitive
def _array_from_scalar_or_array(array_args, array_kwargs, scalar):
return _np.array(scalar, *array_args, **array_kwargs)
@primitive
def array_from_args(array_args, array_kwargs, *args):
return _np.array(args, *array_args, **array_kwargs)
def select(condlist, choicelist, default=0):
raw_array = _np.select(list(condlist), list(choicelist), default=default)
return array(list(raw_array.ravel())).reshape(raw_array.shape)
def stack(arrays, axis=0):
# this code is basically copied from numpy/core/shape_base.py's stack
# we need it here because we want to re-implement stack in terms of the
# primitives defined in this file
arrays = [array(arr) for arr in arrays]
if not arrays:
raise ValueError('need at least one array to stack')
shapes = set(arr.shape for arr in arrays)
if len(shapes) != 1:
raise ValueError('all input arrays must have the same shape')
result_ndim = arrays[0].ndim + 1
if not -result_ndim <= axis < result_ndim:
raise IndexError('axis {0} out of bounds [-{1}, {1})'.format(axis, result_ndim))
if axis < 0:
axis += result_ndim
sl = (slice(None),) * axis + (None,)
return concatenate([arr[sl] for arr in arrays], axis=axis)
def append(arr, values, axis=None):
# this code is basically copied from numpy/lib/function_base.py's append
arr = array(arr)
if axis is None:
if ndim(arr) != 1:
arr = ravel(arr)
values = ravel(array(values))
axis = ndim(arr) - 1
return concatenate((arr, values), axis=axis)
# ----- Enable functions called using [] ----
class r_class():
def __getitem__(self, args):
raw_array = _np.r_[args]
return wrap_if_boxes_inside(raw_array, slow_op_name = "r_")
r_ = r_class()
class c_class():
def __getitem__(self, args):
raw_array = _np.c_[args]
return wrap_if_boxes_inside(raw_array, slow_op_name = "c_")
c_ = c_class()
# ----- misc -----
@primitive
def make_diagonal(D, offset=0, axis1=0, axis2=1):
# Numpy doesn't offer a complement to np.diagonal: a function to create new
# diagonal arrays with extra dimensions. We need such a function for the
# gradient of np.diagonal and it's also quite handy to have. So here it is.
if not (offset==0 and axis1==-1 and axis2==-2):
raise NotImplementedError("Currently make_diagonal only supports offset=0, axis1=-1, axis2=-2")
# We use a trick: calling np.diagonal returns a view on the original array,
# so we can modify it in-place. (only valid for numpy version >= 1.10.)
new_array = _np.zeros(D.shape + (D.shape[-1],))
new_array_diag = _np.diagonal(new_array, offset=0, axis1=-1, axis2=-2)
new_array_diag.flags.writeable = True
new_array_diag[:] = D
return new_array
@notrace_primitive
def metadata(A):
return _np.shape(A), _np.ndim(A), _np.result_type(A), _np.iscomplexobj(A)
@notrace_primitive
def parse_einsum_input(*args):
return _parse_einsum_input(args)
@primitive
def _astype(A, dtype, order='K', casting='unsafe', subok=True, copy=True):
return A.astype(dtype, order, casting, subok, copy)