提交 27f2ecc4 编写于 作者: M Megvii Engine Team

feat(imperative/opr): add extern c opr and support dump

GitOrigin-RevId: b29e374c6052136d92f7f3eb5ac8debe983de855
上级 91264f37
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import functools
import inspect
import sys
import typing
from abc import ABC
class OpBase:
pass
class TensorBase:
pass
class TensorWrapperBase:
pass
......@@ -20,7 +20,6 @@ from .._imperative_rt import GraphOptimizeOptions
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode
from .._wrap import as_device
from ..ops.builtin import OpDef
from .core import TensorBase
def set_priority_to_id(dest_vars):
......@@ -127,7 +126,7 @@ class Graph(_imperative_rt.ComputingGraph):
print("this function should be called after compilation.")
class VarNode(TensorBase):
class VarNode:
def __init__(self, node: _imperative_rt.VarNode, isscalar=False):
self._node = node
self._isscalar = isscalar
......
......@@ -7,12 +7,31 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=redefined-builtin
from typing import Sequence
from typing import Iterable, List, Sequence
from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
def extern_opr_subgraph(
inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes, output_dtypes
):
r"""Load a serialized extern opr subgraph and fake execute the operator.
Args:
inputs: list of input tensors.
output_shapes: The output shapes.
dump_name: The serialized subgraph name.
dump_data: The serialized subgraph.
"""
if not isinstance(inputs, Iterable):
inputs = (inputs,)
op = builtin.ExternOpr(
output_shapes, dump_name, dump_data, len(dump_data), output_dtypes
)
return apply(op, *inputs)
def tensorrt_runtime_opr(inputs, *, data: bytes = None):
# empty model will give None result
if data is None:
......
......@@ -29,13 +29,13 @@ from ..core._imperative_rt.core2 import (
from ..core._imperative_rt.ops import (
AssertEqual,
CollectiveComm,
ExternOpr,
RemoteRecv,
RemoteSend,
)
from ..core._trace_option import set_symbolic_shape
from ..core._wrap import as_device
from ..core.ops.builtin import BatchNorm, OpDef
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
from ..core.tensor.utils import setscalar
from ..utils.naming import AutoNaming
......@@ -129,6 +129,7 @@ class trace:
function: the function will be traced.
symbolic: whether to apply symbolic execution for tracing. Default: False
capture_as_const: capture global vars or closures as const value. Default: False
record_only: if True, won't run even if call the function. Default: False
sublinear_memory_config: configuration for sublinear memory optimization.
If not None, it enables sublinear memory optimization with given setting.
profiling: whether to profile compiled trace. Default: False
......@@ -147,6 +148,7 @@ class trace:
function,
symbolic=False,
capture_as_const=False,
record_only=False,
sublinear_memory_config: SublinearMemoryConfig = None,
dtr_config: DTRConfig = None,
profiling: bool = False,
......@@ -155,8 +157,9 @@ class trace:
symbolic_shape: bool = True,
):
self.__wrapped__ = function
self._symbolic = symbolic
self._capture_as_const = capture_as_const
self._symbolic = symbolic or record_only
self._capture_as_const = capture_as_const or record_only
self._record_only = record_only
self._sublinear_memory_config = sublinear_memory_config
self._dtr_config = dtr_config
self._profiling = profiling
......@@ -418,35 +421,40 @@ class trace:
def do_finalize():
escaped_tensors = self._take_escaped_tensors()
if self._untraced:
for x in escaped_tensors:
if x():
info = self._tinfo[x()._mixin_handle]
info.data_read = True
x()._mixin_handle = -1
x()._recording = False
if self._inputs_to_restore:
for x in self._inputs_to_restore:
x._mixin_handle = -1
x._recording = False
if self._symbolic and (
self._lazy_eval_tensors or self._lazy_eval_links
):
# eval lazy eval tensors
self._lazy_eval(
self._lazy_eval_graph,
self._lazy_eval_tensors,
self._lazy_eval_links,
)
if self._record_only:
self._lazy_eval_graph = None
self._lazy_eval_tensors = None
self._lazy_eval_links = None
self._untraced = False
else:
for x in escaped_tensors:
if x():
info = self._tinfo[x()._mixin_handle]
info.data_read = True
x()._mixin_handle = -1
x()._recording = False
if self._inputs_to_restore:
for x in self._inputs_to_restore:
x._mixin_handle = -1
x._recording = False
if self._symbolic and (
self._lazy_eval_tensors or self._lazy_eval_links
):
# eval lazy eval tensors
self._lazy_eval(
self._lazy_eval_graph,
self._lazy_eval_tensors,
self._lazy_eval_links,
)
self._lazy_eval_graph = None
self._lazy_eval_tensors = None
self._lazy_eval_links = None
self._untraced = False
else:
# compiled_tensor leaks
if self._pc == len(self._seq):
for x in escaped_tensors:
try:
assign_raw_tensor(x(), RawTensor(x()._dev_tensor()))
x().__init__(RawTensor(x()._dev_tensor()))
except RuntimeError:
# TraceMismatchError thrown in do_exit
pass
......@@ -769,8 +777,8 @@ class trace:
raise ValueError(
"you must specify capture_as_const=True at __init__ to use dump"
)
if self._untraced:
raise RuntimeError("should run at least once before calling dump")
if self._untraced and len(self._seq) == 0:
raise RuntimeError("should do record first before dump")
if self._output_names and output_names:
raise TypeError(
"cannot specify output_names when output is already in dict format"
......@@ -1104,10 +1112,6 @@ class CompiledTensorProxy:
self.__info.data_reader.drop_value()
def assign_raw_tensor(lhs, rhs):
lhs.__init__(rhs)
def apply_symbolic_mode(op: OpDef, *args: RawTensor):
graph = active_trace._lazy_eval_graph
ivars = []
......
......@@ -12,11 +12,55 @@ import numpy as np
from ..functional.external import (
atlas_runtime_opr,
cambricon_runtime_opr,
extern_opr_subgraph,
tensorrt_runtime_opr,
)
from .module import Module
class ExternOprSubgraph(Module):
r"""Load a serialized ExternOpr subgraph.
See :func:`~.extern_opr` for more details.
"""
def __init__(
self, output_shapes, dump_name, dump_data, output_dtypes=None, **kwargs
):
super(ExternOprSubgraph, self).__init__(**kwargs)
self._output_shapes = output_shapes
self._dump_name = dump_name
self._dump_data = dump_data
self._output_dtypes = output_dtypes
if self._output_dtypes is None:
self._output_dtypes = [np.float32] * len(output_shapes)
@property
def data(self):
return self._dump_data
@data.setter
def data(self, val):
self._dump_data = np.frombuffer(val, dtype=np.uint8)
@property
def name(self):
return self._dump_name
@name.setter
def name(self, val):
self._dump_name = val
def forward(self, *inputs):
return extern_opr_subgraph(
inputs,
output_shapes=self._output_shapes,
dump_name=self._dump_name,
dump_data=self._dump_data,
output_dtypes=self._output_dtypes,
)
class TensorrtRuntimeSubgraph(Module):
r"""Load a serialized TensorrtRuntime subgraph.
......
......@@ -76,7 +76,7 @@ class XORNet(Module):
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_training_converge(test_traced_module):
net = XORNet()
if test_training_converge:
if test_traced_module:
inp = Tensor(np.random.random((14, 2)))
net = trace_module(net, inp)
......
/**
* \file imperative/src/impl/ops/extern_opr.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/ops/autogen.h"
#include "../op_trait.h"
#include "megbrain/serialization/extern_c_opr_io.h"
namespace mgb::imperative {
namespace { namespace externopr {
TensorShapeArray get_shapes(const std::vector<std::vector<size_t>>& shapes) {
TensorShapeArray ret;
for (auto&& i:shapes) {
SmallVector<size_t> shape(i.begin(), i.end());
TensorShape shp(shape);
ret.push_back(shp);
}
return ret;
}
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const ExternOpr&>(def);
SymbolVarArray symbol_var_inputs(inputs.begin(), inputs.end());
SmallVector<DType> output_dtypes(op.output_dtypes.begin(), op.output_dtypes.end());
auto&& output_shapes = get_shapes(op.output_shapes);
cg::OperatorNodeBase* opr = opr::ExternCOprRunner::make_placeholder(
symbol_var_inputs, output_shapes, op.name.c_str(), op.data.c_str(),
op.data_len, {}, output_dtypes);
return opr;
}
OP_TRAIT_REG(ExternOpr, ExternOpr, opr::ExternCOprRunner)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // externopr
} // namespace mgb::imperative
......@@ -10,7 +10,7 @@ import argparse
import numpy as np
import yaml
from megengine import jit
from megengine import jit, tensor
from megengine.module.external import ExternOprSubgraph
......@@ -27,12 +27,12 @@ def main():
description="load a .pb model and convert to corresponding "
"load-and-run model"
)
parser.add_argument("input", help="mace model file")
parser.add_argument("param", help="mace param file")
parser.add_argument("--input", help="mace model file")
parser.add_argument("--param", help="mace param file")
parser.add_argument(
"output", help="converted model that can be fed to dump_with_testcase_mge.py"
"--output", help="converted model that can be fed to dump_with_testcase_mge.py"
)
parser.add_argument("config", help="config file with yaml format")
parser.add_argument("--config", help="config file with yaml format")
args = parser.parse_args()
with open(args.config, "r") as f:
......@@ -90,17 +90,17 @@ def main():
+ raw_param
)
net = ExternOprSubgraph(wk_raw_content, "mace", osizes)
net = ExternOprSubgraph(osizes, "mace", wk_raw_content)
net.eval()
@jit.trace(symbolic=True)
@jit.trace(record_only=True)
def inference(inputs):
return net(inputs)
inputs = [
np.random.random(isizes[i]).astype(np.float32) for i in range(len(isizes))
tensor(np.random.random(isizes[i]).astype(np.float32)) for i in range(len(isizes))
]
inference.trace(*inputs)
inference(*inputs)
inference.dump(args.output)
......
......@@ -381,6 +381,24 @@ def CheckHasInf: MgbHashableOp<"CheckHasInf", [EmptyParam]>;
def FastpathCopy: MgbHashableOp<"FastpathCopy">;
def ExternOpr: MgbHashableOp<"ExternOpr"> {
let extraArguments = (ins
MgbArrayAttr<MgbArrayAttr<MgbSizeTAddr>>:$output_shapes,
MgbStringAttr:$name,
MgbStringAttr:$data,
MgbSizeTAddr:$data_len,
MgbArrayAttr<MgbDTypeAttr>:$output_dtypes
);
let hashFunction = [{
return mgb::hash_pair_combine(
mgb::hash($_self.dyn_typeinfo()),
mgb::hash_pair_combine(
mgb::hash($_self.name),
mgb::hash($_self.data))
);
}];
}
def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>;
def Split: MgbHashableOp<"Split", [EmptyParam]> {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册