提交 92afcf3e 编写于 作者: M Megvii Engine Team

refactor(mge/imperative): move cgtools to megengine.utils and add load_and_inference in cgtools

GitOrigin-RevId: abfee3d4faf5768ba742a53357be5d34a473f540
上级 66f2dbd7
......@@ -76,7 +76,7 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from .serialization import load, save
from .tensor import Parameter, Tensor, tensor
from .version import __version__
from .core import cgtools
from .utils import comp_graph_tools as cgtools
_set_fork_exec_path_for_timed_func(
sys.executable,
......
......@@ -11,4 +11,3 @@ import sys
from .tensor import Tensor
from .tensor.megbrain_graph import Graph
from .utils import comp_graph_tools as cgtools
......@@ -358,7 +358,7 @@ CompGraphLoadResult = collections.namedtuple(
def load_graph(fpath):
"""Load a serialized computing graph from file.
:parma fpath: Path or Handle for the output file
:param fpath: Path or Handle of the input file
:return: An instance of namedtuple :class:`CompGraphLoadResult`,
whose fields are:
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .comp_graph_tools import *
......@@ -8,8 +8,12 @@
import collections
from typing import Dict, List
from .. import _imperative_rt
from .._imperative_rt import OperatorNode, VarNode
import numpy
from ..core import _imperative_rt
from ..core._imperative_rt import OperatorNode, VarNode
from ..core.tensor import megbrain_graph as G
from ..core.tensor.raw_tensor import as_raw_tensor
def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
......@@ -251,3 +255,33 @@ def set_priority_to_id(dest_vars):
assert isinstance(i, VarNode)
dest_vec.append(i)
_imperative_rt.graph._set_priority_to_id(dest_vec)
def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.ndarray]:
"""Load a serialized computing graph and run inference with input data.
:param file: Path or Handle of the input file.
:param inp_data_list: List of input data.
:return: List of inference results.
"""
*_, out_list = G.load_graph(file)
inputs = get_dep_vars(out_list, "Host2DeviceCopy")
replace_dict = {}
inp_node_list = []
for i in inputs:
inp_node = G.InputNode(
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph
)
replace_dict[i] = inp_node.outputs[0]
inp_node_list.append(inp_node)
new_out = replace_vars(out_list, replace_dict)
out_node_list = [G.OutputNode(i) for i in new_out]
new_out_list = [i.outputs[0] for i in out_node_list]
cg = new_out_list[0].graph
func = cg.compile(new_out_list)
for node, value in zip(inp_node_list, inp_data_list):
node.set_value(as_raw_tensor(value)._dev_tensor())
func.execute()
out_data_list = [o.get_value().numpy() for o in out_node_list]
return out_data_list
......@@ -22,29 +22,6 @@ from megengine.functional import exp, log
from megengine.jit import exclude_from_trace, trace
def load_and_inference(file, inp_data):
cg, _, out_list = G.load_graph(file)
inputs = cgtools.get_dep_vars(out_list, "Host2DeviceCopy")
replace_dict = {}
inp_node_list = []
for i in inputs:
inp_node = G.InputNode(
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph
)
replace_dict[i] = inp_node.outputs[0]
inp_node_list.append(inp_node)
new_out = cgtools.replace_vars(out_list, replace_dict)
out_node_list = [G.OutputNode(i) for i in new_out]
new_out_list = [i.outputs[0] for i in out_node_list]
new_cg = new_out_list[0].graph
func = new_cg.compile(new_out_list)
for node, value in zip(inp_node_list, inp_data):
node.set_value(as_raw_tensor(value)._dev_tensor())
func.execute()
out_data_list = [o.get_value().numpy() for o in out_node_list]
return out_data_list
def test_trace():
for symbolic in [False, True]:
......@@ -124,7 +101,7 @@ def test_dump():
np.testing.assert_equal(dump_info.inputs, ["h2d[0]", "h2d[2]"])
np.testing.assert_equal(dump_info.outputs, ["ADD(h2d[0],h2d[2])[4]"])
file.seek(0)
result = load_and_inference(file, [a, b])
result = cgtools.load_and_inference(file, [a, b])
np.testing.assert_equal(result[0], y)
......@@ -146,7 +123,7 @@ def test_capture_dump():
file = io.BytesIO()
f.dump(file)
file.seek(0)
result = load_and_inference(file, [x])
result = cgtools.load_and_inference(file, [x])
np.testing.assert_equal(result[0], y)
......
......@@ -17,6 +17,7 @@ import numpy as np
import megengine as mge
import megengine.core._imperative_rt as rt
import megengine.core.tensor.megbrain_graph as G
from megengine.core.tensor.megbrain_graph import VarNode
from megengine import cgtools
from megengine.core.ops import builtin
from megengine.core.tensor.core import apply
......@@ -488,7 +489,8 @@ def main():
with open(args.output, "wb") as fout:
fout.write(b"mgbtest0")
fout.write(struct.pack("I", len(feeds["testcases"])))
fout.write(rt.dump_graph(output_mgbvars))
dump_content, _ = G.dump_graph([VarNode(i) for i in output_mgbvars])
fout.write(dump_content)
def make_dev_tensor(value, dtype=None, device=None):
return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor()
......@@ -507,7 +509,8 @@ def main():
testcase.keys()
)
with open(args.output, "ab") as fout:
fout.write(G.dump_graph(*output_mgbvars))
dump_content, _ = G.dump_graph(output_mgbvars)
fout.write(dump_content)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册