From 709c157a2ff4d51846c373b465d021be93033363 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 5 Feb 2018 23:59:41 -0800 Subject: [PATCH] "fix ci" --- paddle/framework/lod_tensor.h | 8 +------- paddle/framework/selected_rows.h | 8 +------- paddle/operators/parallel_do_op.cc | 11 ++++++++--- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index 3465e02c8..a773c1eb3 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -129,13 +129,7 @@ class LoDTensor : public Tensor { explicit LoDTensor(const LoD& lod) : lod_(lod) {} - void set_lod(const LoD& lod) { - lod_ = lod; - if (holder_ != nullptr && - !platform::is_same_place(holder_->place(), lod.place())) { - lod_.CopyToPeer(holder_->place()); - } - } + void set_lod(const LoD& lod) { lod_ = lod; } const LoD& lod() const { return lod_; } diff --git a/paddle/framework/selected_rows.h b/paddle/framework/selected_rows.h index 113234424..30d3dfc1e 100644 --- a/paddle/framework/selected_rows.h +++ b/paddle/framework/selected_rows.h @@ -42,13 +42,7 @@ class SelectedRows { Vector* mutable_rows() { return &rows_; } - void set_rows(const Vector& rows) { - rows_ = rows; - if (value_ != nullptr && - !platform::is_same_place(value_->place(), rows.place())) { - rows_.mutable_data(value_->place()); - } - } + void set_rows(const Vector& rows) { rows_ = rows; } DDim GetCompleteDims() const { std::vector dims = vectorize(value_->dims()); diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index 87678decd..0db2fb623 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -76,21 +76,26 @@ inline void CopyOrShare(const framework::Variable &src, if (src.IsType()) { if (src.Get().place() == dst_place) { dst->GetMutable()->ShareDataWith(src.Get()); + dst->GetMutable()->set_lod(src.Get().lod()); } else { Copy(src.Get(), dst_place, dst->GetMutable()); + LoD lod(src.Get().lod()); + lod.CopyToPeer(dst_place); + dst->GetMutable()->set_lod(lod); } - dst->GetMutable()->set_lod(src.Get().lod()); } else if (src.IsType()) { auto &src_sr = src.Get(); auto *dst_sr = dst->GetMutable(); - dst_sr->set_rows(src_sr.rows()); dst_sr->set_height(src_sr.height()); if (src_sr.value().place() == dst_place) { dst_sr->mutable_value()->ShareDataWith(src_sr.value()); + dst_sr->set_rows(src_sr.rows()); } else { Copy(src_sr.value(), dst_place, dst_sr->mutable_value()); + LoD lod(src.Get().lod()); + lod.CopyToPeer(dst_place); + dst_sr->set_rows(lod); } - dst_sr->set_rows(src_sr.rows()); } else { PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name()); } -- GitLab