diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index 3465e02c8262ffd626298f2672679732396c1c4f..a773c1eb32d8a66999f11f099a1703e3698b01ef 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -129,13 +129,7 @@ class LoDTensor : public Tensor { explicit LoDTensor(const LoD& lod) : lod_(lod) {} - void set_lod(const LoD& lod) { - lod_ = lod; - if (holder_ != nullptr && - !platform::is_same_place(holder_->place(), lod.place())) { - lod_.CopyToPeer(holder_->place()); - } - } + void set_lod(const LoD& lod) { lod_ = lod; } const LoD& lod() const { return lod_; } diff --git a/paddle/framework/selected_rows.h b/paddle/framework/selected_rows.h index 113234424467c97e2069a884a2d2bb43f2c8e657..30d3dfc1e89f073a8180ceacf77619b36f7079a9 100644 --- a/paddle/framework/selected_rows.h +++ b/paddle/framework/selected_rows.h @@ -42,13 +42,7 @@ class SelectedRows { Vector* mutable_rows() { return &rows_; } - void set_rows(const Vector& rows) { - rows_ = rows; - if (value_ != nullptr && - !platform::is_same_place(value_->place(), rows.place())) { - rows_.mutable_data(value_->place()); - } - } + void set_rows(const Vector& rows) { rows_ = rows; } DDim GetCompleteDims() const { std::vector dims = vectorize(value_->dims()); diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index 87678decde14c07f1f98bb8d56624d90c47c729c..0db2fb6238a8251bb1c55aea789a35c6c0e94b9c 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -76,21 +76,26 @@ inline void CopyOrShare(const framework::Variable &src, if (src.IsType()) { if (src.Get().place() == dst_place) { dst->GetMutable()->ShareDataWith(src.Get()); + dst->GetMutable()->set_lod(src.Get().lod()); } else { Copy(src.Get(), dst_place, dst->GetMutable()); + LoD lod(src.Get().lod()); + lod.CopyToPeer(dst_place); + dst->GetMutable()->set_lod(lod); } - dst->GetMutable()->set_lod(src.Get().lod()); } else if (src.IsType()) { auto &src_sr = src.Get(); auto *dst_sr = dst->GetMutable(); - dst_sr->set_rows(src_sr.rows()); dst_sr->set_height(src_sr.height()); if (src_sr.value().place() == dst_place) { dst_sr->mutable_value()->ShareDataWith(src_sr.value()); + dst_sr->set_rows(src_sr.rows()); } else { Copy(src_sr.value(), dst_place, dst_sr->mutable_value()); + LoD lod(src.Get().lod()); + lod.CopyToPeer(dst_place); + dst_sr->set_rows(lod); } - dst_sr->set_rows(src_sr.rows()); } else { PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name()); }