未验证 提交 45bd5898 编写于 作者: Y Yibing Liu 提交者: GitHub

Fix the bug of sequence_unpad op (#18290) (#18305)

* Use TensorCopySync for sequence_unpad op

* Fix the tensor memory alloc bug

test=release/1.5
上级 129f2717
...@@ -34,28 +34,26 @@ class SequenceUnpadOpKernel : public framework::OpKernel<T> { ...@@ -34,28 +34,26 @@ class SequenceUnpadOpKernel : public framework::OpKernel<T> {
auto* len_t = ctx.Input<LoDTensor>("Length"); auto* len_t = ctx.Input<LoDTensor>("Length");
auto* out_t = ctx.Output<LoDTensor>("Out"); auto* out_t = ctx.Output<LoDTensor>("Out");
const int64_t* seq_len_ptr = nullptr; auto& dev_ctx = ctx.template device_context<DeviceContext>();
framework::Tensor seq_len_cpu =
ctx.AllocateTmpTensor<T, DeviceContext>(len_t->dims(), dev_ctx);
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
LoDTensor seq_len_cpu; seq_len_cpu.mutable_data<int64_t>(platform::CPUPlace());
seq_len_cpu.Resize(len_t->dims()); framework::TensorCopySync(*len_t, platform::CPUPlace(), &seq_len_cpu);
seq_len_ptr = seq_len_cpu.mutable_data<int64_t>(platform::CPUPlace());
framework::TensorCopy(*len_t, platform::CPUPlace(),
ctx.template device_context<DeviceContext>(),
&seq_len_cpu);
} else { } else {
seq_len_ptr = len_t->data<int64_t>(); seq_len_cpu = *len_t;
} }
size_t batch_size = x_t->dims()[0]; const int64_t* seq_len_ptr = seq_len_cpu.data<int64_t>();
int64_t batch_size = len_t->dims()[0];
std::vector<size_t> out_lod0(batch_size + 1, 0); std::vector<size_t> out_lod0(batch_size + 1, 0);
for (size_t i = 0; i < batch_size; ++i) { for (int64_t i = 0; i < batch_size; ++i) {
out_lod0[i + 1] = out_lod0[i] + seq_len_ptr[i]; out_lod0[i + 1] = out_lod0[i] + static_cast<size_t>(seq_len_ptr[i]);
} }
framework::LoD out_lod; framework::LoD out_lod;
out_lod.push_back(out_lod0); out_lod.push_back(out_lod0);
out_t->set_lod(out_lod); out_t->set_lod(out_lod);
std::vector<int64_t> out_dims_vec{static_cast<int64_t>(out_lod0.back())}; std::vector<int64_t> out_dims_vec{static_cast<int64_t>(out_lod0.back())};
if (x_t->dims().size() == 2) { if (x_t->dims().size() == 2) {
out_dims_vec.push_back(1); out_dims_vec.push_back(1);
...@@ -71,8 +69,7 @@ class SequenceUnpadOpKernel : public framework::OpKernel<T> { ...@@ -71,8 +69,7 @@ class SequenceUnpadOpKernel : public framework::OpKernel<T> {
int64_t padded_length = x_t->dims()[1]; int64_t padded_length = x_t->dims()[1];
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()( math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *x_t, out_t, dev_ctx, *x_t, out_t, padded_length, 0, false, math::kBatchLengthWidth);
padded_length, 0, false, math::kBatchLengthWidth);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册