From 8c1025d66f3574e53cc41a0df9620d53c76d3ec5 Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Thu, 11 Jan 2018 09:10:29 +0000 Subject: [PATCH] first commit --- paddle/framework/lod_tensor.cc | 32 ++++++++++++++++++++++------- paddle/framework/lod_tensor_test.cc | 28 +++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc index 7ae94c64653..3d1b75597e5 100644 --- a/paddle/framework/lod_tensor.cc +++ b/paddle/framework/lod_tensor.cc @@ -225,20 +225,38 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, std::vector LoDTensor::SplitLoDTensor( const std::vector places) const { check_memory_size(); - PADDLE_ENFORCE(lod().empty(), "Disable parallel lod for now"); PADDLE_ENFORCE(dims()[0] % places.size() == 0, "Batch size should be divided by places size"); std::vector lods; for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) { - int begin = place_idx * dims()[0] / places.size(); - int end = (place_idx + 1) * dims()[0] / places.size(); + size_t batch_size = lod().empty() ? dims()[0] : NumElements(0); + size_t begin = place_idx * batch_size / places.size(); + size_t end = (place_idx + 1) * batch_size / places.size(); - auto src = Slice(begin, end); - auto &dst_place = places[place_idx]; LoDTensor dst; - framework::Copy(src, dst_place, &dst); - + if (lod().empty()) { + auto src = Slice(begin, end); + auto &dst_place = places[place_idx]; + framework::Copy(src, dst_place, &dst); + } else { + auto lod_and_offset = GetSubLoDAndAbsoluteOffset(lod(), begin, end, 0); + + auto &offset = lod_and_offset.second; + auto src = Slice(offset.first, offset.second); + auto &dst_place = places[place_idx]; + framework::Copy(src, dst_place, &dst); + + LoD my_lod; + for (auto &l : lod_and_offset.first) { + std::vector v{0}; + for (auto &ll : l) { + v.push_back(ll + v.back()); + } + my_lod.emplace_back(v); + } + dst.set_lod(my_lod); + } lods.emplace_back(dst); } diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc index baad9c6f98a..5ff7dca564a 100644 --- a/paddle/framework/lod_tensor_test.cc +++ b/paddle/framework/lod_tensor_test.cc @@ -131,5 +131,33 @@ TEST(LoD, ToAbsOffset) { EXPECT_EQ(abs_lod, expected); } +TEST(LoD, SplitLoDTensor) { + LoD lod; + lod.push_back(std::vector({0, 2, 4, 5, 6})); + lod.push_back(std::vector({0, 1, 6, 8, 13, 15, 20})); + + platform::CPUPlace place; + LoDTensor lod_tensor; + lod_tensor.Resize({20, 1}); + float* dst_ptr = lod_tensor.mutable_data(place); + for (int i = 0; i < lod_tensor.numel(); ++i) { + dst_ptr[i] = i; + } + lod_tensor.set_lod(lod); + + std::vector places{platform::CPUPlace(), + platform::CPUPlace()}; + LoD lod0; + lod0.push_back(std::vector({0, 2, 4})); + lod0.push_back(std::vector({0, 1, 6, 8, 13})); + LoD lod1; + lod1.push_back(std::vector({0, 1, 2})); + lod1.push_back(std::vector({0, 2, 7})); + + auto lods = lod_tensor.SplitLoDTensor(places); + EXPECT_EQ(lods[0].lod(), lod0); + EXPECT_EQ(lods[1].lod(), lod1); +} + } // namespace framework } // namespace paddle -- GitLab