From a226f02e7aea666e70ffefccf0b870c5801be1c1 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 2 Sep 2020 10:33:32 +0800 Subject: [PATCH] fix(mge/imperative): fix syncbn in symbolic mode GitOrigin-RevId: a9794318a7e28aa262d6047b90e39f2904c35d8c --- imperative/python/megengine/functional/nn.py | 35 +++++++++++-------- .../python/test/unit/module/test_batchnorm.py | 5 --- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index bc8f259ba..14e4fe83a 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -22,7 +22,7 @@ from .debug_param import get_conv_execution_strategy from .distributed import all_reduce_sum from .elemwise import exp, floor, log, log1p, maximum, minimum, relu from .math import argsort, max, sum -from .tensor import add_axis, broadcast, concat, full, remove_axis, reshape +from .tensor import add_axis, broadcast, concat, remove_axis, reshape from .types import _pair, _pair_nonzero __all__ = [ @@ -692,7 +692,7 @@ def batch_norm2d( def sync_batch_norm( - input: Tensor, + inp: Tensor, running_mean: Tensor, running_var: Tensor, weight: Optional[Tensor] = None, @@ -723,25 +723,30 @@ def sync_batch_norm( Default: 1e-5. """ assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode) - _channels = input.shape[1] - _ndim = input.ndim + _channels = inp.shape[1] + _ndim = inp.ndim + _device = inp.device + _dtype = inp.dtype _param_shape = (1, _channels) + (1,) * (_ndim - 2) + _reduce_axis = [0] + [i for i in range(2, _ndim)] if training: - def _sum_on_channel(input): - return apply(builtin.Reduce(mode="SUM"), input, Tensor(_param_shape))[0] + def _sum_on_channel(inp): + return inp.sum(axis=_reduce_axis, keepdims=True) - reduce_size = input.shape[0] + reduce_size = inp.shape[0] for i in range(2, _ndim): - reduce_size = reduce_size * input.shape[i] - channel_x1s = _sum_on_channel(input) - channel_x2s = _sum_on_channel(input ** 2) + reduce_size = reduce_size * inp.shape[i] + channel_x1s = _sum_on_channel(inp) + channel_x2s = _sum_on_channel(inp ** 2) if is_distributed(): # reduce all nodes' data to calculate mean and variance - reduce_size = full([1 for _ in range(_ndim)], reduce_size) - stat = concat([reduce_size, channel_x1s, channel_x2s], axis=1) + reduce_size = broadcast(Tensor(reduce_size, dtype=_dtype), [1] * _ndim) + stat = concat( + [reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1 + ) stat = all_reduce_sum(stat, group) reduce_size = stat[:, :1].reshape(1) channel_x1s = stat[:, 1 : 1 + _channels] @@ -775,11 +780,11 @@ def sync_batch_norm( inv_var_wt = invsqrt_channel_variance * weight neg_channel_mean = -channel_mean if bias is not None: - outvar = input * inv_var_wt + (neg_channel_mean * inv_var_wt + bias) + outvar = inp * inv_var_wt + (neg_channel_mean * inv_var_wt + bias) else: - outvar = input * inv_var_wt + neg_channel_mean * inv_var_wt + outvar = inp * inv_var_wt + neg_channel_mean * inv_var_wt else: - outvar = input * invsqrt_channel_variance + ( + outvar = inp * invsqrt_channel_variance + ( -channel_mean * invsqrt_channel_variance ) if bias is not None: diff --git a/imperative/python/test/unit/module/test_batchnorm.py b/imperative/python/test/unit/module/test_batchnorm.py index 0f91fa4b6..213b6fc39 100644 --- a/imperative/python/test/unit/module/test_batchnorm.py +++ b/imperative/python/test/unit/module/test_batchnorm.py @@ -27,7 +27,6 @@ from megengine.test import assertTensorClose @pytest.mark.skipif( platform.system() == "Windows", reason="do not imp GPU mode at Windows now" ) -@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") @pytest.mark.isolated_distributed def test_syncbn(): nr_chan = 8 @@ -154,7 +153,6 @@ def test_batchnorm(): @pytest.mark.skipif( platform.system() == "Windows", reason="do not imp GPU mode at Windows now" ) -@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") @pytest.mark.isolated_distributed def test_syncbn1d(): nr_chan = 8 @@ -257,7 +255,6 @@ def test_batchnorm2d(): @pytest.mark.skipif( platform.system() == "Windows", reason="do not imp GPU mode at Windows now" ) -@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") @pytest.mark.isolated_distributed def test_syncbn2d(): nr_chan = 8 @@ -336,7 +333,6 @@ def test_batchnorm_no_stats(): @pytest.mark.skipif( platform.system() == "Windows", reason="do not imp GPU mode at Windows now" ) -@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") @pytest.mark.isolated_distributed def test_syncbn_no_stats(): nr_chan = 8 @@ -393,7 +389,6 @@ def test_batchnorm2d_no_stats(): @pytest.mark.skipif( platform.system() == "Windows", reason="do not imp GPU mode at Windows now" ) -@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") @pytest.mark.isolated_distributed def test_syncbn2d_no_stats(): nr_chan = 8 -- GitLab