From f7872774f3ebb44e345af364bac9230adf6aac32 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Fri, 17 Apr 2020 08:33:40 -0400 Subject: [PATCH] add global batch normalization --- mindspore/nn/layer/normalization.py | 2 +- tests/ut/python/nn/test_batchnorm.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index c85b945a0..04de71f71 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -79,7 +79,7 @@ class _BatchNorm(Cell): if self.rank_id in self.rank_list[i] and self.group != 1: self.is_global = True management.create_group('group' + str(i), self.rank_list[i]) - self.all_reduce = _GlobalBNHelper('group' + str(i)) + self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1) self.shape = P.Shape() self.reduce_mean = P.ReduceMean() self.square = P.Square() diff --git a/tests/ut/python/nn/test_batchnorm.py b/tests/ut/python/nn/test_batchnorm.py index 24f0de85f..b6e27e695 100644 --- a/tests/ut/python/nn/test_batchnorm.py +++ b/tests/ut/python/nn/test_batchnorm.py @@ -90,5 +90,4 @@ def test_global_bn(): device_num=size, parameter_broadcast=True) net = GlobalBNNet() input_data = Tensor(np.array([[2.4, 2.1], [3.2, 5.4]], dtype=np.float32)) - net.set_train() - out = net(input_data) + _executor.compile(net,input_data) -- GitLab