提交 17e27824 编写于 作者: Z zhaojichen

add global batch normalization

上级 c5120e77
......@@ -73,21 +73,3 @@ def test_compile_groupnorm():
net = nn.GroupNorm(16, 64)
input_data = Tensor(np.random.rand(1,64,256,256).astype(np.float32))
_executor.compile(net, input_data)
class GlobalBNNet(nn.Cell):
def __init__(self):
super(GlobalBNNet, self).__init__()
self.bn = nn.GlobalBatchNorm(num_features = 2, group = 2)
def construct(self, x):
return self.bn(x)
def test_global_bn():
init("hccl")
size = 4
context.set_context(mode=context.GRAPH_MODE)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=size, parameter_broadcast=True)
net = GlobalBNNet()
input_data = Tensor(np.array([[2.4, 2.1], [3.2, 5.4]], dtype=np.float32))
_executor.compile(net,input_data)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册