From 38f7cbd9aa31d5ea3ecae430033719b4e6785c00 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 12 Jun 2020 12:53:52 +0800 Subject: [PATCH] fix(mge/module): fix redundant recursion in `train()` GitOrigin-RevId: 6b3566930b72c56b571debea1e1901745b0e7cdc --- python_module/megengine/module/module.py | 16 +++++++++------- python_module/megengine/quantization/observer.py | 4 ++-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index d91fa94ec..b55bdd894 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -291,19 +291,21 @@ class Module(metaclass=ABCMeta): if param.grad is not None: param.grad.reset_zero() - def train(self, mode: bool = True) -> None: + def train(self, mode: bool = True, recursive: bool = True) -> None: """Set training mode of all the modules within this module (including itself) to ``mode``. This effectively sets the ``training`` attributes of those modules to ``mode``, but only has effect on certain modules (e.g. - :class:`~.BatchNorm2d`, :class:`~.Dropout`) + :class:`~.BatchNorm2d`, :class:`~.Dropout`, :class:`~.Observer`) - :param mode: The training mode to be set on modules. + :param mode: the training mode to be set on modules. + :param recursive: whether to recursively call submodules' ``train()``. """ - self.training = mode + if not recursive: + self.training = mode + return - def fn(x) -> None: - if x is not self: - x.train(mode=mode) + def fn(module: Module) -> None: + module.train(mode, recursive=False) self.apply(fn) diff --git a/python_module/megengine/quantization/observer.py b/python_module/megengine/quantization/observer.py index 847cc8141..bbd7c234f 100644 --- a/python_module/megengine/quantization/observer.py +++ b/python_module/megengine/quantization/observer.py @@ -60,8 +60,8 @@ class Observer(Module): def disable(self): self.enabled = False - def train(self, mode: bool = True) -> None: - super().train(mode) + def train(self, mode: bool = True, recursive: bool = True) -> None: + super().train(mode, recursive) if mode: self.enable() else: -- GitLab