From 334eda871768e5a14962bba4cfb545a75530f3eb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 25 Jan 2021 10:12:46 +0800 Subject: [PATCH] refactor(mge): test trace inside opr_test GitOrigin-RevId: 2cf1135c1ccbdba234238d29465dd1eda6765a59 --- .../megengine/utils/comp_graph_tools.py | 4 +- imperative/python/test/helpers/utils.py | 46 ++++++++++++++++++- .../test/unit/functional/test_functional.py | 9 ++-- .../test/unit/functional/test_tensor.py | 13 +++++- 4 files changed, 64 insertions(+), 8 deletions(-) diff --git a/imperative/python/megengine/utils/comp_graph_tools.py b/imperative/python/megengine/utils/comp_graph_tools.py index de3a57365..a2529855f 100644 --- a/imperative/python/megengine/utils/comp_graph_tools.py +++ b/imperative/python/megengine/utils/comp_graph_tools.py @@ -315,9 +315,9 @@ class GraphInference: inputs = get_dep_vars(output_nodes, "Host2DeviceCopy") self._inp_dict = OrderedDict() replace_dict = {} - for i in inputs: + for idx, i in enumerate(inputs): inp_node = G.InputNode( - device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph + device="xpux", dtype=inputs[idx].dtype, graph=inputs[0].graph ) self._inp_dict[i.name] = inp_node replace_dict[i] = inp_node.outputs[0] diff --git a/imperative/python/test/helpers/utils.py b/imperative/python/test/helpers/utils.py index 4724fd26a..63de93469 100644 --- a/imperative/python/test/helpers/utils.py +++ b/imperative/python/test/helpers/utils.py @@ -1,13 +1,22 @@ +import io + import numpy as np +import megengine.utils.comp_graph_tools as cgtools from megengine import tensor +from megengine.jit import trace def _default_compare_fn(x, y): - np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) + if isinstance(x, np.ndarray): + np.testing.assert_allclose(x, y, rtol=1e-6) + else: + np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) -def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs): +def opr_test( + cases, func, compare_fn=_default_compare_fn, ref_fn=None, test_trace=True, **kwargs +): """ :param cases: the list which have dict element, the list length should be 2 for dynamic shape test. and the dict should have input, @@ -35,6 +44,8 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs) if not isinstance(results, (tuple, list)): results = (results,) for r, e in zip(results, expected): + if not isinstance(r, tensor): + r = tensor(r) compare_fn(r, e) def get_param(cases, idx): @@ -63,5 +74,36 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs) inp, outp = get_param(cases, 0) inp_tensor = [tensor(inpi) for inpi in inp] + if test_trace: + copied_inp = inp_tensor.copy() + for symbolic in [False, True]: + traced_func = trace(symbolic=symbolic)(func) + + for _ in range(3): + traced_results = traced_func(*copied_inp, **kwargs) + check_results(traced_results, outp) + + dumped_func = trace(symbolic=True, capture_as_const=True)(func) + dumped_results = dumped_func(*copied_inp, **kwargs) + check_results(dumped_results, outp) + + file = io.BytesIO() + dump_info = dumped_func.dump(file) + file.seek(0) + + # arg_name has pattern arg_xxx, xxx is int value + def take_number(arg_name): + return int(arg_name.split("_")[-1]) + + input_names = dump_info[4] + inps_np = [i.numpy() for i in copied_inp] + input_names.sort(key=take_number) + inp_dict = dict(zip(input_names, inps_np)) + infer_cg = cgtools.GraphInference(file) + + # assume #outputs == 1 + loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0] + check_results(loaded_results, outp) + results = func(*inp_tensor, **kwargs) check_results(results, outp) diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 503e5d34c..21d22c777 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -36,7 +36,7 @@ def test_where(): {"input": [maskv0, xv0, yv0]}, {"input": [maskv1, xv1, yv1]}, ] - opr_test(cases, F.where, ref_fn=np.where) + opr_test(cases, F.where, ref_fn=np.where, test_trace=False) maskv2 = np.array([1, 1, 1], dtype=np.bool_) xv2 = np.array([1, 3, 2], dtype=np.float32) @@ -50,7 +50,7 @@ def test_where(): {"input": [maskv2, xv2, yv2]}, {"input": [maskv3, xv3, yv3]}, ] - opr_test(cases, F.where, ref_fn=np.where) + opr_test(cases, F.where, ref_fn=np.where, test_trace=False) def test_dropout(): @@ -115,14 +115,17 @@ def test_matmul(): {"input": [data4, data5]}, ] for _ in range(0, batch_size): + # FIXME: remove test_trace=False in the future opr_test( - cases, F.matmul, ref_fn=np.matmul, + cases, F.matmul, test_trace=False, ref_fn=np.matmul, ) + # FIXME: remove test_trace=False in the future opr_test( [{"input": [data1, data4]}], F.matmul, ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)), + test_trace=False, transpose_b=True, ) diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index e9343c4f0..c24cafc37 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -162,20 +162,24 @@ def test_linspace(): {"input": [1, 9, 9]}, {"input": [3, 10, 8]}, ] + # FIXME: remove test_trace=False in the future opr_test( cases, F.linspace, ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), + test_trace=False, ) cases = [ {"input": [9, 1, 9]}, {"input": [10, 3, 8]}, ] + # FIXME: remove test_trace=False in the future opr_test( cases, F.linspace, ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), + test_trace=False, ) @@ -184,30 +188,36 @@ def test_arange(): {"input": [1, 9, 1]}, {"input": [2, 10, 2]}, ] + # FIXME: remove test_trace=False in the future opr_test( cases, F.arange, ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), + test_trace=False, ) cases = [ {"input": [9, 1, -1]}, {"input": [10, 2, -2]}, ] + # FIXME: remove test_trace=False in the future opr_test( cases, F.arange, ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), + test_trace=False, ) cases = [ {"input": [9.3, 1.2, -0.5]}, {"input": [10.3, 2.1, -1.7]}, ] + # FIXME: remove test_trace=False in the future opr_test( cases, F.arange, ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), + test_trace=False, ) @@ -279,7 +289,8 @@ def test_broadcast(): {"input": [data1, output1_shape], "output": output1_shape}, {"input": [data2, output2_shape], "output": output2_shape}, ] - opr_test(cases, F.broadcast_to, compare_fn=compare_fn) + # FIXME: remove test_trace=False in the future + opr_test(cases, F.broadcast_to, compare_fn=compare_fn, test_trace=False) x = F.ones((2, 1, 3)) with pytest.raises(RuntimeError): -- GitLab