提交 b07ca1de 编写于 作者: Y yangyaming

resize before computing LoD.

上级 08cb472a
...@@ -46,8 +46,8 @@ class SequenceReshapeKernel : public framework::OpKernel<T> { ...@@ -46,8 +46,8 @@ class SequenceReshapeKernel : public framework::OpKernel<T> {
} else { } else {
auto& out_lod = *out->mutable_lod(); auto& out_lod = *out->mutable_lod();
out_lod.resize(1); out_lod.resize(1);
out_lod[0].clear(); out_lod[0].resize(seq_num + 1);
out_lod[0].push_back(0); out_lod[0][0] = 0;
for (int i = 0; i < seq_num; ++i) { for (int i = 0; i < seq_num; ++i) {
size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i]; size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i];
size_t offset = 0; size_t offset = 0;
...@@ -57,11 +57,10 @@ class SequenceReshapeKernel : public framework::OpKernel<T> { ...@@ -57,11 +57,10 @@ class SequenceReshapeKernel : public framework::OpKernel<T> {
"be divided by new_dim with no remainder for each " "be divided by new_dim with no remainder for each "
"sequence. The %dth sequence is invalid.", "sequence. The %dth sequence is invalid.",
i + 1); i + 1);
out_lod[0].push_back(out_lod[0].back() + offset); out_lod[0][i + 1] = out_lod[0][i] + offset;
} }
} }
out->mutable_data<T>(context.GetPlace());
framework::Copy(*in, context.GetPlace(), out); framework::Copy(*in, context.GetPlace(), out);
out->Resize({static_cast<int64_t>(out->lod()[0].back()), out_width}); out->Resize({static_cast<int64_t>(out->lod()[0].back()), out_width});
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册