refactor(mge): test trace inside opr_test

GitOrigin-RevId: 2cf1135c1ccbdba234238d29465dd1eda6765a59
上级 2b8150ab
无相关合并请求
......@@ -315,9 +315,9 @@ class GraphInference:
inputs = get_dep_vars(output_nodes, "Host2DeviceCopy")
self._inp_dict = OrderedDict()
replace_dict = {}
for i in inputs:
for idx, i in enumerate(inputs):
inp_node = G.InputNode(
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph
device="xpux", dtype=inputs[idx].dtype, graph=inputs[0].graph
)
self._inp_dict[i.name] = inp_node
replace_dict[i] = inp_node.outputs[0]
......
import io
import numpy as np
import megengine.utils.comp_graph_tools as cgtools
from megengine import tensor
from megengine.jit import trace
def _default_compare_fn(x, y):
if isinstance(x, np.ndarray):
np.testing.assert_allclose(x, y, rtol=1e-6)
else:
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6)
def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs):
def opr_test(
cases, func, compare_fn=_default_compare_fn, ref_fn=None, test_trace=True, **kwargs
):
"""
:param cases: the list which have dict element, the list length should be 2 for dynamic shape test.
and the dict should have input,
......@@ -35,6 +44,8 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs)
if not isinstance(results, (tuple, list)):
results = (results,)
for r, e in zip(results, expected):
if not isinstance(r, tensor):
r = tensor(r)
compare_fn(r, e)
def get_param(cases, idx):
......@@ -63,5 +74,36 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs)
inp, outp = get_param(cases, 0)
inp_tensor = [tensor(inpi) for inpi in inp]
if test_trace:
copied_inp = inp_tensor.copy()
for symbolic in [False, True]:
traced_func = trace(symbolic=symbolic)(func)
for _ in range(3):
traced_results = traced_func(*copied_inp, **kwargs)
check_results(traced_results, outp)
dumped_func = trace(symbolic=True, capture_as_const=True)(func)
dumped_results = dumped_func(*copied_inp, **kwargs)
check_results(dumped_results, outp)
file = io.BytesIO()
dump_info = dumped_func.dump(file)
file.seek(0)
# arg_name has pattern arg_xxx, xxx is int value
def take_number(arg_name):
return int(arg_name.split("_")[-1])
input_names = dump_info[4]
inps_np = [i.numpy() for i in copied_inp]
input_names.sort(key=take_number)
inp_dict = dict(zip(input_names, inps_np))
infer_cg = cgtools.GraphInference(file)
# assume #outputs == 1
loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0]
check_results(loaded_results, outp)
results = func(*inp_tensor, **kwargs)
check_results(results, outp)
......@@ -36,7 +36,7 @@ def test_where():
{"input": [maskv0, xv0, yv0]},
{"input": [maskv1, xv1, yv1]},
]
opr_test(cases, F.where, ref_fn=np.where)
opr_test(cases, F.where, ref_fn=np.where, test_trace=False)
maskv2 = np.array([1, 1, 1], dtype=np.bool_)
xv2 = np.array([1, 3, 2], dtype=np.float32)
......@@ -50,7 +50,7 @@ def test_where():
{"input": [maskv2, xv2, yv2]},
{"input": [maskv3, xv3, yv3]},
]
opr_test(cases, F.where, ref_fn=np.where)
opr_test(cases, F.where, ref_fn=np.where, test_trace=False)
def test_dropout():
......@@ -115,14 +115,17 @@ def test_matmul():
{"input": [data4, data5]},
]
for _ in range(0, batch_size):
# FIXME: remove test_trace=False in the future
opr_test(
cases, F.matmul, ref_fn=np.matmul,
cases, F.matmul, test_trace=False, ref_fn=np.matmul,
)
# FIXME: remove test_trace=False in the future
opr_test(
[{"input": [data1, data4]}],
F.matmul,
ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)),
test_trace=False,
transpose_b=True,
)
......
......@@ -162,20 +162,24 @@ def test_linspace():
{"input": [1, 9, 9]},
{"input": [3, 10, 8]},
]
# FIXME: remove test_trace=False in the future
opr_test(
cases,
F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
test_trace=False,
)
cases = [
{"input": [9, 1, 9]},
{"input": [10, 3, 8]},
]
# FIXME: remove test_trace=False in the future
opr_test(
cases,
F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
test_trace=False,
)
......@@ -184,30 +188,36 @@ def test_arange():
{"input": [1, 9, 1]},
{"input": [2, 10, 2]},
]
# FIXME: remove test_trace=False in the future
opr_test(
cases,
F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
test_trace=False,
)
cases = [
{"input": [9, 1, -1]},
{"input": [10, 2, -2]},
]
# FIXME: remove test_trace=False in the future
opr_test(
cases,
F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
test_trace=False,
)
cases = [
{"input": [9.3, 1.2, -0.5]},
{"input": [10.3, 2.1, -1.7]},
]
# FIXME: remove test_trace=False in the future
opr_test(
cases,
F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
test_trace=False,
)
......@@ -279,7 +289,8 @@ def test_broadcast():
{"input": [data1, output1_shape], "output": output1_shape},
{"input": [data2, output2_shape], "output": output2_shape},
]
opr_test(cases, F.broadcast_to, compare_fn=compare_fn)
# FIXME: remove test_trace=False in the future
opr_test(cases, F.broadcast_to, compare_fn=compare_fn, test_trace=False)
x = F.ones((2, 1, 3))
with pytest.raises(RuntimeError):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部