提交 6fb19b66 编写于 作者: M Megvii Engine Team

feat(imperative/src): name operators automatically when tracing

GitOrigin-RevId: ff8eb003c5e2ee17de7d5ebd55a62391e64a48b1
上级 09de5a07
......@@ -96,7 +96,7 @@ class Graph(_imperative_rt.ComputingGraph):
data = data.numpy()
return self._wrap(_imperative_rt.make_const(self, data, device, data.dtype))
def make_const(self, data, dtype=None, device=None):
def make_const(self, data, dtype=None, device=None, name=None):
if isinstance(data, _imperative_rt.DeviceTensorND):
assert dtype is None and device is None
return self._wrap(_imperative_rt.make_shared(self, data))
......@@ -107,7 +107,9 @@ class Graph(_imperative_rt.ComputingGraph):
elif data.dtype == np.int64:
data = data.astype(np.int32)
device = as_device(device).to_c()
return self._wrap(_imperative_rt.make_const(self, data, device, dtype))
return self._wrap(
_imperative_rt.make_const(self, data, device, dtype, name)
)
def make_input(self, *args: "VarNode", device=None, dtype=None, shape=None):
opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self)
......@@ -305,7 +307,7 @@ def dump_graph(
output_vars: Union[Dict[str, VarNode], List[VarNode]],
*,
keep_var_name: int = 1,
keep_op_name: bool = True,
keep_opr_name: bool = False,
keep_param_name: bool = False,
keep_opr_priority: bool = False,
strip_info_file=None,
......@@ -326,7 +328,7 @@ def dump_graph(
* 0: none of the names are kept
* 1: (default)keep names of output vars
* 2: keep names of all (output and internal) vars
:param keep_op_name: whether to keep operator names.
:param keep_opr_name: whether to keep operator names.
:param keep_param_name: whether to keep param names, so param values can be
easily manipulated after loading model
:param keep_opr_priority: whether to keep priority setting for operators
......@@ -370,7 +372,7 @@ def dump_graph(
dump_content = _imperative_rt.dump_graph(
ov,
keep_var_name,
keep_op_name,
keep_opr_name,
keep_param_name,
keep_opr_priority,
stat,
......
......@@ -36,6 +36,7 @@ from ..core.ops.builtin import BackwardGraph, OpDef
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
from ..core.tensor.utils import setscalar
from ..utils.naming import auto_naming
from .sublinear_memory_config import SublinearMemoryConfig
......@@ -77,6 +78,7 @@ def exclude_from_trace():
class TensorInfo:
__slots__ = (
# collected attributes
"name",
"external",
"data_read",
"shape_read",
......@@ -96,6 +98,7 @@ class TensorInfo:
)
def __init__(self):
self.name = None
self.exported = None
self.data_read = None
self.shape_read = None
......@@ -290,12 +293,16 @@ class trace:
h = getattr(x, "_mixin_handle", -1)
if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
h, info = self._new_handle()
name = auto_naming.get_scope() + "." + x.c_name if x.c_name else x._name
info.name = name
info.external = True
info.device = x.device
info.dtype = x.dtype
info.shape = x.shape
if self._capture_as_const:
info.bound_data = RawTensor(x.numpy(), x.dtype, x.device, False)
info.bound_data = RawTensor(
x.numpy(), x.dtype, x.device, False, name
)
ihandles.append(h)
......@@ -669,6 +676,12 @@ class trace:
arg_names=None,
output_names=None,
append=False,
keep_var_name: int = 1,
keep_opr_name: bool = False,
keep_param_name: bool = False,
keep_opr_priority: bool = False,
strip_info_file=None,
append_json=False,
optimize_for_inference=True,
**kwargs
):
......@@ -681,6 +694,20 @@ class trace:
use the default name if not specified.
:param append: whether output is appended to ``file``.
Only works when ``file`` is str.
:param keep_var_name: level for keeping variable names:
* 0: none of the names are kept
* 1: (default)keep names of output vars
* 2: keep names of all (output and internal) vars
:param keep_opr_name: whether to keep operator names.
:param keep_param_name: whether to keep param names, so param values can be
easily manipulated after loading model
:param keep_opr_priority: whether to keep priority setting for operators
:param strip_info_file: a string for path or a file handler. if is not None,
then the dump information for code strip would be written to ``strip_info_file``
:param append_json: will be check when `strip_info_file` is not None. if set
true, the information for code strip will be append to strip_info_file.
if set false, will rewrite strip_info_file
:param optimize_for_inference: enbale optmizations,
will skip all optimize options if this is False. Default: True
......@@ -785,7 +812,10 @@ class trace:
assert info.external
assert info.bound_data
h2v[h] = graph.make_const(
info.bound_data.numpy(), dtype=info.dtype, device=info.device,
info.bound_data.numpy(),
dtype=info.dtype,
device=info.device,
name=info.name,
)
continue
ivars = []
......@@ -795,13 +825,26 @@ class trace:
assert info.external
assert info.bound_data
h2v[h] = graph.make_const(
info.bound_data.numpy(), dtype=info.dtype, device=dumped_device
info.bound_data.numpy(),
dtype=info.dtype,
device=dumped_device,
name=info.name,
)
ivars.append(h2v[h])
ovars = G.apply_normal_varnode(op, *ivars)
auto_naming.record_opnode(ovars[0].op)
assert len(ovars) == len(ohandles)
h2v.update(zip(ohandles, ovars))
for i in ohandles:
name = auto_naming.get_var_name(i)
if name is not None:
h2v[i].name = name
auto_naming.remove_duplicate_names()
dest_vars = []
for i, h in enumerate(self._output_bindings):
v = h2v[h]
......@@ -815,7 +858,15 @@ class trace:
if isinstance(file, str):
permission = "wb" if append == False else "ab"
file = open(file, permission)
dump_content, dump_info = G.dump_graph(dest_vars)
dump_content, dump_info = G.dump_graph(
dest_vars,
keep_var_name=keep_var_name,
keep_opr_name=keep_opr_name,
keep_param_name=keep_param_name,
keep_opr_priority=keep_opr_priority,
strip_info_file=strip_info_file,
append_json=append_json,
)
file.write(dump_content)
return dump_info
......@@ -1095,20 +1146,22 @@ def apply_compiled_mode(op: OpDef, *args: RawTensor):
return active_trace._apply_op(op, args)
def apply_const_compiled_mode(value, dtype, device, is_const, no_cache):
def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name):
if skip_tracing:
args = [
RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
for x in args
]
unset_tracing()
ret = RawTensor(value, dtype, device, False)
ret = RawTensor(value, dtype, device, False, name)
set_tracing()
return ret
return active_trace._apply_const(value, dtype, device)
def apply_with_tracing(op: OpDef, *args: RawTensor):
if hasattr(op, "scope"):
op.scope = auto_naming.get_scope()
if active_trace._symbolic:
outputs = apply_symbolic_mode(op, *args)
else:
......@@ -1120,12 +1173,12 @@ def apply_with_tracing(op: OpDef, *args: RawTensor):
return list(outputs)
def apply_const_with_tracing(value, dtype, device, is_const, no_cache):
def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name):
if active_trace._symbolic:
outputs = apply_const_symbolic_mode(value, dtype, device)
else:
unset_tracing()
outputs = (RawTensor(value, dtype, device, False),)
outputs = (RawTensor(value, dtype, device, False, name),)
set_tracing()
active_trace._record_const(outputs)
return list(outputs)
......@@ -12,12 +12,12 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union
import numpy as np
from ..core._imperative_rt.core2 import pop_scope, push_scope
from ..core.tensor.utils import make_shape_tuple
from ..logger import get_logger
from ..tensor import Parameter, Tensor
from ..utils.deprecation import deprecated
from ..utils.hook import HookHandler
from ..utils.naming import auto_naming
logger = get_logger(__name__)
......@@ -69,7 +69,9 @@ class Module(metaclass=ABCMeta):
Base Module class.
"""
def __init__(self):
def __init__(self, name=""):
self.name = name
# runtime attributes
self.training = True
self.quantize_disabled = False
......@@ -79,6 +81,8 @@ class Module(metaclass=ABCMeta):
self._forward_hooks = OrderedDict()
self._modules = []
# used for profiler and automatic naming
self._name = "{anonymous}"
@abstractmethod
......@@ -105,7 +109,7 @@ class Module(metaclass=ABCMeta):
return HookHandler(self._forward_hooks, hook)
def __call__(self, *inputs, **kwargs):
push_scope(self._name)
auto_naming.push_scope(self.name if self.name else self._name)
for hook in self._forward_pre_hooks.values():
modified_inputs = hook(self, inputs)
if modified_inputs is not None:
......@@ -119,7 +123,7 @@ class Module(metaclass=ABCMeta):
modified_outputs = hook(self, inputs, outputs)
if modified_outputs is not None:
outputs = modified_outputs
pop_scope(self._name)
auto_naming.pop_scope()
return outputs
def _flatten(
......@@ -579,7 +583,7 @@ class Module(metaclass=ABCMeta):
value = super().__getattribute__(name)
if name == "_name":
return value
if _is_module(value):
if isinstance(value, (Tensor, Module)):
value._name = name
return value
......
......@@ -20,6 +20,7 @@ from .core.tensor.array_method import ArrayMethodMixin
from .device import _valid_device, get_default_device
from .logger import get_logger
from .utils.deprecation import deprecated
from .utils.naming import auto_naming
class Tensor(_Tensor, ArrayMethodMixin):
......@@ -27,7 +28,9 @@ class Tensor(_Tensor, ArrayMethodMixin):
dmap_callback = None
_q_dict = None
def __new__(cls, data, dtype=None, device=None, is_const=False, no_cache=False):
def __new__(
cls, data, dtype=None, device=None, is_const=False, no_cache=False, name=""
):
if device is None:
cn = get_default_device()
elif isinstance(device, str):
......@@ -51,8 +54,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
if isinstance(data, np.ndarray):
if 0 in data.strides:
data = data.squeeze().reshape(data.shape)
obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache)
obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache, name)
return obj
@property
......@@ -91,6 +93,15 @@ class Tensor(_Tensor, ArrayMethodMixin):
piece += ", device={}".format(self.device) + ")"
return piece
@property
def name(self):
return self.c_name
@name.setter
def name(self, name):
self.c_name = name
auto_naming.record_var_name(self._mixin_handle, name)
@deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
def set_value(self, value):
if not isinstance(value, _Tensor):
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..core._imperative_rt.core2 import pop_scope, push_scope
class AutoNaming:
r"""
Name all executed operators automaticlly during tracing and record all tensors
renamed by the user.
"""
def __init__(self):
self.scopes = []
self.c_ops = []
self.name2ops = {}
self.handle2names = {}
def clear(self):
for var in vars(self).values():
var.clear()
def push_scope(self, scope):
push_scope(scope)
self.scopes.append(scope)
def pop_scope(self):
scope = self.scopes.pop()
pop_scope(scope)
def get_scope(self):
return ".".join(self.scopes)
def record_var_name(self, handle, name):
self.handle2names[handle] = name
def get_var_name(self, handle):
return self.handle2names.pop(handle, None)
def record_opnode(self, op):
ops = self.name2ops.get(op.name, [])
ops.append(op)
self.name2ops[op.name] = ops
def remove_duplicate_names(self):
for key, ops in self.name2ops.items():
if len(ops) == 1:
continue
for i, op in enumerate(ops):
op.name = key + "[%s]" % str(i)
if len(op.outputs) == 1:
continue
for var in op.outputs:
var.name = var.name.replace(key, op.name)
self.name2ops.clear()
auto_naming = AutoNaming()
......@@ -294,7 +294,7 @@ void init_graph_rt(py::module m) {
m.def("dump_graph", [](
const std::vector<VarNode*>& dest_vars,
int keep_var_name,
bool keep_op_name,
bool keep_opr_name,
bool keep_param_name,
bool keep_opr_priority,
py::list& stat,
......@@ -307,7 +307,7 @@ void init_graph_rt(py::module m) {
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());
ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name,
keep_opr_priority, keep_op_name};
keep_opr_priority, keep_opr_name};
auto rst = dumper->dump(symvars, config);
for (auto i : rst.inputs) {
......@@ -457,13 +457,17 @@ void init_graph_rt(py::module m) {
return opr::SharedDeviceTensor::make(*graph, std::make_shared<DeviceTensorND>(data)).node();
});
m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) {
m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype, std::optional<std::string> name) {
if (!cn.valid()) {
cn = CompNode::load(get_default_device());
}
OperatorNodeConfig config(cn);
if (name) {
config.name(*name);
}
auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype);
return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node();
});
return opr::ImmutableTensor::make(*graph, hv, config).node();
}, py::arg(), py::arg(), py::arg(), py::arg(), py::arg() = py::none());
m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, TensorShape shape, std::optional<std::string> name) {
if (!cn.valid()) {
......
......@@ -99,6 +99,14 @@ PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) {
#define py_get_generic(name, attr) \
py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
template<typename T>
PyObject* py_get_scope_impl(PyObject* obj, void* /* closure */) {
// T: PyOpXXX inst(): return XXX in opdef.h.inl
auto& op = reinterpret_cast<T*>(obj)->inst();
return pyobj_convert_generic<std::string>::to(op.scope());
}
#define py_get_scope(class) py_get_scope_impl<PyOp(class)>
template<typename T, typename U, U T::Ty::*attr>
int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
if (value == NULL) {
......@@ -121,6 +129,27 @@ int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
#define py_set_generic(name, attr) \
py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
template<typename T>
int py_set_scope_impl(PyObject* obj, PyObject* value, void* /* closure */) {
if (value == NULL) {
PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
return -1;
}
auto& op = reinterpret_cast<T*>(obj)->inst();
try {
op.set_scope(pyobj_convert_generic<std::string>::from(value));
return 0;
} catch(py::error_already_set& e) {
e.restore();
} catch(py::builtin_exception& e) {
e.set_error();
} catch(...) {
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
}
return -1;
}
#define py_set_scope(class) py_set_scope_impl<PyOp(class)>
struct PyOpDef {
PyObject_HEAD
std::shared_ptr<OpDef> op;
......
......@@ -24,6 +24,7 @@
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
#include <range/v3/all.hpp>
#include <string>
#include <unordered_map>
......@@ -222,14 +223,15 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
}
} else {
py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
if (nargs != 4 && nargs != 5) {
throw py::type_error("expect 4 or 5 arguments");
if (nargs != 5 && nargs != 6) {
throw py::type_error("expect 5 or 6 arguments");
}
auto data = tup[0].cast<py::array>();
DType dtype = tup[1].cast<DType>();
CompNode cn = tup[2].cast<CompNode>();
bool is_const = tup[3].cast<bool>();
bool no_cache = nargs == 5 ? tup[4].cast<bool>() : false;
bool no_cache = nargs == 6 ? tup[4].cast<bool>() : false;
std::string name = tup[nargs - 1].cast<std::string>();
// const op
if (is_const && is_tracing) {
......@@ -259,6 +261,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
}
m_tensor = std::make_shared<Tensor>(handle);
m_tensor->user_custom_name = name;
if (data.ndim() == 0) {
m_tensor->m_flags |= Tensor::Flags::SCALAR;
......@@ -313,6 +316,19 @@ REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(trace_mixin_info)
#undef REGISTE_TENSORWRAPPER_PYOBJECT_FUNC
#define SET_GET_NAME(member) \
PyObject* TensorWrapper::member() { \
return py::cast(m_tensor->member).release().ptr(); \
} \
void TensorWrapper::set_##member(PyObject* dest) { \
auto py_dest = py::reinterpret_borrow<py::object>(dest); \
m_tensor->member = py_dest.cast<std::string>(); \
}
SET_GET_NAME(user_custom_name)
SET_GET_NAME(automatic_name)
#undef SET_GET_NAME
PyObject* TensorWrapper::handle() {
return py::cast(m_tensor->m_handle).release().ptr();
}
......@@ -453,7 +469,11 @@ void TensorWrapper::reset(PyObject* tensor) {
if (!t) {
throw py::type_error("expect Tensor");
}
std::string user_custom_name = m_tensor->user_custom_name;
std::string automatic_name = m_tensor->automatic_name;
m_tensor = t->m_tensor;
m_tensor->user_custom_name = user_custom_name;
m_tensor->automatic_name = automatic_name;
}
void TensorWrapper::reset_varnode() {
......@@ -785,6 +805,8 @@ void init_tensor(py::module m) {
.def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle")
.def_getset<&TensorWrapper::compiled_info, &TensorWrapper::set_compiled_info>("_compiled_info")
.def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info")
.def_getset<&TensorWrapper::user_custom_name, &TensorWrapper::set_user_custom_name>("c_name")
.def_getset<&TensorWrapper::automatic_name, &TensorWrapper::set_automatic_name>("_name")
.finalize();
if (!tensor_type) throw py::error_already_set();
py::setattr(m, "Tensor", tensor_type);
......
......@@ -15,6 +15,7 @@
#include "megbrain/imperative/interpreter.h"
#include "pybind11/pybind11.h"
#include <string>
#include "./pyext17.h"
......@@ -70,6 +71,8 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
GradInfo m_grad_info;
TraceInfo m_trace_info;
SharedHandle m_handle;
std::string user_custom_name;
std::string automatic_name;
cg::VarNode* m_var;
using Handle = interpreter::Interpreter::Handle;
......@@ -170,6 +173,10 @@ struct TensorWrapper {
void set_compiled_info(PyObject *);
PyObject* trace_mixin_info();
void set_trace_mixin_info(PyObject *);
PyObject* user_custom_name();
void set_user_custom_name(PyObject *);
PyObject* automatic_name();
void set_automatic_name(PyObject *);
PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); };
};
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import io
import numpy as np
import pytest
import megengine.functional as F
import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools
from megengine import Parameter, Tensor
from megengine.core.tensor import megbrain_graph as G
from megengine.jit.tracing import trace
from megengine.utils.naming import auto_naming
def _dump_and_load(func, symbolic, keep_opr_name=True):
auto_naming.clear()
func = trace(func, symbolic=symbolic, capture_as_const=True)
x = Tensor(np.ones(shape=(2, 3)))
func(x).numpy()
file = io.BytesIO()
func.dump(
file,
optimize_for_inference=False,
arg_names="x",
keep_opr_name=keep_opr_name,
keep_var_name=2,
)
file.seek(0)
*_, outputs = G.load_graph(file)
op = cgtools.get_oprs_seq(outputs)[-1]
return op
@pytest.mark.parametrize("symbolic", [False, True])
def test_auto_naming(symbolic):
class Simple(M.Module):
def __init__(self, name):
super().__init__()
self.name = name
def forward(self, x):
return x + x
m = Simple("simple")
op = _dump_and_load(m, symbolic)
assert op.name == "simple.ADD"
assert op.outputs[0].name == "simple.ADD"
@pytest.mark.parametrize("symbolic", [False, True])
def test_user_named_tensor(symbolic):
class Simple(M.Module):
def __init__(self, name):
super().__init__()
self.name = name
self.k = Parameter(1.0, name="k")
def forward(self, x):
x = x + x
x.name = "o_x"
return x
m = Simple("simple")
op = _dump_and_load(m, symbolic)
assert op.name == "simple.ADD"
assert op.outputs[0].name == "o_x"
@pytest.mark.parametrize("symbolic", [False, True])
def test_user_named_param(symbolic):
class Simple(M.Module):
def __init__(self, name):
super().__init__()
self.name = name
self.k = Parameter(2.0, name="k")
def forward(self, x):
return self.k * x
m = Simple("simple")
op = _dump_and_load(m, symbolic)
assert op.inputs[0].name == "x"
assert op.inputs[1].name == "simple.k"
@pytest.mark.parametrize("symbolic", [False, True])
def test_without_module(symbolic):
def f(x):
return 2 * x
op = _dump_and_load(f, symbolic)
assert op.name == "MUL"
@pytest.mark.parametrize("symbolic", [False, True])
def test_with_submodule(symbolic):
class Simple(M.Module):
def __init__(self, name):
super().__init__()
self.name = name
self.linear = M.Linear(3, 3)
def forward(self, x):
x = self.linear(x)
return x
m = Simple("simple")
op = _dump_and_load(m, symbolic)
assert op.name == "simple.linear.ADD"
assert op.inputs[0].owner.name == "simple.linear.MatrixMul"
assert op.outputs[0].name == "simple.linear.ADD"
@pytest.mark.parametrize("symbolic", [False, True])
def test_named_submodule(symbolic):
class Simple(M.Module):
def __init__(self, name):
super().__init__()
self.name = name
self.linear = M.Linear(3, 3, name="x")
def forward(self, x):
x = self.linear(x)
return x
m = Simple("simple")
op = _dump_and_load(m, symbolic)
assert op.name == "simple.x.ADD"
assert op.inputs[0].owner.name == "simple.x.MatrixMul"
assert op.outputs[0].name == "simple.x.ADD"
@pytest.mark.parametrize("symbolic", [False, True])
def test_with_same_operators(symbolic):
class Simple(M.Module):
def __init__(self, name):
super().__init__()
self.name = name
def forward(self, x):
x = F.relu(x)
x = F.relu(x)
return x
m = Simple("simple")
op = _dump_and_load(m, symbolic)
assert op.name == "simple.RELU[1]"
assert op.inputs[0].owner.name == "simple.RELU[0]"
def test_not_keep_opr_name():
def f(x):
return 2 * x
op = _dump_and_load(f, True, False)
assert op.name == "MUL(x,2[2])[4]"
......@@ -148,7 +148,7 @@ def test_dump():
dump_info = f.dump(file)
assert dump_info.nr_opr == 3
np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"])
np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"])
np.testing.assert_equal(dump_info.outputs, ["ADD"])
file.seek(0)
infer_cg = cgtools.GraphInference(file)
result = list((infer_cg.run(a, b)).values())[0]
......
......@@ -75,10 +75,6 @@ std::vector<std::pair<const char*, std::string>> OpDef::props(
return def.trait()->props(def);
}
const char* OpDef::name() const {
return trait()->name;
}
std::string OpDef::to_string() const {
std::string builder = "{";
for (auto&& [name, value]: props(*this)) {
......@@ -107,6 +103,20 @@ const OpTrait* OpDef::trait() const {
return m_trait;
}
const std::string OpDef::scope() const {
return m_scope;
}
void OpDef::set_scope(const std::string& scope) {
m_scope = scope;
}
const std::string OpDef::make_name() const {
if (m_scope.empty())
return trait()->make_name(*this);
return m_scope + "." + trait()->make_name(*this);
}
} // namespace imperative
} // namespace mgb
......
......@@ -75,6 +75,7 @@ using GradMaker = detail::OpMeth<
using Props = detail::OpMeth<decltype(OpDef::props)>;
using HashFunc = detail::OpMeth<size_t(const OpDef&)>;
using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>;
using MakeNameFunc = detail::OpMeth<std::string(const OpDef&)>;
struct OpTrait {
const char* name;
......@@ -88,6 +89,7 @@ struct OpTrait {
Props props;
HashFunc hash;
IsSame is_same_st;
MakeNameFunc make_name;
OpTrait(const char* name);
static OpTrait* find_by_name(const char* name);
static OpTrait* find_by_typeinfo(Typeinfo* type);
......@@ -104,7 +106,8 @@ struct OpTrait {
cb(make_backward_graph) \
cb(props) \
cb(hash) \
cb(is_same_st)
cb(is_same_st) \
cb(make_name)
struct OpTraitRegistry {
OpTrait* trait;
......
......@@ -30,13 +30,14 @@ cg::OperatorNodeBase* apply_on_var_node(
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 3 ||nr_inp == 5,
"BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp);
OperatorNodeConfig config{bn_opr.make_name()};
if (nr_inp == 3) {
return opr::BatchNorm::make(
inputs[0], inputs[1], inputs[2], bn_opr.param())[0]
inputs[0], inputs[1], inputs[2], bn_opr.param(), config)[0]
.node()->owner_opr();
} else {
return opr::BatchNorm::make(
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], bn_opr.param())[0]
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], bn_opr.param(), config)[0]
.node()->owner_opr();
}
}
......
......@@ -27,10 +27,11 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
def.cast_final_safe<Broadcast>();
auto&& op = def.cast_final_safe<Broadcast>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
return opr::Broadcast::make(inputs[0], inputs[1]).node()->owner_opr();
OperatorNodeConfig config{op.make_name()};
return opr::Broadcast::make(inputs[0], inputs[1], config).node()->owner_opr();
}
bool valid_broadcast(const TensorShape& src_shape,
......@@ -96,7 +97,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& op = static_cast<const Reshape&>(def);
mgb_assert(inputs.size() == 2);
return opr::Reshape::make(inputs[0], inputs[1], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::Reshape::make(inputs[0], inputs[1], op.param(), config);
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
......
......@@ -35,7 +35,7 @@ cg::OperatorNodeBase* apply_on_var_node(
auto disable = std::make_shared<DTypeScalar>();
disable->set(0);
cg::OperatorNodeConfig config;
OperatorNodeConfig config{comm.make_name()};
if (comm.comp_node.size() > 0) {
config.comp_node(CompNode::load(comm.comp_node));
}
......
......@@ -23,12 +23,12 @@ namespace {
cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
def.cast_final_safe<CondTake>();
auto&& op = def.cast_final_safe<CondTake>();
auto&& graph = inputs[0]->owner_graph();
opr::CondTake::Param param;
param.val = 1;
cg::OperatorNodeConfig config;
OperatorNodeConfig config{op.make_name()};
cg::OperatorNodeBase* opr = graph->insert_opr(
std::make_unique<opr::CondTake>(
inputs[0], inputs[1], param, config));
......
......@@ -31,7 +31,8 @@ cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& elemwise_opr = def.cast_final_safe<Elemwise>();
return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr();
OperatorNodeConfig config{elemwise_opr.make_name()};
return opr::Elemwise::make(inputs, elemwise_opr.mode, config).node()->owner_opr();
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
......
......@@ -23,7 +23,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& op = static_cast<const CvtColor&>(def);
mgb_assert(inputs.size() == 1);
return opr::CvtColor::make(inputs[0], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::CvtColor::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(CvtColor, CvtColor)
.apply_on_var_node(apply_on_var_node)
......
......@@ -32,7 +32,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send(
ssprintf("%s:%d", send.addr.data(), send.port));
auto&& graph = inputs[0]->owner_graph();
cg::OperatorNodeConfig config;
OperatorNodeConfig config{send.make_name()};
cg::OperatorNodeBase* opr =
graph->insert_opr(std::make_unique<mgb::opr::RemoteSend>(
send.key, inputs[0], group_client, true, config));
......@@ -42,11 +42,13 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send(
cg::OperatorNodeBase* apply_on_var_node_remote_recv(
const OpDef& def, const VarNodeArray& inputs) {
auto&& recv = def.cast_final_safe<RemoteRecv>();
OperatorNodeConfig config{recv.cn};
config.name(recv.make_name());
auto group_client = std::make_shared<GroupClientProxy>(
ssprintf("%s:%d", recv.addr.data(), recv.port));
auto&& graph = inputs[0]->owner_graph();
return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>(
recv.key, inputs[0], *graph, group_client, OperatorNodeConfig{recv.cn},
recv.key, inputs[0], *graph, group_client, config,
recv.shape, recv.dtype));
}
......
......@@ -21,8 +21,10 @@ namespace {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<MatrixInverse>();
mgb_assert(inputs.size() == 1);
return opr::MatrixInverse::make(inputs[0]);
OperatorNodeConfig config{op.make_name()};
return opr::MatrixInverse::make(inputs[0], {}, config);
}
OP_TRAIT_REG(MatrixInverse, MatrixInverse)
.apply_on_var_node(apply_on_var_node)
......
......@@ -29,7 +29,9 @@ cg::OperatorNodeBase* apply_on_var_node(
param.iou_thresh = nms_keep.iou_thresh;
param.max_output = nms_keep.max_output;
return NMSKeepOpr::make(inputs[0], param).node()->owner_opr();
OperatorNodeConfig config{nms_keep.make_name()};
return NMSKeepOpr::make(inputs[0], param, config).node()->owner_opr();
}
OP_TRAIT_REG(NMSKeep, NMSKeep, NMSKeepOpr)
......
......@@ -79,11 +79,13 @@ public:
cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def, const VarNodeArray& inputs) {
auto&& attr = def.cast_final_safe<OprAttr>();
auto config = attr.config;
config.name(attr.make_name());
mgb_assert(!inputs.empty());
auto registry = serialization::OprRegistry::find_by_name(attr.type);
mgb_assert(registry, "operator %s not found", attr.type.c_str());
OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()};
return registry->loader(ctx, inputs, attr.config);
return registry->loader(ctx, inputs, config);
}
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) {
......@@ -99,10 +101,15 @@ std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
return {};
}
std::string make_name(const OpDef& def) {
return "OprAttr";
}
OP_TRAIT_REG(OprAttr, OprAttr)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.props(props)
.make_name(make_name)
.fallback();
} // anonymous namespace
......
......@@ -24,7 +24,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& op = static_cast<const Resize&>(def);
mgb_assert(inputs.size() == 2);
return opr::Resize::make(inputs[0], inputs[1], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::Resize::make(inputs[0], inputs[1], op.param(), config);
}
OP_TRAIT_REG(Resize, Resize)
......
......@@ -46,7 +46,8 @@ auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& conv = static_cast<const Convolution&>(def);
return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy());
OperatorNodeConfig config{conv.make_name()};
return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
}
OP_TRAIT_REG(Convolution, Convolution, opr::Convolution)
......@@ -60,7 +61,7 @@ auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& conv = static_cast<const ConvolutionBackwardData&>(def);
cg::OperatorNodeConfig config;
OperatorNodeConfig config{conv.make_name()};
if (inputs.size() == 2) {
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
} else {
......@@ -88,7 +89,8 @@ auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& ds = static_cast<const Dimshuffle&>(def);
return opr::Dimshuffle::make(inputs[0], ds.pattern);
OperatorNodeConfig config{ds.make_name()};
return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config);
}
OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle)
......@@ -107,7 +109,8 @@ auto apply_on_var_node(
for (auto&& i : add_axis.axis) {
param.push_back(Desc::make_add(i));
}
return opr::AxisAddRemove::make(inputs[0], param);
OperatorNodeConfig config{add_axis.make_name()};
return opr::AxisAddRemove::make(inputs[0], param, config);
}
OP_TRAIT_REG(AddAxis, AddAxis)
......@@ -125,7 +128,8 @@ auto apply_on_var_node(
for (auto&& i : remove_axis.axis) {
param.push_back(Desc::make_remove(i));
}
return opr::AxisAddRemove::make(inputs[0], param);
OperatorNodeConfig config{remove_axis.make_name()};
return opr::AxisAddRemove::make(inputs[0], param, config);
}
OP_TRAIT_REG(RemoveAxis, RemoveAxis)
......@@ -138,7 +142,8 @@ auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& topk = static_cast<const TopK&>(def);
return opr::TopK::make(inputs[0], inputs[1], topk.param())[0]
OperatorNodeConfig config{topk.make_name()};
return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0]
.node()->owner_opr();
}
......@@ -152,10 +157,12 @@ auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& reduce = static_cast<const Reduce&>(def);
OperatorNodeConfig config{reduce.make_name()};
if (inputs.size() > 1) {
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1]);
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config);
} else {
return opr::Reduce::make(inputs[0], reduce.param());
return opr::Reduce::make(
inputs[0], reduce.param(), (cg::VarNode*)nullptr, config);
}
}
......@@ -175,7 +182,8 @@ auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& pool = static_cast<const AdaptivePooling&>(def);
return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param());
OperatorNodeConfig config{pool.make_name()};
return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), config);
}
OP_TRAIT_REG(AdaptivePooling, AdaptivePooling)
......@@ -189,6 +197,7 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& conv = static_cast<const ConvBias&>(def);
cg::OperatorNodeConfig config{conv.dtype};
config.name(conv.make_name());
if (inputs.size() == 2) {
return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
} else if (inputs.size() == 3) {
......@@ -210,6 +219,7 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& conv = static_cast<const BatchConvBias&>(def);
cg::OperatorNodeConfig config{conv.dtype};
config.name(conv.make_name());
if (inputs.size() == 2) {
return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
} else if (inputs.size() == 3) {
......@@ -230,7 +240,8 @@ auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& pool = static_cast<const Pooling&>(def);
return opr::Pooling::make(inputs[0], pool.param());
OperatorNodeConfig config{pool.make_name()};
return opr::Pooling::make(inputs[0], pool.param(), config);
}
OP_TRAIT_REG(Pooling, Pooling)
.apply_on_var_node(apply_on_var_node)
......@@ -243,8 +254,9 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& matmul = static_cast<const MatrixMul&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{matmul.make_name()};
return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param(),
matmul.policy());
matmul.policy(), config);
}
OP_TRAIT_REG(MatrixMul, MatrixMul)
.apply_on_var_node(apply_on_var_node)
......@@ -257,8 +269,9 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& matmul = static_cast<const BatchedMatrixMul&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{matmul.make_name()};
return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param(),
matmul.policy());
matmul.policy(), config);
}
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
.apply_on_var_node(apply_on_var_node)
......@@ -267,10 +280,12 @@ OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
namespace { namespace dot {
auto apply_on_var_node(
const OpDef&,
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<Dot>();
mgb_assert(inputs.size() == 2);
return opr::Dot::make(inputs[0], inputs[1]);
OperatorNodeConfig config{op.make_name()};
return opr::Dot::make(inputs[0], inputs[1], config);
}
OP_TRAIT_REG(Dot, Dot)
.apply_on_var_node(apply_on_var_node)
......@@ -282,7 +297,8 @@ auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& argsort = static_cast<const Argsort&>(def);
return opr::Argsort::make(inputs[0], argsort.param());
OperatorNodeConfig config{argsort.make_name()};
return opr::Argsort::make(inputs[0], argsort.param(), config);
}
OP_TRAIT_REG(Argsort, Argsort)
.apply_on_var_node(apply_on_var_node)
......@@ -294,7 +310,8 @@ auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& argmax = static_cast<const Argmax&>(def);
return opr::Argmax::make(inputs[0], argmax.param());
OperatorNodeConfig config{argmax.make_name()};
return opr::Argmax::make(inputs[0], argmax.param(), config);
}
OP_TRAIT_REG(Argmax, Argmax)
.apply_on_var_node(apply_on_var_node)
......@@ -306,7 +323,8 @@ auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& argmin = static_cast<const Argmin&>(def);
return opr::Argmin::make(inputs[0], argmin.param());
OperatorNodeConfig config{argmin.make_name()};
return opr::Argmin::make(inputs[0], argmin.param(), config);
}
OP_TRAIT_REG(Argmin, Argmin)
.apply_on_var_node(apply_on_var_node)
......@@ -318,11 +336,13 @@ auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& warp = static_cast<const WarpPerspective&>(def);
OperatorNodeConfig config{warp.make_name()};
if (inputs.size() == 3) {
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param());
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param(), config);
} else {
mgb_assert(inputs.size() == 4);
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], inputs[3], warp.param());
return opr::WarpPerspective::make(
inputs[0], inputs[1], inputs[2], inputs[3], warp.param(), config);
}
}
OP_TRAIT_REG(WarpPerspective, WarpPerspective)
......@@ -336,7 +356,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& local = static_cast<const GroupLocal&>(def);
mgb_assert(inputs.size() == 2);
return opr::GroupLocal::make(inputs[0], inputs[1], local.param());
OperatorNodeConfig config{local.make_name()};
return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config);
}
OP_TRAIT_REG(GroupLocal, GroupLocal)
.apply_on_var_node(apply_on_var_node)
......@@ -349,7 +370,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingOneHot&>(def);
mgb_assert(inputs.size() == 2);
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config);
}
OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
.apply_on_var_node(apply_on_var_node)
......@@ -362,7 +384,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingSetOneHot&>(def);
mgb_assert(inputs.size() == 3);
return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param(), config);
}
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
.apply_on_var_node(apply_on_var_node)
......@@ -375,7 +398,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& op = static_cast<const TypeCvt&>(def);
mgb_assert(inputs.size() == 1);
return opr::TypeCvt::make(inputs[0], op.dtype);
OperatorNodeConfig config{op.make_name()};
return opr::TypeCvt::make(inputs[0], op.dtype, config);
}
OP_TRAIT_REG(TypeCvt, TypeCvt)
.apply_on_var_node(apply_on_var_node)
......@@ -388,6 +412,7 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& op = static_cast<const Concat&>(def);
cg::OperatorNodeConfig config{op.comp_node};
config.name(op.make_name());
return opr::Concat::make(inputs, op.axis, config);
}
OP_TRAIT_REG(Concat, Concat)
......@@ -402,6 +427,7 @@ auto apply_on_var_node(
auto&& op = static_cast<const Copy&>(def);
mgb_assert(inputs.size() == 1);
cg::OperatorNodeConfig config{op.comp_node};
config.name(op.make_name());
return opr::Copy::make(inputs[0], config);
}
OP_TRAIT_REG(Copy, Copy)
......@@ -411,10 +437,12 @@ OP_TRAIT_REG(Copy, Copy)
namespace { namespace identity {
auto apply_on_var_node(
const OpDef&,
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<Identity>();
mgb_assert(inputs.size() == 1);
return opr::Identity::make(inputs[0]);
OperatorNodeConfig config{op.make_name()};
return opr::Identity::make(inputs[0], config);
}
OP_TRAIT_REG(Identity, Identity)
.apply_on_var_node(apply_on_var_node)
......@@ -427,7 +455,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& op = static_cast<const AssertEqual&>(def);
mgb_assert(inputs.size() == 2);
return opr::AssertEqual::make(inputs[0],inputs[1],op.param());
OperatorNodeConfig config{op.make_name()};
return opr::AssertEqual::make(inputs[0], inputs[1], op.param(), config);
}
......@@ -443,7 +472,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& op = static_cast<const UniformRNG&>(def);
mgb_assert(inputs.size() == 1);
return opr::UniformRNG::make(inputs[0], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::UniformRNG::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(UniformRNG, UniformRNG)
.apply_on_var_node(apply_on_var_node)
......@@ -456,7 +486,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& op = static_cast<const GaussianRNG&>(def);
mgb_assert(inputs.size() == 1);
return opr::GaussianRNG::make(inputs[0], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::GaussianRNG::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(GaussianRNG, GaussianRNG)
.apply_on_var_node(apply_on_var_node)
......@@ -469,7 +500,9 @@ VarNodeArray apply_on_var_node(
const VarNodeArray& inputs) {
auto&& op = static_cast<const ROIAlign&>(def);
mgb_assert(inputs.size() == 2);
auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param()).node()->owner_opr();
OperatorNodeConfig config{op.make_name()};
auto* opr = opr::ROIAlign::make(
inputs[0], inputs[1], op.param(), config).node()->owner_opr();
return {opr->output(0), opr->output(1)};
}
OP_TRAIT_REG(ROIAlign, ROIAlign)
......@@ -484,7 +517,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& op = static_cast<const NvOf&>(def);
mgb_assert(inputs.size() == 1);
return opr::NvOf::make(inputs[0], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::NvOf::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(NvOf, NvOf)
.apply_on_var_node(apply_on_var_node)
......@@ -499,6 +533,7 @@ auto apply_on_var_node(
auto&& op = static_cast<const Linspace&>(def);
mgb_assert(inputs.size() == 3);
cg::OperatorNodeConfig config{op.comp_node};
config.name(op.make_name());
return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config);
}
OP_TRAIT_REG(Linspace, Linspace)
......@@ -513,6 +548,7 @@ auto apply_on_var_node(
auto&& op = static_cast<const Eye&>(def);
mgb_assert(inputs.size() == 1);
cg::OperatorNodeConfig config{op.comp_node};
config.name(op.make_name());
opr::Eye::Param param{op.k, op.dtype.enumv()};
return opr::Eye::make(inputs[0], param, config);
}
......@@ -527,7 +563,10 @@ VarNodeArray apply_on_var_node(
const VarNodeArray& inputs) {
auto&& op = static_cast<const ROIPooling&>(def);
mgb_assert(inputs.size() == 3);
auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param()).node()->owner_opr();
OperatorNodeConfig config{op.make_name()};
auto* opr = opr::ROIPooling::make(
inputs[0], inputs[1], inputs[2], op.param(), config
).node()->owner_opr();
return {opr->output(0), opr->output(1)};
}
OP_TRAIT_REG(ROIPooling, ROIPooling)
......@@ -541,7 +580,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& op = static_cast<const Remap&>(def);
mgb_assert(inputs.size() == 2);
return opr::Remap::make(inputs[0], inputs[1], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::Remap::make(inputs[0], inputs[1], op.param(), config);
}
OP_TRAIT_REG(Remap, Remap)
.apply_on_var_node(apply_on_var_node)
......@@ -578,7 +618,8 @@ auto apply_on_var_node( \
const OpDef& def, \
const VarNodeArray& inputs) { \
auto&& op = static_cast<const NAME&>(def); \
return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items)); \
OperatorNodeConfig config{op.make_name()}; \
return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items), config); \
} \
OP_TRAIT_REG(NAME, NAME) \
.apply_on_var_node(apply_on_var_node) \
......@@ -609,30 +650,35 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& op = static_cast<const FakeQuant&>(def);
mgb_assert(inputs.size() == 3);
return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), config);
}
OP_TRAIT_REG(FakeQuant, FakeQuant)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // fake_quant
namespace { namespace tqt {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const TQT&>(def);
mgb_assert(inputs.size() == 2);
return opr::TQT::make(inputs[0], inputs[1], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::TQT::make(inputs[0], inputs[1], op.param(), config);
}
OP_TRAIT_REG(TQT, TQT)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // tqt
namespace { namespace elemwise_multi_type {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const ElemwiseMultiType&>(def);
OperatorNodeConfig config{op.dtype};
config.name(op.make_name());
return opr::ElemwiseMultiType::make(inputs, op.param(), config);
}
OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType)
......@@ -646,7 +692,9 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& op = static_cast<const SVD&>(def);
mgb_assert(inputs.size() == 1);
return opr::SVD::make(inputs[0], op.param())[0].node()->owner_opr()->usable_output();
OperatorNodeConfig config{op.make_name()};
return opr::SVD::make(inputs[0], op.param(), config)[0]
.node()->owner_opr()->usable_output();
}
OP_TRAIT_REG(SVD, SVD)
.apply_on_var_node(apply_on_var_node)
......
......@@ -21,7 +21,8 @@ cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op_def = def.cast_final_safe<GetVarShape>();
return opr::GetVarShape::make(inputs, op_def.param()).node()->owner_opr();
OperatorNodeConfig config{op_def.make_name()};
return opr::GetVarShape::make(inputs, op_def.param(), config).node()->owner_opr();
}
DispatchMode decide_dispatch_mode(
......@@ -152,7 +153,7 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node(
auto&& graph = inputs[0]->owner_graph();
auto&& shapes = get_shapes(param.shapes);
cg::OperatorNodeConfig config;
OperatorNodeConfig config(param.make_name());
cg::OperatorNodeBase* opr =
graph->insert_opr(std::make_unique<mgb::opr::ParamPackSplit>(
inputs[0], param.offsets, shapes, config));
......@@ -189,7 +190,7 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node(
auto&& graph = inputs[0]->owner_graph();
VarNodeArray inps(inputs.begin(), inputs.end() - 1);
cg::OperatorNodeConfig config;
OperatorNodeConfig config{param.make_name()};
cg::OperatorNodeBase* opr =
graph->insert_opr(std::make_unique<mgb::opr::ParamPackConcat>(
inps, inputs.back(), param.offsets, config));
......
......@@ -20,8 +20,9 @@ namespace { namespace tensorrt_runtime {
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const TensorRTRuntime&>(def);
OperatorNodeConfig config{op.make_name()};
SymbolVarArray sinputs(inputs.begin(), inputs.end());
return opr::TensorRTRuntimeOpr::make(op.buf.c_str(), op.buf_size, sinputs);
return opr::TensorRTRuntimeOpr::make(op.buf.c_str(), op.buf_size, sinputs, config);
}
OP_TRAIT_REG(TensorRTRuntime, TensorRTRuntime)
.apply_on_var_node(apply_on_var_node)
......
......@@ -21,7 +21,8 @@ namespace { namespace warp_affine {
const VarNodeArray& inputs) {
mgb_assert(inputs.size() == 3);
auto&& op = static_cast<const WarpAffine&>(def);
return opr::WarpAffine::make(inputs[0], inputs[1], inputs[2], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::WarpAffine::make(inputs[0], inputs[1], inputs[2], op.param(), config);
}
OP_TRAIT_REG(WarpAffine, WarpAffine)
......
......@@ -36,6 +36,7 @@ class OpDef : public Hashable,
public NonCopyableObj,
public std::enable_shared_from_this<OpDef> {
mutable const OpTrait* m_trait = nullptr;
std::string m_scope;
public:
virtual ~OpDef() = default;
......@@ -86,10 +87,14 @@ public:
const OpTrait* trait() const;
const char* name() const;
std::string to_string() const;
const std::string scope() const;
const std::string make_name() const;
void set_scope(const std::string& scope);
virtual size_t hash() const;
virtual bool is_same_st(const Hashable&) const;
......
......@@ -113,9 +113,10 @@ static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) {
"{0}({0}_)", i.name
));
}
paramList.push_back("std::string scope_ = {}");
gen_ctor(llvm::join(paramList, ", "),
": " + llvm::join(initList, ", "),
" {}");
" { set_scope(scope_); }");
}
auto packedParams = op.getPackedParams();
......@@ -236,11 +237,19 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx);
os << "}\n";
// generate make_name()
os << formatv(
"std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name")
);
os << mlir::tblgen::tgfmt(hashable->getNameFunctionTemplate(), &ctx);
os << "}\n";
os << "} // anonymous namespace\n";
methods.push_back("hash");
methods.push_back("is_same_st");
methods.push_back("props");
methods.push_back("make_name");
}
if (!methods.empty()) {
os << formatv(
......@@ -327,7 +336,7 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext&
targs.push_back(i.attr.getReturnType());
}
os << llvm::join(targs, ", ");
os << ">()";
os << ", std::string>()";
for (auto &&i : op.getMgbAttributes()) {
os << formatv(", py::arg(\"{0}\")", i.name);
auto defaultValue = i.attr.getDefaultValue();
......@@ -337,7 +346,7 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext&
hasDefaultCtor = true;
}
}
os << ")";
os << ", py::arg(\"scope\") = {})";
}
if (hasDefaultCtor) {
os << "\n .def(py::init<>())";
......@@ -442,6 +451,10 @@ EnumWrapper<{0}::{1}>::type2str = {{
className, i.name));
}
getsetters.push_back(formatv(
"{{\"scope\", py_get_scope({0}), py_set_scope({0}), \"scope\", NULL},",
className));
// generate tp_init
std::string initBody;
if (!op.getMgbAttributes().empty()) {
......@@ -449,6 +462,7 @@ EnumWrapper<{0}::{1}>::type2str = {{
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
initBody += formatv("\"{0}\", ", attr.name);
});
initBody += "\"scope\", ";
initBody += "NULL};\n";
initBody += " PyObject ";
std::vector<std::string> attrs;
......@@ -456,12 +470,15 @@ EnumWrapper<{0}::{1}>::type2str = {{
attrs.push_back(formatv("*{0} = NULL", attr.name));
});
initBody += llvm::join(attrs, ", ") + ";\n";
initBody += " PyObject *scope = NULL;\n";
initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
initBody += std::string(op.getMgbAttributes().size(), 'O');
// an extra slot created for name
initBody += std::string(op.getMgbAttributes().size() + 1, 'O');
initBody += "\", const_cast<char**>(kwlist)";
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
initBody += formatv(" ,&{0}", attr.name);
initBody += formatv(", &{0}", attr.name);
});
initBody += ", &scope";
initBody += "))\n";
initBody += " return -1;\n";
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
......@@ -483,6 +500,25 @@ EnumWrapper<{0}::{1}>::type2str = {{
}
)", className, attr.name);
});
initBody += formatv(R"(
if (scope) {{
try {{
reinterpret_cast<PyOp({0})*>(self)->inst().set_scope(
pyobj_convert_generic<std::string>::from(scope));
} catch(py::error_already_set& e) {{
e.restore();
return -1;
} catch(py::builtin_exception& e) {{
e.set_error();
return -1;
} catch(...) {{
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
return -1;
}
}
)", className);
}
initBody += "\n return 0;";
......
......@@ -241,6 +241,30 @@ private:
body += " return props_;\n";
return body;
}
std::string getModeName() const {
std::string body = formatv(
" auto&& op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
getCppClassName()
);
for (auto&& it : getMgbAttributes()) {
if (it.name == "mode") {
auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr);
body += " switch (op_.mode){\n";
for (auto&& enumMember: enumAttr->getEnumMembers()) {
body += formatv(
" case {0}::{1}::{2}:\n",
getCppClassName(), enumAttr->getEnumName(), enumMember
);
body += formatv(" return \"{0}\";\n", enumMember);
}
body += formatv(
" default: return \"{0}::Unknown\";\n", getCppClassName());
body += " }\n";
}
}
return body;
}
public:
static bool classof(const Operator* op) {
return op->getDef().isSubClassOf("MgbHashableOpMixin");
......@@ -264,6 +288,12 @@ public:
}
return getDefaultPropsFunction();
}
std::string getNameFunctionTemplate() const {
if (getDef().getValueAsBit("usingModeName")) {
return getModeName();
}
return formatv(" return \"{0}\";\n", getCppClassName());
}
};
} // namespace tblgen
......
......@@ -476,6 +476,7 @@ def main():
output_mgbvars = feeds["outputs"]
output_mgbvars = optimize_for_inference(args, output_mgbvars)
output_mgbvars = [var._node for var in output_mgbvars]
inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy")
inputs = sorted((i.name, i.dtype) for i in inputs)
......
......@@ -242,6 +242,7 @@ class MgbPackedParamBase<string className, string accessor>:
class MgbHashableOpMixin {
string hashFunction = ?;
string cmpFunction = ?;
bit usingModeName = 0;
}
class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>:
......
......@@ -21,6 +21,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> {
let inputs = (ins Variadic<AnyType>:$input);
let results = (outs AnyType);
let usingModeName = 1;
}
def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>;
......@@ -247,6 +248,7 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara
let extraArguments = (ins
MgbDTypeAttr:$dtype
);
let usingModeName = 1;
}
def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册