diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index cd0af84b85be5c99c495b436579c4817e1945181..1569e2e359b5c26f0517bad46600f639f3efcddc 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -614,22 +614,26 @@ class Module(metaclass=ABCMeta): return value def __setattr__(self, name: str, value): - if _is_module(value) or ( - isinstance(value, (list, tuple, dict)) and name != "_modules" - ): + is_module_like = _is_module(value) or isinstance(value, (list, tuple, dict)) + if name != "_modules": modules = self.__dict__.get("_modules") - if modules is None: + if modules is None and is_module_like: raise AttributeError( "cannot assign module before Module.__init__() call" ) - if name not in self.__dict__: - modules.append(name) + if is_module_like: + if name not in modules: + modules.append(name) + else: + if modules is not None and name in modules: + modules.remove(name) super().__setattr__(name, value) def __delattr__(self, name: str): if name in self.__dict__ and _is_module(self.__dict__[name]): modules = self.__dict__.get("_modules") - modules.remove(name) + if name in modules: + modules.remove(name) super().__delattr__(name) def _module_info_string(self) -> str: