提交 a8309889 编写于 作者: M Megvii Engine Team

test(mge/utils): cover all test data

GitOrigin-RevId: e676476b9d59cc0ec717d8acb03d83cc35a35293
上级 dd1fecdf
......@@ -94,45 +94,49 @@ def opr_test(
return inp, outp
if len(cases) == 0:
raise ValueError("should give one case at least")
def run_index(index):
inp, outp = get_param(cases, index)
inp_tensor = [make_tensor(inpi, network) for inpi in inp]
if not callable(func):
raise ValueError("the input func should be callable")
if test_trace and not network:
copied_inp = inp_tensor.copy()
for symbolic in [False, True]:
traced_func = trace(symbolic=symbolic)(func)
inp, outp = get_param(cases, 0)
inp_tensor = [make_tensor(inpi, network) for inpi in inp]
for _ in range(3):
traced_results = traced_func(*copied_inp, **kwargs)
check_results(traced_results, outp)
if test_trace and not network:
copied_inp = inp_tensor.copy()
for symbolic in [False, True]:
traced_func = trace(symbolic=symbolic)(func)
dumped_func = trace(symbolic=True, capture_as_const=True)(func)
dumped_results = dumped_func(*copied_inp, **kwargs)
check_results(dumped_results, outp)
for _ in range(3):
traced_results = traced_func(*copied_inp, **kwargs)
check_results(traced_results, outp)
file = io.BytesIO()
dump_info = dumped_func.dump(file)
file.seek(0)
dumped_func = trace(symbolic=True, capture_as_const=True)(func)
dumped_results = dumped_func(*copied_inp, **kwargs)
check_results(dumped_results, outp)
# arg_name has pattern arg_xxx, xxx is int value
def take_number(arg_name):
return int(arg_name.split("_")[-1])
file = io.BytesIO()
dump_info = dumped_func.dump(file)
file.seek(0)
input_names = dump_info[4]
inps_np = [i.numpy() for i in copied_inp]
input_names.sort(key=take_number)
inp_dict = dict(zip(input_names, inps_np))
infer_cg = cgtools.GraphInference(file)
# arg_name has pattern arg_xxx, xxx is int value
def take_number(arg_name):
return int(arg_name.split("_")[-1])
# assume #outputs == 1
loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0]
check_results(loaded_results, outp)
input_names = dump_info[4]
inps_np = [i.numpy() for i in copied_inp]
input_names.sort(key=take_number)
inp_dict = dict(zip(input_names, inps_np))
infer_cg = cgtools.GraphInference(file)
results = func(*inp_tensor, **kwargs)
check_results(results, outp)
# assume #outputs == 1
loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0]
check_results(loaded_results, outp)
if len(cases) == 0:
raise ValueError("should give one case at least")
if not callable(func):
raise ValueError("the input func should be callable")
results = func(*inp_tensor, **kwargs)
check_results(results, outp)
for index in range(len(cases)):
run_index(index)
......@@ -79,7 +79,7 @@ def test_matinv():
opr_test(
cases,
F.matinv,
compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-5),
compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-4),
ref_fn=np.linalg.inv,
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册