提交 91a3580f 编写于 作者: M Megvii Engine Team

refactor(mge/cgtools): remove load_and_inference and use GraphInference

GitOrigin-RevId: 0e688ebd59f98d6ea0b12c8b6e4ef45d8a2c9a27
上级 4485e780
...@@ -27,7 +27,6 @@ __all__ = [ ...@@ -27,7 +27,6 @@ __all__ = [
"replace_vars", "replace_vars",
"replace_oprs", "replace_oprs",
"set_priority_to_id", "set_priority_to_id",
"load_and_inference",
"GraphInference", "GraphInference",
] ]
...@@ -274,21 +273,6 @@ def replace_oprs( ...@@ -274,21 +273,6 @@ def replace_oprs(
return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)
def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.ndarray]:
"""
Loads a serialized computing graph and run inference with input data.
:param file: path or handle of the input file.
:param inp_data_list: list of input data.
:return: list of inference results.
"""
graph = GraphInference(file)
result = graph.run(*inp_data_list)
out_data_list = list(result.values())
return out_data_list
class GraphInference: class GraphInference:
""" """
Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph. Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph.
......
...@@ -201,5 +201,6 @@ def test_quantize_batchmatmul_activation(): ...@@ -201,5 +201,6 @@ def test_quantize_batchmatmul_activation():
file = io.BytesIO() file = io.BytesIO()
f.dump(file, enable_nchw4=True) f.dump(file, enable_nchw4=True)
file.seek(0) file.seek(0)
dumped_outputs = cgtools.load_and_inference(file, [inputs])[0] infer_cg = cgtools.GraphInference(file)[0]
dumped_outputs = list(infer_cg.run(inputs.numpy()).values())[0]
np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6) np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6)
...@@ -141,7 +141,8 @@ def test_dump(): ...@@ -141,7 +141,8 @@ def test_dump():
np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"])
np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"]) np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"])
file.seek(0) file.seek(0)
result = cgtools.load_and_inference(file, [a, b]) infer_cg = cgtools.GraphInference(file)
result = list((infer_cg.run(a, b)).values())[0]
np.testing.assert_equal(result[0], y) np.testing.assert_equal(result[0], y)
...@@ -161,7 +162,8 @@ def test_capture_dump(): ...@@ -161,7 +162,8 @@ def test_capture_dump():
file = io.BytesIO() file = io.BytesIO()
f.dump(file) f.dump(file)
file.seek(0) file.seek(0)
result = cgtools.load_and_inference(file, [x]) infer_cg = cgtools.GraphInference(file)
result = list((infer_cg.run(x)).values())[0]
np.testing.assert_equal(result[0], y) np.testing.assert_equal(result[0], y)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册