diff --git a/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py index 63b69e6107b7828d80c01086fab3e828c8d02ec2..72c2ffa492559e6259da84f0fc9eb5ac6cc42797 100644 --- a/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py +++ b/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py @@ -78,7 +78,7 @@ class TestDygraphGroupNormv2(unittest.TestCase): def attr_data_format(): out = paddle.nn.GroupNorm( - num_groups=2, num_channels=2, data_format="NHWC" + num_groups=2, num_channels=2, data_format="CNHW" ) self.assertRaises(ValueError, attr_data_format) diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index e95cff7b167327f8a3b997c3f43a6dc6ba8a944c..503b95f5b4257fb31450c0671dbbb0f2ce5fa44d 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -377,8 +377,9 @@ class GroupNorm(Layer): self._epsilon = epsilon self._num_channels = num_channels self._num_groups = num_groups - if data_format != 'NCHW': + if data_format not in ['NCHW', 'NHWC']: raise ValueError("unsupported data layout:" + data_format) + self._data_format = data_format param_shape = [self._num_channels] @@ -430,7 +431,7 @@ class GroupNorm(Layer): self.bias, self._epsilon, self._num_groups, - "NCHW", + self._data_format, ) return dygraph_utils._append_activation_in_dygraph(