From 84fe2de6e246fcc3fb98056873e7140aeb9aaf74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Tue, 7 Feb 2023 17:20:28 +0800 Subject: [PATCH] fix the div 0 error of sequence_concat (#49963) * fix the div 0 error of sequence_concat * Update test_sequence_concat.py --- .../fluid/operators/sequence_ops/sequence_concat_op.cc | 5 +++++ .../tests/unittests/sequence/test_sequence_concat.py | 9 +++++++++ 2 files changed, 14 insertions(+) diff --git a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc index 63aef4a628..762ca5e42d 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc @@ -69,6 +69,11 @@ class SequenceConcatOp : public framework::OperatorWithKernel { out_dims = phi::vectorize(x_dim); } batch_size += x_dim[0]; + PADDLE_ENFORCE_NE( + x_dim[0], + 0, + platform::errors::InvalidArgument( + "The first dim of SequenceConcatOp inputs must not be 0.")); if (feature_size == 0) { feature_size = phi::product(x_dim) / x_dim[0]; } else { diff --git a/python/paddle/fluid/tests/unittests/sequence/test_sequence_concat.py b/python/paddle/fluid/tests/unittests/sequence/test_sequence_concat.py index aa883c9dcf..5d281d1cd9 100644 --- a/python/paddle/fluid/tests/unittests/sequence/test_sequence_concat.py +++ b/python/paddle/fluid/tests/unittests/sequence/test_sequence_concat.py @@ -134,6 +134,15 @@ class TestSequenceConcatOpError(unittest.TestCase): self.assertRaises(TypeError, test_dtype) + def test_0_shape(): + # dtype must be 'float32', 'float64', 'int64' + x4_data = paddle.static.data(name="x4", shape=[0], dtype='float32') + y4_data = paddle.static.data(name="y4", shape=[1], dtype='float32') + input_list = [x4_data, y4_data] + paddle.static.nn.sequence_lod.sequence_concat(input=input_list) + + self.assertRaises(ValueError, test_0_shape) + if __name__ == '__main__': paddle.enable_static() -- GitLab