提交 a226f02e 编写于 作者: M Megvii Engine Team

fix(mge/imperative): fix syncbn in symbolic mode

GitOrigin-RevId: a9794318a7e28aa262d6047b90e39f2904c35d8c
上级 34333593
...@@ -22,7 +22,7 @@ from .debug_param import get_conv_execution_strategy ...@@ -22,7 +22,7 @@ from .debug_param import get_conv_execution_strategy
from .distributed import all_reduce_sum from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu from .elemwise import exp, floor, log, log1p, maximum, minimum, relu
from .math import argsort, max, sum from .math import argsort, max, sum
from .tensor import add_axis, broadcast, concat, full, remove_axis, reshape from .tensor import add_axis, broadcast, concat, remove_axis, reshape
from .types import _pair, _pair_nonzero from .types import _pair, _pair_nonzero
__all__ = [ __all__ = [
...@@ -692,7 +692,7 @@ def batch_norm2d( ...@@ -692,7 +692,7 @@ def batch_norm2d(
def sync_batch_norm( def sync_batch_norm(
input: Tensor, inp: Tensor,
running_mean: Tensor, running_mean: Tensor,
running_var: Tensor, running_var: Tensor,
weight: Optional[Tensor] = None, weight: Optional[Tensor] = None,
...@@ -723,25 +723,30 @@ def sync_batch_norm( ...@@ -723,25 +723,30 @@ def sync_batch_norm(
Default: 1e-5. Default: 1e-5.
""" """
assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode) assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode)
_channels = input.shape[1] _channels = inp.shape[1]
_ndim = input.ndim _ndim = inp.ndim
_device = inp.device
_dtype = inp.dtype
_param_shape = (1, _channels) + (1,) * (_ndim - 2) _param_shape = (1, _channels) + (1,) * (_ndim - 2)
_reduce_axis = [0] + [i for i in range(2, _ndim)]
if training: if training:
def _sum_on_channel(input): def _sum_on_channel(inp):
return apply(builtin.Reduce(mode="SUM"), input, Tensor(_param_shape))[0] return inp.sum(axis=_reduce_axis, keepdims=True)
reduce_size = input.shape[0] reduce_size = inp.shape[0]
for i in range(2, _ndim): for i in range(2, _ndim):
reduce_size = reduce_size * input.shape[i] reduce_size = reduce_size * inp.shape[i]
channel_x1s = _sum_on_channel(input) channel_x1s = _sum_on_channel(inp)
channel_x2s = _sum_on_channel(input ** 2) channel_x2s = _sum_on_channel(inp ** 2)
if is_distributed(): if is_distributed():
# reduce all nodes' data to calculate mean and variance # reduce all nodes' data to calculate mean and variance
reduce_size = full([1 for _ in range(_ndim)], reduce_size) reduce_size = broadcast(Tensor(reduce_size, dtype=_dtype), [1] * _ndim)
stat = concat([reduce_size, channel_x1s, channel_x2s], axis=1) stat = concat(
[reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1
)
stat = all_reduce_sum(stat, group) stat = all_reduce_sum(stat, group)
reduce_size = stat[:, :1].reshape(1) reduce_size = stat[:, :1].reshape(1)
channel_x1s = stat[:, 1 : 1 + _channels] channel_x1s = stat[:, 1 : 1 + _channels]
...@@ -775,11 +780,11 @@ def sync_batch_norm( ...@@ -775,11 +780,11 @@ def sync_batch_norm(
inv_var_wt = invsqrt_channel_variance * weight inv_var_wt = invsqrt_channel_variance * weight
neg_channel_mean = -channel_mean neg_channel_mean = -channel_mean
if bias is not None: if bias is not None:
outvar = input * inv_var_wt + (neg_channel_mean * inv_var_wt + bias) outvar = inp * inv_var_wt + (neg_channel_mean * inv_var_wt + bias)
else: else:
outvar = input * inv_var_wt + neg_channel_mean * inv_var_wt outvar = inp * inv_var_wt + neg_channel_mean * inv_var_wt
else: else:
outvar = input * invsqrt_channel_variance + ( outvar = inp * invsqrt_channel_variance + (
-channel_mean * invsqrt_channel_variance -channel_mean * invsqrt_channel_variance
) )
if bias is not None: if bias is not None:
......
...@@ -27,7 +27,6 @@ from megengine.test import assertTensorClose ...@@ -27,7 +27,6 @@ from megengine.test import assertTensorClose
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now" platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
) )
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_syncbn(): def test_syncbn():
nr_chan = 8 nr_chan = 8
...@@ -154,7 +153,6 @@ def test_batchnorm(): ...@@ -154,7 +153,6 @@ def test_batchnorm():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now" platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
) )
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_syncbn1d(): def test_syncbn1d():
nr_chan = 8 nr_chan = 8
...@@ -257,7 +255,6 @@ def test_batchnorm2d(): ...@@ -257,7 +255,6 @@ def test_batchnorm2d():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now" platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
) )
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_syncbn2d(): def test_syncbn2d():
nr_chan = 8 nr_chan = 8
...@@ -336,7 +333,6 @@ def test_batchnorm_no_stats(): ...@@ -336,7 +333,6 @@ def test_batchnorm_no_stats():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now" platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
) )
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_syncbn_no_stats(): def test_syncbn_no_stats():
nr_chan = 8 nr_chan = 8
...@@ -393,7 +389,6 @@ def test_batchnorm2d_no_stats(): ...@@ -393,7 +389,6 @@ def test_batchnorm2d_no_stats():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now" platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
) )
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_syncbn2d_no_stats(): def test_syncbn2d_no_stats():
nr_chan = 8 nr_chan = 8
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册