|
|
3fb7ae8 |
--- theano/tensor/tests/test_basic.py.orig 2019-01-15 14:13:57.000000000 -0700
|
|
|
3fb7ae8 |
+++ theano/tensor/tests/test_basic.py 2019-08-15 09:02:44.459403923 -0600
|
|
|
3fb7ae8 |
@@ -344,93 +344,6 @@ def makeTester(name, op, expected, check
|
|
|
3fb7ae8 |
os.close(f)
|
|
|
3fb7ae8 |
os.remove(fname)
|
|
|
3fb7ae8 |
|
|
|
3fb7ae8 |
- def test_good(self):
|
|
|
3fb7ae8 |
- if skip:
|
|
|
3fb7ae8 |
- raise SkipTest(skip)
|
|
|
3fb7ae8 |
-
|
|
|
3fb7ae8 |
- good = self.add_memmap_values(self.good)
|
|
|
3fb7ae8 |
-
|
|
|
3fb7ae8 |
- for testname, inputs in iteritems(good):
|
|
|
3fb7ae8 |
- inputs = [copy(input) for input in inputs]
|
|
|
3fb7ae8 |
- inputrs = [TensorType(
|
|
|
3fb7ae8 |
- dtype=input.dtype,
|
|
|
3fb7ae8 |
- broadcastable=[shape_elem == 1
|
|
|
3fb7ae8 |
- for shape_elem in input.shape]
|
|
|
3fb7ae8 |
- )() for input in inputs]
|
|
|
3fb7ae8 |
- try:
|
|
|
3fb7ae8 |
- node = safe_make_node(self.op, *inputrs)
|
|
|
3fb7ae8 |
- except Exception as exc:
|
|
|
3fb7ae8 |
- err_msg = ("Test %s::%s: Error occurred while"
|
|
|
3fb7ae8 |
- " making a node with inputs %s") % (
|
|
|
3fb7ae8 |
- self.op, testname, inputs)
|
|
|
3fb7ae8 |
- exc.args += (err_msg,)
|
|
|
3fb7ae8 |
- raise
|
|
|
3fb7ae8 |
-
|
|
|
3fb7ae8 |
- try:
|
|
|
3fb7ae8 |
- f = inplace_func(inputrs, node.outputs, mode=mode, name='test_good')
|
|
|
3fb7ae8 |
- except Exception as exc:
|
|
|
3fb7ae8 |
- err_msg = ("Test %s::%s: Error occurred while"
|
|
|
3fb7ae8 |
- " trying to make a Function") % (self.op, testname)
|
|
|
3fb7ae8 |
- exc.args += (err_msg,)
|
|
|
3fb7ae8 |
- raise
|
|
|
3fb7ae8 |
- if (isinstance(self.expected, dict) and
|
|
|
3fb7ae8 |
- testname in self.expected):
|
|
|
3fb7ae8 |
- expecteds = self.expected[testname]
|
|
|
3fb7ae8 |
- # with numpy version, when we print a number and read it
|
|
|
3fb7ae8 |
- # back, we don't get exactly the same result, so we accept
|
|
|
3fb7ae8 |
- # rounding error in that case.
|
|
|
3fb7ae8 |
- eps = 5e-9
|
|
|
3fb7ae8 |
- else:
|
|
|
3fb7ae8 |
- expecteds = self.expected(*inputs)
|
|
|
3fb7ae8 |
- eps = 1e-10
|
|
|
3fb7ae8 |
-
|
|
|
3fb7ae8 |
- if any([i.dtype in ('float32', 'int8', 'uint8', 'uint16')
|
|
|
3fb7ae8 |
- for i in inputs]):
|
|
|
3fb7ae8 |
- eps = 1e-6
|
|
|
3fb7ae8 |
- eps = np.max([eps, _eps])
|
|
|
3fb7ae8 |
-
|
|
|
3fb7ae8 |
- try:
|
|
|
3fb7ae8 |
- variables = f(*inputs)
|
|
|
3fb7ae8 |
- except Exception as exc:
|
|
|
3fb7ae8 |
- err_msg = ("Test %s::%s: Error occurred while calling"
|
|
|
3fb7ae8 |
- " the Function on the inputs %s") % (
|
|
|
3fb7ae8 |
- self.op, testname, inputs)
|
|
|
3fb7ae8 |
- exc.args += (err_msg,)
|
|
|
3fb7ae8 |
- raise
|
|
|
3fb7ae8 |
-
|
|
|
3fb7ae8 |
- if not isinstance(expecteds, (list, tuple)):
|
|
|
3fb7ae8 |
- expecteds = (expecteds, )
|
|
|
3fb7ae8 |
-
|
|
|
3fb7ae8 |
- for i, (variable, expected) in enumerate(
|
|
|
3fb7ae8 |
- izip(variables, expecteds)):
|
|
|
3fb7ae8 |
- if (variable.dtype != expected.dtype or
|
|
|
3fb7ae8 |
- variable.shape != expected.shape or
|
|
|
3fb7ae8 |
- not np.allclose(variable, expected,
|
|
|
3fb7ae8 |
- atol=eps, rtol=eps)):
|
|
|
3fb7ae8 |
- self.fail(("Test %s::%s: Output %s gave the wrong"
|
|
|
3fb7ae8 |
- " value. With inputs %s, expected %s (dtype %s),"
|
|
|
3fb7ae8 |
- " got %s (dtype %s). eps=%f"
|
|
|
3fb7ae8 |
- " np.allclose returns %s %s") % (
|
|
|
3fb7ae8 |
- self.op,
|
|
|
3fb7ae8 |
- testname,
|
|
|
3fb7ae8 |
- i,
|
|
|
3fb7ae8 |
- inputs,
|
|
|
3fb7ae8 |
- expected,
|
|
|
3fb7ae8 |
- expected.dtype,
|
|
|
3fb7ae8 |
- variable,
|
|
|
3fb7ae8 |
- variable.dtype,
|
|
|
3fb7ae8 |
- eps,
|
|
|
3fb7ae8 |
- np.allclose(variable, expected,
|
|
|
3fb7ae8 |
- atol=eps, rtol=eps),
|
|
|
3fb7ae8 |
- np.allclose(variable, expected)))
|
|
|
3fb7ae8 |
-
|
|
|
3fb7ae8 |
- for description, check in iteritems(self.checks):
|
|
|
3fb7ae8 |
- if not check(inputs, variables):
|
|
|
3fb7ae8 |
- self.fail(("Test %s::%s: Failed check: %s (inputs"
|
|
|
3fb7ae8 |
- " were %s, outputs were %s)") % (
|
|
|
3fb7ae8 |
- self.op, testname, description,
|
|
|
3fb7ae8 |
- inputs, variables))
|
|
|
3fb7ae8 |
-
|
|
|
3fb7ae8 |
def test_bad_build(self):
|
|
|
3fb7ae8 |
if skip:
|
|
|
3fb7ae8 |
raise SkipTest(skip)
|