device_context.h 27.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
QI JUN 已提交
2 3 4 5 6 7 8 9 10 11 12
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

13
#include <future>  // NOLINT
D
dzhwinter 已提交
14
#include <memory>
Y
yuyang18 已提交
15
#include <mutex>  // NOLINT
16
#include <string>
D
dzhwinter 已提交
17
#include <unordered_map>
18
#include <utility>
19
#include <vector>
W
wanghuancoder 已提交
20

W
Wilber 已提交
21 22 23
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/device_context.h"

Y
Yu Yang 已提交
24
#include "paddle/fluid/memory/malloc.h"
25
#ifdef PADDLE_WITH_CUDA
26
#include "paddle/fluid/platform/device/gpu/gpu_helper.h"
Y
Yi Wang 已提交
27 28
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
G
Guo Sheng 已提交
29
#include "paddle/fluid/platform/dynload/cusolver.h"
30
#include "paddle/fluid/platform/dynload/cusparse.h"
31
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
W
Wu Yi 已提交
32
#include "paddle/fluid/platform/dynload/nccl.h"
W
Wu Yi 已提交
33
#endif
34
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
Q
QI JUN 已提交
35
#endif
D
dzhwinter 已提交
36

37
#ifdef PADDLE_WITH_HIP
38
#include "paddle/fluid/platform/device/gpu/gpu_helper.h"  // NOLINT
39 40 41 42 43
#include "paddle/fluid/platform/dynload/miopen.h"
#include "paddle/fluid/platform/dynload/rocblas.h"
#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/dynload/rccl.h"
#endif
44
#include "paddle/fluid/platform/device/gpu/gpu_info.h"  // NOLINT
45 46
#endif

47 48 49 50
#if defined(PADDLE_WITH_XPU_BKCL)
#include "xpu/bkcl.h"
#endif

T
tensor-tang 已提交
51
#ifdef PADDLE_WITH_MKLDNN
52
#include "dnnl.hpp"
53
#include "paddle/fluid/framework/data_layout.h"
T
tensor-tang 已提交
54 55
#endif

56
#include <map>
W
wanghuancoder 已提交
57

58
#include "glog/logging.h"
Y
Yi Wang 已提交
59 60
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
61
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
62
#include "paddle/fluid/platform/stream/cuda_stream.h"
S
sneaxiy 已提交
63
#endif
64
#ifdef PADDLE_WITH_ASCEND_CL
65 66
#include "paddle/fluid/platform/device/npu/enforce_npu.h"
#include "paddle/fluid/platform/device/npu/npu_stream.h"
67
#endif
Q
qijun 已提交
68
#include "unsupported/Eigen/CXX11/Tensor"
Q
QI JUN 已提交
69

W
wanghuancoder 已提交
70 71 72 73 74
namespace Eigen {
struct DefaultDevice;
struct GpuDevice;
}  // namespace Eigen

75
#ifdef PADDLE_WITH_XPU
76 77
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
W
Wilber 已提交
78
#include "paddle/pten/backends/xpu/xpu_context.h"
79 80
#endif

81 82
#ifdef PADDLE_WITH_ASCEND_CL
#include "acl/acl.h"
83
#include "paddle/fluid/platform/device/npu/npu_info.h"
84 85
#endif

Q
QI JUN 已提交
86 87 88
namespace paddle {
namespace platform {

89
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
90 91 92 93
/*Set the value of the global variable allow_tf32_cublas*/
void SetAllowTF32Cublas(bool active);
/*Get the global variable allow_tf32_cublas value*/
bool AllowTF32Cublas();
A
AshburnLee 已提交
94
extern bool allow_tf32_cudnn;
A
AshburnLee 已提交
95 96 97 98
/*Set the value of the global variable allow_tf32_cudnn*/
void SetAllowTF32Cudnn(bool active);
/*Get the global variable allow_tf32_cudnn value*/
bool AllowTF32Cudnn();
99 100
#endif  // PADDLE_WITH_CUDA

101 102 103 104
enum DeviceType {
  CPU = 0,
  CUDA = 1,
  XPU = 2,
105
  NPU = 3,
J
jianghaicheng 已提交
106
  IPU = 4,
F
fwenguang 已提交
107 108 109
  MLU = 5,

  MAX_DEVICE_TYPES = 6,
110 111
};

112 113
DeviceType Place2DeviceType(const platform::Place& place);

114 115 116
constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUDA = DeviceType::CUDA;
constexpr DeviceType kXPU = DeviceType::XPU;
117
constexpr DeviceType kNPU = DeviceType::NPU;
J
jianghaicheng 已提交
118
constexpr DeviceType kIPU = DeviceType::IPU;
F
fwenguang 已提交
119
constexpr DeviceType kMLU = DeviceType::MLU;
120

W
Wilber 已提交
121
using DeviceContext = pten::DeviceContext;
Q
QI JUN 已提交
122

W
Wilber 已提交
123 124 125 126
// using CPUDeviceContext = pten::CPUContext;
// TODO(wilber): The place constructor is used in many places, it is more
// difficult to use CPUDeviceContext = pten::CPUContext directly.
class CPUDeviceContext : public pten::CPUContext {
Q
qijun 已提交
127
 public:
128
  CPUDeviceContext();
Q
qijun 已提交
129
  explicit CPUDeviceContext(CPUPlace place);
Q
QI JUN 已提交
130 131
};

Y
Yang Yu 已提交
132 133 134 135 136 137 138 139
template <typename Place>
struct DefaultDeviceContextType;

template <>
struct DefaultDeviceContextType<platform::CPUPlace> {
  using TYPE = CPUDeviceContext;
};

J
jianghaicheng 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
// Graphcore IPU
#ifdef PADDLE_WITH_IPU
class IPUDeviceContext : public DeviceContext {
 public:
  IPUDeviceContext() = delete;
  explicit IPUDeviceContext(IPUPlace place);
  virtual ~IPUDeviceContext();
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
  Place GetPlace() const override;
  /*! \brief  Wait for all operations completion in the stream. */
  void Wait() const override;

 private:
  IPUPlace place_;
};
template <>
struct DefaultDeviceContextType<platform::IPUPlace> {
  using TYPE = IPUDeviceContext;
};
F
fwenguang 已提交
159
#endif
J
jianghaicheng 已提交
160

F
fwenguang 已提交
161 162 163 164 165
#ifdef PADDLE_WITH_MLU
class MLUDeviceContext;

template <>
struct DefaultDeviceContextType<platform::MLUPlace>;
J
jianghaicheng 已提交
166 167
#endif

168
#ifdef PADDLE_WITH_XPU
Q
QingshuChen 已提交
169
namespace xpu = baidu::xpu::api;
W
Wilber 已提交
170
class XPUDeviceContext : public pten::XPUContext {
171 172 173 174 175 176 177 178 179 180 181 182 183
 public:
  XPUDeviceContext();
  explicit XPUDeviceContext(XPUPlace place);
  virtual ~XPUDeviceContext();
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
};

template <>
struct DefaultDeviceContextType<platform::XPUPlace> {
  using TYPE = XPUDeviceContext;
};
#endif

184 185 186 187 188 189 190 191
#ifdef PADDLE_WITH_ASCEND_CL
class NPUDeviceContext : public DeviceContext {
 public:
  explicit NPUDeviceContext(NPUPlace place);
  virtual ~NPUDeviceContext();
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
  Place GetPlace() const override;
  aclrtContext context() const;
192

193 194 195 196 197 198
  /*! \brief  Wait for all operations completion in the stream. */
  void Wait() const override;

  /*! \brief  Return npu stream in the device context. */
  aclrtStream stream() const;

199 200 201 202 203 204 205
  template <typename Callback>
  void AddStreamCallback(Callback&& callback) const {
    return stream_->AddCallback(callback);
  }

  void WaitStreamCallback() const { return stream_->WaitCallback(); }

206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
#if defined(PADDLE_WITH_ASCEND_CL)
  /*! \brief  Return hccl communicators. */
  HcclComm hccl_comm() const { return hccl_comm_; }

  /*! \brief  Set hccl communicators. */
  void set_hccl_comm(HcclComm comm) { hccl_comm_ = comm; }
#endif

  // template <typename Callback>
  // void AddStreamCallback(Callback&& callback) const {
  //   return stream_->AddCallback(callback);
  // }

  // void WaitStreamCallback() const { return stream_->WaitCallback(); }

221 222 223
 private:
  NPUPlace place_;
  aclrtContext context_;
224 225 226 227

#ifdef PADDLE_WITH_ASCEND_CL
  // HCCLContext_t hccl_context_;
  HcclComm hccl_comm_{nullptr};
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
#endif

  // Need to be the same with other DeviceContext,
  // Eventhough eigen_device_ is not used in NPU
  // NOTE(zhiqiu): why need?
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
  std::shared_ptr<stream::NPUStream> stream_;

  DISABLE_COPY_AND_ASSIGN(NPUDeviceContext);
};

template <>
struct DefaultDeviceContextType<platform::NPUPlace> {
  using TYPE = NPUDeviceContext;
};
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263

// Currently, NPUPinnedDeviceContext is only used to data copying.
class NPUPinnedDeviceContext : public DeviceContext {
 public:
  NPUPinnedDeviceContext();
  explicit NPUPinnedDeviceContext(NPUPinnedPlace place);

  Place GetPlace() const override;

  Eigen::DefaultDevice* eigen_device() const;

 private:
  NPUPinnedPlace place_;
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

template <>
struct DefaultDeviceContextType<platform::NPUPinnedPlace> {
  using TYPE = NPUPinnedDeviceContext;
};

264 265 266
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
267
class CudnnWorkspaceHandle;
W
wanghuancoder 已提交
268
class EigenCudaStreamDevice;
S
sneaxiy 已提交
269

270 271 272 273 274
class CUDAContext {
 public:
  CUDAContext() = default;
  explicit CUDAContext(
      const CUDAPlace& place,
275 276
      const stream::Priority& priority = stream::Priority::kNormal,
      const stream::StreamFlag& flag = stream::StreamFlag::kDefaultFlag);
277 278 279 280 281 282 283 284 285 286 287 288 289 290 291

  ~CUDAContext();

  const CUDAPlace& Place() const { return place_; }

  const std::unique_ptr<Eigen::GpuDevice>& EigenDevice() const {
    return eigen_device_;
  }

  const std::unique_ptr<EigenCudaStreamDevice>& EigenStream() const {
    return eigen_stream_;
  }

  const std::unique_ptr<stream::CUDAStream>& Stream() const { return stream_; }

292 293 294 295 296 297
  stream::CUDAStream* SetStream(stream::CUDAStream* new_stream_ptr) {
    auto* old_stream_ptr = stream_.release();
    stream_.reset(new_stream_ptr);
    return old_stream_ptr;
  }

W
Wilber 已提交
298 299
  void SetStream(gpuStream_t stream);

300
  const gpuStream_t& RawStream() { return stream_->raw_stream(); }
301

302 303 304
#ifdef PADDLE_WITH_HIP
  const miopenHandle_t& CudnnHandle() const { return cudnn_handle_; }
#else
305
  const cudnnHandle_t& CudnnHandle() const { return cudnn_handle_; }
306
#endif
307

308
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
309 310 311
  const cusolverDnHandle_t& CusolverDnHandle() const {
    return cusolver_dn_handle_;
  }
312
#endif
G
Guo Sheng 已提交
313

314 315 316 317 318 319 320 321
  const std::unique_ptr<CublasHandleHolder>& CublasHandle() const {
    return cublas_handle_;
  }

  const std::unique_ptr<CublasHandleHolder>& CublasTensorCoreHandle() const {
    return cublas_tensor_core_handle_;
  }

Z
zhangkaihuo 已提交
322 323 324 325 326 327
#ifndef PADDLE_WITH_HIP
  const std::unique_ptr<CusparseHandleHolder>& CusparseHandle() const {
    return cusparse_handle_;
  }
#endif

328 329 330
  /*! \brief  Call cublas function safely. */
  template <typename Callback>
  inline void CublasCall(Callback&& callback) const {
331 332 333 334 335
    if (cublas_tf32_tensor_core_handle_) {
      cublas_tf32_tensor_core_handle_->Call(std::forward<Callback>(callback));
    } else {
      cublas_handle_->Call(std::forward<Callback>(callback));
    }
336 337
  }

Z
zhangkaihuo 已提交
338 339 340 341 342 343 344 345
#ifndef PADDLE_WITH_HIP
  /*! \brief  Call cusparse function safely. */
  template <typename Callback>
  inline void CusparseCall(Callback&& callback) const {
    cusparse_handle_->Call(std::forward<Callback>(callback));
  }
#endif

346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
  /*! \brief  Check whether tensor core is supported */
  bool tensor_core_available() const;

  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
  template <typename Callback>
  inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
    if (cublas_tensor_core_handle_) {
      cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
    } else {
      cublas_handle_->Call(std::forward<Callback>(callback));
    }
  }

 private:
  void InitEigenContext();

363 364 365 366 367
#ifdef PADDLE_WITH_HIP
  void InitCuBlasContext() {
    cublas_handle_.reset(new CublasHandleHolder(RawStream()));
  }
#else
368 369 370 371 372 373 374
  void InitCuBlasContext() {
    cublas_handle_.reset(
        new CublasHandleHolder(RawStream(), CUBLAS_DEFAULT_MATH));
    if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
      cublas_tensor_core_handle_.reset(
          new CublasHandleHolder(RawStream(), CUBLAS_TENSOR_OP_MATH));
375 376 377 378 379
#if CUDA_VERSION >= 11000
      cublas_tf32_tensor_core_handle_.reset(
          new CublasHandleHolder(RawStream(), CUBLAS_TF32_TENSOR_OP_MATH));
#endif  // CUDA_VERSION >= 11000
#endif  // CUDA_VERSION >= 9000
380 381
    }
  }
382
#endif
383

Z
zhangkaihuo 已提交
384 385 386 387 388 389
#ifndef PADDLE_WITH_HIP
  void InitCuSparseContext() {
    cusparse_handle_.reset(new CusparseHandleHolder(RawStream()));
  }
#endif

390 391
  void InitCuDNNContext() {
    if (dynload::HasCUDNN()) {
392 393
#ifdef PADDLE_WITH_HIP
      size_t miopen_major, miopen_minor, miopen_patch;
394
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenGetVersion(
395 396
          &miopen_major, &miopen_minor, &miopen_patch));
      auto local_miopen_version =
397 398
          (miopen_major * 1000 + miopen_minor * 10 + miopen_patch) / 10;
      auto compile_miopen_version = MIOPEN_VERSION / 10;
399 400 401 402
      if (local_miopen_version < static_cast<size_t>(compile_miopen_version)) {
        LOG_FIRST_N(WARNING, 1)
            << "WARNING: device: " << place_.device
            << ". The installed Paddle is compiled with MIOPEN "
403 404
            << compile_miopen_version / 100 << "."
            << compile_miopen_version % 100
405
            << ", but MIOPEN version in your machine is "
406
            << local_miopen_version / 100 << "." << local_miopen_version % 100
407 408 409 410
            << ", which may cause serious incompatible bug. "
            << "Please recompile or reinstall Paddle with compatible MIOPEN "
               "version.";
      }
411 412
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreate(&cudnn_handle_));
      PADDLE_ENFORCE_GPU_SUCCESS(
413 414
          dynload::miopenSetStream(cudnn_handle_, RawStream()));
#else
415 416 417 418 419 420 421 422 423 424 425 426 427
      auto local_cudnn_version = dynload::cudnnGetVersion() / 100;
      auto compile_cudnn_version = CUDNN_VERSION / 100;
      if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
        LOG_FIRST_N(WARNING, 1)
            << "WARNING: device: " << place_.device
            << ". The installed Paddle is compiled with CUDNN "
            << compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
            << ", but CUDNN version in your machine is "
            << local_cudnn_version / 10 << "." << local_cudnn_version % 10
            << ", which may cause serious incompatible bug. "
            << "Please recompile or reinstall Paddle with compatible CUDNN "
               "version.";
      }
428 429
      PADDLE_RETRY_CUDA_SUCCESS(dynload::cudnnCreate(&cudnn_handle_));
      PADDLE_RETRY_CUDA_SUCCESS(
430
          dynload::cudnnSetStream(cudnn_handle_, RawStream()));
431
#endif
432 433 434 435 436
    } else {
      cudnn_handle_ = nullptr;
    }
  }

437
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
438
  void InitCuSolverContext() {
439 440
    PADDLE_RETRY_CUDA_SUCCESS(dynload::cusolverDnCreate(&cusolver_dn_handle_));
    PADDLE_RETRY_CUDA_SUCCESS(
G
Guo Sheng 已提交
441 442
        dynload::cusolverDnSetStream(cusolver_dn_handle_, RawStream()));
  }
443
#endif
G
Guo Sheng 已提交
444

445 446
  void DestoryCuDNNContext() {
    if (cudnn_handle_) {
447
#ifdef PADDLE_WITH_HIP
448
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenDestroy(cudnn_handle_));
449
#else
450
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::cudnnDestroy(cudnn_handle_));
451
#endif
452 453 454 455 456 457 458
    }
    cudnn_handle_ = nullptr;
  }

  void DestoryCuBlasContext() {
    cublas_handle_.reset();
    cublas_tensor_core_handle_.reset();
459
    cublas_tf32_tensor_core_handle_.reset();
460 461
  }

Z
zhangkaihuo 已提交
462 463 464 465
#ifndef PADDLE_WITH_HIP
  void DestoryCuSparseContext() { cusparse_handle_.reset(); }
#endif

466
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
467 468
  void DestoryCuSolverContext() {
    if (cusolver_dn_handle_) {
469
      PADDLE_ENFORCE_GPU_SUCCESS(
G
Guo Sheng 已提交
470 471 472
          dynload::cusolverDnDestroy(cusolver_dn_handle_));
    }
  }
473
#endif
G
Guo Sheng 已提交
474

475 476 477 478
  CUDAPlace place_;
  std::unique_ptr<Eigen::GpuDevice> eigen_device_;
  std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
  std::unique_ptr<stream::CUDAStream> stream_;
479 480 481
#ifdef PADDLE_WITH_HIP
  miopenHandle_t cudnn_handle_;
#else
482
  cudnnHandle_t cudnn_handle_;
483
#endif
484 485
  std::unique_ptr<CublasHandleHolder> cublas_handle_;
  std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
486
  std::unique_ptr<CublasHandleHolder> cublas_tf32_tensor_core_handle_;
487
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
488
  cusolverDnHandle_t cusolver_dn_handle_;
Z
zhangkaihuo 已提交
489
  std::unique_ptr<CusparseHandleHolder> cusparse_handle_;
490
#endif
491 492 493
  DISABLE_COPY_AND_ASSIGN(CUDAContext);
};

494
class CUDADeviceContext : public DeviceContext {
Q
QI JUN 已提交
495
 public:
D
dzhwinter 已提交
496
  explicit CUDADeviceContext(CUDAPlace place);
497
  virtual ~CUDADeviceContext();
Q
QI JUN 已提交
498

499
  /*! \brief  Wait for all operations completion in the stream. */
500
  void Wait() const override;
Q
QI JUN 已提交
501

502
  /*! \brief  Return place in the device context. */
L
liaogang 已提交
503
  Place GetPlace() const override;
504

K
Kexin Zhao 已提交
505
  /*! \brief  Return compute capability in the device context. */
K
Kexin Zhao 已提交
506 507
  int GetComputeCapability() const;

508 509 510
  /*! \brief  Return the max physical thread count in the device context */
  int GetMaxPhysicalThreadCount() const;

511 512 513 514 515 516
  /*! \brief  Return the SM count in the device context */
  int GetSMCount() const;

  /*! \brief  Return the Max thread num of block in the device context */
  int GetMaxThreadsPerBlock() const;

517 518 519
  /*! \brief  Return the max grid dim size in the device context */
  dim3 GetCUDAMaxGridDimSize() const;

520 521 522
  /*! \brief  Return eigen device in the device context. */
  Eigen::GpuDevice* eigen_device() const;

523 524 525
  /*! \brief  Call cublas function safely. */
  template <typename Callback>
  inline void CublasCall(Callback&& callback) const {
526
    return context()->CublasCall(callback);
527 528
  }

Z
zhangkaihuo 已提交
529 530 531 532 533 534 535 536
#ifndef PADDLE_WITH_HIP
  /*! \brief  Call cusparse function safely. */
  template <typename Callback>
  inline void CusparseCall(Callback&& callback) const {
    return context()->CusparseCall(callback);
  }
#endif

537 538 539 540 541 542 543
  /*! \brief  Check whether tensor core is supported */
  bool tensor_core_available() const;

  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
  template <typename Callback>
  inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
544
    return context()->TensorCoreCublasCallIfAvailable(callback);
545
  }
S
sneaxiy 已提交
546

547 548 549 550
/*! \brief  Return cudnn  handle in the device context. */
#ifdef PADDLE_WITH_HIP
  miopenHandle_t cudnn_handle() const;
#else
551
  cudnnHandle_t cudnn_handle() const;
552
#endif
553

554 555 556 557
/*! \brief  Return cublas handle in the device context. */
#ifdef PADDLE_WITH_HIP
  rocblas_handle cublas_handle() const;
#else
558
  cublasHandle_t cublas_handle() const;
Z
zhangkaihuo 已提交
559
  cusparseHandle_t cusparse_handle() const;
560
#endif
561

S
sneaxiy 已提交
562 563 564 565 566 567 568 569 570
  /*! \brief  Return a cudnn workspace handle to call multiple cudnn
   *  functions without interrupting by other threads.
   *  Once the first cudnn function is called by the handle, a lock
   *  would be acquired to prevent other threads from accessing the
   *  workspace. Once the handle is destructed, the lock would be released.
   *  CudnnWorkspaceHandle is an RAII object to implement thread-safe
   *  sequential cudnn function calls. */
  CudnnWorkspaceHandle cudnn_workspace_handle() const;

571
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
572
  cusolverDnHandle_t cusolver_dn_handle() const;
573
#endif
G
Guo Sheng 已提交
574

Q
init  
qijun 已提交
575
  /*! \brief  Return cuda stream in the device context. */
576
  gpuStream_t stream() const;
Q
QI JUN 已提交
577

578
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
Q
qingqing01 已提交
579 580 581 582 583
  /*! \brief  Return nccl communicators. */
  ncclComm_t nccl_comm() const { return nccl_comm_; }

  /*! \brief  Set nccl communicators. */
  void set_nccl_comm(ncclComm_t comm) { nccl_comm_ = comm; }
Q
qingqing01 已提交
584
#endif
Q
qingqing01 已提交
585

Y
Yu Yang 已提交
586
  template <typename Callback>
587
  void RecordEvent(gpuEvent_t ev, Callback callback) const {
588
    return context()->Stream()->RecordEvent(ev, callback);
Y
Yu Yang 已提交
589 590
  }

S
sneaxiy 已提交
591 592
  template <typename Callback>
  void AddStreamCallback(Callback&& callback) const {
593 594 595 596 597
    return context()->Stream()->AddCallback(callback);
  }

  void WaitStreamCallback() const {
    return context()->Stream()->WaitCallback();
598 599
  }

600
  void ResetDefaultContext(const stream::Priority& priority) {
601 602 603
    default_ctx_.reset(new CUDAContext(place_, priority));
  }

604
  void ResetThreadContext(const stream::Priority& priority) {
605 606 607 608 609 610 611 612 613 614
    std::lock_guard<std::mutex> guard(ctx_mtx_);
    thread_ctx_[this].reset(new CUDAContext(place_, priority));
  }

  std::shared_ptr<CUDAContext> context() const {
    if (!thread_ctx_.count(this)) {
      return default_ctx_;
    }
    return thread_ctx_.at(this);
  }
S
sneaxiy 已提交
615

W
Wilber 已提交
616 617 618 619 620
  // Note: Can only be used under thread_local semantics.
  void SetThreadLocalStream(const gpuStream_t stream) {
    thread_ctx_.at(this)->SetStream(stream);
  }

Q
QI JUN 已提交
621
 private:
D
dzhwinter 已提交
622
  CUDAPlace place_;
623
  std::shared_ptr<CUDAContext> default_ctx_;
Q
QI JUN 已提交
624

625 626 627 628 629 630
  // The thread_local static variable will be released before the
  // global static variable, so avoid using it in dtor.
  static thread_local std::unordered_map<const CUDADeviceContext*,
                                         std::shared_ptr<CUDAContext>>
      thread_ctx_;
  static thread_local std::mutex ctx_mtx_;
631

632 633
  mutable std::mutex cudnn_handle_mtx_;

634
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
Q
qingqing01 已提交
635 636 637 638 639 640
  // NCCL communicator (single process version) for NCCL collective operations.
  // NCCL collective operations provides fast collectives over multiple GPUs
  // both within and across nodes.
  // But, this collectives is used for collectives over multiple GPUs within
  // nodes.
  ncclComm_t nccl_comm_{nullptr};
Q
qingqing01 已提交
641
#endif
Q
qingqing01 已提交
642

C
chengduo 已提交
643 644 645 646 647
  int compute_capability_;
  int runtime_version_;
  int driver_version_;
  int multi_process_;
  int max_threads_per_mp_;
648
  int max_threads_per_block_;
649
  dim3 max_grid_dim_size_;
Y
yuyang18 已提交
650

651
  DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
Q
QI JUN 已提交
652
};
Q
qijun 已提交
653

654 655
class CudnnWorkspaceHandle {
 public:
656 657
  inline CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx, std::mutex* mtx)
      : device_context_(dev_ctx), mtx_(mtx) {}
658 659 660 661 662 663 664 665

  template <typename Callback>
  inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_bytes) {
    if (required_workspace_bytes > WorkspaceSize()) {
      ReallocWorkspace(required_workspace_bytes);
    }
    VLOG(2) << "Cudnn workspace size at RunFunc: "
            << static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB";
666 667 668 669
    {
      std::lock_guard<std::mutex> guard(*mtx_);
      cudnn_func(allocation_ ? allocation_->ptr() : nullptr);
    }
670 671 672 673 674 675 676 677 678 679 680 681 682
  }

  /*! \brief Thread which call RunFuncSync() would release gpu memory after
   *  running the function. Currently this function is only used when cudnn
   *  exhaustive searching and callers have to guarantee that the input function
   *  is host blocking */
  template <typename Callback>
  inline void RunFuncSync(Callback&& cudnn_func,
                          size_t required_workspace_bytes) {
    RunFunc(cudnn_func, required_workspace_bytes);
    ResetWorkspace();
  }

683
  void ReallocWorkspace(size_t required_workspace_bytes);
684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699

  inline void ResetWorkspace() { allocation_ = nullptr; }

  inline size_t WorkspaceSize() {
    if (allocation_ == nullptr) {
      return 0;
    }
    return allocation_->size();
  }

  CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
  CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;

 private:
  memory::allocation::AllocationPtr allocation_;
  const CUDADeviceContext& device_context_;
700
  std::mutex* mtx_;
701 702
};

Y
Yang Yu 已提交
703 704
template <>
struct DefaultDeviceContextType<platform::CUDAPlace> {
Y
Yang Yu 已提交
705
  using TYPE = CUDADeviceContext;
Y
Yang Yu 已提交
706 707
};

C
chengduoZH 已提交
708
// Currently, CUDAPinnedDeviceContext is only used to data copying.
C
chengduoZH 已提交
709 710 711 712 713 714
class CUDAPinnedDeviceContext : public DeviceContext {
 public:
  CUDAPinnedDeviceContext();
  explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place);

  Place GetPlace() const override;
C
chengduoZH 已提交
715

C
chengduoZH 已提交
716 717 718 719 720 721 722 723 724 725 726
  Eigen::DefaultDevice* eigen_device() const;

 private:
  CUDAPinnedPlace place_;
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

template <>
struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
  using TYPE = CUDAPinnedDeviceContext;
};
Q
QI JUN 已提交
727
#endif
Q
qijun 已提交
728

T
tensor-tang 已提交
729
#ifdef PADDLE_WITH_MKLDNN
730 731 732 733 734 735

class MKLDNNDeviceContextThreadLocals {
  // default mkldnn session id

  typedef MKLDNNDeviceContextThreadLocals self;
  struct Body {
736
    bool said_once = false;
737 738 739 740 741 742 743 744 745 746 747
    size_t cur_mkldnn_session_id;
    // Current data input shape string.
    // - For fixed-shape, it's a null string in default.
    // - For dynamic-shape, it's user specific.
    std::string cur_input_shape_str;
    // the cache capacity of different input shapes for MKLDNN.
    // Default 1 means fixed input shape, not dynamic shape.
    int cur_input_shape_cache_capacity;
    // Recently registered data_format. This is needed to
    // know for converting MKL-DNN Tensor to non MKL-DNN
    paddle::framework::DataLayout cur_paddle_data_layout;
748
    // MKL-DNN stream used for execution of primitives (per-thread)
749 750
    dnnl::engine cur_engine;
    dnnl::stream cur_stream;
J
Jacek Czaja 已提交
751 752
    std::string key_suffix;  // Key identifying current Executor
    bool key_attach_thread_id = true;
753
    void* exec_ptr_ = nullptr;
754 755

    Body();
756
    ~Body();
757 758 759 760 761 762
    void set_cur_mkldnn_session_id(size_t sid);
    size_t get_cur_mkldnn_session_id(void);
    void set_cur_input_shape_str(std::string input_shape_str);
    void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
    void set_cur_paddle_data_layout(framework::DataLayout dl);
    framework::DataLayout get_cur_paddle_data_layout(void);
763
    void log_lib_version(void);
764 765
    const dnnl::engine& get_engine(void);
    dnnl::stream& get_stream(void);
J
Jacek Czaja 已提交
766 767 768 769
    void set_key_suffix(const std::string& suffix) { key_suffix = suffix; }
    const std::string& get_key_suffix(void) const { return key_suffix; }
    void disable_tid_in_key(void) { key_attach_thread_id = false; }
    bool is_tid_used_in_key(void) const { return key_attach_thread_id; }
770 771
    void set_curr_exec(void* exec_ptr) { exec_ptr_ = exec_ptr; }
    void* get_curr_exec(void) const { return exec_ptr_; }
772 773 774 775 776 777 778 779 780 781 782 783 784 785 786
  };
  MKLDNNDeviceContextThreadLocals() = default;
  MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
      delete;

 public:
  // default mkldnn session id
  static constexpr size_t kMKLDNNSessionID_Default = 0;
  // mkldnn session id for cache clearing mode
  static constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
  static Body& fetch() {
    thread_local Body b;
    return b;
  }
};
S
Sylwester Fraczek 已提交
787

T
tensor-tang 已提交
788 789
class MKLDNNDeviceContext : public CPUDeviceContext {
 public:
790 791 792 793 794 795 796 797 798 799
  template <class T>
  using BlobPtr_t = std::shared_ptr<T>;
  template <class P1, class P2>
  using umap_value_smart_t = std::unordered_map<P1, BlobPtr_t<P2>>;
  template <class T>
  using umap_key_string_t = umap_value_smart_t<std::string, T>;

  // Following three maps are used to cache MKLDNN primitives.
  // There relations are:
  // - BlobMap = Map<cur_thread_id, ShapeBlob>
800
  // - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
801 802 803
  // - KeyBlob  = Map<blob_name, blob>

  using KeyBlob = umap_key_string_t<void>;
804
  using ShapeBlob = umap_key_string_t<KeyBlob>;
805 806
  using BlobMap = umap_value_smart_t<int, ShapeBlob>;

807 808 809 810
  // Auxillary two-level structure (shape, executor) to easier control
  // clearing cache objects related to specific executor

  using ExecKey = void*;
811
  using ExecMapCacheIterPair = std::pair<BlobPtr_t<KeyBlob>, KeyBlob::iterator>;
812 813 814
  using ExecMap =
      std::unordered_map<ExecKey, std::vector<ExecMapCacheIterPair>>;
  using ExecShape = std::unordered_map<std::string, std::shared_ptr<ExecMap>>;
815

T
tensor-tang 已提交
816 817 818
  explicit MKLDNNDeviceContext(CPUPlace place);

  /* \brief  Get the active engine */
819
  const dnnl::engine& GetEngine() const { return tls().get_engine(); }
T
tensor-tang 已提交
820

821
  // Register object to currently used executor's map
822 823
  void LinkEntryWithExecutor(BlobPtr_t<KeyBlob>, KeyBlob::iterator) const;
  void RemoveShapeEntriesWithExecutor(void) const;
824

825
  // Remove all entries from the blob map
826
  void ResetBlobMap(void* ptr);
827 828 829

  // Prevent next ResetBlobMap()
  void BlockNextCacheClearing();
830

831 832 833
  // Get the ShapeBlob size in cur_mkldnn_session_id.
  size_t GetShapeBlobSize() const;

834 835
  // Set data to blob (i.e. name/data pair). Create blob if not existing
  void SetBlob(const std::string& name, std::shared_ptr<void> data) const;
T
tensor-tang 已提交
836

837
  // Calculate number of oneDNN objects cached
838
  unsigned int GetCachedObjectsNumber(void) const;
839

840 841
  // Find a saved blob. Return nullptr if not found
  std::shared_ptr<void> GetBlob(const std::string& name) const;
T
tensor-tang 已提交
842

843 844 845 846
  static auto tls() -> decltype(MKLDNNDeviceContextThreadLocals::fetch()) {
    return MKLDNNDeviceContextThreadLocals::fetch();
  }

T
tensor-tang 已提交
847
 private:
848
  std::shared_ptr<BlobMap> p_blobmap_;
849 850
  // Map key is pointer of executor and value is a data(iterator in map) needed
  // to erase
851
  std::shared_ptr<ExecShape> p_exec_items_;
852
  std::shared_ptr<std::mutex> p_mutex_;
853
  bool block_next_cache_clearing_ = false;
T
tensor-tang 已提交
854 855 856
};
#endif

D
dzhwinter 已提交
857 858 859 860 861
/*! \brief device context pool singleton */
class DeviceContextPool {
 public:
  explicit DeviceContextPool(const std::vector<platform::Place>& places);

Y
Yang Yu 已提交
862
  static DeviceContextPool& Instance() {
G
GaoWei8 已提交
863 864 865
    PADDLE_ENFORCE_NOT_NULL(pool,
                            platform::errors::PreconditionNotMet(
                                "Need to Create DeviceContextPool firstly!"));
D
dzhwinter 已提交
866 867 868 869
    return *pool;
  }

  /*! \brief  Create should only called by Init function */
Y
Yang Yu 已提交
870
  static DeviceContextPool& Init(const std::vector<platform::Place>& places) {
D
dzhwinter 已提交
871 872 873 874 875 876
    if (pool == nullptr) {
      pool = new DeviceContextPool(places);
    }
    return *pool;
  }

877 878
  static void SetPool(DeviceContextPool* dev_pool) { pool = dev_pool; }

D
dzhwinter 已提交
879
  /*! \brief  Return handle of single device context. */
Y
Yu Yang 已提交
880
  platform::DeviceContext* Get(const platform::Place& place);
D
dzhwinter 已提交
881

Y
Yang Yu 已提交
882 883 884 885 886 887 888
  template <typename Place>
  const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
      const Place& place) {
    return reinterpret_cast<
        const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place));
  }

889 890
  size_t size() const { return device_contexts_.size(); }

D
dzhwinter 已提交
891 892
 private:
  static DeviceContextPool* pool;
893 894
  std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>
      device_contexts_;
D
dzhwinter 已提交
895 896 897
  DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};

Q
QI JUN 已提交
898 899
}  // namespace platform
}  // namespace paddle