未验证 提交 7b743ba2 编写于 作者: J JYChen 提交者: GitHub

catch dimentions error when input is empty in static.nn.group_norm (#35613)

上级 dc3c845a
......@@ -3574,7 +3574,7 @@ def group_norm(input,
Refer to `Group Normalization <https://arxiv.org/abs/1803.08494>`_ .
Parameters:
input(Tensor): 4-D Tensor, the data type is float32 or float64.
input(Tensor): Tensor with dimension greater than 1, the data type is float32 or float64.
groups(int): The number of groups that divided from channels, the data type
is int32.
epsilon(float, optional): The small value added to the variance to prevent
......@@ -3591,12 +3591,12 @@ def group_norm(input,
data_layout(str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`.
The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`.
`[batch_size, input_channels, *]`.
name (str, optional): The default value is None. Normally there is no need for user to set this
property. For more information, please refer to :ref:`api_guide_Name` .
Returns:
Tensor: A 4-D Tensor has same data type and data format with `input`.
Tensor: A Tensor has same data type and data format with `input`.
Examples:
.. code-block:: python
......@@ -3615,6 +3615,10 @@ def group_norm(input,
# create intput and parameters
inputs = {'X': input}
input_shape = input.shape
if len(input_shape) < 2:
raise ValueError(
f"The dimensions of Op(fluid.layers.group_norm)'s input should be more than 1. But received {len(input_shape)}"
)
if data_layout != 'NCHW' and data_layout != 'NHWC':
raise ValueError(
"Param(data_layout) of Op(fluid.layers.group_norm) got wrong value: received "
......
......@@ -152,5 +152,20 @@ class TestGroupNormAPIV2_With_General_Dimensions(unittest.TestCase):
self.assertTrue(np.allclose(result2, expect_res2, atol=1e-5))
class TestGroupNormDimException(unittest.TestCase):
def test_exception(self):
def test_empty_input_static_API():
x = paddle.to_tensor([], dtype='float32')
paddle.static.nn.group_norm(x, 3)
self.assertRaises(ValueError, test_empty_input_static_API)
def test_one_dim_input_static_API():
x = paddle.randn((3, ), dtype='float32')
paddle.static.nn.group_norm(x, 3)
self.assertRaises(ValueError, test_one_dim_input_static_API)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册