未验证 提交 6908550e 编写于 作者: L LoneRanger 提交者: GitHub

Fix Python IndexError of case11: paddle.static.nn.data_norm (#50097)

* 为data_norm函数的input参数增加维度的判断并增加单侧

* 完善data_norm测试样例

* fix data_norm
上级 4b4d92ea
......@@ -523,6 +523,11 @@ class TestDataNormOpErrorr(unittest.TestCase):
input=x2, param_attr={}, enable_scale_and_shift=True
)
# Test input with dimension 1
paddle.enable_static()
x3 = paddle.static.data("", shape=[0], dtype="float32")
self.assertRaises(ValueError, paddle.static.nn.data_norm, x3)
if __name__ == '__main__':
unittest.main()
......@@ -514,6 +514,12 @@ def data_norm(
dtype = helper.input_dtype()
input_shape = input.shape
if len(input_shape) < 2:
raise ValueError(
"The shape pf Input < 2 (got {}D input, input shape is: {})".format(
len(input_shape), input_shape
)
)
if data_layout == 'NCHW':
channel_num = input_shape[1]
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册