diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index b7ff90cf5cd6519728a8b94e67bbac8a0c378735..f498460f324080c70f8d9b41b8d99d0e309ea5b0 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1385,6 +1385,11 @@ def sync_batch_norm( momentum=momentum, eps=eps, ) + if amp._enabled: + inp, weight, bias, running_mean, running_var = cast_tensors( + inp, weight, bias, running_mean, running_var, promote=True + ) + _channels = make_shape_tuple(inp.shape)[1] _ndim = inp.ndim _device = inp.device @@ -1464,7 +1469,8 @@ def sync_batch_norm( channel_x2s, channel_mean, ) - + if amp._enabled: + outvar = outvar.astype("float16") return outvar diff --git a/imperative/python/test/unit/module/test_batchnorm.py b/imperative/python/test/unit/module/test_batchnorm.py index 901b770d114e780684ae9694d70dc0ce26ec3297..911125cc328e37780d3413e877a81e87b5bd371b 100644 --- a/imperative/python/test/unit/module/test_batchnorm.py +++ b/imperative/python/test/unit/module/test_batchnorm.py @@ -13,6 +13,7 @@ import numpy as np import pytest import megengine as mge +import megengine.amp as amp import megengine.distributed as dist from megengine import Tensor, jit from megengine.autodiff.grad_manager import GradManager @@ -24,7 +25,8 @@ _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol @pytest.mark.require_ngpu(2) @pytest.mark.isolated_distributed -def test_syncbn(): +@pytest.mark.parametrize("enable_amp", [False, True]) +def test_syncbn(enable_amp): nr_chan = 8 data_shape = (3, nr_chan, 4, 16) momentum = 0.9 @@ -38,12 +40,17 @@ def test_syncbn(): @dist.launcher(n_gpus=2) def worker(data, yv_expect, running_mean, running_var): - rank = dist.get_rank() - bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps) - for i in range(steps): - yv = bn(Tensor(data[rank][i])) - - _assert_allclose(yv.numpy(), yv_expect[rank]) + with amp.autocast(enabled=enable_amp): + rank = dist.get_rank() + bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps) + for i in range(steps): + yv = bn(Tensor(data[rank][i])) + if enable_amp: + np.testing.assert_allclose( + yv.numpy(), yv_expect[rank], atol=5e-4, rtol=5e-4 + ) + else: + _assert_allclose(yv.numpy(), yv_expect[rank]) _assert_allclose(bn.running_mean.numpy(), running_mean) _assert_allclose(bn.running_var.numpy(), running_var)