From 578a2f5da3b46876e2b84a8bdc90fbf91fc6a6ad Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Thu, 19 Sep 2019 15:39:17 +0800 Subject: [PATCH] fix SplitLodTensor when batch_size = 0, test=develop (#19866) --- paddle/fluid/framework/lod_tensor.cc | 15 +++++++++++++++ paddle/fluid/framework/lod_tensor_test.cc | 20 ++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index 19ce1d23e8..ca820068c4 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -283,6 +283,21 @@ std::vector LoDTensor::SplitLoDTensor( std::vector results; results.reserve(result_size); + // if result_size(batch_size) is 0, just return #places.size() copys of empty + // tensors. + if (result_size == 0) { + for (size_t i = 0; i < places.size(); ++i) { + LoDTensor dst; + dst.Resize(dims()); + dst.mutable_data(places[i], type()); + if (!lod().empty()) { + dst.set_lod(lod()); + } + results.emplace_back(dst); + } + return results; + } + int step_width = static_cast(batch_size / result_size); for (size_t i = 0; i < result_size; ++i) { int begin = static_cast(i * step_width); diff --git a/paddle/fluid/framework/lod_tensor_test.cc b/paddle/fluid/framework/lod_tensor_test.cc index a9f75ec2a9..c93c3f2673 100644 --- a/paddle/fluid/framework/lod_tensor_test.cc +++ b/paddle/fluid/framework/lod_tensor_test.cc @@ -155,6 +155,26 @@ TEST(LoD, SplitLoDTensor) { EXPECT_EQ(lods[1].lod(), lod1); } +TEST(LoD, SplitLoDTensorWithZeroBatchSize) { + LoD lod; + lod.push_back(std::vector({0})); + + platform::CPUPlace place; + LoDTensor lod_tensor; + lod_tensor.Resize({0, 5}); + lod_tensor.mutable_data(place); + lod_tensor.set_lod(lod); + + std::vector places{platform::CPUPlace(), + platform::CPUPlace()}; + LoD lod_res; + lod_res.push_back(std::vector({0})); + + auto lods = lod_tensor.SplitLoDTensor(places); + EXPECT_EQ(lods[0].lod(), lod_res); + EXPECT_EQ(lods[1].lod(), lod_res); +} + TEST(LoD, MergeLoDTensor) { LoD lod; lod.push_back(std::vector({0, 2, 4, 5, 6})); -- GitLab