From f947c1537830c71e6698e65fd5ddf32781075fd7 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Thu, 4 Jan 2018 12:53:25 +0800 Subject: [PATCH] Consider multiple levels of LoD. --- paddle/operators/shrink_rnn_memory_op.cc | 19 ++++++++++++++++--- .../v2/fluid/tests/test_shrink_rnn_memory.py | 4 +++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/paddle/operators/shrink_rnn_memory_op.cc b/paddle/operators/shrink_rnn_memory_op.cc index cc9e3f90b..29f88896e 100644 --- a/paddle/operators/shrink_rnn_memory_op.cc +++ b/paddle/operators/shrink_rnn_memory_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/lod_rank_table.h" +#include "paddle/framework/lod_tensor.h" #include "paddle/operators/array_operator.h" #include "paddle/operators/math/math_function.h" @@ -49,12 +50,24 @@ class ShrinkRNNMemoryOp : public ArrayOp { // should consider multiple levels size_t height = dst_num_rows; - auto lod_level = lod_rank_table.level(); + auto lod_level = rank_table.level(); + if (x_tensor.lod().size() > lod_level && - x_tensor.lod()[lod_level].size() < dst_num_rows) { + x_tensor.lod()[lod_level].size() > static_cast(dst_num_rows)) { auto lod_offset = framework::GetSubLoDAndAbsoluteOffset( - x_tensor.lod(), 0, dst_num_rows + 1, lod_level); + x_tensor.lod(), 0, dst_num_rows, lod_level); height = lod_offset.second.second; + auto out_lod = out_tensor.mutable_lod(); + auto x_lod = x_tensor.lod(); + out_lod->reserve(lod_level + lod_offset.first.size()); + for (size_t i = 0; i < lod_level; ++i) { + out_lod->emplace_back(x_lod.at(i)); + } + framework::LoD remain; + framework::AppendLoD(&remain, lod_offset.first); + for (size_t j = 0; j < remain.size(); ++j) { + out_lod->emplace_back(remain[j]); + } } if (dst_num_rows != 0) { diff --git a/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py b/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py index 707dbd793..9d8565b16 100644 --- a/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py +++ b/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py @@ -29,7 +29,9 @@ class TestShrinkRNNMemory(unittest.TestCase): tensor_np = numpy.random.random(size=(6, 100)).astype('float32') tensor.set(tensor_np, cpu) exe = Executor(cpu) - outs = exe.run(feed={'x': tensor}, fetch_list=[mem1, mem2, mem3]) + outs = exe.run(feed={'x': tensor}, + fetch_list=[mem1, mem2, mem3], + return_numpy=False) self.assertTrue(numpy.allclose(tensor_np[0:6], outs[0])) self.assertTrue(numpy.allclose(tensor_np[0:5], outs[1])) self.assertTrue(numpy.allclose(tensor_np[0:2], outs[2])) -- GitLab