提交 33e8879a 编写于 作者: M Megvii Engine Team

feat(mge/quantization): support distributed qat

GitOrigin-RevId: c915c843b865d2c462fbdeb09ec525cd2a73ad90
上级 8e11204a
......@@ -15,6 +15,7 @@ from .qconfig import (
ema_fakequant_qconfig,
ema_lowbit_fakequant_qconfig,
min_max_fakequant_qconfig,
sync_ema_fakequant_qconfig,
tqt_quant_qconfig,
)
from .utils import QuantMode
......@@ -12,6 +12,8 @@ import numpy as np
from .. import functional as F
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype
from ..distributed import WORLD, get_rank, is_distributed
from ..functional.distributed import all_reduce_max, all_reduce_min
from ..module import Module
from ..tensor import Tensor
from .utils import QuantMode, Round, get_qparam_dict
......@@ -123,6 +125,21 @@ class MinMaxObserver(Observer):
return x_orig
class SyncMinMaxObserver(MinMaxObserver):
def forward(self, x_orig):
if self.enable:
x = x_orig.detach()
if is_distributed():
min_x = all_reduce_min(x.min(), WORLD)
max_x = all_reduce_max(x.max(), WORLD)
else:
min_x = x.min()
max_x = x.max()
self.min_val._reset(F.minimum(self.min_val, min_x))
self.max_val._reset(F.maximum(self.max_val, max_x))
return x_orig
class ExponentialMovingAverageObserver(MinMaxObserver):
def __init__(
self,
......@@ -157,6 +174,28 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
return x_orig
class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
def forward(self, x_orig):
if self.enabled:
x = x_orig.detach()
if is_distributed:
min_x = all_reduce_min(x.min(), WORLD)
max_x = all_reduce_max(x.max(), WORLD)
else:
min_x = x.min()
max_x = x.max()
self.min_val._reset(
self.min_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * min_x
)
self.max_val._reset(
self.max_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * max_x
)
self.runtime_momentum = self.momentum
return x_orig
class HistogramObserver(MinMaxObserver):
def __init__(
self,
......
......@@ -13,6 +13,8 @@ from .observer import (
ExponentialMovingAverageObserver,
HistogramObserver,
MinMaxObserver,
SyncExponentialMovingAverageObserver,
SyncMinMaxObserver,
)
......@@ -92,6 +94,15 @@ ema_fakequant_qconfig = QConfig(
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
)
sync_ema_fakequant_qconfig = QConfig(
weight_observer=partial(SyncMinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=partial(
SyncExponentialMovingAverageObserver, dtype="qint8", narrow_range=False
),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
)
ema_lowbit_fakequant_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint4", narrow_range=False),
act_observer=partial(
......
......@@ -143,7 +143,6 @@ def test_batchnorm():
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.isolated_distributed
def test_syncbn1d():
nr_chan = 8
data_shape = (3, nr_chan, 4)
......@@ -234,7 +233,6 @@ def test_batchnorm2d():
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.isolated_distributed
def test_syncbn2d():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
......@@ -305,7 +303,6 @@ def test_batchnorm_no_stats():
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.isolated_distributed
def test_syncbn_no_stats():
nr_chan = 8
data_shape = (3, nr_chan, 4)
......@@ -354,7 +351,6 @@ def test_batchnorm2d_no_stats():
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.isolated_distributed
def test_syncbn2d_no_stats():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
......
import multiprocessing as mp
import platform
import numpy as np
import pytest
import megengine as mge
import megengine.distributed as dist
import megengine.quantization.observer as ob
from megengine.distributed.helper import get_device_count_by_fork
def test_min_max_observer():
x = np.random.rand(3, 3, 3, 3).astype("float32")
np_min, np_max = x.min(), x.max()
x = mge.tensor(x)
m = ob.MinMaxObserver()
m(x)
assert m.min_val == np_min and m.max_val == np_max
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_sync_min_max_observer():
x = np.random.rand(6, 3, 3, 3).astype("float32")
np_min, np_max = x.min(), x.max()
world_size = 2
port = dist.get_free_ports(1)[0]
server = dist.Server(port)
def worker(rank, slc):
dist.init_process_group("localhost", port, world_size, rank, rank)
m = ob.SyncMinMaxObserver()
y = mge.tensor(x[slc])
m(y)
assert m.min_val == np_min and m.max_val == np_max
procs = []
for rank in range(world_size):
slc = slice(rank * 3, (rank + 1) * 3)
p = mp.Process(target=worker, args=(rank, slc,), daemon=True)
p.start()
procs.append(p)
for p in procs:
p.join(20)
assert p.exitcode == 0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册