Blob Blame History Raw
"""
Handy functions for flattening nested containers containing numpy
arrays. The main purpose is to make examples and optimizers simpler.
"""
from autograd import make_vjp
from autograd.builtins import type
import autograd.numpy as np

def flatten(value):
    """Flattens any nesting of tuples, lists, or dicts, with numpy arrays or
    scalars inside. Returns 1D numpy array and an unflatten function.
    Doesn't preserve mixed numeric types (e.g. floats and ints). Assumes dict
    keys are sortable."""
    unflatten, flat_value = make_vjp(_flatten)(value)
    return flat_value, unflatten

def _flatten(value):
    t = type(value)
    if t in (list, tuple):
        return _concatenate(map(_flatten, value))
    elif t is dict:
        return _concatenate(_flatten(value[k]) for k in sorted(value))
    else:
        return np.ravel(value)

def _concatenate(lst):
    lst = list(lst)
    return np.concatenate(lst) if lst else np.array([])

def flatten_func(func, example):
    _ex, unflatten = flatten(example)
    _func = lambda _x, *args: flatten(func(unflatten(_x), *args))[0]
    return _func, unflatten, _ex