From 5c333e414380f064696a1c152d26cc6b5d6750e4 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 21 Mar 2018 16:21:18 +0800 Subject: [PATCH] Add dctor for dev_ctx --- paddle/fluid/framework/parallel_executor.cc | 27 +++++----------- paddle/fluid/platform/device_context.cc | 34 +++++++++++---------- paddle/fluid/platform/device_context.h | 17 ++--------- paddle/fluid/platform/place.h | 3 +- 4 files changed, 31 insertions(+), 50 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 7064828b21..8c29aacab6 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -35,18 +35,18 @@ using details::VarHandleBase; class ParallelExecutorPrivate { public: - explicit ParallelExecutorPrivate(size_t num_threads) - : pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) {} + explicit ParallelExecutorPrivate(size_t num_threads, + const std::vector &places) + : places_(places), + fetch_dev_ctxs_(places), + pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) {} std::vector places_; - + platform::DeviceContextPool fetch_dev_ctxs_; std::vector local_scopes_; Scope *global_scope_; std::unique_ptr nccl_ctxs_; - std::unordered_map - fetch_dev_ctxs_; platform::Place main_place_; @@ -219,20 +219,9 @@ ParallelExecutor::ParallelExecutor( const std::unordered_set ¶ms, const ProgramDesc &startup_program, const ProgramDesc &main_program, const std::string &loss_var_name, Scope *scope) - : member_(new ParallelExecutorPrivate(num_threads)) { - member_->places_ = places; + : member_(new ParallelExecutorPrivate(num_threads, places)) { member_->global_scope_ = scope; - if (platform::is_cpu_place(places[0])) { - member_->fetch_dev_ctxs_[places[0]] = const_cast( - platform::DeviceContextPool::Instance().Get(places[0])); - } else { - for (auto &p : member_->places_) { - member_->fetch_dev_ctxs_[p] = - new platform::CUDADeviceContext(boost::get(p)); - } - } - // Step 1. RunStartupProgram and Bcast the params to devs. Executor exe(places[0]); exe.Run(startup_program, scope, 0); @@ -509,7 +498,7 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, // FIXME: Use new device context for (auto &p : member_->places_) { - op->dev_ctx_[p] = member_->fetch_dev_ctxs_[p]; + op->dev_ctx_[p] = member_->fetch_dev_ctxs_.Get(p); } for (auto *var : vars) { diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index ab02a95f26..59b76a1edb 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -10,43 +10,45 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/device_context.h" +#include #include "paddle/fluid/memory/memory.h" - namespace paddle { namespace platform { DeviceContextPool* DeviceContextPool::pool = nullptr; -const platform::DeviceContext* DeviceContextPool::Get( - const platform::Place& place) { +platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { auto it = device_contexts_.find(place); if (it == device_contexts_.end()) { PADDLE_THROW( "'Place' is not supported, Please re-compile with WITH_GPU " "option"); } - return it->second; + return it->second.get(); } DeviceContextPool::DeviceContextPool( const std::vector& places) { PADDLE_ENFORCE_GT(places.size(), 0); - for (size_t i = 0; i < places.size(); i++) { - if (platform::is_cpu_place(places[i])) { + using PtrType = std::unique_ptr; + std::unordered_set set; + for (auto& p : places) { + set.insert(p); + } + + for (auto& p : set) { + if (platform::is_cpu_place(p)) { #ifdef PADDLE_WITH_MKLDNN - device_contexts_.emplace(places[i], - new platform::MKLDNNDeviceContext( - boost::get(places[i]))); + device_contexts_.emplace( + p, PtrType(new MKLDNNDeviceContext(boost::get(p)))); #else - device_contexts_.emplace(places[i], - new platform::CPUDeviceContext( - boost::get(places[i]))); + device_contexts_.emplace( + p, PtrType(new CPUDeviceContext(boost::get(p)))); #endif - } else if (platform::is_gpu_place(places[i])) { + } else if (platform::is_gpu_place(p)) { #ifdef PADDLE_WITH_CUDA - device_contexts_.emplace(places[i], - new platform::CUDADeviceContext( - boost::get(places[i]))); + device_contexts_.emplace( + p, PtrType(new CUDADeviceContext(boost::get(p)))); #else PADDLE_THROW( "'CUDAPlace' is not supported, Please re-compile with WITH_GPU " diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index df0a427b48..202394c7be 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -160,7 +160,7 @@ class DeviceContextPool { } /*! \brief Return handle of single device context. */ - const platform::DeviceContext* Get(const platform::Place& place); + platform::DeviceContext* Get(const platform::Place& place); template const typename DefaultDeviceContextType::TYPE* GetByPlace( @@ -173,19 +173,8 @@ class DeviceContextPool { private: static DeviceContextPool* pool; - constexpr static int LEFT_SHIFT = 8; - struct Hash { - std::hash hash_; - size_t operator()(const platform::Place& place) const { - int pre_hash = place.which() << LEFT_SHIFT; - if (platform::is_gpu_place(place)) { - pre_hash += boost::get(place).GetDeviceId(); - } - return hash_(pre_hash); - } - }; - std::unordered_map + std::unordered_map, PlaceHash> device_contexts_; DISABLE_COPY_AND_ASSIGN(DeviceContextPool); }; diff --git a/paddle/fluid/platform/place.h b/paddle/fluid/platform/place.h index 633251eb47..4cc8b377b8 100644 --- a/paddle/fluid/platform/place.h +++ b/paddle/fluid/platform/place.h @@ -67,12 +67,13 @@ bool is_same_place(const Place &, const Place &); struct PlaceHash { std::size_t operator()(const Place &p) const { + constexpr size_t num_dev_bits = 4; std::hash ihash; size_t dev_id = 0; if (is_gpu_place(p)) { dev_id = boost::get(p).device; } - return ihash(dev_id << 2 | p.which()); + return ihash(dev_id << num_dev_bits | p.which()); } }; -- GitLab