提交 d0f70a44 编写于 作者: M Megvii Engine Team

fix(imperative/utils): module parameters' name do not have scope after dumping

GitOrigin-RevId: 1497272294b125192d2356695259f622eb8d0bc5
上级 b9bbf802
...@@ -293,7 +293,9 @@ class trace: ...@@ -293,7 +293,9 @@ class trace:
h = getattr(x, "_mixin_handle", -1) h = getattr(x, "_mixin_handle", -1)
if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
h, info = self._new_handle() h, info = self._new_handle()
name = auto_naming.get_scope() + "." + x.c_name if x.c_name else x._name name = (
auto_naming.get_scope() + "." + (x.c_name if x.c_name else x._name)
)
info.name = name info.name = name
info.external = True info.external = True
info.device = x.device info.device = x.device
...@@ -1123,11 +1125,11 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): ...@@ -1123,11 +1125,11 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
return outputs return outputs
def apply_const_symbolic_mode(value, dtype, device): def apply_const_symbolic_mode(value, dtype, device, name):
graph = active_trace._lazy_eval_graph graph = active_trace._lazy_eval_graph
# don't need to unset tracing # don't need to unset tracing
# because varnode construction will ignore tracing flag # because varnode construction will ignore tracing flag
ret = RawTensor(graph.make_const(value, dtype=dtype, device=device)) ret = RawTensor(graph.make_const(value, dtype=dtype, device=device, name=name))
if np.array(value).ndim == 0: if np.array(value).ndim == 0:
setscalar(ret) setscalar(ret)
return (ret,) return (ret,)
...@@ -1175,7 +1177,7 @@ def apply_with_tracing(op: OpDef, *args: RawTensor): ...@@ -1175,7 +1177,7 @@ def apply_with_tracing(op: OpDef, *args: RawTensor):
def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name): def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name):
if active_trace._symbolic: if active_trace._symbolic:
outputs = apply_const_symbolic_mode(value, dtype, device) outputs = apply_const_symbolic_mode(value, dtype, device, name)
else: else:
unset_tracing() unset_tracing()
outputs = (RawTensor(value, dtype, device, False, name),) outputs = (RawTensor(value, dtype, device, False, name),)
......
...@@ -33,7 +33,7 @@ class Tensor(_Tensor, ArrayMethodMixin): ...@@ -33,7 +33,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
_q_dict = None _q_dict = None
def __new__( def __new__(
cls, data, dtype=None, device=None, is_const=False, no_cache=False, name="" cls, data, dtype=None, device=None, is_const=False, no_cache=False, name=None
): ):
if device is None: if device is None:
cn = get_default_device() cn = get_default_device()
......
...@@ -234,7 +234,8 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { ...@@ -234,7 +234,8 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
CompNode cn = tup[2].cast<CompNode>(); CompNode cn = tup[2].cast<CompNode>();
bool is_const = tup[3].cast<bool>(); bool is_const = tup[3].cast<bool>();
bool no_cache = nargs == 6 ? tup[4].cast<bool>() : false; bool no_cache = nargs == 6 ? tup[4].cast<bool>() : false;
std::string name = tup[nargs - 1].cast<std::string>(); std::string name;
if (tup[nargs - 1].ptr() != Py_None) name = tup[nargs - 1].cast<std::string>();
// const op // const op
if (is_const && is_tracing) { if (is_const && is_tracing) {
......
...@@ -408,6 +408,15 @@ def test_copy_d2d(): ...@@ -408,6 +408,15 @@ def test_copy_d2d():
copy_test("gpu0:0", "gpu0:1") copy_test("gpu0:0", "gpu0:1")
def test_name():
x = tensor(0)
assert x.name == ""
x.name = "x"
assert x.name == "x"
x = tensor(0, name="x")
assert x.name == "x"
def test_q_dict(): def test_q_dict():
x = tensor(1) x = tensor(1)
assert x.q_dict["scale"] is None assert x.q_dict["scale"] is None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册