提交 8c1025d6 编写于 作者: Y Yang Yang

first commit

上级 377424bf
...@@ -225,20 +225,38 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, ...@@ -225,20 +225,38 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
std::vector<LoDTensor> LoDTensor::SplitLoDTensor( std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
const std::vector<platform::Place> places) const { const std::vector<platform::Place> places) const {
check_memory_size(); check_memory_size();
PADDLE_ENFORCE(lod().empty(), "Disable parallel lod for now");
PADDLE_ENFORCE(dims()[0] % places.size() == 0, PADDLE_ENFORCE(dims()[0] % places.size() == 0,
"Batch size should be divided by places size"); "Batch size should be divided by places size");
std::vector<LoDTensor> lods; std::vector<LoDTensor> lods;
for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) { for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) {
int begin = place_idx * dims()[0] / places.size(); size_t batch_size = lod().empty() ? dims()[0] : NumElements(0);
int end = (place_idx + 1) * dims()[0] / places.size(); size_t begin = place_idx * batch_size / places.size();
size_t end = (place_idx + 1) * batch_size / places.size();
LoDTensor dst;
if (lod().empty()) {
auto src = Slice(begin, end); auto src = Slice(begin, end);
auto &dst_place = places[place_idx]; auto &dst_place = places[place_idx];
LoDTensor dst; framework::Copy(src, dst_place, &dst);
} else {
auto lod_and_offset = GetSubLoDAndAbsoluteOffset(lod(), begin, end, 0);
auto &offset = lod_and_offset.second;
auto src = Slice(offset.first, offset.second);
auto &dst_place = places[place_idx];
framework::Copy(src, dst_place, &dst); framework::Copy(src, dst_place, &dst);
LoD my_lod;
for (auto &l : lod_and_offset.first) {
std::vector<size_t> v{0};
for (auto &ll : l) {
v.push_back(ll + v.back());
}
my_lod.emplace_back(v);
}
dst.set_lod(my_lod);
}
lods.emplace_back(dst); lods.emplace_back(dst);
} }
......
...@@ -131,5 +131,33 @@ TEST(LoD, ToAbsOffset) { ...@@ -131,5 +131,33 @@ TEST(LoD, ToAbsOffset) {
EXPECT_EQ(abs_lod, expected); EXPECT_EQ(abs_lod, expected);
} }
TEST(LoD, SplitLoDTensor) {
LoD lod;
lod.push_back(std::vector<size_t>({0, 2, 4, 5, 6}));
lod.push_back(std::vector<size_t>({0, 1, 6, 8, 13, 15, 20}));
platform::CPUPlace place;
LoDTensor lod_tensor;
lod_tensor.Resize({20, 1});
float* dst_ptr = lod_tensor.mutable_data<float>(place);
for (int i = 0; i < lod_tensor.numel(); ++i) {
dst_ptr[i] = i;
}
lod_tensor.set_lod(lod);
std::vector<platform::Place> places{platform::CPUPlace(),
platform::CPUPlace()};
LoD lod0;
lod0.push_back(std::vector<size_t>({0, 2, 4}));
lod0.push_back(std::vector<size_t>({0, 1, 6, 8, 13}));
LoD lod1;
lod1.push_back(std::vector<size_t>({0, 1, 2}));
lod1.push_back(std::vector<size_t>({0, 2, 7}));
auto lods = lod_tensor.SplitLoDTensor(places);
EXPECT_EQ(lods[0].lod(), lod0);
EXPECT_EQ(lods[1].lod(), lod1);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册