提交 b7ed0cb8 编写于 作者: M Megvii Engine Team

fix(mge/quantization): replace `_reset` with "=" in observer

GitOrigin-RevId: ed6af9b98d1f68811a78f44d531c52b60339fbbe
上级 495b2003
...@@ -153,11 +153,11 @@ class ExponentialMovingAverageObserver(MinMaxObserver): ...@@ -153,11 +153,11 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
**kwargs **kwargs
): ):
super().__init__(mode, eps, dtype, narrow_range, **kwargs) super().__init__(mode, eps, dtype, narrow_range, **kwargs)
self.momentum = Tensor(momentum) self.momentum = Tensor(momentum, dtype="float32")
self.runtime_momentum = Tensor(0.0) self.runtime_momentum = Tensor(0.0)
def set_momentum(self, momentum): def set_momentum(self, momentum):
self.momentum._reset(momentum) self.momentum = Tenosr(momentum, dtype="float32")
def forward(self, x_orig): def forward(self, x_orig):
if self.enabled: if self.enabled:
...@@ -439,9 +439,9 @@ class HistogramObserver(MinMaxObserver): ...@@ -439,9 +439,9 @@ class HistogramObserver(MinMaxObserver):
self.bins, self.bins,
) )
self.histogram._reset(new_histogram) self.histogram = Tensor(new_histogram, dtype="float32")
self.min_val._reset(new_min) self.min_val = Tensor(new_min, dtype="float32")
self.max_val._reset(new_max) self.max_val = Tensor(new_max, dtype="float32")
def forward(self, x_orig): def forward(self, x_orig):
self.sideeffect_forward(x_orig) self.sideeffect_forward(x_orig)
......
...@@ -8,6 +8,7 @@ import megengine.distributed as dist ...@@ -8,6 +8,7 @@ import megengine.distributed as dist
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
from megengine.quantization.observer import ( from megengine.quantization.observer import (
ExponentialMovingAverageObserver, ExponentialMovingAverageObserver,
HistogramObserver,
MinMaxObserver, MinMaxObserver,
Observer, Observer,
PassiveObserver, PassiveObserver,
...@@ -44,6 +45,16 @@ def test_exponential_moving_average_observer(): ...@@ -44,6 +45,16 @@ def test_exponential_moving_average_observer():
np.testing.assert_allclose(m.max_val.numpy(), expected_max) np.testing.assert_allclose(m.max_val.numpy(), expected_max)
def test_histogram_observer():
x = np.random.rand(3, 3, 3, 3).astype("float32")
np_min, np_max = x.min(), x.max()
x = mge.tensor(x)
m = HistogramObserver()
m(x)
np.testing.assert_allclose(m.min_val.numpy(), np_min)
np.testing.assert_allclose(m.max_val.numpy(), np_max)
def test_passive_observer(): def test_passive_observer():
q_dict = {"scale": mge.tensor(1.0)} q_dict = {"scale": mge.tensor(1.0)}
m = PassiveObserver(q_dict, "qint8") m = PassiveObserver(q_dict, "qint8")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册