提交 fc212042 编写于 作者: M Megvii Engine Team 提交者: wenjuan

fix(traced_module): fix Module compatible issue and traced module getattr check

GitOrigin-RevId: 62eb3bfb10e8fda942c84a6ce69acaebc85228dc
上级 275b6311
...@@ -138,11 +138,7 @@ class Module(metaclass=ABCMeta): ...@@ -138,11 +138,7 @@ class Module(metaclass=ABCMeta):
return HookHandler(self._forward_hooks, hook) return HookHandler(self._forward_hooks, hook)
def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
AutoNaming.push_scope( AutoNaming.push_scope(self.name if self.name is not None else self._short_name)
self.name
if self.name is not None
else (self._short_name if hasattr(self, "_short_name") else self._name)
)
for hook in self._forward_pre_hooks.values(): for hook in self._forward_pre_hooks.values():
modified_inputs = hook(self, inputs) modified_inputs = hook(self, inputs)
if modified_inputs is not None: if modified_inputs is not None:
...@@ -685,6 +681,12 @@ class Module(metaclass=ABCMeta): ...@@ -685,6 +681,12 @@ class Module(metaclass=ABCMeta):
set_name(self, prefix, k, v) set_name(self, prefix, k, v)
super().__setattr__(name, value) super().__setattr__(name, value)
def __setstate__(self, state):
if "_short_name" not in state:
state["_short_name"] = state["_name"]
state["_name"] = None
self.__dict__.update(state)
def __delattr__(self, name: str): def __delattr__(self, name: str):
if name in self.__dict__ and _is_module(self.__dict__[name]): if name in self.__dict__ and _is_module(self.__dict__[name]):
modules = self.__dict__.get("_modules") modules = self.__dict__.get("_modules")
......
...@@ -50,7 +50,7 @@ class _ModuleState: ...@@ -50,7 +50,7 @@ class _ModuleState:
if self.obj is None: if self.obj is None:
typem = getattr(import_module(self.module[0]), self.module[1]) typem = getattr(import_module(self.module[0]), self.module[1])
m_obj = typem.__new__(typem) m_obj = typem.__new__(typem)
m_obj.__dict__.update(self.state) m_obj.__setstate__(self.state)
self.obj = m_obj self.obj = m_obj
return self.obj return self.obj
......
...@@ -1681,11 +1681,13 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1681,11 +1681,13 @@ class TracedModuleBuilder(NodeMixin):
if isinstance(wrapped, TracedModuleBuilder): if isinstance(wrapped, TracedModuleBuilder):
if not isinstance(mod_attr, (List, Dict, QATModule)): if not isinstance(mod_attr, (List, Dict, QATModule)):
assert mod_attr is wrapped._mod assert (
else: mod_attr is wrapped._mod
), "TracedModule do not support modify module attributes, please check your code."
if isinstance(wrapped, RawTensor):
assert ( assert (
mod_attr is wrapped mod_attr is wrapped
), "TracedModule do not support modify attributes, please check your code." ), "TracedModule do not support modify tensor attributes, please check your code."
if isinstance(wrapped, (NodeMixin, RawTensor)): if isinstance(wrapped, (NodeMixin, RawTensor)):
NodeMixin.wrap( NodeMixin.wrap(
...@@ -2296,7 +2298,7 @@ class TracedModule(Module): ...@@ -2296,7 +2298,7 @@ class TracedModule(Module):
for k, v in state.items(): for k, v in state.items():
if isinstance(v, _ModuleState): if isinstance(v, _ModuleState):
state[k] = v.to_module() state[k] = v.to_module()
self.__dict__.update(state) super().__setstate__(state)
self._update_ref() self._update_ref()
for _, graph in self.argdef_graph_map.items(): for _, graph in self.argdef_graph_map.items():
......
...@@ -87,3 +87,17 @@ def test_compatibility(): ...@@ -87,3 +87,17 @@ def test_compatibility():
test_old_tensor("tensor_v1_1.mge") test_old_tensor("tensor_v1_1.mge")
test_old_tensor("tensor_v1_2.mge") test_old_tensor("tensor_v1_2.mge")
t = mge.tensor([1])
getattr(t, "qparams")
new_args = t.__getnewargs__()
assert (
len(new_args) == 3
and isinstance(new_args[0], np.ndarray)
and new_args[1] == np.int32
and isinstance(new_args[2], str)
), "Modify Tensor __getnewargs__ may break pickle serialization compatible"
state = t.__getstate__()
assert set(state.keys()) == set(
["qparams"]
), "Modify Tensor __getstate__ may break pickle serialization compatible"
...@@ -681,3 +681,27 @@ def test_repr_module_reset_attr(): ...@@ -681,3 +681,27 @@ def test_repr_module_reset_attr():
m1 = ResetAttrModule(False) m1 = ResetAttrModule(False)
output = [m0.__repr__(), m1.__repr__()] output = [m0.__repr__(), m1.__repr__()]
assert output == ground_truth assert output == ground_truth
def test_module_compatible():
class Empty(Module):
def forward(self):
pass
empty_module = Empty()
old_attributes = set(
[
"_modules",
"name",
"training",
"quantize_disabled",
"_forward_pre_hooks",
"_forward_hooks",
"_name",
"_short_name",
]
)
current_attributes = set(empty_module.__dict__.keys())
assert (
old_attributes == current_attributes
), "Add or delete attributes in Module class may break compatibility of pickle serialization"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册