提交 d0f70a44 编写于 作者: M Megvii Engine Team

fix(imperative/utils): module parameters' name do not have scope after dumping

GitOrigin-RevId: 1497272294b125192d2356695259f622eb8d0bc5
上级 b9bbf802
......@@ -293,7 +293,9 @@ class trace:
h = getattr(x, "_mixin_handle", -1)
if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
h, info = self._new_handle()
name = auto_naming.get_scope() + "." + x.c_name if x.c_name else x._name
name = (
auto_naming.get_scope() + "." + (x.c_name if x.c_name else x._name)
)
info.name = name
info.external = True
info.device = x.device
......@@ -1123,11 +1125,11 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
return outputs
def apply_const_symbolic_mode(value, dtype, device):
def apply_const_symbolic_mode(value, dtype, device, name):
graph = active_trace._lazy_eval_graph
# don't need to unset tracing
# because varnode construction will ignore tracing flag
ret = RawTensor(graph.make_const(value, dtype=dtype, device=device))
ret = RawTensor(graph.make_const(value, dtype=dtype, device=device, name=name))
if np.array(value).ndim == 0:
setscalar(ret)
return (ret,)
......@@ -1175,7 +1177,7 @@ def apply_with_tracing(op: OpDef, *args: RawTensor):
def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name):
if active_trace._symbolic:
outputs = apply_const_symbolic_mode(value, dtype, device)
outputs = apply_const_symbolic_mode(value, dtype, device, name)
else:
unset_tracing()
outputs = (RawTensor(value, dtype, device, False, name),)
......
......@@ -33,7 +33,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
_q_dict = None
def __new__(
cls, data, dtype=None, device=None, is_const=False, no_cache=False, name=""
cls, data, dtype=None, device=None, is_const=False, no_cache=False, name=None
):
if device is None:
cn = get_default_device()
......
......@@ -234,7 +234,8 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
CompNode cn = tup[2].cast<CompNode>();
bool is_const = tup[3].cast<bool>();
bool no_cache = nargs == 6 ? tup[4].cast<bool>() : false;
std::string name = tup[nargs - 1].cast<std::string>();
std::string name;
if (tup[nargs - 1].ptr() != Py_None) name = tup[nargs - 1].cast<std::string>();
// const op
if (is_const && is_tracing) {
......
......@@ -408,6 +408,15 @@ def test_copy_d2d():
copy_test("gpu0:0", "gpu0:1")
def test_name():
x = tensor(0)
assert x.name == ""
x.name = "x"
assert x.name == "x"
x = tensor(0, name="x")
assert x.name == "x"
def test_q_dict():
x = tensor(1)
assert x.q_dict["scale"] is None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册