device_context.h 11.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
QI JUN 已提交
2 3 4 5 6 7 8 9 10 11 12
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

13
#include <future>  // NOLINT
D
dzhwinter 已提交
14
#include <memory>
Y
yuyang18 已提交
15
#include <mutex>  // NOLINT
16
#include <string>
D
dzhwinter 已提交
17
#include <unordered_map>
18
#include <utility>
19
#include <vector>
Y
Yu Yang 已提交
20
#include "paddle/fluid/memory/malloc.h"
21
#ifdef PADDLE_WITH_CUDA
22
#include "paddle/fluid/platform/cuda_helper.h"
Y
Yi Wang 已提交
23 24
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
25
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
W
Wu Yi 已提交
26
#include "paddle/fluid/platform/dynload/nccl.h"
W
Wu Yi 已提交
27
#endif
Y
Yi Wang 已提交
28
#include "paddle/fluid/platform/gpu_info.h"
Q
QI JUN 已提交
29
#endif
D
dzhwinter 已提交
30

T
tensor-tang 已提交
31
#ifdef PADDLE_WITH_MKLDNN
L
luotao1 已提交
32
#include "mkldnn.hpp"
33
#include "paddle/fluid/framework/data_layout.h"
T
tensor-tang 已提交
34 35
#endif

36 37
#include <map>
#include "glog/logging.h"
Y
Yi Wang 已提交
38 39
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
S
sneaxiy 已提交
40 41 42
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/stream_callback_manager.h"
#endif
Q
qijun 已提交
43
#include "unsupported/Eigen/CXX11/Tensor"
Q
QI JUN 已提交
44 45 46 47 48 49

namespace paddle {
namespace platform {

class DeviceContext {
 public:
Z
Zeng Jinle 已提交
50
  virtual ~DeviceContext() PADDLE_MAY_THROW {}
L
liaogang 已提交
51
  virtual Place GetPlace() const = 0;
Q
QI JUN 已提交
52

53
  virtual void Wait() const {}
Q
QI JUN 已提交
54 55
};

Q
qijun 已提交
56 57
class CPUDeviceContext : public DeviceContext {
 public:
58
  CPUDeviceContext();
Q
qijun 已提交
59
  explicit CPUDeviceContext(CPUPlace place);
Q
qijun 已提交
60

61
  Eigen::DefaultDevice* eigen_device() const;
Q
qijun 已提交
62

L
liaogang 已提交
63
  Place GetPlace() const override;
Y
Yu Yang 已提交
64

Q
qijun 已提交
65
 private:
D
dzhwinter 已提交
66
  CPUPlace place_;
Q
qijun 已提交
67
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
Q
QI JUN 已提交
68 69
};

Y
Yang Yu 已提交
70 71 72 73 74 75 76 77
template <typename Place>
struct DefaultDeviceContextType;

template <>
struct DefaultDeviceContextType<platform::CPUPlace> {
  using TYPE = CPUDeviceContext;
};

78
#ifdef PADDLE_WITH_CUDA
79

Q
qijun 已提交
80
class EigenCudaStreamDevice;
81
class CudnnWorkspaceHandle;
S
sneaxiy 已提交
82

83
class CUDADeviceContext : public DeviceContext {
Q
QI JUN 已提交
84
 public:
D
dzhwinter 已提交
85
  explicit CUDADeviceContext(CUDAPlace place);
86
  virtual ~CUDADeviceContext();
Q
QI JUN 已提交
87

88
  /*! \brief  Wait for all operations completion in the stream. */
89
  void Wait() const override;
Q
QI JUN 已提交
90

91
  /*! \brief  Return place in the device context. */
L
liaogang 已提交
92
  Place GetPlace() const override;
93

K
Kexin Zhao 已提交
94
  /*! \brief  Return compute capability in the device context. */
K
Kexin Zhao 已提交
95 96
  int GetComputeCapability() const;

97 98 99
  /*! \brief  Return the max physical thread count in the device context */
  int GetMaxPhysicalThreadCount() const;

100 101 102 103 104 105
  /*! \brief  Return the SM count in the device context */
  int GetSMCount() const;

  /*! \brief  Return the Max thread num of block in the device context */
  int GetMaxThreadsPerBlock() const;

106 107 108
  /*! \brief  Return the max grid dim size in the device context */
  dim3 GetCUDAMaxGridDimSize() const;

109 110 111
  /*! \brief  Return eigen device in the device context. */
  Eigen::GpuDevice* eigen_device() const;

112 113 114
  /*! \brief  Call cublas function safely. */
  template <typename Callback>
  inline void CublasCall(Callback&& callback) const {
115
    cublas_handle_->Call(std::forward<Callback>(callback));
116 117 118 119 120 121 122 123 124
  }

  /*! \brief  Check whether tensor core is supported */
  bool tensor_core_available() const;

  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
  template <typename Callback>
  inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
125 126 127 128 129
    if (cublas_tensor_core_handle_) {
      cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
    } else {
      cublas_handle_->Call(std::forward<Callback>(callback));
    }
130
  }
S
sneaxiy 已提交
131

132
  /*! \brief  Return cudnn  handle in the device context. */
133
  cudnnHandle_t cudnn_handle() const;
134

S
sneaxiy 已提交
135 136 137 138 139 140 141 142 143
  /*! \brief  Return a cudnn workspace handle to call multiple cudnn
   *  functions without interrupting by other threads.
   *  Once the first cudnn function is called by the handle, a lock
   *  would be acquired to prevent other threads from accessing the
   *  workspace. Once the handle is destructed, the lock would be released.
   *  CudnnWorkspaceHandle is an RAII object to implement thread-safe
   *  sequential cudnn function calls. */
  CudnnWorkspaceHandle cudnn_workspace_handle() const;

Q
init  
qijun 已提交
144
  /*! \brief  Return cuda stream in the device context. */
145
  cudaStream_t stream() const;
Q
QI JUN 已提交
146

147
#if defined(PADDLE_WITH_NCCL)
Q
qingqing01 已提交
148 149 150 151 152
  /*! \brief  Return nccl communicators. */
  ncclComm_t nccl_comm() const { return nccl_comm_; }

  /*! \brief  Set nccl communicators. */
  void set_nccl_comm(ncclComm_t comm) { nccl_comm_ = comm; }
Q
qingqing01 已提交
153
#endif
Q
qingqing01 已提交
154

Y
Yu Yang 已提交
155 156
  template <typename Callback>
  void RecordEvent(cudaEvent_t ev, Callback callback) {
157 158
    callback();
    PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(ev, stream_));
Y
Yu Yang 已提交
159 160
  }

S
sneaxiy 已提交
161 162
  template <typename Callback>
  void AddStreamCallback(Callback&& callback) const {
163
    callback_manager_->AddCallback(callback);
164 165
  }

166
  void WaitStreamCallback() const { callback_manager_->Wait(); }
S
sneaxiy 已提交
167

Q
QI JUN 已提交
168
 private:
D
dzhwinter 已提交
169
  CUDAPlace place_;
Q
QI JUN 已提交
170

171 172 173 174 175
  mutable std::once_flag init_cudnn_;

  std::unique_ptr<Eigen::GpuDevice> eigen_device_;
  std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
  cudaStream_t stream_;
176

177
  cudnnHandle_t cudnn_handle_;
178 179
  mutable std::mutex cudnn_handle_mtx_;

180 181 182
  std::unique_ptr<CublasHandleHolder> cublas_handle_;
  std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;

183
#if defined(PADDLE_WITH_NCCL)
Q
qingqing01 已提交
184 185 186 187 188 189
  // NCCL communicator (single process version) for NCCL collective operations.
  // NCCL collective operations provides fast collectives over multiple GPUs
  // both within and across nodes.
  // But, this collectives is used for collectives over multiple GPUs within
  // nodes.
  ncclComm_t nccl_comm_{nullptr};
Q
qingqing01 已提交
190
#endif
Q
qingqing01 已提交
191

C
chengduo 已提交
192 193 194 195 196
  int compute_capability_;
  int runtime_version_;
  int driver_version_;
  int multi_process_;
  int max_threads_per_mp_;
197
  int max_threads_per_block_;
198
  dim3 max_grid_dim_size_;
Y
yuyang18 已提交
199

200 201 202
  // StreamCallbackManager is thread-safe
  std::unique_ptr<StreamCallbackManager> callback_manager_;

203
  DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
Q
QI JUN 已提交
204
};
Q
qijun 已提交
205

206 207
class CudnnWorkspaceHandle {
 public:
208 209
  inline CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx, std::mutex* mtx)
      : device_context_(dev_ctx), mtx_(mtx) {}
210 211 212 213 214 215 216 217

  template <typename Callback>
  inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_bytes) {
    if (required_workspace_bytes > WorkspaceSize()) {
      ReallocWorkspace(required_workspace_bytes);
    }
    VLOG(2) << "Cudnn workspace size at RunFunc: "
            << static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB";
218 219 220 221
    {
      std::lock_guard<std::mutex> guard(*mtx_);
      cudnn_func(allocation_ ? allocation_->ptr() : nullptr);
    }
222 223 224 225 226 227 228 229 230 231 232 233 234
  }

  /*! \brief Thread which call RunFuncSync() would release gpu memory after
   *  running the function. Currently this function is only used when cudnn
   *  exhaustive searching and callers have to guarantee that the input function
   *  is host blocking */
  template <typename Callback>
  inline void RunFuncSync(Callback&& cudnn_func,
                          size_t required_workspace_bytes) {
    RunFunc(cudnn_func, required_workspace_bytes);
    ResetWorkspace();
  }

235
  void ReallocWorkspace(size_t required_workspace_bytes);
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251

  inline void ResetWorkspace() { allocation_ = nullptr; }

  inline size_t WorkspaceSize() {
    if (allocation_ == nullptr) {
      return 0;
    }
    return allocation_->size();
  }

  CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
  CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;

 private:
  memory::allocation::AllocationPtr allocation_;
  const CUDADeviceContext& device_context_;
252
  std::mutex* mtx_;
253 254
};

Y
Yang Yu 已提交
255 256
template <>
struct DefaultDeviceContextType<platform::CUDAPlace> {
Y
Yang Yu 已提交
257
  using TYPE = CUDADeviceContext;
Y
Yang Yu 已提交
258 259
};

C
chengduoZH 已提交
260
// Currently, CUDAPinnedDeviceContext is only used to data copying.
C
chengduoZH 已提交
261 262 263 264 265 266
class CUDAPinnedDeviceContext : public DeviceContext {
 public:
  CUDAPinnedDeviceContext();
  explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place);

  Place GetPlace() const override;
C
chengduoZH 已提交
267

C
chengduoZH 已提交
268 269 270 271 272 273 274 275 276 277 278
  Eigen::DefaultDevice* eigen_device() const;

 private:
  CUDAPinnedPlace place_;
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

template <>
struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
  using TYPE = CUDAPinnedDeviceContext;
};
Q
QI JUN 已提交
279
#endif
Q
qijun 已提交
280

T
tensor-tang 已提交
281
#ifdef PADDLE_WITH_MKLDNN
282 283 284 285 286 287
// Following three maps are used to cache MKLDNN primitives.
// There relations are:
// - BlobMap = Map<cur_thread_id, ShapeBlob>
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
// - KeyBlob  = Map<blob_name, blob>
// Where:
S
Sylwester Fraczek 已提交
288
using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
289 290
using ShapeBlob = std::unordered_map<std::string, std::shared_ptr<KeyBlob>>;
using BlobMap = std::unordered_map<int, std::shared_ptr<ShapeBlob>>;
S
Sylwester Fraczek 已提交
291

292 293 294 295 296 297 298
// default mkldnn session id
constexpr size_t kMKLDNNSessionID_Default = 0;
// mkldnn session id for cache clearing mode
constexpr size_t kMKLDNNSessionID_CacheClearing = -1;

void set_cur_mkldnn_session_id(size_t);
size_t get_cur_mkldnn_session_id(void);
299
void set_cur_input_shape_str(std::string input_shape_str);
300
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
301 302
void set_cur_paddle_data_layout(framework::DataLayout);
framework::DataLayout get_cur_paddle_data_layout(void);
S
Sylwester Fraczek 已提交
303

T
tensor-tang 已提交
304 305 306 307 308
class MKLDNNDeviceContext : public CPUDeviceContext {
 public:
  explicit MKLDNNDeviceContext(CPUPlace place);

  /* \brief  Get the active engine */
309
  const mkldnn::engine& GetEngine() const { return engine_; }
T
tensor-tang 已提交
310

311 312 313
  // Remove all entries from the blob map
  void ResetBlobMap() const;

314 315 316
  // Get the ShapeBlob size in cur_mkldnn_session_id.
  size_t GetShapeBlobSize() const;

317 318
  // Set data to blob (i.e. name/data pair). Create blob if not existing
  void SetBlob(const std::string& name, std::shared_ptr<void> data) const;
T
tensor-tang 已提交
319

320 321
  // Find a saved blob. Return nullptr if not found
  std::shared_ptr<void> GetBlob(const std::string& name) const;
T
tensor-tang 已提交
322 323

 private:
324
  mkldnn::engine engine_;
325 326
  std::shared_ptr<BlobMap> p_blobmap_;
  std::shared_ptr<std::mutex> p_mutex_;
T
tensor-tang 已提交
327 328 329
};
#endif

D
dzhwinter 已提交
330 331 332 333 334
/*! \brief device context pool singleton */
class DeviceContextPool {
 public:
  explicit DeviceContextPool(const std::vector<platform::Place>& places);

Y
Yang Yu 已提交
335
  static DeviceContextPool& Instance() {
D
dzhwinter 已提交
336 337 338 339 340
    PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!");
    return *pool;
  }

  /*! \brief  Create should only called by Init function */
Y
Yang Yu 已提交
341
  static DeviceContextPool& Init(const std::vector<platform::Place>& places) {
D
dzhwinter 已提交
342 343 344 345 346 347
    if (pool == nullptr) {
      pool = new DeviceContextPool(places);
    }
    return *pool;
  }

348 349
  static void SetPool(DeviceContextPool* dev_pool) { pool = dev_pool; }

D
dzhwinter 已提交
350
  /*! \brief  Return handle of single device context. */
Y
Yu Yang 已提交
351
  platform::DeviceContext* Get(const platform::Place& place);
D
dzhwinter 已提交
352

Y
Yang Yu 已提交
353 354 355 356 357 358 359
  template <typename Place>
  const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
      const Place& place) {
    return reinterpret_cast<
        const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place));
  }

360 361
  size_t size() const { return device_contexts_.size(); }

D
dzhwinter 已提交
362 363
 private:
  static DeviceContextPool* pool;
364 365
  std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>
      device_contexts_;
D
dzhwinter 已提交
366 367 368
  DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};

Q
QI JUN 已提交
369 370
}  // namespace platform
}  // namespace paddle