未验证 提交 84fe2de6 编写于 作者: 张春乔 提交者: GitHub

fix the div 0 error of sequence_concat (#49963)

* fix the div 0 error of sequence_concat
* Update test_sequence_concat.py
上级 20075bf3
...@@ -69,6 +69,11 @@ class SequenceConcatOp : public framework::OperatorWithKernel { ...@@ -69,6 +69,11 @@ class SequenceConcatOp : public framework::OperatorWithKernel {
out_dims = phi::vectorize(x_dim); out_dims = phi::vectorize(x_dim);
} }
batch_size += x_dim[0]; batch_size += x_dim[0];
PADDLE_ENFORCE_NE(
x_dim[0],
0,
platform::errors::InvalidArgument(
"The first dim of SequenceConcatOp inputs must not be 0."));
if (feature_size == 0) { if (feature_size == 0) {
feature_size = phi::product(x_dim) / x_dim[0]; feature_size = phi::product(x_dim) / x_dim[0];
} else { } else {
......
...@@ -134,6 +134,15 @@ class TestSequenceConcatOpError(unittest.TestCase): ...@@ -134,6 +134,15 @@ class TestSequenceConcatOpError(unittest.TestCase):
self.assertRaises(TypeError, test_dtype) self.assertRaises(TypeError, test_dtype)
def test_0_shape():
# dtype must be 'float32', 'float64', 'int64'
x4_data = paddle.static.data(name="x4", shape=[0], dtype='float32')
y4_data = paddle.static.data(name="y4", shape=[1], dtype='float32')
input_list = [x4_data, y4_data]
paddle.static.nn.sequence_lod.sequence_concat(input=input_list)
self.assertRaises(ValueError, test_0_shape)
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册