device_context.h 15.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
QI JUN 已提交
2 3 4 5 6 7 8 9 10 11 12
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

13
#include <future>  // NOLINT
D
dzhwinter 已提交
14
#include <memory>
Y
yuyang18 已提交
15
#include <mutex>  // NOLINT
16
#include <string>
D
dzhwinter 已提交
17
#include <unordered_map>
18
#include <utility>
19
#include <vector>
Y
Yu Yang 已提交
20
#include "paddle/fluid/memory/malloc.h"
21
#ifdef PADDLE_WITH_CUDA
22
#include "paddle/fluid/platform/cuda_helper.h"
Y
Yi Wang 已提交
23 24
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
25
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
W
Wu Yi 已提交
26
#include "paddle/fluid/platform/dynload/nccl.h"
W
Wu Yi 已提交
27
#endif
Y
Yi Wang 已提交
28
#include "paddle/fluid/platform/gpu_info.h"
Q
QI JUN 已提交
29
#endif
D
dzhwinter 已提交
30

T
tensor-tang 已提交
31
#ifdef PADDLE_WITH_MKLDNN
L
luotao1 已提交
32
#include "mkldnn.hpp"
33
#include "paddle/fluid/framework/data_layout.h"
T
tensor-tang 已提交
34 35
#endif

36 37
#include <map>
#include "glog/logging.h"
Y
Yi Wang 已提交
38 39
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
S
sneaxiy 已提交
40
#ifdef PADDLE_WITH_CUDA
41
#include "paddle/fluid/platform/stream/cuda_stream.h"
S
sneaxiy 已提交
42
#endif
Q
qijun 已提交
43
#include "unsupported/Eigen/CXX11/Tensor"
Q
QI JUN 已提交
44 45 46 47 48 49

namespace paddle {
namespace platform {

class DeviceContext {
 public:
Z
Zeng Jinle 已提交
50
  virtual ~DeviceContext() PADDLE_MAY_THROW {}
L
liaogang 已提交
51
  virtual Place GetPlace() const = 0;
Q
QI JUN 已提交
52

53
  virtual void Wait() const {}
Q
QI JUN 已提交
54 55
};

Q
qijun 已提交
56 57
class CPUDeviceContext : public DeviceContext {
 public:
58
  CPUDeviceContext();
Q
qijun 已提交
59
  explicit CPUDeviceContext(CPUPlace place);
Q
qijun 已提交
60

61
  Eigen::DefaultDevice* eigen_device() const;
Q
qijun 已提交
62

L
liaogang 已提交
63
  Place GetPlace() const override;
Y
Yu Yang 已提交
64

Q
qijun 已提交
65
 private:
D
dzhwinter 已提交
66
  CPUPlace place_;
Q
qijun 已提交
67
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
Q
QI JUN 已提交
68 69
};

Y
Yang Yu 已提交
70 71 72 73 74 75 76 77
template <typename Place>
struct DefaultDeviceContextType;

template <>
struct DefaultDeviceContextType<platform::CPUPlace> {
  using TYPE = CPUDeviceContext;
};

78
#ifdef PADDLE_WITH_CUDA
79

Q
qijun 已提交
80
class EigenCudaStreamDevice;
81
class CudnnWorkspaceHandle;
S
sneaxiy 已提交
82

83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
class CUDAContext {
 public:
  CUDAContext() = default;
  explicit CUDAContext(
      const CUDAPlace& place,
      const stream::Priority& priority = stream::Priority::kNormal);

  ~CUDAContext();

  const CUDAPlace& Place() const { return place_; }

  const std::unique_ptr<Eigen::GpuDevice>& EigenDevice() const {
    return eigen_device_;
  }

  const std::unique_ptr<EigenCudaStreamDevice>& EigenStream() const {
    return eigen_stream_;
  }

  const std::unique_ptr<stream::CUDAStream>& Stream() const { return stream_; }

  const cudaStream_t& RawStream() { return stream_->raw_stream(); }

  const cudnnHandle_t& CudnnHandle() const { return cudnn_handle_; }

  const std::unique_ptr<CublasHandleHolder>& CublasHandle() const {
    return cublas_handle_;
  }

  const std::unique_ptr<CublasHandleHolder>& CublasTensorCoreHandle() const {
    return cublas_tensor_core_handle_;
  }

  /*! \brief  Call cublas function safely. */
  template <typename Callback>
  inline void CublasCall(Callback&& callback) const {
    cublas_handle_->Call(std::forward<Callback>(callback));
  }

  /*! \brief  Check whether tensor core is supported */
  bool tensor_core_available() const;

  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
  template <typename Callback>
  inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
    if (cublas_tensor_core_handle_) {
      cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
    } else {
      cublas_handle_->Call(std::forward<Callback>(callback));
    }
  }

 private:
  void InitEigenContext();

  void InitCuBlasContext() {
    cublas_handle_.reset(
        new CublasHandleHolder(RawStream(), CUBLAS_DEFAULT_MATH));
    if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
      cublas_tensor_core_handle_.reset(
          new CublasHandleHolder(RawStream(), CUBLAS_TENSOR_OP_MATH));
#endif
    }
  }

  void InitCuDNNContext() {
    if (dynload::HasCUDNN()) {
      auto local_cudnn_version = dynload::cudnnGetVersion() / 100;
      auto compile_cudnn_version = CUDNN_VERSION / 100;
      if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
        LOG_FIRST_N(WARNING, 1)
            << "WARNING: device: " << place_.device
            << ". The installed Paddle is compiled with CUDNN "
            << compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
            << ", but CUDNN version in your machine is "
            << local_cudnn_version / 10 << "." << local_cudnn_version % 10
            << ", which may cause serious incompatible bug. "
            << "Please recompile or reinstall Paddle with compatible CUDNN "
               "version.";
      }
      PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreate(&cudnn_handle_));
      PADDLE_ENFORCE_CUDA_SUCCESS(
          dynload::cudnnSetStream(cudnn_handle_, RawStream()));
    } else {
      cudnn_handle_ = nullptr;
    }
  }

  void DestoryCuDNNContext() {
    if (cudnn_handle_) {
      PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroy(cudnn_handle_));
    }
    cudnn_handle_ = nullptr;
  }

  void DestoryCuBlasContext() {
    cublas_handle_.reset();
    cublas_tensor_core_handle_.reset();
  }

  CUDAPlace place_;
  std::unique_ptr<Eigen::GpuDevice> eigen_device_;
  std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
  std::unique_ptr<stream::CUDAStream> stream_;
  cudnnHandle_t cudnn_handle_;
  std::unique_ptr<CublasHandleHolder> cublas_handle_;
  std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
  DISABLE_COPY_AND_ASSIGN(CUDAContext);
};

195
class CUDADeviceContext : public DeviceContext {
Q
QI JUN 已提交
196
 public:
D
dzhwinter 已提交
197
  explicit CUDADeviceContext(CUDAPlace place);
198
  virtual ~CUDADeviceContext();
Q
QI JUN 已提交
199

200
  /*! \brief  Wait for all operations completion in the stream. */
201
  void Wait() const override;
Q
QI JUN 已提交
202

203
  /*! \brief  Return place in the device context. */
L
liaogang 已提交
204
  Place GetPlace() const override;
205

K
Kexin Zhao 已提交
206
  /*! \brief  Return compute capability in the device context. */
K
Kexin Zhao 已提交
207 208
  int GetComputeCapability() const;

209 210 211
  /*! \brief  Return the max physical thread count in the device context */
  int GetMaxPhysicalThreadCount() const;

212 213 214 215 216 217
  /*! \brief  Return the SM count in the device context */
  int GetSMCount() const;

  /*! \brief  Return the Max thread num of block in the device context */
  int GetMaxThreadsPerBlock() const;

218 219 220
  /*! \brief  Return the max grid dim size in the device context */
  dim3 GetCUDAMaxGridDimSize() const;

221 222 223
  /*! \brief  Return eigen device in the device context. */
  Eigen::GpuDevice* eigen_device() const;

224 225 226
  /*! \brief  Call cublas function safely. */
  template <typename Callback>
  inline void CublasCall(Callback&& callback) const {
227
    return context()->CublasCall(callback);
228 229 230 231 232 233 234 235 236
  }

  /*! \brief  Check whether tensor core is supported */
  bool tensor_core_available() const;

  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
  template <typename Callback>
  inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
237
    return context()->TensorCoreCublasCallIfAvailable(callback);
238
  }
S
sneaxiy 已提交
239

240
  /*! \brief  Return cudnn  handle in the device context. */
241
  cudnnHandle_t cudnn_handle() const;
242

S
sneaxiy 已提交
243 244 245 246 247 248 249 250 251
  /*! \brief  Return a cudnn workspace handle to call multiple cudnn
   *  functions without interrupting by other threads.
   *  Once the first cudnn function is called by the handle, a lock
   *  would be acquired to prevent other threads from accessing the
   *  workspace. Once the handle is destructed, the lock would be released.
   *  CudnnWorkspaceHandle is an RAII object to implement thread-safe
   *  sequential cudnn function calls. */
  CudnnWorkspaceHandle cudnn_workspace_handle() const;

Q
init  
qijun 已提交
252
  /*! \brief  Return cuda stream in the device context. */
253
  cudaStream_t stream() const;
Q
QI JUN 已提交
254

255
#if defined(PADDLE_WITH_NCCL)
Q
qingqing01 已提交
256 257 258 259 260
  /*! \brief  Return nccl communicators. */
  ncclComm_t nccl_comm() const { return nccl_comm_; }

  /*! \brief  Set nccl communicators. */
  void set_nccl_comm(ncclComm_t comm) { nccl_comm_ = comm; }
Q
qingqing01 已提交
261
#endif
Q
qingqing01 已提交
262

Y
Yu Yang 已提交
263
  template <typename Callback>
264 265
  void RecordEvent(cudaEvent_t ev, Callback callback) const {
    return context()->Stream()->RecordEvent(ev, callback);
Y
Yu Yang 已提交
266 267
  }

S
sneaxiy 已提交
268 269
  template <typename Callback>
  void AddStreamCallback(Callback&& callback) const {
270 271 272 273 274
    return context()->Stream()->AddCallback(callback);
  }

  void WaitStreamCallback() const {
    return context()->Stream()->WaitCallback();
275 276
  }

277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
  void ResetDefaultContext(const stream::Priority& priority) {
    default_ctx_.reset(new CUDAContext(place_, priority));
  }

  void ResetThreadContext(const stream::Priority& priority) {
    std::lock_guard<std::mutex> guard(ctx_mtx_);
    thread_ctx_[this].reset(new CUDAContext(place_, priority));
  }

  std::shared_ptr<CUDAContext> context() const {
    if (!thread_ctx_.count(this)) {
      return default_ctx_;
    }
    return thread_ctx_.at(this);
  }
S
sneaxiy 已提交
292

Q
QI JUN 已提交
293
 private:
D
dzhwinter 已提交
294
  CUDAPlace place_;
295
  std::shared_ptr<CUDAContext> default_ctx_;
Q
QI JUN 已提交
296

297 298 299 300 301 302
  // The thread_local static variable will be released before the
  // global static variable, so avoid using it in dtor.
  static thread_local std::unordered_map<const CUDADeviceContext*,
                                         std::shared_ptr<CUDAContext>>
      thread_ctx_;
  static thread_local std::mutex ctx_mtx_;
303

304 305
  mutable std::mutex cudnn_handle_mtx_;

306
#if defined(PADDLE_WITH_NCCL)
Q
qingqing01 已提交
307 308 309 310 311 312
  // NCCL communicator (single process version) for NCCL collective operations.
  // NCCL collective operations provides fast collectives over multiple GPUs
  // both within and across nodes.
  // But, this collectives is used for collectives over multiple GPUs within
  // nodes.
  ncclComm_t nccl_comm_{nullptr};
Q
qingqing01 已提交
313
#endif
Q
qingqing01 已提交
314

C
chengduo 已提交
315 316 317 318 319
  int compute_capability_;
  int runtime_version_;
  int driver_version_;
  int multi_process_;
  int max_threads_per_mp_;
320
  int max_threads_per_block_;
321
  dim3 max_grid_dim_size_;
Y
yuyang18 已提交
322

323
  DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
Q
QI JUN 已提交
324
};
Q
qijun 已提交
325

326 327
class CudnnWorkspaceHandle {
 public:
328 329
  inline CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx, std::mutex* mtx)
      : device_context_(dev_ctx), mtx_(mtx) {}
330 331 332 333 334 335 336 337

  template <typename Callback>
  inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_bytes) {
    if (required_workspace_bytes > WorkspaceSize()) {
      ReallocWorkspace(required_workspace_bytes);
    }
    VLOG(2) << "Cudnn workspace size at RunFunc: "
            << static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB";
338 339 340 341
    {
      std::lock_guard<std::mutex> guard(*mtx_);
      cudnn_func(allocation_ ? allocation_->ptr() : nullptr);
    }
342 343 344 345 346 347 348 349 350 351 352 353 354
  }

  /*! \brief Thread which call RunFuncSync() would release gpu memory after
   *  running the function. Currently this function is only used when cudnn
   *  exhaustive searching and callers have to guarantee that the input function
   *  is host blocking */
  template <typename Callback>
  inline void RunFuncSync(Callback&& cudnn_func,
                          size_t required_workspace_bytes) {
    RunFunc(cudnn_func, required_workspace_bytes);
    ResetWorkspace();
  }

355
  void ReallocWorkspace(size_t required_workspace_bytes);
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371

  inline void ResetWorkspace() { allocation_ = nullptr; }

  inline size_t WorkspaceSize() {
    if (allocation_ == nullptr) {
      return 0;
    }
    return allocation_->size();
  }

  CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
  CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;

 private:
  memory::allocation::AllocationPtr allocation_;
  const CUDADeviceContext& device_context_;
372
  std::mutex* mtx_;
373 374
};

Y
Yang Yu 已提交
375 376
template <>
struct DefaultDeviceContextType<platform::CUDAPlace> {
Y
Yang Yu 已提交
377
  using TYPE = CUDADeviceContext;
Y
Yang Yu 已提交
378 379
};

C
chengduoZH 已提交
380
// Currently, CUDAPinnedDeviceContext is only used to data copying.
C
chengduoZH 已提交
381 382 383 384 385 386
class CUDAPinnedDeviceContext : public DeviceContext {
 public:
  CUDAPinnedDeviceContext();
  explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place);

  Place GetPlace() const override;
C
chengduoZH 已提交
387

C
chengduoZH 已提交
388 389 390 391 392 393 394 395 396 397 398
  Eigen::DefaultDevice* eigen_device() const;

 private:
  CUDAPinnedPlace place_;
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

template <>
struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
  using TYPE = CUDAPinnedDeviceContext;
};
Q
QI JUN 已提交
399
#endif
Q
qijun 已提交
400

T
tensor-tang 已提交
401
#ifdef PADDLE_WITH_MKLDNN
402 403 404 405 406 407
// Following three maps are used to cache MKLDNN primitives.
// There relations are:
// - BlobMap = Map<cur_thread_id, ShapeBlob>
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
// - KeyBlob  = Map<blob_name, blob>
// Where:
S
Sylwester Fraczek 已提交
408
using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
409 410
using ShapeBlob = std::unordered_map<std::string, std::shared_ptr<KeyBlob>>;
using BlobMap = std::unordered_map<int, std::shared_ptr<ShapeBlob>>;
S
Sylwester Fraczek 已提交
411

412 413 414 415 416 417 418
// default mkldnn session id
constexpr size_t kMKLDNNSessionID_Default = 0;
// mkldnn session id for cache clearing mode
constexpr size_t kMKLDNNSessionID_CacheClearing = -1;

void set_cur_mkldnn_session_id(size_t);
size_t get_cur_mkldnn_session_id(void);
419
void set_cur_input_shape_str(std::string input_shape_str);
420
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
421 422
void set_cur_paddle_data_layout(framework::DataLayout);
framework::DataLayout get_cur_paddle_data_layout(void);
S
Sylwester Fraczek 已提交
423

T
tensor-tang 已提交
424 425 426 427 428
class MKLDNNDeviceContext : public CPUDeviceContext {
 public:
  explicit MKLDNNDeviceContext(CPUPlace place);

  /* \brief  Get the active engine */
429
  const mkldnn::engine& GetEngine() const { return engine_; }
T
tensor-tang 已提交
430

431 432 433
  // Remove all entries from the blob map
  void ResetBlobMap() const;

434 435 436
  // Get the ShapeBlob size in cur_mkldnn_session_id.
  size_t GetShapeBlobSize() const;

437 438
  // Set data to blob (i.e. name/data pair). Create blob if not existing
  void SetBlob(const std::string& name, std::shared_ptr<void> data) const;
T
tensor-tang 已提交
439

440 441
  // Find a saved blob. Return nullptr if not found
  std::shared_ptr<void> GetBlob(const std::string& name) const;
T
tensor-tang 已提交
442 443

 private:
444
  mkldnn::engine engine_;
445 446
  std::shared_ptr<BlobMap> p_blobmap_;
  std::shared_ptr<std::mutex> p_mutex_;
T
tensor-tang 已提交
447 448 449
};
#endif

D
dzhwinter 已提交
450 451 452 453 454
/*! \brief device context pool singleton */
class DeviceContextPool {
 public:
  explicit DeviceContextPool(const std::vector<platform::Place>& places);

Y
Yang Yu 已提交
455
  static DeviceContextPool& Instance() {
D
dzhwinter 已提交
456 457 458 459 460
    PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!");
    return *pool;
  }

  /*! \brief  Create should only called by Init function */
Y
Yang Yu 已提交
461
  static DeviceContextPool& Init(const std::vector<platform::Place>& places) {
D
dzhwinter 已提交
462 463 464 465 466 467
    if (pool == nullptr) {
      pool = new DeviceContextPool(places);
    }
    return *pool;
  }

468 469
  static void SetPool(DeviceContextPool* dev_pool) { pool = dev_pool; }

D
dzhwinter 已提交
470
  /*! \brief  Return handle of single device context. */
Y
Yu Yang 已提交
471
  platform::DeviceContext* Get(const platform::Place& place);
D
dzhwinter 已提交
472

Y
Yang Yu 已提交
473 474 475 476 477 478 479
  template <typename Place>
  const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
      const Place& place) {
    return reinterpret_cast<
        const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place));
  }

480 481
  size_t size() const { return device_contexts_.size(); }

D
dzhwinter 已提交
482 483
 private:
  static DeviceContextPool* pool;
484 485
  std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>
      device_contexts_;
D
dzhwinter 已提交
486 487 488
  DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};

Q
QI JUN 已提交
489 490
}  // namespace platform
}  // namespace paddle