未验证 提交 2dc3d40c 编写于 作者: Z Zhang Zheng 提交者: GitHub

Support NHWC layout in GroupNorm (#47533)

* Support NHWC layout in GroupNorm

* fix cteset
上级 c79ae02b
...@@ -78,7 +78,7 @@ class TestDygraphGroupNormv2(unittest.TestCase): ...@@ -78,7 +78,7 @@ class TestDygraphGroupNormv2(unittest.TestCase):
def attr_data_format(): def attr_data_format():
out = paddle.nn.GroupNorm( out = paddle.nn.GroupNorm(
num_groups=2, num_channels=2, data_format="NHWC" num_groups=2, num_channels=2, data_format="CNHW"
) )
self.assertRaises(ValueError, attr_data_format) self.assertRaises(ValueError, attr_data_format)
......
...@@ -377,8 +377,9 @@ class GroupNorm(Layer): ...@@ -377,8 +377,9 @@ class GroupNorm(Layer):
self._epsilon = epsilon self._epsilon = epsilon
self._num_channels = num_channels self._num_channels = num_channels
self._num_groups = num_groups self._num_groups = num_groups
if data_format != 'NCHW': if data_format not in ['NCHW', 'NHWC']:
raise ValueError("unsupported data layout:" + data_format) raise ValueError("unsupported data layout:" + data_format)
self._data_format = data_format
param_shape = [self._num_channels] param_shape = [self._num_channels]
...@@ -430,7 +431,7 @@ class GroupNorm(Layer): ...@@ -430,7 +431,7 @@ class GroupNorm(Layer):
self.bias, self.bias,
self._epsilon, self._epsilon,
self._num_groups, self._num_groups,
"NCHW", self._data_format,
) )
return dygraph_utils._append_activation_in_dygraph( return dygraph_utils._append_activation_in_dygraph(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册