From 84d99d1cc4b074b5eaf1ef1fbd679a845fbcfaf9 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 27 Jan 2022 17:09:43 +0800 Subject: [PATCH] fix(traced_module): fix Module compatible issue and traced module getattr check GitOrigin-RevId: 62eb3bfb10e8fda942c84a6ce69acaebc85228dc --- imperative/python/megengine/module/module.py | 12 ++++++---- .../megengine/traced_module/serialization.py | 2 +- .../megengine/traced_module/traced_module.py | 10 ++++---- .../test/unit/core/test_serialization.py | 14 +++++++++++ .../python/test/unit/module/test_module.py | 24 +++++++++++++++++++ 5 files changed, 52 insertions(+), 10 deletions(-) diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index 9d3ea762e..b853aeecb 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -138,11 +138,7 @@ class Module(metaclass=ABCMeta): return HookHandler(self._forward_hooks, hook) def __call__(self, *inputs, **kwargs): - AutoNaming.push_scope( - self.name - if self.name is not None - else (self._short_name if hasattr(self, "_short_name") else self._name) - ) + AutoNaming.push_scope(self.name if self.name is not None else self._short_name) for hook in self._forward_pre_hooks.values(): modified_inputs = hook(self, inputs) if modified_inputs is not None: @@ -685,6 +681,12 @@ class Module(metaclass=ABCMeta): set_name(self, prefix, k, v) super().__setattr__(name, value) + def __setstate__(self, state): + if "_short_name" not in state: + state["_short_name"] = state["_name"] + state["_name"] = None + self.__dict__.update(state) + def __delattr__(self, name: str): if name in self.__dict__ and _is_module(self.__dict__[name]): modules = self.__dict__.get("_modules") diff --git a/imperative/python/megengine/traced_module/serialization.py b/imperative/python/megengine/traced_module/serialization.py index 7762a40e9..d86854e6e 100644 --- a/imperative/python/megengine/traced_module/serialization.py +++ b/imperative/python/megengine/traced_module/serialization.py @@ -50,7 +50,7 @@ class _ModuleState: if self.obj is None: typem = getattr(import_module(self.module[0]), self.module[1]) m_obj = typem.__new__(typem) - m_obj.__dict__.update(self.state) + m_obj.__setstate__(self.state) self.obj = m_obj return self.obj diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 107987760..6f50b3136 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -1681,11 +1681,13 @@ class TracedModuleBuilder(NodeMixin): if isinstance(wrapped, TracedModuleBuilder): if not isinstance(mod_attr, (List, Dict, QATModule)): - assert mod_attr is wrapped._mod - else: + assert ( + mod_attr is wrapped._mod + ), "TracedModule do not support modify module attributes, please check your code." + if isinstance(wrapped, RawTensor): assert ( mod_attr is wrapped - ), "TracedModule do not support modify attributes, please check your code." + ), "TracedModule do not support modify tensor attributes, please check your code." if isinstance(wrapped, (NodeMixin, RawTensor)): NodeMixin.wrap( @@ -2296,7 +2298,7 @@ class TracedModule(Module): for k, v in state.items(): if isinstance(v, _ModuleState): state[k] = v.to_module() - self.__dict__.update(state) + super().__setstate__(state) self._update_ref() for _, graph in self.argdef_graph_map.items(): diff --git a/imperative/python/test/unit/core/test_serialization.py b/imperative/python/test/unit/core/test_serialization.py index 15f47eb83..509be51f8 100644 --- a/imperative/python/test/unit/core/test_serialization.py +++ b/imperative/python/test/unit/core/test_serialization.py @@ -87,3 +87,17 @@ def test_compatibility(): test_old_tensor("tensor_v1_1.mge") test_old_tensor("tensor_v1_2.mge") + + t = mge.tensor([1]) + getattr(t, "qparams") + new_args = t.__getnewargs__() + assert ( + len(new_args) == 3 + and isinstance(new_args[0], np.ndarray) + and new_args[1] == np.int32 + and isinstance(new_args[2], str) + ), "Modify Tensor __getnewargs__ may break pickle serialization compatible" + state = t.__getstate__() + assert set(state.keys()) == set( + ["qparams"] + ), "Modify Tensor __getstate__ may break pickle serialization compatible" diff --git a/imperative/python/test/unit/module/test_module.py b/imperative/python/test/unit/module/test_module.py index c573e224d..095bf6e82 100644 --- a/imperative/python/test/unit/module/test_module.py +++ b/imperative/python/test/unit/module/test_module.py @@ -681,3 +681,27 @@ def test_repr_module_reset_attr(): m1 = ResetAttrModule(False) output = [m0.__repr__(), m1.__repr__()] assert output == ground_truth + + +def test_module_compatible(): + class Empty(Module): + def forward(self): + pass + + empty_module = Empty() + old_attributes = set( + [ + "_modules", + "name", + "training", + "quantize_disabled", + "_forward_pre_hooks", + "_forward_hooks", + "_name", + "_short_name", + ] + ) + current_attributes = set(empty_module.__dict__.keys()) + assert ( + old_attributes == current_attributes + ), "Add or delete attributes in Module class may break compatibility of pickle serialization" -- GitLab