From 91a3580f751ed5ca4d17f1305f8614cec438b87e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 2 Feb 2021 10:09:56 +0800 Subject: [PATCH] refactor(mge/cgtools): remove load_and_inference and use GraphInference GitOrigin-RevId: 0e688ebd59f98d6ea0b12c8b6e4ef45d8a2c9a27 --- .../python/megengine/utils/comp_graph_tools.py | 16 ---------------- imperative/python/test/unit/module/test_qat.py | 3 ++- imperative/python/test/unit/test_tracing.py | 6 ++++-- 3 files changed, 6 insertions(+), 19 deletions(-) diff --git a/imperative/python/megengine/utils/comp_graph_tools.py b/imperative/python/megengine/utils/comp_graph_tools.py index a2529855f..20bc1e7d5 100644 --- a/imperative/python/megengine/utils/comp_graph_tools.py +++ b/imperative/python/megengine/utils/comp_graph_tools.py @@ -27,7 +27,6 @@ __all__ = [ "replace_vars", "replace_oprs", "set_priority_to_id", - "load_and_inference", "GraphInference", ] @@ -274,21 +273,6 @@ def replace_oprs( return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) -def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.ndarray]: - """ - Loads a serialized computing graph and run inference with input data. - - :param file: path or handle of the input file. - :param inp_data_list: list of input data. - :return: list of inference results. - - """ - graph = GraphInference(file) - result = graph.run(*inp_data_list) - out_data_list = list(result.values()) - return out_data_list - - class GraphInference: """ Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph. diff --git a/imperative/python/test/unit/module/test_qat.py b/imperative/python/test/unit/module/test_qat.py index 37fdc3c58..9bc60e2bf 100644 --- a/imperative/python/test/unit/module/test_qat.py +++ b/imperative/python/test/unit/module/test_qat.py @@ -201,5 +201,6 @@ def test_quantize_batchmatmul_activation(): file = io.BytesIO() f.dump(file, enable_nchw4=True) file.seek(0) - dumped_outputs = cgtools.load_and_inference(file, [inputs])[0] + infer_cg = cgtools.GraphInference(file)[0] + dumped_outputs = list(infer_cg.run(inputs.numpy()).values())[0] np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6) diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index f3a570e16..94b3c1899 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -141,7 +141,8 @@ def test_dump(): np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"]) file.seek(0) - result = cgtools.load_and_inference(file, [a, b]) + infer_cg = cgtools.GraphInference(file) + result = list((infer_cg.run(a, b)).values())[0] np.testing.assert_equal(result[0], y) @@ -161,7 +162,8 @@ def test_capture_dump(): file = io.BytesIO() f.dump(file) file.seek(0) - result = cgtools.load_and_inference(file, [x]) + infer_cg = cgtools.GraphInference(file) + result = list((infer_cg.run(x)).values())[0] np.testing.assert_equal(result[0], y) -- GitLab