From 179b78934a81c7935b3a3d6fa22f9596170a31dc Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Tue, 6 Feb 2018 00:24:13 -0800 Subject: [PATCH] "fix CopyToPeer" --- paddle/framework/lod_tensor.h | 2 +- paddle/framework/mixed_vector.h | 25 +++++++++++++++++++++++-- paddle/operators/parallel_do_op.cc | 4 ++-- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index a773c1eb32d..be2b3016196 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -65,7 +65,7 @@ struct LoD : public std::vector> { void CopyToPeer(platform::Place place) { for (auto it = this->begin(); it != this->end(); ++it) { - it->mutable_data(place); + it->CopyToPeer(place); } } }; diff --git a/paddle/framework/mixed_vector.h b/paddle/framework/mixed_vector.h index 1fc7622e9b2..cdb968e3cb7 100644 --- a/paddle/framework/mixed_vector.h +++ b/paddle/framework/mixed_vector.h @@ -82,7 +82,7 @@ inline const T *Vector::data(platform::Place place) const { if (cuda_ptr_ == nullptr) { return nullptr; } - if (platform::is_same_place(place, place_)) { + if (boost::get(place) == place_) { return static_cast(cuda_ptr_.get()); } else { PADDLE_THROW( @@ -99,7 +99,7 @@ inline T *Vector::mutable_data(platform::Place place) { if (platform::is_cpu_place(place)) { return std::vector::data(); } else if (platform::is_gpu_place(place)) { - if (!platform::is_same_place(place, place_)) { + if (boost::get(place) != place_) { place_ = boost::get(place); } #ifdef PADDLE_WITH_CUDA @@ -159,5 +159,26 @@ void Vector::CopyFromCUDA() { #endif } +template +void Vector::CopyToPeer(platform::Place place) { +#ifdef PADDLE_WITH_CUDA + if (boost::get(place) != place_) { + place_ = boost::get(place); + } + if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) { + cuda_ptr_.reset( + memory::Alloc(place_, this->size() * sizeof(T)), + memory::PlainDeleter(place_)); + } + cuda_size_ = this->size(); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto *ctx = pool.GetByPlace(place_); + memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(), + static_cast(this->data()), + this->size() * sizeof(T), ctx->stream()); + ctx->Wait(); +#endif +} + } // namespace framework } // namespace paddle diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index 0db2fb6238a..eb6308d306a 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -79,7 +79,7 @@ inline void CopyOrShare(const framework::Variable &src, dst->GetMutable()->set_lod(src.Get().lod()); } else { Copy(src.Get(), dst_place, dst->GetMutable()); - LoD lod(src.Get().lod()); + framework::LoD lod(src.Get().lod()); lod.CopyToPeer(dst_place); dst->GetMutable()->set_lod(lod); } @@ -92,7 +92,7 @@ inline void CopyOrShare(const framework::Variable &src, dst_sr->set_rows(src_sr.rows()); } else { Copy(src_sr.value(), dst_place, dst_sr->mutable_value()); - LoD lod(src.Get().lod()); + framework::Vector lod(src_sr.rows()); lod.CopyToPeer(dst_place); dst_sr->set_rows(lod); } -- GitLab