提交 aed681d3 编写于 作者: M Megvii Engine Team

feat(imperative/utils): optimize the naming rules

GitOrigin-RevId: 329bac640aa6e2e3c981aa294a361684b982892e
上级 c6bbc478
...@@ -40,7 +40,7 @@ from ..core.ops.builtin import BackwardGraph, OpDef ...@@ -40,7 +40,7 @@ from ..core.ops.builtin import BackwardGraph, OpDef
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..core.tensor.utils import setscalar from ..core.tensor.utils import setscalar
from ..utils.naming import auto_naming from ..utils.naming import AutoNaming
from .sublinear_memory_config import SublinearMemoryConfig from .sublinear_memory_config import SublinearMemoryConfig
...@@ -297,9 +297,7 @@ class trace: ...@@ -297,9 +297,7 @@ class trace:
h = getattr(x, "_mixin_handle", -1) h = getattr(x, "_mixin_handle", -1)
if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
h, info = self._new_handle() h, info = self._new_handle()
name = ( name = AutoNaming.gen_name(x)
auto_naming.get_scope() + "." + (x.c_name if x.c_name else x._name)
)
info.name = name info.name = name
info.external = True info.external = True
info.device = x.device info.device = x.device
...@@ -845,17 +843,17 @@ class trace: ...@@ -845,17 +843,17 @@ class trace:
ivars.append(h2v[h]) ivars.append(h2v[h])
ovars = G.apply_normal_varnode(op, *ivars) ovars = G.apply_normal_varnode(op, *ivars)
auto_naming.record_opnode(ovars[0].op) AutoNaming.record_opnode(ovars[0].op)
assert len(ovars) == len(ohandles) assert len(ovars) == len(ohandles)
h2v.update(zip(ohandles, ovars)) h2v.update(zip(ohandles, ovars))
for i in ohandles: for i in ohandles:
name = auto_naming.get_var_name(i) name = AutoNaming.get_var_name(i)
if name is not None: if name is not None:
h2v[i].name = name h2v[i].name = name
auto_naming.remove_duplicate_names() AutoNaming.remove_duplicate_names()
dest_vars = [] dest_vars = []
for i, h in enumerate(self._output_bindings): for i, h in enumerate(self._output_bindings):
...@@ -1173,7 +1171,7 @@ def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name): ...@@ -1173,7 +1171,7 @@ def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name):
def apply_with_tracing(op: OpDef, *args: RawTensor): def apply_with_tracing(op: OpDef, *args: RawTensor):
if hasattr(op, "scope"): if hasattr(op, "scope"):
op.scope = auto_naming.get_scope() op.scope = AutoNaming.get_scope()
if active_trace._symbolic: if active_trace._symbolic:
outputs = apply_symbolic_mode(op, *args) outputs = apply_symbolic_mode(op, *args)
else: else:
......
...@@ -16,7 +16,7 @@ from ..logger import get_logger ...@@ -16,7 +16,7 @@ from ..logger import get_logger
from ..tensor import Parameter, Tensor from ..tensor import Parameter, Tensor
from ..utils.deprecation import deprecated from ..utils.deprecation import deprecated
from ..utils.hook import HookHandler from ..utils.hook import HookHandler
from ..utils.naming import auto_naming from ..utils.naming import AutoNaming
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -111,7 +111,7 @@ class Module(metaclass=ABCMeta): ...@@ -111,7 +111,7 @@ class Module(metaclass=ABCMeta):
self._forward_hooks = OrderedDict() self._forward_hooks = OrderedDict()
# used for profiler and automatic naming # used for profiler and automatic naming
self._name = "{anonymous}" self._name = None
@abstractmethod @abstractmethod
def forward(self, inputs): def forward(self, inputs):
...@@ -137,7 +137,7 @@ class Module(metaclass=ABCMeta): ...@@ -137,7 +137,7 @@ class Module(metaclass=ABCMeta):
return HookHandler(self._forward_hooks, hook) return HookHandler(self._forward_hooks, hook)
def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
auto_naming.push_scope(self.name if self.name is not None else self._name) AutoNaming.push_scope(self.name if self.name is not None else self._name)
for hook in self._forward_pre_hooks.values(): for hook in self._forward_pre_hooks.values():
modified_inputs = hook(self, inputs) modified_inputs = hook(self, inputs)
if modified_inputs is not None: if modified_inputs is not None:
...@@ -151,7 +151,7 @@ class Module(metaclass=ABCMeta): ...@@ -151,7 +151,7 @@ class Module(metaclass=ABCMeta):
modified_outputs = hook(self, inputs, outputs) modified_outputs = hook(self, inputs, outputs)
if modified_outputs is not None: if modified_outputs is not None:
outputs = modified_outputs outputs = modified_outputs
auto_naming.pop_scope() AutoNaming.pop_scope()
return outputs return outputs
def _flatten( def _flatten(
......
...@@ -20,7 +20,7 @@ from .core.tensor.array_method import ArrayMethodMixin ...@@ -20,7 +20,7 @@ from .core.tensor.array_method import ArrayMethodMixin
from .device import _valid_device, get_default_device from .device import _valid_device, get_default_device
from .logger import get_logger from .logger import get_logger
from .utils.deprecation import deprecated from .utils.deprecation import deprecated
from .utils.naming import auto_naming from .utils.naming import AutoNaming
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -168,7 +168,7 @@ class Tensor(_Tensor, ArrayMethodMixin): ...@@ -168,7 +168,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
@name.setter @name.setter
def name(self, name): def name(self, name):
self.c_name = name self.c_name = name
auto_naming.record_var_name(self._mixin_handle, name) AutoNaming.record_var_name(self._mixin_handle, name)
@deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
def set_value(self, value): def set_value(self, value):
......
...@@ -15,40 +15,57 @@ class AutoNaming: ...@@ -15,40 +15,57 @@ class AutoNaming:
renamed by the user. renamed by the user.
""" """
def __init__(self): scopes = []
self.scopes = [] c_ops = []
self.c_ops = [] name2ops = {}
self.name2ops = {} handle2names = {}
self.handle2names = {} __cls_attributes__ = {"scopes", "c_ops", "name2ops", "handle2names"}
def clear(self): @classmethod
for var in vars(self).values(): def clear(cls):
var.clear() for attr in cls.__cls_attributes__:
getattr(cls, attr).clear()
def push_scope(self, scope): @classmethod
push_scope(scope) def push_scope(cls, scope):
self.scopes.append(scope) if scope is not None:
push_scope(scope)
cls.scopes.append(scope)
def pop_scope(self): @classmethod
scope = self.scopes.pop() def pop_scope(cls):
pop_scope(scope) scope = cls.scopes.pop()
if scope is not None:
pop_scope(scope)
def get_scope(self): @classmethod
return ".".join(self.scopes) def get_scope(cls):
return ".".join(s for s in cls.scopes if s is not None)
def record_var_name(self, handle, name): @classmethod
self.handle2names[handle] = name def gen_name(cls, x) -> str:
scope = cls.get_scope()
name = x.c_name if x.c_name else x._name
return scope + "." + name if len(scope) else name
def get_var_name(self, handle): @classmethod
return self.handle2names.pop(handle, None) def record_var_name(cls, handle, name):
cls.handle2names[handle] = name
def record_opnode(self, op): @classmethod
ops = self.name2ops.get(op.name, []) def get_var_name(cls, handle):
ops.append(op) return cls.handle2names.pop(handle, None)
self.name2ops[op.name] = ops
def remove_duplicate_names(self): @classmethod
for key, ops in self.name2ops.items(): def record_opnode(cls, op):
ops = cls.name2ops.get(op.name, [])
if op not in ops:
ops.append(op)
cls.name2ops[op.name] = ops
@classmethod
def remove_duplicate_names(cls):
for key, ops in cls.name2ops.items():
if len(ops) == 1: if len(ops) == 1:
continue continue
for i, op in enumerate(ops): for i, op in enumerate(ops):
...@@ -57,7 +74,4 @@ class AutoNaming: ...@@ -57,7 +74,4 @@ class AutoNaming:
continue continue
for var in op.outputs: for var in op.outputs:
var.name = var.name.replace(key, op.name) var.name = var.name.replace(key, op.name)
self.name2ops.clear() cls.name2ops.clear()
auto_naming = AutoNaming()
...@@ -28,7 +28,7 @@ from megengine.functional import exp, log ...@@ -28,7 +28,7 @@ from megengine.functional import exp, log
from megengine.jit import exclude_from_trace, trace from megengine.jit import exclude_from_trace, trace
from megengine.module import Module from megengine.module import Module
from megengine.random import normal, uniform from megengine.random import normal, uniform
from megengine.utils.naming import auto_naming from megengine.utils.naming import AutoNaming
@pytest.mark.parametrize("trace_mode", [False, True]) @pytest.mark.parametrize("trace_mode", [False, True])
...@@ -141,7 +141,7 @@ def test_dump(): ...@@ -141,7 +141,7 @@ def test_dump():
return a + b return a + b
# prevent from remaining scope from exception test # prevent from remaining scope from exception test
auto_naming.clear() AutoNaming.clear()
a = tensor([2]) a = tensor([2])
b = tensor([4]) b = tensor([4])
y = f(a, b).numpy() y = f(a, b).numpy()
......
...@@ -18,11 +18,11 @@ from megengine import Parameter, Tensor ...@@ -18,11 +18,11 @@ from megengine import Parameter, Tensor
from megengine.core.tensor import megbrain_graph as G from megengine.core.tensor import megbrain_graph as G
from megengine.jit.tracing import trace from megengine.jit.tracing import trace
from megengine.quantization.quantize import quantize, quantize_qat from megengine.quantization.quantize import quantize, quantize_qat
from megengine.utils.naming import auto_naming from megengine.utils.naming import AutoNaming
def _dump_and_load(func, symbolic, keep_opr_name=True): def _dump_and_load(func, symbolic, keep_opr_name=True):
auto_naming.clear() AutoNaming.clear()
func = trace(func, symbolic=symbolic, capture_as_const=True) func = trace(func, symbolic=symbolic, capture_as_const=True)
x = Tensor(np.ones(shape=(2, 3))) x = Tensor(np.ones(shape=(2, 3)))
func(x).numpy() func(x).numpy()
...@@ -103,6 +103,18 @@ def test_without_module(symbolic): ...@@ -103,6 +103,18 @@ def test_without_module(symbolic):
assert op.name == "MUL" assert op.name == "MUL"
@pytest.mark.parametrize("symbolic", [False, True])
def test_ignore_top_module(symbolic):
class Simple(M.Module):
def forward(self, x):
return x + x
m = Simple()
op = _dump_and_load(m, symbolic)[-1]
assert op.name == "ADD"
assert op.outputs[0].name == "ADD"
@pytest.mark.parametrize("symbolic", [False, True]) @pytest.mark.parametrize("symbolic", [False, True])
def test_with_submodule(symbolic): def test_with_submodule(symbolic):
class Simple(M.Module): class Simple(M.Module):
...@@ -196,7 +208,7 @@ def test_not_keep_opr_name(): ...@@ -196,7 +208,7 @@ def test_not_keep_opr_name():
return 2 * x return 2 * x
op = _dump_and_load(f, True, False)[-1] op = _dump_and_load(f, True, False)[-1]
assert op.name == "MUL(x,2[2])[4]" assert op.name == "MUL(x,const<2>[2])[4]"
@pytest.mark.parametrize("symbolic", [False, True]) @pytest.mark.parametrize("symbolic", [False, True])
......
...@@ -419,7 +419,7 @@ void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND &val) { ...@@ -419,7 +419,7 @@ void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND &val) {
if (one_elem(val.shape())) { if (one_elem(val.shape())) {
float v; float v;
static_cast_dtype(&v, val.dtype(), val.raw_ptr()); static_cast_dtype(&v, val.dtype(), val.raw_ptr());
m_summary = ssprintf("%.3g", v); m_summary = ssprintf("const<%.3g>", v);
if (val.shape().ndim != 1) { if (val.shape().ndim != 1) {
m_summary += val.shape().to_string(); m_summary += val.shape().to_string();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册