diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index 73eb4b984972c29d4341e4626fecb3aae45ee852..afd916113b41fa48a6e2cae854c2bf2c2b0ae3c6 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -96,7 +96,7 @@ class Graph(_imperative_rt.ComputingGraph): data = data.numpy() return self._wrap(_imperative_rt.make_const(self, data, device, data.dtype)) - def make_const(self, data, dtype=None, device=None): + def make_const(self, data, dtype=None, device=None, name=None): if isinstance(data, _imperative_rt.DeviceTensorND): assert dtype is None and device is None return self._wrap(_imperative_rt.make_shared(self, data)) @@ -107,7 +107,9 @@ class Graph(_imperative_rt.ComputingGraph): elif data.dtype == np.int64: data = data.astype(np.int32) device = as_device(device).to_c() - return self._wrap(_imperative_rt.make_const(self, data, device, dtype)) + return self._wrap( + _imperative_rt.make_const(self, data, device, dtype, name) + ) def make_input(self, *args: "VarNode", device=None, dtype=None, shape=None): opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) @@ -305,7 +307,7 @@ def dump_graph( output_vars: Union[Dict[str, VarNode], List[VarNode]], *, keep_var_name: int = 1, - keep_op_name: bool = True, + keep_opr_name: bool = False, keep_param_name: bool = False, keep_opr_priority: bool = False, strip_info_file=None, @@ -326,7 +328,7 @@ def dump_graph( * 0: none of the names are kept * 1: (default)keep names of output vars * 2: keep names of all (output and internal) vars - :param keep_op_name: whether to keep operator names. + :param keep_opr_name: whether to keep operator names. :param keep_param_name: whether to keep param names, so param values can be easily manipulated after loading model :param keep_opr_priority: whether to keep priority setting for operators @@ -370,7 +372,7 @@ def dump_graph( dump_content = _imperative_rt.dump_graph( ov, keep_var_name, - keep_op_name, + keep_opr_name, keep_param_name, keep_opr_priority, stat, diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index dca50c47fb300090f1f091744ccedfddc8bbe484..73817a0a6172fb6bbbf75a92dba1c5836ffd728b 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -36,6 +36,7 @@ from ..core.ops.builtin import BackwardGraph, OpDef from ..core.ops.special import Const from ..core.tensor import megbrain_graph as G from ..core.tensor.utils import setscalar +from ..utils.naming import auto_naming from .sublinear_memory_config import SublinearMemoryConfig @@ -77,6 +78,7 @@ def exclude_from_trace(): class TensorInfo: __slots__ = ( # collected attributes + "name", "external", "data_read", "shape_read", @@ -96,6 +98,7 @@ class TensorInfo: ) def __init__(self): + self.name = None self.exported = None self.data_read = None self.shape_read = None @@ -290,12 +293,16 @@ class trace: h = getattr(x, "_mixin_handle", -1) if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): h, info = self._new_handle() + name = auto_naming.get_scope() + "." + x.c_name if x.c_name else x._name + info.name = name info.external = True info.device = x.device info.dtype = x.dtype info.shape = x.shape if self._capture_as_const: - info.bound_data = RawTensor(x.numpy(), x.dtype, x.device, False) + info.bound_data = RawTensor( + x.numpy(), x.dtype, x.device, False, name + ) ihandles.append(h) @@ -669,6 +676,12 @@ class trace: arg_names=None, output_names=None, append=False, + keep_var_name: int = 1, + keep_opr_name: bool = False, + keep_param_name: bool = False, + keep_opr_priority: bool = False, + strip_info_file=None, + append_json=False, optimize_for_inference=True, **kwargs ): @@ -681,6 +694,20 @@ class trace: use the default name if not specified. :param append: whether output is appended to ``file``. Only works when ``file`` is str. + :param keep_var_name: level for keeping variable names: + + * 0: none of the names are kept + * 1: (default)keep names of output vars + * 2: keep names of all (output and internal) vars + :param keep_opr_name: whether to keep operator names. + :param keep_param_name: whether to keep param names, so param values can be + easily manipulated after loading model + :param keep_opr_priority: whether to keep priority setting for operators + :param strip_info_file: a string for path or a file handler. if is not None, + then the dump information for code strip would be written to ``strip_info_file`` + :param append_json: will be check when `strip_info_file` is not None. if set + true, the information for code strip will be append to strip_info_file. + if set false, will rewrite strip_info_file :param optimize_for_inference: enbale optmizations, will skip all optimize options if this is False. Default: True @@ -785,7 +812,10 @@ class trace: assert info.external assert info.bound_data h2v[h] = graph.make_const( - info.bound_data.numpy(), dtype=info.dtype, device=info.device, + info.bound_data.numpy(), + dtype=info.dtype, + device=info.device, + name=info.name, ) continue ivars = [] @@ -795,13 +825,26 @@ class trace: assert info.external assert info.bound_data h2v[h] = graph.make_const( - info.bound_data.numpy(), dtype=info.dtype, device=dumped_device + info.bound_data.numpy(), + dtype=info.dtype, + device=dumped_device, + name=info.name, ) ivars.append(h2v[h]) ovars = G.apply_normal_varnode(op, *ivars) + + auto_naming.record_opnode(ovars[0].op) + assert len(ovars) == len(ohandles) h2v.update(zip(ohandles, ovars)) + for i in ohandles: + name = auto_naming.get_var_name(i) + if name is not None: + h2v[i].name = name + + auto_naming.remove_duplicate_names() + dest_vars = [] for i, h in enumerate(self._output_bindings): v = h2v[h] @@ -815,7 +858,15 @@ class trace: if isinstance(file, str): permission = "wb" if append == False else "ab" file = open(file, permission) - dump_content, dump_info = G.dump_graph(dest_vars) + dump_content, dump_info = G.dump_graph( + dest_vars, + keep_var_name=keep_var_name, + keep_opr_name=keep_opr_name, + keep_param_name=keep_param_name, + keep_opr_priority=keep_opr_priority, + strip_info_file=strip_info_file, + append_json=append_json, + ) file.write(dump_content) return dump_info @@ -1095,20 +1146,22 @@ def apply_compiled_mode(op: OpDef, *args: RawTensor): return active_trace._apply_op(op, args) -def apply_const_compiled_mode(value, dtype, device, is_const, no_cache): +def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name): if skip_tracing: args = [ RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x for x in args ] unset_tracing() - ret = RawTensor(value, dtype, device, False) + ret = RawTensor(value, dtype, device, False, name) set_tracing() return ret return active_trace._apply_const(value, dtype, device) def apply_with_tracing(op: OpDef, *args: RawTensor): + if hasattr(op, "scope"): + op.scope = auto_naming.get_scope() if active_trace._symbolic: outputs = apply_symbolic_mode(op, *args) else: @@ -1120,12 +1173,12 @@ def apply_with_tracing(op: OpDef, *args: RawTensor): return list(outputs) -def apply_const_with_tracing(value, dtype, device, is_const, no_cache): +def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name): if active_trace._symbolic: outputs = apply_const_symbolic_mode(value, dtype, device) else: unset_tracing() - outputs = (RawTensor(value, dtype, device, False),) + outputs = (RawTensor(value, dtype, device, False, name),) set_tracing() active_trace._record_const(outputs) return list(outputs) diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index 28dcd76064f8ac5af9f8320b60328af089abe22d..1861ee287062d89e630ff716aefb73ded5b8ea8c 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -12,12 +12,12 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union import numpy as np -from ..core._imperative_rt.core2 import pop_scope, push_scope from ..core.tensor.utils import make_shape_tuple from ..logger import get_logger from ..tensor import Parameter, Tensor from ..utils.deprecation import deprecated from ..utils.hook import HookHandler +from ..utils.naming import auto_naming logger = get_logger(__name__) @@ -69,7 +69,9 @@ class Module(metaclass=ABCMeta): Base Module class. """ - def __init__(self): + def __init__(self, name=""): + self.name = name + # runtime attributes self.training = True self.quantize_disabled = False @@ -79,6 +81,8 @@ class Module(metaclass=ABCMeta): self._forward_hooks = OrderedDict() self._modules = [] + + # used for profiler and automatic naming self._name = "{anonymous}" @abstractmethod @@ -105,7 +109,7 @@ class Module(metaclass=ABCMeta): return HookHandler(self._forward_hooks, hook) def __call__(self, *inputs, **kwargs): - push_scope(self._name) + auto_naming.push_scope(self.name if self.name else self._name) for hook in self._forward_pre_hooks.values(): modified_inputs = hook(self, inputs) if modified_inputs is not None: @@ -119,7 +123,7 @@ class Module(metaclass=ABCMeta): modified_outputs = hook(self, inputs, outputs) if modified_outputs is not None: outputs = modified_outputs - pop_scope(self._name) + auto_naming.pop_scope() return outputs def _flatten( @@ -579,7 +583,7 @@ class Module(metaclass=ABCMeta): value = super().__getattribute__(name) if name == "_name": return value - if _is_module(value): + if isinstance(value, (Tensor, Module)): value._name = name return value diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index 2da9c774ce971beef275afdd94c4411fa0797a74..fc0ca2d351ab8d146f8992f1644066725d10cfe3 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -20,6 +20,7 @@ from .core.tensor.array_method import ArrayMethodMixin from .device import _valid_device, get_default_device from .logger import get_logger from .utils.deprecation import deprecated +from .utils.naming import auto_naming class Tensor(_Tensor, ArrayMethodMixin): @@ -27,7 +28,9 @@ class Tensor(_Tensor, ArrayMethodMixin): dmap_callback = None _q_dict = None - def __new__(cls, data, dtype=None, device=None, is_const=False, no_cache=False): + def __new__( + cls, data, dtype=None, device=None, is_const=False, no_cache=False, name="" + ): if device is None: cn = get_default_device() elif isinstance(device, str): @@ -51,8 +54,7 @@ class Tensor(_Tensor, ArrayMethodMixin): if isinstance(data, np.ndarray): if 0 in data.strides: data = data.squeeze().reshape(data.shape) - - obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache) + obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache, name) return obj @property @@ -91,6 +93,15 @@ class Tensor(_Tensor, ArrayMethodMixin): piece += ", device={}".format(self.device) + ")" return piece + @property + def name(self): + return self.c_name + + @name.setter + def name(self, name): + self.c_name = name + auto_naming.record_var_name(self._mixin_handle, name) + @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") def set_value(self, value): if not isinstance(value, _Tensor): diff --git a/imperative/python/megengine/utils/naming.py b/imperative/python/megengine/utils/naming.py new file mode 100644 index 0000000000000000000000000000000000000000..f5eb71cf78cbf37e90abe3aae7b78b81d3b9abcd --- /dev/null +++ b/imperative/python/megengine/utils/naming.py @@ -0,0 +1,63 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +from ..core._imperative_rt.core2 import pop_scope, push_scope + + +class AutoNaming: + r""" + Name all executed operators automaticlly during tracing and record all tensors + renamed by the user. + """ + + def __init__(self): + self.scopes = [] + self.c_ops = [] + self.name2ops = {} + self.handle2names = {} + + def clear(self): + for var in vars(self).values(): + var.clear() + + def push_scope(self, scope): + push_scope(scope) + self.scopes.append(scope) + + def pop_scope(self): + scope = self.scopes.pop() + pop_scope(scope) + + def get_scope(self): + return ".".join(self.scopes) + + def record_var_name(self, handle, name): + self.handle2names[handle] = name + + def get_var_name(self, handle): + return self.handle2names.pop(handle, None) + + def record_opnode(self, op): + ops = self.name2ops.get(op.name, []) + ops.append(op) + self.name2ops[op.name] = ops + + def remove_duplicate_names(self): + for key, ops in self.name2ops.items(): + if len(ops) == 1: + continue + for i, op in enumerate(ops): + op.name = key + "[%s]" % str(i) + if len(op.outputs) == 1: + continue + for var in op.outputs: + var.name = var.name.replace(key, op.name) + self.name2ops.clear() + + +auto_naming = AutoNaming() diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 4a913449b42bd20de6c73b7607a2845c59341958..1483ed3f8307f4a216725087a2340f274254baaf 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -294,7 +294,7 @@ void init_graph_rt(py::module m) { m.def("dump_graph", []( const std::vector& dest_vars, int keep_var_name, - bool keep_op_name, + bool keep_opr_name, bool keep_param_name, bool keep_opr_priority, py::list& stat, @@ -307,7 +307,7 @@ void init_graph_rt(py::module m) { SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, - keep_opr_priority, keep_op_name}; + keep_opr_priority, keep_opr_name}; auto rst = dumper->dump(symvars, config); for (auto i : rst.inputs) { @@ -457,13 +457,17 @@ void init_graph_rt(py::module m) { return opr::SharedDeviceTensor::make(*graph, std::make_shared(data)).node(); }); - m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) { + m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype, std::optional name) { if (!cn.valid()) { cn = CompNode::load(get_default_device()); } + OperatorNodeConfig config(cn); + if (name) { + config.name(*name); + } auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); - return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); - }); + return opr::ImmutableTensor::make(*graph, hv, config).node(); + }, py::arg(), py::arg(), py::arg(), py::arg(), py::arg() = py::none()); m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, TensorShape shape, std::optional name) { if (!cn.valid()) { diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index edd7f43d47e843e936aeed1b1455ed2afd8e4cb5..c1790eade59ebdc35ffad0cfbb076f72d7df08d0 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -99,6 +99,14 @@ PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) { #define py_get_generic(name, attr) \ py_get_generic_impl().attr), &name::attr> +template +PyObject* py_get_scope_impl(PyObject* obj, void* /* closure */) { + // T: PyOpXXX inst(): return XXX in opdef.h.inl + auto& op = reinterpret_cast(obj)->inst(); + return pyobj_convert_generic::to(op.scope()); +} +#define py_get_scope(class) py_get_scope_impl + template int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { if (value == NULL) { @@ -121,6 +129,27 @@ int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { #define py_set_generic(name, attr) \ py_set_generic_impl().attr), &name::attr> +template +int py_set_scope_impl(PyObject* obj, PyObject* value, void* /* closure */) { + if (value == NULL) { + PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute"); + return -1; + } + auto& op = reinterpret_cast(obj)->inst(); + try { + op.set_scope(pyobj_convert_generic::from(value)); + return 0; + } catch(py::error_already_set& e) { + e.restore(); + } catch(py::builtin_exception& e) { + e.set_error(); + } catch(...) { + PyErr_SetString(PyExc_RuntimeError, "Unknown Error"); + } + return -1; +} +#define py_set_scope(class) py_set_scope_impl + struct PyOpDef { PyObject_HEAD std::shared_ptr op; diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 2ab6a27f7c2ac490df9707a8d6af87613b600c58..2fb3cefdbc49f73d4282c4a94fcbf925d95ceb09 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include @@ -222,14 +223,15 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { } } else { py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType - if (nargs != 4 && nargs != 5) { - throw py::type_error("expect 4 or 5 arguments"); + if (nargs != 5 && nargs != 6) { + throw py::type_error("expect 5 or 6 arguments"); } auto data = tup[0].cast(); DType dtype = tup[1].cast(); CompNode cn = tup[2].cast(); bool is_const = tup[3].cast(); - bool no_cache = nargs == 5 ? tup[4].cast() : false; + bool no_cache = nargs == 6 ? tup[4].cast() : false; + std::string name = tup[nargs - 1].cast(); // const op if (is_const && is_tracing) { @@ -259,6 +261,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { } m_tensor = std::make_shared(handle); + m_tensor->user_custom_name = name; if (data.ndim() == 0) { m_tensor->m_flags |= Tensor::Flags::SCALAR; @@ -313,6 +316,19 @@ REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(trace_mixin_info) #undef REGISTE_TENSORWRAPPER_PYOBJECT_FUNC +#define SET_GET_NAME(member) \ + PyObject* TensorWrapper::member() { \ + return py::cast(m_tensor->member).release().ptr(); \ + } \ + void TensorWrapper::set_##member(PyObject* dest) { \ + auto py_dest = py::reinterpret_borrow(dest); \ + m_tensor->member = py_dest.cast(); \ + } +SET_GET_NAME(user_custom_name) +SET_GET_NAME(automatic_name) +#undef SET_GET_NAME + + PyObject* TensorWrapper::handle() { return py::cast(m_tensor->m_handle).release().ptr(); } @@ -453,7 +469,11 @@ void TensorWrapper::reset(PyObject* tensor) { if (!t) { throw py::type_error("expect Tensor"); } + std::string user_custom_name = m_tensor->user_custom_name; + std::string automatic_name = m_tensor->automatic_name; m_tensor = t->m_tensor; + m_tensor->user_custom_name = user_custom_name; + m_tensor->automatic_name = automatic_name; } void TensorWrapper::reset_varnode() { @@ -785,6 +805,8 @@ void init_tensor(py::module m) { .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") .def_getset<&TensorWrapper::compiled_info, &TensorWrapper::set_compiled_info>("_compiled_info") .def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info") + .def_getset<&TensorWrapper::user_custom_name, &TensorWrapper::set_user_custom_name>("c_name") + .def_getset<&TensorWrapper::automatic_name, &TensorWrapper::set_automatic_name>("_name") .finalize(); if (!tensor_type) throw py::error_already_set(); py::setattr(m, "Tensor", tensor_type); diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index 84ac2efb7138e8458260c5ba434f9a70a8e67680..57c95069cc729ae79bcef966a2035ff4762af58b 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -15,6 +15,7 @@ #include "megbrain/imperative/interpreter.h" #include "pybind11/pybind11.h" +#include #include "./pyext17.h" @@ -70,6 +71,8 @@ struct Tensor : std::enable_shared_from_this, NonCopyableObj { GradInfo m_grad_info; TraceInfo m_trace_info; SharedHandle m_handle; + std::string user_custom_name; + std::string automatic_name; cg::VarNode* m_var; using Handle = interpreter::Interpreter::Handle; @@ -170,6 +173,10 @@ struct TensorWrapper { void set_compiled_info(PyObject *); PyObject* trace_mixin_info(); void set_trace_mixin_info(PyObject *); + PyObject* user_custom_name(); + void set_user_custom_name(PyObject *); + PyObject* automatic_name(); + void set_automatic_name(PyObject *); PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; }; diff --git a/imperative/python/test/unit/test_dump_naming.py b/imperative/python/test/unit/test_dump_naming.py new file mode 100644 index 0000000000000000000000000000000000000000..019d885479e2d1bbd8f8fd2ea3f7363e39948d6e --- /dev/null +++ b/imperative/python/test/unit/test_dump_naming.py @@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import io + +import numpy as np +import pytest + +import megengine.functional as F +import megengine.module as M +import megengine.utils.comp_graph_tools as cgtools +from megengine import Parameter, Tensor +from megengine.core.tensor import megbrain_graph as G +from megengine.jit.tracing import trace +from megengine.utils.naming import auto_naming + + +def _dump_and_load(func, symbolic, keep_opr_name=True): + auto_naming.clear() + func = trace(func, symbolic=symbolic, capture_as_const=True) + x = Tensor(np.ones(shape=(2, 3))) + func(x).numpy() + file = io.BytesIO() + func.dump( + file, + optimize_for_inference=False, + arg_names="x", + keep_opr_name=keep_opr_name, + keep_var_name=2, + ) + file.seek(0) + *_, outputs = G.load_graph(file) + op = cgtools.get_oprs_seq(outputs)[-1] + return op + + +@pytest.mark.parametrize("symbolic", [False, True]) +def test_auto_naming(symbolic): + class Simple(M.Module): + def __init__(self, name): + super().__init__() + self.name = name + + def forward(self, x): + return x + x + + m = Simple("simple") + op = _dump_and_load(m, symbolic) + assert op.name == "simple.ADD" + assert op.outputs[0].name == "simple.ADD" + + +@pytest.mark.parametrize("symbolic", [False, True]) +def test_user_named_tensor(symbolic): + class Simple(M.Module): + def __init__(self, name): + super().__init__() + self.name = name + self.k = Parameter(1.0, name="k") + + def forward(self, x): + x = x + x + x.name = "o_x" + return x + + m = Simple("simple") + + op = _dump_and_load(m, symbolic) + assert op.name == "simple.ADD" + assert op.outputs[0].name == "o_x" + + +@pytest.mark.parametrize("symbolic", [False, True]) +def test_user_named_param(symbolic): + class Simple(M.Module): + def __init__(self, name): + super().__init__() + self.name = name + self.k = Parameter(2.0, name="k") + + def forward(self, x): + return self.k * x + + m = Simple("simple") + + op = _dump_and_load(m, symbolic) + assert op.inputs[0].name == "x" + assert op.inputs[1].name == "simple.k" + + +@pytest.mark.parametrize("symbolic", [False, True]) +def test_without_module(symbolic): + def f(x): + return 2 * x + + op = _dump_and_load(f, symbolic) + assert op.name == "MUL" + + +@pytest.mark.parametrize("symbolic", [False, True]) +def test_with_submodule(symbolic): + class Simple(M.Module): + def __init__(self, name): + super().__init__() + self.name = name + self.linear = M.Linear(3, 3) + + def forward(self, x): + x = self.linear(x) + return x + + m = Simple("simple") + + op = _dump_and_load(m, symbolic) + assert op.name == "simple.linear.ADD" + assert op.inputs[0].owner.name == "simple.linear.MatrixMul" + assert op.outputs[0].name == "simple.linear.ADD" + + +@pytest.mark.parametrize("symbolic", [False, True]) +def test_named_submodule(symbolic): + class Simple(M.Module): + def __init__(self, name): + super().__init__() + self.name = name + self.linear = M.Linear(3, 3, name="x") + + def forward(self, x): + x = self.linear(x) + return x + + m = Simple("simple") + + op = _dump_and_load(m, symbolic) + assert op.name == "simple.x.ADD" + assert op.inputs[0].owner.name == "simple.x.MatrixMul" + assert op.outputs[0].name == "simple.x.ADD" + + +@pytest.mark.parametrize("symbolic", [False, True]) +def test_with_same_operators(symbolic): + class Simple(M.Module): + def __init__(self, name): + super().__init__() + self.name = name + + def forward(self, x): + x = F.relu(x) + x = F.relu(x) + return x + + m = Simple("simple") + + op = _dump_and_load(m, symbolic) + assert op.name == "simple.RELU[1]" + assert op.inputs[0].owner.name == "simple.RELU[0]" + + +def test_not_keep_opr_name(): + def f(x): + return 2 * x + + op = _dump_and_load(f, True, False) + assert op.name == "MUL(x,2[2])[4]" diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index 4a8b135297a54197eeedf9ed8f75984da3075b17..d838149d8ef425b48b6592555e5938f1c6425043 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -148,7 +148,7 @@ def test_dump(): dump_info = f.dump(file) assert dump_info.nr_opr == 3 np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) - np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"]) + np.testing.assert_equal(dump_info.outputs, ["ADD"]) file.seek(0) infer_cg = cgtools.GraphInference(file) result = list((infer_cg.run(a, b)).values())[0] diff --git a/imperative/src/impl/op_def.cpp b/imperative/src/impl/op_def.cpp index fcfd8fa8cd621fc223f413e637e85ae74a48067f..96d2812dda8f35f73546576d474da4da8f40f652 100644 --- a/imperative/src/impl/op_def.cpp +++ b/imperative/src/impl/op_def.cpp @@ -75,10 +75,6 @@ std::vector> OpDef::props( return def.trait()->props(def); } -const char* OpDef::name() const { - return trait()->name; -} - std::string OpDef::to_string() const { std::string builder = "{"; for (auto&& [name, value]: props(*this)) { @@ -107,6 +103,20 @@ const OpTrait* OpDef::trait() const { return m_trait; } +const std::string OpDef::scope() const { + return m_scope; +} + +void OpDef::set_scope(const std::string& scope) { + m_scope = scope; +} + +const std::string OpDef::make_name() const { + if (m_scope.empty()) + return trait()->make_name(*this); + return m_scope + "." + trait()->make_name(*this); +} + } // namespace imperative } // namespace mgb diff --git a/imperative/src/impl/op_trait.h b/imperative/src/impl/op_trait.h index 3c49fc6944a3122b4145f6dafdc12e37f3996d3d..8abd638d4a7771d40e615f58f7e89cec1eca6e5b 100644 --- a/imperative/src/impl/op_trait.h +++ b/imperative/src/impl/op_trait.h @@ -75,6 +75,7 @@ using GradMaker = detail::OpMeth< using Props = detail::OpMeth; using HashFunc = detail::OpMeth; using IsSame = detail::OpMeth; +using MakeNameFunc = detail::OpMeth; struct OpTrait { const char* name; @@ -88,6 +89,7 @@ struct OpTrait { Props props; HashFunc hash; IsSame is_same_st; + MakeNameFunc make_name; OpTrait(const char* name); static OpTrait* find_by_name(const char* name); static OpTrait* find_by_typeinfo(Typeinfo* type); @@ -104,7 +106,8 @@ struct OpTrait { cb(make_backward_graph) \ cb(props) \ cb(hash) \ - cb(is_same_st) + cb(is_same_st) \ + cb(make_name) struct OpTraitRegistry { OpTrait* trait; diff --git a/imperative/src/impl/ops/batch_norm.cpp b/imperative/src/impl/ops/batch_norm.cpp index 6acbe826e431131239885ba45dedab2a15a1edab..cad0b7dad069719ca078aa0ceb95f3f580f54a47 100644 --- a/imperative/src/impl/ops/batch_norm.cpp +++ b/imperative/src/impl/ops/batch_norm.cpp @@ -30,13 +30,14 @@ cg::OperatorNodeBase* apply_on_var_node( size_t nr_inp = inputs.size(); mgb_assert(nr_inp == 3 ||nr_inp == 5, "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); + OperatorNodeConfig config{bn_opr.make_name()}; if (nr_inp == 3) { return opr::BatchNorm::make( - inputs[0], inputs[1], inputs[2], bn_opr.param())[0] + inputs[0], inputs[1], inputs[2], bn_opr.param(), config)[0] .node()->owner_opr(); } else { return opr::BatchNorm::make( - inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], bn_opr.param())[0] + inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], bn_opr.param(), config)[0] .node()->owner_opr(); } } diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp index 71a019ff5c4f55672589426cba8b73f075b3f97d..dac2282ad2c215e6547fc031d78cf0b0ea24cdc1 100644 --- a/imperative/src/impl/ops/broadcast.cpp +++ b/imperative/src/impl/ops/broadcast.cpp @@ -27,10 +27,11 @@ std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { cg::OperatorNodeBase* apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { - def.cast_final_safe(); + auto&& op = def.cast_final_safe(); size_t nr_inp = inputs.size(); mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); - return opr::Broadcast::make(inputs[0], inputs[1]).node()->owner_opr(); + OperatorNodeConfig config{op.make_name()}; + return opr::Broadcast::make(inputs[0], inputs[1], config).node()->owner_opr(); } bool valid_broadcast(const TensorShape& src_shape, @@ -96,7 +97,8 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 2); - return opr::Reshape::make(inputs[0], inputs[1], op.param()); + OperatorNodeConfig config{op.make_name()}; + return opr::Reshape::make(inputs[0], inputs[1], op.param(), config); } std::tuple, bool> infer_output_attrs_fallible( diff --git a/imperative/src/impl/ops/collective_comm.cpp b/imperative/src/impl/ops/collective_comm.cpp index 03786b830fd00999d1719ef86e24626ef8c4b835..57e351de7aaca09f02e396cf9e7c1a990f40d6a4 100644 --- a/imperative/src/impl/ops/collective_comm.cpp +++ b/imperative/src/impl/ops/collective_comm.cpp @@ -35,7 +35,7 @@ cg::OperatorNodeBase* apply_on_var_node( auto disable = std::make_shared(); disable->set(0); - cg::OperatorNodeConfig config; + OperatorNodeConfig config{comm.make_name()}; if (comm.comp_node.size() > 0) { config.comp_node(CompNode::load(comm.comp_node)); } diff --git a/imperative/src/impl/ops/cond_take.cpp b/imperative/src/impl/ops/cond_take.cpp index 04083afdff6ada9814c65e0c4ba19940a68841eb..bdab57db95599fef09b90087adbc35a9fa1c6082 100644 --- a/imperative/src/impl/ops/cond_take.cpp +++ b/imperative/src/impl/ops/cond_take.cpp @@ -23,12 +23,12 @@ namespace { cg::OperatorNodeBase* apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { - def.cast_final_safe(); + auto&& op = def.cast_final_safe(); auto&& graph = inputs[0]->owner_graph(); opr::CondTake::Param param; param.val = 1; - cg::OperatorNodeConfig config; + OperatorNodeConfig config{op.make_name()}; cg::OperatorNodeBase* opr = graph->insert_opr( std::make_unique( inputs[0], inputs[1], param, config)); diff --git a/imperative/src/impl/ops/elemwise.cpp b/imperative/src/impl/ops/elemwise.cpp index f1aab9de22932d89225b947dbcd84afec8dd2c55..15f54af7788831c570c0ea6ca407e1036aaa5e5b 100644 --- a/imperative/src/impl/ops/elemwise.cpp +++ b/imperative/src/impl/ops/elemwise.cpp @@ -31,7 +31,8 @@ cg::OperatorNodeBase* apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& elemwise_opr = def.cast_final_safe(); - return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr(); + OperatorNodeConfig config{elemwise_opr.make_name()}; + return opr::Elemwise::make(inputs, elemwise_opr.mode, config).node()->owner_opr(); } std::tuple, bool> infer_output_attrs_fallible( diff --git a/imperative/src/impl/ops/img_proc.cpp b/imperative/src/impl/ops/img_proc.cpp index 38497f7d2ba702d248a49da8331d82d316528136..2e1aacebcce18b2740fbd5f7c96595a6e7d10d88 100644 --- a/imperative/src/impl/ops/img_proc.cpp +++ b/imperative/src/impl/ops/img_proc.cpp @@ -23,7 +23,8 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 1); - return opr::CvtColor::make(inputs[0], op.param()); + OperatorNodeConfig config{op.make_name()}; + return opr::CvtColor::make(inputs[0], op.param(), config); } OP_TRAIT_REG(CvtColor, CvtColor) .apply_on_var_node(apply_on_var_node) diff --git a/imperative/src/impl/ops/io_remote.cpp b/imperative/src/impl/ops/io_remote.cpp index 99f83ea43acba670ac3d285aa409ee4155ef2880..ed0398d74522cc8e76539c755e611a36235064fd 100644 --- a/imperative/src/impl/ops/io_remote.cpp +++ b/imperative/src/impl/ops/io_remote.cpp @@ -32,7 +32,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send( ssprintf("%s:%d", send.addr.data(), send.port)); auto&& graph = inputs[0]->owner_graph(); - cg::OperatorNodeConfig config; + OperatorNodeConfig config{send.make_name()}; cg::OperatorNodeBase* opr = graph->insert_opr(std::make_unique( send.key, inputs[0], group_client, true, config)); @@ -42,11 +42,13 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send( cg::OperatorNodeBase* apply_on_var_node_remote_recv( const OpDef& def, const VarNodeArray& inputs) { auto&& recv = def.cast_final_safe(); + OperatorNodeConfig config{recv.cn}; + config.name(recv.make_name()); auto group_client = std::make_shared( ssprintf("%s:%d", recv.addr.data(), recv.port)); auto&& graph = inputs[0]->owner_graph(); return graph->insert_opr(std::make_unique( - recv.key, inputs[0], *graph, group_client, OperatorNodeConfig{recv.cn}, + recv.key, inputs[0], *graph, group_client, config, recv.shape, recv.dtype)); } diff --git a/imperative/src/impl/ops/matrix_inverse.cpp b/imperative/src/impl/ops/matrix_inverse.cpp index b20794e09a08686ead046facdaf95f030d543217..b9b7ce2f8d1504220f0d83e739e4c88ec9c5750a 100644 --- a/imperative/src/impl/ops/matrix_inverse.cpp +++ b/imperative/src/impl/ops/matrix_inverse.cpp @@ -21,8 +21,10 @@ namespace { auto apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { + auto&& op = def.cast_final_safe(); mgb_assert(inputs.size() == 1); - return opr::MatrixInverse::make(inputs[0]); + OperatorNodeConfig config{op.make_name()}; + return opr::MatrixInverse::make(inputs[0], {}, config); } OP_TRAIT_REG(MatrixInverse, MatrixInverse) .apply_on_var_node(apply_on_var_node) diff --git a/imperative/src/impl/ops/nms.cpp b/imperative/src/impl/ops/nms.cpp index 5898d31547be1db8eaf6c666aa349135fc40f74a..7f4aacb8002727a9dfdb29964f713171f0b9dfad 100644 --- a/imperative/src/impl/ops/nms.cpp +++ b/imperative/src/impl/ops/nms.cpp @@ -29,7 +29,9 @@ cg::OperatorNodeBase* apply_on_var_node( param.iou_thresh = nms_keep.iou_thresh; param.max_output = nms_keep.max_output; - return NMSKeepOpr::make(inputs[0], param).node()->owner_opr(); + OperatorNodeConfig config{nms_keep.make_name()}; + + return NMSKeepOpr::make(inputs[0], param, config).node()->owner_opr(); } OP_TRAIT_REG(NMSKeep, NMSKeep, NMSKeepOpr) diff --git a/imperative/src/impl/ops/opr_attr.cpp b/imperative/src/impl/ops/opr_attr.cpp index 8905954a679bed24ede98e67c8e9a4cd36c23063..c06cdc10dd06d173738f78f7e08df60a18699ff9 100644 --- a/imperative/src/impl/ops/opr_attr.cpp +++ b/imperative/src/impl/ops/opr_attr.cpp @@ -79,11 +79,13 @@ public: cg::OperatorNodeBase* apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& attr = def.cast_final_safe(); + auto config = attr.config; + config.name(attr.make_name()); mgb_assert(!inputs.empty()); auto registry = serialization::OprRegistry::find_by_name(attr.type); mgb_assert(registry, "operator %s not found", attr.type.c_str()); OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; - return registry->loader(ctx, inputs, attr.config); + return registry->loader(ctx, inputs, config); } std::shared_ptr make_from_op_node(cg::OperatorNodeBase* opr) { @@ -99,10 +101,15 @@ std::vector> props(const OpDef& def) { return {}; } +std::string make_name(const OpDef& def) { + return "OprAttr"; +} + OP_TRAIT_REG(OprAttr, OprAttr) .make_from_op_node(make_from_op_node) .apply_on_var_node(apply_on_var_node) .props(props) + .make_name(make_name) .fallback(); } // anonymous namespace diff --git a/imperative/src/impl/ops/resize.cpp b/imperative/src/impl/ops/resize.cpp index 8beaff73d2403406823aefd6766a8fd00cf8eb31..599f63e015766695568576549bcc0e48d4627c89 100644 --- a/imperative/src/impl/ops/resize.cpp +++ b/imperative/src/impl/ops/resize.cpp @@ -24,7 +24,8 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 2); - return opr::Resize::make(inputs[0], inputs[1], op.param()); + OperatorNodeConfig config{op.make_name()}; + return opr::Resize::make(inputs[0], inputs[1], op.param(), config); } OP_TRAIT_REG(Resize, Resize) diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index 54c084cb76262504a30c4961277287cd74f93e33..36a27e871a3516cc6ccab3d105f7990e81d7b51d 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -46,7 +46,8 @@ auto apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& conv = static_cast(def); - return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy()); + OperatorNodeConfig config{conv.make_name()}; + return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); } OP_TRAIT_REG(Convolution, Convolution, opr::Convolution) @@ -60,7 +61,7 @@ auto apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& conv = static_cast(def); - cg::OperatorNodeConfig config; + OperatorNodeConfig config{conv.make_name()}; if (inputs.size() == 2) { return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); } else { @@ -88,7 +89,8 @@ auto apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& ds = static_cast(def); - return opr::Dimshuffle::make(inputs[0], ds.pattern); + OperatorNodeConfig config{ds.make_name()}; + return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config); } OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle) @@ -107,7 +109,8 @@ auto apply_on_var_node( for (auto&& i : add_axis.axis) { param.push_back(Desc::make_add(i)); } - return opr::AxisAddRemove::make(inputs[0], param); + OperatorNodeConfig config{add_axis.make_name()}; + return opr::AxisAddRemove::make(inputs[0], param, config); } OP_TRAIT_REG(AddAxis, AddAxis) @@ -125,7 +128,8 @@ auto apply_on_var_node( for (auto&& i : remove_axis.axis) { param.push_back(Desc::make_remove(i)); } - return opr::AxisAddRemove::make(inputs[0], param); + OperatorNodeConfig config{remove_axis.make_name()}; + return opr::AxisAddRemove::make(inputs[0], param, config); } OP_TRAIT_REG(RemoveAxis, RemoveAxis) @@ -138,7 +142,8 @@ auto apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& topk = static_cast(def); - return opr::TopK::make(inputs[0], inputs[1], topk.param())[0] + OperatorNodeConfig config{topk.make_name()}; + return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0] .node()->owner_opr(); } @@ -152,10 +157,12 @@ auto apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& reduce = static_cast(def); + OperatorNodeConfig config{reduce.make_name()}; if (inputs.size() > 1) { - return opr::Reduce::make(inputs[0], reduce.param(), inputs[1]); + return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config); } else { - return opr::Reduce::make(inputs[0], reduce.param()); + return opr::Reduce::make( + inputs[0], reduce.param(), (cg::VarNode*)nullptr, config); } } @@ -175,7 +182,8 @@ auto apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& pool = static_cast(def); - return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param()); + OperatorNodeConfig config{pool.make_name()}; + return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), config); } OP_TRAIT_REG(AdaptivePooling, AdaptivePooling) @@ -189,6 +197,7 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& conv = static_cast(def); cg::OperatorNodeConfig config{conv.dtype}; + config.name(conv.make_name()); if (inputs.size() == 2) { return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); } else if (inputs.size() == 3) { @@ -210,6 +219,7 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& conv = static_cast(def); cg::OperatorNodeConfig config{conv.dtype}; + config.name(conv.make_name()); if (inputs.size() == 2) { return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); } else if (inputs.size() == 3) { @@ -230,7 +240,8 @@ auto apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& pool = static_cast(def); - return opr::Pooling::make(inputs[0], pool.param()); + OperatorNodeConfig config{pool.make_name()}; + return opr::Pooling::make(inputs[0], pool.param(), config); } OP_TRAIT_REG(Pooling, Pooling) .apply_on_var_node(apply_on_var_node) @@ -243,8 +254,9 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& matmul = static_cast(def); mgb_assert(inputs.size() == 2); + OperatorNodeConfig config{matmul.make_name()}; return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param(), - matmul.policy()); + matmul.policy(), config); } OP_TRAIT_REG(MatrixMul, MatrixMul) .apply_on_var_node(apply_on_var_node) @@ -257,8 +269,9 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& matmul = static_cast(def); mgb_assert(inputs.size() == 2); + OperatorNodeConfig config{matmul.make_name()}; return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param(), - matmul.policy()); + matmul.policy(), config); } OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) .apply_on_var_node(apply_on_var_node) @@ -267,10 +280,12 @@ OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) namespace { namespace dot { auto apply_on_var_node( - const OpDef&, + const OpDef& def, const VarNodeArray& inputs) { + auto&& op = def.cast_final_safe(); mgb_assert(inputs.size() == 2); - return opr::Dot::make(inputs[0], inputs[1]); + OperatorNodeConfig config{op.make_name()}; + return opr::Dot::make(inputs[0], inputs[1], config); } OP_TRAIT_REG(Dot, Dot) .apply_on_var_node(apply_on_var_node) @@ -282,7 +297,8 @@ auto apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& argsort = static_cast(def); - return opr::Argsort::make(inputs[0], argsort.param()); + OperatorNodeConfig config{argsort.make_name()}; + return opr::Argsort::make(inputs[0], argsort.param(), config); } OP_TRAIT_REG(Argsort, Argsort) .apply_on_var_node(apply_on_var_node) @@ -294,7 +310,8 @@ auto apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& argmax = static_cast(def); - return opr::Argmax::make(inputs[0], argmax.param()); + OperatorNodeConfig config{argmax.make_name()}; + return opr::Argmax::make(inputs[0], argmax.param(), config); } OP_TRAIT_REG(Argmax, Argmax) .apply_on_var_node(apply_on_var_node) @@ -306,7 +323,8 @@ auto apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& argmin = static_cast(def); - return opr::Argmin::make(inputs[0], argmin.param()); + OperatorNodeConfig config{argmin.make_name()}; + return opr::Argmin::make(inputs[0], argmin.param(), config); } OP_TRAIT_REG(Argmin, Argmin) .apply_on_var_node(apply_on_var_node) @@ -318,11 +336,13 @@ auto apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& warp = static_cast(def); + OperatorNodeConfig config{warp.make_name()}; if (inputs.size() == 3) { - return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param()); + return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param(), config); } else { mgb_assert(inputs.size() == 4); - return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], inputs[3], warp.param()); + return opr::WarpPerspective::make( + inputs[0], inputs[1], inputs[2], inputs[3], warp.param(), config); } } OP_TRAIT_REG(WarpPerspective, WarpPerspective) @@ -336,7 +356,8 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& local = static_cast(def); mgb_assert(inputs.size() == 2); - return opr::GroupLocal::make(inputs[0], inputs[1], local.param()); + OperatorNodeConfig config{local.make_name()}; + return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config); } OP_TRAIT_REG(GroupLocal, GroupLocal) .apply_on_var_node(apply_on_var_node) @@ -349,7 +370,8 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 2); - return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param()); + OperatorNodeConfig config{op.make_name()}; + return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config); } OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) .apply_on_var_node(apply_on_var_node) @@ -362,7 +384,8 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 3); - return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param()); + OperatorNodeConfig config{op.make_name()}; + return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param(), config); } OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) .apply_on_var_node(apply_on_var_node) @@ -375,7 +398,8 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 1); - return opr::TypeCvt::make(inputs[0], op.dtype); + OperatorNodeConfig config{op.make_name()}; + return opr::TypeCvt::make(inputs[0], op.dtype, config); } OP_TRAIT_REG(TypeCvt, TypeCvt) .apply_on_var_node(apply_on_var_node) @@ -388,6 +412,7 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& op = static_cast(def); cg::OperatorNodeConfig config{op.comp_node}; + config.name(op.make_name()); return opr::Concat::make(inputs, op.axis, config); } OP_TRAIT_REG(Concat, Concat) @@ -402,6 +427,7 @@ auto apply_on_var_node( auto&& op = static_cast(def); mgb_assert(inputs.size() == 1); cg::OperatorNodeConfig config{op.comp_node}; + config.name(op.make_name()); return opr::Copy::make(inputs[0], config); } OP_TRAIT_REG(Copy, Copy) @@ -411,10 +437,12 @@ OP_TRAIT_REG(Copy, Copy) namespace { namespace identity { auto apply_on_var_node( - const OpDef&, + const OpDef& def, const VarNodeArray& inputs) { + auto&& op = def.cast_final_safe(); mgb_assert(inputs.size() == 1); - return opr::Identity::make(inputs[0]); + OperatorNodeConfig config{op.make_name()}; + return opr::Identity::make(inputs[0], config); } OP_TRAIT_REG(Identity, Identity) .apply_on_var_node(apply_on_var_node) @@ -427,7 +455,8 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 2); - return opr::AssertEqual::make(inputs[0],inputs[1],op.param()); + OperatorNodeConfig config{op.make_name()}; + return opr::AssertEqual::make(inputs[0], inputs[1], op.param(), config); } @@ -443,7 +472,8 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 1); - return opr::UniformRNG::make(inputs[0], op.param()); + OperatorNodeConfig config{op.make_name()}; + return opr::UniformRNG::make(inputs[0], op.param(), config); } OP_TRAIT_REG(UniformRNG, UniformRNG) .apply_on_var_node(apply_on_var_node) @@ -456,7 +486,8 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 1); - return opr::GaussianRNG::make(inputs[0], op.param()); + OperatorNodeConfig config{op.make_name()}; + return opr::GaussianRNG::make(inputs[0], op.param(), config); } OP_TRAIT_REG(GaussianRNG, GaussianRNG) .apply_on_var_node(apply_on_var_node) @@ -469,7 +500,9 @@ VarNodeArray apply_on_var_node( const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 2); - auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param()).node()->owner_opr(); + OperatorNodeConfig config{op.make_name()}; + auto* opr = opr::ROIAlign::make( + inputs[0], inputs[1], op.param(), config).node()->owner_opr(); return {opr->output(0), opr->output(1)}; } OP_TRAIT_REG(ROIAlign, ROIAlign) @@ -484,7 +517,8 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 1); - return opr::NvOf::make(inputs[0], op.param()); + OperatorNodeConfig config{op.make_name()}; + return opr::NvOf::make(inputs[0], op.param(), config); } OP_TRAIT_REG(NvOf, NvOf) .apply_on_var_node(apply_on_var_node) @@ -499,6 +533,7 @@ auto apply_on_var_node( auto&& op = static_cast(def); mgb_assert(inputs.size() == 3); cg::OperatorNodeConfig config{op.comp_node}; + config.name(op.make_name()); return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config); } OP_TRAIT_REG(Linspace, Linspace) @@ -513,6 +548,7 @@ auto apply_on_var_node( auto&& op = static_cast(def); mgb_assert(inputs.size() == 1); cg::OperatorNodeConfig config{op.comp_node}; + config.name(op.make_name()); opr::Eye::Param param{op.k, op.dtype.enumv()}; return opr::Eye::make(inputs[0], param, config); } @@ -527,7 +563,10 @@ VarNodeArray apply_on_var_node( const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 3); - auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param()).node()->owner_opr(); + OperatorNodeConfig config{op.make_name()}; + auto* opr = opr::ROIPooling::make( + inputs[0], inputs[1], inputs[2], op.param(), config + ).node()->owner_opr(); return {opr->output(0), opr->output(1)}; } OP_TRAIT_REG(ROIPooling, ROIPooling) @@ -541,7 +580,8 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 2); - return opr::Remap::make(inputs[0], inputs[1], op.param()); + OperatorNodeConfig config{op.make_name()}; + return opr::Remap::make(inputs[0], inputs[1], op.param(), config); } OP_TRAIT_REG(Remap, Remap) .apply_on_var_node(apply_on_var_node) @@ -578,7 +618,8 @@ auto apply_on_var_node( \ const OpDef& def, \ const VarNodeArray& inputs) { \ auto&& op = static_cast(def); \ - return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items)); \ + OperatorNodeConfig config{op.make_name()}; \ + return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items), config); \ } \ OP_TRAIT_REG(NAME, NAME) \ .apply_on_var_node(apply_on_var_node) \ @@ -609,30 +650,35 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 3); - return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param()); + OperatorNodeConfig config{op.make_name()}; + return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), config); } OP_TRAIT_REG(FakeQuant, FakeQuant) .apply_on_var_node(apply_on_var_node) .fallback(); }} // fake_quant + namespace { namespace tqt { auto apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 2); - return opr::TQT::make(inputs[0], inputs[1], op.param()); + OperatorNodeConfig config{op.make_name()}; + return opr::TQT::make(inputs[0], inputs[1], op.param(), config); } OP_TRAIT_REG(TQT, TQT) .apply_on_var_node(apply_on_var_node) .fallback(); }} // tqt + namespace { namespace elemwise_multi_type { auto apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); OperatorNodeConfig config{op.dtype}; + config.name(op.make_name()); return opr::ElemwiseMultiType::make(inputs, op.param(), config); } OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType) @@ -646,7 +692,9 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 1); - return opr::SVD::make(inputs[0], op.param())[0].node()->owner_opr()->usable_output(); + OperatorNodeConfig config{op.make_name()}; + return opr::SVD::make(inputs[0], op.param(), config)[0] + .node()->owner_opr()->usable_output(); } OP_TRAIT_REG(SVD, SVD) .apply_on_var_node(apply_on_var_node) diff --git a/imperative/src/impl/ops/tensor_manip.cpp b/imperative/src/impl/ops/tensor_manip.cpp index 3dd5ed942b53b02920f32c7c65a98bf6f312f008..4931c97d6a4f2e1d2414278d332bf018c88069d6 100644 --- a/imperative/src/impl/ops/tensor_manip.cpp +++ b/imperative/src/impl/ops/tensor_manip.cpp @@ -21,7 +21,8 @@ cg::OperatorNodeBase* apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& op_def = def.cast_final_safe(); - return opr::GetVarShape::make(inputs, op_def.param()).node()->owner_opr(); + OperatorNodeConfig config{op_def.make_name()}; + return opr::GetVarShape::make(inputs, op_def.param(), config).node()->owner_opr(); } DispatchMode decide_dispatch_mode( @@ -152,7 +153,7 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node( auto&& graph = inputs[0]->owner_graph(); auto&& shapes = get_shapes(param.shapes); - cg::OperatorNodeConfig config; + OperatorNodeConfig config(param.make_name()); cg::OperatorNodeBase* opr = graph->insert_opr(std::make_unique( inputs[0], param.offsets, shapes, config)); @@ -189,7 +190,7 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node( auto&& graph = inputs[0]->owner_graph(); VarNodeArray inps(inputs.begin(), inputs.end() - 1); - cg::OperatorNodeConfig config; + OperatorNodeConfig config{param.make_name()}; cg::OperatorNodeBase* opr = graph->insert_opr(std::make_unique( inps, inputs.back(), param.offsets, config)); diff --git a/imperative/src/impl/ops/tensorrt_runtime.cpp b/imperative/src/impl/ops/tensorrt_runtime.cpp index a5eef4591316789cd5ef0b786f87084a222423c5..8e666a706fe04eba199a6a5f8f776a11d2851b1c 100644 --- a/imperative/src/impl/ops/tensorrt_runtime.cpp +++ b/imperative/src/impl/ops/tensorrt_runtime.cpp @@ -20,8 +20,9 @@ namespace { namespace tensorrt_runtime { const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); + OperatorNodeConfig config{op.make_name()}; SymbolVarArray sinputs(inputs.begin(), inputs.end()); - return opr::TensorRTRuntimeOpr::make(op.buf.c_str(), op.buf_size, sinputs); + return opr::TensorRTRuntimeOpr::make(op.buf.c_str(), op.buf_size, sinputs, config); } OP_TRAIT_REG(TensorRTRuntime, TensorRTRuntime) .apply_on_var_node(apply_on_var_node) diff --git a/imperative/src/impl/ops/warp_affine.cpp b/imperative/src/impl/ops/warp_affine.cpp index 53237d69429dc3686980decb3d7ed6aac938f8ad..c4352ee4e574d7e6115366efafa1d1ece3e09014 100644 --- a/imperative/src/impl/ops/warp_affine.cpp +++ b/imperative/src/impl/ops/warp_affine.cpp @@ -21,7 +21,8 @@ namespace { namespace warp_affine { const VarNodeArray& inputs) { mgb_assert(inputs.size() == 3); auto&& op = static_cast(def); - return opr::WarpAffine::make(inputs[0], inputs[1], inputs[2], op.param()); + OperatorNodeConfig config{op.make_name()}; + return opr::WarpAffine::make(inputs[0], inputs[1], inputs[2], op.param(), config); } OP_TRAIT_REG(WarpAffine, WarpAffine) diff --git a/imperative/src/include/megbrain/imperative/op_def.h b/imperative/src/include/megbrain/imperative/op_def.h index 8d8e62a193bff993a711ecb59e6e8c1d9e6fcbec..61e286345d27b230fa8f0adcdd402c9e7e5ebcf4 100644 --- a/imperative/src/include/megbrain/imperative/op_def.h +++ b/imperative/src/include/megbrain/imperative/op_def.h @@ -36,6 +36,7 @@ class OpDef : public Hashable, public NonCopyableObj, public std::enable_shared_from_this { mutable const OpTrait* m_trait = nullptr; + std::string m_scope; public: virtual ~OpDef() = default; @@ -86,10 +87,14 @@ public: const OpTrait* trait() const; - const char* name() const; - std::string to_string() const; + const std::string scope() const; + + const std::string make_name() const; + + void set_scope(const std::string& scope); + virtual size_t hash() const; virtual bool is_same_st(const Hashable&) const; diff --git a/imperative/tablegen/autogen.cpp b/imperative/tablegen/autogen.cpp index a31cff185ddb56c9af9887a434bc4d8da1d05497..a86788adf44db564dc58d32b6e9a2f347c6010f0 100644 --- a/imperative/tablegen/autogen.cpp +++ b/imperative/tablegen/autogen.cpp @@ -113,9 +113,10 @@ static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) { "{0}({0}_)", i.name )); } + paramList.push_back("std::string scope_ = {}"); gen_ctor(llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "), - " {}"); + " { set_scope(scope_); }"); } auto packedParams = op.getPackedParams(); @@ -236,11 +237,19 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); os << "}\n"; + // generate make_name() + os << formatv( + "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") + ); + os << mlir::tblgen::tgfmt(hashable->getNameFunctionTemplate(), &ctx); + os << "}\n"; + os << "} // anonymous namespace\n"; methods.push_back("hash"); methods.push_back("is_same_st"); methods.push_back("props"); + methods.push_back("make_name"); } if (!methods.empty()) { os << formatv( @@ -327,7 +336,7 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& targs.push_back(i.attr.getReturnType()); } os << llvm::join(targs, ", "); - os << ">()"; + os << ", std::string>()"; for (auto &&i : op.getMgbAttributes()) { os << formatv(", py::arg(\"{0}\")", i.name); auto defaultValue = i.attr.getDefaultValue(); @@ -337,7 +346,7 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& hasDefaultCtor = true; } } - os << ")"; + os << ", py::arg(\"scope\") = {})"; } if (hasDefaultCtor) { os << "\n .def(py::init<>())"; @@ -442,6 +451,10 @@ EnumWrapper<{0}::{1}>::type2str = {{ className, i.name)); } + getsetters.push_back(formatv( + "{{\"scope\", py_get_scope({0}), py_set_scope({0}), \"scope\", NULL},", + className)); + // generate tp_init std::string initBody; if (!op.getMgbAttributes().empty()) { @@ -449,6 +462,7 @@ EnumWrapper<{0}::{1}>::type2str = {{ llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { initBody += formatv("\"{0}\", ", attr.name); }); + initBody += "\"scope\", "; initBody += "NULL};\n"; initBody += " PyObject "; std::vector attrs; @@ -456,12 +470,15 @@ EnumWrapper<{0}::{1}>::type2str = {{ attrs.push_back(formatv("*{0} = NULL", attr.name)); }); initBody += llvm::join(attrs, ", ") + ";\n"; + initBody += " PyObject *scope = NULL;\n"; initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; - initBody += std::string(op.getMgbAttributes().size(), 'O'); + // an extra slot created for name + initBody += std::string(op.getMgbAttributes().size() + 1, 'O'); initBody += "\", const_cast(kwlist)"; llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { - initBody += formatv(" ,&{0}", attr.name); + initBody += formatv(", &{0}", attr.name); }); + initBody += ", &scope"; initBody += "))\n"; initBody += " return -1;\n"; llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { @@ -483,6 +500,25 @@ EnumWrapper<{0}::{1}>::type2str = {{ } )", className, attr.name); }); + + initBody += formatv(R"( + if (scope) {{ + try {{ + reinterpret_cast(self)->inst().set_scope( + pyobj_convert_generic::from(scope)); + } catch(py::error_already_set& e) {{ + e.restore(); + return -1; + } catch(py::builtin_exception& e) {{ + e.set_error(); + return -1; + } catch(...) {{ + PyErr_SetString(PyExc_RuntimeError, "Unknown Error"); + return -1; + } + } +)", className); + } initBody += "\n return 0;"; diff --git a/imperative/tablegen/helper.h b/imperative/tablegen/helper.h index 091d05270566e0f2920c57a187ecfce8598f367e..9145d387ed1a65cdba1777c4340538e117fc8e45 100644 --- a/imperative/tablegen/helper.h +++ b/imperative/tablegen/helper.h @@ -241,6 +241,30 @@ private: body += " return props_;\n"; return body; } + std::string getModeName() const { + std::string body = formatv( + " auto&& op_ = def_.cast_final_safe<{0}>();\n" + " static_cast(op_);\n", + getCppClassName() + ); + for (auto&& it : getMgbAttributes()) { + if (it.name == "mode") { + auto* enumAttr = llvm::dyn_cast(&it.attr); + body += " switch (op_.mode){\n"; + for (auto&& enumMember: enumAttr->getEnumMembers()) { + body += formatv( + " case {0}::{1}::{2}:\n", + getCppClassName(), enumAttr->getEnumName(), enumMember + ); + body += formatv(" return \"{0}\";\n", enumMember); + } + body += formatv( + " default: return \"{0}::Unknown\";\n", getCppClassName()); + body += " }\n"; + } + } + return body; + } public: static bool classof(const Operator* op) { return op->getDef().isSubClassOf("MgbHashableOpMixin"); @@ -264,6 +288,12 @@ public: } return getDefaultPropsFunction(); } + std::string getNameFunctionTemplate() const { + if (getDef().getValueAsBit("usingModeName")) { + return getModeName(); + } + return formatv(" return \"{0}\";\n", getCppClassName()); + } }; } // namespace tblgen diff --git a/sdk/load-and-run/dump_with_testcase_mge.py b/sdk/load-and-run/dump_with_testcase_mge.py index 02010e75020fb4e19e6fb75c5c4caa192fed6cf4..9aafe3f564d5c3afef45733ecfac4a26a5fcc188 100755 --- a/sdk/load-and-run/dump_with_testcase_mge.py +++ b/sdk/load-and-run/dump_with_testcase_mge.py @@ -476,6 +476,7 @@ def main(): output_mgbvars = feeds["outputs"] output_mgbvars = optimize_for_inference(args, output_mgbvars) + output_mgbvars = [var._node for var in output_mgbvars] inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") inputs = sorted((i.name, i.dtype) for i in inputs) diff --git a/src/core/include/megbrain/ir/base.td b/src/core/include/megbrain/ir/base.td index b9c59d0fde56f823d0ede04c2e8d25bf541da160..2b11392d1c7d827f0a6c5312e794473bc2cdcd7e 100644 --- a/src/core/include/megbrain/ir/base.td +++ b/src/core/include/megbrain/ir/base.td @@ -242,6 +242,7 @@ class MgbPackedParamBase: class MgbHashableOpMixin { string hashFunction = ?; string cmpFunction = ?; + bit usingModeName = 0; } class MgbOp params=[], list traits=[]>: diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index fced2ca449aa2664b973d9c2b37fb27b9a796254..674acc7e129cd9e4c132af9527c3fcf2cca0cb22 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -21,6 +21,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { let inputs = (ins Variadic:$input); let results = (outs AnyType); + let usingModeName = 1; } def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; @@ -247,6 +248,7 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara let extraArguments = (ins MgbDTypeAttr:$dtype ); + let usingModeName = 1; } def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>;