device_context.h 21.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
QI JUN 已提交
2 3 4 5 6 7 8 9 10 11 12
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

13
#include <future>  // NOLINT
D
dzhwinter 已提交
14
#include <memory>
Y
yuyang18 已提交
15
#include <mutex>  // NOLINT
16
#include <string>
D
dzhwinter 已提交
17
#include <unordered_map>
18
#include <utility>
19
#include <vector>
W
wanghuancoder 已提交
20

Y
Yu Yang 已提交
21
#include "paddle/fluid/memory/malloc.h"
22
#ifdef PADDLE_WITH_CUDA
23
#include "paddle/fluid/platform/cuda_helper.h"
Y
Yi Wang 已提交
24 25
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
G
Guo Sheng 已提交
26
#include "paddle/fluid/platform/dynload/cusolver.h"
27
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
W
Wu Yi 已提交
28
#include "paddle/fluid/platform/dynload/nccl.h"
W
Wu Yi 已提交
29
#endif
Y
Yi Wang 已提交
30
#include "paddle/fluid/platform/gpu_info.h"
Q
QI JUN 已提交
31
#endif
D
dzhwinter 已提交
32

33 34 35 36
#if defined(PADDLE_WITH_XPU_BKCL)
#include "xpu/bkcl.h"
#endif

T
tensor-tang 已提交
37
#ifdef PADDLE_WITH_MKLDNN
L
luotao1 已提交
38
#include "mkldnn.hpp"
39
#include "paddle/fluid/framework/data_layout.h"
T
tensor-tang 已提交
40 41
#endif

42
#include <map>
W
wanghuancoder 已提交
43

44
#include "glog/logging.h"
Y
Yi Wang 已提交
45 46
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
S
sneaxiy 已提交
47
#ifdef PADDLE_WITH_CUDA
48
#include "paddle/fluid/platform/stream/cuda_stream.h"
S
sneaxiy 已提交
49
#endif
50 51 52
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/stream/npu_stream.h"
#endif
Q
qijun 已提交
53
#include "unsupported/Eigen/CXX11/Tensor"
Q
QI JUN 已提交
54

W
wanghuancoder 已提交
55 56 57 58 59
namespace Eigen {
struct DefaultDevice;
struct GpuDevice;
}  // namespace Eigen

60 61
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/xpu_header.h"
62
#include "paddle/fluid/platform/xpu_info.h"
63 64
#endif

65 66 67 68 69
#ifdef PADDLE_WITH_ASCEND_CL
#include "acl/acl.h"
#include "paddle/fluid/platform/npu_info.h"
#endif

Q
QI JUN 已提交
70 71 72
namespace paddle {
namespace platform {

73 74 75 76 77
#ifdef PADDLE_WITH_CUDA
/*Set the value of the global variable allow_tf32_cublas*/
void SetAllowTF32Cublas(bool active);
/*Get the global variable allow_tf32_cublas value*/
bool AllowTF32Cublas();
A
AshburnLee 已提交
78 79 80 81
/*Set the value of the global variable allow_tf32_cudnn*/
void SetAllowTF32Cudnn(bool active);
/*Get the global variable allow_tf32_cudnn value*/
bool AllowTF32Cudnn();
82 83
#endif  // PADDLE_WITH_CUDA

84 85 86 87
enum DeviceType {
  CPU = 0,
  CUDA = 1,
  XPU = 2,
88
  NPU = 3,
89 90 91 92 93
};

constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUDA = DeviceType::CUDA;
constexpr DeviceType kXPU = DeviceType::XPU;
94
constexpr DeviceType kNPU = DeviceType::NPU;
95

Q
QI JUN 已提交
96 97
class DeviceContext {
 public:
Z
Zeng Jinle 已提交
98
  virtual ~DeviceContext() PADDLE_MAY_THROW {}
L
liaogang 已提交
99
  virtual Place GetPlace() const = 0;
Q
QI JUN 已提交
100

101
  virtual void Wait() const {}
Q
QI JUN 已提交
102 103
};

Q
qijun 已提交
104 105
class CPUDeviceContext : public DeviceContext {
 public:
106
  CPUDeviceContext();
Q
qijun 已提交
107
  explicit CPUDeviceContext(CPUPlace place);
Q
qijun 已提交
108

109
  Eigen::DefaultDevice* eigen_device() const;
Q
qijun 已提交
110

L
liaogang 已提交
111
  Place GetPlace() const override;
Y
Yu Yang 已提交
112

Q
qijun 已提交
113
 private:
D
dzhwinter 已提交
114
  CPUPlace place_;
Q
qijun 已提交
115
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
Q
QI JUN 已提交
116 117
};

Y
Yang Yu 已提交
118 119 120 121 122 123 124 125
template <typename Place>
struct DefaultDeviceContextType;

template <>
struct DefaultDeviceContextType<platform::CPUPlace> {
  using TYPE = CPUDeviceContext;
};

126 127 128 129 130 131 132 133 134 135 136 137 138
#ifdef PADDLE_WITH_XPU
class XPUDeviceContext : public DeviceContext {
 public:
  XPUDeviceContext();
  explicit XPUDeviceContext(XPUPlace place);
  virtual ~XPUDeviceContext();
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
  Place GetPlace() const override;
  xpu::Context* x_context() const;

  /*! \brief  Wait for all operations completion in the stream. */
  void Wait() const override;

139
#ifdef PADDLE_WITH_XPU_BKCL
140
  /*! \brief  Return bkcl context. */
141 142 143 144 145 146
  BKCLContext_t bkcl_context() const { return bkcl_context_; }

  /*! \brief  Set bkcl context. */
  void set_bkcl_context(BKCLContext_t context) { bkcl_context_ = context; }
#endif

147 148 149
 private:
  XPUPlace place_;
  xpu::Context* context_;
150 151 152
#ifdef PADDLE_WITH_XPU_BKCL
  BKCLContext_t bkcl_context_;
#endif
153 154 155 156 157 158 159 160 161 162 163 164 165

  // Need to be the same with other DeviceContext,
  // Eventhough eigen_device_ is not used in XPU
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
  DISABLE_COPY_AND_ASSIGN(XPUDeviceContext);
};

template <>
struct DefaultDeviceContextType<platform::XPUPlace> {
  using TYPE = XPUDeviceContext;
};
#endif

166 167 168 169 170 171 172 173 174 175 176 177
#ifdef PADDLE_WITH_ASCEND_CL
class NPUDeviceContext : public DeviceContext {
 public:
  explicit NPUDeviceContext(NPUPlace place);
  virtual ~NPUDeviceContext();
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
  Place GetPlace() const override;
  aclrtContext* context() const;

  /*! \brief  Wait for all operations completion in the stream. */
  void Wait() const override;

178 179 180
  /*! \brief  Return npu stream in the device context. */
  aclrtStream stream() const;

181 182 183 184 185 186 187 188 189 190 191
 private:
  NPUPlace place_;
  aclrtContext context_;
#ifdef PADDLE_WITH_ASCEND_HCCL
  HCCLContext_t hccl_context_;
#endif

  // Need to be the same with other DeviceContext,
  // Eventhough eigen_device_ is not used in NPU
  // NOTE(zhiqiu): why need?
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
192 193
  std::shared_ptr<stream::NPUStream> stream_;

194 195 196 197 198 199 200 201 202
  DISABLE_COPY_AND_ASSIGN(NPUDeviceContext);
};

template <>
struct DefaultDeviceContextType<platform::NPUPlace> {
  using TYPE = NPUDeviceContext;
};
#endif

203
#ifdef PADDLE_WITH_CUDA
204

205
class CudnnWorkspaceHandle;
W
wanghuancoder 已提交
206
class EigenCudaStreamDevice;
S
sneaxiy 已提交
207

208 209 210 211 212
class CUDAContext {
 public:
  CUDAContext() = default;
  explicit CUDAContext(
      const CUDAPlace& place,
213
      const stream::Priority& priority = stream::Priority::kNormal);
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232

  ~CUDAContext();

  const CUDAPlace& Place() const { return place_; }

  const std::unique_ptr<Eigen::GpuDevice>& EigenDevice() const {
    return eigen_device_;
  }

  const std::unique_ptr<EigenCudaStreamDevice>& EigenStream() const {
    return eigen_stream_;
  }

  const std::unique_ptr<stream::CUDAStream>& Stream() const { return stream_; }

  const cudaStream_t& RawStream() { return stream_->raw_stream(); }

  const cudnnHandle_t& CudnnHandle() const { return cudnn_handle_; }

G
Guo Sheng 已提交
233 234 235 236
  const cusolverDnHandle_t& CusolverDnHandle() const {
    return cusolver_dn_handle_;
  }

237 238 239 240 241 242 243 244 245 246 247
  const std::unique_ptr<CublasHandleHolder>& CublasHandle() const {
    return cublas_handle_;
  }

  const std::unique_ptr<CublasHandleHolder>& CublasTensorCoreHandle() const {
    return cublas_tensor_core_handle_;
  }

  /*! \brief  Call cublas function safely. */
  template <typename Callback>
  inline void CublasCall(Callback&& callback) const {
248 249 250 251 252
    if (cublas_tf32_tensor_core_handle_) {
      cublas_tf32_tensor_core_handle_->Call(std::forward<Callback>(callback));
    } else {
      cublas_handle_->Call(std::forward<Callback>(callback));
    }
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
  }

  /*! \brief  Check whether tensor core is supported */
  bool tensor_core_available() const;

  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
  template <typename Callback>
  inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
    if (cublas_tensor_core_handle_) {
      cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
    } else {
      cublas_handle_->Call(std::forward<Callback>(callback));
    }
  }

 private:
  void InitEigenContext();

  void InitCuBlasContext() {
    cublas_handle_.reset(
        new CublasHandleHolder(RawStream(), CUBLAS_DEFAULT_MATH));
    if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
      cublas_tensor_core_handle_.reset(
          new CublasHandleHolder(RawStream(), CUBLAS_TENSOR_OP_MATH));
279 280 281 282 283
#if CUDA_VERSION >= 11000
      cublas_tf32_tensor_core_handle_.reset(
          new CublasHandleHolder(RawStream(), CUBLAS_TF32_TENSOR_OP_MATH));
#endif  // CUDA_VERSION >= 11000
#endif  // CUDA_VERSION >= 9000
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
    }
  }

  void InitCuDNNContext() {
    if (dynload::HasCUDNN()) {
      auto local_cudnn_version = dynload::cudnnGetVersion() / 100;
      auto compile_cudnn_version = CUDNN_VERSION / 100;
      if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
        LOG_FIRST_N(WARNING, 1)
            << "WARNING: device: " << place_.device
            << ". The installed Paddle is compiled with CUDNN "
            << compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
            << ", but CUDNN version in your machine is "
            << local_cudnn_version / 10 << "." << local_cudnn_version % 10
            << ", which may cause serious incompatible bug. "
            << "Please recompile or reinstall Paddle with compatible CUDNN "
               "version.";
      }
302 303
      PADDLE_RETRY_CUDA_SUCCESS(dynload::cudnnCreate(&cudnn_handle_));
      PADDLE_RETRY_CUDA_SUCCESS(
304
          dynload::cudnnSetStream(cudnn_handle_, RawStream()));
305 306 307 308 309
    } else {
      cudnn_handle_ = nullptr;
    }
  }

G
Guo Sheng 已提交
310
  void InitCuSolverContext() {
311 312
    PADDLE_RETRY_CUDA_SUCCESS(dynload::cusolverDnCreate(&cusolver_dn_handle_));
    PADDLE_RETRY_CUDA_SUCCESS(
G
Guo Sheng 已提交
313 314 315
        dynload::cusolverDnSetStream(cusolver_dn_handle_, RawStream()));
  }

316 317
  void DestoryCuDNNContext() {
    if (cudnn_handle_) {
318
      PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroy(cudnn_handle_));
319 320 321 322 323 324 325
    }
    cudnn_handle_ = nullptr;
  }

  void DestoryCuBlasContext() {
    cublas_handle_.reset();
    cublas_tensor_core_handle_.reset();
326
    cublas_tf32_tensor_core_handle_.reset();
327 328
  }

G
Guo Sheng 已提交
329 330 331 332 333 334 335
  void DestoryCuSolverContext() {
    if (cusolver_dn_handle_) {
      PADDLE_ENFORCE_CUDA_SUCCESS(
          dynload::cusolverDnDestroy(cusolver_dn_handle_));
    }
  }

336 337 338 339 340 341 342
  CUDAPlace place_;
  std::unique_ptr<Eigen::GpuDevice> eigen_device_;
  std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
  std::unique_ptr<stream::CUDAStream> stream_;
  cudnnHandle_t cudnn_handle_;
  std::unique_ptr<CublasHandleHolder> cublas_handle_;
  std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
343
  std::unique_ptr<CublasHandleHolder> cublas_tf32_tensor_core_handle_;
G
Guo Sheng 已提交
344
  cusolverDnHandle_t cusolver_dn_handle_;
345 346 347
  DISABLE_COPY_AND_ASSIGN(CUDAContext);
};

348
class CUDADeviceContext : public DeviceContext {
Q
QI JUN 已提交
349
 public:
D
dzhwinter 已提交
350
  explicit CUDADeviceContext(CUDAPlace place);
351
  virtual ~CUDADeviceContext();
Q
QI JUN 已提交
352

353
  /*! \brief  Wait for all operations completion in the stream. */
354
  void Wait() const override;
Q
QI JUN 已提交
355

356
  /*! \brief  Return place in the device context. */
L
liaogang 已提交
357
  Place GetPlace() const override;
358

K
Kexin Zhao 已提交
359
  /*! \brief  Return compute capability in the device context. */
K
Kexin Zhao 已提交
360 361
  int GetComputeCapability() const;

362 363 364
  /*! \brief  Return the max physical thread count in the device context */
  int GetMaxPhysicalThreadCount() const;

365 366 367 368 369 370
  /*! \brief  Return the SM count in the device context */
  int GetSMCount() const;

  /*! \brief  Return the Max thread num of block in the device context */
  int GetMaxThreadsPerBlock() const;

371 372 373
  /*! \brief  Return the max grid dim size in the device context */
  dim3 GetCUDAMaxGridDimSize() const;

374 375 376
  /*! \brief  Return eigen device in the device context. */
  Eigen::GpuDevice* eigen_device() const;

377 378 379
  /*! \brief  Call cublas function safely. */
  template <typename Callback>
  inline void CublasCall(Callback&& callback) const {
380
    return context()->CublasCall(callback);
381 382 383 384 385 386 387 388 389
  }

  /*! \brief  Check whether tensor core is supported */
  bool tensor_core_available() const;

  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
  template <typename Callback>
  inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
390
    return context()->TensorCoreCublasCallIfAvailable(callback);
391
  }
S
sneaxiy 已提交
392

393
  /*! \brief  Return cudnn  handle in the device context. */
394
  cudnnHandle_t cudnn_handle() const;
395

S
sneaxiy 已提交
396 397 398 399 400 401 402 403 404
  /*! \brief  Return a cudnn workspace handle to call multiple cudnn
   *  functions without interrupting by other threads.
   *  Once the first cudnn function is called by the handle, a lock
   *  would be acquired to prevent other threads from accessing the
   *  workspace. Once the handle is destructed, the lock would be released.
   *  CudnnWorkspaceHandle is an RAII object to implement thread-safe
   *  sequential cudnn function calls. */
  CudnnWorkspaceHandle cudnn_workspace_handle() const;

G
Guo Sheng 已提交
405 406
  cusolverDnHandle_t cusolver_dn_handle() const;

Q
init  
qijun 已提交
407
  /*! \brief  Return cuda stream in the device context. */
408
  cudaStream_t stream() const;
Q
QI JUN 已提交
409

410
#if defined(PADDLE_WITH_NCCL)
Q
qingqing01 已提交
411 412 413 414 415
  /*! \brief  Return nccl communicators. */
  ncclComm_t nccl_comm() const { return nccl_comm_; }

  /*! \brief  Set nccl communicators. */
  void set_nccl_comm(ncclComm_t comm) { nccl_comm_ = comm; }
Q
qingqing01 已提交
416
#endif
Q
qingqing01 已提交
417

Y
Yu Yang 已提交
418
  template <typename Callback>
419 420
  void RecordEvent(cudaEvent_t ev, Callback callback) const {
    return context()->Stream()->RecordEvent(ev, callback);
Y
Yu Yang 已提交
421 422
  }

S
sneaxiy 已提交
423 424
  template <typename Callback>
  void AddStreamCallback(Callback&& callback) const {
425 426 427 428 429
    return context()->Stream()->AddCallback(callback);
  }

  void WaitStreamCallback() const {
    return context()->Stream()->WaitCallback();
430 431
  }

432
  void ResetDefaultContext(const stream::Priority& priority) {
433 434 435
    default_ctx_.reset(new CUDAContext(place_, priority));
  }

436
  void ResetThreadContext(const stream::Priority& priority) {
437 438 439 440 441 442 443 444 445 446
    std::lock_guard<std::mutex> guard(ctx_mtx_);
    thread_ctx_[this].reset(new CUDAContext(place_, priority));
  }

  std::shared_ptr<CUDAContext> context() const {
    if (!thread_ctx_.count(this)) {
      return default_ctx_;
    }
    return thread_ctx_.at(this);
  }
S
sneaxiy 已提交
447

Q
QI JUN 已提交
448
 private:
D
dzhwinter 已提交
449
  CUDAPlace place_;
450
  std::shared_ptr<CUDAContext> default_ctx_;
Q
QI JUN 已提交
451

452 453 454 455 456 457
  // The thread_local static variable will be released before the
  // global static variable, so avoid using it in dtor.
  static thread_local std::unordered_map<const CUDADeviceContext*,
                                         std::shared_ptr<CUDAContext>>
      thread_ctx_;
  static thread_local std::mutex ctx_mtx_;
458

459 460
  mutable std::mutex cudnn_handle_mtx_;

461
#if defined(PADDLE_WITH_NCCL)
Q
qingqing01 已提交
462 463 464 465 466 467
  // NCCL communicator (single process version) for NCCL collective operations.
  // NCCL collective operations provides fast collectives over multiple GPUs
  // both within and across nodes.
  // But, this collectives is used for collectives over multiple GPUs within
  // nodes.
  ncclComm_t nccl_comm_{nullptr};
Q
qingqing01 已提交
468
#endif
Q
qingqing01 已提交
469

C
chengduo 已提交
470 471 472 473 474
  int compute_capability_;
  int runtime_version_;
  int driver_version_;
  int multi_process_;
  int max_threads_per_mp_;
475
  int max_threads_per_block_;
476
  dim3 max_grid_dim_size_;
Y
yuyang18 已提交
477

478
  DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
Q
QI JUN 已提交
479
};
Q
qijun 已提交
480

481 482
class CudnnWorkspaceHandle {
 public:
483 484
  inline CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx, std::mutex* mtx)
      : device_context_(dev_ctx), mtx_(mtx) {}
485 486 487 488 489 490 491 492

  template <typename Callback>
  inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_bytes) {
    if (required_workspace_bytes > WorkspaceSize()) {
      ReallocWorkspace(required_workspace_bytes);
    }
    VLOG(2) << "Cudnn workspace size at RunFunc: "
            << static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB";
493 494 495 496
    {
      std::lock_guard<std::mutex> guard(*mtx_);
      cudnn_func(allocation_ ? allocation_->ptr() : nullptr);
    }
497 498 499 500 501 502 503 504 505 506 507 508 509
  }

  /*! \brief Thread which call RunFuncSync() would release gpu memory after
   *  running the function. Currently this function is only used when cudnn
   *  exhaustive searching and callers have to guarantee that the input function
   *  is host blocking */
  template <typename Callback>
  inline void RunFuncSync(Callback&& cudnn_func,
                          size_t required_workspace_bytes) {
    RunFunc(cudnn_func, required_workspace_bytes);
    ResetWorkspace();
  }

510
  void ReallocWorkspace(size_t required_workspace_bytes);
511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526

  inline void ResetWorkspace() { allocation_ = nullptr; }

  inline size_t WorkspaceSize() {
    if (allocation_ == nullptr) {
      return 0;
    }
    return allocation_->size();
  }

  CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
  CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;

 private:
  memory::allocation::AllocationPtr allocation_;
  const CUDADeviceContext& device_context_;
527
  std::mutex* mtx_;
528 529
};

Y
Yang Yu 已提交
530 531
template <>
struct DefaultDeviceContextType<platform::CUDAPlace> {
Y
Yang Yu 已提交
532
  using TYPE = CUDADeviceContext;
Y
Yang Yu 已提交
533 534
};

C
chengduoZH 已提交
535
// Currently, CUDAPinnedDeviceContext is only used to data copying.
C
chengduoZH 已提交
536 537 538 539 540 541
class CUDAPinnedDeviceContext : public DeviceContext {
 public:
  CUDAPinnedDeviceContext();
  explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place);

  Place GetPlace() const override;
C
chengduoZH 已提交
542

C
chengduoZH 已提交
543 544 545 546 547 548 549 550 551 552 553
  Eigen::DefaultDevice* eigen_device() const;

 private:
  CUDAPinnedPlace place_;
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

template <>
struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
  using TYPE = CUDAPinnedDeviceContext;
};
Q
QI JUN 已提交
554
#endif
Q
qijun 已提交
555

T
tensor-tang 已提交
556
#ifdef PADDLE_WITH_MKLDNN
557 558 559 560 561 562

class MKLDNNDeviceContextThreadLocals {
  // default mkldnn session id

  typedef MKLDNNDeviceContextThreadLocals self;
  struct Body {
563
    bool said_once = false;
564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582
    size_t cur_mkldnn_session_id;
    // Current data input shape string.
    // - For fixed-shape, it's a null string in default.
    // - For dynamic-shape, it's user specific.
    std::string cur_input_shape_str;
    // the cache capacity of different input shapes for MKLDNN.
    // Default 1 means fixed input shape, not dynamic shape.
    int cur_input_shape_cache_capacity;
    // Recently registered data_format. This is needed to
    // know for converting MKL-DNN Tensor to non MKL-DNN
    paddle::framework::DataLayout cur_paddle_data_layout;

    Body();
    void set_cur_mkldnn_session_id(size_t sid);
    size_t get_cur_mkldnn_session_id(void);
    void set_cur_input_shape_str(std::string input_shape_str);
    void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
    void set_cur_paddle_data_layout(framework::DataLayout dl);
    framework::DataLayout get_cur_paddle_data_layout(void);
583
    void log_lib_version(void);
584 585 586 587 588 589 590 591 592 593 594 595 596 597 598
  };
  MKLDNNDeviceContextThreadLocals() = default;
  MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
      delete;

 public:
  // default mkldnn session id
  static constexpr size_t kMKLDNNSessionID_Default = 0;
  // mkldnn session id for cache clearing mode
  static constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
  static Body& fetch() {
    thread_local Body b;
    return b;
  }
};
S
Sylwester Fraczek 已提交
599

T
tensor-tang 已提交
600 601
class MKLDNNDeviceContext : public CPUDeviceContext {
 public:
602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618
  template <class T>
  using BlobPtr_t = std::shared_ptr<T>;
  template <class P1, class P2>
  using umap_value_smart_t = std::unordered_map<P1, BlobPtr_t<P2>>;
  template <class T>
  using umap_key_string_t = umap_value_smart_t<std::string, T>;

  // Following three maps are used to cache MKLDNN primitives.
  // There relations are:
  // - BlobMap = Map<cur_thread_id, ShapeBlob>
  // - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
  // - KeyBlob  = Map<blob_name, blob>

  using KeyBlob = umap_key_string_t<void>;
  using ShapeBlob = umap_key_string_t<KeyBlob>;
  using BlobMap = umap_value_smart_t<int, ShapeBlob>;

T
tensor-tang 已提交
619 620 621
  explicit MKLDNNDeviceContext(CPUPlace place);

  /* \brief  Get the active engine */
622
  const mkldnn::engine& GetEngine() const { return engine_; }
T
tensor-tang 已提交
623

624
  // Remove all entries from the blob map
625 626
  void ResetBlobMap();

627 628 629 630
  // Set a suffix to be added to key
  void SetKeySuffix(const std::string& suffix) { key_suffix_ = suffix; }
  const std::string& GetKeySuffix(void) const { return key_suffix_; }

631
  // Disable adding  thread ID to the key
632 633
  void DisableThreadInfoInKey(void) { key_attach_thread_id_ = false; }
  bool IsThreadIdUsedInKey(void) const { return key_attach_thread_id_; }
634

635 636
  // Prevent next ResetBlobMap()
  void BlockNextCacheClearing();
637

638 639 640
  // Get the ShapeBlob size in cur_mkldnn_session_id.
  size_t GetShapeBlobSize() const;

641 642
  // Set data to blob (i.e. name/data pair). Create blob if not existing
  void SetBlob(const std::string& name, std::shared_ptr<void> data) const;
T
tensor-tang 已提交
643

644 645 646
  // Calculate number of oneDNN objects cached
  unsigned int GetCachedObjectsNumber(void);

647 648
  // Find a saved blob. Return nullptr if not found
  std::shared_ptr<void> GetBlob(const std::string& name) const;
T
tensor-tang 已提交
649

650 651 652 653
  static auto tls() -> decltype(MKLDNNDeviceContextThreadLocals::fetch()) {
    return MKLDNNDeviceContextThreadLocals::fetch();
  }

T
tensor-tang 已提交
654
 private:
655
  mkldnn::engine engine_;
656 657
  std::shared_ptr<BlobMap> p_blobmap_;
  std::shared_ptr<std::mutex> p_mutex_;
658
  bool block_next_cache_clearing_ = false;
659
  std::string key_suffix_;  // Key identifying current Executor
660
  bool key_attach_thread_id_ = true;
T
tensor-tang 已提交
661 662 663
};
#endif

D
dzhwinter 已提交
664 665 666 667 668
/*! \brief device context pool singleton */
class DeviceContextPool {
 public:
  explicit DeviceContextPool(const std::vector<platform::Place>& places);

Y
Yang Yu 已提交
669
  static DeviceContextPool& Instance() {
G
GaoWei8 已提交
670 671 672
    PADDLE_ENFORCE_NOT_NULL(pool,
                            platform::errors::PreconditionNotMet(
                                "Need to Create DeviceContextPool firstly!"));
D
dzhwinter 已提交
673 674 675 676
    return *pool;
  }

  /*! \brief  Create should only called by Init function */
Y
Yang Yu 已提交
677
  static DeviceContextPool& Init(const std::vector<platform::Place>& places) {
D
dzhwinter 已提交
678 679 680 681 682 683
    if (pool == nullptr) {
      pool = new DeviceContextPool(places);
    }
    return *pool;
  }

684 685
  static void SetPool(DeviceContextPool* dev_pool) { pool = dev_pool; }

D
dzhwinter 已提交
686
  /*! \brief  Return handle of single device context. */
Y
Yu Yang 已提交
687
  platform::DeviceContext* Get(const platform::Place& place);
D
dzhwinter 已提交
688

Y
Yang Yu 已提交
689 690 691 692 693 694 695
  template <typename Place>
  const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
      const Place& place) {
    return reinterpret_cast<
        const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place));
  }

696 697
  size_t size() const { return device_contexts_.size(); }

D
dzhwinter 已提交
698 699
 private:
  static DeviceContextPool* pool;
700 701
  std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>
      device_contexts_;
D
dzhwinter 已提交
702 703 704
  DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};

Q
QI JUN 已提交
705 706
}  // namespace platform
}  // namespace paddle