diff --git a/imperative/python/test/helpers/utils.py b/imperative/python/test/helpers/utils.py index 10cd4a8b5168f59834b244c8e2c0fa5e6eda5575..99ce51b2614cef6b4cecf8083b7f2821e1b858e3 100644 --- a/imperative/python/test/helpers/utils.py +++ b/imperative/python/test/helpers/utils.py @@ -94,45 +94,49 @@ def opr_test( return inp, outp - if len(cases) == 0: - raise ValueError("should give one case at least") + def run_index(index): + inp, outp = get_param(cases, index) + inp_tensor = [make_tensor(inpi, network) for inpi in inp] - if not callable(func): - raise ValueError("the input func should be callable") + if test_trace and not network: + copied_inp = inp_tensor.copy() + for symbolic in [False, True]: + traced_func = trace(symbolic=symbolic)(func) - inp, outp = get_param(cases, 0) - inp_tensor = [make_tensor(inpi, network) for inpi in inp] + for _ in range(3): + traced_results = traced_func(*copied_inp, **kwargs) + check_results(traced_results, outp) - if test_trace and not network: - copied_inp = inp_tensor.copy() - for symbolic in [False, True]: - traced_func = trace(symbolic=symbolic)(func) + dumped_func = trace(symbolic=True, capture_as_const=True)(func) + dumped_results = dumped_func(*copied_inp, **kwargs) + check_results(dumped_results, outp) - for _ in range(3): - traced_results = traced_func(*copied_inp, **kwargs) - check_results(traced_results, outp) + file = io.BytesIO() + dump_info = dumped_func.dump(file) + file.seek(0) - dumped_func = trace(symbolic=True, capture_as_const=True)(func) - dumped_results = dumped_func(*copied_inp, **kwargs) - check_results(dumped_results, outp) + # arg_name has pattern arg_xxx, xxx is int value + def take_number(arg_name): + return int(arg_name.split("_")[-1]) - file = io.BytesIO() - dump_info = dumped_func.dump(file) - file.seek(0) + input_names = dump_info[4] + inps_np = [i.numpy() for i in copied_inp] + input_names.sort(key=take_number) + inp_dict = dict(zip(input_names, inps_np)) + infer_cg = cgtools.GraphInference(file) - # arg_name has pattern arg_xxx, xxx is int value - def take_number(arg_name): - return int(arg_name.split("_")[-1]) + # assume #outputs == 1 + loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0] + check_results(loaded_results, outp) - input_names = dump_info[4] - inps_np = [i.numpy() for i in copied_inp] - input_names.sort(key=take_number) - inp_dict = dict(zip(input_names, inps_np)) - infer_cg = cgtools.GraphInference(file) + results = func(*inp_tensor, **kwargs) + check_results(results, outp) - # assume #outputs == 1 - loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0] - check_results(loaded_results, outp) + if len(cases) == 0: + raise ValueError("should give one case at least") + + if not callable(func): + raise ValueError("the input func should be callable") - results = func(*inp_tensor, **kwargs) - check_results(results, outp) + for index in range(len(cases)): + run_index(index) diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index a4366a6c405ca85b8efb13f98162a0fb6e7cb076..612f6612fca4a76912643630ca9cf8e9179e261f 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -79,7 +79,7 @@ def test_matinv(): opr_test( cases, F.matinv, - compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-5), + compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-4), ref_fn=np.linalg.inv, )