diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc index 7ae94c646537e0d7c4687b949a1b06cd3a7f3404..3d1b75597e5be7ca3e5c71276db31c88b1a4a488 100644 --- a/paddle/framework/lod_tensor.cc +++ b/paddle/framework/lod_tensor.cc @@ -225,20 +225,38 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, std::vector LoDTensor::SplitLoDTensor( const std::vector places) const { check_memory_size(); - PADDLE_ENFORCE(lod().empty(), "Disable parallel lod for now"); PADDLE_ENFORCE(dims()[0] % places.size() == 0, "Batch size should be divided by places size"); std::vector lods; for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) { - int begin = place_idx * dims()[0] / places.size(); - int end = (place_idx + 1) * dims()[0] / places.size(); + size_t batch_size = lod().empty() ? dims()[0] : NumElements(0); + size_t begin = place_idx * batch_size / places.size(); + size_t end = (place_idx + 1) * batch_size / places.size(); - auto src = Slice(begin, end); - auto &dst_place = places[place_idx]; LoDTensor dst; - framework::Copy(src, dst_place, &dst); - + if (lod().empty()) { + auto src = Slice(begin, end); + auto &dst_place = places[place_idx]; + framework::Copy(src, dst_place, &dst); + } else { + auto lod_and_offset = GetSubLoDAndAbsoluteOffset(lod(), begin, end, 0); + + auto &offset = lod_and_offset.second; + auto src = Slice(offset.first, offset.second); + auto &dst_place = places[place_idx]; + framework::Copy(src, dst_place, &dst); + + LoD my_lod; + for (auto &l : lod_and_offset.first) { + std::vector v{0}; + for (auto &ll : l) { + v.push_back(ll + v.back()); + } + my_lod.emplace_back(v); + } + dst.set_lod(my_lod); + } lods.emplace_back(dst); } diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc index baad9c6f98ac135c3650fe3113522850328c1298..5ff7dca564aab96f278659a00816497d83468025 100644 --- a/paddle/framework/lod_tensor_test.cc +++ b/paddle/framework/lod_tensor_test.cc @@ -131,5 +131,33 @@ TEST(LoD, ToAbsOffset) { EXPECT_EQ(abs_lod, expected); } +TEST(LoD, SplitLoDTensor) { + LoD lod; + lod.push_back(std::vector({0, 2, 4, 5, 6})); + lod.push_back(std::vector({0, 1, 6, 8, 13, 15, 20})); + + platform::CPUPlace place; + LoDTensor lod_tensor; + lod_tensor.Resize({20, 1}); + float* dst_ptr = lod_tensor.mutable_data(place); + for (int i = 0; i < lod_tensor.numel(); ++i) { + dst_ptr[i] = i; + } + lod_tensor.set_lod(lod); + + std::vector places{platform::CPUPlace(), + platform::CPUPlace()}; + LoD lod0; + lod0.push_back(std::vector({0, 2, 4})); + lod0.push_back(std::vector({0, 1, 6, 8, 13})); + LoD lod1; + lod1.push_back(std::vector({0, 1, 2})); + lod1.push_back(std::vector({0, 2, 7})); + + auto lods = lod_tensor.SplitLoDTensor(places); + EXPECT_EQ(lods[0].lod(), lod0); + EXPECT_EQ(lods[1].lod(), lod1); +} + } // namespace framework } // namespace paddle