diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 98b4178177b0a8bafd6fe34a92be2a07a2fbc5a7..59b76a1edb5ec5900520fbccb6a6f8f6e7a70aa4 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -10,43 +10,45 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/device_context.h" +#include #include "paddle/fluid/memory/memory.h" - namespace paddle { namespace platform { DeviceContextPool* DeviceContextPool::pool = nullptr; -const platform::DeviceContext* DeviceContextPool::Get( - const platform::Place& place) { +platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { auto it = device_contexts_.find(place); if (it == device_contexts_.end()) { PADDLE_THROW( "'Place' is not supported, Please re-compile with WITH_GPU " "option"); } - return it->second; + return it->second.get(); } DeviceContextPool::DeviceContextPool( const std::vector& places) { PADDLE_ENFORCE_GT(places.size(), 0); - for (size_t i = 0; i < places.size(); i++) { - if (platform::is_cpu_place(places[i])) { + using PtrType = std::unique_ptr; + std::unordered_set set; + for (auto& p : places) { + set.insert(p); + } + + for (auto& p : set) { + if (platform::is_cpu_place(p)) { #ifdef PADDLE_WITH_MKLDNN - device_contexts_.emplace(places[i], - new platform::MKLDNNDeviceContext( - boost::get(places[i]))); + device_contexts_.emplace( + p, PtrType(new MKLDNNDeviceContext(boost::get(p)))); #else - device_contexts_.emplace(places[i], - new platform::CPUDeviceContext( - boost::get(places[i]))); + device_contexts_.emplace( + p, PtrType(new CPUDeviceContext(boost::get(p)))); #endif - } else if (platform::is_gpu_place(places[i])) { + } else if (platform::is_gpu_place(p)) { #ifdef PADDLE_WITH_CUDA - device_contexts_.emplace(places[i], - new platform::CUDADeviceContext( - boost::get(places[i]))); + device_contexts_.emplace( + p, PtrType(new CUDADeviceContext(boost::get(p)))); #else PADDLE_THROW( "'CUDAPlace' is not supported, Please re-compile with WITH_GPU " @@ -159,6 +161,7 @@ CUDADeviceContext::~CUDADeviceContext() { Place CUDADeviceContext::GetPlace() const { return place_; } void CUDADeviceContext::Wait() const { + std::lock_guard guard(mutex_); PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); PADDLE_ENFORCE(cudaGetLastError()); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 603b890af13b529c490c29112a73a09cc815d07a..202394c7be7e103a609dd0999fc883c794ef0edd 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -103,6 +103,7 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_; + mutable std::mutex mutex_; cudaStream_t stream_; cudnnHandle_t cudnn_handle_; cublasHandle_t cublas_handle_; @@ -159,7 +160,7 @@ class DeviceContextPool { } /*! \brief Return handle of single device context. */ - const platform::DeviceContext* Get(const platform::Place& place); + platform::DeviceContext* Get(const platform::Place& place); template const typename DefaultDeviceContextType::TYPE* GetByPlace( @@ -172,19 +173,8 @@ class DeviceContextPool { private: static DeviceContextPool* pool; - constexpr static int LEFT_SHIFT = 8; - struct Hash { - std::hash hash_; - size_t operator()(const platform::Place& place) const { - int pre_hash = place.which() << LEFT_SHIFT; - if (platform::is_gpu_place(place)) { - pre_hash += boost::get(place).GetDeviceId(); - } - return hash_(pre_hash); - } - }; - std::unordered_map + std::unordered_map, PlaceHash> device_contexts_; DISABLE_COPY_AND_ASSIGN(DeviceContextPool); }; diff --git a/paddle/fluid/platform/place.h b/paddle/fluid/platform/place.h index 501bddfc6ec8b5d0bf554b0911c32e47fd51ec15..4cc8b377b8b671eb5a446ecbae21ba9628fbd2c8 100644 --- a/paddle/fluid/platform/place.h +++ b/paddle/fluid/platform/place.h @@ -65,6 +65,18 @@ bool is_cpu_place(const Place &); bool places_are_same_class(const Place &, const Place &); bool is_same_place(const Place &, const Place &); +struct PlaceHash { + std::size_t operator()(const Place &p) const { + constexpr size_t num_dev_bits = 4; + std::hash ihash; + size_t dev_id = 0; + if (is_gpu_place(p)) { + dev_id = boost::get(p).device; + } + return ihash(dev_id << num_dev_bits | p.which()); + } +}; + std::ostream &operator<<(std::ostream &, const Place &); template