diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index 19ce1d23e8611f27403304d0b0de98f8efa2be00..ca820068c4d2f89b76306df81bac757918195ec1 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -283,6 +283,21 @@ std::vector LoDTensor::SplitLoDTensor( std::vector results; results.reserve(result_size); + // if result_size(batch_size) is 0, just return #places.size() copys of empty + // tensors. + if (result_size == 0) { + for (size_t i = 0; i < places.size(); ++i) { + LoDTensor dst; + dst.Resize(dims()); + dst.mutable_data(places[i], type()); + if (!lod().empty()) { + dst.set_lod(lod()); + } + results.emplace_back(dst); + } + return results; + } + int step_width = static_cast(batch_size / result_size); for (size_t i = 0; i < result_size; ++i) { int begin = static_cast(i * step_width); diff --git a/paddle/fluid/framework/lod_tensor_test.cc b/paddle/fluid/framework/lod_tensor_test.cc index a9f75ec2a9c36ac55f09ef48f3f6a4a52f14ccf9..c93c3f2673b1d80ef1e1a9dd68ad50501ba16f42 100644 --- a/paddle/fluid/framework/lod_tensor_test.cc +++ b/paddle/fluid/framework/lod_tensor_test.cc @@ -155,6 +155,26 @@ TEST(LoD, SplitLoDTensor) { EXPECT_EQ(lods[1].lod(), lod1); } +TEST(LoD, SplitLoDTensorWithZeroBatchSize) { + LoD lod; + lod.push_back(std::vector({0})); + + platform::CPUPlace place; + LoDTensor lod_tensor; + lod_tensor.Resize({0, 5}); + lod_tensor.mutable_data(place); + lod_tensor.set_lod(lod); + + std::vector places{platform::CPUPlace(), + platform::CPUPlace()}; + LoD lod_res; + lod_res.push_back(std::vector({0})); + + auto lods = lod_tensor.SplitLoDTensor(places); + EXPECT_EQ(lods[0].lod(), lod_res); + EXPECT_EQ(lods[1].lod(), lod_res); +} + TEST(LoD, MergeLoDTensor) { LoD lod; lod.push_back(std::vector({0, 2, 4, 5, 6}));