From b7ed0cb850276a9288c03bb696c2cc04d99cf4ad Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 11 Jan 2021 15:37:11 +0800 Subject: [PATCH] fix(mge/quantization): replace `_reset` with "=" in observer GitOrigin-RevId: ed6af9b98d1f68811a78f44d531c52b60339fbbe --- imperative/python/megengine/quantization/observer.py | 10 +++++----- .../python/test/unit/quantization/test_observer.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/imperative/python/megengine/quantization/observer.py b/imperative/python/megengine/quantization/observer.py index 4862d79e..466ae918 100644 --- a/imperative/python/megengine/quantization/observer.py +++ b/imperative/python/megengine/quantization/observer.py @@ -153,11 +153,11 @@ class ExponentialMovingAverageObserver(MinMaxObserver): **kwargs ): super().__init__(mode, eps, dtype, narrow_range, **kwargs) - self.momentum = Tensor(momentum) + self.momentum = Tensor(momentum, dtype="float32") self.runtime_momentum = Tensor(0.0) def set_momentum(self, momentum): - self.momentum._reset(momentum) + self.momentum = Tenosr(momentum, dtype="float32") def forward(self, x_orig): if self.enabled: @@ -439,9 +439,9 @@ class HistogramObserver(MinMaxObserver): self.bins, ) - self.histogram._reset(new_histogram) - self.min_val._reset(new_min) - self.max_val._reset(new_max) + self.histogram = Tensor(new_histogram, dtype="float32") + self.min_val = Tensor(new_min, dtype="float32") + self.max_val = Tensor(new_max, dtype="float32") def forward(self, x_orig): self.sideeffect_forward(x_orig) diff --git a/imperative/python/test/unit/quantization/test_observer.py b/imperative/python/test/unit/quantization/test_observer.py index e1a0a091..d643c3e1 100644 --- a/imperative/python/test/unit/quantization/test_observer.py +++ b/imperative/python/test/unit/quantization/test_observer.py @@ -8,6 +8,7 @@ import megengine.distributed as dist from megengine.distributed.helper import get_device_count_by_fork from megengine.quantization.observer import ( ExponentialMovingAverageObserver, + HistogramObserver, MinMaxObserver, Observer, PassiveObserver, @@ -44,6 +45,16 @@ def test_exponential_moving_average_observer(): np.testing.assert_allclose(m.max_val.numpy(), expected_max) +def test_histogram_observer(): + x = np.random.rand(3, 3, 3, 3).astype("float32") + np_min, np_max = x.min(), x.max() + x = mge.tensor(x) + m = HistogramObserver() + m(x) + np.testing.assert_allclose(m.min_val.numpy(), np_min) + np.testing.assert_allclose(m.max_val.numpy(), np_max) + + def test_passive_observer(): q_dict = {"scale": mge.tensor(1.0)} m = PassiveObserver(q_dict, "qint8") -- GitLab