diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 2a1ca28ed45b184001137f0afcd72eca729eb032..7a102b0bbe9446e6151b0587a3764e66fd783eb4 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -116,15 +116,7 @@ class _BatchNorm(Cell): group_list = [list(i) for i in world_rank_list] return group_list - def _shape_infer(self, x): - """global batch normalization shape and axes infer""" - if len(self.shape(x)) == 4: - axes = (0, 2, 3) - re_shape = (1, self.num_features, 1, 1) - else: - axes = (0,) - re_shape = (1, self.num_features) - return axes, re_shape + def _global_sync(self, x, axes, re_shape): """calculate global batch normalization output""" @@ -150,7 +142,7 @@ class _BatchNorm(Cell): if self.training and self.use_batch_statistics: if self.is_ge_backend: if self.is_global: - axes, re_shape = self._shape_infer(x) + axes, re_shape = _shape_infer(F.shape(x), self.num_features) y = self._global_sync(x, axes, re_shape) else: y, batch_mean, batch_var, _, _ = \ @@ -189,6 +181,17 @@ def _channel_check(channel, num_channel): if channel != num_channel: raise ValueError("the input channel is not equal with num_channel") +@constexpr +def _shape_infer(x_shape, num_feature): + """global batch normalization shape and axes infer""" + if len(x_shape) == 4: + axes = (0, 2, 3) + re_shape = (1, num_feature, 1, 1) + else: + axes = (0,) + re_shape = (1, num_feature) + return axes, re_shape + class BatchNorm1d(_BatchNorm): r""" Batch normalization layer over a 2D input.