diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4a93c7d11fd8aedaaf8faf770296e6dd920e8626..b44e249ee9fabf158f44cee9a05f88a419dda6f9 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3574,7 +3574,7 @@ def group_norm(input, Refer to `Group Normalization `_ . Parameters: - input(Tensor): 4-D Tensor, the data type is float32 or float64. + input(Tensor): Tensor with dimension greater than 1, the data type is float32 or float64. groups(int): The number of groups that divided from channels, the data type is int32. epsilon(float, optional): The small value added to the variance to prevent @@ -3591,12 +3591,12 @@ def group_norm(input, data_layout(str, optional): Specify the data format of the input, and the data format of the output will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: - `[batch_size, input_channels, input_height, input_width]`. + `[batch_size, input_channels, *]`. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . Returns: - Tensor: A 4-D Tensor has same data type and data format with `input`. + Tensor: A Tensor has same data type and data format with `input`. Examples: .. code-block:: python @@ -3615,6 +3615,10 @@ def group_norm(input, # create intput and parameters inputs = {'X': input} input_shape = input.shape + if len(input_shape) < 2: + raise ValueError( + f"The dimensions of Op(fluid.layers.group_norm)'s input should be more than 1. But received {len(input_shape)}" + ) if data_layout != 'NCHW' and data_layout != 'NHWC': raise ValueError( "Param(data_layout) of Op(fluid.layers.group_norm) got wrong value: received " diff --git a/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py index fbdf4a1cfd1ac0a23ac0853271042e0a5a090385..a22233dfcc8c6d63884ec6bf82a145facb79c961 100644 --- a/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py +++ b/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py @@ -152,5 +152,20 @@ class TestGroupNormAPIV2_With_General_Dimensions(unittest.TestCase): self.assertTrue(np.allclose(result2, expect_res2, atol=1e-5)) +class TestGroupNormDimException(unittest.TestCase): + def test_exception(self): + def test_empty_input_static_API(): + x = paddle.to_tensor([], dtype='float32') + paddle.static.nn.group_norm(x, 3) + + self.assertRaises(ValueError, test_empty_input_static_API) + + def test_one_dim_input_static_API(): + x = paddle.randn((3, ), dtype='float32') + paddle.static.nn.group_norm(x, 3) + + self.assertRaises(ValueError, test_one_dim_input_static_API) + + if __name__ == '__main__': unittest.main()