提交 b7ed0cb8 编写于 作者: M Megvii Engine Team

fix(mge/quantization): replace `_reset` with "=" in observer

GitOrigin-RevId: ed6af9b98d1f68811a78f44d531c52b60339fbbe
上级 495b2003
......@@ -153,11 +153,11 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
**kwargs
):
super().__init__(mode, eps, dtype, narrow_range, **kwargs)
self.momentum = Tensor(momentum)
self.momentum = Tensor(momentum, dtype="float32")
self.runtime_momentum = Tensor(0.0)
def set_momentum(self, momentum):
self.momentum._reset(momentum)
self.momentum = Tenosr(momentum, dtype="float32")
def forward(self, x_orig):
if self.enabled:
......@@ -439,9 +439,9 @@ class HistogramObserver(MinMaxObserver):
self.bins,
)
self.histogram._reset(new_histogram)
self.min_val._reset(new_min)
self.max_val._reset(new_max)
self.histogram = Tensor(new_histogram, dtype="float32")
self.min_val = Tensor(new_min, dtype="float32")
self.max_val = Tensor(new_max, dtype="float32")
def forward(self, x_orig):
self.sideeffect_forward(x_orig)
......
......@@ -8,6 +8,7 @@ import megengine.distributed as dist
from megengine.distributed.helper import get_device_count_by_fork
from megengine.quantization.observer import (
ExponentialMovingAverageObserver,
HistogramObserver,
MinMaxObserver,
Observer,
PassiveObserver,
......@@ -44,6 +45,16 @@ def test_exponential_moving_average_observer():
np.testing.assert_allclose(m.max_val.numpy(), expected_max)
def test_histogram_observer():
x = np.random.rand(3, 3, 3, 3).astype("float32")
np_min, np_max = x.min(), x.max()
x = mge.tensor(x)
m = HistogramObserver()
m(x)
np.testing.assert_allclose(m.min_val.numpy(), np_min)
np.testing.assert_allclose(m.max_val.numpy(), np_max)
def test_passive_observer():
q_dict = {"scale": mge.tensor(1.0)}
m = PassiveObserver(q_dict, "qint8")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册