device_context.h 27.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
QI JUN 已提交
2 3 4 5 6 7 8 9 10 11 12
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

W
Wilber 已提交
13
#include <functional>
14
#include <future>  // NOLINT
D
dzhwinter 已提交
15
#include <memory>
Y
yuyang18 已提交
16
#include <mutex>  // NOLINT
17
#include <string>
D
dzhwinter 已提交
18
#include <unordered_map>
19
#include <utility>
20
#include <vector>
W
wanghuancoder 已提交
21

W
Wilber 已提交
22
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
W
Wilber 已提交
23
#include "paddle/pten/backends/cpu/cpu_context.h"
W
Wilber 已提交
24
#include "paddle/pten/backends/gpu/gpu_decls.h"
W
Wilber 已提交
25 26
#include "paddle/pten/core/device_context.h"

Y
Yu Yang 已提交
27
#include "paddle/fluid/memory/malloc.h"
28
#ifdef PADDLE_WITH_CUDA
29
#include "paddle/fluid/platform/device/gpu/gpu_helper.h"
Y
Yi Wang 已提交
30 31
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
G
Guo Sheng 已提交
32
#include "paddle/fluid/platform/dynload/cusolver.h"
33
#include "paddle/fluid/platform/dynload/cusparse.h"
W
Wilber 已提交
34
#include "paddle/pten/backends/gpu/gpu_context.h"
35
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
W
Wu Yi 已提交
36
#include "paddle/fluid/platform/dynload/nccl.h"
W
Wu Yi 已提交
37
#endif
38
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
Q
QI JUN 已提交
39
#endif
D
dzhwinter 已提交
40

41
#ifdef PADDLE_WITH_HIP
42
#include "paddle/fluid/platform/device/gpu/gpu_helper.h"  // NOLINT
43 44
#include "paddle/fluid/platform/dynload/miopen.h"
#include "paddle/fluid/platform/dynload/rocblas.h"
W
Wilber 已提交
45
#include "paddle/pten/backends/gpu/gpu_context.h"  // NOLINT
46 47 48
#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/dynload/rccl.h"
#endif
49
#include "paddle/fluid/platform/device/gpu/gpu_info.h"  // NOLINT
50 51
#endif

52 53 54 55
#if defined(PADDLE_WITH_XPU_BKCL)
#include "xpu/bkcl.h"
#endif

T
tensor-tang 已提交
56
#ifdef PADDLE_WITH_MKLDNN
57
#include "dnnl.hpp"
58
#include "paddle/fluid/framework/data_layout.h"
T
tensor-tang 已提交
59 60
#endif

61
#include <map>
W
wanghuancoder 已提交
62

63
#include "glog/logging.h"
Y
Yi Wang 已提交
64 65
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
66
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
67
#include "paddle/fluid/platform/stream/cuda_stream.h"
S
sneaxiy 已提交
68
#endif
69
#ifdef PADDLE_WITH_ASCEND_CL
70 71
#include "paddle/fluid/platform/device/npu/enforce_npu.h"
#include "paddle/fluid/platform/device/npu/npu_stream.h"
72
#endif
73 74 75

#include "paddle/fluid/platform/device/device_ext.h"
#include "paddle/fluid/platform/device/stream.h"
Q
qijun 已提交
76
#include "unsupported/Eigen/CXX11/Tensor"
Q
QI JUN 已提交
77

W
wanghuancoder 已提交
78 79 80 81 82
namespace Eigen {
struct DefaultDevice;
struct GpuDevice;
}  // namespace Eigen

83
#ifdef PADDLE_WITH_XPU
84 85
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
W
Wilber 已提交
86
#include "paddle/pten/backends/xpu/xpu_context.h"
87 88
#endif

89 90
#ifdef PADDLE_WITH_ASCEND_CL
#include "acl/acl.h"
91
#include "paddle/fluid/platform/device/npu/npu_info.h"
92 93
#endif

Q
QI JUN 已提交
94 95 96
namespace paddle {
namespace platform {

97
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
98 99 100 101
/*Set the value of the global variable allow_tf32_cublas*/
void SetAllowTF32Cublas(bool active);
/*Get the global variable allow_tf32_cublas value*/
bool AllowTF32Cublas();
A
AshburnLee 已提交
102
extern bool allow_tf32_cudnn;
A
AshburnLee 已提交
103 104 105 106
/*Set the value of the global variable allow_tf32_cudnn*/
void SetAllowTF32Cudnn(bool active);
/*Get the global variable allow_tf32_cudnn value*/
bool AllowTF32Cudnn();
107 108
#endif  // PADDLE_WITH_CUDA

109 110 111 112
enum DeviceType {
  CPU = 0,
  CUDA = 1,
  XPU = 2,
113
  NPU = 3,
J
jianghaicheng 已提交
114
  IPU = 4,
F
fwenguang 已提交
115 116 117
  MLU = 5,

  MAX_DEVICE_TYPES = 6,
118 119
};

120 121
DeviceType Place2DeviceType(const platform::Place& place);

122 123 124
constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUDA = DeviceType::CUDA;
constexpr DeviceType kXPU = DeviceType::XPU;
125
constexpr DeviceType kNPU = DeviceType::NPU;
J
jianghaicheng 已提交
126
constexpr DeviceType kIPU = DeviceType::IPU;
F
fwenguang 已提交
127
constexpr DeviceType kMLU = DeviceType::MLU;
128

W
Wilber 已提交
129
using DeviceContext = pten::DeviceContext;
Q
QI JUN 已提交
130

W
Wilber 已提交
131 132 133 134
// using CPUDeviceContext = pten::CPUContext;
// TODO(wilber): The place constructor is used in many places, it is more
// difficult to use CPUDeviceContext = pten::CPUContext directly.
class CPUDeviceContext : public pten::CPUContext {
Q
qijun 已提交
135
 public:
136
  CPUDeviceContext();
Q
qijun 已提交
137
  explicit CPUDeviceContext(CPUPlace place);
Q
QI JUN 已提交
138 139
};

Y
Yang Yu 已提交
140 141 142 143 144 145 146 147
template <typename Place>
struct DefaultDeviceContextType;

template <>
struct DefaultDeviceContextType<platform::CPUPlace> {
  using TYPE = CPUDeviceContext;
};

J
jianghaicheng 已提交
148 149 150 151 152 153 154 155
// Graphcore IPU
#ifdef PADDLE_WITH_IPU
class IPUDeviceContext : public DeviceContext {
 public:
  IPUDeviceContext() = delete;
  explicit IPUDeviceContext(IPUPlace place);
  virtual ~IPUDeviceContext();
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
W
Wilber 已提交
156
  const Place& GetPlace() const override;
J
jianghaicheng 已提交
157 158 159 160 161 162 163 164 165 166
  /*! \brief  Wait for all operations completion in the stream. */
  void Wait() const override;

 private:
  IPUPlace place_;
};
template <>
struct DefaultDeviceContextType<platform::IPUPlace> {
  using TYPE = IPUDeviceContext;
};
F
fwenguang 已提交
167
#endif
J
jianghaicheng 已提交
168

F
fwenguang 已提交
169 170 171 172 173
#ifdef PADDLE_WITH_MLU
class MLUDeviceContext;

template <>
struct DefaultDeviceContextType<platform::MLUPlace>;
J
jianghaicheng 已提交
174 175
#endif

176
#ifdef PADDLE_WITH_XPU
Q
QingshuChen 已提交
177
namespace xpu = baidu::xpu::api;
W
Wilber 已提交
178
class XPUDeviceContext : public pten::XPUContext {
179 180 181 182 183 184 185 186 187 188 189 190 191
 public:
  XPUDeviceContext();
  explicit XPUDeviceContext(XPUPlace place);
  virtual ~XPUDeviceContext();
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
};

template <>
struct DefaultDeviceContextType<platform::XPUPlace> {
  using TYPE = XPUDeviceContext;
};
#endif

192 193 194 195 196 197
#ifdef PADDLE_WITH_ASCEND_CL
class NPUDeviceContext : public DeviceContext {
 public:
  explicit NPUDeviceContext(NPUPlace place);
  virtual ~NPUDeviceContext();
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
W
Wilber 已提交
198
  const Place& GetPlace() const override;
199
  aclrtContext context() const;
200

201 202 203 204 205 206
  /*! \brief  Wait for all operations completion in the stream. */
  void Wait() const override;

  /*! \brief  Return npu stream in the device context. */
  aclrtStream stream() const;

207 208 209 210 211 212 213
  template <typename Callback>
  void AddStreamCallback(Callback&& callback) const {
    return stream_->AddCallback(callback);
  }

  void WaitStreamCallback() const { return stream_->WaitCallback(); }

214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
#if defined(PADDLE_WITH_ASCEND_CL)
  /*! \brief  Return hccl communicators. */
  HcclComm hccl_comm() const { return hccl_comm_; }

  /*! \brief  Set hccl communicators. */
  void set_hccl_comm(HcclComm comm) { hccl_comm_ = comm; }
#endif

  // template <typename Callback>
  // void AddStreamCallback(Callback&& callback) const {
  //   return stream_->AddCallback(callback);
  // }

  // void WaitStreamCallback() const { return stream_->WaitCallback(); }

229 230 231
 private:
  NPUPlace place_;
  aclrtContext context_;
232 233 234 235

#ifdef PADDLE_WITH_ASCEND_CL
  // HCCLContext_t hccl_context_;
  HcclComm hccl_comm_{nullptr};
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
#endif

  // Need to be the same with other DeviceContext,
  // Eventhough eigen_device_ is not used in NPU
  // NOTE(zhiqiu): why need?
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
  std::shared_ptr<stream::NPUStream> stream_;

  DISABLE_COPY_AND_ASSIGN(NPUDeviceContext);
};

template <>
struct DefaultDeviceContextType<platform::NPUPlace> {
  using TYPE = NPUDeviceContext;
};
251 252 253 254 255 256 257

// Currently, NPUPinnedDeviceContext is only used to data copying.
class NPUPinnedDeviceContext : public DeviceContext {
 public:
  NPUPinnedDeviceContext();
  explicit NPUPinnedDeviceContext(NPUPinnedPlace place);

W
Wilber 已提交
258
  const Place& GetPlace() const override;
259 260 261 262 263 264 265 266 267 268 269 270 271

  Eigen::DefaultDevice* eigen_device() const;

 private:
  NPUPinnedPlace place_;
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

template <>
struct DefaultDeviceContextType<platform::NPUPinnedPlace> {
  using TYPE = NPUPinnedDeviceContext;
};

272 273 274
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
275
class CudnnWorkspaceHandle;
W
wanghuancoder 已提交
276
class EigenCudaStreamDevice;
S
sneaxiy 已提交
277

278 279 280 281 282
class CUDAContext {
 public:
  CUDAContext() = default;
  explicit CUDAContext(
      const CUDAPlace& place,
283 284
      const stream::Priority& priority = stream::Priority::kNormal,
      const stream::StreamFlag& flag = stream::StreamFlag::kDefaultFlag);
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299

  ~CUDAContext();

  const CUDAPlace& Place() const { return place_; }

  const std::unique_ptr<Eigen::GpuDevice>& EigenDevice() const {
    return eigen_device_;
  }

  const std::unique_ptr<EigenCudaStreamDevice>& EigenStream() const {
    return eigen_stream_;
  }

  const std::unique_ptr<stream::CUDAStream>& Stream() const { return stream_; }

300 301 302 303 304 305
  stream::CUDAStream* SetStream(stream::CUDAStream* new_stream_ptr) {
    auto* old_stream_ptr = stream_.release();
    stream_.reset(new_stream_ptr);
    return old_stream_ptr;
  }

W
Wilber 已提交
306 307
  void SetStream(gpuStream_t stream);

308
  const gpuStream_t& RawStream() { return stream_->raw_stream(); }
309

310 311 312
#ifdef PADDLE_WITH_HIP
  const miopenHandle_t& CudnnHandle() const { return cudnn_handle_; }
#else
313
  const cudnnHandle_t& CudnnHandle() const { return cudnn_handle_; }
314
#endif
315

316
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
317 318 319
  const cusolverDnHandle_t& CusolverDnHandle() const {
    return cusolver_dn_handle_;
  }
320
#endif
G
Guo Sheng 已提交
321

322 323 324 325 326 327 328 329
  const std::unique_ptr<CublasHandleHolder>& CublasHandle() const {
    return cublas_handle_;
  }

  const std::unique_ptr<CublasHandleHolder>& CublasTensorCoreHandle() const {
    return cublas_tensor_core_handle_;
  }

Z
zhangkaihuo 已提交
330 331 332 333 334 335
#ifndef PADDLE_WITH_HIP
  const std::unique_ptr<CusparseHandleHolder>& CusparseHandle() const {
    return cusparse_handle_;
  }
#endif

336
  /*! \brief  Call cublas function safely. */
W
Wilber 已提交
337 338
  inline void CublasCall(
      const std::function<void(blasHandle_t)>& callback) const {
339
    if (cublas_tf32_tensor_core_handle_) {
W
Wilber 已提交
340
      cublas_tf32_tensor_core_handle_->Call(callback);
341
    } else {
W
Wilber 已提交
342
      cublas_handle_->Call(callback);
343
    }
344 345
  }

Z
zhangkaihuo 已提交
346 347
#ifndef PADDLE_WITH_HIP
  /*! \brief  Call cusparse function safely. */
W
Wilber 已提交
348 349 350
  inline void CusparseCall(
      const std::function<void(pten::sparseHandle_t)>& callback) const {
    cusparse_handle_->Call(callback);
Z
zhangkaihuo 已提交
351 352 353
  }
#endif

354 355 356 357 358
  /*! \brief  Check whether tensor core is supported */
  bool tensor_core_available() const;

  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
W
Wilber 已提交
359 360
  inline void TensorCoreCublasCallIfAvailable(
      const std::function<void(blasHandle_t)>& callback) const {
361
    if (cublas_tensor_core_handle_) {
W
Wilber 已提交
362
      cublas_tensor_core_handle_->Call(callback);
363
    } else {
W
Wilber 已提交
364
      cublas_handle_->Call(callback);
365 366 367 368 369 370
    }
  }

 private:
  void InitEigenContext();

371 372 373 374 375
#ifdef PADDLE_WITH_HIP
  void InitCuBlasContext() {
    cublas_handle_.reset(new CublasHandleHolder(RawStream()));
  }
#else
376 377 378 379 380 381 382
  void InitCuBlasContext() {
    cublas_handle_.reset(
        new CublasHandleHolder(RawStream(), CUBLAS_DEFAULT_MATH));
    if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
      cublas_tensor_core_handle_.reset(
          new CublasHandleHolder(RawStream(), CUBLAS_TENSOR_OP_MATH));
383 384 385 386 387
#if CUDA_VERSION >= 11000
      cublas_tf32_tensor_core_handle_.reset(
          new CublasHandleHolder(RawStream(), CUBLAS_TF32_TENSOR_OP_MATH));
#endif  // CUDA_VERSION >= 11000
#endif  // CUDA_VERSION >= 9000
388 389
    }
  }
390
#endif
391

Z
zhangkaihuo 已提交
392 393 394 395 396 397
#ifndef PADDLE_WITH_HIP
  void InitCuSparseContext() {
    cusparse_handle_.reset(new CusparseHandleHolder(RawStream()));
  }
#endif

398 399
  void InitCuDNNContext() {
    if (dynload::HasCUDNN()) {
400 401
#ifdef PADDLE_WITH_HIP
      size_t miopen_major, miopen_minor, miopen_patch;
402
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenGetVersion(
403 404
          &miopen_major, &miopen_minor, &miopen_patch));
      auto local_miopen_version =
405 406
          (miopen_major * 1000 + miopen_minor * 10 + miopen_patch) / 10;
      auto compile_miopen_version = MIOPEN_VERSION / 10;
407 408 409 410
      if (local_miopen_version < static_cast<size_t>(compile_miopen_version)) {
        LOG_FIRST_N(WARNING, 1)
            << "WARNING: device: " << place_.device
            << ". The installed Paddle is compiled with MIOPEN "
411 412
            << compile_miopen_version / 100 << "."
            << compile_miopen_version % 100
413
            << ", but MIOPEN version in your machine is "
414
            << local_miopen_version / 100 << "." << local_miopen_version % 100
415 416 417 418
            << ", which may cause serious incompatible bug. "
            << "Please recompile or reinstall Paddle with compatible MIOPEN "
               "version.";
      }
419 420
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreate(&cudnn_handle_));
      PADDLE_ENFORCE_GPU_SUCCESS(
421 422
          dynload::miopenSetStream(cudnn_handle_, RawStream()));
#else
423 424 425 426 427 428 429 430 431 432 433 434 435
      auto local_cudnn_version = dynload::cudnnGetVersion() / 100;
      auto compile_cudnn_version = CUDNN_VERSION / 100;
      if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
        LOG_FIRST_N(WARNING, 1)
            << "WARNING: device: " << place_.device
            << ". The installed Paddle is compiled with CUDNN "
            << compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
            << ", but CUDNN version in your machine is "
            << local_cudnn_version / 10 << "." << local_cudnn_version % 10
            << ", which may cause serious incompatible bug. "
            << "Please recompile or reinstall Paddle with compatible CUDNN "
               "version.";
      }
436 437
      PADDLE_RETRY_CUDA_SUCCESS(dynload::cudnnCreate(&cudnn_handle_));
      PADDLE_RETRY_CUDA_SUCCESS(
438
          dynload::cudnnSetStream(cudnn_handle_, RawStream()));
439
#endif
440 441 442 443 444
    } else {
      cudnn_handle_ = nullptr;
    }
  }

445
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
446
  void InitCuSolverContext() {
447 448
    PADDLE_RETRY_CUDA_SUCCESS(dynload::cusolverDnCreate(&cusolver_dn_handle_));
    PADDLE_RETRY_CUDA_SUCCESS(
G
Guo Sheng 已提交
449 450
        dynload::cusolverDnSetStream(cusolver_dn_handle_, RawStream()));
  }
451
#endif
G
Guo Sheng 已提交
452

453 454
  void DestoryCuDNNContext() {
    if (cudnn_handle_) {
455
#ifdef PADDLE_WITH_HIP
456
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenDestroy(cudnn_handle_));
457
#else
458
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::cudnnDestroy(cudnn_handle_));
459
#endif
460 461 462 463 464 465 466
    }
    cudnn_handle_ = nullptr;
  }

  void DestoryCuBlasContext() {
    cublas_handle_.reset();
    cublas_tensor_core_handle_.reset();
467
    cublas_tf32_tensor_core_handle_.reset();
468 469
  }

Z
zhangkaihuo 已提交
470 471 472 473
#ifndef PADDLE_WITH_HIP
  void DestoryCuSparseContext() { cusparse_handle_.reset(); }
#endif

474
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
475 476
  void DestoryCuSolverContext() {
    if (cusolver_dn_handle_) {
477
      PADDLE_ENFORCE_GPU_SUCCESS(
G
Guo Sheng 已提交
478 479 480
          dynload::cusolverDnDestroy(cusolver_dn_handle_));
    }
  }
481
#endif
G
Guo Sheng 已提交
482

483 484 485 486
  CUDAPlace place_;
  std::unique_ptr<Eigen::GpuDevice> eigen_device_;
  std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
  std::unique_ptr<stream::CUDAStream> stream_;
487 488 489
#ifdef PADDLE_WITH_HIP
  miopenHandle_t cudnn_handle_;
#else
490
  cudnnHandle_t cudnn_handle_;
491
#endif
492 493
  std::unique_ptr<CublasHandleHolder> cublas_handle_;
  std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
494
  std::unique_ptr<CublasHandleHolder> cublas_tf32_tensor_core_handle_;
495
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
496
  cusolverDnHandle_t cusolver_dn_handle_;
Z
zhangkaihuo 已提交
497
  std::unique_ptr<CusparseHandleHolder> cusparse_handle_;
498
#endif
499 500 501
  DISABLE_COPY_AND_ASSIGN(CUDAContext);
};

W
Wilber 已提交
502
class CUDADeviceContext : public pten::GPUContext {
Q
QI JUN 已提交
503
 public:
D
dzhwinter 已提交
504
  explicit CUDADeviceContext(CUDAPlace place);
505
  virtual ~CUDADeviceContext();
Q
QI JUN 已提交
506

507
  /*! \brief  Wait for all operations completion in the stream. */
508
  void Wait() const override;
Q
QI JUN 已提交
509

510 511 512
  /*! \brief  Return eigen device in the device context. */
  Eigen::GpuDevice* eigen_device() const;

513
  /*! \brief  Call cublas function safely. */
W
Wilber 已提交
514 515 516 517 518 519
  inline void CublasCall(
      const std::function<void(blasHandle_t)>& callback) const {
    if (!thread_ctx_.count(this)) {
      pten::GPUContext::CublasCall(callback);
      return;
    }
520
    return context()->CublasCall(callback);
521 522
  }

Z
zhangkaihuo 已提交
523 524
#ifndef PADDLE_WITH_HIP
  /*! \brief  Call cusparse function safely. */
W
Wilber 已提交
525 526 527 528 529 530 531
  inline void CusparseCall(
      const std::function<void(pten::sparseHandle_t)>& callback) const {
    if (!thread_ctx_.count(this)) {
      pten::GPUContext::CusparseCall(callback);
      return;
    }
    context()->CusparseCall(callback);
Z
zhangkaihuo 已提交
532 533 534
  }
#endif

535 536
  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
W
Wilber 已提交
537 538 539 540 541 542 543
  inline void TensorCoreCublasCallIfAvailable(
      const std::function<void(blasHandle_t)>& callback) const {
    if (!thread_ctx_.count(this)) {
      pten::GPUContext::TensorCoreCublasCallIfAvailable(callback);
      return;
    }
    context()->TensorCoreCublasCallIfAvailable(callback);
544
  }
S
sneaxiy 已提交
545

546 547 548 549
/*! \brief  Return cudnn  handle in the device context. */
#ifdef PADDLE_WITH_HIP
  miopenHandle_t cudnn_handle() const;
#else
550
  cudnnHandle_t cudnn_handle() const;
551
#endif
552

553 554 555 556
/*! \brief  Return cublas handle in the device context. */
#ifdef PADDLE_WITH_HIP
  rocblas_handle cublas_handle() const;
#else
557
  cublasHandle_t cublas_handle() const;
Z
zhangkaihuo 已提交
558
  cusparseHandle_t cusparse_handle() const;
559
#endif
560

W
Wilber 已提交
561 562 563 564
#ifndef PADDLE_WITH_HIP
  cusolverDnHandle_t cusolver_dn_handle() const;
#endif

S
sneaxiy 已提交
565 566 567 568 569 570 571
  /*! \brief  Return a cudnn workspace handle to call multiple cudnn
   *  functions without interrupting by other threads.
   *  Once the first cudnn function is called by the handle, a lock
   *  would be acquired to prevent other threads from accessing the
   *  workspace. Once the handle is destructed, the lock would be released.
   *  CudnnWorkspaceHandle is an RAII object to implement thread-safe
   *  sequential cudnn function calls. */
W
Wilber 已提交
572
  pten::DnnWorkspaceHandle cudnn_workspace_handle() const;
S
sneaxiy 已提交
573

Q
init  
qijun 已提交
574
  /*! \brief  Return cuda stream in the device context. */
575
  gpuStream_t stream() const;
Q
QI JUN 已提交
576

W
Wilber 已提交
577
  void RecordEvent(gpuEvent_t ev, const std::function<void()>& callback) const;
578

W
Wilber 已提交
579
  void AddStreamCallback(const std::function<void()>& callback) const;
580

W
Wilber 已提交
581
  void WaitStreamCallback() const;
582

583
  void ResetThreadContext(const stream::Priority& priority) {
584
    std::lock_guard<std::mutex> guard(ctx_mtx_);
W
Wilber 已提交
585
    thread_ctx_[this].reset(new CUDAContext(this->GetPlace(), priority));
586 587
  }

W
Wilber 已提交
588
  std::shared_ptr<CUDAContext> context() const;
S
sneaxiy 已提交
589

W
Wilber 已提交
590 591 592 593 594
  // Note: Can only be used under thread_local semantics.
  void SetThreadLocalStream(const gpuStream_t stream) {
    thread_ctx_.at(this)->SetStream(stream);
  }

W
Wilber 已提交
595 596 597 598
  // NOTE: Just for compatibility with the past, please delete if there is an
  // elegant way.
  stream::CUDAStream* GetCudaStream() const;
  stream::CUDAStream* SetCudaStream(stream::CUDAStream*);
Q
QI JUN 已提交
599

W
Wilber 已提交
600
 private:
601 602 603 604 605 606
  // The thread_local static variable will be released before the
  // global static variable, so avoid using it in dtor.
  static thread_local std::unordered_map<const CUDADeviceContext*,
                                         std::shared_ptr<CUDAContext>>
      thread_ctx_;
  static thread_local std::mutex ctx_mtx_;
607

608 609
  mutable std::mutex cudnn_handle_mtx_;

W
Wilber 已提交
610 611 612
  // NOTE: Just for compatibility with the past, please delete if there is an
  // elegant way.
  std::unique_ptr<stream::CUDAStream> cuda_stream_;
W
Wilber 已提交
613
  std::unique_ptr<pten::DnnWorkspaceHandle> workspace_{nullptr};
Y
yuyang18 已提交
614

615
  DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
Q
QI JUN 已提交
616
};
Q
qijun 已提交
617

618 619
class CudnnWorkspaceHandle {
 public:
620 621
  inline CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx, std::mutex* mtx)
      : device_context_(dev_ctx), mtx_(mtx) {}
622 623 624 625 626 627 628 629

  template <typename Callback>
  inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_bytes) {
    if (required_workspace_bytes > WorkspaceSize()) {
      ReallocWorkspace(required_workspace_bytes);
    }
    VLOG(2) << "Cudnn workspace size at RunFunc: "
            << static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB";
630 631 632 633
    {
      std::lock_guard<std::mutex> guard(*mtx_);
      cudnn_func(allocation_ ? allocation_->ptr() : nullptr);
    }
634 635 636 637 638 639 640 641 642 643 644 645 646
  }

  /*! \brief Thread which call RunFuncSync() would release gpu memory after
   *  running the function. Currently this function is only used when cudnn
   *  exhaustive searching and callers have to guarantee that the input function
   *  is host blocking */
  template <typename Callback>
  inline void RunFuncSync(Callback&& cudnn_func,
                          size_t required_workspace_bytes) {
    RunFunc(cudnn_func, required_workspace_bytes);
    ResetWorkspace();
  }

647
  void ReallocWorkspace(size_t required_workspace_bytes);
648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663

  inline void ResetWorkspace() { allocation_ = nullptr; }

  inline size_t WorkspaceSize() {
    if (allocation_ == nullptr) {
      return 0;
    }
    return allocation_->size();
  }

  CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
  CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;

 private:
  memory::allocation::AllocationPtr allocation_;
  const CUDADeviceContext& device_context_;
664
  std::mutex* mtx_;
665 666
};

Y
Yang Yu 已提交
667 668
template <>
struct DefaultDeviceContextType<platform::CUDAPlace> {
Y
Yang Yu 已提交
669
  using TYPE = CUDADeviceContext;
Y
Yang Yu 已提交
670 671
};

C
chengduoZH 已提交
672
// Currently, CUDAPinnedDeviceContext is only used to data copying.
C
chengduoZH 已提交
673 674 675 676 677
class CUDAPinnedDeviceContext : public DeviceContext {
 public:
  CUDAPinnedDeviceContext();
  explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place);

W
Wilber 已提交
678
  const Place& GetPlace() const override;
C
chengduoZH 已提交
679

C
chengduoZH 已提交
680 681 682 683 684 685 686 687 688 689 690
  Eigen::DefaultDevice* eigen_device() const;

 private:
  CUDAPinnedPlace place_;
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

template <>
struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
  using TYPE = CUDAPinnedDeviceContext;
};
Q
QI JUN 已提交
691
#endif
Q
qijun 已提交
692

T
tensor-tang 已提交
693
#ifdef PADDLE_WITH_MKLDNN
694 695 696 697 698 699

class MKLDNNDeviceContextThreadLocals {
  // default mkldnn session id

  typedef MKLDNNDeviceContextThreadLocals self;
  struct Body {
700
    bool said_once = false;
701 702 703 704 705 706 707 708 709 710 711
    size_t cur_mkldnn_session_id;
    // Current data input shape string.
    // - For fixed-shape, it's a null string in default.
    // - For dynamic-shape, it's user specific.
    std::string cur_input_shape_str;
    // the cache capacity of different input shapes for MKLDNN.
    // Default 1 means fixed input shape, not dynamic shape.
    int cur_input_shape_cache_capacity;
    // Recently registered data_format. This is needed to
    // know for converting MKL-DNN Tensor to non MKL-DNN
    paddle::framework::DataLayout cur_paddle_data_layout;
712
    // MKL-DNN stream used for execution of primitives (per-thread)
713 714
    dnnl::engine cur_engine;
    dnnl::stream cur_stream;
J
Jacek Czaja 已提交
715 716
    std::string key_suffix;  // Key identifying current Executor
    bool key_attach_thread_id = true;
717
    void* exec_ptr_ = nullptr;
718 719

    Body();
720
    ~Body();
721 722 723 724 725 726
    void set_cur_mkldnn_session_id(size_t sid);
    size_t get_cur_mkldnn_session_id(void);
    void set_cur_input_shape_str(std::string input_shape_str);
    void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
    void set_cur_paddle_data_layout(framework::DataLayout dl);
    framework::DataLayout get_cur_paddle_data_layout(void);
727
    void log_lib_version(void);
728 729
    const dnnl::engine& get_engine(void);
    dnnl::stream& get_stream(void);
J
Jacek Czaja 已提交
730 731 732 733
    void set_key_suffix(const std::string& suffix) { key_suffix = suffix; }
    const std::string& get_key_suffix(void) const { return key_suffix; }
    void disable_tid_in_key(void) { key_attach_thread_id = false; }
    bool is_tid_used_in_key(void) const { return key_attach_thread_id; }
734 735
    void set_curr_exec(void* exec_ptr) { exec_ptr_ = exec_ptr; }
    void* get_curr_exec(void) const { return exec_ptr_; }
736 737 738 739 740 741 742 743 744 745 746 747 748 749 750
  };
  MKLDNNDeviceContextThreadLocals() = default;
  MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
      delete;

 public:
  // default mkldnn session id
  static constexpr size_t kMKLDNNSessionID_Default = 0;
  // mkldnn session id for cache clearing mode
  static constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
  static Body& fetch() {
    thread_local Body b;
    return b;
  }
};
S
Sylwester Fraczek 已提交
751

T
tensor-tang 已提交
752 753
class MKLDNNDeviceContext : public CPUDeviceContext {
 public:
754 755 756 757 758 759 760 761 762 763
  template <class T>
  using BlobPtr_t = std::shared_ptr<T>;
  template <class P1, class P2>
  using umap_value_smart_t = std::unordered_map<P1, BlobPtr_t<P2>>;
  template <class T>
  using umap_key_string_t = umap_value_smart_t<std::string, T>;

  // Following three maps are used to cache MKLDNN primitives.
  // There relations are:
  // - BlobMap = Map<cur_thread_id, ShapeBlob>
764
  // - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
765 766 767
  // - KeyBlob  = Map<blob_name, blob>

  using KeyBlob = umap_key_string_t<void>;
768
  using ShapeBlob = umap_key_string_t<KeyBlob>;
769 770
  using BlobMap = umap_value_smart_t<int, ShapeBlob>;

771 772 773 774
  // Auxillary two-level structure (shape, executor) to easier control
  // clearing cache objects related to specific executor

  using ExecKey = void*;
775
  using ExecMapCacheIterPair = std::pair<BlobPtr_t<KeyBlob>, KeyBlob::iterator>;
776 777 778
  using ExecMap =
      std::unordered_map<ExecKey, std::vector<ExecMapCacheIterPair>>;
  using ExecShape = std::unordered_map<std::string, std::shared_ptr<ExecMap>>;
779

T
tensor-tang 已提交
780 781 782
  explicit MKLDNNDeviceContext(CPUPlace place);

  /* \brief  Get the active engine */
783
  const dnnl::engine& GetEngine() const { return tls().get_engine(); }
T
tensor-tang 已提交
784

785
  // Register object to currently used executor's map
786 787
  void LinkEntryWithExecutor(BlobPtr_t<KeyBlob>, KeyBlob::iterator) const;
  void RemoveShapeEntriesWithExecutor(void) const;
788

789
  // Remove all entries from the blob map
790
  void ResetBlobMap(void* ptr);
791 792 793

  // Prevent next ResetBlobMap()
  void BlockNextCacheClearing();
794

795 796 797
  // Get the ShapeBlob size in cur_mkldnn_session_id.
  size_t GetShapeBlobSize() const;

798 799
  // Set data to blob (i.e. name/data pair). Create blob if not existing
  void SetBlob(const std::string& name, std::shared_ptr<void> data) const;
T
tensor-tang 已提交
800

801
  // Calculate number of oneDNN objects cached
802
  unsigned int GetCachedObjectsNumber(void) const;
803

804 805
  // Find a saved blob. Return nullptr if not found
  std::shared_ptr<void> GetBlob(const std::string& name) const;
T
tensor-tang 已提交
806

807 808 809 810
  static auto tls() -> decltype(MKLDNNDeviceContextThreadLocals::fetch()) {
    return MKLDNNDeviceContextThreadLocals::fetch();
  }

T
tensor-tang 已提交
811
 private:
812
  std::shared_ptr<BlobMap> p_blobmap_;
813 814
  // Map key is pointer of executor and value is a data(iterator in map) needed
  // to erase
815
  std::shared_ptr<ExecShape> p_exec_items_;
816
  std::shared_ptr<std::mutex> p_mutex_;
817
  bool block_next_cache_clearing_ = false;
T
tensor-tang 已提交
818 819 820
};
#endif

821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860
#ifdef PADDLE_WITH_CUSTOM_DEVICE
class CustomDeviceContext : public DeviceContext {
 public:
  explicit CustomDeviceContext(CustomPlace place);
  virtual ~CustomDeviceContext();

  const Place& GetPlace() const override;
  void Wait() const override;
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
  C_Stream stream() const {
    return reinterpret_cast<C_Stream>(stream_->raw_stream());
  }

  template <typename Callback>
  void AddStreamCallback(Callback&& callback) const {
    return stream_->AddCallback(callback);
  }

  void WaitStreamCallback() const { return stream_->WaitCallback(); }

 private:
  std::string device_type_;

  CustomPlace place_;

  std::shared_ptr<platform::stream::Stream> stream_;

  CustomDeviceContext();
};
template <>
struct DefaultDeviceContextType<platform::CustomPlace> {
  using TYPE = CustomDeviceContext;
};
#else
template <>
struct DefaultDeviceContextType<platform::CustomPlace> {
  using TYPE = DeviceContext;
};
#endif

D
dzhwinter 已提交
861 862 863 864 865
/*! \brief device context pool singleton */
class DeviceContextPool {
 public:
  explicit DeviceContextPool(const std::vector<platform::Place>& places);

Y
Yang Yu 已提交
866
  static DeviceContextPool& Instance() {
G
GaoWei8 已提交
867 868 869
    PADDLE_ENFORCE_NOT_NULL(pool,
                            platform::errors::PreconditionNotMet(
                                "Need to Create DeviceContextPool firstly!"));
D
dzhwinter 已提交
870 871 872 873
    return *pool;
  }

  /*! \brief  Create should only called by Init function */
Y
Yang Yu 已提交
874
  static DeviceContextPool& Init(const std::vector<platform::Place>& places) {
D
dzhwinter 已提交
875 876 877 878 879 880
    if (pool == nullptr) {
      pool = new DeviceContextPool(places);
    }
    return *pool;
  }

881 882
  static void SetPool(DeviceContextPool* dev_pool) { pool = dev_pool; }

D
dzhwinter 已提交
883
  /*! \brief  Return handle of single device context. */
Y
Yu Yang 已提交
884
  platform::DeviceContext* Get(const platform::Place& place);
D
dzhwinter 已提交
885

Y
Yang Yu 已提交
886 887 888 889 890 891 892
  template <typename Place>
  const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
      const Place& place) {
    return reinterpret_cast<
        const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place));
  }

893 894
  size_t size() const { return device_contexts_.size(); }

D
dzhwinter 已提交
895 896
 private:
  static DeviceContextPool* pool;
897 898
  std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>
      device_contexts_;
D
dzhwinter 已提交
899 900 901
  DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};

Q
QI JUN 已提交
902 903
}  // namespace platform
}  // namespace paddle