diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index c85b945a0d98e2a001819ebd760a5e72341a8b40..04de71f71c1f368e3f6dcadc248572fae89f6e7d 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -79,7 +79,7 @@ class _BatchNorm(Cell): if self.rank_id in self.rank_list[i] and self.group != 1: self.is_global = True management.create_group('group' + str(i), self.rank_list[i]) - self.all_reduce = _GlobalBNHelper('group' + str(i)) + self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1) self.shape = P.Shape() self.reduce_mean = P.ReduceMean() self.square = P.Square() diff --git a/tests/ut/python/nn/test_batchnorm.py b/tests/ut/python/nn/test_batchnorm.py index 24f0de85f7d11ae825814636f7217bc2eb54bdf6..b6e27e69502541dfae282c36b4ffa5d04a22a773 100644 --- a/tests/ut/python/nn/test_batchnorm.py +++ b/tests/ut/python/nn/test_batchnorm.py @@ -90,5 +90,4 @@ def test_global_bn(): device_num=size, parameter_broadcast=True) net = GlobalBNNet() input_data = Tensor(np.array([[2.4, 2.1], [3.2, 5.4]], dtype=np.float32)) - net.set_train() - out = net(input_data) + _executor.compile(net,input_data)