未验证 提交 45bd5898 编写于 作者: Y Yibing Liu 提交者: GitHub

Fix the bug of sequence_unpad op (#18290) (#18305)

* Use TensorCopySync for sequence_unpad op

* Fix the tensor memory alloc bug

test=release/1.5
上级 129f2717
......@@ -34,28 +34,26 @@ class SequenceUnpadOpKernel : public framework::OpKernel<T> {
auto* len_t = ctx.Input<LoDTensor>("Length");
auto* out_t = ctx.Output<LoDTensor>("Out");
const int64_t* seq_len_ptr = nullptr;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
framework::Tensor seq_len_cpu =
ctx.AllocateTmpTensor<T, DeviceContext>(len_t->dims(), dev_ctx);
if (platform::is_gpu_place(ctx.GetPlace())) {
LoDTensor seq_len_cpu;
seq_len_cpu.Resize(len_t->dims());
seq_len_ptr = seq_len_cpu.mutable_data<int64_t>(platform::CPUPlace());
framework::TensorCopy(*len_t, platform::CPUPlace(),
ctx.template device_context<DeviceContext>(),
&seq_len_cpu);
seq_len_cpu.mutable_data<int64_t>(platform::CPUPlace());
framework::TensorCopySync(*len_t, platform::CPUPlace(), &seq_len_cpu);
} else {
seq_len_ptr = len_t->data<int64_t>();
seq_len_cpu = *len_t;
}
size_t batch_size = x_t->dims()[0];
const int64_t* seq_len_ptr = seq_len_cpu.data<int64_t>();
int64_t batch_size = len_t->dims()[0];
std::vector<size_t> out_lod0(batch_size + 1, 0);
for (size_t i = 0; i < batch_size; ++i) {
out_lod0[i + 1] = out_lod0[i] + seq_len_ptr[i];
for (int64_t i = 0; i < batch_size; ++i) {
out_lod0[i + 1] = out_lod0[i] + static_cast<size_t>(seq_len_ptr[i]);
}
framework::LoD out_lod;
out_lod.push_back(out_lod0);
out_t->set_lod(out_lod);
std::vector<int64_t> out_dims_vec{static_cast<int64_t>(out_lod0.back())};
if (x_t->dims().size() == 2) {
out_dims_vec.push_back(1);
......@@ -71,8 +69,7 @@ class SequenceUnpadOpKernel : public framework::OpKernel<T> {
int64_t padded_length = x_t->dims()[1];
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *x_t, out_t,
padded_length, 0, false, math::kBatchLengthWidth);
dev_ctx, *x_t, out_t, padded_length, 0, false, math::kBatchLengthWidth);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册