diff --git a/tests/ut/python/nn/test_batchnorm.py b/tests/ut/python/nn/test_batchnorm.py index b6e27e69502541dfae282c36b4ffa5d04a22a773..10b4cb00a1e326271ed569ecb7e3d06663de8f56 100644 --- a/tests/ut/python/nn/test_batchnorm.py +++ b/tests/ut/python/nn/test_batchnorm.py @@ -73,21 +73,3 @@ def test_compile_groupnorm(): net = nn.GroupNorm(16, 64) input_data = Tensor(np.random.rand(1,64,256,256).astype(np.float32)) _executor.compile(net, input_data) - -class GlobalBNNet(nn.Cell): - def __init__(self): - super(GlobalBNNet, self).__init__() - self.bn = nn.GlobalBatchNorm(num_features = 2, group = 2) - def construct(self, x): - return self.bn(x) - -def test_global_bn(): - init("hccl") - size = 4 - context.set_context(mode=context.GRAPH_MODE) - context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, - device_num=size, parameter_broadcast=True) - net = GlobalBNNet() - input_data = Tensor(np.array([[2.4, 2.1], [3.2, 5.4]], dtype=np.float32)) - _executor.compile(net,input_data)