From 8e50a6daa771a621d445c9513a9553ec46478cdd Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 20 Jan 2021 14:43:09 +0800 Subject: [PATCH] feat(mge/utils): add GraphInference in cgtools GitOrigin-RevId: 72f22011691a23f50a966480829ac6744753d563 --- .../megengine/utils/comp_graph_tools.py | 98 +++++++++++++++---- imperative/python/test/unit/test_cgtools.py | 45 +++++++++ 2 files changed, 122 insertions(+), 21 deletions(-) diff --git a/imperative/python/megengine/utils/comp_graph_tools.py b/imperative/python/megengine/utils/comp_graph_tools.py index f4f1cfdc..de3a5736 100644 --- a/imperative/python/megengine/utils/comp_graph_tools.py +++ b/imperative/python/megengine/utils/comp_graph_tools.py @@ -6,7 +6,8 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections -from typing import Dict, List +from collections import OrderedDict +from typing import Dict, List, Optional import numpy @@ -27,6 +28,7 @@ __all__ = [ "replace_oprs", "set_priority_to_id", "load_and_inference", + "GraphInference", ] @@ -46,7 +48,7 @@ def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]: q = list(var) while q: - v = q.pop() + v = q.pop(0) if v in memo: continue memo.add(v) @@ -281,23 +283,77 @@ def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.n :return: list of inference results. """ - *_, out_list = G.load_graph(file) - inputs = get_dep_vars(out_list, "Host2DeviceCopy") - replace_dict = {} - inp_node_list = [] - for i in inputs: - inp_node = G.InputNode( - device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph - ) - replace_dict[i] = inp_node.outputs[0] - inp_node_list.append(inp_node) - new_out = replace_vars(out_list, replace_dict) - out_node_list = [G.OutputNode(i) for i in new_out] - new_out_list = [i.outputs[0] for i in out_node_list] - cg = new_out_list[0].graph - func = cg.compile(new_out_list) - for node, value in zip(inp_node_list, inp_data_list): - node.set_value(Tensor(value)._dev_tensor()) - func.execute() - out_data_list = [o.get_value().numpy() for o in out_node_list] + graph = GraphInference(file) + result = graph.run(*inp_data_list) + out_data_list = list(result.values()) return out_data_list + + +class GraphInference: + """ + Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph. + The `GraphInference.run()` accepts a list `inp_args` or a dict `inp_dict` {input_name: input_value} as input and returns a dict {output_name: output_value}. + + :param file: could be file object or filename. + :param outputs: only compile the subgraph with outputs as its endpoints. + """ + + def __init__(self, file, outputs: Optional[List[str]] = None): + *_, output_nodes = G.load_graph(file) + if outputs is not None: + output_name = outputs.copy() + all_vars = get_dep_vars(output_nodes) + output_nodes + new_outputs = {} + for i in all_vars: + if i.name in output_name: + new_outputs[i.name] = i + output_name.remove(i.name) + assert ( + len(output_name) == 0 + ), "Can not find varnode {} in this model".format(output_name) + output_nodes = [new_outputs[i] for i in outputs] + inputs = get_dep_vars(output_nodes, "Host2DeviceCopy") + self._inp_dict = OrderedDict() + replace_dict = {} + for i in inputs: + inp_node = G.InputNode( + device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph + ) + self._inp_dict[i.name] = inp_node + replace_dict[i] = inp_node.outputs[0] + new_output_nodes = replace_vars(output_nodes, replace_dict) + for old, new in zip(output_nodes, new_output_nodes): + new.name = old.name + self._out_dict = OrderedDict( + [(i.name, G.OutputNode(i)) for i in new_output_nodes] + ) + new_out_list = [i.outputs[0] for i in self._out_dict.values()] + cg = new_out_list[0].graph + self._func = cg.compile(new_out_list) + + def run( + self, + *inp_args: numpy.ndarray, + inp_dict: Optional[Dict[str, numpy.ndarray]] = None + ): + assert len(inp_args) <= len( + self._inp_dict + ), "This model expects {} inputs".format(len(self._inp_dict)) + inputs = {} + inp_keys = list(self._inp_dict.keys()) + for ind, data in enumerate(inp_args): + inputs[inp_keys[ind]] = data + if inp_dict is not None: + inputs.update(inp_dict) + assert ( + inputs.keys() == self._inp_dict.keys() + ), "This model expects inputs {}, but gets inputs {}".format( + list(self._inp_dict.keys()), list(inputs.keys()) + ) + for key in self._inp_dict: + self._inp_dict[key].set_value(Tensor(inputs[key])._dev_tensor()) + self._func.execute() + result = OrderedDict() + for key in self._out_dict: + result[key] = self._out_dict[key].get_value().numpy() + return result diff --git a/imperative/python/test/unit/test_cgtools.py b/imperative/python/test/unit/test_cgtools.py index 0c240dce..406a9c08 100644 --- a/imperative/python/test/unit/test_cgtools.py +++ b/imperative/python/test/unit/test_cgtools.py @@ -139,3 +139,48 @@ def test_get_opr_seq(): seq_2 = cgtools.get_oprs_seq(outputs, False) assert len(seq_2) == 6 + + +def test_graph_function(): + class Net(M.Module): + def forward(self, a, b): + return a - b, a * b + + net = Net() + + @trace(symbolic=True, capture_as_const=True) + def function(a, b, *, net=None): + return net(a, b) + + a = np.array([1, 2, 3]) + b = np.array([3]) + x, y = function(megengine.tensor(a), megengine.tensor(b), net=net) + + file = io.BytesIO() + function.dump( + file, + arg_names=["a", "b"], + output_names=["x", "y"], + optimize_for_inference=False, + ) + file.seek(0) + + graph = cgtools.GraphInference(file) + results = graph.run(inp_dict={"a": a, "b": b}) + np.testing.assert_equal(x.numpy(), results["x"]) + np.testing.assert_equal(y.numpy(), results["y"]) + + results = graph.run(a, inp_dict={"b": b}) + np.testing.assert_equal(x.numpy(), results["x"]) + np.testing.assert_equal(y.numpy(), results["y"]) + + results = graph.run(a, b) + np.testing.assert_equal(x.numpy(), results["x"]) + np.testing.assert_equal(y.numpy(), results["y"]) + + file.seek(0) + + graph1 = cgtools.GraphInference(file, outputs=["x"]) + results = graph1.run(inp_dict={"a": a, "b": b}) + np.testing.assert_equal(x.numpy(), results["x"]) + assert "y" not in results -- GitLab