From da11aa40efa0e6ef4bfbcd72c9e3f8f86c39cd06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=AD=A3=E6=B5=B7?= <65210872+ccsuzzh@users.noreply.github.com> Date: Tue, 31 Jan 2023 11:02:11 +0800 Subject: [PATCH] Fix Python IndexError of case13: paddle.static.nn.batch_norm (#50011) * add channel_num check for paddle.static.nn.batch_norm * fix bugs * fix bugs --- python/paddle/fluid/tests/unittests/test_batch_norm_op.py | 4 ++++ python/paddle/fluid/tests/unittests/test_fold_op.py | 2 +- python/paddle/static/nn/common.py | 6 ++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index c2a6c468e5..02171db3fc 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -768,6 +768,10 @@ class TestBatchNormOpError(unittest.TestCase): ) self.assertRaises(TypeError, paddle.static.nn.batch_norm, x2) + # the first dimension of input for batch_norm must between [2d, 5d]. + x3 = paddle.static.data("", shape=[0], dtype="float32") + self.assertRaises(ValueError, paddle.static.nn.batch_norm, x3) + class TestDygraphBatchNormAPIError(unittest.TestCase): def test_errors(self): diff --git a/python/paddle/fluid/tests/unittests/test_fold_op.py b/python/paddle/fluid/tests/unittests/test_fold_op.py index 1f3193fa1f..a86161cc45 100644 --- a/python/paddle/fluid/tests/unittests/test_fold_op.py +++ b/python/paddle/fluid/tests/unittests/test_fold_op.py @@ -179,7 +179,7 @@ class TestFoldOpError(unittest.TestCase): with program_guard(Program(), Program()): def test_input_shape(): - # input_shpae must be 3-D + # input_shape must be 3-D x = paddle.randn(shape=[2, 3, 6, 7], dtype="float32") out = fold(x, output_sizes=[2, 3], kernel_sizes=[2, 2]) diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index 3b40153cbb..c43385a8e9 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -2731,6 +2731,12 @@ def batch_norm( dtype = core.VarDesc.VarType.FP32 input_shape = input.shape + if len(input.shape) < 2 or len(input.shape) > 5: + raise ValueError( + 'expected 2D or 3D or 4D or 5D input (got {}D input, input shape is: {})'.format( + len(input.shape), input_shape + ) + ) if data_layout == 'NCHW': channel_num = input_shape[1] else: -- GitLab