未验证 提交 da11aa40 编写于 作者: 张正海 提交者: GitHub

Fix Python IndexError of case13: paddle.static.nn.batch_norm (#50011)

* add channel_num check for paddle.static.nn.batch_norm

* fix bugs

* fix bugs
上级 0d32f554
......@@ -768,6 +768,10 @@ class TestBatchNormOpError(unittest.TestCase):
)
self.assertRaises(TypeError, paddle.static.nn.batch_norm, x2)
# the first dimension of input for batch_norm must between [2d, 5d].
x3 = paddle.static.data("", shape=[0], dtype="float32")
self.assertRaises(ValueError, paddle.static.nn.batch_norm, x3)
class TestDygraphBatchNormAPIError(unittest.TestCase):
def test_errors(self):
......
......@@ -179,7 +179,7 @@ class TestFoldOpError(unittest.TestCase):
with program_guard(Program(), Program()):
def test_input_shape():
# input_shpae must be 3-D
# input_shape must be 3-D
x = paddle.randn(shape=[2, 3, 6, 7], dtype="float32")
out = fold(x, output_sizes=[2, 3], kernel_sizes=[2, 2])
......
......@@ -2731,6 +2731,12 @@ def batch_norm(
dtype = core.VarDesc.VarType.FP32
input_shape = input.shape
if len(input.shape) < 2 or len(input.shape) > 5:
raise ValueError(
'expected 2D or 3D or 4D or 5D input (got {}D input, input shape is: {})'.format(
len(input.shape), input_shape
)
)
if data_layout == 'NCHW':
channel_num = input_shape[1]
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册