未验证 提交 0f3ccd14 编写于 作者: F Feiyu Chan 提交者: GitHub

sequence_mask fix: when the input length is an empty tensor, the kernel tries...

sequence_mask fix: when the input length is an empty tensor, the kernel tries to dereference illegal sentinel iterator (#49525)
上级 032da731
......@@ -106,18 +106,23 @@ class SequenceMaskKernel : public framework::OpKernel<Tx> {
auto *x_data = x->data<Tx>();
auto x_numel = x->numel();
if (maxlen < 0) {
if (x_numel == 0) {
maxlen = 0;
} else {
#if defined(__NVCC__) || defined(__HIPCC__)
VLOG(10)
<< "SequenceMaskOp on GPU may be slow when maxlen is not provided.";
maxlen = static_cast<int>(
thrust::reduce(thrust::device_pointer_cast(x_data),
thrust::device_pointer_cast(x_data) + x_numel,
static_cast<Tx>(0),
thrust::maximum<Tx>()));
VLOG(10)
<< "SequenceMaskOp on GPU may be slow when maxlen is not provided.";
maxlen = static_cast<int>(
thrust::reduce(thrust::device_pointer_cast(x_data),
thrust::device_pointer_cast(x_data) + x_numel,
static_cast<Tx>(0),
thrust::maximum<Tx>()));
#else
maxlen = static_cast<int>(*std::max_element(x_data, x_data + x_numel));
maxlen = static_cast<int>(*std::max_element(x_data, x_data + x_numel));
#endif
}
auto y_dim = phi::vectorize<int>(x->dims());
y_dim.push_back(maxlen);
y->Resize(phi::make_ddim(y_dim));
......
......@@ -48,10 +48,14 @@ class SequenceMaskNPUKernel : public framework::OpKernel<T> {
if (maxlen < 0) {
auto x_numel = x->numel();
std::vector<T> x_vec;
framework::TensorToVector(*x, dev_ctx, &x_vec);
auto x_data = x_vec.data();
maxlen = static_cast<int>(*std::max_element(x_data, x_data + x_numel));
if (x_numel == 0) {
maxlen = 0;
} else {
std::vector<T> x_vec;
framework::TensorToVector(*x, dev_ctx, &x_vec);
auto x_data = x_vec.data();
maxlen = static_cast<int>(*std::max_element(x_data, x_data + x_numel));
}
}
auto y_dim = phi::vectorize<int>(x->dims());
y_dim.push_back(maxlen);
......
......@@ -17,6 +17,7 @@ import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import (
Program,
......@@ -171,5 +172,14 @@ class TestSequenceMaskOpError(unittest.TestCase):
self.assertRaises(TypeError, test_Variable)
class TestSequenceMaskWithEmptyTensor(unittest.TestCase):
def test_empty(self):
paddle.disable_static()
lengths = paddle.to_tensor(np.array([], dtype=np.int64))
mask = paddle.nn.functional.sequence_mask(lengths)
self.assertEqual(list(mask.shape), [0, 0])
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册