diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index a773c1eb32d8a66999f11f099a1703e3698b01ef..be2b301619639106ac7b578e5a79cf33f4379e48 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -65,7 +65,7 @@ struct LoD : public std::vector> { void CopyToPeer(platform::Place place) { for (auto it = this->begin(); it != this->end(); ++it) { - it->mutable_data(place); + it->CopyToPeer(place); } } }; diff --git a/paddle/framework/mixed_vector.h b/paddle/framework/mixed_vector.h index 1fc7622e9b299a67c43bcc4560a66c4ca3ca697e..cdb968e3cb7f2546ccfa68d0a3c5490e4f0ff760 100644 --- a/paddle/framework/mixed_vector.h +++ b/paddle/framework/mixed_vector.h @@ -82,7 +82,7 @@ inline const T *Vector::data(platform::Place place) const { if (cuda_ptr_ == nullptr) { return nullptr; } - if (platform::is_same_place(place, place_)) { + if (boost::get(place) == place_) { return static_cast(cuda_ptr_.get()); } else { PADDLE_THROW( @@ -99,7 +99,7 @@ inline T *Vector::mutable_data(platform::Place place) { if (platform::is_cpu_place(place)) { return std::vector::data(); } else if (platform::is_gpu_place(place)) { - if (!platform::is_same_place(place, place_)) { + if (boost::get(place) != place_) { place_ = boost::get(place); } #ifdef PADDLE_WITH_CUDA @@ -159,5 +159,26 @@ void Vector::CopyFromCUDA() { #endif } +template +void Vector::CopyToPeer(platform::Place place) { +#ifdef PADDLE_WITH_CUDA + if (boost::get(place) != place_) { + place_ = boost::get(place); + } + if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) { + cuda_ptr_.reset( + memory::Alloc(place_, this->size() * sizeof(T)), + memory::PlainDeleter(place_)); + } + cuda_size_ = this->size(); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto *ctx = pool.GetByPlace(place_); + memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(), + static_cast(this->data()), + this->size() * sizeof(T), ctx->stream()); + ctx->Wait(); +#endif +} + } // namespace framework } // namespace paddle diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index 0db2fb6238a8251bb1c55aea789a35c6c0e94b9c..eb6308d306af24894207705d6413ac1fec55577b 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -79,7 +79,7 @@ inline void CopyOrShare(const framework::Variable &src, dst->GetMutable()->set_lod(src.Get().lod()); } else { Copy(src.Get(), dst_place, dst->GetMutable()); - LoD lod(src.Get().lod()); + framework::LoD lod(src.Get().lod()); lod.CopyToPeer(dst_place); dst->GetMutable()->set_lod(lod); } @@ -92,7 +92,7 @@ inline void CopyOrShare(const framework::Variable &src, dst_sr->set_rows(src_sr.rows()); } else { Copy(src_sr.value(), dst_place, dst_sr->mutable_value()); - LoD lod(src.Get().lod()); + framework::Vector lod(src_sr.rows()); lod.CopyToPeer(dst_place); dst_sr->set_rows(lod); }