提交 7b81ca68 编写于 作者: Z zhaojichen

fix globalbatchnorm bug

上级 6c9a54af
...@@ -116,15 +116,7 @@ class _BatchNorm(Cell): ...@@ -116,15 +116,7 @@ class _BatchNorm(Cell):
group_list = [list(i) for i in world_rank_list] group_list = [list(i) for i in world_rank_list]
return group_list return group_list
def _shape_infer(self, x):
"""global batch normalization shape and axes infer"""
if len(self.shape(x)) == 4:
axes = (0, 2, 3)
re_shape = (1, self.num_features, 1, 1)
else:
axes = (0,)
re_shape = (1, self.num_features)
return axes, re_shape
def _global_sync(self, x, axes, re_shape): def _global_sync(self, x, axes, re_shape):
"""calculate global batch normalization output""" """calculate global batch normalization output"""
...@@ -150,7 +142,7 @@ class _BatchNorm(Cell): ...@@ -150,7 +142,7 @@ class _BatchNorm(Cell):
if self.training and self.use_batch_statistics: if self.training and self.use_batch_statistics:
if self.is_ge_backend: if self.is_ge_backend:
if self.is_global: if self.is_global:
axes, re_shape = self._shape_infer(x) axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape) y = self._global_sync(x, axes, re_shape)
else: else:
y, batch_mean, batch_var, _, _ = \ y, batch_mean, batch_var, _, _ = \
...@@ -189,6 +181,17 @@ def _channel_check(channel, num_channel): ...@@ -189,6 +181,17 @@ def _channel_check(channel, num_channel):
if channel != num_channel: if channel != num_channel:
raise ValueError("the input channel is not equal with num_channel") raise ValueError("the input channel is not equal with num_channel")
@constexpr
def _shape_infer(x_shape, num_feature):
"""global batch normalization shape and axes infer"""
if len(x_shape) == 4:
axes = (0, 2, 3)
re_shape = (1, num_feature, 1, 1)
else:
axes = (0,)
re_shape = (1, num_feature)
return axes, re_shape
class BatchNorm1d(_BatchNorm): class BatchNorm1d(_BatchNorm):
r""" r"""
Batch normalization layer over a 2D input. Batch normalization layer over a 2D input.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册