提交 a8309889 编写于 作者: M Megvii Engine Team

test(mge/utils): cover all test data

GitOrigin-RevId: e676476b9d59cc0ec717d8acb03d83cc35a35293
上级 dd1fecdf
......@@ -94,13 +94,8 @@ def opr_test(
return inp, outp
if len(cases) == 0:
raise ValueError("should give one case at least")
if not callable(func):
raise ValueError("the input func should be callable")
inp, outp = get_param(cases, 0)
def run_index(index):
inp, outp = get_param(cases, index)
inp_tensor = [make_tensor(inpi, network) for inpi in inp]
if test_trace and not network:
......@@ -136,3 +131,12 @@ def opr_test(
results = func(*inp_tensor, **kwargs)
check_results(results, outp)
if len(cases) == 0:
raise ValueError("should give one case at least")
if not callable(func):
raise ValueError("the input func should be callable")
for index in range(len(cases)):
run_index(index)
......@@ -79,7 +79,7 @@ def test_matinv():
opr_test(
cases,
F.matinv,
compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-5),
compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-4),
ref_fn=np.linalg.inv,
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册