未验证 提交 0f3ccd14 编写于 作者: F Feiyu Chan 提交者: GitHub

sequence_mask fix: when the input length is an empty tensor, the kernel tries...

sequence_mask fix: when the input length is an empty tensor, the kernel tries to dereference illegal sentinel iterator (#49525)
上级 032da731
...@@ -106,18 +106,23 @@ class SequenceMaskKernel : public framework::OpKernel<Tx> { ...@@ -106,18 +106,23 @@ class SequenceMaskKernel : public framework::OpKernel<Tx> {
auto *x_data = x->data<Tx>(); auto *x_data = x->data<Tx>();
auto x_numel = x->numel(); auto x_numel = x->numel();
if (maxlen < 0) { if (maxlen < 0) {
if (x_numel == 0) {
maxlen = 0;
} else {
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
VLOG(10) VLOG(10)
<< "SequenceMaskOp on GPU may be slow when maxlen is not provided."; << "SequenceMaskOp on GPU may be slow when maxlen is not provided.";
maxlen = static_cast<int>( maxlen = static_cast<int>(
thrust::reduce(thrust::device_pointer_cast(x_data), thrust::reduce(thrust::device_pointer_cast(x_data),
thrust::device_pointer_cast(x_data) + x_numel, thrust::device_pointer_cast(x_data) + x_numel,
static_cast<Tx>(0), static_cast<Tx>(0),
thrust::maximum<Tx>())); thrust::maximum<Tx>()));
#else #else
maxlen = static_cast<int>(*std::max_element(x_data, x_data + x_numel)); maxlen = static_cast<int>(*std::max_element(x_data, x_data + x_numel));
#endif #endif
}
auto y_dim = phi::vectorize<int>(x->dims()); auto y_dim = phi::vectorize<int>(x->dims());
y_dim.push_back(maxlen); y_dim.push_back(maxlen);
y->Resize(phi::make_ddim(y_dim)); y->Resize(phi::make_ddim(y_dim));
......
...@@ -48,10 +48,14 @@ class SequenceMaskNPUKernel : public framework::OpKernel<T> { ...@@ -48,10 +48,14 @@ class SequenceMaskNPUKernel : public framework::OpKernel<T> {
if (maxlen < 0) { if (maxlen < 0) {
auto x_numel = x->numel(); auto x_numel = x->numel();
std::vector<T> x_vec; if (x_numel == 0) {
framework::TensorToVector(*x, dev_ctx, &x_vec); maxlen = 0;
auto x_data = x_vec.data(); } else {
maxlen = static_cast<int>(*std::max_element(x_data, x_data + x_numel)); std::vector<T> x_vec;
framework::TensorToVector(*x, dev_ctx, &x_vec);
auto x_data = x_vec.data();
maxlen = static_cast<int>(*std::max_element(x_data, x_data + x_numel));
}
} }
auto y_dim = phi::vectorize<int>(x->dims()); auto y_dim = phi::vectorize<int>(x->dims());
y_dim.push_back(maxlen); y_dim.push_back(maxlen);
......
...@@ -17,6 +17,7 @@ import unittest ...@@ -17,6 +17,7 @@ import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import ( from paddle.fluid.framework import (
Program, Program,
...@@ -171,5 +172,14 @@ class TestSequenceMaskOpError(unittest.TestCase): ...@@ -171,5 +172,14 @@ class TestSequenceMaskOpError(unittest.TestCase):
self.assertRaises(TypeError, test_Variable) self.assertRaises(TypeError, test_Variable)
class TestSequenceMaskWithEmptyTensor(unittest.TestCase):
def test_empty(self):
paddle.disable_static()
lengths = paddle.to_tensor(np.array([], dtype=np.int64))
mask = paddle.nn.functional.sequence_mask(lengths)
self.assertEqual(list(mask.shape), [0, 0])
paddle.enable_static()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册