/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include // NOLINT #include #include // NOLINT #include #include #include #include #include "paddle/fluid/memory/malloc.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cudnn.h" #if !defined(__APPLE__) && !defined(_WIN32) #include "paddle/fluid/platform/dynload/nccl.h" #endif #include "paddle/fluid/platform/gpu_info.h" #endif #ifdef PADDLE_WITH_MKLDNN #include "mkldnn.hpp" #endif #include #include "glog/logging.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/stream_callback_manager.h" #endif #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace platform { class DeviceContext { public: virtual ~DeviceContext() {} virtual Place GetPlace() const = 0; virtual void Wait() const {} }; class CPUDeviceContext : public DeviceContext { public: CPUDeviceContext(); explicit CPUDeviceContext(CPUPlace place); Eigen::DefaultDevice* eigen_device() const; Place GetPlace() const override; private: CPUPlace place_; std::unique_ptr eigen_device_; }; template struct DefaultDeviceContextType; template <> struct DefaultDeviceContextType { using TYPE = CPUDeviceContext; }; #ifdef PADDLE_WITH_CUDA class EigenCudaStreamDevice; class CudnnWorkspaceHandle; class CUDADeviceContext : public DeviceContext { public: explicit CUDADeviceContext(CUDAPlace place); virtual ~CUDADeviceContext(); /*! \brief Wait for all operations completion in the stream. */ void Wait() const override; /*! \brief Return place in the device context. */ Place GetPlace() const override; /*! \brief Return compute capability in the device context. */ int GetComputeCapability() const; /*! \brief Return the max physical thread count in the device context */ int GetMaxPhysicalThreadCount() const; /*! \brief Return eigen device in the device context. */ Eigen::GpuDevice* eigen_device() const; /*! \brief Call cublas function safely. */ template inline void CublasCall(Callback&& callback) const { cublas_handle_->Call(std::forward(callback)); } /*! \brief Check whether tensor core is supported */ bool tensor_core_available() const; /*! \brief Call cublas function with Tensor Core safely. If Tensor Core is not available, use DEFAULT_MATH instead. */ template inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const { if (cublas_tensor_core_handle_) { cublas_tensor_core_handle_->Call(std::forward(callback)); } else { cublas_handle_->Call(std::forward(callback)); } } /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle() const; /*! \brief Return a cudnn workspace handle to call multiple cudnn * functions without interrupting by other threads. * Once the first cudnn function is called by the handle, a lock * would be acquired to prevent other threads from accessing the * workspace. Once the handle is destructed, the lock would be released. * CudnnWorkspaceHandle is an RAII object to implement thread-safe * sequential cudnn function calls. */ CudnnWorkspaceHandle cudnn_workspace_handle() const; /*! \brief Return cuda stream in the device context. */ cudaStream_t stream() const; #if !defined(_WIN32) /*! \brief Return nccl communicators. */ ncclComm_t nccl_comm() const { return nccl_comm_; } /*! \brief Set nccl communicators. */ void set_nccl_comm(ncclComm_t comm) { nccl_comm_ = comm; } #endif template void RecordEvent(cudaEvent_t ev, Callback callback) { callback(); PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(ev, stream_)); } template void AddStreamCallback(Callback&& callback) const { callback_manager_->AddCallback(callback); } void WaitStreamCallback() const { callback_manager_->Wait(); } private: CUDAPlace place_; mutable std::once_flag init_cudnn_; std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_; cudaStream_t stream_; cudnnHandle_t cudnn_handle_; mutable std::mutex cudnn_handle_mtx_; std::unique_ptr cublas_handle_; std::unique_ptr cublas_tensor_core_handle_; #if !defined(_WIN32) // NCCL communicator (single process version) for NCCL collective operations. // NCCL collective operations provides fast collectives over multiple GPUs // both within and across nodes. // But, this collectives is used for collectives over multiple GPUs within // nodes. ncclComm_t nccl_comm_{nullptr}; #endif int compute_capability_; int runtime_version_; int driver_version_; int multi_process_; int max_threads_per_mp_; // StreamCallbackManager is thread-safe std::unique_ptr callback_manager_; DISABLE_COPY_AND_ASSIGN(CUDADeviceContext); }; class CudnnWorkspaceHandle { public: inline CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx, std::mutex* mtx) : device_context_(dev_ctx), mtx_(mtx) {} template inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_bytes) { if (required_workspace_bytes > WorkspaceSize()) { ReallocWorkspace(required_workspace_bytes); } VLOG(2) << "Cudnn workspace size at RunFunc: " << static_cast(WorkspaceSize()) / (1 << 20) << " MB"; { std::lock_guard guard(*mtx_); cudnn_func(allocation_ ? allocation_->ptr() : nullptr); } } /*! \brief Thread which call RunFuncSync() would release gpu memory after * running the function. Currently this function is only used when cudnn * exhaustive searching and callers have to guarantee that the input function * is host blocking */ template inline void RunFuncSync(Callback&& cudnn_func, size_t required_workspace_bytes) { RunFunc(cudnn_func, required_workspace_bytes); ResetWorkspace(); } inline void ReallocWorkspace(size_t required_workspace_bytes) { if (required_workspace_bytes <= WorkspaceSize()) { return; } // reset allocation first before re-allocate to save memory allocation_.reset(); allocation_ = memory::Alloc(device_context_, required_workspace_bytes); } inline void ResetWorkspace() { allocation_ = nullptr; } inline size_t WorkspaceSize() { if (allocation_ == nullptr) { return 0; } return allocation_->size(); } CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default; CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete; private: memory::allocation::AllocationPtr allocation_; const CUDADeviceContext& device_context_; std::mutex* mtx_; }; template <> struct DefaultDeviceContextType { using TYPE = CUDADeviceContext; }; // Currently, CUDAPinnedDeviceContext is only used to data copying. class CUDAPinnedDeviceContext : public DeviceContext { public: CUDAPinnedDeviceContext(); explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place); Place GetPlace() const override; Eigen::DefaultDevice* eigen_device() const; private: CUDAPinnedPlace place_; std::unique_ptr eigen_device_; }; template <> struct DefaultDeviceContextType { using TYPE = CUDAPinnedDeviceContext; }; #endif #ifdef PADDLE_WITH_MKLDNN // Following three maps are used to cache MKLDNN primitives. // There relations are: // - BlobMap = Map // - ShapeBlob = Map // - KeyBlob = Map // Where: using KeyBlob = std::unordered_map>; using ShapeBlob = std::unordered_map>; using BlobMap = std::unordered_map>; // default mkldnn session id constexpr size_t kMKLDNNSessionID_Default = 0; // mkldnn session id for cache clearing mode constexpr size_t kMKLDNNSessionID_CacheClearing = -1; void set_cur_mkldnn_session_id(size_t); size_t get_cur_mkldnn_session_id(void); void set_cur_input_shape_str(std::string input_shape_str); void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity); class MKLDNNDeviceContext : public CPUDeviceContext { public: explicit MKLDNNDeviceContext(CPUPlace place); /* \brief Get the active engine */ const mkldnn::engine& GetEngine() const { return engine_; } // Remove all entries from the blob map void ResetBlobMap() const; // Get the ShapeBlob size in cur_mkldnn_session_id. size_t GetShapeBlobSize() const; // Set data to blob (i.e. name/data pair). Create blob if not existing void SetBlob(const std::string& name, std::shared_ptr data) const; // Find a saved blob. Return nullptr if not found std::shared_ptr GetBlob(const std::string& name) const; private: mkldnn::engine engine_; std::shared_ptr p_blobmap_; std::shared_ptr p_mutex_; }; #endif /*! \brief device context pool singleton */ class DeviceContextPool { public: explicit DeviceContextPool(const std::vector& places); static DeviceContextPool& Instance() { PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!"); return *pool; } /*! \brief Create should only called by Init function */ static DeviceContextPool& Init(const std::vector& places) { if (pool == nullptr) { pool = new DeviceContextPool(places); } return *pool; } /*! \brief Return handle of single device context. */ platform::DeviceContext* Get(const platform::Place& place); template const typename DefaultDeviceContextType::TYPE* GetByPlace( const Place& place) { return reinterpret_cast< const typename DefaultDeviceContextType::TYPE*>(Get(place)); } size_t size() const { return device_contexts_.size(); } private: static DeviceContextPool* pool; std::map>> device_contexts_; DISABLE_COPY_AND_ASSIGN(DeviceContextPool); }; } // namespace platform } // namespace paddle