device_context.h 28.8 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3
Copyright (c) 2022 NVIDIA Corporation. All rights reserved.

Q
QI JUN 已提交
4 5 6 7 8 9 10 11 12 13 14
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

W
Wilber 已提交
15
#include <functional>
16
#include <future>  // NOLINT
D
dzhwinter 已提交
17
#include <memory>
Y
yuyang18 已提交
18
#include <mutex>  // NOLINT
19
#include <string>
D
dzhwinter 已提交
20
#include <unordered_map>
21
#include <utility>
22
#include <vector>
W
wanghuancoder 已提交
23

24
#include "paddle/fluid/memory/malloc.h"
W
Wilber 已提交
25
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
26
#include "paddle/phi/backends/cpu/cpu_context.h"
27
#include "paddle/phi/backends/custom/custom_context.h"
28 29
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/core/device_context.h"
30
#ifdef PADDLE_WITH_CUDA
31
#include "paddle/fluid/platform/device/gpu/gpu_helper.h"
Y
Yi Wang 已提交
32
#include "paddle/fluid/platform/dynload/cublas.h"
33
#include "paddle/fluid/platform/dynload/cublasLt.h"
Y
Yi Wang 已提交
34
#include "paddle/fluid/platform/dynload/cudnn.h"
G
Guo Sheng 已提交
35
#include "paddle/fluid/platform/dynload/cusolver.h"
36
#include "paddle/fluid/platform/dynload/cusparse.h"
37
#include "paddle/phi/backends/gpu/gpu_context.h"
38
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
W
Wu Yi 已提交
39
#include "paddle/fluid/platform/dynload/nccl.h"
W
Wu Yi 已提交
40
#endif
41
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
Q
QI JUN 已提交
42
#endif
D
dzhwinter 已提交
43

44
#ifdef PADDLE_WITH_HIP
45
#include "paddle/fluid/platform/device/gpu/gpu_helper.h"  // NOLINT
46 47
#include "paddle/fluid/platform/dynload/miopen.h"
#include "paddle/fluid/platform/dynload/rocblas.h"
48
#include "paddle/phi/backends/gpu/gpu_context.h"  // NOLINT
49 50 51
#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/dynload/rccl.h"
#endif
52
#include "paddle/fluid/platform/device/gpu/gpu_info.h"  // NOLINT
53 54
#endif

55 56 57 58
#if defined(PADDLE_WITH_XPU_BKCL)
#include "xpu/bkcl.h"
#endif

T
tensor-tang 已提交
59
#ifdef PADDLE_WITH_MKLDNN
60
#include "dnnl.hpp"  // NOLINT
61
#include "paddle/fluid/framework/data_layout.h"
T
tensor-tang 已提交
62 63
#endif

64
#include <map>
W
wanghuancoder 已提交
65

66
#include "glog/logging.h"
Y
Yi Wang 已提交
67 68
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
69
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
70
#include "paddle/fluid/platform/stream/cuda_stream.h"
S
sneaxiy 已提交
71
#endif
72
#ifdef PADDLE_WITH_ASCEND_CL
73 74
#include "paddle/fluid/platform/device/npu/enforce_npu.h"
#include "paddle/fluid/platform/device/npu/npu_stream.h"
75
#endif
76

77 78
#include "paddle/phi/backends/device_ext.h"
#include "paddle/phi/backends/stream.h"
79 80

#if !defined(PADDLE_WITH_XPU_KP) || defined(__xpu_on_host__)
Q
qijun 已提交
81
#include "unsupported/Eigen/CXX11/Tensor"
82
#endif
Q
QI JUN 已提交
83

W
wanghuancoder 已提交
84 85 86 87 88
namespace Eigen {
struct DefaultDevice;
struct GpuDevice;
}  // namespace Eigen

89
#ifdef PADDLE_WITH_XPU
90 91
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
92
#include "paddle/phi/backends/xpu/xpu_context.h"
93 94
#endif

95 96
#ifdef PADDLE_WITH_ASCEND_CL
#include "acl/acl.h"
97
#include "paddle/fluid/platform/device/npu/npu_info.h"
98 99
#endif

Q
QI JUN 已提交
100 101 102
namespace paddle {
namespace platform {

103
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
104 105 106 107
/*Set the value of the global variable allow_tf32_cublas*/
void SetAllowTF32Cublas(bool active);
/*Get the global variable allow_tf32_cublas value*/
bool AllowTF32Cublas();
A
AshburnLee 已提交
108
extern bool allow_tf32_cudnn;
A
AshburnLee 已提交
109 110 111 112
/*Set the value of the global variable allow_tf32_cudnn*/
void SetAllowTF32Cudnn(bool active);
/*Get the global variable allow_tf32_cudnn value*/
bool AllowTF32Cudnn();
113 114
#endif  // PADDLE_WITH_CUDA

115 116 117 118
enum DeviceType {
  CPU = 0,
  CUDA = 1,
  XPU = 2,
119
  NPU = 3,
J
jianghaicheng 已提交
120
  IPU = 4,
F
fwenguang 已提交
121 122 123
  MLU = 5,

  MAX_DEVICE_TYPES = 6,
124 125
};

126 127
DeviceType Place2DeviceType(const platform::Place& place);

128 129 130
constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUDA = DeviceType::CUDA;
constexpr DeviceType kXPU = DeviceType::XPU;
131
constexpr DeviceType kNPU = DeviceType::NPU;
J
jianghaicheng 已提交
132
constexpr DeviceType kIPU = DeviceType::IPU;
F
fwenguang 已提交
133
constexpr DeviceType kMLU = DeviceType::MLU;
134

135
using DeviceContext = phi::DeviceContext;
Q
QI JUN 已提交
136

Y
Yang Yu 已提交
137 138 139 140 141
template <typename Place>
struct DefaultDeviceContextType;

template <>
struct DefaultDeviceContextType<platform::CPUPlace> {
L
Leo Chen 已提交
142
  using TYPE = phi::CPUContext;
Y
Yang Yu 已提交
143 144
};

J
jianghaicheng 已提交
145 146 147 148 149 150 151 152
// Graphcore IPU
#ifdef PADDLE_WITH_IPU
class IPUDeviceContext : public DeviceContext {
 public:
  IPUDeviceContext() = delete;
  explicit IPUDeviceContext(IPUPlace place);
  virtual ~IPUDeviceContext();
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
W
Wilber 已提交
153
  const Place& GetPlace() const override;
J
jianghaicheng 已提交
154 155 156 157 158 159 160 161 162 163
  /*! \brief  Wait for all operations completion in the stream. */
  void Wait() const override;

 private:
  IPUPlace place_;
};
template <>
struct DefaultDeviceContextType<platform::IPUPlace> {
  using TYPE = IPUDeviceContext;
};
F
fwenguang 已提交
164
#endif
J
jianghaicheng 已提交
165

F
fwenguang 已提交
166 167 168 169 170
#ifdef PADDLE_WITH_MLU
class MLUDeviceContext;

template <>
struct DefaultDeviceContextType<platform::MLUPlace>;
J
jianghaicheng 已提交
171 172
#endif

173
#ifdef PADDLE_WITH_XPU
Q
QingshuChen 已提交
174
namespace xpu = baidu::xpu::api;
175
class XPUDeviceContext : public phi::XPUContext {
176 177 178 179 180
 public:
  XPUDeviceContext();
  explicit XPUDeviceContext(XPUPlace place);
  virtual ~XPUDeviceContext();
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
181
  xpuStream stream() const { return XPUContext::x_context()->xpu_stream; }
182 183 184 185 186 187 188 189
};

template <>
struct DefaultDeviceContextType<platform::XPUPlace> {
  using TYPE = XPUDeviceContext;
};
#endif

190 191 192 193 194 195
#ifdef PADDLE_WITH_ASCEND_CL
class NPUDeviceContext : public DeviceContext {
 public:
  explicit NPUDeviceContext(NPUPlace place);
  virtual ~NPUDeviceContext();
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
W
Wilber 已提交
196
  const Place& GetPlace() const override;
197
  aclrtContext context() const;
198

199 200 201 202 203 204
  /*! \brief  Wait for all operations completion in the stream. */
  void Wait() const override;

  /*! \brief  Return npu stream in the device context. */
  aclrtStream stream() const;

205 206 207 208 209 210 211
  template <typename Callback>
  void AddStreamCallback(Callback&& callback) const {
    return stream_->AddCallback(callback);
  }

  void WaitStreamCallback() const { return stream_->WaitCallback(); }

212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
#if defined(PADDLE_WITH_ASCEND_CL)
  /*! \brief  Return hccl communicators. */
  HcclComm hccl_comm() const { return hccl_comm_; }

  /*! \brief  Set hccl communicators. */
  void set_hccl_comm(HcclComm comm) { hccl_comm_ = comm; }
#endif

  // template <typename Callback>
  // void AddStreamCallback(Callback&& callback) const {
  //   return stream_->AddCallback(callback);
  // }

  // void WaitStreamCallback() const { return stream_->WaitCallback(); }

227 228 229
 private:
  NPUPlace place_;
  aclrtContext context_;
230 231 232 233

#ifdef PADDLE_WITH_ASCEND_CL
  // HCCLContext_t hccl_context_;
  HcclComm hccl_comm_{nullptr};
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
#endif

  // Need to be the same with other DeviceContext,
  // Eventhough eigen_device_ is not used in NPU
  // NOTE(zhiqiu): why need?
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
  std::shared_ptr<stream::NPUStream> stream_;

  DISABLE_COPY_AND_ASSIGN(NPUDeviceContext);
};

template <>
struct DefaultDeviceContextType<platform::NPUPlace> {
  using TYPE = NPUDeviceContext;
};
249 250 251 252 253 254 255

// Currently, NPUPinnedDeviceContext is only used to data copying.
class NPUPinnedDeviceContext : public DeviceContext {
 public:
  NPUPinnedDeviceContext();
  explicit NPUPinnedDeviceContext(NPUPinnedPlace place);

W
Wilber 已提交
256
  const Place& GetPlace() const override;
257 258 259 260 261 262 263 264 265 266 267 268 269

  Eigen::DefaultDevice* eigen_device() const;

 private:
  NPUPinnedPlace place_;
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

template <>
struct DefaultDeviceContextType<platform::NPUPinnedPlace> {
  using TYPE = NPUPinnedDeviceContext;
};

270 271 272
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
273
class CudnnWorkspaceHandle;
W
wanghuancoder 已提交
274
class EigenCudaStreamDevice;
S
sneaxiy 已提交
275

276 277 278 279 280
class CUDAContext {
 public:
  CUDAContext() = default;
  explicit CUDAContext(
      const CUDAPlace& place,
281 282
      const stream::Priority& priority = stream::Priority::kNormal,
      const stream::StreamFlag& flag = stream::StreamFlag::kDefaultFlag);
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297

  ~CUDAContext();

  const CUDAPlace& Place() const { return place_; }

  const std::unique_ptr<Eigen::GpuDevice>& EigenDevice() const {
    return eigen_device_;
  }

  const std::unique_ptr<EigenCudaStreamDevice>& EigenStream() const {
    return eigen_stream_;
  }

  const std::unique_ptr<stream::CUDAStream>& Stream() const { return stream_; }

298 299 300 301 302 303
  stream::CUDAStream* SetStream(stream::CUDAStream* new_stream_ptr) {
    auto* old_stream_ptr = stream_.release();
    stream_.reset(new_stream_ptr);
    return old_stream_ptr;
  }

W
Wilber 已提交
304 305
  void SetStream(gpuStream_t stream);

306
  const gpuStream_t& RawStream() { return stream_->raw_stream(); }
307

308 309 310
#ifdef PADDLE_WITH_HIP
  const miopenHandle_t& CudnnHandle() const { return cudnn_handle_; }
#else
311
  const cudnnHandle_t& CudnnHandle() const { return cudnn_handle_; }
312
#endif
313

314
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
315 316 317
  const cusolverDnHandle_t& CusolverDnHandle() const {
    return cusolver_dn_handle_;
  }
318
#endif
G
Guo Sheng 已提交
319

320 321 322 323 324 325 326 327
  const std::unique_ptr<CublasHandleHolder>& CublasHandle() const {
    return cublas_handle_;
  }

  const std::unique_ptr<CublasHandleHolder>& CublasTensorCoreHandle() const {
    return cublas_tensor_core_handle_;
  }

Z
zhangkaihuo 已提交
328
#ifndef PADDLE_WITH_HIP
329 330 331 332 333 334
#if CUDA_VERSION >= 11060
  const std::unique_ptr<CublasLtHandleHolder>& CublasLtHandle() const {
    return cublaslt_handle_;
  }
#endif

Z
zhangkaihuo 已提交
335 336 337 338 339
  const std::unique_ptr<CusparseHandleHolder>& CusparseHandle() const {
    return cusparse_handle_;
  }
#endif

340
  /*! \brief  Call cublas function safely. */
W
Wilber 已提交
341 342
  inline void CublasCall(
      const std::function<void(blasHandle_t)>& callback) const {
343
    if (cublas_tf32_tensor_core_handle_) {
W
Wilber 已提交
344
      cublas_tf32_tensor_core_handle_->Call(callback);
345
    } else {
W
Wilber 已提交
346
      cublas_handle_->Call(callback);
347
    }
348 349
  }

Z
zhangkaihuo 已提交
350
#ifndef PADDLE_WITH_HIP
351 352 353 354 355 356 357 358
#if CUDA_VERSION >= 11060
  /*! \brief  Call cublasLt function safely. */
  inline void CublasLtCall(
      const std::function<void(blasLtHandle_t)>& callback) const {
    cublaslt_handle_->Call(callback);
  }
#endif

Z
zhangkaihuo 已提交
359
  /*! \brief  Call cusparse function safely. */
W
Wilber 已提交
360
  inline void CusparseCall(
361
      const std::function<void(phi::sparseHandle_t)>& callback) const {
W
Wilber 已提交
362
    cusparse_handle_->Call(callback);
Z
zhangkaihuo 已提交
363 364 365
  }
#endif

366 367 368 369 370
  /*! \brief  Check whether tensor core is supported */
  bool tensor_core_available() const;

  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
W
Wilber 已提交
371 372
  inline void TensorCoreCublasCallIfAvailable(
      const std::function<void(blasHandle_t)>& callback) const {
373
    if (cublas_tensor_core_handle_) {
W
Wilber 已提交
374
      cublas_tensor_core_handle_->Call(callback);
375
    } else {
W
Wilber 已提交
376
      cublas_handle_->Call(callback);
377 378 379 380 381 382
    }
  }

 private:
  void InitEigenContext();

383 384 385 386 387
#ifdef PADDLE_WITH_HIP
  void InitCuBlasContext() {
    cublas_handle_.reset(new CublasHandleHolder(RawStream()));
  }
#else
388 389 390 391 392 393 394
  void InitCuBlasContext() {
    cublas_handle_.reset(
        new CublasHandleHolder(RawStream(), CUBLAS_DEFAULT_MATH));
    if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
      cublas_tensor_core_handle_.reset(
          new CublasHandleHolder(RawStream(), CUBLAS_TENSOR_OP_MATH));
395 396 397 398 399
#if CUDA_VERSION >= 11000
      cublas_tf32_tensor_core_handle_.reset(
          new CublasHandleHolder(RawStream(), CUBLAS_TF32_TENSOR_OP_MATH));
#endif  // CUDA_VERSION >= 11000
#endif  // CUDA_VERSION >= 9000
400 401
    }
  }
402
#endif
403

Z
zhangkaihuo 已提交
404
#ifndef PADDLE_WITH_HIP
405 406 407 408 409 410
#if CUDA_VERSION >= 11060
  void InitCuBlasLtContext() {
    cublaslt_handle_.reset(new CublasLtHandleHolder());
  }
#endif

Z
zhangkaihuo 已提交
411 412 413 414 415
  void InitCuSparseContext() {
    cusparse_handle_.reset(new CusparseHandleHolder(RawStream()));
  }
#endif

416 417
  void InitCuDNNContext() {
    if (dynload::HasCUDNN()) {
418 419
#ifdef PADDLE_WITH_HIP
      size_t miopen_major, miopen_minor, miopen_patch;
420
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenGetVersion(
421 422
          &miopen_major, &miopen_minor, &miopen_patch));
      auto local_miopen_version =
423 424
          (miopen_major * 1000 + miopen_minor * 10 + miopen_patch) / 10;
      auto compile_miopen_version = MIOPEN_VERSION / 10;
425 426 427 428
      if (local_miopen_version < static_cast<size_t>(compile_miopen_version)) {
        LOG_FIRST_N(WARNING, 1)
            << "WARNING: device: " << place_.device
            << ". The installed Paddle is compiled with MIOPEN "
429 430
            << compile_miopen_version / 100 << "."
            << compile_miopen_version % 100
431
            << ", but MIOPEN version in your machine is "
432
            << local_miopen_version / 100 << "." << local_miopen_version % 100
433 434 435 436
            << ", which may cause serious incompatible bug. "
            << "Please recompile or reinstall Paddle with compatible MIOPEN "
               "version.";
      }
437 438
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreate(&cudnn_handle_));
      PADDLE_ENFORCE_GPU_SUCCESS(
439 440
          dynload::miopenSetStream(cudnn_handle_, RawStream()));
#else
441 442 443 444 445 446 447 448 449 450 451 452 453
      auto local_cudnn_version = dynload::cudnnGetVersion() / 100;
      auto compile_cudnn_version = CUDNN_VERSION / 100;
      if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
        LOG_FIRST_N(WARNING, 1)
            << "WARNING: device: " << place_.device
            << ". The installed Paddle is compiled with CUDNN "
            << compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
            << ", but CUDNN version in your machine is "
            << local_cudnn_version / 10 << "." << local_cudnn_version % 10
            << ", which may cause serious incompatible bug. "
            << "Please recompile or reinstall Paddle with compatible CUDNN "
               "version.";
      }
454 455
      PADDLE_RETRY_CUDA_SUCCESS(dynload::cudnnCreate(&cudnn_handle_));
      PADDLE_RETRY_CUDA_SUCCESS(
456
          dynload::cudnnSetStream(cudnn_handle_, RawStream()));
457
#endif
458 459 460 461 462
    } else {
      cudnn_handle_ = nullptr;
    }
  }

463
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
464
  void InitCuSolverContext() {
465 466
    PADDLE_RETRY_CUDA_SUCCESS(dynload::cusolverDnCreate(&cusolver_dn_handle_));
    PADDLE_RETRY_CUDA_SUCCESS(
G
Guo Sheng 已提交
467 468
        dynload::cusolverDnSetStream(cusolver_dn_handle_, RawStream()));
  }
469
#endif
G
Guo Sheng 已提交
470

471 472
  void DestoryCuDNNContext() {
    if (cudnn_handle_) {
473
#ifdef PADDLE_WITH_HIP
474
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenDestroy(cudnn_handle_));
475
#else
476
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::cudnnDestroy(cudnn_handle_));
477
#endif
478 479 480 481 482 483 484
    }
    cudnn_handle_ = nullptr;
  }

  void DestoryCuBlasContext() {
    cublas_handle_.reset();
    cublas_tensor_core_handle_.reset();
485
    cublas_tf32_tensor_core_handle_.reset();
486 487
  }

Z
zhangkaihuo 已提交
488
#ifndef PADDLE_WITH_HIP
489 490 491 492
#if CUDA_VERSION >= 11060
  void DestoryCuBlasLtContext() { cublaslt_handle_.reset(); }
#endif

Z
zhangkaihuo 已提交
493 494 495
  void DestoryCuSparseContext() { cusparse_handle_.reset(); }
#endif

496
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
497 498
  void DestoryCuSolverContext() {
    if (cusolver_dn_handle_) {
499
      PADDLE_ENFORCE_GPU_SUCCESS(
G
Guo Sheng 已提交
500 501 502
          dynload::cusolverDnDestroy(cusolver_dn_handle_));
    }
  }
503
#endif
G
Guo Sheng 已提交
504

505 506 507 508
  CUDAPlace place_;
  std::unique_ptr<Eigen::GpuDevice> eigen_device_;
  std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
  std::unique_ptr<stream::CUDAStream> stream_;
509 510 511
#ifdef PADDLE_WITH_HIP
  miopenHandle_t cudnn_handle_;
#else
512
  cudnnHandle_t cudnn_handle_;
513
#endif
514 515
  std::unique_ptr<CublasHandleHolder> cublas_handle_;
  std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
516
  std::unique_ptr<CublasHandleHolder> cublas_tf32_tensor_core_handle_;
517
#ifndef PADDLE_WITH_HIP
518 519 520
#if CUDA_VERSION >= 11060
  std::unique_ptr<CublasLtHandleHolder> cublaslt_handle_;
#endif
G
Guo Sheng 已提交
521
  cusolverDnHandle_t cusolver_dn_handle_;
Z
zhangkaihuo 已提交
522
  std::unique_ptr<CusparseHandleHolder> cusparse_handle_;
523
#endif
524 525 526
  DISABLE_COPY_AND_ASSIGN(CUDAContext);
};

527
class CUDADeviceContext : public phi::GPUContext {
Q
QI JUN 已提交
528
 public:
D
dzhwinter 已提交
529
  explicit CUDADeviceContext(CUDAPlace place);
530
  virtual ~CUDADeviceContext();
Q
QI JUN 已提交
531

532
  /*! \brief  Wait for all operations completion in the stream. */
533
  void Wait() const override;
Q
QI JUN 已提交
534

535 536 537
  /*! \brief  Return eigen device in the device context. */
  Eigen::GpuDevice* eigen_device() const;

538
  /*! \brief  Call cublas function safely. */
W
Wilber 已提交
539 540 541
  inline void CublasCall(
      const std::function<void(blasHandle_t)>& callback) const {
    if (!thread_ctx_.count(this)) {
542
      phi::GPUContext::CublasCall(callback);
W
Wilber 已提交
543 544
      return;
    }
545
    return context()->CublasCall(callback);
546 547
  }

Z
zhangkaihuo 已提交
548 549
#ifndef PADDLE_WITH_HIP
  /*! \brief  Call cusparse function safely. */
W
Wilber 已提交
550
  inline void CusparseCall(
551
      const std::function<void(phi::sparseHandle_t)>& callback) const {
W
Wilber 已提交
552
    if (!thread_ctx_.count(this)) {
553
      phi::GPUContext::CusparseCall(callback);
W
Wilber 已提交
554 555 556
      return;
    }
    context()->CusparseCall(callback);
Z
zhangkaihuo 已提交
557 558 559
  }
#endif

560 561
  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
W
Wilber 已提交
562 563 564
  inline void TensorCoreCublasCallIfAvailable(
      const std::function<void(blasHandle_t)>& callback) const {
    if (!thread_ctx_.count(this)) {
565
      phi::GPUContext::TensorCoreCublasCallIfAvailable(callback);
W
Wilber 已提交
566 567 568
      return;
    }
    context()->TensorCoreCublasCallIfAvailable(callback);
569
  }
S
sneaxiy 已提交
570

571 572 573 574
/*! \brief  Return cudnn  handle in the device context. */
#ifdef PADDLE_WITH_HIP
  miopenHandle_t cudnn_handle() const;
#else
575
  cudnnHandle_t cudnn_handle() const;
576
#endif
577

578 579 580 581
/*! \brief  Return cublas handle in the device context. */
#ifdef PADDLE_WITH_HIP
  rocblas_handle cublas_handle() const;
#else
582
  cublasHandle_t cublas_handle() const;
583
  cublasLtHandle_t cublaslt_handle() const;
Z
zhangkaihuo 已提交
584
  cusparseHandle_t cusparse_handle() const;
585
#endif
586

W
Wilber 已提交
587 588 589 590
#ifndef PADDLE_WITH_HIP
  cusolverDnHandle_t cusolver_dn_handle() const;
#endif

S
sneaxiy 已提交
591 592 593 594 595 596 597
  /*! \brief  Return a cudnn workspace handle to call multiple cudnn
   *  functions without interrupting by other threads.
   *  Once the first cudnn function is called by the handle, a lock
   *  would be acquired to prevent other threads from accessing the
   *  workspace. Once the handle is destructed, the lock would be released.
   *  CudnnWorkspaceHandle is an RAII object to implement thread-safe
   *  sequential cudnn function calls. */
598
  phi::DnnWorkspaceHandle cudnn_workspace_handle() const;
S
sneaxiy 已提交
599

Q
init  
qijun 已提交
600
  /*! \brief  Return cuda stream in the device context. */
601
  gpuStream_t stream() const;
Q
QI JUN 已提交
602

W
Wilber 已提交
603
  void RecordEvent(gpuEvent_t ev, const std::function<void()>& callback) const;
604

W
Wilber 已提交
605
  void AddStreamCallback(const std::function<void()>& callback) const;
606

W
Wilber 已提交
607
  void WaitStreamCallback() const;
608

609
  void ResetThreadContext(const stream::Priority& priority) {
610
    std::lock_guard<std::mutex> guard(ctx_mtx_);
W
Wilber 已提交
611
    thread_ctx_[this].reset(new CUDAContext(this->GetPlace(), priority));
612 613
  }

W
Wilber 已提交
614
  std::shared_ptr<CUDAContext> context() const;
S
sneaxiy 已提交
615

W
Wilber 已提交
616 617 618 619 620
  // Note: Can only be used under thread_local semantics.
  void SetThreadLocalStream(const gpuStream_t stream) {
    thread_ctx_.at(this)->SetStream(stream);
  }

W
Wilber 已提交
621 622 623 624
  // NOTE: Just for compatibility with the past, please delete if there is an
  // elegant way.
  stream::CUDAStream* GetCudaStream() const;
  stream::CUDAStream* SetCudaStream(stream::CUDAStream*);
Q
QI JUN 已提交
625

W
Wilber 已提交
626
 private:
627 628 629 630 631 632
  // The thread_local static variable will be released before the
  // global static variable, so avoid using it in dtor.
  static thread_local std::unordered_map<const CUDADeviceContext*,
                                         std::shared_ptr<CUDAContext>>
      thread_ctx_;
  static thread_local std::mutex ctx_mtx_;
633

634 635
  mutable std::mutex cudnn_handle_mtx_;

W
Wilber 已提交
636 637 638
  // NOTE: Just for compatibility with the past, please delete if there is an
  // elegant way.
  std::unique_ptr<stream::CUDAStream> cuda_stream_;
Y
yuyang18 已提交
639

640
  DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
Q
QI JUN 已提交
641
};
Q
qijun 已提交
642

643 644
class CudnnWorkspaceHandle {
 public:
645 646
  inline CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx, std::mutex* mtx)
      : device_context_(dev_ctx), mtx_(mtx) {}
647 648 649 650 651 652 653 654

  template <typename Callback>
  inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_bytes) {
    if (required_workspace_bytes > WorkspaceSize()) {
      ReallocWorkspace(required_workspace_bytes);
    }
    VLOG(2) << "Cudnn workspace size at RunFunc: "
            << static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB";
655 656 657 658
    {
      std::lock_guard<std::mutex> guard(*mtx_);
      cudnn_func(allocation_ ? allocation_->ptr() : nullptr);
    }
659 660 661 662 663 664 665 666 667 668 669 670 671
  }

  /*! \brief Thread which call RunFuncSync() would release gpu memory after
   *  running the function. Currently this function is only used when cudnn
   *  exhaustive searching and callers have to guarantee that the input function
   *  is host blocking */
  template <typename Callback>
  inline void RunFuncSync(Callback&& cudnn_func,
                          size_t required_workspace_bytes) {
    RunFunc(cudnn_func, required_workspace_bytes);
    ResetWorkspace();
  }

672
  void ReallocWorkspace(size_t required_workspace_bytes);
673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688

  inline void ResetWorkspace() { allocation_ = nullptr; }

  inline size_t WorkspaceSize() {
    if (allocation_ == nullptr) {
      return 0;
    }
    return allocation_->size();
  }

  CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
  CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;

 private:
  memory::allocation::AllocationPtr allocation_;
  const CUDADeviceContext& device_context_;
689
  std::mutex* mtx_;
690 691
};

Y
Yang Yu 已提交
692 693
template <>
struct DefaultDeviceContextType<platform::CUDAPlace> {
Y
Yang Yu 已提交
694
  using TYPE = CUDADeviceContext;
Y
Yang Yu 已提交
695 696
};

C
chengduoZH 已提交
697
// Currently, CUDAPinnedDeviceContext is only used to data copying.
C
chengduoZH 已提交
698 699 700 701 702
class CUDAPinnedDeviceContext : public DeviceContext {
 public:
  CUDAPinnedDeviceContext();
  explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place);

W
Wilber 已提交
703
  const Place& GetPlace() const override;
C
chengduoZH 已提交
704

C
chengduoZH 已提交
705 706 707 708 709 710 711 712 713 714 715
  Eigen::DefaultDevice* eigen_device() const;

 private:
  CUDAPinnedPlace place_;
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

template <>
struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
  using TYPE = CUDAPinnedDeviceContext;
};
Q
QI JUN 已提交
716
#endif
Q
qijun 已提交
717

T
tensor-tang 已提交
718
#ifdef PADDLE_WITH_MKLDNN
719 720 721 722 723 724

class MKLDNNDeviceContextThreadLocals {
  // default mkldnn session id

  typedef MKLDNNDeviceContextThreadLocals self;
  struct Body {
725
    bool said_once = false;
726 727 728 729 730 731 732 733 734 735 736
    size_t cur_mkldnn_session_id;
    // Current data input shape string.
    // - For fixed-shape, it's a null string in default.
    // - For dynamic-shape, it's user specific.
    std::string cur_input_shape_str;
    // the cache capacity of different input shapes for MKLDNN.
    // Default 1 means fixed input shape, not dynamic shape.
    int cur_input_shape_cache_capacity;
    // Recently registered data_format. This is needed to
    // know for converting MKL-DNN Tensor to non MKL-DNN
    paddle::framework::DataLayout cur_paddle_data_layout;
737
    // MKL-DNN stream used for execution of primitives (per-thread)
738 739
    dnnl::engine cur_engine;
    dnnl::stream cur_stream;
J
Jacek Czaja 已提交
740 741
    std::string key_suffix;  // Key identifying current Executor
    bool key_attach_thread_id = true;
742
    void* exec_ptr_ = nullptr;
743 744

    Body();
745
    ~Body();
746 747 748 749 750 751
    void set_cur_mkldnn_session_id(size_t sid);
    size_t get_cur_mkldnn_session_id(void);
    void set_cur_input_shape_str(std::string input_shape_str);
    void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
    void set_cur_paddle_data_layout(framework::DataLayout dl);
    framework::DataLayout get_cur_paddle_data_layout(void);
752
    void log_lib_version(void);
753 754
    const dnnl::engine& get_engine(void);
    dnnl::stream& get_stream(void);
J
Jacek Czaja 已提交
755 756 757 758
    void set_key_suffix(const std::string& suffix) { key_suffix = suffix; }
    const std::string& get_key_suffix(void) const { return key_suffix; }
    void disable_tid_in_key(void) { key_attach_thread_id = false; }
    bool is_tid_used_in_key(void) const { return key_attach_thread_id; }
759 760
    void set_curr_exec(void* exec_ptr) { exec_ptr_ = exec_ptr; }
    void* get_curr_exec(void) const { return exec_ptr_; }
761 762 763 764 765 766 767 768 769 770 771 772 773 774 775
  };
  MKLDNNDeviceContextThreadLocals() = default;
  MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
      delete;

 public:
  // default mkldnn session id
  static constexpr size_t kMKLDNNSessionID_Default = 0;
  // mkldnn session id for cache clearing mode
  static constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
  static Body& fetch() {
    thread_local Body b;
    return b;
  }
};
S
Sylwester Fraczek 已提交
776

L
Leo Chen 已提交
777
class MKLDNNDeviceContext : public phi::CPUContext {
T
tensor-tang 已提交
778
 public:
779 780 781 782 783 784 785 786 787 788
  template <class T>
  using BlobPtr_t = std::shared_ptr<T>;
  template <class P1, class P2>
  using umap_value_smart_t = std::unordered_map<P1, BlobPtr_t<P2>>;
  template <class T>
  using umap_key_string_t = umap_value_smart_t<std::string, T>;

  // Following three maps are used to cache MKLDNN primitives.
  // There relations are:
  // - BlobMap = Map<cur_thread_id, ShapeBlob>
789
  // - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
790 791 792
  // - KeyBlob  = Map<blob_name, blob>

  using KeyBlob = umap_key_string_t<void>;
793
  using ShapeBlob = umap_key_string_t<KeyBlob>;
794 795
  using BlobMap = umap_value_smart_t<int, ShapeBlob>;

796 797 798 799
  // Auxillary two-level structure (shape, executor) to easier control
  // clearing cache objects related to specific executor

  using ExecKey = void*;
800
  using ExecMapCacheIterPair = std::pair<BlobPtr_t<KeyBlob>, KeyBlob::iterator>;
801 802 803
  using ExecMap =
      std::unordered_map<ExecKey, std::vector<ExecMapCacheIterPair>>;
  using ExecShape = std::unordered_map<std::string, std::shared_ptr<ExecMap>>;
804

T
tensor-tang 已提交
805 806 807
  explicit MKLDNNDeviceContext(CPUPlace place);

  /* \brief  Get the active engine */
808
  const dnnl::engine& GetEngine() const { return tls().get_engine(); }
T
tensor-tang 已提交
809

810
  // Register object to currently used executor's map
811 812
  void LinkEntryWithExecutor(BlobPtr_t<KeyBlob>, KeyBlob::iterator) const;
  void RemoveShapeEntriesWithExecutor(void) const;
813

814
  // Remove all entries from the blob map
815
  void ResetBlobMap(void* ptr);
816 817 818

  // Prevent next ResetBlobMap()
  void BlockNextCacheClearing();
819

820 821 822
  // Get the ShapeBlob size in cur_mkldnn_session_id.
  size_t GetShapeBlobSize() const;

823 824
  // Set data to blob (i.e. name/data pair). Create blob if not existing
  void SetBlob(const std::string& name, std::shared_ptr<void> data) const;
T
tensor-tang 已提交
825

826
  // Calculate number of oneDNN objects cached
827
  unsigned int GetCachedObjectsNumber(void) const;
828

829 830
  // Find a saved blob. Return nullptr if not found
  std::shared_ptr<void> GetBlob(const std::string& name) const;
T
tensor-tang 已提交
831

832 833 834 835
  static auto tls() -> decltype(MKLDNNDeviceContextThreadLocals::fetch()) {
    return MKLDNNDeviceContextThreadLocals::fetch();
  }

T
tensor-tang 已提交
836
 private:
837
  std::shared_ptr<BlobMap> p_blobmap_;
838 839
  // Map key is pointer of executor and value is a data(iterator in map) needed
  // to erase
840
  std::shared_ptr<ExecShape> p_exec_items_;
841
  std::shared_ptr<std::mutex> p_mutex_;
842 843
  // 0 - clearing is allowed. x > 0 do not clear.
  unsigned int block_next_cache_clearing_ = 0;
T
tensor-tang 已提交
844 845 846
};
#endif

847
#ifdef PADDLE_WITH_CUSTOM_DEVICE
848
class CustomDeviceContext : public phi::CustomContext {
849 850 851 852 853 854 855 856 857 858 859 860 861 862
 public:
  explicit CustomDeviceContext(CustomPlace place);
  virtual ~CustomDeviceContext();

  Eigen::DefaultDevice* eigen_device() const { return nullptr; }

  template <typename Callback>
  void AddStreamCallback(Callback&& callback) const {
    return stream_->AddCallback(callback);
  }

  void WaitStreamCallback() const { return stream_->WaitCallback(); }

 private:
863
  std::shared_ptr<phi::stream::Stream> stream_;
864 865 866 867 868 869 870 871 872 873 874 875
};
template <>
struct DefaultDeviceContextType<platform::CustomPlace> {
  using TYPE = CustomDeviceContext;
};
#else
template <>
struct DefaultDeviceContextType<platform::CustomPlace> {
  using TYPE = DeviceContext;
};
#endif

876 877 878 879 880 881
void EmplaceDeviceContexts(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        place_to_device_context,
    const std::vector<platform::Place>& places,
    bool disable_setting_default_stream_for_allocator);

D
dzhwinter 已提交
882 883 884
/*! \brief device context pool singleton */
class DeviceContextPool {
 public:
Y
Yang Yu 已提交
885
  static DeviceContextPool& Instance() {
G
GaoWei8 已提交
886 887 888
    PADDLE_ENFORCE_NOT_NULL(pool,
                            platform::errors::PreconditionNotMet(
                                "Need to Create DeviceContextPool firstly!"));
D
dzhwinter 已提交
889 890 891 892
    return *pool;
  }

  /*! \brief  Create should only called by Init function */
Y
Yang Yu 已提交
893
  static DeviceContextPool& Init(const std::vector<platform::Place>& places) {
D
dzhwinter 已提交
894 895 896 897 898 899
    if (pool == nullptr) {
      pool = new DeviceContextPool(places);
    }
    return *pool;
  }

900 901
  static bool IsInitialized() { return pool != nullptr; }

902 903
  static void SetPool(DeviceContextPool* dev_pool) { pool = dev_pool; }

D
dzhwinter 已提交
904
  /*! \brief  Return handle of single device context. */
Y
Yu Yang 已提交
905
  platform::DeviceContext* Get(const platform::Place& place);
D
dzhwinter 已提交
906

Y
Yang Yu 已提交
907 908 909 910 911 912 913
  template <typename Place>
  const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
      const Place& place) {
    return reinterpret_cast<
        const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place));
  }

914
  size_t size() const;
915

916
  const std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>&
917 918 919 920 921
  device_contexts() const;

  static void SetDeviceContexts(
      const std::map<Place,
                     std::shared_future<std::unique_ptr<DeviceContext>>>*);
922

D
dzhwinter 已提交
923
 private:
924 925
  explicit DeviceContextPool(const std::vector<platform::Place>& places);

D
dzhwinter 已提交
926
  static DeviceContextPool* pool;
927 928
  std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>
      device_contexts_;
929 930 931
  static thread_local const std::
      map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
          external_device_contexts_;  // not owned
D
dzhwinter 已提交
932 933 934
  DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};

Q
QI JUN 已提交
935 936
}  // namespace platform
}  // namespace paddle