diff --git a/imperative/python/megengine/traced_module/utils.py b/imperative/python/megengine/traced_module/utils.py index 197c602ef3860bcdd689bc9a8d735115607c3eed..f9c3c46e59a0144bd5678be96e10ce0c4e0ce2e7 100644 --- a/imperative/python/megengine/traced_module/utils.py +++ b/imperative/python/megengine/traced_module/utils.py @@ -147,13 +147,14 @@ class _ModuleDict(Module, MutableMapping): def __init__(self, modules: Optional[Dict[str, Module]] = None): super().__init__() - self._size = 0 + self._module_keys = [] if modules is not None: self.update(modules) def __delitem__(self, key): delattr(self, key) - self._size -= 1 + assert key in self._module_keys + self._module_keys.remove(key) def __getitem__(self, key): return getattr(self, key) @@ -162,22 +163,23 @@ class _ModuleDict(Module, MutableMapping): if not isinstance(value, Module): raise ValueError("invalid sub-module") setattr(self, key, value) - self._size += 1 + if key not in self._module_keys: + self._module_keys.append(key) def __iter__(self): return iter(self.keys()) def __len__(self): - return self._size + return len(self._module_keys) def items(self): - return dict(self.named_children()).items() + return [(key, getattr(self, key)) for key in self._module_keys] def values(self): - return dict(self.named_children()).values() + return [getattr(self, key) for key in self._module_keys] def keys(self): - return dict(self.named_children()).keys() + return self._module_keys def forward(self): raise RuntimeError("ModuleList is not callable") diff --git a/imperative/python/test/unit/traced_module/test_trace_module.py b/imperative/python/test/unit/traced_module/test_trace_module.py index 49432d66c7c3a9963205a0b44e68a5d467b92d7d..d18d08351240d7995e7b0dc512dee9f563b8b227 100644 --- a/imperative/python/test/unit/traced_module/test_trace_module.py +++ b/imperative/python/test/unit/traced_module/test_trace_module.py @@ -1,3 +1,5 @@ +from collections import OrderedDict + import numpy as np import megengine.functional as F @@ -29,14 +31,19 @@ class MyModule3(M.Module): self.modules = [ M.Elemwise("ADD"), M.Elemwise("ADD"), - {"a": M.Elemwise("ADD"), "b": M.Elemwise("ADD")}, + OrderedDict([("a", M.Elemwise("ADD")), ("b", M.Elemwise("ADD"))]), + M.Elemwise("RELU"), + M.Elemwise("RELU"), ] def forward(self, a, b): x = self.modules[0](a, b) y = self.modules[1](a, b) - y = self.modules[2]["a"](x, y) - y = self.modules[2]["b"](x, y) + assert list(self.modules[2].keys()) == ["a", "b"] + for _, m in self.modules[2].items(): + y = m(x, y) + for m in self.modules[3:]: + y = m(y) return y @@ -78,6 +85,7 @@ def test_trace_module(): assert isinstance(tm3.modules.__dict__["0"], M.Elemwise) assert isinstance(tm3.modules.__dict__["2"], TracedModule) assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise) + assert isinstance(tm3.modules.__dict__["3"], M.Elemwise) m4 = MyModule4() tm4 = trace_module(m4, a, b)