diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index 9d3ea762e3acb24afe33bd471e9f3d743bb4109d..b853aeecbe591a18ff417fcac56f87f128b4dede 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -138,11 +138,7 @@ class Module(metaclass=ABCMeta): return HookHandler(self._forward_hooks, hook) def __call__(self, *inputs, **kwargs): - AutoNaming.push_scope( - self.name - if self.name is not None - else (self._short_name if hasattr(self, "_short_name") else self._name) - ) + AutoNaming.push_scope(self.name if self.name is not None else self._short_name) for hook in self._forward_pre_hooks.values(): modified_inputs = hook(self, inputs) if modified_inputs is not None: @@ -685,6 +681,12 @@ class Module(metaclass=ABCMeta): set_name(self, prefix, k, v) super().__setattr__(name, value) + def __setstate__(self, state): + if "_short_name" not in state: + state["_short_name"] = state["_name"] + state["_name"] = None + self.__dict__.update(state) + def __delattr__(self, name: str): if name in self.__dict__ and _is_module(self.__dict__[name]): modules = self.__dict__.get("_modules") diff --git a/imperative/python/megengine/traced_module/serialization.py b/imperative/python/megengine/traced_module/serialization.py index 7762a40e93598edbe79c6eb64c01b94c4b767d5c..d86854e6ef7e28d93e8e2f8432b7766335cc7a7a 100644 --- a/imperative/python/megengine/traced_module/serialization.py +++ b/imperative/python/megengine/traced_module/serialization.py @@ -50,7 +50,7 @@ class _ModuleState: if self.obj is None: typem = getattr(import_module(self.module[0]), self.module[1]) m_obj = typem.__new__(typem) - m_obj.__dict__.update(self.state) + m_obj.__setstate__(self.state) self.obj = m_obj return self.obj diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 10798776000399c99b20c829e633cc929a1ee063..6f50b31369f6bbd85d2f8520c5ee6d9971baad32 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -1681,11 +1681,13 @@ class TracedModuleBuilder(NodeMixin): if isinstance(wrapped, TracedModuleBuilder): if not isinstance(mod_attr, (List, Dict, QATModule)): - assert mod_attr is wrapped._mod - else: + assert ( + mod_attr is wrapped._mod + ), "TracedModule do not support modify module attributes, please check your code." + if isinstance(wrapped, RawTensor): assert ( mod_attr is wrapped - ), "TracedModule do not support modify attributes, please check your code." + ), "TracedModule do not support modify tensor attributes, please check your code." if isinstance(wrapped, (NodeMixin, RawTensor)): NodeMixin.wrap( @@ -2296,7 +2298,7 @@ class TracedModule(Module): for k, v in state.items(): if isinstance(v, _ModuleState): state[k] = v.to_module() - self.__dict__.update(state) + super().__setstate__(state) self._update_ref() for _, graph in self.argdef_graph_map.items(): diff --git a/imperative/python/test/unit/core/test_serialization.py b/imperative/python/test/unit/core/test_serialization.py index 15f47eb8394e391d07f3eadc54fd8c80517f11a3..509be51f85313523e24bd6da753b41b3e12ee7af 100644 --- a/imperative/python/test/unit/core/test_serialization.py +++ b/imperative/python/test/unit/core/test_serialization.py @@ -87,3 +87,17 @@ def test_compatibility(): test_old_tensor("tensor_v1_1.mge") test_old_tensor("tensor_v1_2.mge") + + t = mge.tensor([1]) + getattr(t, "qparams") + new_args = t.__getnewargs__() + assert ( + len(new_args) == 3 + and isinstance(new_args[0], np.ndarray) + and new_args[1] == np.int32 + and isinstance(new_args[2], str) + ), "Modify Tensor __getnewargs__ may break pickle serialization compatible" + state = t.__getstate__() + assert set(state.keys()) == set( + ["qparams"] + ), "Modify Tensor __getstate__ may break pickle serialization compatible" diff --git a/imperative/python/test/unit/module/test_module.py b/imperative/python/test/unit/module/test_module.py index c573e224d79be6109ff6df9a00279178766bdead..095bf6e82182c2082e3a6e1de33cb7c348dc748b 100644 --- a/imperative/python/test/unit/module/test_module.py +++ b/imperative/python/test/unit/module/test_module.py @@ -681,3 +681,27 @@ def test_repr_module_reset_attr(): m1 = ResetAttrModule(False) output = [m0.__repr__(), m1.__repr__()] assert output == ground_truth + + +def test_module_compatible(): + class Empty(Module): + def forward(self): + pass + + empty_module = Empty() + old_attributes = set( + [ + "_modules", + "name", + "training", + "quantize_disabled", + "_forward_pre_hooks", + "_forward_hooks", + "_name", + "_short_name", + ] + ) + current_attributes = set(empty_module.__dict__.keys()) + assert ( + old_attributes == current_attributes + ), "Add or delete attributes in Module class may break compatibility of pickle serialization"