提交 334eda87 编写于 作者: M Megvii Engine Team

refactor(mge): test trace inside opr_test

GitOrigin-RevId: 2cf1135c1ccbdba234238d29465dd1eda6765a59
上级 2b8150ab
...@@ -315,9 +315,9 @@ class GraphInference: ...@@ -315,9 +315,9 @@ class GraphInference:
inputs = get_dep_vars(output_nodes, "Host2DeviceCopy") inputs = get_dep_vars(output_nodes, "Host2DeviceCopy")
self._inp_dict = OrderedDict() self._inp_dict = OrderedDict()
replace_dict = {} replace_dict = {}
for i in inputs: for idx, i in enumerate(inputs):
inp_node = G.InputNode( inp_node = G.InputNode(
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph device="xpux", dtype=inputs[idx].dtype, graph=inputs[0].graph
) )
self._inp_dict[i.name] = inp_node self._inp_dict[i.name] = inp_node
replace_dict[i] = inp_node.outputs[0] replace_dict[i] = inp_node.outputs[0]
......
import io
import numpy as np import numpy as np
import megengine.utils.comp_graph_tools as cgtools
from megengine import tensor from megengine import tensor
from megengine.jit import trace
def _default_compare_fn(x, y): def _default_compare_fn(x, y):
if isinstance(x, np.ndarray):
np.testing.assert_allclose(x, y, rtol=1e-6)
else:
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) np.testing.assert_allclose(x.numpy(), y, rtol=1e-6)
def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs): def opr_test(
cases, func, compare_fn=_default_compare_fn, ref_fn=None, test_trace=True, **kwargs
):
""" """
:param cases: the list which have dict element, the list length should be 2 for dynamic shape test. :param cases: the list which have dict element, the list length should be 2 for dynamic shape test.
and the dict should have input, and the dict should have input,
...@@ -35,6 +44,8 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs) ...@@ -35,6 +44,8 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs)
if not isinstance(results, (tuple, list)): if not isinstance(results, (tuple, list)):
results = (results,) results = (results,)
for r, e in zip(results, expected): for r, e in zip(results, expected):
if not isinstance(r, tensor):
r = tensor(r)
compare_fn(r, e) compare_fn(r, e)
def get_param(cases, idx): def get_param(cases, idx):
...@@ -63,5 +74,36 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs) ...@@ -63,5 +74,36 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs)
inp, outp = get_param(cases, 0) inp, outp = get_param(cases, 0)
inp_tensor = [tensor(inpi) for inpi in inp] inp_tensor = [tensor(inpi) for inpi in inp]
if test_trace:
copied_inp = inp_tensor.copy()
for symbolic in [False, True]:
traced_func = trace(symbolic=symbolic)(func)
for _ in range(3):
traced_results = traced_func(*copied_inp, **kwargs)
check_results(traced_results, outp)
dumped_func = trace(symbolic=True, capture_as_const=True)(func)
dumped_results = dumped_func(*copied_inp, **kwargs)
check_results(dumped_results, outp)
file = io.BytesIO()
dump_info = dumped_func.dump(file)
file.seek(0)
# arg_name has pattern arg_xxx, xxx is int value
def take_number(arg_name):
return int(arg_name.split("_")[-1])
input_names = dump_info[4]
inps_np = [i.numpy() for i in copied_inp]
input_names.sort(key=take_number)
inp_dict = dict(zip(input_names, inps_np))
infer_cg = cgtools.GraphInference(file)
# assume #outputs == 1
loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0]
check_results(loaded_results, outp)
results = func(*inp_tensor, **kwargs) results = func(*inp_tensor, **kwargs)
check_results(results, outp) check_results(results, outp)
...@@ -36,7 +36,7 @@ def test_where(): ...@@ -36,7 +36,7 @@ def test_where():
{"input": [maskv0, xv0, yv0]}, {"input": [maskv0, xv0, yv0]},
{"input": [maskv1, xv1, yv1]}, {"input": [maskv1, xv1, yv1]},
] ]
opr_test(cases, F.where, ref_fn=np.where) opr_test(cases, F.where, ref_fn=np.where, test_trace=False)
maskv2 = np.array([1, 1, 1], dtype=np.bool_) maskv2 = np.array([1, 1, 1], dtype=np.bool_)
xv2 = np.array([1, 3, 2], dtype=np.float32) xv2 = np.array([1, 3, 2], dtype=np.float32)
...@@ -50,7 +50,7 @@ def test_where(): ...@@ -50,7 +50,7 @@ def test_where():
{"input": [maskv2, xv2, yv2]}, {"input": [maskv2, xv2, yv2]},
{"input": [maskv3, xv3, yv3]}, {"input": [maskv3, xv3, yv3]},
] ]
opr_test(cases, F.where, ref_fn=np.where) opr_test(cases, F.where, ref_fn=np.where, test_trace=False)
def test_dropout(): def test_dropout():
...@@ -115,14 +115,17 @@ def test_matmul(): ...@@ -115,14 +115,17 @@ def test_matmul():
{"input": [data4, data5]}, {"input": [data4, data5]},
] ]
for _ in range(0, batch_size): for _ in range(0, batch_size):
# FIXME: remove test_trace=False in the future
opr_test( opr_test(
cases, F.matmul, ref_fn=np.matmul, cases, F.matmul, test_trace=False, ref_fn=np.matmul,
) )
# FIXME: remove test_trace=False in the future
opr_test( opr_test(
[{"input": [data1, data4]}], [{"input": [data1, data4]}],
F.matmul, F.matmul,
ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)), ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)),
test_trace=False,
transpose_b=True, transpose_b=True,
) )
......
...@@ -162,20 +162,24 @@ def test_linspace(): ...@@ -162,20 +162,24 @@ def test_linspace():
{"input": [1, 9, 9]}, {"input": [1, 9, 9]},
{"input": [3, 10, 8]}, {"input": [3, 10, 8]},
] ]
# FIXME: remove test_trace=False in the future
opr_test( opr_test(
cases, cases,
F.linspace, F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
test_trace=False,
) )
cases = [ cases = [
{"input": [9, 1, 9]}, {"input": [9, 1, 9]},
{"input": [10, 3, 8]}, {"input": [10, 3, 8]},
] ]
# FIXME: remove test_trace=False in the future
opr_test( opr_test(
cases, cases,
F.linspace, F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
test_trace=False,
) )
...@@ -184,30 +188,36 @@ def test_arange(): ...@@ -184,30 +188,36 @@ def test_arange():
{"input": [1, 9, 1]}, {"input": [1, 9, 1]},
{"input": [2, 10, 2]}, {"input": [2, 10, 2]},
] ]
# FIXME: remove test_trace=False in the future
opr_test( opr_test(
cases, cases,
F.arange, F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
test_trace=False,
) )
cases = [ cases = [
{"input": [9, 1, -1]}, {"input": [9, 1, -1]},
{"input": [10, 2, -2]}, {"input": [10, 2, -2]},
] ]
# FIXME: remove test_trace=False in the future
opr_test( opr_test(
cases, cases,
F.arange, F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
test_trace=False,
) )
cases = [ cases = [
{"input": [9.3, 1.2, -0.5]}, {"input": [9.3, 1.2, -0.5]},
{"input": [10.3, 2.1, -1.7]}, {"input": [10.3, 2.1, -1.7]},
] ]
# FIXME: remove test_trace=False in the future
opr_test( opr_test(
cases, cases,
F.arange, F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
test_trace=False,
) )
...@@ -279,7 +289,8 @@ def test_broadcast(): ...@@ -279,7 +289,8 @@ def test_broadcast():
{"input": [data1, output1_shape], "output": output1_shape}, {"input": [data1, output1_shape], "output": output1_shape},
{"input": [data2, output2_shape], "output": output2_shape}, {"input": [data2, output2_shape], "output": output2_shape},
] ]
opr_test(cases, F.broadcast_to, compare_fn=compare_fn) # FIXME: remove test_trace=False in the future
opr_test(cases, F.broadcast_to, compare_fn=compare_fn, test_trace=False)
x = F.ones((2, 1, 3)) x = F.ones((2, 1, 3))
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册