提交 ac11c38a 编写于 作者: M Megvii Engine Team

feat(mge/imperative): add graph load and cgtools for imperative

GitOrigin-RevId: ba251f452ae8c6cc9c3dae99d1be92711cbeff5e
上级 76f36796
...@@ -76,6 +76,7 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level ...@@ -76,6 +76,7 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from .serialization import load, save from .serialization import load, save
from .tensor import Parameter, Tensor, tensor from .tensor import Parameter, Tensor, tensor
from .version import __version__ from .version import __version__
from .core import cgtools
_set_fork_exec_path_for_timed_func( _set_fork_exec_path_for_timed_func(
sys.executable, sys.executable,
......
...@@ -10,3 +10,5 @@ import os ...@@ -10,3 +10,5 @@ import os
import sys import sys
from .tensor import Tensor from .tensor import Tensor
from .tensor.megbrain_graph import Graph
from .utils import comp_graph_tools as cgtools
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
import json
import threading import threading
import weakref import weakref
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
...@@ -162,14 +163,42 @@ def optimize_for_inference(dest_vars, **kwargs): ...@@ -162,14 +163,42 @@ def optimize_for_inference(dest_vars, **kwargs):
return [VarNode(i) for i in res_vars] return [VarNode(i) for i in res_vars]
def dump(*args): def dump_graph(*args):
return _imperative_rt.dump_graph([i._node for i in args]) return _imperative_rt.dump_graph([i._node for i in args])
CompGraphLoadResult = collections.namedtuple(
"CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list"]
)
def load_graph(fpath):
"""Load a serialized computing graph from file.
:parma fpath: Path or Handle for the output file
:return: An instance of namedtuple :class:`CompGraphLoadResult`,
whose fields are:
* ``graph`` loaded CompGraph
* ``output_vars_dict`` A Python dict, mapping name to output SymbolVar
* ``output_vars_list`` A Python list, containing output vars in the
order passed to serialize_comp_graph_to_file
"""
output_vars_map = []
output_vars_list = []
if isinstance(fpath, str):
buf = open(fpath, "rb").read()
else:
buf = fpath.read()
cg = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list)
return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list)
class VarNode(TensorBase): class VarNode(TensorBase):
def __init__(self, node: _imperative_rt.VarNode): def __init__(self, node: _imperative_rt.VarNode):
self._node = node self._node = node
self.graph._var_cache[node] = self if hasattr(self.graph, "_var_cache"):
self.graph._var_cache[node] = self
@property @property
def graph(self) -> Graph: def graph(self) -> Graph:
...@@ -177,12 +206,19 @@ class VarNode(TensorBase): ...@@ -177,12 +206,19 @@ class VarNode(TensorBase):
@property @property
def op(self): def op(self):
return self.graph._wrap(self._node.owner) if hasattr(self.graph, "_wrap"):
return self.graph._wrap(self._node.owner)
else:
return self._node.owner
@property @property
def name(self): def name(self):
return self._node.name return self._node.name
@property
def id(self):
return self._node.id
@name.setter @name.setter
def name(self, name): def name(self, name):
self._node.name = name self._node.name = name
...@@ -207,7 +243,8 @@ class VarNode(TensorBase): ...@@ -207,7 +243,8 @@ class VarNode(TensorBase):
class OpNode: class OpNode:
def __init__(self, node: _imperative_rt.OperatorNode): def __init__(self, node: _imperative_rt.OperatorNode):
self._node = node self._node = node
self.graph._op_cache[node] = self if hasattr(self.graph, "_op_cache"):
self.graph._op_cache[node] = self
@property @property
def graph(self) -> Graph: def graph(self) -> Graph:
...@@ -217,29 +254,53 @@ class OpNode: ...@@ -217,29 +254,53 @@ class OpNode:
def name(self): def name(self):
return self._node.name return self._node.name
@property
def id(self):
return self._node.id
@name.setter @name.setter
def name(self, name): def name(self, name):
self._node.name = name self._node.name = name
@property @property
def inputs(self): def inputs(self):
return tuple(map(self.graph._wrap, self._node.inputs)) if hasattr(self.graph, "_wrap"):
return tuple(map(self.graph._wrap, self._node.inputs))
else:
return self._node.inputs
@property @property
def outputs(self): def outputs(self):
return tuple(map(self.graph._wrap, self._node.outputs)) if hasattr(self.graph, "_wrap"):
return tuple(map(self.graph._wrap, self._node.outputs))
else:
return self._node.outputs
@property
def params(self):
return json.loads(self._node.params)
@property
def type(self):
return self._node.type
def _wrap(x): def _wrap(x):
if isinstance(x, collections.abc.Sequence): if isinstance(x, collections.abc.Sequence):
return type(x)(map(_wrap, x)) return type(x)(map(_wrap, x))
return x.graph._wrap(x) if hasattr(x.graph, "_wrap"):
return x.graph._wrap(x)
else:
return x
def _unwrap(x): def _unwrap(x):
if isinstance(x, collections.abc.Sequence): if isinstance(x, collections.abc.Sequence):
return type(x)(map(_unwrap, x)) return type(x)(map(_unwrap, x))
return x._node if isinstance(x, VarNode):
return x._node
else:
return x
@apply.register() @apply.register()
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from typing import Dict, List
from .. import _imperative_rt
from .._imperative_rt import OperatorNode, VarNode
def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
"""return :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var``
depands on. If ``var_type`` is None, return all types.
"""
outputs = []
memo = set()
if isinstance(var, VarNode):
var = [var]
if isinstance(var_type, str):
var_type = [var_type]
q = list(var)
while q:
v = q.pop()
if v in memo:
continue
memo.add(v)
q.extend(get_owner_opr_inputs(v))
if var_type is not None:
if get_owner_opr_type(v) in var_type:
outputs.append(v)
else:
outputs.append(v)
return outputs
def get_owner_opr_inputs(var: VarNode) -> List[VarNode]:
"""get the inputs of owner opr of a variable
"""
assert isinstance(var, VarNode)
return var.owner.inputs
def get_owner_opr_type(var: VarNode) -> str:
"""get the type of owner opr of a variable
"""
assert isinstance(var, VarNode)
return var.owner.type
def get_opr_type(opr: OperatorNode) -> str:
"""get the type of a opr
"""
assert isinstance(opr, OperatorNode)
return opr.type
def graph_traversal(outputs: VarNode):
"""helper function to traverse the computing graph and return enough useful information
:param outputs: model outputs
:return: tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree)
WHERE
map_oprs is dict from opr_id to actual opr
map_vars is dict from var_id to actual var
var2oprs is dict from var to dest oprs along with index
opr2receivers is dict from current opr to next opr
indegree2opr is dict from in_degree to opr in computing graph
opr2indegree is dict from opr in computing graph to in_degree
(indegree2opr, opr2indegree) are only used in topological sort in get_oprs_seq function
"""
# meta information for comp graph
map_oprs = collections.defaultdict(set)
map_vars = collections.defaultdict(set)
var2oprs = collections.defaultdict(list)
opr2receivers = collections.defaultdict(list)
queue = list(map(lambda x: x.owner, outputs))
visited = set(map(lambda x: x.id, queue))
# iterate through whole comp_graph, fill in meta information
indegree2opr = collections.defaultdict(set)
opr2indegree = {}
idx = 0
while idx < len(queue):
cur_opr = queue[idx]
map_oprs[cur_opr.id] = cur_opr
idx += 1
indegree = 0
for var_idx, var in enumerate(cur_opr.inputs):
map_vars[var.id] = var
var2oprs[var.id].append((cur_opr.id, var_idx))
pre_opr = var.owner
if pre_opr.id not in visited:
visited.add(pre_opr.id)
queue.append(pre_opr)
indegree += 1
opr2receivers[pre_opr.id].append(cur_opr.id)
indegree2opr[indegree].add(cur_opr.id)
opr2indegree[cur_opr.id] = indegree
return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree
def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNode]:
"""get oprs in some topological order for a dumped model
:param outputs: model outputs
:param prune_reshape: whether to prune the operators useless during inference
:return: opr list with some correct execution order
"""
def topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree):
# generate an execution order with topological sort algorithm
oprs_seq = []
nr_remain = len(map_oprs)
while indegree2opr[0]:
opr_id = indegree2opr[0].pop()
opr = map_oprs[opr_id]
nr_remain -= 1
# skip const value generation operator
if get_opr_type(opr) != "ImmutableTensor":
oprs_seq.append(opr)
for post_id in opr2receivers[opr_id]:
indegree = opr2indegree[post_id]
indegree2opr[indegree].remove(post_id)
indegree -= 1
indegree2opr[indegree].add(post_id)
opr2indegree[post_id] = indegree
assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format(
nr_remain
)
return oprs_seq
# reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor
# when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph
def prune_reshape_oprs(outputs, oprs_seq, var2oprs):
def iterative_pruning(cur_opr, post_opr, marked_opr_ids):
useless = True
for oup in cur_opr.outputs:
if "workspace" not in oup.name:
var_idx = post_opr.inputs.index(oup)
var2oprs[oup.id].remove((post_opr.id, var_idx))
useless = useless and (len(var2oprs[oup.id]) == 0)
if useless:
marked_opr_ids.append(cur_opr.id)
for inp in cur_opr.inputs:
iterative_pruning(inp.owner, cur_opr, marked_opr_ids)
reshape_vars = get_dep_vars(outputs, "Reshape")
reshape_oprs = [var.owner for var in reshape_vars]
marked_opr_ids = []
for reshape_opr in reshape_oprs:
iterative_pruning(reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids)
# filter out all marked oprs
return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq))
map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal(
outputs
)
oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree)
if prune_reshape is True:
oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy())
return oprs_seq
def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]:
"""replace vars in the graph
:param dst: target vars representing the graph
:param varmap: the map that specifies how to replace the vars
:return: new vars that correspond to ``dst`` with all the dependencies
replaced
"""
dst_vec = []
repl_src_vec = []
repl_dst_vec = []
for i in dst:
assert isinstance(i, VarNode)
dst_vec.append(i)
for i, j in getattr(varmap, "items", lambda: varmap)():
assert isinstance(i, VarNode)
assert isinstance(j, VarNode)
repl_src_vec.append(i)
repl_dst_vec.append(j)
return _imperative_rt.graph._replace_vars(repl_src_vec, repl_dst_vec, dst_vec)
def replace_oprs(
dst: List[VarNode], oprmap: Dict[OperatorNode, OperatorNode]
) -> List[VarNode]:
"""Replace operators in the graph.
:param dst: target vars representing the graph
:param oprmap: the map that specifies how to replace the operators
:return: new vars that correspond to ``dst`` with all the dependencies
replaced
"""
dst_vec = []
repl_src_vec = []
repl_dst_vec = []
for i in dst:
assert isinstance(i, VarNode)
dst_vec.append(i)
for i, j in getattr(oprmap, "items", lambda: oprmap)():
assert isinstance(i, OperatorNode)
assert isinstance(j, OperatorNode)
repl_src_vec.append(i)
repl_dst_vec.append(j)
return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)
def set_priority_to_id(dest_vars):
"""For all oprs in the subgraph constructed by dest_vars
set its priority to id if its original priority is zero
:param dest_vars: target vars representing the graph
"""
dest_vec = []
for i in dest_vars:
assert isinstance(i, VarNode)
dest_vec.append(i)
_imperative_rt.graph._set_priority_to_id(dest_vec)
...@@ -569,7 +569,7 @@ class trace: ...@@ -569,7 +569,7 @@ class trace:
if isinstance(file, str): if isinstance(file, str):
permission = "wb" if append == False else "ab" permission = "wb" if append == False else "ab"
file = open(file, permission) file = open(file, permission)
file.write(G.dump(*dest_vars)) file.write(G.dump_graph(*dest_vars))
def _process_inputs(self, *args, **kwargs): def _process_inputs(self, *args, **kwargs):
if self._untraced: if self._untraced:
......
...@@ -64,7 +64,60 @@ auto def_rendezvous(py::object m, const char* name) { ...@@ -64,7 +64,60 @@ auto def_rendezvous(py::object m, const char* name) {
using TensorAttr = LogicalTensorDesc; using TensorAttr = LogicalTensorDesc;
using HostNDWithEvent = std::pair<HostTensorND, std::shared_ptr<CompNode::Event>>; using HostNDWithEvent = std::pair<HostTensorND, std::shared_ptr<CompNode::Event>>;
std::vector<mgb::cg::VarNode*> _replace_vars(const std::vector<mgb::cg::VarNode*>& repl_src,
const std::vector<mgb::cg::VarNode*>& repl_dst,
const std::vector<mgb::cg::VarNode*>& vars) {
mgb::ThinHashMap<SymbolVar, SymbolVar> varmap;
for (size_t i = 0; i < repl_src.size(); ++i) {
varmap[SymbolVar(repl_src[i])] = SymbolVar(repl_dst[i]);
}
SymbolVarArray symvars(vars.begin(), vars.end());
auto sym_result = mgb::cg::replace_vars(symvars, varmap);
std::vector<mgb::cg::VarNode*> result;
for (auto symvar : sym_result){
result.push_back(symvar.node());
}
return result;
}
typedef std::vector<mgb::cg::OperatorNodeBase*> OperatorArray;
std::vector<mgb::cg::VarNode*> _replace_oprs(const OperatorArray& repl_src,
const OperatorArray& repl_dst,
const std::vector<mgb::cg::VarNode*>& vars) {
mgb::ThinHashMap<mgb::cg::OperatorNodeBase*, mgb::cg::OperatorNodeBase*>
oprmap;
for (size_t i = 0; i < repl_src.size(); ++i) {
oprmap[repl_src[i]] = repl_dst[i];
}
const SymbolVarArray symvars(vars.begin(), vars.end());
auto sym_result = mgb::cg::replace_oprs(symvars, oprmap);
std::vector<mgb::cg::VarNode*> result;
for (auto symvar : sym_result){
result.push_back(symvar.node());
}
return result;
}
void _set_priority_to_id(const std::vector<mgb::cg::VarNode*>& dest_vars) {
auto on_opr = [](mgb::cg::OperatorNodeBase* opr) {
if (opr->node_prop().attribute().priority == 0) {
opr->node_prop().attribute().priority = opr->id();
}
};
mgb::cg::DepOprIter dep_iter{on_opr};
for (const auto& var : dest_vars) {
dep_iter.add(SymbolVar(var));
}
}
void init_graph_rt(py::module m) { void init_graph_rt(py::module m) {
static const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr{std::make_unique<mgb::OprFootprint>()};
def_rendezvous<DeviceTensorND>(m, "DeviceTensorNDRendezvous"); def_rendezvous<DeviceTensorND>(m, "DeviceTensorNDRendezvous");
def_rendezvous<HostNDWithEvent>(m, "HostTensorNDRendezvous"); def_rendezvous<HostNDWithEvent>(m, "HostTensorNDRendezvous");
...@@ -99,7 +152,10 @@ void init_graph_rt(py::module m) { ...@@ -99,7 +152,10 @@ void init_graph_rt(py::module m) {
return py::none(); return py::none();
} }
return py::cast(*val).attr("numpy")(); return py::cast(*val).attr("numpy")();
}); })
.def_property_readonly("id",[](cg::VarNode* v){
return (v->id());
});
py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(m, "OperatorNode") py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(m, "OperatorNode")
.def_property_readonly("graph", [](cg::OperatorNodeBase* opr) {return opr->owner_graph();}) .def_property_readonly("graph", [](cg::OperatorNodeBase* opr) {return opr->owner_graph();})
...@@ -110,7 +166,17 @@ void init_graph_rt(py::module m) { ...@@ -110,7 +166,17 @@ void init_graph_rt(py::module m) {
}) })
.def_property_readonly("outputs", [](cg::OperatorNodeBase* opr) { .def_property_readonly("outputs", [](cg::OperatorNodeBase* opr) {
return to_tuple(opr->usable_output()); return to_tuple(opr->usable_output());
}); })
.def_property_readonly("id",[](cg::OperatorNodeBase* opr){
return opr->id();
})
.def_property_readonly("params",[](cg::OperatorNodeBase* opr){
return _imperative_sm_opr_footprint_ptr->calc_footprint(opr).param->to_string();
})
.def_property_readonly("type",[](cg::OperatorNodeBase* opr){
return opr->dyn_typeinfo()->name;
});
py::class_<cg::AsyncExecutable>(m, "AsyncExecutable") py::class_<cg::AsyncExecutable>(m, "AsyncExecutable")
.def("execute", &cg::AsyncExecutable::execute, py::call_guard<py::gil_scoped_release>()) .def("execute", &cg::AsyncExecutable::execute, py::call_guard<py::gil_scoped_release>())
...@@ -174,6 +240,44 @@ void init_graph_rt(py::module m) { ...@@ -174,6 +240,44 @@ void init_graph_rt(py::module m) {
}); });
m.def("load_graph", [](std::string& buf, py::list& _output_var_map, py::list& _output_var_list) {
using namespace mgb::serialization;
auto file = InputFile::make_mem_proxy(buf.c_str(), buf.length());
auto format = GraphLoader::identify_graph_dump_format(*file);
auto loader = GraphLoader::make(std::move(file), format.val());
GraphLoader::LoadConfig config;
auto rst = loader->load(config);
std::vector<std::pair<std::string, SymbolVar>> output_var_map;
SymbolVarArray output_var_list;
output_var_map = {rst.output_var_map.begin(), rst.output_var_map.end()};
output_var_list = std::move(rst.output_var_list);
for (auto i : output_var_list){
_output_var_list.append(i.node());
}
for (auto i : output_var_map){
_output_var_map.append(py::make_tuple(i.first,i.second.node()));
}
std::unordered_map<HostTensorND*, const std::string*> tensor2name;
for (const auto& pair : rst.tensor_map) {
tensor2name[pair.second.get()] = &pair.first;
}
auto cb = [&tensor2name, graph=rst.graph](cg::OperatorNodeBase* opr) {
if (!opr->same_type<opr::Host2DeviceCopy>())
return;
auto& h2d = opr->cast_final_safe<opr::Host2DeviceCopy>();
auto it = tensor2name.find(h2d.host_data().get());
mgb_throw_if(it == tensor2name.end(), GraphError,
"unbound Host2DeviceCopy in loaded graph");
h2d.output(0)->name(*it->second);
};
cg::DepOprIter iter{cb};
for (const auto& var : output_var_list) {
iter.add(var.node()->owner_opr());
}
return rst.graph;
});
#define CURRENT_CLASS cg::ComputingGraph::Options #define CURRENT_CLASS cg::ComputingGraph::Options
auto PyComputingGraphOptions = py::class_<cg::ComputingGraph::Options>(PyComputingGraph, "Options") auto PyComputingGraphOptions = py::class_<cg::ComputingGraph::Options>(PyComputingGraph, "Options")
...@@ -287,6 +391,10 @@ void init_graph_rt(py::module m) { ...@@ -287,6 +391,10 @@ void init_graph_rt(py::module m) {
return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, shape, dtype), config).node(); return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, shape, dtype), config).node();
}, py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::none()); }, py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::none());
m.def("_replace_vars", &_replace_vars,py::arg(),py::arg(),py::arg());
m.def("_replace_oprs", &_replace_oprs,py::arg(),py::arg(),py::arg());
m.def("_set_priority_to_id",&_set_priority_to_id,py::arg());
m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback, m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback,
const CompNode& comp_node, const CompNode& comp_node,
const DType& dtype, const DType& dtype,
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <future> #include <future>
#include "megbrain/plugin/opr_footprint.h"
#include "megbrain/graph.h" #include "megbrain/graph.h"
template<typename T> template<typename T>
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import io
import numpy as np
import megengine
import megengine.functional as F
import megengine.module as M
from megengine import cgtools
from megengine.core.tensor import megbrain_graph as mgb_graph
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.jit import trace
def make_dev_tensor(value, dtype=None, device=None):
return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor()
def test_replace_vars():
g = mgb_graph.Graph()
g.options.async_exec_level = 0b100
device = "xpux"
dtype = np.float32
a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
const = g.make_const(1.234)
a_plus_a = F.add(a.outputs[0], a.outputs[0])
a_plus_a_mul_const = F.mul(a_plus_a, const)
rst = F.add(a_plus_a_mul_const, a.outputs[0])
(new,) = cgtools.replace_vars([rst._node], {const._node: a_plus_a._node})
out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
func = g.compile(out.outputs[0])
func.execute()
x = make_dev_tensor(5.0, device=device)
a.set_value(x)
res = out.get_value().numpy()
np.testing.assert_equal(res, np.array([105.0]))
def test_replace_oprs():
g = mgb_graph.Graph()
g.options.async_exec_level = 0b100
device = "xpux"
dtype = np.float32
a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
const = g.make_const(1.25)
a_plus_a = F.add(a.outputs[0], a.outputs[0])
old_opr = a_plus_a.op
a_plus_a_mul_const = F.mul(a_plus_a, const)
a_mul_a = F.mul(a.outputs[0], a.outputs[0])
new_opr = a_mul_a.op
(new,) = cgtools.replace_oprs(
[a_plus_a_mul_const._node], {old_opr._node: new_opr._node}
)
out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
func = g.compile(out.outputs[0])
func.execute()
x = make_dev_tensor(5.0, device=device)
a.set_value(x)
res = out.get_value().numpy()
np.testing.assert_equal(res, np.array([5.0 * 5.0 * 1.25]))
def test_graph_traversal():
net = M.Conv2d(3, 32, 3)
@trace(symbolic=True, capture_as_const=True)
def fun(data):
x = net(data)
return x
data = np.random.random([1, 3, 224, 224]).astype(np.float32)
for i in range(3):
fun(megengine.tensor(data))
file = io.BytesIO()
fun.dump(file)
file.seek(0)
cg, _, outputs = mgb_graph.load_graph(file)
_, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs)
input_var = map_vars[1]
_, var_idx = var2oprs[input_var.id][0]
assert var_idx == 0
...@@ -13,6 +13,10 @@ import numpy as np ...@@ -13,6 +13,10 @@ import numpy as np
import pytest import pytest
from megengine import tensor from megengine import tensor
import megengine
import megengine.core.tensor.megbrain_graph as mgb_graph
import megengine.module as M
from megengine import cgtools
from megengine.core.ops import builtin as ops from megengine.core.ops import builtin as ops
from megengine.core.tensor import megbrain_graph as G from megengine.core.tensor import megbrain_graph as G
from megengine.core.tensor.core import apply from megengine.core.tensor.core import apply
...@@ -21,6 +25,29 @@ from megengine.functional import exp, log ...@@ -21,6 +25,29 @@ from megengine.functional import exp, log
from megengine.jit import exclude_from_trace, trace from megengine.jit import exclude_from_trace, trace
def load_and_inference(file, inp_data):
cg, _, out_list = mgb_graph.load_graph(file)
inputs = cgtools.get_dep_vars(out_list, "Host2DeviceCopy")
replace_dict = {}
inp_node_list = []
for i in inputs:
inp_node = mgb_graph.InputNode(
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph
)
replace_dict[i] = inp_node.outputs[0]
inp_node_list.append(inp_node)
new_out = cgtools.replace_vars(out_list, replace_dict)
out_node_list = [mgb_graph.OutputNode(i) for i in new_out]
new_out_list = [i.outputs[0] for i in out_node_list]
new_cg = new_out_list[0].graph
func = new_cg.compile(new_out_list)
for node, value in zip(inp_node_list, inp_data):
node.set_value(as_raw_tensor(value)._dev_tensor())
func.execute()
out_data_list = [o.get_value().numpy() for o in out_node_list]
return out_data_list
def test_trace(): def test_trace():
for symbolic in [False, True]: for symbolic in [False, True]:
...@@ -81,13 +108,58 @@ def test_print_in_trace(): ...@@ -81,13 +108,58 @@ def test_print_in_trace():
def test_dump(): def test_dump():
@trace(symbolic=True, capture_as_const=True)
def f(a, b):
op = ops.Elemwise(mode="add")
(y,) = apply(op, a, b)
return y
a = as_raw_tensor([2]).numpy()
b = as_raw_tensor([4]).numpy()
y = f.__wrapped__(as_raw_tensor(a), as_raw_tensor(b)).numpy()
for i in range(3):
np.testing.assert_equal(f(as_raw_tensor(a), as_raw_tensor(b)).numpy(), y)
file = io.BytesIO()
f.dump(file)
file.seek(0)
result = load_and_inference(file, [a, b])
np.testing.assert_equal(result[0], y)
def test_capture_dump():
a = as_raw_tensor([2])
@trace(symbolic=True, capture_as_const=True)
def f(x):
op = ops.Elemwise(mode="mul")
(y,) = apply(op, x, a)
return y
x = as_raw_tensor([3]).numpy()
y = f.__wrapped__(as_raw_tensor(x)).numpy()
for i in range(3):
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
file = io.BytesIO()
f.dump(file)
file.seek(0)
result = load_and_inference(file, [x])
np.testing.assert_equal(result[0], y)
def test_dump_volatile():
p = as_raw_tensor([2])
@trace(symbolic=True, capture_as_const=True) @trace(symbolic=True, capture_as_const=True)
def f(x): def f(x):
op = ops.Elemwise(mode="negate") op = ops.Elemwise(mode="mul")
(y,) = apply(op, x) (y,) = apply(op, x, p)
return y return y
x = as_raw_tensor([1]).numpy() x = as_raw_tensor([3]).numpy()
y = f.__wrapped__(as_raw_tensor(x)).numpy() y = f.__wrapped__(as_raw_tensor(x)).numpy()
for i in range(3): for i in range(3):
...@@ -95,6 +167,13 @@ def test_dump(): ...@@ -95,6 +167,13 @@ def test_dump():
file = io.BytesIO() file = io.BytesIO()
f.dump(file) f.dump(file)
file.seek(0)
cg, _, outputs = mgb_graph.load_graph(file)
(out,) = outputs
assert (
cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1])
== "SharedDeviceTensor"
)
def test_trace_profiler(): def test_trace_profiler():
......
...@@ -471,11 +471,9 @@ def main(): ...@@ -471,11 +471,9 @@ def main():
assert not testcase, 'extra inputs provided in testcase: {}'.format( assert not testcase, 'extra inputs provided in testcase: {}'.format(
testcase.keys() testcase.keys()
) )
mgb.serialize_comp_graph_to_file( with open(args.output, "ab") as fout:
args.output, fout.write(G.dump_graph(*output_mgbvars))
output_mgbvars,
append=True,
output_strip_info=args.output_strip_info)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册