diff --git a/paddle/operators/shrink_rnn_memory_op.cc b/paddle/operators/shrink_rnn_memory_op.cc index cc9e3f90b42cdc1e4cfd1663b85bf0b4ad66332b..29f88896e74b126517dd0bc71d300af29b151301 100644 --- a/paddle/operators/shrink_rnn_memory_op.cc +++ b/paddle/operators/shrink_rnn_memory_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/lod_rank_table.h" +#include "paddle/framework/lod_tensor.h" #include "paddle/operators/array_operator.h" #include "paddle/operators/math/math_function.h" @@ -49,12 +50,24 @@ class ShrinkRNNMemoryOp : public ArrayOp { // should consider multiple levels size_t height = dst_num_rows; - auto lod_level = lod_rank_table.level(); + auto lod_level = rank_table.level(); + if (x_tensor.lod().size() > lod_level && - x_tensor.lod()[lod_level].size() < dst_num_rows) { + x_tensor.lod()[lod_level].size() > static_cast(dst_num_rows)) { auto lod_offset = framework::GetSubLoDAndAbsoluteOffset( - x_tensor.lod(), 0, dst_num_rows + 1, lod_level); + x_tensor.lod(), 0, dst_num_rows, lod_level); height = lod_offset.second.second; + auto out_lod = out_tensor.mutable_lod(); + auto x_lod = x_tensor.lod(); + out_lod->reserve(lod_level + lod_offset.first.size()); + for (size_t i = 0; i < lod_level; ++i) { + out_lod->emplace_back(x_lod.at(i)); + } + framework::LoD remain; + framework::AppendLoD(&remain, lod_offset.first); + for (size_t j = 0; j < remain.size(); ++j) { + out_lod->emplace_back(remain[j]); + } } if (dst_num_rows != 0) { diff --git a/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py b/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py index 707dbd793a01c7af813cb98f2d0d7302446e5e59..9d8565b1681e17f1690490ae3d700492c35a8a78 100644 --- a/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py +++ b/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py @@ -29,7 +29,9 @@ class TestShrinkRNNMemory(unittest.TestCase): tensor_np = numpy.random.random(size=(6, 100)).astype('float32') tensor.set(tensor_np, cpu) exe = Executor(cpu) - outs = exe.run(feed={'x': tensor}, fetch_list=[mem1, mem2, mem3]) + outs = exe.run(feed={'x': tensor}, + fetch_list=[mem1, mem2, mem3], + return_numpy=False) self.assertTrue(numpy.allclose(tensor_np[0:6], outs[0])) self.assertTrue(numpy.allclose(tensor_np[0:5], outs[1])) self.assertTrue(numpy.allclose(tensor_np[0:2], outs[2]))