diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index d91fa94ec84e418f24b926f1df0cd6a6b1b9f489..b55bdd89446a52529d8b657f99537975132ddd7f 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -291,19 +291,21 @@ class Module(metaclass=ABCMeta): if param.grad is not None: param.grad.reset_zero() - def train(self, mode: bool = True) -> None: + def train(self, mode: bool = True, recursive: bool = True) -> None: """Set training mode of all the modules within this module (including itself) to ``mode``. This effectively sets the ``training`` attributes of those modules to ``mode``, but only has effect on certain modules (e.g. - :class:`~.BatchNorm2d`, :class:`~.Dropout`) + :class:`~.BatchNorm2d`, :class:`~.Dropout`, :class:`~.Observer`) - :param mode: The training mode to be set on modules. + :param mode: the training mode to be set on modules. + :param recursive: whether to recursively call submodules' ``train()``. """ - self.training = mode + if not recursive: + self.training = mode + return - def fn(x) -> None: - if x is not self: - x.train(mode=mode) + def fn(module: Module) -> None: + module.train(mode, recursive=False) self.apply(fn) diff --git a/python_module/megengine/quantization/observer.py b/python_module/megengine/quantization/observer.py index 847cc814106316ec7d6178ab198d92f1e9f7f87a..bbd7c234f05107754a43477cde3c89a6a988944a 100644 --- a/python_module/megengine/quantization/observer.py +++ b/python_module/megengine/quantization/observer.py @@ -60,8 +60,8 @@ class Observer(Module): def disable(self): self.enabled = False - def train(self, mode: bool = True) -> None: - super().train(mode) + def train(self, mode: bool = True, recursive: bool = True) -> None: + super().train(mode, recursive) if mode: self.enable() else: