提交 406115db 编写于 作者: M Megvii Engine Team

fix(imperative): syncbn fp16 support

GitOrigin-RevId: 6059d5b76b480b65d4088573e9da1f5f608c4ae2
上级 d5ef7923
......@@ -1385,6 +1385,11 @@ def sync_batch_norm(
momentum=momentum,
eps=eps,
)
if amp._enabled:
inp, weight, bias, running_mean, running_var = cast_tensors(
inp, weight, bias, running_mean, running_var, promote=True
)
_channels = make_shape_tuple(inp.shape)[1]
_ndim = inp.ndim
_device = inp.device
......@@ -1464,7 +1469,8 @@ def sync_batch_norm(
channel_x2s,
channel_mean,
)
if amp._enabled:
outvar = outvar.astype("float16")
return outvar
......
......@@ -13,6 +13,7 @@ import numpy as np
import pytest
import megengine as mge
import megengine.amp as amp
import megengine.distributed as dist
from megengine import Tensor, jit
from megengine.autodiff.grad_manager import GradManager
......@@ -24,7 +25,8 @@ _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_syncbn():
@pytest.mark.parametrize("enable_amp", [False, True])
def test_syncbn(enable_amp):
nr_chan = 8
data_shape = (3, nr_chan, 4, 16)
momentum = 0.9
......@@ -38,12 +40,17 @@ def test_syncbn():
@dist.launcher(n_gpus=2)
def worker(data, yv_expect, running_mean, running_var):
rank = dist.get_rank()
bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps)
for i in range(steps):
yv = bn(Tensor(data[rank][i]))
_assert_allclose(yv.numpy(), yv_expect[rank])
with amp.autocast(enabled=enable_amp):
rank = dist.get_rank()
bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps)
for i in range(steps):
yv = bn(Tensor(data[rank][i]))
if enable_amp:
np.testing.assert_allclose(
yv.numpy(), yv_expect[rank], atol=5e-4, rtol=5e-4
)
else:
_assert_allclose(yv.numpy(), yv_expect[rank])
_assert_allclose(bn.running_mean.numpy(), running_mean)
_assert_allclose(bn.running_var.numpy(), running_var)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册