提交 ae3123b3 编写于 作者: M Megvii Engine Team

feat(mge): add python graph for mgb graph editing

GitOrigin-RevId: 6a9d5beba2eb0ebce3908e785f041ba0aa7085a4
上级 c82d8875
......@@ -11,6 +11,7 @@ from typing import Iterable, Union
import numpy as np
from .._imperative_rt import VarNode
from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device
from ..ops import builtin
from ..ops.special import Const
......@@ -59,7 +60,7 @@ def astype(x, dtype):
def convert_single_value(v, *, dtype=None, device=None):
if isinstance(v, Tensor):
if isinstance(v, (Tensor, VarNode)):
if not is_quantize(v.dtype):
v = astype(v, dtype)
else:
......
......@@ -12,11 +12,12 @@ import functools
import numpy as np
from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.graph import VarNode
from ..core.ops import builtin
from ..core.ops.builtin import Elemwise
from ..core.tensor import utils
from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.utils import isscalar, setscalar
from ..core.tensor.utils import astype, isscalar, setscalar
from ..device import get_default_device
from ..jit.tracing import is_tracing
from ..tensor import Tensor
......@@ -77,7 +78,7 @@ __all__ = [
def _elwise(*args, mode):
tensor_args = list(filter(lambda x: isinstance(x, Tensor), args))
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, VarNode)), args))
if len(tensor_args) == 0:
dtype = utils.dtype_promotion(args)
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device())
......@@ -109,7 +110,7 @@ def _elwise(*args, mode):
Elemwise.Mode.ROUND,
) and np.issubdtype(args[0].dtype, np.integer):
return args[0]
args = tuple(map(lambda x: x.astype("float32"), args))
args = tuple(map(lambda x: astype(x, "float32"), args))
return _elwise_apply(args, mode)
......
......@@ -65,7 +65,6 @@ def get_owner_opr_inputs(var: VarNode) -> List[VarNode]:
"""
Gets the inputs of owner opr of a variable.
"""
assert isinstance(var, VarNode)
return var.owner.inputs
......@@ -74,7 +73,6 @@ def get_owner_opr_type(var: VarNode) -> str:
Gets the type of owner opr of a variable.
"""
assert isinstance(var, VarNode)
return var.owner.type
......@@ -109,7 +107,7 @@ def graph_traversal(outputs: VarNode):
var2oprs = collections.defaultdict(list)
opr2receivers = collections.defaultdict(list)
queue = list(map(lambda x: x.owner, outputs))
queue = list(set(map(lambda x: x.owner, outputs)))
visited = set(map(lambda x: x.id, queue))
# iterate through whole comp_graph, fill in meta information
......@@ -143,12 +141,15 @@ def graph_traversal(outputs: VarNode):
return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree
def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNode]:
def get_oprs_seq(
outputs: List[VarNode], prune_reshape=False, prune_immtensor=True
) -> List[OperatorNode]:
"""
Gets oprs in some topological order for a dumped model.
:param outputs: model outputs.
:param prune_reshape: whether to prune the useless operators during inference.
:param prune_reshape: whether to prune the useless operators used by Reshape opr during inference.
:param prune_immtensor: whether to prune the ImmutableTensor opr.
:return: opr list with some correct execution order.
"""
......@@ -160,9 +161,7 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo
opr_id = indegree2opr[0].pop()
opr = map_oprs[opr_id]
nr_remain -= 1
# skip const value generation operator
if get_opr_type(opr) != "ImmutableTensor":
if opr.type != "ImmutableTensor" or not prune_immtensor:
oprs_seq.append(opr)
for post_id in opr2receivers[opr_id]:
......
此差异已折叠。
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import json
import sys
from typing import Callable
from ..core import _imperative_rt as rt
from ..core._wrap import Device
from ..core.ops import builtin
from ..core.tensor.megbrain_graph import InputNode
from ..tensor import Tensor
from .comp_graph_tools import replace_vars
class NetworkNode:
pass
class VarNode(NetworkNode):
def __init__(self, owner_opr=None, name=None):
self.var = None
self.owner = owner_opr
self.name = name
self.id = id(self)
@classmethod
def load(cls, sym_var, owner_opr):
obj = cls()
obj.var = sym_var # mgb varnode
obj.name = sym_var.name
obj.owner = owner_opr
return obj
@property
def shape(self):
rst = None
if self.var:
try:
rst = self.var.shape
except:
rst = None
return rst
@property
def dtype(self):
return self.var.dtype if self.var else None
def set_owner_opr(self, owner_opr):
self.owner_opr = owner_opr
class OpNode(NetworkNode):
opdef = None
type = None
def __init__(self):
self.inputs = []
self.outputs = []
self.params = {}
self._opr = None # mgb opnode
self.id = id(self)
@classmethod
def load(cls, opr):
obj = cls()
obj.params = json.loads(opr.params)
obj.name = opr.name
obj._opr = opr
return obj
def compile(self, graph=None):
op = self.opdef(**self.params)
args = [i.var for i in self.inputs]
outputs = rt.invoke_op(op, args)
assert len(outputs) == len(self.outputs)
self._opr = outputs[0].owner
for i in range(len(self.outputs)):
self.outputs[i].var = outputs[i]
self.outputs[i].var.name = self.outputs[i].name
assert self.outputs[i].owner is self
def add_inp_var(self, x):
self.inputs.append(x)
def add_out_var(self, x):
self.outputs.append(x)
def str_to_mge_class(classname):
# TODO: use megbrain C++ RTTI to replace type string
if classname == "RNGOpr<MegDNNOpr>":
classname = "RNGOpr"
oprcls = getattr(sys.modules[__name__], classname, None)
return oprcls if oprcls else ReadOnlyOpNode
class Host2DeviceCopy(OpNode):
type = "Host2DeviceCopy"
def __init__(self, shape=None, dtype=None, name=None, device=None):
super().__init__()
self.shape = shape
self.dtype = dtype
self.name = name
self.device = Device(device).to_c() if device else Device("xpux").to_c()
self.outputs = []
@classmethod
def load(cls, opr):
self = cls()
self.outputs = []
assert len(opr.outputs) == 1, "wrong number of outputs"
self.shape = opr.outputs[0].shape
self.dtype = opr.outputs[0].dtype
self.name = opr.outputs[0].name
self.device = opr.outputs[0].comp_node
self._opr = opr
return self
def compile(self, graph):
outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
self._opr = outputs.owner
if len(self.outputs) == 0:
self.outputs.append(VarNode(self, self.name))
self.outputs[0].var = outputs
assert self.outputs[0].owner is self
class ImmutableTensor(OpNode):
type = "ImmutableTensor"
def __init__(self, data=None, name=None, device=None, graph=None):
super().__init__()
self.name = name
self.outputs = []
self.graph = graph
if data is not None:
self.set_value(data, device)
@property
def device(self):
return self._opr.outputs[0].comp_node if self._opr else None
@device.setter
def device(self, device):
self.set_value(self.numpy(), device)
@property
def shape(self):
return self.outputs[0].shape
@property
def dtype(self):
return self._opr.outputs[0].dtype if self._opr else None
def numpy(self):
return self._opr.outputs[0].value if self._opr else None
def set_value(self, data, device=None):
assert self.graph is not None
cn = device if device else self.device
assert isinstance(data, (int, float, np.ndarray))
if isinstance(data, (int, float)):
data = np.array(data)
if data.dtype == np.float64:
data = data.astype(np.float32)
elif data.dtype == np.int64:
data = data.astype(np.int32)
varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name)
if len(self.outputs) == 0:
self.outputs.append(VarNode(self, self.name))
self.outputs[0].var = varnode
self._opr = varnode.owner
@classmethod
def load(cls, opr):
self = cls()
self.outputs = []
self._opr = opr
self.name = opr.outputs[0].name
self.graph = opr.graph
return self
def compile(self, graph):
assert self.outputs[0].var is self._opr.outputs[0]
assert self.outputs[0].owner is self
if self.graph != graph:
self.graph = graph
self.set_value(self.numpy())
if self.name is not None:
self.outputs[0].var.name = self.name
class ReadOnlyOpNode(OpNode):
@classmethod
def load(cls, opr):
obj = super(ReadOnlyOpNode, cls).load(opr)
obj.type = opr.type
return obj
def compile(self):
assert self._opr is not None
assert len(self.inputs) == len(self._opr.inputs)
assert len(self.outputs) == len(self._opr.outputs)
repl_dict = {}
for ind, i in enumerate(self.inputs):
if i.var != self._opr.inputs[ind]:
repl_dict[self._opr.inputs[ind]] = i.var
if bool(repl_dict):
out_vars = replace_vars(self._opr.outputs, repl_dict)
for ind, o in enumerate(self.outputs):
o.var = out_vars[ind]
class Elemwise(OpNode):
type = "Elemwise"
opdef = builtin.Elemwise
class Reduce(OpNode):
type = "Reduce"
opdef = builtin.Reduce
class TypeCvt(OpNode):
type = "TypeCvt"
opdef = builtin.TypeCvt
@classmethod
def load(cls, opr):
obj = super(TypeCvt, cls).load(opr)
t_dtype = opr.outputs[0].dtype
obj.params["dtype"] = t_dtype
return obj
class MatrixInverse(OpNode):
type = "MatrixInverse"
opdef = builtin.MatrixInverse
class MatrixMul(OpNode):
type = "MatrixMul"
opdef = builtin.MatrixMul
class BatchedMatrixMul(OpNode):
type = "BatchedMatmul"
opdef = builtin.BatchedMatrixMul
class Dot(OpNode):
type = "Dot"
opdef = builtin.Dot
class SVD(OpNode):
type = "SVD"
opdef = builtin.SVD
class ConvolutionForward(OpNode):
type = "Convolution"
opdef = builtin.Convolution
class ConvolutionBackwardData(OpNode):
type = "ConvTranspose"
opdef = builtin.ConvolutionBackwardData
class DeformableConvForward(OpNode):
type = "DeformableConv"
opdef = builtin.DeformableConv
class GroupLocalForward(OpNode):
type = "GroupLocal"
opdef = builtin.GroupLocal
class PoolingForward(OpNode):
type = "Pooling"
opdef = builtin.Pooling
class AdaptivePoolingForward(OpNode):
type = "AdaptivePooling"
opdef = builtin.AdaptivePooling
class ROIPoolingForward(OpNode):
type = "ROIPooling"
opdef = builtin.ROIPooling
class DeformablePSROIPoolingForward(OpNode):
type = "DeformablePSROIPooling"
opdef = builtin.DeformablePSROIPooling
class ConvBiasForward(OpNode):
type = "ConvBias"
opdef = builtin.ConvBias
@classmethod
def load(cls, opr):
obj = super(ConvBiasForward, cls).load(opr)
obj.params["dtype"] = opr.outputs[0].dtype
return obj
class BatchConvBiasForward(OpNode):
type = "BatchConvBias"
opdef = builtin.BatchConvBias
@classmethod
def load(cls, opr):
obj = super(BatchConvBiasForward, cls).load(opr)
obj.params["dtype"] = opr.outputs[0].dtype
return obj
class BatchNormForward(OpNode):
type = "BatchNorm"
opdef = builtin.BatchNorm
class ROIAlignForward(OpNode):
type = "ROIAlign"
opdef = builtin.ROIAlign
class WarpPerspectiveForward(OpNode):
type = "WarpPerspective"
opdef = builtin.WarpPerspective
class WarpAffineForward(OpNode):
type = "WarpAffine"
opdef = builtin.WarpAffine
class RemapForward(OpNode):
type = "Remap"
opdef = builtin.Remap
class ResizeForward(OpNode):
type = "Resize"
opdef = builtin.Resize
class IndexingOneHot(OpNode):
type = "IndexingOneHot"
opdef = builtin.IndexingOneHot
class IndexingSetOneHot(OpNode):
type = "IndexingSetOneHot"
opdef = builtin.IndexingSetOneHot
class Copy(OpNode):
type = "Copy"
opdef = builtin.Copy
@classmethod
def load(cls, opr):
obj = super(Copy, cls).load(opr)
obj.params["comp_node"] = opr.outputs[0].comp_node
return obj
class ArgsortForward(OpNode):
type = "Argsort"
opdef = builtin.Argsort
class Argmax(OpNode):
type = "Argmax"
opdef = builtin.Argmax
class Argmin(OpNode):
type = "Argmin"
opdef = builtin.Argmin
class CondTake(OpNode):
type = "CondTake"
opdef = builtin.CondTake
class TopK(OpNode):
type = "TopK"
opdef = builtin.TopK
class NvOf(OpNode):
type = "NvOf"
opdef = builtin.NvOf
class RNGOpr(OpNode):
@classmethod
def load(cls, opr):
obj = super(RNGOpr, cls).load(opr)
if len(obj.params) == 3:
obj.opdef = builtin.GaussianRNG
obj.type = "GaussianRNG"
else:
obj.opdef = builtin.UniformRNG
obj.type = "UniformRNG"
return obj
class Linspace(OpNode):
type = "Linspace"
opdef = builtin.Linspace
@classmethod
def load(cls, opr):
obj = super(Linspace, cls).load(opr)
obj.params["comp_node"] = opr.outputs[0].comp_node
return obj
class Eye(OpNode):
type = "Eye"
opdef = builtin.Eye
@classmethod
def load(cls, opr):
obj = super(Eye, cls).load(opr)
obj.params["dtype"] = opr.outputs[0].dtype
obj.params["comp_node"] = opr.outputs[0].comp_node
return obj
class GetVarShape(OpNode):
type = "GetVarShape"
opdef = builtin.GetVarShape
class Concat(OpNode):
type = "Concat"
opdef = builtin.Concat
@classmethod
def load(cls, opr):
obj = super(Concat, cls).load(opr)
obj.params["comp_node"] = Device("xpux").to_c()
return obj
class Broadcast(OpNode):
type = "Broadcast"
opdef = builtin.Broadcast
class Identity(OpNode):
type = "Identity"
opdef = builtin.Identity
class NMSKeep(OpNode):
type = "NMSKeep"
opdef = builtin.NMSKeep
# class ParamPackSplit
# class ParamPackConcat
class Dimshuffle(OpNode):
type = "Dimshuffle"
opdef = builtin.Dimshuffle
@classmethod
def load(cls, opr):
obj = super(Dimshuffle, cls).load(opr)
del obj.params["ndim"]
return obj
class Reshape(OpNode):
type = "Reshape"
opdef = builtin.Reshape
class AxisAddRemove(OpNode):
type = "AxisAddRemove"
@classmethod
def load(cls, opr):
obj = cls()
obj.name = opr.name
obj._opr = opr
params = json.loads(opr.params)
desc = params["desc"]
method = None
axis = []
for i in desc:
if method is None:
method = i["method"]
assert method == i["method"]
axis.append(i["axisnum"])
obj.params = {"axis": axis}
obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis
return obj
class IndexingBase(OpNode):
@classmethod
def load(cls, opr):
obj = cls()
obj.name = opr.name
obj._opr = opr
params = json.loads(opr.params)
items = [
[
p["axis"],
bool(p["begin"]),
bool(p["end"]),
bool(p["step"]),
bool(p["idx"]),
]
for p in params
]
obj.params["items"] = items
return obj
class Subtensor(IndexingBase):
type = "Subtensor"
opdef = builtin.Subtensor
class SetSubtensor(IndexingBase):
type = "SetSubtensor"
opdef = builtin.SetSubtensor
class IncrSubtensor(IndexingBase):
type = "IncrSubtensor"
opdef = builtin.IncrSubtensor
class IndexingMultiAxisVec(IndexingBase):
type = "IndexingMultiAxisVec"
opdef = builtin.IndexingMultiAxisVec
class IndexingSetMultiAxisVec(IndexingBase):
type = "IndexingSetMultiAxisVec"
opdef = builtin.IndexingSetMultiAxisVec
class IndexingIncrMultiAxisVec(IndexingBase):
type = "IndexingIncrMultiAxisVec"
opdef = builtin.IndexingIncrMultiAxisVec
class MeshIndexing(IndexingBase):
type = "MeshIndexing"
opdef = builtin.MeshIndexing
class SetMeshIndexing(IndexingBase):
type = "SetMeshIndexing"
opdef = builtin.SetMeshIndexing
class IncrMeshIndexing(IndexingBase):
type = "IncrMeshIndexing"
opdef = builtin.IncrMeshIndexing
class BatchedMeshIndexing(IndexingBase):
type = "BatchedMeshIndexing"
opdef = builtin.BatchedMeshIndexing
class BatchedSetMeshIndexing(IndexingBase):
type = "BatchedSetMeshIndexing"
opdef = builtin.BatchedSetMeshIndexing
class BatchedIncrMeshIndexing(IndexingBase):
type = "BatchedIncrMeshIndexing"
opdef = builtin.BatchedIncrMeshIndexing
# class CollectiveComm
# class RemoteSend
# class RemoteRecv
# class TQT
# class FakeQuant
# class InplaceAdd
class AssertEqual(OpNode):
type = "AssertEqual"
opdef = builtin.AssertEqual
class ElemwiseMultiType(OpNode):
type = "ElemwiseMultiType"
opdef = builtin.ElemwiseMultiType
@classmethod
def load(cls, opr):
obj = super(ElemwiseMultiType, cls).load(opr)
obj.params["dtype"] = opr.outputs[0].dtype
return obj
class CvtColorForward(OpNode):
type = "CvtColor"
opdef = builtin.CvtColor
......@@ -160,6 +160,16 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
if (ctx.op->same_type<BackwardGraph>()) {
ctx.backward = true;
}
if (py::isinstance<cg::VarNode>(py::handle(args[0]))){
SmallVector<cg::VarNode*> vinputs(nargs);
for (size_t i = 0; i < nargs; ++i) {
vinputs[i] = py::handle(args[i]).cast<cg::VarNode *>();
}
auto op = ctx.op.get();
return to_tuple(OpDef::apply_on_var_node(*op, vinputs)).release().ptr();
}
for (size_t i = 0; i < nargs; ++i) {
if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
......@@ -675,6 +685,16 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) {
tensors.emplace_back(descr);
continue;
}
if (py::isinstance<cg::VarNode>(py::handle(handle))){
auto var = py::handle(handle).cast<cg::VarNode *>();
mgb::DType type = var->dtype();
auto && descr = npy::dtype_mgb2np_descr(type);
Py_INCREF(descr.get());
tensors.emplace_back(descr.get());
continue;
}
PyArray_Descr* descr = scalar2dtype(handle);
if (descr) {
scalars.emplace_back(descr);
......@@ -719,12 +739,14 @@ CompNode _get_device(PyObject*const* args, size_t nargs) {
for (size_t i = 0; i < nargs; ++i) {
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i];
TensorWrapper* tw = TensorWrapper::try_cast(handle);
if (tw) {
bool is_var = py::isinstance<cg::VarNode>(py::handle(handle));
if (tw || is_var) {
if (!valid) {
cn = tw->m_tensor->comp_node();
cn = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node();
valid = true;
} else {
CompNode cn1 = tw->m_tensor->comp_node();
CompNode cn1 = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node();
if (cn1 != cn) {
throw py::value_error(ssprintf("ambiguous device: %s vs %s",
cn.to_string().c_str(), cn1.to_string().c_str()));
......
import io
import numpy as np
import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F
import megengine.module as M
import megengine.utils.network_node as N
from megengine.jit.tracing import trace
from megengine.tensor import Tensor
from megengine.utils.comp_graph_tools import GraphInference
from megengine.utils.network import Network as Net
from megengine.utils.network import as_oprnode
from megengine.utils.network_node import Host2DeviceCopy, VarNode
def test_replace_var():
a = Tensor([1, 2])
b = Tensor([3, 4])
@trace(symbolic=True, capture_as_const=True)
def fwd(a, b):
return (a + b) * 2
fwd(a, b)
orig_model = io.BytesIO()
fwd.dump(
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
)
orig_model.seek(0)
graph = Net.load(orig_model)
vara = graph.var_filter.name("a").as_unique()
varb = graph.var_filter.name("b").as_unique()
out = F.mul(vara.var, varb.var)
out = F.relu(out)
var_list = graph.add_dep_oprs(out)
opnode = list(graph.opr_filter.has_input(vara))
repl_dict = {opnode[0].outputs[0]: var_list[0]}
graph.replace_vars(repl_dict)
modified_model = io.BytesIO()
graph.dump(modified_model)
modified_model.seek(0)
load_graph = GraphInference(modified_model)
out = load_graph.run(a, b)
np.testing.assert_equal(out["o"], [6, 16])
def test_replace_opr():
a = Tensor([1, 2])
b = Tensor([3, 4])
@trace(symbolic=True, capture_as_const=True)
def fwd(a, b):
return (a + b) * 2
fwd(a, b)
orig_model = io.BytesIO()
fwd.dump(
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
)
orig_model.seek(0)
graph = Net.load(orig_model)
vara = graph.var_filter.name("a").as_unique()
varb = graph.var_filter.name("b").as_unique()
out1 = F.sub(vara.var, varb.var)
out1 = F.relu(out1)
var_list = graph.add_dep_oprs(out1)
repl_opr = as_oprnode(var_list)
orig_opr = graph.opr_filter.has_input(vara).as_unique()
repl_dict = {orig_opr: repl_opr}
graph.replace_oprs(repl_dict)
modified_model1 = io.BytesIO()
graph.dump(modified_model1)
modified_model1.seek(0)
load_graph = GraphInference(modified_model1)
out = load_graph.run(a, b)
np.testing.assert_equal(out["o"], [0, 0])
def test_modify_params():
a = Tensor([1, 2])
b = Tensor([3, 4])
@trace(symbolic=True, capture_as_const=True)
def fwd(a, b):
return (a + b) * 2
fwd(a, b)
orig_model = io.BytesIO()
fwd.dump(
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
)
orig_model.seek(0)
graph = Net.load(orig_model)
param_const = graph.params_filter.as_unique()
param_const.set_value(3)
modified_model = io.BytesIO()
graph.dump(modified_model)
modified_model.seek(0)
load_graph = GraphInference(modified_model)
out = load_graph.run(a, b)
np.testing.assert_equal(out["o"], [12, 18])
def test_make_const():
a = Tensor([1, 2])
b = Tensor([3, 4])
@trace(symbolic=True, capture_as_const=True)
def fwd(a, b):
return (a + b) * 2
fwd(a, b)
orig_model = io.BytesIO()
fwd.dump(
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
)
orig_model.seek(0)
graph = Net.load(orig_model)
const_b = graph.make_const(np.array([0.0, 0.0]), name="b")
varb = graph.var_filter.name("b").as_unique()
repl_dict = {varb: const_b}
graph.replace_vars(repl_dict)
modified_model = io.BytesIO()
graph.dump(modified_model)
modified_model.seek(0)
load_graph = GraphInference(modified_model)
out = load_graph.run(a)
np.testing.assert_equal(out["o"], [2, 4])
def test_add_input():
a = Tensor([1, 2])
b = Tensor([3, 4])
@trace(symbolic=True, capture_as_const=True)
def fwd(a, b):
return (a + b) * 2
fwd(a, b)
orig_model = io.BytesIO()
fwd.dump(
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
)
orig_model.seek(0)
graph = Net.load(orig_model)
inp_c = graph.make_input_node((2,), np.int32, name="c")
varo = graph.var_filter.name("o").as_unique()
out = F.add(varo.var, inp_c.var)
out = graph.add_dep_oprs(out)[0]
out.name = "o1"
graph.remove_output(varo)
graph.add_output(out)
modified_model = io.BytesIO()
graph.dump(modified_model)
modified_model.seek(0)
load_graph = GraphInference(modified_model)
out = load_graph.run(a, b, a)
np.testing.assert_equal(out["o1"], ((a + b) * 2 + a).numpy())
def test_add_output():
a = Tensor([1.0, 2.0])
b = Tensor([3.0, 4.0])
@trace(symbolic=True, capture_as_const=True)
def fwd(a, b):
return (a + b) * 2
fwd(a, b)
orig_model = io.BytesIO()
fwd.dump(
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
)
orig_model.seek(0)
net = Net.load(orig_model)
var_a = net.var_filter.name("a").as_unique()
var_b = net.var_filter.name("b").as_unique()
y = F.add(var_a.var, var_b.var)
y = F.sigmoid(y)
new_vars = net.add_dep_oprs(y)[0]
new_vars.name = "o1"
net.add_output(new_vars)
modified_model = io.BytesIO()
net.dump(modified_model)
modified_model.seek(0)
g = GraphInference(modified_model)
out = g.run(a.numpy(), b.numpy())
np.testing.assert_equal(out["o"], ((a + b) * 2).numpy())
np.testing.assert_equal(out["o1"], (F.sigmoid((a + b))).numpy())
def test_query():
class Model(M.Module):
def __init__(self):
super().__init__()
self.conv1 = M.Conv2d(3, 32, 3)
self.conv2 = M.Conv2d(32, 32, 3)
self.conv3 = M.Conv2d(32, 32, 3)
def forward(self, data):
x = self.conv1(data)
x = self.conv2(x)
x = self.conv3(x)
return x
n = Model()
@trace(symbolic=True, capture_as_const=True)
def fwd(data):
return n(data)
fwd(Tensor(np.random.random((1, 3, 224, 224))))
orig_model = io.BytesIO()
fwd.dump(
orig_model,
arg_names=["data"],
output_names="o",
keep_opr_name=True,
keep_var_name=True,
optimize_for_inference=False,
)
orig_model.seek(0)
graph = Net.load(orig_model)
r = graph.data_providers_filter.as_count()
assert r == 1
opr = graph.get_opr_by_type(Host2DeviceCopy)
assert isinstance(opr, Host2DeviceCopy)
r1 = graph.params_filter.as_count()
assert r1 == 6
r2 = graph.opr_filter.type(N.ConvolutionForward).as_count()
assert r2 == 3
r3 = graph.opr_filter.not_type(N.ConvolutionForward).as_count()
assert r3 == len(graph.all_oprs) - r2
var = graph.var_filter.name("data").as_unique()
r4 = graph.opr_filter.has_input(var).as_count()
assert r4 == 1
r5 = graph.opr_filter.name("data").as_count()
assert r5 == 1
opr = graph.get_opr_by_name("data")
assert isinstance(opr, Host2DeviceCopy)
var = graph.get_var_by_name("data")
assert isinstance(var, VarNode)
r6 = graph.var_filter.name("*bias").as_count()
assert r6 == 3
def test_optimize_for_inference():
@trace(symbolic=True, capture_as_const=True)
def f(x):
return F.exp(x)
orig_model = io.BytesIO()
f(Tensor(5.0))
f.dump(orig_model, optimize_for_inference=False)
orig_model.seek(0)
optimize_model = io.BytesIO()
net = Net.load(orig_model)
net.dump(optimize_model, enable_io16xc32=True)
optimize_model.seek(0)
res = G.load_graph(optimize_model)
computing_input = res.output_vars_list[0].owner.inputs[0]
assert computing_input.dtype == np.float16
def test_reset_batchsize():
@trace(symbolic=True, capture_as_const=True)
def f(x):
return F.exp(x)
orig_model = io.BytesIO()
f(Tensor(np.random.random((3, 3, 224, 224))))
f.dump(orig_model, optimize_for_inference=False)
orig_model.seek(0)
modified_model = io.BytesIO()
net = Net.load(orig_model)
net.reset_batch_size(1)
net.dump(modified_model, optimize_for_inference=False)
modified_model.seek(0)
net1 = Net.load(modified_model)
assert net1.data_providers_filter.as_unique().shape[0] == 1
def test_modify_opr_name():
@trace(symbolic=True, capture_as_const=True)
def f(x):
return F.exp(x)
orig_model = io.BytesIO()
f(Tensor(np.random.random((3, 3, 224, 224))))
f.dump(orig_model, arg_names=["a"], optimize_for_inference=False)
orig_model.seek(0)
modified_model = io.BytesIO()
net = Net.load(orig_model)
net.modify_opr_names("net")
net.modify_opr_names(lambda x: "net1." + x)
net.dump(modified_model, optimize_for_inference=False)
modified_model.seek(0)
net1 = Net.load(modified_model)
assert net1.data_providers_filter.as_unique().name == "net1.net.a"
import io
import os
import platform
import numpy as np
import pytest
import megengine.core.tensor.dtype as dtype
import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F
import megengine.module as M
import megengine.random as rand
from megengine.core._imperative_rt.core2 import apply
from megengine.core._wrap import Device
from megengine.core.ops import builtin
from megengine.device import is_cuda_available
from megengine.functional.external import tensorrt_runtime_opr
from megengine.jit.tracing import trace
from megengine.tensor import Tensor
from megengine.utils.comp_graph_tools import GraphInference
from megengine.utils.network import Network as Net
def check_pygraph_dump(trace_func, inp_data, expect_results):
orig_model = io.BytesIO()
inp_size = len(inp_data)
out_size = len(expect_results)
arg_names = ["arg_{}".format(i) for i in range(inp_size)]
output_names = ["out_{}".format(i) for i in range(out_size)]
trace_func.dump(
orig_model,
arg_names=arg_names,
output_names=output_names,
optimize_for_inference=False,
)
orig_model.seek(0)
net = Net.load(orig_model)
file = io.BytesIO()
net.dump(file, optimize_for_inference=False)
file.seek(0)
graph = GraphInference(file)
inp_dict = dict([(arg_names[i], inp_data[i].numpy()) for i in range(inp_size)])
results = graph.run(inp_dict=inp_dict)
for ind, tensor in enumerate(expect_results):
np.testing.assert_equal(tensor.numpy(), results[output_names[ind]])
assert tensor.dtype == results[output_names[ind]].dtype
def test_elemwise():
@trace(symbolic=True, capture_as_const=True)
def fwd(x, y):
z1 = x * y
z2 = x + y
z3 = z1 / z2
z3 = z3 ** 3
return z3
x = Tensor([1.0, 2.0])
y = Tensor([3.0, 5.0])
result = fwd(x, y)
check_pygraph_dump(fwd, [x, y], [result])
def test_reduce():
@trace(symbolic=True, capture_as_const=True)
def fwd(data):
x = data.sum(axis=2)
x = x.mean(axis=1)
return x
data = Tensor(np.random.random((1, 32, 32)))
result = fwd(data)
check_pygraph_dump(fwd, [data], [result])
def test_typecvt():
@trace(symbolic=True, capture_as_const=True)
def fwd(data):
return data.astype(dtype.qint8(0.8))
x = Tensor(np.random.random((2, 3)) * 255)
result = fwd(x)
check_pygraph_dump(fwd, [x], [result])
def test_matinv():
@trace(symbolic=True, capture_as_const=True)
def fwd(data):
return F.matinv(data)
data = Tensor(np.random.random((5, 5)))
result = fwd(data)
check_pygraph_dump(fwd, [data], [result])
def test_matmul():
@trace(symbolic=True, capture_as_const=True)
def fwd(data1, data2):
return F.matmul(data1, data2)
data1 = Tensor(np.random.random((32, 64)))
data2 = Tensor(np.random.random((64, 16)))
result = fwd(data1, data2)
check_pygraph_dump(fwd, [data1, data2], [result])
def test_batchmatmul():
@trace(symbolic=True, capture_as_const=True)
def fwd(x, y):
return F.matmul(x, y)
x = Tensor(np.random.random((3, 3, 5)))
y = Tensor(np.random.random((3, 5, 3)))
result = fwd(x, y)
check_pygraph_dump(fwd, [x, y], [result])
def test_dot():
@trace(symbolic=True, capture_as_const=True)
def fwd(x, y):
return F.dot(x, y)
x = Tensor([1.0, 2.0, 3.0])
y = Tensor([3.0, 4.0, 5.0])
result = fwd(x, y)
check_pygraph_dump(fwd, [x, y], [result])
def test_svd():
@trace(symbolic=True, capture_as_const=True)
def fwd(data):
_, out, _ = F.svd(data)
return out
input = Tensor(np.random.random((1, 1, 3, 3)))
result = fwd(input)
check_pygraph_dump(fwd, [input], [result])
def test_conv():
conv = M.Conv2d(3, 32, 3)
@trace(symbolic=True, capture_as_const=True)
def fwd(data):
return conv(data)
data = Tensor(np.random.random((1, 3, 32, 32)))
result = fwd(data)
check_pygraph_dump(fwd, [data], [result])
def test_deformable_conv():
if not is_cuda_available():
return
conv = M.DeformableConv2d(3, 32, 3)
@trace(symbolic=True, capture_as_const=True)
def fwd(data, offset, mask):
return conv(data, offset, mask)
data = Tensor(np.random.random((1, 3, 32, 32)))
offset = Tensor(np.ones((32, 3 * 3 * 2, 30, 30)).astype("int32") * 5)
mask = Tensor(np.ones((32, 3 * 3, 30, 30)).astype("int32"))
out = fwd(data, offset, mask)
check_pygraph_dump(fwd, [data, offset, mask], [out])
def test_convtranspose():
deconv = M.ConvTranspose2d(32, 32, 3)
@trace(symbolic=True, capture_as_const=True)
def fwd(data):
return deconv(data)
data = Tensor(np.random.random((1, 32, 32, 32)))
result = fwd(data)
check_pygraph_dump(fwd, [data], [result])
@pytest.mark.skip(reason="pytest aborted")
def test_grouplocal():
n = M.LocalConv2d(3, 32, 32, 32, 3)
@trace(symbolic=True, capture_as_const=True)
def fwd(data):
return n(data)
input = Tensor(np.random.random((1, 3, 32, 32)))
result = fwd(input)
check_pygraph_dump(fwd, [input], [result])
def test_pooling():
@trace(symbolic=True, capture_as_const=True)
def fwd(data):
out = F.max_pool2d(data, 2, 2)
out = F.avg_pool2d(out, 2, 2)
return out
data = Tensor(np.random.random((1, 3, 64, 64)))
result = fwd(data)
check_pygraph_dump(fwd, [data], [result])
def test_adaptivepooling():
pool1 = M.AdaptiveMaxPool2d((2, 2))
pool2 = M.AdaptiveAvgPool2d((2, 2))
@trace(symbolic=True, capture_as_const=True)
def fwd(data):
out = pool1(data)
out = pool2(out)
return out
input = Tensor(np.random.random((1, 3, 32, 32)))
result = fwd(input)
check_pygraph_dump(fwd, [input], [result])
def test_roipooling():
inp = Tensor(np.random.random((1, 1, 128, 128)))
rois = Tensor(np.random.random((4, 5)))
@trace(symbolic=True, capture_as_const=True)
def fwd(inp, rois):
return F.nn.roi_pooling(inp, rois, (2, 2), scale=2.0)
output = fwd(inp, rois)
check_pygraph_dump(fwd, [inp, rois], [output])
def test_deformable_ps_roi_pooling():
inp = Tensor(np.random.random((1, 256, 64, 64)).astype("float32"))
rois = Tensor(np.random.random((1, 5)).astype("float32"))
trans = Tensor(np.random.random((24, 2, 7, 7)).astype("float32"))
pooled_h = 7
pooled_w = 7
sample_per_part = 4
no_trans = False
part_size = 7
spatial_scale = 1.0 / 64
trans_std = 0.1
@trace(symbolic=True, capture_as_const=True)
def fwd(inp, rois, trans):
y = F.deformable_psroi_pooling(
inp,
rois,
trans,
no_trans,
part_size,
pooled_h,
pooled_w,
sample_per_part,
spatial_scale,
trans_std,
)
return y
result = fwd(inp, rois, trans)
check_pygraph_dump(fwd, [inp, rois, trans], [result])
def test_convbias():
@trace(symbolic=True, capture_as_const=True)
def fwd(inp, weight, bias):
return F.quantized.conv_bias_activation(
inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="RELU"
)
inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0))
weight = Tensor(np.random.random((32, 3, 3, 3)), dtype=dtype.qint8(scale=1.0))
bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0))
result = fwd(inp, weight, bias)
check_pygraph_dump(fwd, [inp, weight, bias], [result])
def test_batch_convbias():
if is_cuda_available():
return
@trace(symbolic=True, capture_as_const=True)
def fwd(inp, weight, bias):
return F.quantized.batch_conv_bias_activation(
inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="RELU"
)
inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0))
weight = Tensor(np.random.random((1, 32, 3, 3, 3)), dtype=dtype.qint8(scale=1.0))
bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0))
result = fwd(inp, weight, bias)
check_pygraph_dump(fwd, [inp, weight, bias], [result])
def test_batchnorm():
bn = M.BatchNorm2d(32)
bn.eval()
@trace(symbolic=True, capture_as_const=True)
def fwd(data):
return bn(data)
data = Tensor(np.random.random((1, 32, 32, 32)))
result = fwd(data)
check_pygraph_dump(fwd, [data], [result])
def test_roialign():
inp = Tensor(np.random.randn(1, 1, 128, 128))
rois = Tensor(np.random.random((4, 5)))
@trace(symbolic=True, capture_as_const=True)
def fwd(inp, rois):
return F.nn.roi_align(inp, rois, (2, 2))
output = fwd(inp, rois)
check_pygraph_dump(fwd, [inp, rois], [output])
def test_warpperspective():
inp_shape = (1, 1, 4, 4)
x = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
M_shape = (1, 3, 3)
# M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
M = Tensor(
np.array(
[[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
).reshape(M_shape)
)
@trace(symbolic=True, capture_as_const=True)
def fwd(x, M):
return F.warp_perspective(x, M, (2, 2))
result = fwd(x, M)
check_pygraph_dump(fwd, [x, M], [result])
def test_warpaffine():
inp_shape = (1, 3, 3, 3)
x = Tensor(np.arange(27, dtype=np.float32).reshape(inp_shape))
weightv = Tensor([[[1.26666667, 0.6, -83.33333333], [-0.33333333, 1, 66.66666667]]])
@trace(symbolic=True, capture_as_const=True)
def fwd(x, weightv):
return F.warp_affine(x, weightv, (2, 2), border_mode="WRAP")
outp = fwd(x, weightv)
check_pygraph_dump(fwd, [x, weightv], [outp])
def test_remap():
inp_shape = (1, 1, 4, 4)
inp = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
map_xy_shape = (1, 2, 2, 2)
map_xy = Tensor(
np.array(
[[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]], dtype=np.float32
).reshape(map_xy_shape)
)
@trace(symbolic=True, capture_as_const=True)
def fwd(inp, map_xy):
return F.remap(inp, map_xy)
out = fwd(inp, map_xy)
check_pygraph_dump(fwd, [inp, map_xy], [out])
def test_resize():
x = Tensor(np.random.randn(10, 3, 32, 32))
@trace(symbolic=True, capture_as_const=True)
def fwd(x):
return F.nn.interpolate(x, size=(16, 16), mode="BILINEAR")
out = fwd(x)
check_pygraph_dump(fwd, [x], [out])
def test_index_onehot():
src = Tensor([[1.0, 2.0]])
index = Tensor([0])
@trace(symbolic=True, capture_as_const=True)
def fwd(src, index):
return F.indexing_one_hot(src, index)
out = fwd(src, index)
check_pygraph_dump(fwd, [src, index], [out])
def test_set_onehot():
x = Tensor(np.arange(1, 4, dtype=np.int32))
@trace(symbolic=True, capture_as_const=True)
def fwd(x):
return F.one_hot(x, num_classes=4)
out = fwd(x)
check_pygraph_dump(fwd, [x], [out])
def test_copy():
x = Tensor([1, 2, 3])
@trace(symbolic=True, capture_as_const=True)
def fwd(x):
return x.to("cpu0:0")
o = fwd(x)
check_pygraph_dump(fwd, [x], [o])
def test_argsort():
@trace(symbolic=True, capture_as_const=True)
def fwd(data):
return F.argsort(data, True)
data = Tensor([1.0, 2.0, 3.0, 5.0])
result = fwd(data)
check_pygraph_dump(fwd, [data], [result])
def test_argmax_min():
@trace(symbolic=True, capture_as_const=True)
def fwd(data):
return F.argmax(data), F.argmin(data)
data = Tensor(np.random.random((10, 10)))
result = fwd(data)
check_pygraph_dump(fwd, [data], result)
def test_condtake():
mask = Tensor(np.array([[True, False], [False, True]], dtype=np.bool_))
x = Tensor(np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32))
@trace(symbolic=True, capture_as_const=True)
def fwd(mask, x):
v, index = F.cond_take(mask, x)
return v, index
v, index = fwd(mask, x)
check_pygraph_dump(fwd, [mask, x], [v, index])
def test_topk():
x = Tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
@trace(symbolic=True, capture_as_const=True)
def fwd(x):
top, indices = F.topk(x, 5)
return top, indices
top, indices = fwd(x)
check_pygraph_dump(fwd, [x], [top, indices])
def test_random():
@trace(symbolic=True, capture_as_const=True)
def fwd():
x = rand.uniform(size=(2, 2))
y = rand.normal(size=(1, 3, 3, 3))
return x, y
x, y = fwd()
check_pygraph_dump(fwd, [], [x, y])
def test_tensor_gen():
@trace(symbolic=True, capture_as_const=True)
def fwd():
a = F.linspace(3, 10, 3, device=Device("xpux").to_c())
b = F.eye(3, device=Device("xpux").to_c())
return a, b
a, b = fwd()
check_pygraph_dump(fwd, [], [a, b])
def test_getvarshape():
op = builtin.GetVarShape(axis=1)
@trace(symbolic=True, capture_as_const=True)
def fwd(data):
return apply(op, data)[0]
data = Tensor(np.random.random((1, 2, 3, 4)))
result = fwd(data)
check_pygraph_dump(fwd, [data], [result])
def test_concat():
@trace(symbolic=True, capture_as_const=True)
def fwd(data1, data2):
return F.concat([data1, data2], axis=1)
x = Tensor(np.random.random((2, 3)))
y = Tensor(np.random.random((2, 5)))
result = fwd(x, y)
check_pygraph_dump(fwd, [x, y], [result])
def test_broadcast():
inp = Tensor([[1], [2], [3], [4]])
@trace(symbolic=True, capture_as_const=True)
def fwd(inp):
return F.broadcast_to(inp, (4, 4))
out = fwd(inp)
check_pygraph_dump(fwd, [inp], [out])
def test_identity():
@trace(symbolic=True, capture_as_const=True)
def fwd(data):
return F.copy(data)
data = Tensor([1.0, 2.0])
result = fwd(data)
check_pygraph_dump(fwd, [data], [result])
@pytest.mark.skip(reason="advance indexing trace error")
def test_nms():
x = np.zeros((100, 4))
np.random.seed(42)
x[:, :2] = np.random.rand(100, 2) * 20
x[:, 2:] = np.random.rand(100, 2) * 20 + 100
scores = Tensor(np.random.rand(100))
inp = Tensor(x)
@trace(symbolic=True, capture_as_const=True)
def fwd(inp, scores):
return F.nn.nms(inp, scores, iou_thresh=0.7, max_output=3)
result = fwd(inp, scores)
check_pygraph_dump(fwd, [inp, scores], [result])
def test_dimshuffle():
inp = Tensor([1, 2, 3, 4])
@trace(symbolic=True, capture_as_const=True)
def fwd(inp):
return inp.T
out = fwd(inp)
check_pygraph_dump(fwd, [inp], [out])
def test_reshape():
@trace(symbolic=True, capture_as_const=True)
def fwd(data):
return data.reshape((1, 8))
data = Tensor(np.random.random((1, 2, 2, 2)))
result = fwd(data)
check_pygraph_dump(fwd, [data], [result])
def test_add_remove_axis():
@trace(symbolic=True, capture_as_const=True)
def fwd(data):
x = F.expand_dims(data, [0, 0])
y = F.squeeze(x, 0)
return y
data = Tensor([1.0, 2.0])
result = fwd(data)
check_pygraph_dump(fwd, [data], [result])
@pytest.mark.parametrize("mode", ["get", "set", "inc"])
def test_subtensor(mode):
items = [[0, True, True, True, False], [1, False, False, False, True]]
data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random(2))]
if mode == "get":
op = builtin.Subtensor(items)
data = data[:1]
if mode == "set":
op = builtin.SetSubtensor(items)
if mode == "inc":
op = builtin.IncrSubtensor(items)
tensors = [Tensor(0), Tensor(4), Tensor(2), Tensor(3)]
@trace(symbolic=True, capture_as_const=True)
def fwd(*tensors):
return apply(op, *tensors)[0]
result = fwd(*data, *tensors)
check_pygraph_dump(fwd, data + tensors, [result])
@pytest.mark.parametrize("mode", ["get", "set", "inc"])
def test_advance_indexing(mode):
items = [[0, False, False, False, True]]
tensors = [Tensor([0, 4, 2])]
data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random((3, 5)))]
if mode == "get":
op = builtin.IndexingMultiAxisVec(items)
data = data[:1]
if mode == "set":
op = builtin.IndexingSetMultiAxisVec(items)
if mode == "inc":
op = builtin.IndexingIncrMultiAxisVec(items)
@trace(symbolic=True, capture_as_const=True)
def fwd(*tensors):
return apply(op, *tensors)[0]
result = fwd(*data, *tensors)
check_pygraph_dump(fwd, data + tensors, [result])
@pytest.mark.parametrize("mode", ["get", "set", "inc"])
def test_mesh_indexing(mode):
items = [[0, True, True, True, False], [1, False, False, False, True]]
tensors = [Tensor(0), Tensor(5), Tensor(2), Tensor([1, 3])]
data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random((3, 2)))]
if mode == "get":
op = builtin.IndexingMultiAxisVec(items)
data = data[:1]
if mode == "set":
op = builtin.IndexingSetMultiAxisVec(items)
if mode == "inc":
op = builtin.IndexingIncrMultiAxisVec(items)
@trace(symbolic=True, capture_as_const=True)
def fwd(*tensors):
return apply(op, *tensors)[0]
result = fwd(*data, *tensors)
check_pygraph_dump(fwd, data + tensors, [result])
@pytest.mark.parametrize("mode", ["get", "set", "inc"])
def test_batch_mesh_indexing(mode):
items = [[1, False, False, False, True], [2, False, False, False, True]]
tensors = [Tensor([[0, 2], [0, 2]]), Tensor([[0, 1, 2], [1, 2, 3]])]
data = [Tensor(np.random.random((2, 3, 4))), Tensor(np.random.random((2, 2, 3)))]
if mode == "get":
op = builtin.BatchedMeshIndexing(items)
data = data[:1]
if mode == "set":
op = builtin.BatchedSetMeshIndexing(items)
if mode == "inc":
op = builtin.BatchedIncrMeshIndexing(items)
@trace(symbolic=True, capture_as_const=True)
def fwd(*tensors):
return apply(op, *tensors)[0]
result = fwd(*data, *tensors)
check_pygraph_dump(fwd, data + tensors, [result])
@pytest.mark.skip(reason="tmp skip")
def test_assert_equal():
g = G.Graph()
inp1 = g.make_h2d(dtype=np.float32, device="xpux")
inp2 = g.make_h2d(dtype=np.float32, device="xpux")
op = builtin.AssertEqual(maxerr=1e-5)
out = G.apply_normal_varnode(op, inp1._node, inp2._node)[0]
print(out)
g.compile(out)
file = io.BytesIO()
out_model = G.dump_graph([out])
file.write(out_model[0])
file.seek(0)
net = Net.load(file)
dump_file = io.BytesIO()
net.dump(dump_file)
dump_file.seek(0)
g = GraphInference(dump_file)
g.run(np.array([1.0, 2.0]), np.array([1.0, 2.0]))
def test_elemwise_multitype():
op = builtin.ElemwiseMultiType(mode="QADD", dtype=dtype.qint32(2.0))
@trace(symbolic=True, capture_as_const=True)
def fwd(x, y):
return apply(op, x, y)[0]
x = Tensor(np.random.random(10) * 10, dtype=dtype.qint8(2.0))
y = Tensor(np.random.random(10) * 10, dtype=dtype.qint8(2.0))
result = fwd(x, y)
check_pygraph_dump(fwd, [x, y], [result])
def test_cvtcolor():
inp = np.random.randn(3, 3, 3, 3).astype(np.float32)
x = Tensor(inp)
@trace(symbolic=True, capture_as_const=True)
def fwd(inp):
return F.img_proc.cvt_color(inp, mode="RGB2GRAY")
result = fwd(x)
check_pygraph_dump(fwd, [x], [result])
......@@ -17,9 +17,20 @@
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/dnn/roi_pooling.h"
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/standalone/nms_opr.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/rand.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/indexing.h"
#include "megbrain/opr/internal/indexing_helper.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_gen.h"
#if MGB_ENABLE_JSON
#include "megdnn/opr_param_json.h"
#endif
......@@ -354,7 +365,7 @@ uint64_t opr_footprint_func<opr::DeformableConvForward>(
auto&& out_shape = opr->output()[0]->shape();
auto&& filter_shape = opr->input()[1]->shape();
using Param = opr::DeformableConvForward::Param;
auto&& param = opr->cast_final_safe<opr::Convolution>().param();
auto&& param = opr->cast_final_safe<opr::DeformableConvForward>().param();
size_t fh, fw, icpg;
mgb_assert(param.format == Param::Format::NCHW);
if (param.sparse == Param::Sparse::GROUP) {
......@@ -425,9 +436,11 @@ uint64_t opr_footprint_func<opr::BatchConvBiasForward>(
auto&& filter_shape = opr->input()[1]->shape();
using Param = opr::BatchConvBiasForward::Param;
auto&& param = opr->cast_final_safe<opr::BatchConvBiasForward>().param();
mgb_assert(param.format == Param::Format::NCHW4);
size_t packed_channels = 4;
size_t packed_channels = 1;
size_t kern_spatial_pos = 3;
if (param.format == Param::Format::NCHW4) {
packed_channels = 4;
}
size_t fh = filter_shape[kern_spatial_pos],
fw = filter_shape[kern_spatial_pos + 1];
return out_shape.total_nr_elems() * fh * fw * src_shape[1] *
......@@ -508,7 +521,29 @@ REGISTE_PARAM_JSON_FUNC(LocalShareBackwardFilter)
REGISTE_PARAM_JSON_FUNC(DeformableConvForward)
REGISTE_PARAM_JSON_FUNC(DeformableConvBackwardFilter)
REGISTE_PARAM_JSON_FUNC(DeformableConvBackwardData)
REGISTE_PARAM_JSON_FUNC(DeformablePSROIPoolingForward)
REGISTE_PARAM_JSON_FUNC(BatchConvBiasForward)
REGISTE_PARAM_JSON_FUNC(BatchNormForward)
REGISTE_PARAM_JSON_FUNC(ElemwiseMultiType)
REGISTE_PARAM_JSON_FUNC(Argsort)
REGISTE_PARAM_JSON_FUNC(Argmax)
REGISTE_PARAM_JSON_FUNC(Argmin)
REGISTE_PARAM_JSON_FUNC(AdaptivePooling)
REGISTE_PARAM_JSON_FUNC(ROIPooling)
REGISTE_PARAM_JSON_FUNC(ROIAlign)
REGISTE_PARAM_JSON_FUNC(WarpPerspective)
REGISTE_PARAM_JSON_FUNC(WarpAffine)
REGISTE_PARAM_JSON_FUNC(Remap)
REGISTE_PARAM_JSON_FUNC(Resize)
REGISTE_PARAM_JSON_FUNC(IndexingOneHot)
REGISTE_PARAM_JSON_FUNC(IndexingSetOneHot)
REGISTE_PARAM_JSON_FUNC(TopK)
REGISTE_PARAM_JSON_FUNC(UniformRNG)
REGISTE_PARAM_JSON_FUNC(GaussianRNG)
REGISTE_PARAM_JSON_FUNC(Linspace)
REGISTE_PARAM_JSON_FUNC(Eye)
REGISTE_PARAM_JSON_FUNC(CvtColor)
template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::Dimshuffle>(
......@@ -547,24 +582,83 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::AxisAddRemove>(
});
}
std::shared_ptr<json::Value> indexing_param_to_json(
const std::vector<opr::indexing::AxisIndexer>& indices) {
auto desc = json::Array::make();
for (auto& index : indices) {
desc->add(json::Object::make({
{"axis", json::NumberInt::make(index.axis.get_raw())},
{"begin",
json::NumberInt::make(index.begin.node() != nullptr)},
{"end", json::NumberInt::make(index.end.node() != nullptr)},
{"step",
json::NumberInt::make(index.step.node() != nullptr)},
{"idx", json::NumberInt::make(index.idx.node() != nullptr)},
}));
}
return desc;
}
#define REGISTE_INDEXING_PARAM_JSON_FUNC(cls) \
template <> \
std::shared_ptr<json::Value> opr_param_json_func<opr::cls>( \
cg::OperatorNodeBase * opr) { \
auto indices = opr->cast_final_safe<opr::cls>().index_desc(); \
return indexing_param_to_json(indices); \
}
REGISTE_INDEXING_PARAM_JSON_FUNC(Subtensor);
REGISTE_INDEXING_PARAM_JSON_FUNC(SetSubtensor);
REGISTE_INDEXING_PARAM_JSON_FUNC(IncrSubtensor);
REGISTE_INDEXING_PARAM_JSON_FUNC(IndexingMultiAxisVec);
REGISTE_INDEXING_PARAM_JSON_FUNC(IndexingSetMultiAxisVec);
REGISTE_INDEXING_PARAM_JSON_FUNC(IndexingIncrMultiAxisVec);
REGISTE_INDEXING_PARAM_JSON_FUNC(MeshIndexing);
REGISTE_INDEXING_PARAM_JSON_FUNC(IncrMeshIndexing);
REGISTE_INDEXING_PARAM_JSON_FUNC(SetMeshIndexing);
REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedMeshIndexing);
REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedIncrMeshIndexing);
REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedSetMeshIndexing);
template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::Subtensor>(
std::shared_ptr<json::Value> opr_param_json_func<opr::Reshape>(
cg::OperatorNodeBase * opr) {
auto desc = json::Array::make();
auto indices = opr->cast_final_safe<opr::Subtensor>().index_desc();
for (auto &index : indices){
desc->add(
json::Object::make({
{"axis", json::NumberInt::make(index.axis.get_raw())},
{"begin", json::NumberInt::make(index.begin.node() != nullptr)},
{"end", json::NumberInt::make(index.end.node() != nullptr)},
{"step", json::NumberInt::make(index.step.node() != nullptr)},
{"idx", json::NumberInt::make(index.idx.node() != nullptr)},
}));
auto axis_param = opr->cast_final_safe<opr::Reshape>().param();
if (axis_param.axis != axis_param.MAX_NDIM){
return json::Object::make({
{"axis", json::NumberInt::make(axis_param.axis)},
});
} else {
return json::Object::make();
}
}
return desc;
template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::GetVarShape>(
cg::OperatorNodeBase * opr) {
auto desc = json::Array::make();
auto axis_param = opr->cast_final_safe<opr::GetVarShape>().param();
if (axis_param.axis != axis_param.MAX_NDIM){
return json::Object::make({
{"axis", json::NumberInt::make(axis_param.axis)},
});
} else {
return json::Object::make();
}
}
template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::standalone::NMSKeep>(
cg::OperatorNodeBase * opr) {
auto nms_param = opr->cast_final_safe<opr::standalone::NMSKeep>().param();
return json::Object::make({
{"iou_thresh", json::Number::make(nms_param.iou_thresh)},
{"max_output", json::Number::make(nms_param.max_output)},
});
}
#endif // MGB_ENABLE_JSON
} // namespace
......@@ -632,6 +726,17 @@ void OprFootprint::init_all_footprints() {
add_single_param_json<opr::Dimshuffle>();
add_single_param_json<opr::AxisAddRemove>();
add_single_param_json<opr::Subtensor>();
add_single_param_json<opr::SetSubtensor>();
add_single_param_json<opr::IncrSubtensor>();
add_single_param_json<opr::IndexingMultiAxisVec>();
add_single_param_json<opr::IndexingSetMultiAxisVec>();
add_single_param_json<opr::IndexingIncrMultiAxisVec>();
add_single_param_json<opr::MeshIndexing>();
add_single_param_json<opr::SetMeshIndexing>();
add_single_param_json<opr::IncrMeshIndexing>();
add_single_param_json<opr::BatchedMeshIndexing>();
add_single_param_json<opr::BatchedSetMeshIndexing>();
add_single_param_json<opr::BatchedIncrMeshIndexing>();
add_single_param_json<opr::Reduce>();
add_single_param_json<opr::LocalShareForward>();
add_single_param_json<opr::LocalShareBackwardData>();
......@@ -639,7 +744,31 @@ void OprFootprint::init_all_footprints() {
add_single_param_json<opr::DeformableConvForward>();
add_single_param_json<opr::DeformableConvBackwardFilter>();
add_single_param_json<opr::DeformableConvBackwardData>();
add_single_param_json<opr::DeformablePSROIPoolingForward>();
add_single_param_json<opr::BatchConvBiasForward>();
add_single_param_json<opr::BatchNormForward>();
add_single_param_json<opr::Reshape>();
add_single_param_json<opr::GetVarShape>();
add_single_param_json<opr::Argsort>();
add_single_param_json<opr::Argmin>();
add_single_param_json<opr::Argmax>();
add_single_param_json<opr::ElemwiseMultiType>();
add_single_param_json<opr::AdaptivePooling>();
add_single_param_json<opr::ROIPooling>();
add_single_param_json<opr::ROIAlign>();
add_single_param_json<opr::WarpPerspective>();
add_single_param_json<opr::Remap>();
add_single_param_json<opr::Resize>();
add_single_param_json<opr::IndexingOneHot>();
add_single_param_json<opr::IndexingSetOneHot>();
add_single_param_json<opr::WarpAffine>();
add_single_param_json<opr::TopK>();
add_single_param_json<opr::UniformRNG>();
add_single_param_json<opr::GaussianRNG>();
add_single_param_json<opr::Linspace>();
add_single_param_json<opr::Eye>();
add_single_param_json<opr::standalone::NMSKeep>();
add_single_param_json<opr::CvtColor>();
#endif
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册