From ac11c38a977a935dd2230bc6a42f7158feff9c93 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 6 Sep 2020 17:58:34 +0800 Subject: [PATCH] feat(mge/imperative): add graph load and cgtools for imperative GitOrigin-RevId: ba251f452ae8c6cc9c3dae99d1be92711cbeff5e --- imperative/python/megengine/__init__.py | 1 + imperative/python/megengine/core/__init__.py | 2 + .../megengine/core/tensor/megbrain_graph.py | 77 +++++- .../megengine/core/utils/comp_graph_tools.py | 253 ++++++++++++++++++ imperative/python/megengine/jit/tracing.py | 2 +- imperative/python/src/graph_rt.cpp | 112 +++++++- imperative/python/src/graph_rt.h | 2 +- imperative/python/test/unit/test_cgtools.py | 90 +++++++ imperative/python/test/unit/test_tracing.py | 85 +++++- sdk/load-and-run/dump_with_testcase_mge.py | 8 +- 10 files changed, 612 insertions(+), 20 deletions(-) create mode 100644 imperative/python/megengine/core/utils/comp_graph_tools.py create mode 100644 imperative/python/test/unit/test_cgtools.py diff --git a/imperative/python/megengine/__init__.py b/imperative/python/megengine/__init__.py index 8ab657e47..fa1d69623 100644 --- a/imperative/python/megengine/__init__.py +++ b/imperative/python/megengine/__init__.py @@ -76,6 +76,7 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level from .serialization import load, save from .tensor import Parameter, Tensor, tensor from .version import __version__ +from .core import cgtools _set_fork_exec_path_for_timed_func( sys.executable, diff --git a/imperative/python/megengine/core/__init__.py b/imperative/python/megengine/core/__init__.py index e24057552..50d29e9eb 100644 --- a/imperative/python/megengine/core/__init__.py +++ b/imperative/python/megengine/core/__init__.py @@ -10,3 +10,5 @@ import os import sys from .tensor import Tensor +from .tensor.megbrain_graph import Graph +from .utils import comp_graph_tools as cgtools diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index da2c6f6c4..3ce75d632 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -7,6 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections +import json import threading import weakref from concurrent.futures import Future, ThreadPoolExecutor @@ -162,14 +163,42 @@ def optimize_for_inference(dest_vars, **kwargs): return [VarNode(i) for i in res_vars] -def dump(*args): +def dump_graph(*args): return _imperative_rt.dump_graph([i._node for i in args]) +CompGraphLoadResult = collections.namedtuple( + "CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list"] +) + + +def load_graph(fpath): + """Load a serialized computing graph from file. + + :parma fpath: Path or Handle for the output file + :return: An instance of namedtuple :class:`CompGraphLoadResult`, + whose fields are: + + * ``graph`` loaded CompGraph + * ``output_vars_dict`` A Python dict, mapping name to output SymbolVar + * ``output_vars_list`` A Python list, containing output vars in the + order passed to serialize_comp_graph_to_file + """ + output_vars_map = [] + output_vars_list = [] + if isinstance(fpath, str): + buf = open(fpath, "rb").read() + else: + buf = fpath.read() + cg = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list) + return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list) + + class VarNode(TensorBase): def __init__(self, node: _imperative_rt.VarNode): self._node = node - self.graph._var_cache[node] = self + if hasattr(self.graph, "_var_cache"): + self.graph._var_cache[node] = self @property def graph(self) -> Graph: @@ -177,12 +206,19 @@ class VarNode(TensorBase): @property def op(self): - return self.graph._wrap(self._node.owner) + if hasattr(self.graph, "_wrap"): + return self.graph._wrap(self._node.owner) + else: + return self._node.owner @property def name(self): return self._node.name + @property + def id(self): + return self._node.id + @name.setter def name(self, name): self._node.name = name @@ -207,7 +243,8 @@ class VarNode(TensorBase): class OpNode: def __init__(self, node: _imperative_rt.OperatorNode): self._node = node - self.graph._op_cache[node] = self + if hasattr(self.graph, "_op_cache"): + self.graph._op_cache[node] = self @property def graph(self) -> Graph: @@ -217,29 +254,53 @@ class OpNode: def name(self): return self._node.name + @property + def id(self): + return self._node.id + @name.setter def name(self, name): self._node.name = name @property def inputs(self): - return tuple(map(self.graph._wrap, self._node.inputs)) + if hasattr(self.graph, "_wrap"): + return tuple(map(self.graph._wrap, self._node.inputs)) + else: + return self._node.inputs @property def outputs(self): - return tuple(map(self.graph._wrap, self._node.outputs)) + if hasattr(self.graph, "_wrap"): + return tuple(map(self.graph._wrap, self._node.outputs)) + else: + return self._node.outputs + + @property + def params(self): + return json.loads(self._node.params) + + @property + def type(self): + return self._node.type def _wrap(x): if isinstance(x, collections.abc.Sequence): return type(x)(map(_wrap, x)) - return x.graph._wrap(x) + if hasattr(x.graph, "_wrap"): + return x.graph._wrap(x) + else: + return x def _unwrap(x): if isinstance(x, collections.abc.Sequence): return type(x)(map(_unwrap, x)) - return x._node + if isinstance(x, VarNode): + return x._node + else: + return x @apply.register() diff --git a/imperative/python/megengine/core/utils/comp_graph_tools.py b/imperative/python/megengine/core/utils/comp_graph_tools.py new file mode 100644 index 000000000..ceffcc7ef --- /dev/null +++ b/imperative/python/megengine/core/utils/comp_graph_tools.py @@ -0,0 +1,253 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import collections +from typing import Dict, List + +from .. import _imperative_rt +from .._imperative_rt import OperatorNode, VarNode + + +def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]: + """return :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var`` + depands on. If ``var_type`` is None, return all types. + """ + outputs = [] + memo = set() + + if isinstance(var, VarNode): + var = [var] + + if isinstance(var_type, str): + var_type = [var_type] + + q = list(var) + while q: + v = q.pop() + if v in memo: + continue + memo.add(v) + q.extend(get_owner_opr_inputs(v)) + if var_type is not None: + if get_owner_opr_type(v) in var_type: + outputs.append(v) + else: + outputs.append(v) + + return outputs + + +def get_owner_opr_inputs(var: VarNode) -> List[VarNode]: + """get the inputs of owner opr of a variable + """ + assert isinstance(var, VarNode) + return var.owner.inputs + + +def get_owner_opr_type(var: VarNode) -> str: + """get the type of owner opr of a variable + + """ + assert isinstance(var, VarNode) + return var.owner.type + + +def get_opr_type(opr: OperatorNode) -> str: + """get the type of a opr + """ + assert isinstance(opr, OperatorNode) + return opr.type + + +def graph_traversal(outputs: VarNode): + """helper function to traverse the computing graph and return enough useful information + + :param outputs: model outputs + :return: tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree) + WHERE + map_oprs is dict from opr_id to actual opr + map_vars is dict from var_id to actual var + var2oprs is dict from var to dest oprs along with index + opr2receivers is dict from current opr to next opr + indegree2opr is dict from in_degree to opr in computing graph + opr2indegree is dict from opr in computing graph to in_degree + + (indegree2opr, opr2indegree) are only used in topological sort in get_oprs_seq function + """ + # meta information for comp graph + map_oprs = collections.defaultdict(set) + map_vars = collections.defaultdict(set) + + var2oprs = collections.defaultdict(list) + opr2receivers = collections.defaultdict(list) + + queue = list(map(lambda x: x.owner, outputs)) + visited = set(map(lambda x: x.id, queue)) + + # iterate through whole comp_graph, fill in meta information + indegree2opr = collections.defaultdict(set) + opr2indegree = {} + + idx = 0 + while idx < len(queue): + cur_opr = queue[idx] + map_oprs[cur_opr.id] = cur_opr + + idx += 1 + + indegree = 0 + for var_idx, var in enumerate(cur_opr.inputs): + map_vars[var.id] = var + var2oprs[var.id].append((cur_opr.id, var_idx)) + + pre_opr = var.owner + + if pre_opr.id not in visited: + visited.add(pre_opr.id) + queue.append(pre_opr) + + indegree += 1 + opr2receivers[pre_opr.id].append(cur_opr.id) + + indegree2opr[indegree].add(cur_opr.id) + opr2indegree[cur_opr.id] = indegree + + return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree + + +def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNode]: + """get oprs in some topological order for a dumped model + + :param outputs: model outputs + :param prune_reshape: whether to prune the operators useless during inference + :return: opr list with some correct execution order + """ + + def topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree): + # generate an execution order with topological sort algorithm + oprs_seq = [] + nr_remain = len(map_oprs) + while indegree2opr[0]: + opr_id = indegree2opr[0].pop() + opr = map_oprs[opr_id] + nr_remain -= 1 + + # skip const value generation operator + if get_opr_type(opr) != "ImmutableTensor": + oprs_seq.append(opr) + + for post_id in opr2receivers[opr_id]: + indegree = opr2indegree[post_id] + indegree2opr[indegree].remove(post_id) + + indegree -= 1 + indegree2opr[indegree].add(post_id) + opr2indegree[post_id] = indegree + + assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format( + nr_remain + ) + return oprs_seq + + # reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor + # when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph + def prune_reshape_oprs(outputs, oprs_seq, var2oprs): + def iterative_pruning(cur_opr, post_opr, marked_opr_ids): + useless = True + for oup in cur_opr.outputs: + if "workspace" not in oup.name: + var_idx = post_opr.inputs.index(oup) + var2oprs[oup.id].remove((post_opr.id, var_idx)) + useless = useless and (len(var2oprs[oup.id]) == 0) + + if useless: + marked_opr_ids.append(cur_opr.id) + + for inp in cur_opr.inputs: + iterative_pruning(inp.owner, cur_opr, marked_opr_ids) + + reshape_vars = get_dep_vars(outputs, "Reshape") + reshape_oprs = [var.owner for var in reshape_vars] + + marked_opr_ids = [] + for reshape_opr in reshape_oprs: + iterative_pruning(reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids) + + # filter out all marked oprs + return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) + + map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal( + outputs + ) + oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree) + if prune_reshape is True: + oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy()) + return oprs_seq + + +def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]: + """replace vars in the graph + + :param dst: target vars representing the graph + :param varmap: the map that specifies how to replace the vars + + :return: new vars that correspond to ``dst`` with all the dependencies + replaced + """ + dst_vec = [] + repl_src_vec = [] + repl_dst_vec = [] + for i in dst: + assert isinstance(i, VarNode) + dst_vec.append(i) + + for i, j in getattr(varmap, "items", lambda: varmap)(): + assert isinstance(i, VarNode) + assert isinstance(j, VarNode) + repl_src_vec.append(i) + repl_dst_vec.append(j) + + return _imperative_rt.graph._replace_vars(repl_src_vec, repl_dst_vec, dst_vec) + + +def replace_oprs( + dst: List[VarNode], oprmap: Dict[OperatorNode, OperatorNode] +) -> List[VarNode]: + """Replace operators in the graph. + + :param dst: target vars representing the graph + :param oprmap: the map that specifies how to replace the operators + + :return: new vars that correspond to ``dst`` with all the dependencies + replaced + """ + dst_vec = [] + repl_src_vec = [] + repl_dst_vec = [] + for i in dst: + assert isinstance(i, VarNode) + dst_vec.append(i) + + for i, j in getattr(oprmap, "items", lambda: oprmap)(): + assert isinstance(i, OperatorNode) + assert isinstance(j, OperatorNode) + repl_src_vec.append(i) + repl_dst_vec.append(j) + + return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) + + +def set_priority_to_id(dest_vars): + """For all oprs in the subgraph constructed by dest_vars + set its priority to id if its original priority is zero + :param dest_vars: target vars representing the graph + """ + dest_vec = [] + for i in dest_vars: + assert isinstance(i, VarNode) + dest_vec.append(i) + _imperative_rt.graph._set_priority_to_id(dest_vec) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 987a96fb8..4fc59de37 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -569,7 +569,7 @@ class trace: if isinstance(file, str): permission = "wb" if append == False else "ab" file = open(file, permission) - file.write(G.dump(*dest_vars)) + file.write(G.dump_graph(*dest_vars)) def _process_inputs(self, *args, **kwargs): if self._untraced: diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 6874388b6..ecba0d3ce 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -64,7 +64,60 @@ auto def_rendezvous(py::object m, const char* name) { using TensorAttr = LogicalTensorDesc; using HostNDWithEvent = std::pair>; +std::vector _replace_vars(const std::vector& repl_src, + const std::vector& repl_dst, + const std::vector& vars) { + mgb::ThinHashMap varmap; + for (size_t i = 0; i < repl_src.size(); ++i) { + varmap[SymbolVar(repl_src[i])] = SymbolVar(repl_dst[i]); + } + SymbolVarArray symvars(vars.begin(), vars.end()); + auto sym_result = mgb::cg::replace_vars(symvars, varmap); + std::vector result; + for (auto symvar : sym_result){ + result.push_back(symvar.node()); + } + return result; + } + +typedef std::vector OperatorArray; +std::vector _replace_oprs(const OperatorArray& repl_src, + const OperatorArray& repl_dst, + const std::vector& vars) { + mgb::ThinHashMap + oprmap; + for (size_t i = 0; i < repl_src.size(); ++i) { + oprmap[repl_src[i]] = repl_dst[i]; + } + const SymbolVarArray symvars(vars.begin(), vars.end()); + auto sym_result = mgb::cg::replace_oprs(symvars, oprmap); + std::vector result; + for (auto symvar : sym_result){ + result.push_back(symvar.node()); + } + return result; + } + + + +void _set_priority_to_id(const std::vector& dest_vars) { + auto on_opr = [](mgb::cg::OperatorNodeBase* opr) { + if (opr->node_prop().attribute().priority == 0) { + opr->node_prop().attribute().priority = opr->id(); + } + }; + mgb::cg::DepOprIter dep_iter{on_opr}; + for (const auto& var : dest_vars) { + dep_iter.add(SymbolVar(var)); + } +} + + + void init_graph_rt(py::module m) { + + static const std::unique_ptr _imperative_sm_opr_footprint_ptr{std::make_unique()}; + def_rendezvous(m, "DeviceTensorNDRendezvous"); def_rendezvous(m, "HostTensorNDRendezvous"); @@ -99,7 +152,10 @@ void init_graph_rt(py::module m) { return py::none(); } return py::cast(*val).attr("numpy")(); - }); + }) + .def_property_readonly("id",[](cg::VarNode* v){ + return (v->id()); + }); py::class_>(m, "OperatorNode") .def_property_readonly("graph", [](cg::OperatorNodeBase* opr) {return opr->owner_graph();}) @@ -110,7 +166,17 @@ void init_graph_rt(py::module m) { }) .def_property_readonly("outputs", [](cg::OperatorNodeBase* opr) { return to_tuple(opr->usable_output()); - }); + }) + .def_property_readonly("id",[](cg::OperatorNodeBase* opr){ + return opr->id(); + }) + .def_property_readonly("params",[](cg::OperatorNodeBase* opr){ + return _imperative_sm_opr_footprint_ptr->calc_footprint(opr).param->to_string(); + }) + .def_property_readonly("type",[](cg::OperatorNodeBase* opr){ + return opr->dyn_typeinfo()->name; + }); + py::class_(m, "AsyncExecutable") .def("execute", &cg::AsyncExecutable::execute, py::call_guard()) @@ -174,6 +240,44 @@ void init_graph_rt(py::module m) { }); + m.def("load_graph", [](std::string& buf, py::list& _output_var_map, py::list& _output_var_list) { + using namespace mgb::serialization; + auto file = InputFile::make_mem_proxy(buf.c_str(), buf.length()); + auto format = GraphLoader::identify_graph_dump_format(*file); + auto loader = GraphLoader::make(std::move(file), format.val()); + GraphLoader::LoadConfig config; + auto rst = loader->load(config); + std::vector> output_var_map; + SymbolVarArray output_var_list; + output_var_map = {rst.output_var_map.begin(), rst.output_var_map.end()}; + output_var_list = std::move(rst.output_var_list); + for (auto i : output_var_list){ + _output_var_list.append(i.node()); + } + for (auto i : output_var_map){ + _output_var_map.append(py::make_tuple(i.first,i.second.node())); + } + std::unordered_map tensor2name; + for (const auto& pair : rst.tensor_map) { + tensor2name[pair.second.get()] = &pair.first; + } + auto cb = [&tensor2name, graph=rst.graph](cg::OperatorNodeBase* opr) { + if (!opr->same_type()) + return; + auto& h2d = opr->cast_final_safe(); + auto it = tensor2name.find(h2d.host_data().get()); + mgb_throw_if(it == tensor2name.end(), GraphError, + "unbound Host2DeviceCopy in loaded graph"); + h2d.output(0)->name(*it->second); + }; + cg::DepOprIter iter{cb}; + for (const auto& var : output_var_list) { + iter.add(var.node()->owner_opr()); + } + return rst.graph; + + }); + #define CURRENT_CLASS cg::ComputingGraph::Options auto PyComputingGraphOptions = py::class_(PyComputingGraph, "Options") @@ -287,6 +391,10 @@ void init_graph_rt(py::module m) { return opr::Host2DeviceCopy::make(graph, std::make_shared(cn, shape, dtype), config).node(); }, py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::none()); + m.def("_replace_vars", &_replace_vars,py::arg(),py::arg(),py::arg()); + m.def("_replace_oprs", &_replace_oprs,py::arg(),py::arg(),py::arg()); + m.def("_set_priority_to_id",&_set_priority_to_id,py::arg()); + m.def("input_callback", [input_callback](std::function callback, const CompNode& comp_node, const DType& dtype, diff --git a/imperative/python/src/graph_rt.h b/imperative/python/src/graph_rt.h index b9116211c..ee2a11da8 100644 --- a/imperative/python/src/graph_rt.h +++ b/imperative/python/src/graph_rt.h @@ -16,7 +16,7 @@ #include #include #include - +#include "megbrain/plugin/opr_footprint.h" #include "megbrain/graph.h" template diff --git a/imperative/python/test/unit/test_cgtools.py b/imperative/python/test/unit/test_cgtools.py new file mode 100644 index 000000000..da611edfa --- /dev/null +++ b/imperative/python/test/unit/test_cgtools.py @@ -0,0 +1,90 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import io + +import numpy as np + +import megengine +import megengine.functional as F +import megengine.module as M +from megengine import cgtools +from megengine.core.tensor import megbrain_graph as mgb_graph +from megengine.core.tensor.raw_tensor import as_raw_tensor +from megengine.jit import trace + + +def make_dev_tensor(value, dtype=None, device=None): + return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor() + + +def test_replace_vars(): + g = mgb_graph.Graph() + g.options.async_exec_level = 0b100 + device = "xpux" + dtype = np.float32 + a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g) + const = g.make_const(1.234) + a_plus_a = F.add(a.outputs[0], a.outputs[0]) + a_plus_a_mul_const = F.mul(a_plus_a, const) + rst = F.add(a_plus_a_mul_const, a.outputs[0]) + (new,) = cgtools.replace_vars([rst._node], {const._node: a_plus_a._node}) + out = mgb_graph.OutputNode(mgb_graph.VarNode(new)) + func = g.compile(out.outputs[0]) + func.execute() + x = make_dev_tensor(5.0, device=device) + a.set_value(x) + res = out.get_value().numpy() + np.testing.assert_equal(res, np.array([105.0])) + + +def test_replace_oprs(): + g = mgb_graph.Graph() + g.options.async_exec_level = 0b100 + device = "xpux" + dtype = np.float32 + a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g) + const = g.make_const(1.25) + a_plus_a = F.add(a.outputs[0], a.outputs[0]) + old_opr = a_plus_a.op + a_plus_a_mul_const = F.mul(a_plus_a, const) + a_mul_a = F.mul(a.outputs[0], a.outputs[0]) + new_opr = a_mul_a.op + (new,) = cgtools.replace_oprs( + [a_plus_a_mul_const._node], {old_opr._node: new_opr._node} + ) + out = mgb_graph.OutputNode(mgb_graph.VarNode(new)) + func = g.compile(out.outputs[0]) + func.execute() + x = make_dev_tensor(5.0, device=device) + a.set_value(x) + res = out.get_value().numpy() + np.testing.assert_equal(res, np.array([5.0 * 5.0 * 1.25])) + + +def test_graph_traversal(): + net = M.Conv2d(3, 32, 3) + + @trace(symbolic=True, capture_as_const=True) + def fun(data): + x = net(data) + return x + + data = np.random.random([1, 3, 224, 224]).astype(np.float32) + for i in range(3): + fun(megengine.tensor(data)) + + file = io.BytesIO() + fun.dump(file) + file.seek(0) + cg, _, outputs = mgb_graph.load_graph(file) + + _, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs) + input_var = map_vars[1] + _, var_idx = var2oprs[input_var.id][0] + + assert var_idx == 0 diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index d78d231ce..ce4c1e91c 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -13,6 +13,10 @@ import numpy as np import pytest from megengine import tensor +import megengine +import megengine.core.tensor.megbrain_graph as mgb_graph +import megengine.module as M +from megengine import cgtools from megengine.core.ops import builtin as ops from megengine.core.tensor import megbrain_graph as G from megengine.core.tensor.core import apply @@ -21,6 +25,29 @@ from megengine.functional import exp, log from megengine.jit import exclude_from_trace, trace +def load_and_inference(file, inp_data): + cg, _, out_list = mgb_graph.load_graph(file) + inputs = cgtools.get_dep_vars(out_list, "Host2DeviceCopy") + replace_dict = {} + inp_node_list = [] + for i in inputs: + inp_node = mgb_graph.InputNode( + device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph + ) + replace_dict[i] = inp_node.outputs[0] + inp_node_list.append(inp_node) + new_out = cgtools.replace_vars(out_list, replace_dict) + out_node_list = [mgb_graph.OutputNode(i) for i in new_out] + new_out_list = [i.outputs[0] for i in out_node_list] + new_cg = new_out_list[0].graph + func = new_cg.compile(new_out_list) + for node, value in zip(inp_node_list, inp_data): + node.set_value(as_raw_tensor(value)._dev_tensor()) + func.execute() + out_data_list = [o.get_value().numpy() for o in out_node_list] + return out_data_list + + def test_trace(): for symbolic in [False, True]: @@ -81,13 +108,58 @@ def test_print_in_trace(): def test_dump(): + @trace(symbolic=True, capture_as_const=True) + def f(a, b): + op = ops.Elemwise(mode="add") + (y,) = apply(op, a, b) + return y + + a = as_raw_tensor([2]).numpy() + b = as_raw_tensor([4]).numpy() + y = f.__wrapped__(as_raw_tensor(a), as_raw_tensor(b)).numpy() + + for i in range(3): + np.testing.assert_equal(f(as_raw_tensor(a), as_raw_tensor(b)).numpy(), y) + + file = io.BytesIO() + f.dump(file) + file.seek(0) + result = load_and_inference(file, [a, b]) + np.testing.assert_equal(result[0], y) + + +def test_capture_dump(): + a = as_raw_tensor([2]) + + @trace(symbolic=True, capture_as_const=True) + def f(x): + op = ops.Elemwise(mode="mul") + (y,) = apply(op, x, a) + return y + + x = as_raw_tensor([3]).numpy() + y = f.__wrapped__(as_raw_tensor(x)).numpy() + + for i in range(3): + np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) + + file = io.BytesIO() + f.dump(file) + file.seek(0) + result = load_and_inference(file, [x]) + np.testing.assert_equal(result[0], y) + + +def test_dump_volatile(): + p = as_raw_tensor([2]) + @trace(symbolic=True, capture_as_const=True) def f(x): - op = ops.Elemwise(mode="negate") - (y,) = apply(op, x) + op = ops.Elemwise(mode="mul") + (y,) = apply(op, x, p) return y - x = as_raw_tensor([1]).numpy() + x = as_raw_tensor([3]).numpy() y = f.__wrapped__(as_raw_tensor(x)).numpy() for i in range(3): @@ -95,6 +167,13 @@ def test_dump(): file = io.BytesIO() f.dump(file) + file.seek(0) + cg, _, outputs = mgb_graph.load_graph(file) + (out,) = outputs + assert ( + cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1]) + == "SharedDeviceTensor" + ) def test_trace_profiler(): diff --git a/sdk/load-and-run/dump_with_testcase_mge.py b/sdk/load-and-run/dump_with_testcase_mge.py index 57b933b4b..4cbbba89e 100755 --- a/sdk/load-and-run/dump_with_testcase_mge.py +++ b/sdk/load-and-run/dump_with_testcase_mge.py @@ -471,11 +471,9 @@ def main(): assert not testcase, 'extra inputs provided in testcase: {}'.format( testcase.keys() ) - mgb.serialize_comp_graph_to_file( - args.output, - output_mgbvars, - append=True, - output_strip_info=args.output_strip_info) + with open(args.output, "ab") as fout: + fout.write(G.dump_graph(*output_mgbvars)) + if __name__ == '__main__': main() -- GitLab