未验证 提交 2dc3d40c 编写于 作者: Z Zhang Zheng 提交者: GitHub

Support NHWC layout in GroupNorm (#47533)

* Support NHWC layout in GroupNorm

* fix cteset
上级 c79ae02b
......@@ -78,7 +78,7 @@ class TestDygraphGroupNormv2(unittest.TestCase):
def attr_data_format():
out = paddle.nn.GroupNorm(
num_groups=2, num_channels=2, data_format="NHWC"
num_groups=2, num_channels=2, data_format="CNHW"
)
self.assertRaises(ValueError, attr_data_format)
......
......@@ -377,8 +377,9 @@ class GroupNorm(Layer):
self._epsilon = epsilon
self._num_channels = num_channels
self._num_groups = num_groups
if data_format != 'NCHW':
if data_format not in ['NCHW', 'NHWC']:
raise ValueError("unsupported data layout:" + data_format)
self._data_format = data_format
param_shape = [self._num_channels]
......@@ -430,7 +431,7 @@ class GroupNorm(Layer):
self.bias,
self._epsilon,
self._num_groups,
"NCHW",
self._data_format,
)
return dygraph_utils._append_activation_in_dygraph(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册