提交 dbca3270 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): fix module dict

GitOrigin-RevId: d7baf00e3c1b982fe4dba2d9834f2d3097359e96
上级 01d2473c
...@@ -147,13 +147,14 @@ class _ModuleDict(Module, MutableMapping): ...@@ -147,13 +147,14 @@ class _ModuleDict(Module, MutableMapping):
def __init__(self, modules: Optional[Dict[str, Module]] = None): def __init__(self, modules: Optional[Dict[str, Module]] = None):
super().__init__() super().__init__()
self._size = 0 self._module_keys = []
if modules is not None: if modules is not None:
self.update(modules) self.update(modules)
def __delitem__(self, key): def __delitem__(self, key):
delattr(self, key) delattr(self, key)
self._size -= 1 assert key in self._module_keys
self._module_keys.remove(key)
def __getitem__(self, key): def __getitem__(self, key):
return getattr(self, key) return getattr(self, key)
...@@ -162,22 +163,23 @@ class _ModuleDict(Module, MutableMapping): ...@@ -162,22 +163,23 @@ class _ModuleDict(Module, MutableMapping):
if not isinstance(value, Module): if not isinstance(value, Module):
raise ValueError("invalid sub-module") raise ValueError("invalid sub-module")
setattr(self, key, value) setattr(self, key, value)
self._size += 1 if key not in self._module_keys:
self._module_keys.append(key)
def __iter__(self): def __iter__(self):
return iter(self.keys()) return iter(self.keys())
def __len__(self): def __len__(self):
return self._size return len(self._module_keys)
def items(self): def items(self):
return dict(self.named_children()).items() return [(key, getattr(self, key)) for key in self._module_keys]
def values(self): def values(self):
return dict(self.named_children()).values() return [getattr(self, key) for key in self._module_keys]
def keys(self): def keys(self):
return dict(self.named_children()).keys() return self._module_keys
def forward(self): def forward(self):
raise RuntimeError("ModuleList is not callable") raise RuntimeError("ModuleList is not callable")
from collections import OrderedDict
import numpy as np import numpy as np
import megengine.functional as F import megengine.functional as F
...@@ -29,14 +31,19 @@ class MyModule3(M.Module): ...@@ -29,14 +31,19 @@ class MyModule3(M.Module):
self.modules = [ self.modules = [
M.Elemwise("ADD"), M.Elemwise("ADD"),
M.Elemwise("ADD"), M.Elemwise("ADD"),
{"a": M.Elemwise("ADD"), "b": M.Elemwise("ADD")}, OrderedDict([("a", M.Elemwise("ADD")), ("b", M.Elemwise("ADD"))]),
M.Elemwise("RELU"),
M.Elemwise("RELU"),
] ]
def forward(self, a, b): def forward(self, a, b):
x = self.modules[0](a, b) x = self.modules[0](a, b)
y = self.modules[1](a, b) y = self.modules[1](a, b)
y = self.modules[2]["a"](x, y) assert list(self.modules[2].keys()) == ["a", "b"]
y = self.modules[2]["b"](x, y) for _, m in self.modules[2].items():
y = m(x, y)
for m in self.modules[3:]:
y = m(y)
return y return y
...@@ -78,6 +85,7 @@ def test_trace_module(): ...@@ -78,6 +85,7 @@ def test_trace_module():
assert isinstance(tm3.modules.__dict__["0"], M.Elemwise) assert isinstance(tm3.modules.__dict__["0"], M.Elemwise)
assert isinstance(tm3.modules.__dict__["2"], TracedModule) assert isinstance(tm3.modules.__dict__["2"], TracedModule)
assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise) assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise)
assert isinstance(tm3.modules.__dict__["3"], M.Elemwise)
m4 = MyModule4() m4 = MyModule4()
tm4 = trace_module(m4, a, b) tm4 = trace_module(m4, a, b)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册