device_context.h 28.8 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3
Copyright (c) 2022 NVIDIA Corporation. All rights reserved.

Q
QI JUN 已提交
4 5 6 7 8 9 10 11 12 13 14
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

W
Wilber 已提交
15
#include <functional>
16
#include <future>  // NOLINT
D
dzhwinter 已提交
17
#include <memory>
Y
yuyang18 已提交
18
#include <mutex>  // NOLINT
19
#include <string>
D
dzhwinter 已提交
20
#include <unordered_map>
21
#include <utility>
22
#include <vector>
W
wanghuancoder 已提交
23

24
#include "paddle/fluid/memory/malloc.h"
W
Wilber 已提交
25
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
26
#include "paddle/phi/backends/cpu/cpu_context.h"
27
#include "paddle/phi/backends/custom/custom_context.h"
28 29
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/core/device_context.h"
30
#ifdef PADDLE_WITH_CUDA
31
#include "paddle/fluid/platform/device/gpu/gpu_helper.h"
Y
Yi Wang 已提交
32
#include "paddle/fluid/platform/dynload/cublas.h"
33
#include "paddle/fluid/platform/dynload/cublasLt.h"
Y
Yi Wang 已提交
34
#include "paddle/fluid/platform/dynload/cudnn.h"
G
Guo Sheng 已提交
35
#include "paddle/fluid/platform/dynload/cusolver.h"
36
#include "paddle/fluid/platform/dynload/cusparse.h"
37
#include "paddle/phi/backends/gpu/gpu_context.h"
38
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
W
Wu Yi 已提交
39
#include "paddle/fluid/platform/dynload/nccl.h"
W
Wu Yi 已提交
40
#endif
41
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
Q
QI JUN 已提交
42
#endif
D
dzhwinter 已提交
43

44
#ifdef PADDLE_WITH_HIP
45
#include "paddle/fluid/platform/device/gpu/gpu_helper.h"  // NOLINT
46 47
#include "paddle/fluid/platform/dynload/miopen.h"
#include "paddle/fluid/platform/dynload/rocblas.h"
48
#include "paddle/phi/backends/gpu/gpu_context.h"  // NOLINT
49 50 51
#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/dynload/rccl.h"
#endif
52
#include "paddle/fluid/platform/device/gpu/gpu_info.h"  // NOLINT
53 54
#endif

55 56 57 58
#if defined(PADDLE_WITH_XPU_BKCL)
#include "xpu/bkcl.h"
#endif

T
tensor-tang 已提交
59
#ifdef PADDLE_WITH_MKLDNN
60
#include "dnnl.hpp"  // NOLINT
61
#include "paddle/fluid/framework/data_layout.h"
T
tensor-tang 已提交
62 63
#endif

64
#include <map>
W
wanghuancoder 已提交
65

66
#include "glog/logging.h"
Y
Yi Wang 已提交
67 68
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
69
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
70
#include "paddle/fluid/platform/stream/cuda_stream.h"
S
sneaxiy 已提交
71
#endif
72
#ifdef PADDLE_WITH_ASCEND_CL
73 74
#include "paddle/fluid/platform/device/npu/enforce_npu.h"
#include "paddle/fluid/platform/device/npu/npu_stream.h"
75
#endif
76

77 78
#include "paddle/phi/backends/device_ext.h"
#include "paddle/phi/backends/stream.h"
79 80

#if !defined(PADDLE_WITH_XPU_KP) || defined(__xpu_on_host__)
Q
qijun 已提交
81
#include "unsupported/Eigen/CXX11/Tensor"
82
#endif
Q
QI JUN 已提交
83

W
wanghuancoder 已提交
84 85 86 87 88
namespace Eigen {
struct DefaultDevice;
struct GpuDevice;
}  // namespace Eigen

89
#ifdef PADDLE_WITH_XPU
90 91
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
92
#include "paddle/phi/backends/xpu/xpu_context.h"
93 94
#endif

95 96
#ifdef PADDLE_WITH_ASCEND_CL
#include "acl/acl.h"
97
#include "paddle/fluid/platform/device/npu/npu_info.h"
98 99
#endif

Q
QI JUN 已提交
100 101 102
namespace paddle {
namespace platform {

103
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
104 105 106 107
/*Set the value of the global variable allow_tf32_cublas*/
void SetAllowTF32Cublas(bool active);
/*Get the global variable allow_tf32_cublas value*/
bool AllowTF32Cublas();
A
AshburnLee 已提交
108
extern bool allow_tf32_cudnn;
A
AshburnLee 已提交
109 110 111 112
/*Set the value of the global variable allow_tf32_cudnn*/
void SetAllowTF32Cudnn(bool active);
/*Get the global variable allow_tf32_cudnn value*/
bool AllowTF32Cudnn();
113 114
#endif  // PADDLE_WITH_CUDA

115 116 117 118
enum DeviceType {
  CPU = 0,
  CUDA = 1,
  XPU = 2,
119
  NPU = 3,
J
jianghaicheng 已提交
120
  IPU = 4,
F
fwenguang 已提交
121 122 123
  MLU = 5,

  MAX_DEVICE_TYPES = 6,
124 125
};

126 127
DeviceType Place2DeviceType(const platform::Place& place);

128 129 130
constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUDA = DeviceType::CUDA;
constexpr DeviceType kXPU = DeviceType::XPU;
131
constexpr DeviceType kNPU = DeviceType::NPU;
J
jianghaicheng 已提交
132
constexpr DeviceType kIPU = DeviceType::IPU;
F
fwenguang 已提交
133
constexpr DeviceType kMLU = DeviceType::MLU;
134

135
using DeviceContext = phi::DeviceContext;
L
Leo Chen 已提交
136
using CPUDeviceContext = phi::CPUContext;
Q
QI JUN 已提交
137

Y
Yang Yu 已提交
138 139 140 141 142
template <typename Place>
struct DefaultDeviceContextType;

template <>
struct DefaultDeviceContextType<platform::CPUPlace> {
L
Leo Chen 已提交
143
  using TYPE = phi::CPUContext;
Y
Yang Yu 已提交
144 145
};

J
jianghaicheng 已提交
146 147 148 149 150 151 152 153
// Graphcore IPU
#ifdef PADDLE_WITH_IPU
class IPUDeviceContext : public DeviceContext {
 public:
  IPUDeviceContext() = delete;
  explicit IPUDeviceContext(IPUPlace place);
  virtual ~IPUDeviceContext();
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
W
Wilber 已提交
154
  const Place& GetPlace() const override;
J
jianghaicheng 已提交
155 156 157 158 159 160 161 162 163 164
  /*! \brief  Wait for all operations completion in the stream. */
  void Wait() const override;

 private:
  IPUPlace place_;
};
template <>
struct DefaultDeviceContextType<platform::IPUPlace> {
  using TYPE = IPUDeviceContext;
};
F
fwenguang 已提交
165
#endif
J
jianghaicheng 已提交
166

F
fwenguang 已提交
167 168 169 170 171
#ifdef PADDLE_WITH_MLU
class MLUDeviceContext;

template <>
struct DefaultDeviceContextType<platform::MLUPlace>;
J
jianghaicheng 已提交
172 173
#endif

174
#ifdef PADDLE_WITH_XPU
Q
QingshuChen 已提交
175
namespace xpu = baidu::xpu::api;
176
class XPUDeviceContext : public phi::XPUContext {
177 178 179 180 181
 public:
  XPUDeviceContext();
  explicit XPUDeviceContext(XPUPlace place);
  virtual ~XPUDeviceContext();
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
182
  xpuStream stream() const { return XPUContext::x_context()->xpu_stream; }
183 184 185 186 187 188 189 190
};

template <>
struct DefaultDeviceContextType<platform::XPUPlace> {
  using TYPE = XPUDeviceContext;
};
#endif

191 192 193 194 195 196
#ifdef PADDLE_WITH_ASCEND_CL
class NPUDeviceContext : public DeviceContext {
 public:
  explicit NPUDeviceContext(NPUPlace place);
  virtual ~NPUDeviceContext();
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
W
Wilber 已提交
197
  const Place& GetPlace() const override;
198
  aclrtContext context() const;
199

200 201 202 203 204 205
  /*! \brief  Wait for all operations completion in the stream. */
  void Wait() const override;

  /*! \brief  Return npu stream in the device context. */
  aclrtStream stream() const;

206 207 208 209 210 211 212
  template <typename Callback>
  void AddStreamCallback(Callback&& callback) const {
    return stream_->AddCallback(callback);
  }

  void WaitStreamCallback() const { return stream_->WaitCallback(); }

213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
#if defined(PADDLE_WITH_ASCEND_CL)
  /*! \brief  Return hccl communicators. */
  HcclComm hccl_comm() const { return hccl_comm_; }

  /*! \brief  Set hccl communicators. */
  void set_hccl_comm(HcclComm comm) { hccl_comm_ = comm; }
#endif

  // template <typename Callback>
  // void AddStreamCallback(Callback&& callback) const {
  //   return stream_->AddCallback(callback);
  // }

  // void WaitStreamCallback() const { return stream_->WaitCallback(); }

228 229 230
 private:
  NPUPlace place_;
  aclrtContext context_;
231 232 233 234

#ifdef PADDLE_WITH_ASCEND_CL
  // HCCLContext_t hccl_context_;
  HcclComm hccl_comm_{nullptr};
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
#endif

  // Need to be the same with other DeviceContext,
  // Eventhough eigen_device_ is not used in NPU
  // NOTE(zhiqiu): why need?
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
  std::shared_ptr<stream::NPUStream> stream_;

  DISABLE_COPY_AND_ASSIGN(NPUDeviceContext);
};

template <>
struct DefaultDeviceContextType<platform::NPUPlace> {
  using TYPE = NPUDeviceContext;
};
250 251 252 253 254 255 256

// Currently, NPUPinnedDeviceContext is only used to data copying.
class NPUPinnedDeviceContext : public DeviceContext {
 public:
  NPUPinnedDeviceContext();
  explicit NPUPinnedDeviceContext(NPUPinnedPlace place);

W
Wilber 已提交
257
  const Place& GetPlace() const override;
258 259 260 261 262 263 264 265 266 267 268 269 270

  Eigen::DefaultDevice* eigen_device() const;

 private:
  NPUPinnedPlace place_;
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

template <>
struct DefaultDeviceContextType<platform::NPUPinnedPlace> {
  using TYPE = NPUPinnedDeviceContext;
};

271 272 273
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
274
class CudnnWorkspaceHandle;
W
wanghuancoder 已提交
275
class EigenCudaStreamDevice;
S
sneaxiy 已提交
276

277 278 279 280 281
class CUDAContext {
 public:
  CUDAContext() = default;
  explicit CUDAContext(
      const CUDAPlace& place,
282 283
      const stream::Priority& priority = stream::Priority::kNormal,
      const stream::StreamFlag& flag = stream::StreamFlag::kDefaultFlag);
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298

  ~CUDAContext();

  const CUDAPlace& Place() const { return place_; }

  const std::unique_ptr<Eigen::GpuDevice>& EigenDevice() const {
    return eigen_device_;
  }

  const std::unique_ptr<EigenCudaStreamDevice>& EigenStream() const {
    return eigen_stream_;
  }

  const std::unique_ptr<stream::CUDAStream>& Stream() const { return stream_; }

299 300 301 302 303 304
  stream::CUDAStream* SetStream(stream::CUDAStream* new_stream_ptr) {
    auto* old_stream_ptr = stream_.release();
    stream_.reset(new_stream_ptr);
    return old_stream_ptr;
  }

W
Wilber 已提交
305 306
  void SetStream(gpuStream_t stream);

307
  const gpuStream_t& RawStream() { return stream_->raw_stream(); }
308

309 310 311
#ifdef PADDLE_WITH_HIP
  const miopenHandle_t& CudnnHandle() const { return cudnn_handle_; }
#else
312
  const cudnnHandle_t& CudnnHandle() const { return cudnn_handle_; }
313
#endif
314

315
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
316 317 318
  const cusolverDnHandle_t& CusolverDnHandle() const {
    return cusolver_dn_handle_;
  }
319
#endif
G
Guo Sheng 已提交
320

321 322 323 324 325 326 327 328
  const std::unique_ptr<CublasHandleHolder>& CublasHandle() const {
    return cublas_handle_;
  }

  const std::unique_ptr<CublasHandleHolder>& CublasTensorCoreHandle() const {
    return cublas_tensor_core_handle_;
  }

Z
zhangkaihuo 已提交
329
#ifndef PADDLE_WITH_HIP
330 331 332 333 334 335
#if CUDA_VERSION >= 11060
  const std::unique_ptr<CublasLtHandleHolder>& CublasLtHandle() const {
    return cublaslt_handle_;
  }
#endif

Z
zhangkaihuo 已提交
336 337 338 339 340
  const std::unique_ptr<CusparseHandleHolder>& CusparseHandle() const {
    return cusparse_handle_;
  }
#endif

341
  /*! \brief  Call cublas function safely. */
W
Wilber 已提交
342 343
  inline void CublasCall(
      const std::function<void(blasHandle_t)>& callback) const {
344
    if (cublas_tf32_tensor_core_handle_) {
W
Wilber 已提交
345
      cublas_tf32_tensor_core_handle_->Call(callback);
346
    } else {
W
Wilber 已提交
347
      cublas_handle_->Call(callback);
348
    }
349 350
  }

Z
zhangkaihuo 已提交
351
#ifndef PADDLE_WITH_HIP
352 353 354 355 356 357 358 359
#if CUDA_VERSION >= 11060
  /*! \brief  Call cublasLt function safely. */
  inline void CublasLtCall(
      const std::function<void(blasLtHandle_t)>& callback) const {
    cublaslt_handle_->Call(callback);
  }
#endif

Z
zhangkaihuo 已提交
360
  /*! \brief  Call cusparse function safely. */
W
Wilber 已提交
361
  inline void CusparseCall(
362
      const std::function<void(phi::sparseHandle_t)>& callback) const {
W
Wilber 已提交
363
    cusparse_handle_->Call(callback);
Z
zhangkaihuo 已提交
364 365 366
  }
#endif

367 368 369 370 371
  /*! \brief  Check whether tensor core is supported */
  bool tensor_core_available() const;

  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
W
Wilber 已提交
372 373
  inline void TensorCoreCublasCallIfAvailable(
      const std::function<void(blasHandle_t)>& callback) const {
374
    if (cublas_tensor_core_handle_) {
W
Wilber 已提交
375
      cublas_tensor_core_handle_->Call(callback);
376
    } else {
W
Wilber 已提交
377
      cublas_handle_->Call(callback);
378 379 380 381 382 383
    }
  }

 private:
  void InitEigenContext();

384 385 386 387 388
#ifdef PADDLE_WITH_HIP
  void InitCuBlasContext() {
    cublas_handle_.reset(new CublasHandleHolder(RawStream()));
  }
#else
389 390 391 392 393 394 395
  void InitCuBlasContext() {
    cublas_handle_.reset(
        new CublasHandleHolder(RawStream(), CUBLAS_DEFAULT_MATH));
    if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
      cublas_tensor_core_handle_.reset(
          new CublasHandleHolder(RawStream(), CUBLAS_TENSOR_OP_MATH));
396 397 398 399 400
#if CUDA_VERSION >= 11000
      cublas_tf32_tensor_core_handle_.reset(
          new CublasHandleHolder(RawStream(), CUBLAS_TF32_TENSOR_OP_MATH));
#endif  // CUDA_VERSION >= 11000
#endif  // CUDA_VERSION >= 9000
401 402
    }
  }
403
#endif
404

Z
zhangkaihuo 已提交
405
#ifndef PADDLE_WITH_HIP
406 407 408 409 410 411
#if CUDA_VERSION >= 11060
  void InitCuBlasLtContext() {
    cublaslt_handle_.reset(new CublasLtHandleHolder());
  }
#endif

Z
zhangkaihuo 已提交
412 413 414 415 416
  void InitCuSparseContext() {
    cusparse_handle_.reset(new CusparseHandleHolder(RawStream()));
  }
#endif

417 418
  void InitCuDNNContext() {
    if (dynload::HasCUDNN()) {
419 420
#ifdef PADDLE_WITH_HIP
      size_t miopen_major, miopen_minor, miopen_patch;
421
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenGetVersion(
422 423
          &miopen_major, &miopen_minor, &miopen_patch));
      auto local_miopen_version =
424 425
          (miopen_major * 1000 + miopen_minor * 10 + miopen_patch) / 10;
      auto compile_miopen_version = MIOPEN_VERSION / 10;
426 427 428 429
      if (local_miopen_version < static_cast<size_t>(compile_miopen_version)) {
        LOG_FIRST_N(WARNING, 1)
            << "WARNING: device: " << place_.device
            << ". The installed Paddle is compiled with MIOPEN "
430 431
            << compile_miopen_version / 100 << "."
            << compile_miopen_version % 100
432
            << ", but MIOPEN version in your machine is "
433
            << local_miopen_version / 100 << "." << local_miopen_version % 100
434 435 436 437
            << ", which may cause serious incompatible bug. "
            << "Please recompile or reinstall Paddle with compatible MIOPEN "
               "version.";
      }
438 439
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreate(&cudnn_handle_));
      PADDLE_ENFORCE_GPU_SUCCESS(
440 441
          dynload::miopenSetStream(cudnn_handle_, RawStream()));
#else
442 443 444 445 446 447 448 449 450 451 452 453 454
      auto local_cudnn_version = dynload::cudnnGetVersion() / 100;
      auto compile_cudnn_version = CUDNN_VERSION / 100;
      if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
        LOG_FIRST_N(WARNING, 1)
            << "WARNING: device: " << place_.device
            << ". The installed Paddle is compiled with CUDNN "
            << compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
            << ", but CUDNN version in your machine is "
            << local_cudnn_version / 10 << "." << local_cudnn_version % 10
            << ", which may cause serious incompatible bug. "
            << "Please recompile or reinstall Paddle with compatible CUDNN "
               "version.";
      }
455 456
      PADDLE_RETRY_CUDA_SUCCESS(dynload::cudnnCreate(&cudnn_handle_));
      PADDLE_RETRY_CUDA_SUCCESS(
457
          dynload::cudnnSetStream(cudnn_handle_, RawStream()));
458
#endif
459 460 461 462 463
    } else {
      cudnn_handle_ = nullptr;
    }
  }

464
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
465
  void InitCuSolverContext() {
466 467
    PADDLE_RETRY_CUDA_SUCCESS(dynload::cusolverDnCreate(&cusolver_dn_handle_));
    PADDLE_RETRY_CUDA_SUCCESS(
G
Guo Sheng 已提交
468 469
        dynload::cusolverDnSetStream(cusolver_dn_handle_, RawStream()));
  }
470
#endif
G
Guo Sheng 已提交
471

472 473
  void DestoryCuDNNContext() {
    if (cudnn_handle_) {
474
#ifdef PADDLE_WITH_HIP
475
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenDestroy(cudnn_handle_));
476
#else
477
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::cudnnDestroy(cudnn_handle_));
478
#endif
479 480 481 482 483 484 485
    }
    cudnn_handle_ = nullptr;
  }

  void DestoryCuBlasContext() {
    cublas_handle_.reset();
    cublas_tensor_core_handle_.reset();
486
    cublas_tf32_tensor_core_handle_.reset();
487 488
  }

Z
zhangkaihuo 已提交
489
#ifndef PADDLE_WITH_HIP
490 491 492 493
#if CUDA_VERSION >= 11060
  void DestoryCuBlasLtContext() { cublaslt_handle_.reset(); }
#endif

Z
zhangkaihuo 已提交
494 495 496
  void DestoryCuSparseContext() { cusparse_handle_.reset(); }
#endif

497
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
498 499
  void DestoryCuSolverContext() {
    if (cusolver_dn_handle_) {
500
      PADDLE_ENFORCE_GPU_SUCCESS(
G
Guo Sheng 已提交
501 502 503
          dynload::cusolverDnDestroy(cusolver_dn_handle_));
    }
  }
504
#endif
G
Guo Sheng 已提交
505

506 507 508 509
  CUDAPlace place_;
  std::unique_ptr<Eigen::GpuDevice> eigen_device_;
  std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
  std::unique_ptr<stream::CUDAStream> stream_;
510 511 512
#ifdef PADDLE_WITH_HIP
  miopenHandle_t cudnn_handle_;
#else
513
  cudnnHandle_t cudnn_handle_;
514
#endif
515 516
  std::unique_ptr<CublasHandleHolder> cublas_handle_;
  std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
517
  std::unique_ptr<CublasHandleHolder> cublas_tf32_tensor_core_handle_;
518
#ifndef PADDLE_WITH_HIP
519 520 521
#if CUDA_VERSION >= 11060
  std::unique_ptr<CublasLtHandleHolder> cublaslt_handle_;
#endif
G
Guo Sheng 已提交
522
  cusolverDnHandle_t cusolver_dn_handle_;
Z
zhangkaihuo 已提交
523
  std::unique_ptr<CusparseHandleHolder> cusparse_handle_;
524
#endif
525 526 527
  DISABLE_COPY_AND_ASSIGN(CUDAContext);
};

528
class CUDADeviceContext : public phi::GPUContext {
Q
QI JUN 已提交
529
 public:
D
dzhwinter 已提交
530
  explicit CUDADeviceContext(CUDAPlace place);
531
  virtual ~CUDADeviceContext();
Q
QI JUN 已提交
532

533
  /*! \brief  Wait for all operations completion in the stream. */
534
  void Wait() const override;
Q
QI JUN 已提交
535

536 537 538
  /*! \brief  Return eigen device in the device context. */
  Eigen::GpuDevice* eigen_device() const;

539
  /*! \brief  Call cublas function safely. */
W
Wilber 已提交
540 541 542
  inline void CublasCall(
      const std::function<void(blasHandle_t)>& callback) const {
    if (!thread_ctx_.count(this)) {
543
      phi::GPUContext::CublasCall(callback);
W
Wilber 已提交
544 545
      return;
    }
546
    return context()->CublasCall(callback);
547 548
  }

Z
zhangkaihuo 已提交
549 550
#ifndef PADDLE_WITH_HIP
  /*! \brief  Call cusparse function safely. */
W
Wilber 已提交
551
  inline void CusparseCall(
552
      const std::function<void(phi::sparseHandle_t)>& callback) const {
W
Wilber 已提交
553
    if (!thread_ctx_.count(this)) {
554
      phi::GPUContext::CusparseCall(callback);
W
Wilber 已提交
555 556 557
      return;
    }
    context()->CusparseCall(callback);
Z
zhangkaihuo 已提交
558 559 560
  }
#endif

561 562
  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
W
Wilber 已提交
563 564 565
  inline void TensorCoreCublasCallIfAvailable(
      const std::function<void(blasHandle_t)>& callback) const {
    if (!thread_ctx_.count(this)) {
566
      phi::GPUContext::TensorCoreCublasCallIfAvailable(callback);
W
Wilber 已提交
567 568 569
      return;
    }
    context()->TensorCoreCublasCallIfAvailable(callback);
570
  }
S
sneaxiy 已提交
571

572 573 574 575
/*! \brief  Return cudnn  handle in the device context. */
#ifdef PADDLE_WITH_HIP
  miopenHandle_t cudnn_handle() const;
#else
576
  cudnnHandle_t cudnn_handle() const;
577
#endif
578

579 580 581 582
/*! \brief  Return cublas handle in the device context. */
#ifdef PADDLE_WITH_HIP
  rocblas_handle cublas_handle() const;
#else
583
  cublasHandle_t cublas_handle() const;
584
  cublasLtHandle_t cublaslt_handle() const;
Z
zhangkaihuo 已提交
585
  cusparseHandle_t cusparse_handle() const;
586
#endif
587

W
Wilber 已提交
588 589 590 591
#ifndef PADDLE_WITH_HIP
  cusolverDnHandle_t cusolver_dn_handle() const;
#endif

S
sneaxiy 已提交
592 593 594 595 596 597 598
  /*! \brief  Return a cudnn workspace handle to call multiple cudnn
   *  functions without interrupting by other threads.
   *  Once the first cudnn function is called by the handle, a lock
   *  would be acquired to prevent other threads from accessing the
   *  workspace. Once the handle is destructed, the lock would be released.
   *  CudnnWorkspaceHandle is an RAII object to implement thread-safe
   *  sequential cudnn function calls. */
599
  phi::DnnWorkspaceHandle cudnn_workspace_handle() const;
S
sneaxiy 已提交
600

Q
init  
qijun 已提交
601
  /*! \brief  Return cuda stream in the device context. */
602
  gpuStream_t stream() const;
Q
QI JUN 已提交
603

W
Wilber 已提交
604
  void RecordEvent(gpuEvent_t ev, const std::function<void()>& callback) const;
605

W
Wilber 已提交
606
  void AddStreamCallback(const std::function<void()>& callback) const;
607

W
Wilber 已提交
608
  void WaitStreamCallback() const;
609

610
  void ResetThreadContext(const stream::Priority& priority) {
611
    std::lock_guard<std::mutex> guard(ctx_mtx_);
W
Wilber 已提交
612
    thread_ctx_[this].reset(new CUDAContext(this->GetPlace(), priority));
613 614
  }

W
Wilber 已提交
615
  std::shared_ptr<CUDAContext> context() const;
S
sneaxiy 已提交
616

W
Wilber 已提交
617 618 619 620 621
  // Note: Can only be used under thread_local semantics.
  void SetThreadLocalStream(const gpuStream_t stream) {
    thread_ctx_.at(this)->SetStream(stream);
  }

W
Wilber 已提交
622 623 624 625
  // NOTE: Just for compatibility with the past, please delete if there is an
  // elegant way.
  stream::CUDAStream* GetCudaStream() const;
  stream::CUDAStream* SetCudaStream(stream::CUDAStream*);
Q
QI JUN 已提交
626

W
Wilber 已提交
627
 private:
628 629 630 631 632 633
  // The thread_local static variable will be released before the
  // global static variable, so avoid using it in dtor.
  static thread_local std::unordered_map<const CUDADeviceContext*,
                                         std::shared_ptr<CUDAContext>>
      thread_ctx_;
  static thread_local std::mutex ctx_mtx_;
634

635 636
  mutable std::mutex cudnn_handle_mtx_;

W
Wilber 已提交
637 638 639
  // NOTE: Just for compatibility with the past, please delete if there is an
  // elegant way.
  std::unique_ptr<stream::CUDAStream> cuda_stream_;
Y
yuyang18 已提交
640

641
  DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
Q
QI JUN 已提交
642
};
Q
qijun 已提交
643

644 645
class CudnnWorkspaceHandle {
 public:
646 647
  inline CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx, std::mutex* mtx)
      : device_context_(dev_ctx), mtx_(mtx) {}
648 649 650 651 652 653 654 655

  template <typename Callback>
  inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_bytes) {
    if (required_workspace_bytes > WorkspaceSize()) {
      ReallocWorkspace(required_workspace_bytes);
    }
    VLOG(2) << "Cudnn workspace size at RunFunc: "
            << static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB";
656 657 658 659
    {
      std::lock_guard<std::mutex> guard(*mtx_);
      cudnn_func(allocation_ ? allocation_->ptr() : nullptr);
    }
660 661 662 663 664 665 666 667 668 669 670 671 672
  }

  /*! \brief Thread which call RunFuncSync() would release gpu memory after
   *  running the function. Currently this function is only used when cudnn
   *  exhaustive searching and callers have to guarantee that the input function
   *  is host blocking */
  template <typename Callback>
  inline void RunFuncSync(Callback&& cudnn_func,
                          size_t required_workspace_bytes) {
    RunFunc(cudnn_func, required_workspace_bytes);
    ResetWorkspace();
  }

673
  void ReallocWorkspace(size_t required_workspace_bytes);
674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689

  inline void ResetWorkspace() { allocation_ = nullptr; }

  inline size_t WorkspaceSize() {
    if (allocation_ == nullptr) {
      return 0;
    }
    return allocation_->size();
  }

  CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
  CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;

 private:
  memory::allocation::AllocationPtr allocation_;
  const CUDADeviceContext& device_context_;
690
  std::mutex* mtx_;
691 692
};

Y
Yang Yu 已提交
693 694
template <>
struct DefaultDeviceContextType<platform::CUDAPlace> {
Y
Yang Yu 已提交
695
  using TYPE = CUDADeviceContext;
Y
Yang Yu 已提交
696 697
};

C
chengduoZH 已提交
698
// Currently, CUDAPinnedDeviceContext is only used to data copying.
C
chengduoZH 已提交
699 700 701 702 703
class CUDAPinnedDeviceContext : public DeviceContext {
 public:
  CUDAPinnedDeviceContext();
  explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place);

W
Wilber 已提交
704
  const Place& GetPlace() const override;
C
chengduoZH 已提交
705

C
chengduoZH 已提交
706 707 708 709 710 711 712 713 714 715 716
  Eigen::DefaultDevice* eigen_device() const;

 private:
  CUDAPinnedPlace place_;
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

template <>
struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
  using TYPE = CUDAPinnedDeviceContext;
};
Q
QI JUN 已提交
717
#endif
Q
qijun 已提交
718

T
tensor-tang 已提交
719
#ifdef PADDLE_WITH_MKLDNN
720 721 722 723 724 725

class MKLDNNDeviceContextThreadLocals {
  // default mkldnn session id

  typedef MKLDNNDeviceContextThreadLocals self;
  struct Body {
726
    bool said_once = false;
727 728 729 730 731 732 733 734 735 736 737
    size_t cur_mkldnn_session_id;
    // Current data input shape string.
    // - For fixed-shape, it's a null string in default.
    // - For dynamic-shape, it's user specific.
    std::string cur_input_shape_str;
    // the cache capacity of different input shapes for MKLDNN.
    // Default 1 means fixed input shape, not dynamic shape.
    int cur_input_shape_cache_capacity;
    // Recently registered data_format. This is needed to
    // know for converting MKL-DNN Tensor to non MKL-DNN
    paddle::framework::DataLayout cur_paddle_data_layout;
738
    // MKL-DNN stream used for execution of primitives (per-thread)
739 740
    dnnl::engine cur_engine;
    dnnl::stream cur_stream;
J
Jacek Czaja 已提交
741 742
    std::string key_suffix;  // Key identifying current Executor
    bool key_attach_thread_id = true;
743
    void* exec_ptr_ = nullptr;
744 745

    Body();
746
    ~Body();
747 748 749 750 751 752
    void set_cur_mkldnn_session_id(size_t sid);
    size_t get_cur_mkldnn_session_id(void);
    void set_cur_input_shape_str(std::string input_shape_str);
    void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
    void set_cur_paddle_data_layout(framework::DataLayout dl);
    framework::DataLayout get_cur_paddle_data_layout(void);
753
    void log_lib_version(void);
754 755
    const dnnl::engine& get_engine(void);
    dnnl::stream& get_stream(void);
J
Jacek Czaja 已提交
756 757 758 759
    void set_key_suffix(const std::string& suffix) { key_suffix = suffix; }
    const std::string& get_key_suffix(void) const { return key_suffix; }
    void disable_tid_in_key(void) { key_attach_thread_id = false; }
    bool is_tid_used_in_key(void) const { return key_attach_thread_id; }
760 761
    void set_curr_exec(void* exec_ptr) { exec_ptr_ = exec_ptr; }
    void* get_curr_exec(void) const { return exec_ptr_; }
762 763 764 765 766 767 768 769 770 771 772 773 774 775 776
  };
  MKLDNNDeviceContextThreadLocals() = default;
  MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
      delete;

 public:
  // default mkldnn session id
  static constexpr size_t kMKLDNNSessionID_Default = 0;
  // mkldnn session id for cache clearing mode
  static constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
  static Body& fetch() {
    thread_local Body b;
    return b;
  }
};
S
Sylwester Fraczek 已提交
777

L
Leo Chen 已提交
778
class MKLDNNDeviceContext : public phi::CPUContext {
T
tensor-tang 已提交
779
 public:
780 781 782 783 784 785 786 787 788 789
  template <class T>
  using BlobPtr_t = std::shared_ptr<T>;
  template <class P1, class P2>
  using umap_value_smart_t = std::unordered_map<P1, BlobPtr_t<P2>>;
  template <class T>
  using umap_key_string_t = umap_value_smart_t<std::string, T>;

  // Following three maps are used to cache MKLDNN primitives.
  // There relations are:
  // - BlobMap = Map<cur_thread_id, ShapeBlob>
790
  // - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
791 792 793
  // - KeyBlob  = Map<blob_name, blob>

  using KeyBlob = umap_key_string_t<void>;
794
  using ShapeBlob = umap_key_string_t<KeyBlob>;
795 796
  using BlobMap = umap_value_smart_t<int, ShapeBlob>;

797 798 799 800
  // Auxillary two-level structure (shape, executor) to easier control
  // clearing cache objects related to specific executor

  using ExecKey = void*;
801
  using ExecMapCacheIterPair = std::pair<BlobPtr_t<KeyBlob>, KeyBlob::iterator>;
802 803 804
  using ExecMap =
      std::unordered_map<ExecKey, std::vector<ExecMapCacheIterPair>>;
  using ExecShape = std::unordered_map<std::string, std::shared_ptr<ExecMap>>;
805

T
tensor-tang 已提交
806 807 808
  explicit MKLDNNDeviceContext(CPUPlace place);

  /* \brief  Get the active engine */
809
  const dnnl::engine& GetEngine() const { return tls().get_engine(); }
T
tensor-tang 已提交
810

811
  // Register object to currently used executor's map
812 813
  void LinkEntryWithExecutor(BlobPtr_t<KeyBlob>, KeyBlob::iterator) const;
  void RemoveShapeEntriesWithExecutor(void) const;
814

815
  // Remove all entries from the blob map
816
  void ResetBlobMap(void* ptr);
817 818 819

  // Prevent next ResetBlobMap()
  void BlockNextCacheClearing();
820

821 822 823
  // Get the ShapeBlob size in cur_mkldnn_session_id.
  size_t GetShapeBlobSize() const;

824 825
  // Set data to blob (i.e. name/data pair). Create blob if not existing
  void SetBlob(const std::string& name, std::shared_ptr<void> data) const;
T
tensor-tang 已提交
826

827
  // Calculate number of oneDNN objects cached
828
  unsigned int GetCachedObjectsNumber(void) const;
829

830 831
  // Find a saved blob. Return nullptr if not found
  std::shared_ptr<void> GetBlob(const std::string& name) const;
T
tensor-tang 已提交
832

833 834 835 836
  static auto tls() -> decltype(MKLDNNDeviceContextThreadLocals::fetch()) {
    return MKLDNNDeviceContextThreadLocals::fetch();
  }

T
tensor-tang 已提交
837
 private:
838
  std::shared_ptr<BlobMap> p_blobmap_;
839 840
  // Map key is pointer of executor and value is a data(iterator in map) needed
  // to erase
841
  std::shared_ptr<ExecShape> p_exec_items_;
842
  std::shared_ptr<std::mutex> p_mutex_;
843 844
  // 0 - clearing is allowed. x > 0 do not clear.
  unsigned int block_next_cache_clearing_ = 0;
T
tensor-tang 已提交
845 846 847
};
#endif

848
#ifdef PADDLE_WITH_CUSTOM_DEVICE
849
class CustomDeviceContext : public phi::CustomContext {
850 851 852 853 854 855 856 857 858 859 860 861 862 863
 public:
  explicit CustomDeviceContext(CustomPlace place);
  virtual ~CustomDeviceContext();

  Eigen::DefaultDevice* eigen_device() const { return nullptr; }

  template <typename Callback>
  void AddStreamCallback(Callback&& callback) const {
    return stream_->AddCallback(callback);
  }

  void WaitStreamCallback() const { return stream_->WaitCallback(); }

 private:
864
  std::shared_ptr<phi::stream::Stream> stream_;
865 866 867 868 869 870 871 872 873 874 875 876
};
template <>
struct DefaultDeviceContextType<platform::CustomPlace> {
  using TYPE = CustomDeviceContext;
};
#else
template <>
struct DefaultDeviceContextType<platform::CustomPlace> {
  using TYPE = DeviceContext;
};
#endif

877 878 879 880 881 882
void EmplaceDeviceContexts(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        place_to_device_context,
    const std::vector<platform::Place>& places,
    bool disable_setting_default_stream_for_allocator);

D
dzhwinter 已提交
883 884 885
/*! \brief device context pool singleton */
class DeviceContextPool {
 public:
Y
Yang Yu 已提交
886
  static DeviceContextPool& Instance() {
G
GaoWei8 已提交
887 888 889
    PADDLE_ENFORCE_NOT_NULL(pool,
                            platform::errors::PreconditionNotMet(
                                "Need to Create DeviceContextPool firstly!"));
D
dzhwinter 已提交
890 891 892 893
    return *pool;
  }

  /*! \brief  Create should only called by Init function */
Y
Yang Yu 已提交
894
  static DeviceContextPool& Init(const std::vector<platform::Place>& places) {
D
dzhwinter 已提交
895 896 897 898 899 900
    if (pool == nullptr) {
      pool = new DeviceContextPool(places);
    }
    return *pool;
  }

901 902
  static bool IsInitialized() { return pool != nullptr; }

903 904
  static void SetPool(DeviceContextPool* dev_pool) { pool = dev_pool; }

D
dzhwinter 已提交
905
  /*! \brief  Return handle of single device context. */
Y
Yu Yang 已提交
906
  platform::DeviceContext* Get(const platform::Place& place);
D
dzhwinter 已提交
907

Y
Yang Yu 已提交
908 909 910 911 912 913 914
  template <typename Place>
  const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
      const Place& place) {
    return reinterpret_cast<
        const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place));
  }

915
  size_t size() const;
916

917
  const std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>&
918 919 920 921 922
  device_contexts() const;

  static void SetDeviceContexts(
      const std::map<Place,
                     std::shared_future<std::unique_ptr<DeviceContext>>>*);
923

D
dzhwinter 已提交
924
 private:
925 926
  explicit DeviceContextPool(const std::vector<platform::Place>& places);

D
dzhwinter 已提交
927
  static DeviceContextPool* pool;
928 929
  std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>
      device_contexts_;
930 931 932
  static thread_local const std::
      map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
          external_device_contexts_;  // not owned
D
dzhwinter 已提交
933 934 935
  DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};

Q
QI JUN 已提交
936 937
}  // namespace platform
}  // namespace paddle