From 47ebe435a79ab836649ba11c635129c8a6664ea1 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Thu, 1 Feb 2018 20:41:54 +0800 Subject: [PATCH] Fix/vector (#8045) * "clean code" * "clean code" --- paddle/framework/mixed_vector.h | 77 +++++++++++++-------------------- 1 file changed, 29 insertions(+), 48 deletions(-) diff --git a/paddle/framework/mixed_vector.h b/paddle/framework/mixed_vector.h index 0e0e23958..85caac8dc 100644 --- a/paddle/framework/mixed_vector.h +++ b/paddle/framework/mixed_vector.h @@ -34,18 +34,6 @@ namespace framework { template class Vector : public std::vector { - public: - /* NOTE(dzhwinter): - * Data always store and modified on Host. - * If the data is modified when use cuda_data interface, - * You need to call the CopyFromCUDA explicitly to synchronize data. - * - */ - enum class kDataPosition { - kDataOnHost = 0, - kDataOnDevice = 1, - }; - public: using std::vector::vector; @@ -55,11 +43,12 @@ class Vector : public std::vector { virtual ~Vector() { #ifdef PADDLE_WITH_CUDA if (cuda_ptr_ != nullptr) { - memory::Free(place_, static_cast(cuda_ptr_)); + memory::Free(place_, cuda_ptr_); } #endif } + /* Get device vector */ T *cuda_data() { CopyToCUDA(); PADDLE_ENFORCE_NOT_NULL( @@ -67,81 +56,73 @@ class Vector : public std::vector { return static_cast(cuda_ptr_); } + /* Get host vector */ T *data() { return std::vector::data(); } - const T *data() const { return std::vector::data(); } + /* Synchronize host vector to device vector */ void CopyToCUDA(); - + /* Synchronize device vector to host vector */ void CopyFromCUDA(); - + /* Switch device vector location */ void CopyToPeer(platform::Place); private: void *cuda_ptr_ = nullptr; - size_t cuda_size_ = 0; - /*The DataPosition is unused now, - if we want support random access from cpu and cuda, - we need to overload all the vector method */ - - kDataPosition position_ = kDataPosition::kDataOnHost; + size_t cuda_size_ = 0; // device vector numel platform::CUDAPlace place_; }; template void Vector::CopyToCUDA() { #ifdef PADDLE_WITH_CUDA - if (cuda_ptr_ == nullptr) { + if (cuda_size_ < this->size()) { + if (cuda_ptr_ != nullptr) { + memory::Free(place_, cuda_ptr_); + } cuda_ptr_ = memory::Alloc(place_, this->size() * sizeof(T)); } + cuda_size_ = this->size(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto *cuda_ctx = pool.GetByPlace(place_); - - memory::Copy(place_, static_cast(cuda_ptr_), platform::CPUPlace(), + auto *ctx = pool.GetByPlace(place_); + memory::Copy(place_, cuda_ptr_, platform::CPUPlace(), static_cast(this->data()), - this->size() * sizeof(T), cuda_ctx->stream()); - cuda_ctx->Wait(); - - cuda_size_ = this->size(); + this->size() * sizeof(T), ctx->stream()); + ctx->Wait(); #endif } template void Vector::CopyFromCUDA() { #ifdef PADDLE_WITH_CUDA - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto *cuda_ctx = pool.GetByPlace(place_); if (cuda_ptr_ == nullptr) { - LOG(WARNING) << "No uncommited cuda data."; + LOG(WARNING) << "No uncommitted cuda data."; return; } this->resize(cuda_size_); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto *ctx = pool.GetByPlace(place_); memory::Copy(platform::CPUPlace(), static_cast(this->data()), place_, static_cast(cuda_ptr_), this->size() * sizeof(T), - cuda_ctx->stream()); - cuda_ctx->Wait(); - + ctx->stream()); + ctx->Wait(); #endif } template void Vector::CopyToPeer(platform::Place peer_place) { - if (platform::is_cpu_place(peer_place)) { - return; - } #ifdef PADDLE_WITH_CUDA - auto *cuda_ctx = platform::DeviceContextPool::Instance().GetByPlace(place_); - void *peer_cuda_ptr_ = memory::Alloc( + auto *ctx = platform::DeviceContextPool::Instance().GetByPlace(place_); + void *peer_cuda_ptr = memory::Alloc( boost::get(peer_place), this->size() * sizeof(T)); - memory::Copy(boost::get(peer_place), - static_cast(peer_cuda_ptr_), place_, - static_cast(cuda_ptr_), this->size() * sizeof(T), - cuda_ctx->stream()); - cuda_ctx->Wait(); - memory::Free(place_, static_cast(cuda_ptr_)); + memory::Copy(boost::get(peer_place), peer_cuda_ptr, + place_, cuda_ptr_, this->size() * sizeof(T), ctx->stream()); + ctx->Wait(); + + memory::Free(place_, cuda_ptr_); place_ = boost::get(peer_place); - cuda_ptr_ = peer_cuda_ptr_; + cuda_ptr_ = peer_cuda_ptr; #endif } -- GitLab