From 2dc3d40ce2fd50e789b959f20a3cd69ba0fe0e67 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Wed, 2 Nov 2022 19:17:13 +0800 Subject: [PATCH] Support NHWC layout in GroupNorm (#47533) * Support NHWC layout in GroupNorm * fix cteset --- python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py | 2 +- python/paddle/nn/layer/norm.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py index 63b69e6107b..72c2ffa4925 100644 --- a/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py +++ b/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py @@ -78,7 +78,7 @@ class TestDygraphGroupNormv2(unittest.TestCase): def attr_data_format(): out = paddle.nn.GroupNorm( - num_groups=2, num_channels=2, data_format="NHWC" + num_groups=2, num_channels=2, data_format="CNHW" ) self.assertRaises(ValueError, attr_data_format) diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index e95cff7b167..503b95f5b42 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -377,8 +377,9 @@ class GroupNorm(Layer): self._epsilon = epsilon self._num_channels = num_channels self._num_groups = num_groups - if data_format != 'NCHW': + if data_format not in ['NCHW', 'NHWC']: raise ValueError("unsupported data layout:" + data_format) + self._data_format = data_format param_shape = [self._num_channels] @@ -430,7 +431,7 @@ class GroupNorm(Layer): self.bias, self._epsilon, self._num_groups, - "NCHW", + self._data_format, ) return dygraph_utils._append_activation_in_dygraph( -- GitLab