提交 8e50a6da 编写于 作者: M Megvii Engine Team

feat(mge/utils): add GraphInference in cgtools

GitOrigin-RevId: 72f22011691a23f50a966480829ac6744753d563
上级 97207d00
......@@ -6,7 +6,8 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from typing import Dict, List
from collections import OrderedDict
from typing import Dict, List, Optional
import numpy
......@@ -27,6 +28,7 @@ __all__ = [
"replace_oprs",
"set_priority_to_id",
"load_and_inference",
"GraphInference",
]
......@@ -46,7 +48,7 @@ def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
q = list(var)
while q:
v = q.pop()
v = q.pop(0)
if v in memo:
continue
memo.add(v)
......@@ -281,23 +283,77 @@ def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.n
:return: list of inference results.
"""
*_, out_list = G.load_graph(file)
inputs = get_dep_vars(out_list, "Host2DeviceCopy")
replace_dict = {}
inp_node_list = []
for i in inputs:
inp_node = G.InputNode(
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph
)
replace_dict[i] = inp_node.outputs[0]
inp_node_list.append(inp_node)
new_out = replace_vars(out_list, replace_dict)
out_node_list = [G.OutputNode(i) for i in new_out]
new_out_list = [i.outputs[0] for i in out_node_list]
cg = new_out_list[0].graph
func = cg.compile(new_out_list)
for node, value in zip(inp_node_list, inp_data_list):
node.set_value(Tensor(value)._dev_tensor())
func.execute()
out_data_list = [o.get_value().numpy() for o in out_node_list]
graph = GraphInference(file)
result = graph.run(*inp_data_list)
out_data_list = list(result.values())
return out_data_list
class GraphInference:
"""
Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph.
The `GraphInference.run()` accepts a list `inp_args` or a dict `inp_dict` {input_name: input_value} as input and returns a dict {output_name: output_value}.
:param file: could be file object or filename.
:param outputs: only compile the subgraph with outputs as its endpoints.
"""
def __init__(self, file, outputs: Optional[List[str]] = None):
*_, output_nodes = G.load_graph(file)
if outputs is not None:
output_name = outputs.copy()
all_vars = get_dep_vars(output_nodes) + output_nodes
new_outputs = {}
for i in all_vars:
if i.name in output_name:
new_outputs[i.name] = i
output_name.remove(i.name)
assert (
len(output_name) == 0
), "Can not find varnode {} in this model".format(output_name)
output_nodes = [new_outputs[i] for i in outputs]
inputs = get_dep_vars(output_nodes, "Host2DeviceCopy")
self._inp_dict = OrderedDict()
replace_dict = {}
for i in inputs:
inp_node = G.InputNode(
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph
)
self._inp_dict[i.name] = inp_node
replace_dict[i] = inp_node.outputs[0]
new_output_nodes = replace_vars(output_nodes, replace_dict)
for old, new in zip(output_nodes, new_output_nodes):
new.name = old.name
self._out_dict = OrderedDict(
[(i.name, G.OutputNode(i)) for i in new_output_nodes]
)
new_out_list = [i.outputs[0] for i in self._out_dict.values()]
cg = new_out_list[0].graph
self._func = cg.compile(new_out_list)
def run(
self,
*inp_args: numpy.ndarray,
inp_dict: Optional[Dict[str, numpy.ndarray]] = None
):
assert len(inp_args) <= len(
self._inp_dict
), "This model expects {} inputs".format(len(self._inp_dict))
inputs = {}
inp_keys = list(self._inp_dict.keys())
for ind, data in enumerate(inp_args):
inputs[inp_keys[ind]] = data
if inp_dict is not None:
inputs.update(inp_dict)
assert (
inputs.keys() == self._inp_dict.keys()
), "This model expects inputs {}, but gets inputs {}".format(
list(self._inp_dict.keys()), list(inputs.keys())
)
for key in self._inp_dict:
self._inp_dict[key].set_value(Tensor(inputs[key])._dev_tensor())
self._func.execute()
result = OrderedDict()
for key in self._out_dict:
result[key] = self._out_dict[key].get_value().numpy()
return result
......@@ -139,3 +139,48 @@ def test_get_opr_seq():
seq_2 = cgtools.get_oprs_seq(outputs, False)
assert len(seq_2) == 6
def test_graph_function():
class Net(M.Module):
def forward(self, a, b):
return a - b, a * b
net = Net()
@trace(symbolic=True, capture_as_const=True)
def function(a, b, *, net=None):
return net(a, b)
a = np.array([1, 2, 3])
b = np.array([3])
x, y = function(megengine.tensor(a), megengine.tensor(b), net=net)
file = io.BytesIO()
function.dump(
file,
arg_names=["a", "b"],
output_names=["x", "y"],
optimize_for_inference=False,
)
file.seek(0)
graph = cgtools.GraphInference(file)
results = graph.run(inp_dict={"a": a, "b": b})
np.testing.assert_equal(x.numpy(), results["x"])
np.testing.assert_equal(y.numpy(), results["y"])
results = graph.run(a, inp_dict={"b": b})
np.testing.assert_equal(x.numpy(), results["x"])
np.testing.assert_equal(y.numpy(), results["y"])
results = graph.run(a, b)
np.testing.assert_equal(x.numpy(), results["x"])
np.testing.assert_equal(y.numpy(), results["y"])
file.seek(0)
graph1 = cgtools.GraphInference(file, outputs=["x"])
results = graph1.run(inp_dict={"a": a, "b": b})
np.testing.assert_equal(x.numpy(), results["x"])
assert "y" not in results
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册