diff --git a/imperative/python/megengine/quantization/__init__.py b/imperative/python/megengine/quantization/__init__.py index 9c8a0e0da5f9f7c8609584653f68b1d3ab584c85..d8be24ee60fc17b05c558b49ef31925f99ba7596 100644 --- a/imperative/python/megengine/quantization/__init__.py +++ b/imperative/python/megengine/quantization/__init__.py @@ -15,6 +15,7 @@ from .qconfig import ( ema_fakequant_qconfig, ema_lowbit_fakequant_qconfig, min_max_fakequant_qconfig, + sync_ema_fakequant_qconfig, tqt_quant_qconfig, ) from .utils import QuantMode diff --git a/imperative/python/megengine/quantization/observer.py b/imperative/python/megengine/quantization/observer.py index 48b103c6e75e2c129c8068912a14a5966e7da332..b6239bafb8cae92d1cb97388a85602b92ca1717d 100644 --- a/imperative/python/megengine/quantization/observer.py +++ b/imperative/python/megengine/quantization/observer.py @@ -12,6 +12,8 @@ import numpy as np from .. import functional as F from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype +from ..distributed import WORLD, get_rank, is_distributed +from ..functional.distributed import all_reduce_max, all_reduce_min from ..module import Module from ..tensor import Tensor from .utils import QuantMode, Round, get_qparam_dict @@ -123,6 +125,21 @@ class MinMaxObserver(Observer): return x_orig +class SyncMinMaxObserver(MinMaxObserver): + def forward(self, x_orig): + if self.enable: + x = x_orig.detach() + if is_distributed(): + min_x = all_reduce_min(x.min(), WORLD) + max_x = all_reduce_max(x.max(), WORLD) + else: + min_x = x.min() + max_x = x.max() + self.min_val._reset(F.minimum(self.min_val, min_x)) + self.max_val._reset(F.maximum(self.max_val, max_x)) + return x_orig + + class ExponentialMovingAverageObserver(MinMaxObserver): def __init__( self, @@ -157,6 +174,28 @@ class ExponentialMovingAverageObserver(MinMaxObserver): return x_orig +class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver): + def forward(self, x_orig): + if self.enabled: + x = x_orig.detach() + if is_distributed: + min_x = all_reduce_min(x.min(), WORLD) + max_x = all_reduce_max(x.max(), WORLD) + else: + min_x = x.min() + max_x = x.max() + self.min_val._reset( + self.min_val * self.runtime_momentum + + (1 - self.runtime_momentum) * min_x + ) + self.max_val._reset( + self.max_val * self.runtime_momentum + + (1 - self.runtime_momentum) * max_x + ) + self.runtime_momentum = self.momentum + return x_orig + + class HistogramObserver(MinMaxObserver): def __init__( self, diff --git a/imperative/python/megengine/quantization/qconfig.py b/imperative/python/megengine/quantization/qconfig.py index 6606c1a513be2cf3d1a766a7c044f550b6c8480d..74757c7dbf08d6bd6da711ad2346d6e2b78a0045 100644 --- a/imperative/python/megengine/quantization/qconfig.py +++ b/imperative/python/megengine/quantization/qconfig.py @@ -13,6 +13,8 @@ from .observer import ( ExponentialMovingAverageObserver, HistogramObserver, MinMaxObserver, + SyncExponentialMovingAverageObserver, + SyncMinMaxObserver, ) @@ -92,6 +94,15 @@ ema_fakequant_qconfig = QConfig( act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), ) +sync_ema_fakequant_qconfig = QConfig( + weight_observer=partial(SyncMinMaxObserver, dtype="qint8", narrow_range=True), + act_observer=partial( + SyncExponentialMovingAverageObserver, dtype="qint8", narrow_range=False + ), + weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), + act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), +) + ema_lowbit_fakequant_qconfig = QConfig( weight_observer=partial(MinMaxObserver, dtype="qint4", narrow_range=False), act_observer=partial( diff --git a/imperative/python/test/unit/module/test_batchnorm.py b/imperative/python/test/unit/module/test_batchnorm.py index f9bef4d2561f341ed625115f89eee8e22f88a599..bb14f2cb3884ece343dd46af8968eae8d0fd23d3 100644 --- a/imperative/python/test/unit/module/test_batchnorm.py +++ b/imperative/python/test/unit/module/test_batchnorm.py @@ -143,7 +143,6 @@ def test_batchnorm(): @pytest.mark.skipif( platform.system() == "Darwin", reason="do not imp GPU mode at macos now" ) -@pytest.mark.isolated_distributed def test_syncbn1d(): nr_chan = 8 data_shape = (3, nr_chan, 4) @@ -234,7 +233,6 @@ def test_batchnorm2d(): @pytest.mark.skipif( platform.system() == "Darwin", reason="do not imp GPU mode at macos now" ) -@pytest.mark.isolated_distributed def test_syncbn2d(): nr_chan = 8 data_shape = (3, nr_chan, 16, 16) @@ -305,7 +303,6 @@ def test_batchnorm_no_stats(): @pytest.mark.skipif( platform.system() == "Darwin", reason="do not imp GPU mode at macos now" ) -@pytest.mark.isolated_distributed def test_syncbn_no_stats(): nr_chan = 8 data_shape = (3, nr_chan, 4) @@ -354,7 +351,6 @@ def test_batchnorm2d_no_stats(): @pytest.mark.skipif( platform.system() == "Darwin", reason="do not imp GPU mode at macos now" ) -@pytest.mark.isolated_distributed def test_syncbn2d_no_stats(): nr_chan = 8 data_shape = (3, nr_chan, 16, 16) diff --git a/imperative/python/test/unit/quantization/test_observer.py b/imperative/python/test/unit/quantization/test_observer.py new file mode 100644 index 0000000000000000000000000000000000000000..497ed461e5582629b95772ee5e297227b1b5e9bf --- /dev/null +++ b/imperative/python/test/unit/quantization/test_observer.py @@ -0,0 +1,52 @@ +import multiprocessing as mp +import platform + +import numpy as np +import pytest + +import megengine as mge +import megengine.distributed as dist +import megengine.quantization.observer as ob +from megengine.distributed.helper import get_device_count_by_fork + + +def test_min_max_observer(): + x = np.random.rand(3, 3, 3, 3).astype("float32") + np_min, np_max = x.min(), x.max() + x = mge.tensor(x) + m = ob.MinMaxObserver() + m(x) + assert m.min_val == np_min and m.max_val == np_max + + +@pytest.mark.skipif( + platform.system() == "Darwin", reason="do not imp GPU mode at macos now" +) +@pytest.mark.skipif( + platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" +) +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") +@pytest.mark.isolated_distributed +def test_sync_min_max_observer(): + x = np.random.rand(6, 3, 3, 3).astype("float32") + np_min, np_max = x.min(), x.max() + world_size = 2 + port = dist.get_free_ports(1)[0] + server = dist.Server(port) + + def worker(rank, slc): + dist.init_process_group("localhost", port, world_size, rank, rank) + m = ob.SyncMinMaxObserver() + y = mge.tensor(x[slc]) + m(y) + assert m.min_val == np_min and m.max_val == np_max + + procs = [] + for rank in range(world_size): + slc = slice(rank * 3, (rank + 1) * 3) + p = mp.Process(target=worker, args=(rank, slc,), daemon=True) + p.start() + procs.append(p) + for p in procs: + p.join(20) + assert p.exitcode == 0