提交 38f7cbd9 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mge/module): fix redundant recursion in `train()`

GitOrigin-RevId: 6b3566930b72c56b571debea1e1901745b0e7cdc
上级 5c232352
...@@ -291,19 +291,21 @@ class Module(metaclass=ABCMeta): ...@@ -291,19 +291,21 @@ class Module(metaclass=ABCMeta):
if param.grad is not None: if param.grad is not None:
param.grad.reset_zero() param.grad.reset_zero()
def train(self, mode: bool = True) -> None: def train(self, mode: bool = True, recursive: bool = True) -> None:
"""Set training mode of all the modules within this module (including itself) to """Set training mode of all the modules within this module (including itself) to
``mode``. This effectively sets the ``training`` attributes of those modules ``mode``. This effectively sets the ``training`` attributes of those modules
to ``mode``, but only has effect on certain modules (e.g. to ``mode``, but only has effect on certain modules (e.g.
:class:`~.BatchNorm2d`, :class:`~.Dropout`) :class:`~.BatchNorm2d`, :class:`~.Dropout`, :class:`~.Observer`)
:param mode: The training mode to be set on modules. :param mode: the training mode to be set on modules.
:param recursive: whether to recursively call submodules' ``train()``.
""" """
if not recursive:
self.training = mode self.training = mode
return
def fn(x) -> None: def fn(module: Module) -> None:
if x is not self: module.train(mode, recursive=False)
x.train(mode=mode)
self.apply(fn) self.apply(fn)
......
...@@ -60,8 +60,8 @@ class Observer(Module): ...@@ -60,8 +60,8 @@ class Observer(Module):
def disable(self): def disable(self):
self.enabled = False self.enabled = False
def train(self, mode: bool = True) -> None: def train(self, mode: bool = True, recursive: bool = True) -> None:
super().train(mode) super().train(mode, recursive)
if mode: if mode:
self.enable() self.enable()
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册