提交 a226f02e 编写于 作者: M Megvii Engine Team

fix(mge/imperative): fix syncbn in symbolic mode

GitOrigin-RevId: a9794318a7e28aa262d6047b90e39f2904c35d8c
上级 34333593
......@@ -22,7 +22,7 @@ from .debug_param import get_conv_execution_strategy
from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu
from .math import argsort, max, sum
from .tensor import add_axis, broadcast, concat, full, remove_axis, reshape
from .tensor import add_axis, broadcast, concat, remove_axis, reshape
from .types import _pair, _pair_nonzero
__all__ = [
......@@ -692,7 +692,7 @@ def batch_norm2d(
def sync_batch_norm(
input: Tensor,
inp: Tensor,
running_mean: Tensor,
running_var: Tensor,
weight: Optional[Tensor] = None,
......@@ -723,25 +723,30 @@ def sync_batch_norm(
Default: 1e-5.
"""
assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode)
_channels = input.shape[1]
_ndim = input.ndim
_channels = inp.shape[1]
_ndim = inp.ndim
_device = inp.device
_dtype = inp.dtype
_param_shape = (1, _channels) + (1,) * (_ndim - 2)
_reduce_axis = [0] + [i for i in range(2, _ndim)]
if training:
def _sum_on_channel(input):
return apply(builtin.Reduce(mode="SUM"), input, Tensor(_param_shape))[0]
def _sum_on_channel(inp):
return inp.sum(axis=_reduce_axis, keepdims=True)
reduce_size = input.shape[0]
reduce_size = inp.shape[0]
for i in range(2, _ndim):
reduce_size = reduce_size * input.shape[i]
channel_x1s = _sum_on_channel(input)
channel_x2s = _sum_on_channel(input ** 2)
reduce_size = reduce_size * inp.shape[i]
channel_x1s = _sum_on_channel(inp)
channel_x2s = _sum_on_channel(inp ** 2)
if is_distributed():
# reduce all nodes' data to calculate mean and variance
reduce_size = full([1 for _ in range(_ndim)], reduce_size)
stat = concat([reduce_size, channel_x1s, channel_x2s], axis=1)
reduce_size = broadcast(Tensor(reduce_size, dtype=_dtype), [1] * _ndim)
stat = concat(
[reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1
)
stat = all_reduce_sum(stat, group)
reduce_size = stat[:, :1].reshape(1)
channel_x1s = stat[:, 1 : 1 + _channels]
......@@ -775,11 +780,11 @@ def sync_batch_norm(
inv_var_wt = invsqrt_channel_variance * weight
neg_channel_mean = -channel_mean
if bias is not None:
outvar = input * inv_var_wt + (neg_channel_mean * inv_var_wt + bias)
outvar = inp * inv_var_wt + (neg_channel_mean * inv_var_wt + bias)
else:
outvar = input * inv_var_wt + neg_channel_mean * inv_var_wt
outvar = inp * inv_var_wt + neg_channel_mean * inv_var_wt
else:
outvar = input * invsqrt_channel_variance + (
outvar = inp * invsqrt_channel_variance + (
-channel_mean * invsqrt_channel_variance
)
if bias is not None:
......
......@@ -27,7 +27,6 @@ from megengine.test import assertTensorClose
@pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
)
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed
def test_syncbn():
nr_chan = 8
......@@ -154,7 +153,6 @@ def test_batchnorm():
@pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
)
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed
def test_syncbn1d():
nr_chan = 8
......@@ -257,7 +255,6 @@ def test_batchnorm2d():
@pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
)
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed
def test_syncbn2d():
nr_chan = 8
......@@ -336,7 +333,6 @@ def test_batchnorm_no_stats():
@pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
)
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed
def test_syncbn_no_stats():
nr_chan = 8
......@@ -393,7 +389,6 @@ def test_batchnorm2d_no_stats():
@pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
)
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed
def test_syncbn2d_no_stats():
nr_chan = 8
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册