提交 a8309889 编写于 作者: M Megvii Engine Team

test(mge/utils): cover all test data

GitOrigin-RevId: e676476b9d59cc0ec717d8acb03d83cc35a35293
上级 dd1fecdf
...@@ -94,45 +94,49 @@ def opr_test( ...@@ -94,45 +94,49 @@ def opr_test(
return inp, outp return inp, outp
if len(cases) == 0: def run_index(index):
raise ValueError("should give one case at least") inp, outp = get_param(cases, index)
inp_tensor = [make_tensor(inpi, network) for inpi in inp]
if not callable(func): if test_trace and not network:
raise ValueError("the input func should be callable") copied_inp = inp_tensor.copy()
for symbolic in [False, True]:
traced_func = trace(symbolic=symbolic)(func)
inp, outp = get_param(cases, 0) for _ in range(3):
inp_tensor = [make_tensor(inpi, network) for inpi in inp] traced_results = traced_func(*copied_inp, **kwargs)
check_results(traced_results, outp)
if test_trace and not network: dumped_func = trace(symbolic=True, capture_as_const=True)(func)
copied_inp = inp_tensor.copy() dumped_results = dumped_func(*copied_inp, **kwargs)
for symbolic in [False, True]: check_results(dumped_results, outp)
traced_func = trace(symbolic=symbolic)(func)
for _ in range(3): file = io.BytesIO()
traced_results = traced_func(*copied_inp, **kwargs) dump_info = dumped_func.dump(file)
check_results(traced_results, outp) file.seek(0)
dumped_func = trace(symbolic=True, capture_as_const=True)(func) # arg_name has pattern arg_xxx, xxx is int value
dumped_results = dumped_func(*copied_inp, **kwargs) def take_number(arg_name):
check_results(dumped_results, outp) return int(arg_name.split("_")[-1])
file = io.BytesIO() input_names = dump_info[4]
dump_info = dumped_func.dump(file) inps_np = [i.numpy() for i in copied_inp]
file.seek(0) input_names.sort(key=take_number)
inp_dict = dict(zip(input_names, inps_np))
infer_cg = cgtools.GraphInference(file)
# arg_name has pattern arg_xxx, xxx is int value # assume #outputs == 1
def take_number(arg_name): loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0]
return int(arg_name.split("_")[-1]) check_results(loaded_results, outp)
input_names = dump_info[4] results = func(*inp_tensor, **kwargs)
inps_np = [i.numpy() for i in copied_inp] check_results(results, outp)
input_names.sort(key=take_number)
inp_dict = dict(zip(input_names, inps_np))
infer_cg = cgtools.GraphInference(file)
# assume #outputs == 1 if len(cases) == 0:
loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0] raise ValueError("should give one case at least")
check_results(loaded_results, outp)
if not callable(func):
raise ValueError("the input func should be callable")
results = func(*inp_tensor, **kwargs) for index in range(len(cases)):
check_results(results, outp) run_index(index)
...@@ -79,7 +79,7 @@ def test_matinv(): ...@@ -79,7 +79,7 @@ def test_matinv():
opr_test( opr_test(
cases, cases,
F.matinv, F.matinv,
compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-5), compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-4),
ref_fn=np.linalg.inv, ref_fn=np.linalg.inv,
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册