diff --git a/imperative/python/test/helpers/utils.py b/imperative/python/test/helpers/utils.py index 99ce51b2614cef6b4cecf8083b7f2821e1b858e3..ef19085f3d344524be18502651f30b8d5397c122 100644 --- a/imperative/python/test/helpers/utils.py +++ b/imperative/python/test/helpers/utils.py @@ -11,12 +11,12 @@ from megengine.utils.network_node import VarNode def _default_compare_fn(x, y): - if isinstance(x, np.ndarray): - np.testing.assert_allclose(x, y, rtol=1e-6) - elif isinstance(x, tensor): - np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) - else: - np.testing.assert_allclose(get_var_value(x), y, rtol=1e-6) + if isinstance(x, tensor): + x = x.numpy() + elif not isinstance(x, np.ndarray): + x = get_var_value(x) + assert isinstance(x, np.ndarray) + np.testing.assert_allclose(x, y, rtol=1e-6) def make_tensor(x, network=None, device=None): @@ -69,12 +69,16 @@ def opr_test( """ - def check_results(results, expected): + def check_results(results, expected, check_shape=True): if not isinstance(results, (tuple, list)): results = (results,) for r, e in zip(results, expected): if not isinstance(r, (tensor, VarNode)): r = tensor(r) + if check_shape: + r_shape = r.numpy().shape + e_shape = e.shape if isinstance(e, np.ndarray) else () + assert r_shape == e_shape compare_fn(r, e) def get_param(cases, idx): @@ -127,10 +131,10 @@ def opr_test( # assume #outputs == 1 loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0] - check_results(loaded_results, outp) + check_results(loaded_results, outp, check_shape=False) # scalar info lost results = func(*inp_tensor, **kwargs) - check_results(results, outp) + check_results(results, outp, check_shape=(network is None)) if len(cases) == 0: raise ValueError("should give one case at least") diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index ec1ff2fab77732350ad6002ee44ec33d16a0ed3b..a1bd6553be9090e42bf226e443fa916b696f2a20 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -39,12 +39,6 @@ def test_where(): xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32) yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32) - cases = [ - {"input": [maskv0, xv0, yv0]}, - {"input": [maskv1, xv1, yv1]}, - ] - opr_test(cases, F.where, ref_fn=np.where, test_trace=False) - maskv2 = np.array([1, 1, 1], dtype=np.bool_) xv2 = np.array([1, 3, 2], dtype=np.float32) yv2 = np.array([5, 6, 9], dtype=np.float32) @@ -53,11 +47,18 @@ def test_where(): xv3 = np.array([1, 3, 2], dtype=np.float32) yv3 = np.array([5, 6, 9], dtype=np.float32) + maskv4 = np.array(1, dtype=np.bool_) + xv4 = np.array(1, dtype=np.float32) + yv4 = np.array(0, dtype=np.float32) + cases = [ + {"input": [maskv0, xv0, yv0]}, + {"input": [maskv1, xv1, yv1]}, {"input": [maskv2, xv2, yv2]}, {"input": [maskv3, xv3, yv3]}, + {"input": [maskv4, xv4, yv4]}, ] - opr_test(cases, F.where, ref_fn=np.where, test_trace=False) + opr_test(cases, F.where, ref_fn=np.where, test_trace=True) def test_dropout(): @@ -618,12 +619,12 @@ def test_binary_cross_entropy(): np.random.seed(123) data1 = np.random.uniform(size=data1_shape).astype(np.float32) label1 = np.random.uniform(size=label1_shape).astype(np.float32) - expect1 = np.array([0.6361], dtype=np.float32) + expect1 = np.array(0.6361, dtype=np.float32) np.random.seed(123) data2 = np.random.uniform(size=data2_shape).astype(np.float32) label2 = np.random.uniform(size=label2_shape).astype(np.float32) - expect2 = np.array([0.6750], dtype=np.float32) + expect2 = np.array(0.6750, dtype=np.float32) cases = [ {"input": [data1, label1], "output": expect1,}, diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index b103d659c77f0ad7533ee06a0a74e989d12537aa..9470c35ee700cfe2f6617568524feabddc7edc23 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -335,18 +335,18 @@ def test_reshape_shape_inference(is_varnode): source = output.shape if isinstance(source, tensor): source = source.numpy() - np.testing.assert_equal(source, target) + np.testing.assert_equal(source, target.shape) def func(x, target_shape): return x.reshape(target_shape) cases = [ - {"input": [x_shape_known, tshp_unknown], "output": [(2, 2),]}, - {"input": [x_shape_unknown, tshp_unknown], "output": [(2, 2),]}, - {"input": [x_shape_known, tshp_known], "output": [(2, 2),]}, - {"input": [x_shape_known, tshp_known_unspec], "output": [(2, 2),]}, - {"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]}, - {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]}, + {"input": [x_shape_known, tshp_unknown], "output": [np.zeros((2, 2)),]}, + {"input": [x_shape_unknown, tshp_unknown], "output": [np.zeros((2, 2)),]}, + {"input": [x_shape_known, tshp_known], "output": [np.zeros((2, 2)),]}, + {"input": [x_shape_known, tshp_known_unspec], "output": [np.zeros((2, 2)),]}, + {"input": [x_shape_unknown, tshp_known], "output": [np.zeros((2, 2)),]}, + {"input": [x_shape_unknown, tshp_known_unspec], "output": [np.zeros((2, 2)),]}, ] opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network) if is_varnode: @@ -533,46 +533,30 @@ def test_flatten(is_varnode): data0 = np.random.random(data0_shape).astype(np.float32) data1 = np.random.random(data1_shape).astype(np.float32) - def compare_fn(x, y): - assert x._tuple_shape[0] == y - - output0 = (2 * 3 * 4 * 5,) - output1 = (4 * 5 * 6 * 7,) cases = [ - {"input": data0, "output": output0}, - {"input": data1, "output": output1}, + {"input": data0, "output": data0.flatten()}, + {"input": data1, "output": data1.flatten()}, ] - opr_test(cases, F.flatten, compare_fn=compare_fn, network=network) + opr_test(cases, F.flatten, network=network) - output0 = (2, 3 * 4 * 5) - output1 = (4, 5 * 6 * 7) cases = [ - {"input": data0, "output": output0}, - {"input": data1, "output": output1}, + {"input": data0, "output": data0.reshape(2, -1)}, + {"input": data1, "output": data1.reshape(4, -1)}, ] - opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network) + opr_test(cases, F.flatten, start_axis=1, network=network) - output0 = (2, 3, 4 * 5) - output1 = (4, 5, 6 * 7) cases = [ - {"input": data0, "output": output0}, - {"input": data1, "output": output1}, + {"input": data0, "output": data0.reshape(2, 3, -1)}, + {"input": data1, "output": data1.reshape(4, 5, -1)}, ] - opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network) + opr_test(cases, F.flatten, start_axis=2, network=network) - output0 = (2, 3 * 4, 5) - output1 = (4, 5 * 6, 7) cases = [ - {"input": data0, "output": output0}, - {"input": data1, "output": output1}, + {"input": data0, "output": data0.reshape(2, -1, 5)}, + {"input": data1, "output": data1.reshape(4, -1, 7)}, ] opr_test( - cases, - F.flatten, - compare_fn=compare_fn, - start_axis=1, - end_axis=2, - network=network, + cases, F.flatten, start_axis=1, end_axis=2, network=network, ) @@ -595,15 +579,22 @@ def test_broadcast(is_varnode): output3_shape = (10, 10) data3 = np.random.random(input3_shape).astype(np.float32) - def compare_fn(x, y): - assert x._tuple_shape[0] == y - cases = [ - {"input": [data1, output1_shape], "output": output1_shape}, - {"input": [data2, output2_shape], "output": output2_shape}, - {"input": [data3, output3_shape], "output": output3_shape}, + { + "input": [data1, output1_shape], + "output": np.broadcast_to(data1, output1_shape), + }, + { + "input": [data2, output2_shape], + "output": np.broadcast_to(data2, output2_shape), + }, + { + "input": [data3, output3_shape], + "output": np.broadcast_to(data3, output3_shape), + }, ] - opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network) + + opr_test(cases, F.broadcast_to, network=network) x = F.ones((2, 1, 3)) with pytest.raises(RuntimeError):