提交 e3210364 编写于 作者: Y yangyaming

Only shrink for the first level LoD.

上级 f947c153
......@@ -48,26 +48,16 @@ class ShrinkRNNMemoryOp : public ArrayOp {
PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set");
auto &out_tensor = *out_var->GetMutable<framework::LoDTensor>();
// should consider multiple levels
size_t height = dst_num_rows;
auto lod_level = rank_table.level();
if (x_tensor.lod().size() > lod_level &&
x_tensor.lod()[lod_level].size() > static_cast<size_t>(dst_num_rows)) {
auto lod_offset = framework::GetSubLoDAndAbsoluteOffset(
x_tensor.lod(), 0, dst_num_rows, lod_level);
// do shrink for the top level LoD
if (x_tensor.lod().size() > 0 &&
x_tensor.lod()[0].size() > static_cast<size_t>(dst_num_rows)) {
auto lod_offset = framework::GetSubLoDAndAbsoluteOffset(x_tensor.lod(), 0,
dst_num_rows, 0);
height = lod_offset.second.second;
auto out_lod = out_tensor.mutable_lod();
auto x_lod = x_tensor.lod();
out_lod->reserve(lod_level + lod_offset.first.size());
for (size_t i = 0; i < lod_level; ++i) {
out_lod->emplace_back(x_lod.at(i));
}
framework::LoD remain;
framework::AppendLoD(&remain, lod_offset.first);
for (size_t j = 0; j < remain.size(); ++j) {
out_lod->emplace_back(remain[j]);
}
framework::AppendLoD(out_lod, lod_offset.first);
}
if (dst_num_rows != 0) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册