device_context.h 22.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
QI JUN 已提交
2 3 4 5 6 7 8 9 10 11 12
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

13
#include <future>  // NOLINT
D
dzhwinter 已提交
14
#include <memory>
Y
yuyang18 已提交
15
#include <mutex>  // NOLINT
16
#include <string>
D
dzhwinter 已提交
17
#include <unordered_map>
18
#include <utility>
19
#include <vector>
W
wanghuancoder 已提交
20

Y
Yu Yang 已提交
21
#include "paddle/fluid/memory/malloc.h"
22
#ifdef PADDLE_WITH_CUDA
23
#include "paddle/fluid/platform/cuda_helper.h"
Y
Yi Wang 已提交
24 25
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
G
Guo Sheng 已提交
26
#include "paddle/fluid/platform/dynload/cusolver.h"
27
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
W
Wu Yi 已提交
28
#include "paddle/fluid/platform/dynload/nccl.h"
W
Wu Yi 已提交
29
#endif
Y
Yi Wang 已提交
30
#include "paddle/fluid/platform/gpu_info.h"
Q
QI JUN 已提交
31
#endif
D
dzhwinter 已提交
32

33 34 35 36 37 38 39 40 41 42
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/cuda_helper.h"  // NOLINT
#include "paddle/fluid/platform/dynload/miopen.h"
#include "paddle/fluid/platform/dynload/rocblas.h"
#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/dynload/rccl.h"
#endif
#include "paddle/fluid/platform/gpu_info.h"  // NOLINT
#endif

43 44 45 46
#if defined(PADDLE_WITH_XPU_BKCL)
#include "xpu/bkcl.h"
#endif

T
tensor-tang 已提交
47
#ifdef PADDLE_WITH_MKLDNN
L
luotao1 已提交
48
#include "mkldnn.hpp"
49
#include "paddle/fluid/framework/data_layout.h"
T
tensor-tang 已提交
50 51
#endif

52
#include <map>
W
wanghuancoder 已提交
53

54
#include "glog/logging.h"
Y
Yi Wang 已提交
55 56
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
57
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
58
#include "paddle/fluid/platform/stream/cuda_stream.h"
S
sneaxiy 已提交
59
#endif
Q
qijun 已提交
60
#include "unsupported/Eigen/CXX11/Tensor"
Q
QI JUN 已提交
61

W
wanghuancoder 已提交
62 63 64 65 66
namespace Eigen {
struct DefaultDevice;
struct GpuDevice;
}  // namespace Eigen

67 68
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/xpu_header.h"
69
#include "paddle/fluid/platform/xpu_info.h"
70 71
#endif

Q
QI JUN 已提交
72 73 74
namespace paddle {
namespace platform {

75
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
76 77 78 79
/*Set the value of the global variable allow_tf32_cublas*/
void SetAllowTF32Cublas(bool active);
/*Get the global variable allow_tf32_cublas value*/
bool AllowTF32Cublas();
A
AshburnLee 已提交
80 81 82 83
/*Set the value of the global variable allow_tf32_cudnn*/
void SetAllowTF32Cudnn(bool active);
/*Get the global variable allow_tf32_cudnn value*/
bool AllowTF32Cudnn();
84 85
#endif  // PADDLE_WITH_CUDA

86 87 88 89 90 91 92 93 94 95
enum DeviceType {
  CPU = 0,
  CUDA = 1,
  XPU = 2,
};

constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUDA = DeviceType::CUDA;
constexpr DeviceType kXPU = DeviceType::XPU;

Q
QI JUN 已提交
96 97
class DeviceContext {
 public:
Z
Zeng Jinle 已提交
98
  virtual ~DeviceContext() PADDLE_MAY_THROW {}
L
liaogang 已提交
99
  virtual Place GetPlace() const = 0;
Q
QI JUN 已提交
100

101
  virtual void Wait() const {}
Q
QI JUN 已提交
102 103
};

Q
qijun 已提交
104 105
class CPUDeviceContext : public DeviceContext {
 public:
106
  CPUDeviceContext();
Q
qijun 已提交
107
  explicit CPUDeviceContext(CPUPlace place);
Q
qijun 已提交
108

109
  Eigen::DefaultDevice* eigen_device() const;
Q
qijun 已提交
110

L
liaogang 已提交
111
  Place GetPlace() const override;
Y
Yu Yang 已提交
112

Q
qijun 已提交
113
 private:
D
dzhwinter 已提交
114
  CPUPlace place_;
Q
qijun 已提交
115
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
Q
QI JUN 已提交
116 117
};

Y
Yang Yu 已提交
118 119 120 121 122 123 124 125
template <typename Place>
struct DefaultDeviceContextType;

template <>
struct DefaultDeviceContextType<platform::CPUPlace> {
  using TYPE = CPUDeviceContext;
};

126 127 128 129 130 131 132 133 134 135 136 137 138
#ifdef PADDLE_WITH_XPU
class XPUDeviceContext : public DeviceContext {
 public:
  XPUDeviceContext();
  explicit XPUDeviceContext(XPUPlace place);
  virtual ~XPUDeviceContext();
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
  Place GetPlace() const override;
  xpu::Context* x_context() const;

  /*! \brief  Wait for all operations completion in the stream. */
  void Wait() const override;

139
#ifdef PADDLE_WITH_XPU_BKCL
140
  /*! \brief  Return bkcl context. */
141 142 143 144 145 146
  BKCLContext_t bkcl_context() const { return bkcl_context_; }

  /*! \brief  Set bkcl context. */
  void set_bkcl_context(BKCLContext_t context) { bkcl_context_ = context; }
#endif

147 148 149
 private:
  XPUPlace place_;
  xpu::Context* context_;
150 151 152
#ifdef PADDLE_WITH_XPU_BKCL
  BKCLContext_t bkcl_context_;
#endif
153 154 155 156 157 158 159 160 161 162 163 164 165

  // Need to be the same with other DeviceContext,
  // Eventhough eigen_device_ is not used in XPU
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
  DISABLE_COPY_AND_ASSIGN(XPUDeviceContext);
};

template <>
struct DefaultDeviceContextType<platform::XPUPlace> {
  using TYPE = XPUDeviceContext;
};
#endif

166
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
167

168
class CudnnWorkspaceHandle;
W
wanghuancoder 已提交
169
class EigenCudaStreamDevice;
S
sneaxiy 已提交
170

171 172 173 174 175
class CUDAContext {
 public:
  CUDAContext() = default;
  explicit CUDAContext(
      const CUDAPlace& place,
176
      const stream::Priority& priority = stream::Priority::kNormal);
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191

  ~CUDAContext();

  const CUDAPlace& Place() const { return place_; }

  const std::unique_ptr<Eigen::GpuDevice>& EigenDevice() const {
    return eigen_device_;
  }

  const std::unique_ptr<EigenCudaStreamDevice>& EigenStream() const {
    return eigen_stream_;
  }

  const std::unique_ptr<stream::CUDAStream>& Stream() const { return stream_; }

192
  const gpuStream_t& RawStream() { return stream_->raw_stream(); }
193

194 195 196
#ifdef PADDLE_WITH_HIP
  const miopenHandle_t& CudnnHandle() const { return cudnn_handle_; }
#else
197
  const cudnnHandle_t& CudnnHandle() const { return cudnn_handle_; }
198
#endif
199

200
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
201 202 203
  const cusolverDnHandle_t& CusolverDnHandle() const {
    return cusolver_dn_handle_;
  }
204
#endif
G
Guo Sheng 已提交
205

206 207 208 209 210 211 212 213 214 215 216
  const std::unique_ptr<CublasHandleHolder>& CublasHandle() const {
    return cublas_handle_;
  }

  const std::unique_ptr<CublasHandleHolder>& CublasTensorCoreHandle() const {
    return cublas_tensor_core_handle_;
  }

  /*! \brief  Call cublas function safely. */
  template <typename Callback>
  inline void CublasCall(Callback&& callback) const {
217 218 219 220 221
    if (cublas_tf32_tensor_core_handle_) {
      cublas_tf32_tensor_core_handle_->Call(std::forward<Callback>(callback));
    } else {
      cublas_handle_->Call(std::forward<Callback>(callback));
    }
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
  }

  /*! \brief  Check whether tensor core is supported */
  bool tensor_core_available() const;

  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
  template <typename Callback>
  inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
    if (cublas_tensor_core_handle_) {
      cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
    } else {
      cublas_handle_->Call(std::forward<Callback>(callback));
    }
  }

 private:
  void InitEigenContext();

241 242 243 244 245
#ifdef PADDLE_WITH_HIP
  void InitCuBlasContext() {
    cublas_handle_.reset(new CublasHandleHolder(RawStream()));
  }
#else
246 247 248 249 250 251 252
  void InitCuBlasContext() {
    cublas_handle_.reset(
        new CublasHandleHolder(RawStream(), CUBLAS_DEFAULT_MATH));
    if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
      cublas_tensor_core_handle_.reset(
          new CublasHandleHolder(RawStream(), CUBLAS_TENSOR_OP_MATH));
253 254 255 256 257
#if CUDA_VERSION >= 11000
      cublas_tf32_tensor_core_handle_.reset(
          new CublasHandleHolder(RawStream(), CUBLAS_TF32_TENSOR_OP_MATH));
#endif  // CUDA_VERSION >= 11000
#endif  // CUDA_VERSION >= 9000
258 259
    }
  }
260
#endif
261 262 263

  void InitCuDNNContext() {
    if (dynload::HasCUDNN()) {
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
#ifdef PADDLE_WITH_HIP
      size_t miopen_major, miopen_minor, miopen_patch;
      PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenGetVersion(
          &miopen_major, &miopen_minor, &miopen_patch));
      auto local_miopen_version =
          (miopen_major * 1000 + miopen_minor * 100 + miopen_patch) / 100;
      auto compile_miopen_version = MIOPEN_VERSION / 100;
      if (local_miopen_version < static_cast<size_t>(compile_miopen_version)) {
        LOG_FIRST_N(WARNING, 1)
            << "WARNING: device: " << place_.device
            << ". The installed Paddle is compiled with MIOPEN "
            << compile_miopen_version / 10 << "." << compile_miopen_version % 10
            << ", but MIOPEN version in your machine is "
            << local_miopen_version / 10 << "." << local_miopen_version % 10
            << ", which may cause serious incompatible bug. "
            << "Please recompile or reinstall Paddle with compatible MIOPEN "
               "version.";
      }
      PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenCreate(&cudnn_handle_));
      PADDLE_ENFORCE_CUDA_SUCCESS(
          dynload::miopenSetStream(cudnn_handle_, RawStream()));
#else
286 287 288 289 290 291 292 293 294 295 296 297 298
      auto local_cudnn_version = dynload::cudnnGetVersion() / 100;
      auto compile_cudnn_version = CUDNN_VERSION / 100;
      if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
        LOG_FIRST_N(WARNING, 1)
            << "WARNING: device: " << place_.device
            << ". The installed Paddle is compiled with CUDNN "
            << compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
            << ", but CUDNN version in your machine is "
            << local_cudnn_version / 10 << "." << local_cudnn_version % 10
            << ", which may cause serious incompatible bug. "
            << "Please recompile or reinstall Paddle with compatible CUDNN "
               "version.";
      }
299 300
      PADDLE_RETRY_CUDA_SUCCESS(dynload::cudnnCreate(&cudnn_handle_));
      PADDLE_RETRY_CUDA_SUCCESS(
301
          dynload::cudnnSetStream(cudnn_handle_, RawStream()));
302
#endif
303 304 305 306 307
    } else {
      cudnn_handle_ = nullptr;
    }
  }

308
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
309
  void InitCuSolverContext() {
310 311
    PADDLE_RETRY_CUDA_SUCCESS(dynload::cusolverDnCreate(&cusolver_dn_handle_));
    PADDLE_RETRY_CUDA_SUCCESS(
G
Guo Sheng 已提交
312 313
        dynload::cusolverDnSetStream(cusolver_dn_handle_, RawStream()));
  }
314
#endif
G
Guo Sheng 已提交
315

316 317
  void DestoryCuDNNContext() {
    if (cudnn_handle_) {
318 319 320
#ifdef PADDLE_WITH_HIP
      PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenDestroy(cudnn_handle_));
#else
321
      PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroy(cudnn_handle_));
322
#endif
323 324 325 326 327 328 329
    }
    cudnn_handle_ = nullptr;
  }

  void DestoryCuBlasContext() {
    cublas_handle_.reset();
    cublas_tensor_core_handle_.reset();
330
    cublas_tf32_tensor_core_handle_.reset();
331 332
  }

333
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
334 335 336 337 338 339
  void DestoryCuSolverContext() {
    if (cusolver_dn_handle_) {
      PADDLE_ENFORCE_CUDA_SUCCESS(
          dynload::cusolverDnDestroy(cusolver_dn_handle_));
    }
  }
340
#endif
G
Guo Sheng 已提交
341

342 343 344 345
  CUDAPlace place_;
  std::unique_ptr<Eigen::GpuDevice> eigen_device_;
  std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
  std::unique_ptr<stream::CUDAStream> stream_;
346 347 348
#ifdef PADDLE_WITH_HIP
  miopenHandle_t cudnn_handle_;
#else
349
  cudnnHandle_t cudnn_handle_;
350
#endif
351 352
  std::unique_ptr<CublasHandleHolder> cublas_handle_;
  std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
353
  std::unique_ptr<CublasHandleHolder> cublas_tf32_tensor_core_handle_;
354
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
355
  cusolverDnHandle_t cusolver_dn_handle_;
356
#endif
357 358 359
  DISABLE_COPY_AND_ASSIGN(CUDAContext);
};

360
class CUDADeviceContext : public DeviceContext {
Q
QI JUN 已提交
361
 public:
D
dzhwinter 已提交
362
  explicit CUDADeviceContext(CUDAPlace place);
363
  virtual ~CUDADeviceContext();
Q
QI JUN 已提交
364

365
  /*! \brief  Wait for all operations completion in the stream. */
366
  void Wait() const override;
Q
QI JUN 已提交
367

368
  /*! \brief  Return place in the device context. */
L
liaogang 已提交
369
  Place GetPlace() const override;
370

K
Kexin Zhao 已提交
371
  /*! \brief  Return compute capability in the device context. */
K
Kexin Zhao 已提交
372 373
  int GetComputeCapability() const;

374 375 376
  /*! \brief  Return the max physical thread count in the device context */
  int GetMaxPhysicalThreadCount() const;

377 378 379 380 381 382
  /*! \brief  Return the SM count in the device context */
  int GetSMCount() const;

  /*! \brief  Return the Max thread num of block in the device context */
  int GetMaxThreadsPerBlock() const;

383 384 385
  /*! \brief  Return the max grid dim size in the device context */
  dim3 GetCUDAMaxGridDimSize() const;

386 387 388
  /*! \brief  Return eigen device in the device context. */
  Eigen::GpuDevice* eigen_device() const;

389 390 391
  /*! \brief  Call cublas function safely. */
  template <typename Callback>
  inline void CublasCall(Callback&& callback) const {
392
    return context()->CublasCall(callback);
393 394 395 396 397 398 399 400 401
  }

  /*! \brief  Check whether tensor core is supported */
  bool tensor_core_available() const;

  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
  template <typename Callback>
  inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
402
    return context()->TensorCoreCublasCallIfAvailable(callback);
403
  }
S
sneaxiy 已提交
404

405 406 407 408
/*! \brief  Return cudnn  handle in the device context. */
#ifdef PADDLE_WITH_HIP
  miopenHandle_t cudnn_handle() const;
#else
409
  cudnnHandle_t cudnn_handle() const;
410
#endif
411

S
sneaxiy 已提交
412 413 414 415 416 417 418 419 420
  /*! \brief  Return a cudnn workspace handle to call multiple cudnn
   *  functions without interrupting by other threads.
   *  Once the first cudnn function is called by the handle, a lock
   *  would be acquired to prevent other threads from accessing the
   *  workspace. Once the handle is destructed, the lock would be released.
   *  CudnnWorkspaceHandle is an RAII object to implement thread-safe
   *  sequential cudnn function calls. */
  CudnnWorkspaceHandle cudnn_workspace_handle() const;

421
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
422
  cusolverDnHandle_t cusolver_dn_handle() const;
423
#endif
G
Guo Sheng 已提交
424

Q
init  
qijun 已提交
425
  /*! \brief  Return cuda stream in the device context. */
426
  gpuStream_t stream() const;
Q
QI JUN 已提交
427

428
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
Q
qingqing01 已提交
429 430 431 432 433
  /*! \brief  Return nccl communicators. */
  ncclComm_t nccl_comm() const { return nccl_comm_; }

  /*! \brief  Set nccl communicators. */
  void set_nccl_comm(ncclComm_t comm) { nccl_comm_ = comm; }
Q
qingqing01 已提交
434
#endif
Q
qingqing01 已提交
435

Y
Yu Yang 已提交
436
  template <typename Callback>
437
  void RecordEvent(gpuEvent_t ev, Callback callback) const {
438
    return context()->Stream()->RecordEvent(ev, callback);
Y
Yu Yang 已提交
439 440
  }

S
sneaxiy 已提交
441 442
  template <typename Callback>
  void AddStreamCallback(Callback&& callback) const {
443 444 445 446 447
    return context()->Stream()->AddCallback(callback);
  }

  void WaitStreamCallback() const {
    return context()->Stream()->WaitCallback();
448 449
  }

450
  void ResetDefaultContext(const stream::Priority& priority) {
451 452 453
    default_ctx_.reset(new CUDAContext(place_, priority));
  }

454
  void ResetThreadContext(const stream::Priority& priority) {
455 456 457 458 459 460 461 462 463 464
    std::lock_guard<std::mutex> guard(ctx_mtx_);
    thread_ctx_[this].reset(new CUDAContext(place_, priority));
  }

  std::shared_ptr<CUDAContext> context() const {
    if (!thread_ctx_.count(this)) {
      return default_ctx_;
    }
    return thread_ctx_.at(this);
  }
S
sneaxiy 已提交
465

Q
QI JUN 已提交
466
 private:
D
dzhwinter 已提交
467
  CUDAPlace place_;
468
  std::shared_ptr<CUDAContext> default_ctx_;
Q
QI JUN 已提交
469

470 471 472 473 474 475
  // The thread_local static variable will be released before the
  // global static variable, so avoid using it in dtor.
  static thread_local std::unordered_map<const CUDADeviceContext*,
                                         std::shared_ptr<CUDAContext>>
      thread_ctx_;
  static thread_local std::mutex ctx_mtx_;
476

477 478
  mutable std::mutex cudnn_handle_mtx_;

479
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
Q
qingqing01 已提交
480 481 482 483 484 485
  // NCCL communicator (single process version) for NCCL collective operations.
  // NCCL collective operations provides fast collectives over multiple GPUs
  // both within and across nodes.
  // But, this collectives is used for collectives over multiple GPUs within
  // nodes.
  ncclComm_t nccl_comm_{nullptr};
Q
qingqing01 已提交
486
#endif
Q
qingqing01 已提交
487

C
chengduo 已提交
488 489 490 491 492
  int compute_capability_;
  int runtime_version_;
  int driver_version_;
  int multi_process_;
  int max_threads_per_mp_;
493
  int max_threads_per_block_;
494
  dim3 max_grid_dim_size_;
Y
yuyang18 已提交
495

496
  DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
Q
QI JUN 已提交
497
};
Q
qijun 已提交
498

499 500
class CudnnWorkspaceHandle {
 public:
501 502
  inline CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx, std::mutex* mtx)
      : device_context_(dev_ctx), mtx_(mtx) {}
503 504 505 506 507 508 509 510

  template <typename Callback>
  inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_bytes) {
    if (required_workspace_bytes > WorkspaceSize()) {
      ReallocWorkspace(required_workspace_bytes);
    }
    VLOG(2) << "Cudnn workspace size at RunFunc: "
            << static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB";
511 512 513 514
    {
      std::lock_guard<std::mutex> guard(*mtx_);
      cudnn_func(allocation_ ? allocation_->ptr() : nullptr);
    }
515 516 517 518 519 520 521 522 523 524 525 526 527
  }

  /*! \brief Thread which call RunFuncSync() would release gpu memory after
   *  running the function. Currently this function is only used when cudnn
   *  exhaustive searching and callers have to guarantee that the input function
   *  is host blocking */
  template <typename Callback>
  inline void RunFuncSync(Callback&& cudnn_func,
                          size_t required_workspace_bytes) {
    RunFunc(cudnn_func, required_workspace_bytes);
    ResetWorkspace();
  }

528
  void ReallocWorkspace(size_t required_workspace_bytes);
529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544

  inline void ResetWorkspace() { allocation_ = nullptr; }

  inline size_t WorkspaceSize() {
    if (allocation_ == nullptr) {
      return 0;
    }
    return allocation_->size();
  }

  CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
  CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;

 private:
  memory::allocation::AllocationPtr allocation_;
  const CUDADeviceContext& device_context_;
545
  std::mutex* mtx_;
546 547
};

Y
Yang Yu 已提交
548 549
template <>
struct DefaultDeviceContextType<platform::CUDAPlace> {
Y
Yang Yu 已提交
550
  using TYPE = CUDADeviceContext;
Y
Yang Yu 已提交
551 552
};

C
chengduoZH 已提交
553
// Currently, CUDAPinnedDeviceContext is only used to data copying.
C
chengduoZH 已提交
554 555 556 557 558 559
class CUDAPinnedDeviceContext : public DeviceContext {
 public:
  CUDAPinnedDeviceContext();
  explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place);

  Place GetPlace() const override;
C
chengduoZH 已提交
560

C
chengduoZH 已提交
561 562 563 564 565 566 567 568 569 570 571
  Eigen::DefaultDevice* eigen_device() const;

 private:
  CUDAPinnedPlace place_;
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

template <>
struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
  using TYPE = CUDAPinnedDeviceContext;
};
Q
QI JUN 已提交
572
#endif
Q
qijun 已提交
573

T
tensor-tang 已提交
574
#ifdef PADDLE_WITH_MKLDNN
575 576 577 578 579 580

class MKLDNNDeviceContextThreadLocals {
  // default mkldnn session id

  typedef MKLDNNDeviceContextThreadLocals self;
  struct Body {
581
    bool said_once = false;
582 583 584 585 586 587 588 589 590 591 592
    size_t cur_mkldnn_session_id;
    // Current data input shape string.
    // - For fixed-shape, it's a null string in default.
    // - For dynamic-shape, it's user specific.
    std::string cur_input_shape_str;
    // the cache capacity of different input shapes for MKLDNN.
    // Default 1 means fixed input shape, not dynamic shape.
    int cur_input_shape_cache_capacity;
    // Recently registered data_format. This is needed to
    // know for converting MKL-DNN Tensor to non MKL-DNN
    paddle::framework::DataLayout cur_paddle_data_layout;
593 594 595
    // MKL-DNN stream used for execution of primitives (per-thread)
    mkldnn::engine cur_engine;
    mkldnn::stream cur_stream;
596 597

    Body();
598
    ~Body();
599 600 601 602 603 604
    void set_cur_mkldnn_session_id(size_t sid);
    size_t get_cur_mkldnn_session_id(void);
    void set_cur_input_shape_str(std::string input_shape_str);
    void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
    void set_cur_paddle_data_layout(framework::DataLayout dl);
    framework::DataLayout get_cur_paddle_data_layout(void);
605
    void log_lib_version(void);
606 607
    const mkldnn::engine& get_engine(void);
    mkldnn::stream& get_stream(void);
608 609 610 611 612 613 614 615 616 617 618 619 620 621 622
  };
  MKLDNNDeviceContextThreadLocals() = default;
  MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
      delete;

 public:
  // default mkldnn session id
  static constexpr size_t kMKLDNNSessionID_Default = 0;
  // mkldnn session id for cache clearing mode
  static constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
  static Body& fetch() {
    thread_local Body b;
    return b;
  }
};
S
Sylwester Fraczek 已提交
623

T
tensor-tang 已提交
624 625
class MKLDNNDeviceContext : public CPUDeviceContext {
 public:
626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642
  template <class T>
  using BlobPtr_t = std::shared_ptr<T>;
  template <class P1, class P2>
  using umap_value_smart_t = std::unordered_map<P1, BlobPtr_t<P2>>;
  template <class T>
  using umap_key_string_t = umap_value_smart_t<std::string, T>;

  // Following three maps are used to cache MKLDNN primitives.
  // There relations are:
  // - BlobMap = Map<cur_thread_id, ShapeBlob>
  // - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
  // - KeyBlob  = Map<blob_name, blob>

  using KeyBlob = umap_key_string_t<void>;
  using ShapeBlob = umap_key_string_t<KeyBlob>;
  using BlobMap = umap_value_smart_t<int, ShapeBlob>;

T
tensor-tang 已提交
643 644 645
  explicit MKLDNNDeviceContext(CPUPlace place);

  /* \brief  Get the active engine */
646
  const mkldnn::engine& GetEngine() const { return tls().get_engine(); }
T
tensor-tang 已提交
647

648
  // Remove all entries from the blob map
649 650
  void ResetBlobMap();

651 652 653 654
  // Set a suffix to be added to key
  void SetKeySuffix(const std::string& suffix) { key_suffix_ = suffix; }
  const std::string& GetKeySuffix(void) const { return key_suffix_; }

655
  // Disable adding  thread ID to the key
656 657
  void DisableThreadInfoInKey(void) { key_attach_thread_id_ = false; }
  bool IsThreadIdUsedInKey(void) const { return key_attach_thread_id_; }
658

659 660
  // Prevent next ResetBlobMap()
  void BlockNextCacheClearing();
661

662 663 664
  // Get the ShapeBlob size in cur_mkldnn_session_id.
  size_t GetShapeBlobSize() const;

665 666
  // Set data to blob (i.e. name/data pair). Create blob if not existing
  void SetBlob(const std::string& name, std::shared_ptr<void> data) const;
T
tensor-tang 已提交
667

668 669 670
  // Calculate number of oneDNN objects cached
  unsigned int GetCachedObjectsNumber(void);

671 672
  // Find a saved blob. Return nullptr if not found
  std::shared_ptr<void> GetBlob(const std::string& name) const;
T
tensor-tang 已提交
673

674 675 676 677
  static auto tls() -> decltype(MKLDNNDeviceContextThreadLocals::fetch()) {
    return MKLDNNDeviceContextThreadLocals::fetch();
  }

T
tensor-tang 已提交
678
 private:
679 680
  std::shared_ptr<BlobMap> p_blobmap_;
  std::shared_ptr<std::mutex> p_mutex_;
681
  bool block_next_cache_clearing_ = false;
682
  std::string key_suffix_;  // Key identifying current Executor
683
  bool key_attach_thread_id_ = true;
T
tensor-tang 已提交
684 685 686
};
#endif

D
dzhwinter 已提交
687 688 689 690 691
/*! \brief device context pool singleton */
class DeviceContextPool {
 public:
  explicit DeviceContextPool(const std::vector<platform::Place>& places);

Y
Yang Yu 已提交
692
  static DeviceContextPool& Instance() {
G
GaoWei8 已提交
693 694 695
    PADDLE_ENFORCE_NOT_NULL(pool,
                            platform::errors::PreconditionNotMet(
                                "Need to Create DeviceContextPool firstly!"));
D
dzhwinter 已提交
696 697 698 699
    return *pool;
  }

  /*! \brief  Create should only called by Init function */
Y
Yang Yu 已提交
700
  static DeviceContextPool& Init(const std::vector<platform::Place>& places) {
D
dzhwinter 已提交
701 702 703 704 705 706
    if (pool == nullptr) {
      pool = new DeviceContextPool(places);
    }
    return *pool;
  }

707 708
  static void SetPool(DeviceContextPool* dev_pool) { pool = dev_pool; }

D
dzhwinter 已提交
709
  /*! \brief  Return handle of single device context. */
Y
Yu Yang 已提交
710
  platform::DeviceContext* Get(const platform::Place& place);
D
dzhwinter 已提交
711

Y
Yang Yu 已提交
712 713 714 715 716 717 718
  template <typename Place>
  const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
      const Place& place) {
    return reinterpret_cast<
        const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place));
  }

719 720
  size_t size() const { return device_contexts_.size(); }

D
dzhwinter 已提交
721 722
 private:
  static DeviceContextPool* pool;
723 724
  std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>
      device_contexts_;
D
dzhwinter 已提交
725 726 727
  DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};

Q
QI JUN 已提交
728 729
}  // namespace platform
}  // namespace paddle