提交 206521cd 编写于 作者: M Megvii Engine Team

fix(mge/quant): avoid updating the scale and zero point in eval mode

GitOrigin-RevId: dfb08cf701aba3df09a8310f18e8d92bf9a1db21
上级 c9986df5
......@@ -294,7 +294,8 @@ class Module(metaclass=ABCMeta):
self.training = mode
def fn(x) -> None:
x.training = mode
if x is not self:
x.train(mode=mode)
self.apply(fn)
......
......@@ -56,6 +56,13 @@ class Observer(Module):
def disable(self):
self.enabled = False
def train(self, mode: bool = True) -> None:
super().train(mode)
if mode:
self.enable()
else:
self.disable()
@abstractmethod
def forward(self, x):
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册