From 206521cd3334270758da0d47542ea8748d0074ad Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 8 May 2020 17:16:44 +0800 Subject: [PATCH] fix(mge/quant): avoid updating the scale and zero point in eval mode GitOrigin-RevId: dfb08cf701aba3df09a8310f18e8d92bf9a1db21 --- python_module/megengine/module/module.py | 3 ++- python_module/megengine/quantization/observer.py | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index 7041b93fd..8c7bf8cdb 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -294,7 +294,8 @@ class Module(metaclass=ABCMeta): self.training = mode def fn(x) -> None: - x.training = mode + if x is not self: + x.train(mode=mode) self.apply(fn) diff --git a/python_module/megengine/quantization/observer.py b/python_module/megengine/quantization/observer.py index 64e6addab..a1d9b84bc 100644 --- a/python_module/megengine/quantization/observer.py +++ b/python_module/megengine/quantization/observer.py @@ -56,6 +56,13 @@ class Observer(Module): def disable(self): self.enabled = False + def train(self, mode: bool = True) -> None: + super().train(mode) + if mode: + self.enable() + else: + self.disable() + @abstractmethod def forward(self, x): pass -- GitLab