提交 f7872774 编写于 作者: Z zhaojichen

add global batch normalization

上级 d8bd5a09
......@@ -79,7 +79,7 @@ class _BatchNorm(Cell):
if self.rank_id in self.rank_list[i] and self.group != 1:
self.is_global = True
management.create_group('group' + str(i), self.rank_list[i])
self.all_reduce = _GlobalBNHelper('group' + str(i))
self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1)
self.shape = P.Shape()
self.reduce_mean = P.ReduceMean()
self.square = P.Square()
......
......@@ -90,5 +90,4 @@ def test_global_bn():
device_num=size, parameter_broadcast=True)
net = GlobalBNNet()
input_data = Tensor(np.array([[2.4, 2.1], [3.2, 5.4]], dtype=np.float32))
net.set_train()
out = net(input_data)
_executor.compile(net,input_data)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册