From 63320f722cc718e69ddaa4aa5921e7fd047097df Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 5 Feb 2018 01:17:00 -0800 Subject: [PATCH] "add some interfaces" --- paddle/framework/lod_tensor.h | 22 ++++++- paddle/framework/mixed_vector.h | 102 ++++++++++++++++++++------------ paddle/memory/memory.h | 18 ++++++ 3 files changed, 103 insertions(+), 39 deletions(-) diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index d0ab640485b..ab289241610 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -48,12 +48,26 @@ namespace framework { */ struct LoD : public std::vector> { using std::vector>::vector; + platform::Place place() const { + if (this->size() == 0) { + // Not Initialze Yet. + return platform::CPUPlace(); + } else { + return this->front().place(); + } + } void CopyFromCUDA() { for (auto it = this->begin(); it != this->end(); ++it) { it->CopyFromCUDA(); } } + + void CopyToPeer(platform::Place place) { + for (auto it = this->begin(); it != this->end(); ++it) { + it->mutable_data(place); + } + } }; std::ostream& operator<<(std::ostream& os, const LoD& lod); @@ -115,7 +129,13 @@ class LoDTensor : public Tensor { explicit LoDTensor(const LoD& lod) : lod_(lod) {} - void set_lod(const LoD& lod) { lod_ = lod; } + void set_lod(const LoD& lod) { + lod_ = lod; + if (holder_ != nullptr && + platform::is_same_place(holder_->place(), lod.place())) { + lod_.CopyToPeer(holder_->place()); + } + } const LoD& lod() const { return lod_; } diff --git a/paddle/framework/mixed_vector.h b/paddle/framework/mixed_vector.h index 85caac8dcd9..d86899bc631 100644 --- a/paddle/framework/mixed_vector.h +++ b/paddle/framework/mixed_vector.h @@ -40,14 +40,15 @@ class Vector : public std::vector { Vector() {} Vector(const std::vector &v) : std::vector(v) {} // NOLINT - virtual ~Vector() { -#ifdef PADDLE_WITH_CUDA - if (cuda_ptr_ != nullptr) { - memory::Free(place_, cuda_ptr_); - } -#endif - } + inline platform::Place place() const { return place_; } + /*! Return a pointer to constant memory block. */ + inline const T *data(platform::Place place) const; + + /*! Return a pointer to mutable memory block. */ + inline T *mutable_data(platform::Place place); + + // TODO(dzhwinter): below interfaces should be removed /* Get device vector */ T *cuda_data() { CopyToCUDA(); @@ -68,25 +69,71 @@ class Vector : public std::vector { void CopyToPeer(platform::Place); private: - void *cuda_ptr_ = nullptr; + std::shared_ptr cuda_ptr_; size_t cuda_size_ = 0; // device vector numel platform::CUDAPlace place_; }; template -void Vector::CopyToCUDA() { +inline const T *Vector::data(platform::Place place) const { + if (platform::is_cpu_place(place)) { + return std::vector::data(); + } else if (platform::is_gpu_place(place)) { + if (cuda_ptr_ == nullptr) { + return nullptr; + } + if (platform::is_same_place(place, place_)) { + return static_cast(cuda_ptr_.get()); + } else { + PADDLE_THROW( + "Unmatched place. Please use `mutable_data` copy lod to the target " + "Place first."); + } + } else { + PADDLE_THROW("Unsupport Place."); + } +} + +template +inline T *Vector::mutable_data(platform::Place place) { + if (platform::is_cpu_place(place)) { + return std::vector::data(); + } else if (platform::is_gpu_place(place)) { + if (!platform::is_same_place(place, place_)) { + place_ = boost::get(place); + } #ifdef PADDLE_WITH_CUDA - if (cuda_size_ < this->size()) { - if (cuda_ptr_ != nullptr) { - memory::Free(place_, cuda_ptr_); + if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) { + cuda_ptr_.reset( + memory::Alloc(place_, this->size() * sizeof(T)), + memory::PlainDeleter(place_)); } - cuda_ptr_ = - memory::Alloc(place_, this->size() * sizeof(T)); + cuda_size_ = this->size(); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto *ctx = pool.GetByPlace(place_); + memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(), + static_cast(this->data()), + this->size() * sizeof(T), ctx->stream()); + ctx->Wait(); + return static_cast(cuda_ptr_.get()); +#endif + } else { + PADDLE_THROW("Unsupport Place."); + } +} + +template +void Vector::CopyToCUDA() { +#ifdef PADDLE_WITH_CUDA + if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) { + cuda_ptr_.reset( + memory::Alloc(this->size() * sizeof(T)), + memory::PlainDeleter(place_)); } cuda_size_ = this->size(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto *ctx = pool.GetByPlace(place_); - memory::Copy(place_, cuda_ptr_, platform::CPUPlace(), + memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(), static_cast(this->data()), this->size() * sizeof(T), ctx->stream()); ctx->Wait(); @@ -104,32 +151,11 @@ void Vector::CopyFromCUDA() { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto *ctx = pool.GetByPlace(place_); memory::Copy(platform::CPUPlace(), static_cast(this->data()), place_, - static_cast(cuda_ptr_), this->size() * sizeof(T), - ctx->stream()); - ctx->Wait(); -#endif -} - -template -void Vector::CopyToPeer(platform::Place peer_place) { -#ifdef PADDLE_WITH_CUDA - auto *ctx = platform::DeviceContextPool::Instance().GetByPlace(place_); - void *peer_cuda_ptr = memory::Alloc( - boost::get(peer_place), this->size() * sizeof(T)); - memory::Copy(boost::get(peer_place), peer_cuda_ptr, - place_, cuda_ptr_, this->size() * sizeof(T), ctx->stream()); + static_cast(cuda_ptr_.get()), + this->size() * sizeof(T), ctx->stream()); ctx->Wait(); - - memory::Free(place_, cuda_ptr_); - place_ = boost::get(peer_place); - cuda_ptr_ = peer_cuda_ptr; #endif } -template class Vector; -template class Vector; -template class Vector; -template class Vector; - } // namespace framework } // namespace paddle diff --git a/paddle/memory/memory.h b/paddle/memory/memory.h index 7012b6d331d..30ed68c6e0e 100644 --- a/paddle/memory/memory.h +++ b/paddle/memory/memory.h @@ -81,5 +81,23 @@ class PODDeleter { Place place_; }; +/** + * \brief Free memory block in one place does not meet POD + * + * \note In some cases, custom deleter is used to + * deallocate the memory automatically for + * std::unique_ptr in tensor.h. + * + */ +template +class PlainDeleter { + public: + explicit PlainDeleter(Place place) : place_(place) {} + void operator()(T* ptr) { Free(place_, reinterpret_cast(ptr)); } + + private: + Place place_; +}; + } // namespace memory } // namespace paddle -- GitLab