提交 d002f60a 编写于 作者: Y Yang Yang

merge develop

上级 e0e45c05
...@@ -289,14 +289,15 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, ...@@ -289,14 +289,15 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
std::vector<LoDTensor> LoDTensor::SplitLoDTensor( std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
const std::vector<platform::Place> places) const { const std::vector<platform::Place> places) const {
check_memory_size(); check_memory_size();
PADDLE_ENFORCE(lod().empty(), "Disable parallel lod for now"); int batch_size =
size_t result_size = std::min(static_cast<size_t>(dims()[0]), places.size()); lod().empty() ? dims()[0] : static_cast<int>(lod()[0].size()) - 1;
size_t remainder = dims()[0] % places.size(); size_t result_size = std::min(static_cast<size_t>(batch_size), places.size());
size_t remainder = batch_size % places.size();
std::vector<LoDTensor> results; std::vector<LoDTensor> results;
results.reserve(result_size); results.reserve(result_size);
int step_width = static_cast<int>(dims()[0] / result_size); int step_width = static_cast<int>(batch_size / result_size);
for (size_t i = 0; i < result_size; ++i) { for (size_t i = 0; i < result_size; ++i) {
int begin = static_cast<int>(i * step_width); int begin = static_cast<int>(i * step_width);
int end = static_cast<int>((i + 1) * step_width); int end = static_cast<int>((i + 1) * step_width);
...@@ -307,14 +308,14 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor( ...@@ -307,14 +308,14 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
LoDTensor dst; LoDTensor dst;
if (lod().empty()) { if (lod().empty()) {
auto src = Slice(begin, end); auto src = Slice(begin, end);
auto &dst_place = places[place_idx]; auto &dst_place = places[i];
framework::Copy(src, dst_place, &dst); framework::Copy(src, dst_place, &dst);
} else { } else {
auto lod_and_offset = GetSubLoDAndAbsoluteOffset(lod(), begin, end, 0); auto lod_and_offset = GetSubLoDAndAbsoluteOffset(lod(), begin, end, 0);
auto &offset = lod_and_offset.second; auto &offset = lod_and_offset.second;
auto src = Slice(offset.first, offset.second); auto src = Slice(offset.first, offset.second);
auto &dst_place = places[place_idx]; auto &dst_place = places[i];
framework::Copy(src, dst_place, &dst); framework::Copy(src, dst_place, &dst);
LoD my_lod; LoD my_lod;
...@@ -327,7 +328,7 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor( ...@@ -327,7 +328,7 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
} }
dst.set_lod(my_lod); dst.set_lod(my_lod);
} }
lods.emplace_back(dst); results.emplace_back(dst);
} }
return results; return results;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册