device_context.h 25.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
QI JUN 已提交
2 3 4 5 6 7 8 9 10 11 12
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

13
#include <future>  // NOLINT
D
dzhwinter 已提交
14
#include <memory>
Y
yuyang18 已提交
15
#include <mutex>  // NOLINT
16
#include <string>
D
dzhwinter 已提交
17
#include <unordered_map>
18
#include <utility>
19
#include <vector>
W
wanghuancoder 已提交
20

Y
Yu Yang 已提交
21
#include "paddle/fluid/memory/malloc.h"
22
#ifdef PADDLE_WITH_CUDA
23
#include "paddle/fluid/platform/cuda_helper.h"
Y
Yi Wang 已提交
24 25
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
G
Guo Sheng 已提交
26
#include "paddle/fluid/platform/dynload/cusolver.h"
27
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
W
Wu Yi 已提交
28
#include "paddle/fluid/platform/dynload/nccl.h"
W
Wu Yi 已提交
29
#endif
Y
Yi Wang 已提交
30
#include "paddle/fluid/platform/gpu_info.h"
Q
QI JUN 已提交
31
#endif
D
dzhwinter 已提交
32

33 34 35 36 37 38 39 40 41 42
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/cuda_helper.h"  // NOLINT
#include "paddle/fluid/platform/dynload/miopen.h"
#include "paddle/fluid/platform/dynload/rocblas.h"
#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/dynload/rccl.h"
#endif
#include "paddle/fluid/platform/gpu_info.h"  // NOLINT
#endif

43 44 45 46
#if defined(PADDLE_WITH_XPU_BKCL)
#include "xpu/bkcl.h"
#endif

T
tensor-tang 已提交
47
#ifdef PADDLE_WITH_MKLDNN
L
luotao1 已提交
48
#include "mkldnn.hpp"
49
#include "paddle/fluid/framework/data_layout.h"
T
tensor-tang 已提交
50 51
#endif

52
#include <map>
W
wanghuancoder 已提交
53

54
#include "glog/logging.h"
Y
Yi Wang 已提交
55 56
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
57
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
58
#include "paddle/fluid/platform/stream/cuda_stream.h"
S
sneaxiy 已提交
59
#endif
60 61 62
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/stream/npu_stream.h"
#endif
Q
qijun 已提交
63
#include "unsupported/Eigen/CXX11/Tensor"
Q
QI JUN 已提交
64

W
wanghuancoder 已提交
65 66 67 68 69
namespace Eigen {
struct DefaultDevice;
struct GpuDevice;
}  // namespace Eigen

70 71
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/xpu_header.h"
72
#include "paddle/fluid/platform/xpu_info.h"
73 74
#endif

75 76 77 78 79
#ifdef PADDLE_WITH_ASCEND_CL
#include "acl/acl.h"
#include "paddle/fluid/platform/npu_info.h"
#endif

Q
QI JUN 已提交
80 81 82
namespace paddle {
namespace platform {

83
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
84 85 86 87
/*Set the value of the global variable allow_tf32_cublas*/
void SetAllowTF32Cublas(bool active);
/*Get the global variable allow_tf32_cublas value*/
bool AllowTF32Cublas();
A
AshburnLee 已提交
88
extern bool allow_tf32_cudnn;
A
AshburnLee 已提交
89 90 91 92
/*Set the value of the global variable allow_tf32_cudnn*/
void SetAllowTF32Cudnn(bool active);
/*Get the global variable allow_tf32_cudnn value*/
bool AllowTF32Cudnn();
93 94
#endif  // PADDLE_WITH_CUDA

95 96 97 98
enum DeviceType {
  CPU = 0,
  CUDA = 1,
  XPU = 2,
99
  NPU = 3,
100 101 102 103 104
};

constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUDA = DeviceType::CUDA;
constexpr DeviceType kXPU = DeviceType::XPU;
105
constexpr DeviceType kNPU = DeviceType::NPU;
106

Q
QI JUN 已提交
107 108
class DeviceContext {
 public:
Z
Zeng Jinle 已提交
109
  virtual ~DeviceContext() PADDLE_MAY_THROW {}
L
liaogang 已提交
110
  virtual Place GetPlace() const = 0;
Q
QI JUN 已提交
111

112
  virtual void Wait() const {}
Q
QI JUN 已提交
113 114
};

Q
qijun 已提交
115 116
class CPUDeviceContext : public DeviceContext {
 public:
117
  CPUDeviceContext();
Q
qijun 已提交
118
  explicit CPUDeviceContext(CPUPlace place);
Q
qijun 已提交
119

120
  Eigen::DefaultDevice* eigen_device() const;
Q
qijun 已提交
121

L
liaogang 已提交
122
  Place GetPlace() const override;
Y
Yu Yang 已提交
123

Q
qijun 已提交
124
 private:
D
dzhwinter 已提交
125
  CPUPlace place_;
Q
qijun 已提交
126
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
Q
QI JUN 已提交
127 128
};

Y
Yang Yu 已提交
129 130 131 132 133 134 135 136
template <typename Place>
struct DefaultDeviceContextType;

template <>
struct DefaultDeviceContextType<platform::CPUPlace> {
  using TYPE = CPUDeviceContext;
};

137 138 139 140 141 142 143 144 145 146 147 148 149
#ifdef PADDLE_WITH_XPU
class XPUDeviceContext : public DeviceContext {
 public:
  XPUDeviceContext();
  explicit XPUDeviceContext(XPUPlace place);
  virtual ~XPUDeviceContext();
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
  Place GetPlace() const override;
  xpu::Context* x_context() const;

  /*! \brief  Wait for all operations completion in the stream. */
  void Wait() const override;

150
#ifdef PADDLE_WITH_XPU_BKCL
151
  /*! \brief  Return bkcl context. */
152 153 154 155 156 157
  BKCLContext_t bkcl_context() const { return bkcl_context_; }

  /*! \brief  Set bkcl context. */
  void set_bkcl_context(BKCLContext_t context) { bkcl_context_ = context; }
#endif

158 159 160
 private:
  XPUPlace place_;
  xpu::Context* context_;
161 162 163
#ifdef PADDLE_WITH_XPU_BKCL
  BKCLContext_t bkcl_context_;
#endif
164 165 166 167 168 169 170 171 172 173 174 175 176

  // Need to be the same with other DeviceContext,
  // Eventhough eigen_device_ is not used in XPU
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
  DISABLE_COPY_AND_ASSIGN(XPUDeviceContext);
};

template <>
struct DefaultDeviceContextType<platform::XPUPlace> {
  using TYPE = XPUDeviceContext;
};
#endif

177 178 179 180 181 182 183 184
#ifdef PADDLE_WITH_ASCEND_CL
class NPUDeviceContext : public DeviceContext {
 public:
  explicit NPUDeviceContext(NPUPlace place);
  virtual ~NPUDeviceContext();
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
  Place GetPlace() const override;
  aclrtContext context() const;
185

186 187 188 189 190 191
  /*! \brief  Wait for all operations completion in the stream. */
  void Wait() const override;

  /*! \brief  Return npu stream in the device context. */
  aclrtStream stream() const;

192 193 194 195 196 197 198
  template <typename Callback>
  void AddStreamCallback(Callback&& callback) const {
    return stream_->AddCallback(callback);
  }

  void WaitStreamCallback() const { return stream_->WaitCallback(); }

199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
#if defined(PADDLE_WITH_ASCEND_CL)
  /*! \brief  Return hccl communicators. */
  HcclComm hccl_comm() const { return hccl_comm_; }

  /*! \brief  Set hccl communicators. */
  void set_hccl_comm(HcclComm comm) { hccl_comm_ = comm; }
#endif

  // template <typename Callback>
  // void AddStreamCallback(Callback&& callback) const {
  //   return stream_->AddCallback(callback);
  // }

  // void WaitStreamCallback() const { return stream_->WaitCallback(); }

214 215 216
 private:
  NPUPlace place_;
  aclrtContext context_;
217 218 219 220

#ifdef PADDLE_WITH_ASCEND_CL
  // HCCLContext_t hccl_context_;
  HcclComm hccl_comm_{nullptr};
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
#endif

  // Need to be the same with other DeviceContext,
  // Eventhough eigen_device_ is not used in NPU
  // NOTE(zhiqiu): why need?
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
  std::shared_ptr<stream::NPUStream> stream_;

  DISABLE_COPY_AND_ASSIGN(NPUDeviceContext);
};

template <>
struct DefaultDeviceContextType<platform::NPUPlace> {
  using TYPE = NPUDeviceContext;
};
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256

// Currently, NPUPinnedDeviceContext is only used to data copying.
class NPUPinnedDeviceContext : public DeviceContext {
 public:
  NPUPinnedDeviceContext();
  explicit NPUPinnedDeviceContext(NPUPinnedPlace place);

  Place GetPlace() const override;

  Eigen::DefaultDevice* eigen_device() const;

 private:
  NPUPinnedPlace place_;
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

template <>
struct DefaultDeviceContextType<platform::NPUPinnedPlace> {
  using TYPE = NPUPinnedDeviceContext;
};

257 258 259
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
260
class CudnnWorkspaceHandle;
W
wanghuancoder 已提交
261
class EigenCudaStreamDevice;
S
sneaxiy 已提交
262

263 264 265 266 267
class CUDAContext {
 public:
  CUDAContext() = default;
  explicit CUDAContext(
      const CUDAPlace& place,
268
      const stream::Priority& priority = stream::Priority::kNormal);
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283

  ~CUDAContext();

  const CUDAPlace& Place() const { return place_; }

  const std::unique_ptr<Eigen::GpuDevice>& EigenDevice() const {
    return eigen_device_;
  }

  const std::unique_ptr<EigenCudaStreamDevice>& EigenStream() const {
    return eigen_stream_;
  }

  const std::unique_ptr<stream::CUDAStream>& Stream() const { return stream_; }

284
  const gpuStream_t& RawStream() { return stream_->raw_stream(); }
285

286 287 288
#ifdef PADDLE_WITH_HIP
  const miopenHandle_t& CudnnHandle() const { return cudnn_handle_; }
#else
289
  const cudnnHandle_t& CudnnHandle() const { return cudnn_handle_; }
290
#endif
291

292
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
293 294 295
  const cusolverDnHandle_t& CusolverDnHandle() const {
    return cusolver_dn_handle_;
  }
296
#endif
G
Guo Sheng 已提交
297

298 299 300 301 302 303 304 305 306 307 308
  const std::unique_ptr<CublasHandleHolder>& CublasHandle() const {
    return cublas_handle_;
  }

  const std::unique_ptr<CublasHandleHolder>& CublasTensorCoreHandle() const {
    return cublas_tensor_core_handle_;
  }

  /*! \brief  Call cublas function safely. */
  template <typename Callback>
  inline void CublasCall(Callback&& callback) const {
309 310 311 312 313
    if (cublas_tf32_tensor_core_handle_) {
      cublas_tf32_tensor_core_handle_->Call(std::forward<Callback>(callback));
    } else {
      cublas_handle_->Call(std::forward<Callback>(callback));
    }
314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
  }

  /*! \brief  Check whether tensor core is supported */
  bool tensor_core_available() const;

  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
  template <typename Callback>
  inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
    if (cublas_tensor_core_handle_) {
      cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
    } else {
      cublas_handle_->Call(std::forward<Callback>(callback));
    }
  }

 private:
  void InitEigenContext();

333 334 335 336 337
#ifdef PADDLE_WITH_HIP
  void InitCuBlasContext() {
    cublas_handle_.reset(new CublasHandleHolder(RawStream()));
  }
#else
338 339 340 341 342 343 344
  void InitCuBlasContext() {
    cublas_handle_.reset(
        new CublasHandleHolder(RawStream(), CUBLAS_DEFAULT_MATH));
    if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
      cublas_tensor_core_handle_.reset(
          new CublasHandleHolder(RawStream(), CUBLAS_TENSOR_OP_MATH));
345 346 347 348 349
#if CUDA_VERSION >= 11000
      cublas_tf32_tensor_core_handle_.reset(
          new CublasHandleHolder(RawStream(), CUBLAS_TF32_TENSOR_OP_MATH));
#endif  // CUDA_VERSION >= 11000
#endif  // CUDA_VERSION >= 9000
350 351
    }
  }
352
#endif
353 354 355

  void InitCuDNNContext() {
    if (dynload::HasCUDNN()) {
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
#ifdef PADDLE_WITH_HIP
      size_t miopen_major, miopen_minor, miopen_patch;
      PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenGetVersion(
          &miopen_major, &miopen_minor, &miopen_patch));
      auto local_miopen_version =
          (miopen_major * 1000 + miopen_minor * 100 + miopen_patch) / 100;
      auto compile_miopen_version = MIOPEN_VERSION / 100;
      if (local_miopen_version < static_cast<size_t>(compile_miopen_version)) {
        LOG_FIRST_N(WARNING, 1)
            << "WARNING: device: " << place_.device
            << ". The installed Paddle is compiled with MIOPEN "
            << compile_miopen_version / 10 << "." << compile_miopen_version % 10
            << ", but MIOPEN version in your machine is "
            << local_miopen_version / 10 << "." << local_miopen_version % 10
            << ", which may cause serious incompatible bug. "
            << "Please recompile or reinstall Paddle with compatible MIOPEN "
               "version.";
      }
      PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenCreate(&cudnn_handle_));
      PADDLE_ENFORCE_CUDA_SUCCESS(
          dynload::miopenSetStream(cudnn_handle_, RawStream()));
#else
378 379 380 381 382 383 384 385 386 387 388 389 390
      auto local_cudnn_version = dynload::cudnnGetVersion() / 100;
      auto compile_cudnn_version = CUDNN_VERSION / 100;
      if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
        LOG_FIRST_N(WARNING, 1)
            << "WARNING: device: " << place_.device
            << ". The installed Paddle is compiled with CUDNN "
            << compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
            << ", but CUDNN version in your machine is "
            << local_cudnn_version / 10 << "." << local_cudnn_version % 10
            << ", which may cause serious incompatible bug. "
            << "Please recompile or reinstall Paddle with compatible CUDNN "
               "version.";
      }
391 392
      PADDLE_RETRY_CUDA_SUCCESS(dynload::cudnnCreate(&cudnn_handle_));
      PADDLE_RETRY_CUDA_SUCCESS(
393
          dynload::cudnnSetStream(cudnn_handle_, RawStream()));
394
#endif
395 396 397 398 399
    } else {
      cudnn_handle_ = nullptr;
    }
  }

400
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
401
  void InitCuSolverContext() {
402 403
    PADDLE_RETRY_CUDA_SUCCESS(dynload::cusolverDnCreate(&cusolver_dn_handle_));
    PADDLE_RETRY_CUDA_SUCCESS(
G
Guo Sheng 已提交
404 405
        dynload::cusolverDnSetStream(cusolver_dn_handle_, RawStream()));
  }
406
#endif
G
Guo Sheng 已提交
407

408 409
  void DestoryCuDNNContext() {
    if (cudnn_handle_) {
410 411 412
#ifdef PADDLE_WITH_HIP
      PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenDestroy(cudnn_handle_));
#else
413
      PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroy(cudnn_handle_));
414
#endif
415 416 417 418 419 420 421
    }
    cudnn_handle_ = nullptr;
  }

  void DestoryCuBlasContext() {
    cublas_handle_.reset();
    cublas_tensor_core_handle_.reset();
422
    cublas_tf32_tensor_core_handle_.reset();
423 424
  }

425
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
426 427 428 429 430 431
  void DestoryCuSolverContext() {
    if (cusolver_dn_handle_) {
      PADDLE_ENFORCE_CUDA_SUCCESS(
          dynload::cusolverDnDestroy(cusolver_dn_handle_));
    }
  }
432
#endif
G
Guo Sheng 已提交
433

434 435 436 437
  CUDAPlace place_;
  std::unique_ptr<Eigen::GpuDevice> eigen_device_;
  std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
  std::unique_ptr<stream::CUDAStream> stream_;
438 439 440
#ifdef PADDLE_WITH_HIP
  miopenHandle_t cudnn_handle_;
#else
441
  cudnnHandle_t cudnn_handle_;
442
#endif
443 444
  std::unique_ptr<CublasHandleHolder> cublas_handle_;
  std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
445
  std::unique_ptr<CublasHandleHolder> cublas_tf32_tensor_core_handle_;
446
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
447
  cusolverDnHandle_t cusolver_dn_handle_;
448
#endif
449 450 451
  DISABLE_COPY_AND_ASSIGN(CUDAContext);
};

452
class CUDADeviceContext : public DeviceContext {
Q
QI JUN 已提交
453
 public:
D
dzhwinter 已提交
454
  explicit CUDADeviceContext(CUDAPlace place);
455
  virtual ~CUDADeviceContext();
Q
QI JUN 已提交
456

457
  /*! \brief  Wait for all operations completion in the stream. */
458
  void Wait() const override;
Q
QI JUN 已提交
459

460
  /*! \brief  Return place in the device context. */
L
liaogang 已提交
461
  Place GetPlace() const override;
462

K
Kexin Zhao 已提交
463
  /*! \brief  Return compute capability in the device context. */
K
Kexin Zhao 已提交
464 465
  int GetComputeCapability() const;

466 467 468
  /*! \brief  Return the max physical thread count in the device context */
  int GetMaxPhysicalThreadCount() const;

469 470 471 472 473 474
  /*! \brief  Return the SM count in the device context */
  int GetSMCount() const;

  /*! \brief  Return the Max thread num of block in the device context */
  int GetMaxThreadsPerBlock() const;

475 476 477
  /*! \brief  Return the max grid dim size in the device context */
  dim3 GetCUDAMaxGridDimSize() const;

478 479 480
  /*! \brief  Return eigen device in the device context. */
  Eigen::GpuDevice* eigen_device() const;

481 482 483
  /*! \brief  Call cublas function safely. */
  template <typename Callback>
  inline void CublasCall(Callback&& callback) const {
484
    return context()->CublasCall(callback);
485 486 487 488 489 490 491 492 493
  }

  /*! \brief  Check whether tensor core is supported */
  bool tensor_core_available() const;

  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
  template <typename Callback>
  inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
494
    return context()->TensorCoreCublasCallIfAvailable(callback);
495
  }
S
sneaxiy 已提交
496

497 498 499 500
/*! \brief  Return cudnn  handle in the device context. */
#ifdef PADDLE_WITH_HIP
  miopenHandle_t cudnn_handle() const;
#else
501
  cudnnHandle_t cudnn_handle() const;
502
#endif
503

504 505 506 507
/*! \brief  Return cublas handle in the device context. */
#ifdef PADDLE_WITH_HIP
  rocblas_handle cublas_handle() const;
#else
508
  cublasHandle_t cublas_handle() const;
509
#endif
510

S
sneaxiy 已提交
511 512 513 514 515 516 517 518 519
  /*! \brief  Return a cudnn workspace handle to call multiple cudnn
   *  functions without interrupting by other threads.
   *  Once the first cudnn function is called by the handle, a lock
   *  would be acquired to prevent other threads from accessing the
   *  workspace. Once the handle is destructed, the lock would be released.
   *  CudnnWorkspaceHandle is an RAII object to implement thread-safe
   *  sequential cudnn function calls. */
  CudnnWorkspaceHandle cudnn_workspace_handle() const;

520
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
521
  cusolverDnHandle_t cusolver_dn_handle() const;
522
#endif
G
Guo Sheng 已提交
523

Q
init  
qijun 已提交
524
  /*! \brief  Return cuda stream in the device context. */
525
  gpuStream_t stream() const;
Q
QI JUN 已提交
526

527
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
Q
qingqing01 已提交
528 529 530 531 532
  /*! \brief  Return nccl communicators. */
  ncclComm_t nccl_comm() const { return nccl_comm_; }

  /*! \brief  Set nccl communicators. */
  void set_nccl_comm(ncclComm_t comm) { nccl_comm_ = comm; }
Q
qingqing01 已提交
533
#endif
Q
qingqing01 已提交
534

Y
Yu Yang 已提交
535
  template <typename Callback>
536
  void RecordEvent(gpuEvent_t ev, Callback callback) const {
537
    return context()->Stream()->RecordEvent(ev, callback);
Y
Yu Yang 已提交
538 539
  }

S
sneaxiy 已提交
540 541
  template <typename Callback>
  void AddStreamCallback(Callback&& callback) const {
542 543 544 545 546
    return context()->Stream()->AddCallback(callback);
  }

  void WaitStreamCallback() const {
    return context()->Stream()->WaitCallback();
547 548
  }

549
  void ResetDefaultContext(const stream::Priority& priority) {
550 551 552
    default_ctx_.reset(new CUDAContext(place_, priority));
  }

553
  void ResetThreadContext(const stream::Priority& priority) {
554 555 556 557 558 559 560 561 562 563
    std::lock_guard<std::mutex> guard(ctx_mtx_);
    thread_ctx_[this].reset(new CUDAContext(place_, priority));
  }

  std::shared_ptr<CUDAContext> context() const {
    if (!thread_ctx_.count(this)) {
      return default_ctx_;
    }
    return thread_ctx_.at(this);
  }
S
sneaxiy 已提交
564

Q
QI JUN 已提交
565
 private:
D
dzhwinter 已提交
566
  CUDAPlace place_;
567
  std::shared_ptr<CUDAContext> default_ctx_;
Q
QI JUN 已提交
568

569 570 571 572 573 574
  // The thread_local static variable will be released before the
  // global static variable, so avoid using it in dtor.
  static thread_local std::unordered_map<const CUDADeviceContext*,
                                         std::shared_ptr<CUDAContext>>
      thread_ctx_;
  static thread_local std::mutex ctx_mtx_;
575

576 577
  mutable std::mutex cudnn_handle_mtx_;

578
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
Q
qingqing01 已提交
579 580 581 582 583 584
  // NCCL communicator (single process version) for NCCL collective operations.
  // NCCL collective operations provides fast collectives over multiple GPUs
  // both within and across nodes.
  // But, this collectives is used for collectives over multiple GPUs within
  // nodes.
  ncclComm_t nccl_comm_{nullptr};
Q
qingqing01 已提交
585
#endif
Q
qingqing01 已提交
586

C
chengduo 已提交
587 588 589 590 591
  int compute_capability_;
  int runtime_version_;
  int driver_version_;
  int multi_process_;
  int max_threads_per_mp_;
592
  int max_threads_per_block_;
593
  dim3 max_grid_dim_size_;
Y
yuyang18 已提交
594

595
  DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
Q
QI JUN 已提交
596
};
Q
qijun 已提交
597

598 599
class CudnnWorkspaceHandle {
 public:
600 601
  inline CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx, std::mutex* mtx)
      : device_context_(dev_ctx), mtx_(mtx) {}
602 603 604 605 606 607 608 609

  template <typename Callback>
  inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_bytes) {
    if (required_workspace_bytes > WorkspaceSize()) {
      ReallocWorkspace(required_workspace_bytes);
    }
    VLOG(2) << "Cudnn workspace size at RunFunc: "
            << static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB";
610 611 612 613
    {
      std::lock_guard<std::mutex> guard(*mtx_);
      cudnn_func(allocation_ ? allocation_->ptr() : nullptr);
    }
614 615 616 617 618 619 620 621 622 623 624 625 626
  }

  /*! \brief Thread which call RunFuncSync() would release gpu memory after
   *  running the function. Currently this function is only used when cudnn
   *  exhaustive searching and callers have to guarantee that the input function
   *  is host blocking */
  template <typename Callback>
  inline void RunFuncSync(Callback&& cudnn_func,
                          size_t required_workspace_bytes) {
    RunFunc(cudnn_func, required_workspace_bytes);
    ResetWorkspace();
  }

627
  void ReallocWorkspace(size_t required_workspace_bytes);
628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643

  inline void ResetWorkspace() { allocation_ = nullptr; }

  inline size_t WorkspaceSize() {
    if (allocation_ == nullptr) {
      return 0;
    }
    return allocation_->size();
  }

  CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
  CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;

 private:
  memory::allocation::AllocationPtr allocation_;
  const CUDADeviceContext& device_context_;
644
  std::mutex* mtx_;
645 646
};

Y
Yang Yu 已提交
647 648
template <>
struct DefaultDeviceContextType<platform::CUDAPlace> {
Y
Yang Yu 已提交
649
  using TYPE = CUDADeviceContext;
Y
Yang Yu 已提交
650 651
};

C
chengduoZH 已提交
652
// Currently, CUDAPinnedDeviceContext is only used to data copying.
C
chengduoZH 已提交
653 654 655 656 657 658
class CUDAPinnedDeviceContext : public DeviceContext {
 public:
  CUDAPinnedDeviceContext();
  explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place);

  Place GetPlace() const override;
C
chengduoZH 已提交
659

C
chengduoZH 已提交
660 661 662 663 664 665 666 667 668 669 670
  Eigen::DefaultDevice* eigen_device() const;

 private:
  CUDAPinnedPlace place_;
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

template <>
struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
  using TYPE = CUDAPinnedDeviceContext;
};
Q
QI JUN 已提交
671
#endif
Q
qijun 已提交
672

T
tensor-tang 已提交
673
#ifdef PADDLE_WITH_MKLDNN
674 675 676 677 678 679

class MKLDNNDeviceContextThreadLocals {
  // default mkldnn session id

  typedef MKLDNNDeviceContextThreadLocals self;
  struct Body {
680
    bool said_once = false;
681 682 683 684 685 686 687 688 689 690 691
    size_t cur_mkldnn_session_id;
    // Current data input shape string.
    // - For fixed-shape, it's a null string in default.
    // - For dynamic-shape, it's user specific.
    std::string cur_input_shape_str;
    // the cache capacity of different input shapes for MKLDNN.
    // Default 1 means fixed input shape, not dynamic shape.
    int cur_input_shape_cache_capacity;
    // Recently registered data_format. This is needed to
    // know for converting MKL-DNN Tensor to non MKL-DNN
    paddle::framework::DataLayout cur_paddle_data_layout;
692 693 694
    // MKL-DNN stream used for execution of primitives (per-thread)
    mkldnn::engine cur_engine;
    mkldnn::stream cur_stream;
J
Jacek Czaja 已提交
695 696
    std::string key_suffix;  // Key identifying current Executor
    bool key_attach_thread_id = true;
697
    void* exec_ptr_ = nullptr;
698 699

    Body();
700
    ~Body();
701 702 703 704 705 706
    void set_cur_mkldnn_session_id(size_t sid);
    size_t get_cur_mkldnn_session_id(void);
    void set_cur_input_shape_str(std::string input_shape_str);
    void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
    void set_cur_paddle_data_layout(framework::DataLayout dl);
    framework::DataLayout get_cur_paddle_data_layout(void);
707
    void log_lib_version(void);
708 709
    const mkldnn::engine& get_engine(void);
    mkldnn::stream& get_stream(void);
J
Jacek Czaja 已提交
710 711 712 713
    void set_key_suffix(const std::string& suffix) { key_suffix = suffix; }
    const std::string& get_key_suffix(void) const { return key_suffix; }
    void disable_tid_in_key(void) { key_attach_thread_id = false; }
    bool is_tid_used_in_key(void) const { return key_attach_thread_id; }
714 715
    void set_curr_exec(void* exec_ptr) { exec_ptr_ = exec_ptr; }
    void* get_curr_exec(void) const { return exec_ptr_; }
716 717 718 719 720 721 722 723 724 725 726 727 728 729 730
  };
  MKLDNNDeviceContextThreadLocals() = default;
  MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
      delete;

 public:
  // default mkldnn session id
  static constexpr size_t kMKLDNNSessionID_Default = 0;
  // mkldnn session id for cache clearing mode
  static constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
  static Body& fetch() {
    thread_local Body b;
    return b;
  }
};
S
Sylwester Fraczek 已提交
731

T
tensor-tang 已提交
732 733
class MKLDNNDeviceContext : public CPUDeviceContext {
 public:
734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750
  template <class T>
  using BlobPtr_t = std::shared_ptr<T>;
  template <class P1, class P2>
  using umap_value_smart_t = std::unordered_map<P1, BlobPtr_t<P2>>;
  template <class T>
  using umap_key_string_t = umap_value_smart_t<std::string, T>;

  // Following three maps are used to cache MKLDNN primitives.
  // There relations are:
  // - BlobMap = Map<cur_thread_id, ShapeBlob>
  // - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
  // - KeyBlob  = Map<blob_name, blob>

  using KeyBlob = umap_key_string_t<void>;
  using ShapeBlob = umap_key_string_t<KeyBlob>;
  using BlobMap = umap_value_smart_t<int, ShapeBlob>;

751 752 753
  using ExecMap = std::unordered_map<
      void*, std::vector<std::pair<BlobPtr_t<KeyBlob>, KeyBlob::iterator>>>;

T
tensor-tang 已提交
754 755 756
  explicit MKLDNNDeviceContext(CPUPlace place);

  /* \brief  Get the active engine */
757
  const mkldnn::engine& GetEngine() const { return tls().get_engine(); }
T
tensor-tang 已提交
758

759 760 761
  // Register object to currently used executor's map
  void LinkEntryWithExecutor(BlobPtr_t<KeyBlob>, KeyBlob::iterator) const;

762
  // Remove all entries from the blob map
763
  void ResetBlobMap(void* ptr);
764 765 766

  // Prevent next ResetBlobMap()
  void BlockNextCacheClearing();
767

768 769 770
  // Get the ShapeBlob size in cur_mkldnn_session_id.
  size_t GetShapeBlobSize() const;

771 772
  // Set data to blob (i.e. name/data pair). Create blob if not existing
  void SetBlob(const std::string& name, std::shared_ptr<void> data) const;
T
tensor-tang 已提交
773

774 775 776
  // Calculate number of oneDNN objects cached
  unsigned int GetCachedObjectsNumber(void);

777 778
  // Find a saved blob. Return nullptr if not found
  std::shared_ptr<void> GetBlob(const std::string& name) const;
T
tensor-tang 已提交
779

780 781 782 783
  static auto tls() -> decltype(MKLDNNDeviceContextThreadLocals::fetch()) {
    return MKLDNNDeviceContextThreadLocals::fetch();
  }

T
tensor-tang 已提交
784
 private:
785
  std::shared_ptr<BlobMap> p_blobmap_;
786 787 788
  // Map key is pointer of executor and value is a data(iterator in map) needed
  // to erase
  std::shared_ptr<ExecMap> p_exec_items_;
789
  std::shared_ptr<std::mutex> p_mutex_;
790
  bool block_next_cache_clearing_ = false;
T
tensor-tang 已提交
791 792 793
};
#endif

D
dzhwinter 已提交
794 795 796 797 798
/*! \brief device context pool singleton */
class DeviceContextPool {
 public:
  explicit DeviceContextPool(const std::vector<platform::Place>& places);

Y
Yang Yu 已提交
799
  static DeviceContextPool& Instance() {
G
GaoWei8 已提交
800 801 802
    PADDLE_ENFORCE_NOT_NULL(pool,
                            platform::errors::PreconditionNotMet(
                                "Need to Create DeviceContextPool firstly!"));
D
dzhwinter 已提交
803 804 805 806
    return *pool;
  }

  /*! \brief  Create should only called by Init function */
Y
Yang Yu 已提交
807
  static DeviceContextPool& Init(const std::vector<platform::Place>& places) {
D
dzhwinter 已提交
808 809 810 811 812 813
    if (pool == nullptr) {
      pool = new DeviceContextPool(places);
    }
    return *pool;
  }

814 815
  static void SetPool(DeviceContextPool* dev_pool) { pool = dev_pool; }

D
dzhwinter 已提交
816
  /*! \brief  Return handle of single device context. */
Y
Yu Yang 已提交
817
  platform::DeviceContext* Get(const platform::Place& place);
D
dzhwinter 已提交
818

Y
Yang Yu 已提交
819 820 821 822 823 824 825
  template <typename Place>
  const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
      const Place& place) {
    return reinterpret_cast<
        const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place));
  }

826 827
  size_t size() const { return device_contexts_.size(); }

D
dzhwinter 已提交
828 829
 private:
  static DeviceContextPool* pool;
830 831
  std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>
      device_contexts_;
D
dzhwinter 已提交
832 833 834
  DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};

Q
QI JUN 已提交
835 836
}  // namespace platform
}  // namespace paddle