device_context.h 17.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
QI JUN 已提交
2 3 4 5 6 7 8 9 10 11 12
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

13
#include <future>  // NOLINT
D
dzhwinter 已提交
14
#include <memory>
Y
yuyang18 已提交
15
#include <mutex>  // NOLINT
16
#include <string>
D
dzhwinter 已提交
17
#include <unordered_map>
18
#include <utility>
19
#include <vector>
Y
Yu Yang 已提交
20
#include "paddle/fluid/memory/malloc.h"
21
#ifdef PADDLE_WITH_CUDA
22
#include "paddle/fluid/platform/cuda_helper.h"
Y
Yi Wang 已提交
23 24
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
G
Guo Sheng 已提交
25
#include "paddle/fluid/platform/dynload/cusolver.h"
26
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
W
Wu Yi 已提交
27
#include "paddle/fluid/platform/dynload/nccl.h"
W
Wu Yi 已提交
28
#endif
Y
Yi Wang 已提交
29
#include "paddle/fluid/platform/gpu_info.h"
Q
QI JUN 已提交
30
#endif
D
dzhwinter 已提交
31

T
tensor-tang 已提交
32
#ifdef PADDLE_WITH_MKLDNN
L
luotao1 已提交
33
#include "mkldnn.hpp"
34
#include "paddle/fluid/framework/data_layout.h"
T
tensor-tang 已提交
35 36
#endif

37 38
#include <map>
#include "glog/logging.h"
Y
Yi Wang 已提交
39 40
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
S
sneaxiy 已提交
41
#ifdef PADDLE_WITH_CUDA
42
#include "paddle/fluid/platform/stream/cuda_stream.h"
S
sneaxiy 已提交
43
#endif
Q
qijun 已提交
44
#include "unsupported/Eigen/CXX11/Tensor"
Q
QI JUN 已提交
45

46 47 48 49
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/xpu_header.h"
#endif

Q
QI JUN 已提交
50 51 52 53 54
namespace paddle {
namespace platform {

class DeviceContext {
 public:
Z
Zeng Jinle 已提交
55
  virtual ~DeviceContext() PADDLE_MAY_THROW {}
L
liaogang 已提交
56
  virtual Place GetPlace() const = 0;
Q
QI JUN 已提交
57

58
  virtual void Wait() const {}
Q
QI JUN 已提交
59 60
};

Q
qijun 已提交
61 62
class CPUDeviceContext : public DeviceContext {
 public:
63
  CPUDeviceContext();
Q
qijun 已提交
64
  explicit CPUDeviceContext(CPUPlace place);
Q
qijun 已提交
65

66
  Eigen::DefaultDevice* eigen_device() const;
Q
qijun 已提交
67

L
liaogang 已提交
68
  Place GetPlace() const override;
Y
Yu Yang 已提交
69

Q
qijun 已提交
70
 private:
D
dzhwinter 已提交
71
  CPUPlace place_;
Q
qijun 已提交
72
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
Q
QI JUN 已提交
73 74
};

Y
Yang Yu 已提交
75 76 77 78 79 80 81 82
template <typename Place>
struct DefaultDeviceContextType;

template <>
struct DefaultDeviceContextType<platform::CPUPlace> {
  using TYPE = CPUDeviceContext;
};

83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
#ifdef PADDLE_WITH_XPU
class XPUDeviceContext : public DeviceContext {
 public:
  XPUDeviceContext();
  explicit XPUDeviceContext(XPUPlace place);
  virtual ~XPUDeviceContext();
  Eigen::DefaultDevice* eigen_device() const { return nullptr; }
  Place GetPlace() const override;
  xpu::Context* x_context() const;

  /*! \brief  Wait for all operations completion in the stream. */
  void Wait() const override;

 private:
  XPUPlace place_;
  xpu::Context* context_;

  // Need to be the same with other DeviceContext,
  // Eventhough eigen_device_ is not used in XPU
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
  DISABLE_COPY_AND_ASSIGN(XPUDeviceContext);
};

template <>
struct DefaultDeviceContextType<platform::XPUPlace> {
  using TYPE = XPUDeviceContext;
};
#endif

112
#ifdef PADDLE_WITH_CUDA
113

Q
qijun 已提交
114
class EigenCudaStreamDevice;
115
class CudnnWorkspaceHandle;
S
sneaxiy 已提交
116

117 118 119 120 121
class CUDAContext {
 public:
  CUDAContext() = default;
  explicit CUDAContext(
      const CUDAPlace& place,
122
      const stream::Priority& priority = stream::Priority::kNormal);
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141

  ~CUDAContext();

  const CUDAPlace& Place() const { return place_; }

  const std::unique_ptr<Eigen::GpuDevice>& EigenDevice() const {
    return eigen_device_;
  }

  const std::unique_ptr<EigenCudaStreamDevice>& EigenStream() const {
    return eigen_stream_;
  }

  const std::unique_ptr<stream::CUDAStream>& Stream() const { return stream_; }

  const cudaStream_t& RawStream() { return stream_->raw_stream(); }

  const cudnnHandle_t& CudnnHandle() const { return cudnn_handle_; }

G
Guo Sheng 已提交
142 143 144 145
  const cusolverDnHandle_t& CusolverDnHandle() const {
    return cusolver_dn_handle_;
  }

146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
  const std::unique_ptr<CublasHandleHolder>& CublasHandle() const {
    return cublas_handle_;
  }

  const std::unique_ptr<CublasHandleHolder>& CublasTensorCoreHandle() const {
    return cublas_tensor_core_handle_;
  }

  /*! \brief  Call cublas function safely. */
  template <typename Callback>
  inline void CublasCall(Callback&& callback) const {
    cublas_handle_->Call(std::forward<Callback>(callback));
  }

  /*! \brief  Check whether tensor core is supported */
  bool tensor_core_available() const;

  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
  template <typename Callback>
  inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
    if (cublas_tensor_core_handle_) {
      cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
    } else {
      cublas_handle_->Call(std::forward<Callback>(callback));
    }
  }

 private:
  void InitEigenContext();

  void InitCuBlasContext() {
    cublas_handle_.reset(
        new CublasHandleHolder(RawStream(), CUBLAS_DEFAULT_MATH));
    if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
      cublas_tensor_core_handle_.reset(
          new CublasHandleHolder(RawStream(), CUBLAS_TENSOR_OP_MATH));
#endif
    }
  }

  void InitCuDNNContext() {
    if (dynload::HasCUDNN()) {
      auto local_cudnn_version = dynload::cudnnGetVersion() / 100;
      auto compile_cudnn_version = CUDNN_VERSION / 100;
      if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
        LOG_FIRST_N(WARNING, 1)
            << "WARNING: device: " << place_.device
            << ". The installed Paddle is compiled with CUDNN "
            << compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
            << ", but CUDNN version in your machine is "
            << local_cudnn_version / 10 << "." << local_cudnn_version % 10
            << ", which may cause serious incompatible bug. "
            << "Please recompile or reinstall Paddle with compatible CUDNN "
               "version.";
      }
203
      PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreate(&cudnn_handle_));
204
      PADDLE_ENFORCE_CUDA_SUCCESS(
205
          dynload::cudnnSetStream(cudnn_handle_, RawStream()));
206 207 208 209 210
    } else {
      cudnn_handle_ = nullptr;
    }
  }

G
Guo Sheng 已提交
211 212 213 214 215 216 217
  void InitCuSolverContext() {
    PADDLE_ENFORCE_CUDA_SUCCESS(
        dynload::cusolverDnCreate(&cusolver_dn_handle_));
    PADDLE_ENFORCE_CUDA_SUCCESS(
        dynload::cusolverDnSetStream(cusolver_dn_handle_, RawStream()));
  }

218 219
  void DestoryCuDNNContext() {
    if (cudnn_handle_) {
220
      PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroy(cudnn_handle_));
221 222 223 224 225 226 227 228 229
    }
    cudnn_handle_ = nullptr;
  }

  void DestoryCuBlasContext() {
    cublas_handle_.reset();
    cublas_tensor_core_handle_.reset();
  }

G
Guo Sheng 已提交
230 231 232 233 234 235 236
  void DestoryCuSolverContext() {
    if (cusolver_dn_handle_) {
      PADDLE_ENFORCE_CUDA_SUCCESS(
          dynload::cusolverDnDestroy(cusolver_dn_handle_));
    }
  }

237 238 239 240 241 242 243
  CUDAPlace place_;
  std::unique_ptr<Eigen::GpuDevice> eigen_device_;
  std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
  std::unique_ptr<stream::CUDAStream> stream_;
  cudnnHandle_t cudnn_handle_;
  std::unique_ptr<CublasHandleHolder> cublas_handle_;
  std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
G
Guo Sheng 已提交
244
  cusolverDnHandle_t cusolver_dn_handle_;
245 246 247
  DISABLE_COPY_AND_ASSIGN(CUDAContext);
};

248
class CUDADeviceContext : public DeviceContext {
Q
QI JUN 已提交
249
 public:
D
dzhwinter 已提交
250
  explicit CUDADeviceContext(CUDAPlace place);
251
  virtual ~CUDADeviceContext();
Q
QI JUN 已提交
252

253
  /*! \brief  Wait for all operations completion in the stream. */
254
  void Wait() const override;
Q
QI JUN 已提交
255

256
  /*! \brief  Return place in the device context. */
L
liaogang 已提交
257
  Place GetPlace() const override;
258

K
Kexin Zhao 已提交
259
  /*! \brief  Return compute capability in the device context. */
K
Kexin Zhao 已提交
260 261
  int GetComputeCapability() const;

262 263 264
  /*! \brief  Return the max physical thread count in the device context */
  int GetMaxPhysicalThreadCount() const;

265 266 267 268 269 270
  /*! \brief  Return the SM count in the device context */
  int GetSMCount() const;

  /*! \brief  Return the Max thread num of block in the device context */
  int GetMaxThreadsPerBlock() const;

271 272 273
  /*! \brief  Return the max grid dim size in the device context */
  dim3 GetCUDAMaxGridDimSize() const;

274 275 276
  /*! \brief  Return eigen device in the device context. */
  Eigen::GpuDevice* eigen_device() const;

277 278 279
  /*! \brief  Call cublas function safely. */
  template <typename Callback>
  inline void CublasCall(Callback&& callback) const {
280
    return context()->CublasCall(callback);
281 282 283 284 285 286 287 288 289
  }

  /*! \brief  Check whether tensor core is supported */
  bool tensor_core_available() const;

  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
  template <typename Callback>
  inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
290
    return context()->TensorCoreCublasCallIfAvailable(callback);
291
  }
S
sneaxiy 已提交
292

293
  /*! \brief  Return cudnn  handle in the device context. */
294
  cudnnHandle_t cudnn_handle() const;
295

S
sneaxiy 已提交
296 297 298 299 300 301 302 303 304
  /*! \brief  Return a cudnn workspace handle to call multiple cudnn
   *  functions without interrupting by other threads.
   *  Once the first cudnn function is called by the handle, a lock
   *  would be acquired to prevent other threads from accessing the
   *  workspace. Once the handle is destructed, the lock would be released.
   *  CudnnWorkspaceHandle is an RAII object to implement thread-safe
   *  sequential cudnn function calls. */
  CudnnWorkspaceHandle cudnn_workspace_handle() const;

G
Guo Sheng 已提交
305 306
  cusolverDnHandle_t cusolver_dn_handle() const;

Q
init  
qijun 已提交
307
  /*! \brief  Return cuda stream in the device context. */
308
  cudaStream_t stream() const;
Q
QI JUN 已提交
309

310
#if defined(PADDLE_WITH_NCCL)
Q
qingqing01 已提交
311 312 313 314 315
  /*! \brief  Return nccl communicators. */
  ncclComm_t nccl_comm() const { return nccl_comm_; }

  /*! \brief  Set nccl communicators. */
  void set_nccl_comm(ncclComm_t comm) { nccl_comm_ = comm; }
Q
qingqing01 已提交
316
#endif
Q
qingqing01 已提交
317

Y
Yu Yang 已提交
318
  template <typename Callback>
319 320
  void RecordEvent(cudaEvent_t ev, Callback callback) const {
    return context()->Stream()->RecordEvent(ev, callback);
Y
Yu Yang 已提交
321 322
  }

S
sneaxiy 已提交
323 324
  template <typename Callback>
  void AddStreamCallback(Callback&& callback) const {
325 326 327 328 329
    return context()->Stream()->AddCallback(callback);
  }

  void WaitStreamCallback() const {
    return context()->Stream()->WaitCallback();
330 331
  }

332
  void ResetDefaultContext(const stream::Priority& priority) {
333 334 335
    default_ctx_.reset(new CUDAContext(place_, priority));
  }

336
  void ResetThreadContext(const stream::Priority& priority) {
337 338 339 340 341 342 343 344 345 346
    std::lock_guard<std::mutex> guard(ctx_mtx_);
    thread_ctx_[this].reset(new CUDAContext(place_, priority));
  }

  std::shared_ptr<CUDAContext> context() const {
    if (!thread_ctx_.count(this)) {
      return default_ctx_;
    }
    return thread_ctx_.at(this);
  }
S
sneaxiy 已提交
347

Q
QI JUN 已提交
348
 private:
D
dzhwinter 已提交
349
  CUDAPlace place_;
350
  std::shared_ptr<CUDAContext> default_ctx_;
Q
QI JUN 已提交
351

352 353 354 355 356 357
  // The thread_local static variable will be released before the
  // global static variable, so avoid using it in dtor.
  static thread_local std::unordered_map<const CUDADeviceContext*,
                                         std::shared_ptr<CUDAContext>>
      thread_ctx_;
  static thread_local std::mutex ctx_mtx_;
358

359 360
  mutable std::mutex cudnn_handle_mtx_;

361
#if defined(PADDLE_WITH_NCCL)
Q
qingqing01 已提交
362 363 364 365 366 367
  // NCCL communicator (single process version) for NCCL collective operations.
  // NCCL collective operations provides fast collectives over multiple GPUs
  // both within and across nodes.
  // But, this collectives is used for collectives over multiple GPUs within
  // nodes.
  ncclComm_t nccl_comm_{nullptr};
Q
qingqing01 已提交
368
#endif
Q
qingqing01 已提交
369

C
chengduo 已提交
370 371 372 373 374
  int compute_capability_;
  int runtime_version_;
  int driver_version_;
  int multi_process_;
  int max_threads_per_mp_;
375
  int max_threads_per_block_;
376
  dim3 max_grid_dim_size_;
Y
yuyang18 已提交
377

378
  DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
Q
QI JUN 已提交
379
};
Q
qijun 已提交
380

381 382
class CudnnWorkspaceHandle {
 public:
383 384
  inline CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx, std::mutex* mtx)
      : device_context_(dev_ctx), mtx_(mtx) {}
385 386 387 388 389 390 391 392

  template <typename Callback>
  inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_bytes) {
    if (required_workspace_bytes > WorkspaceSize()) {
      ReallocWorkspace(required_workspace_bytes);
    }
    VLOG(2) << "Cudnn workspace size at RunFunc: "
            << static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB";
393 394 395 396
    {
      std::lock_guard<std::mutex> guard(*mtx_);
      cudnn_func(allocation_ ? allocation_->ptr() : nullptr);
    }
397 398 399 400 401 402 403 404 405 406 407 408 409
  }

  /*! \brief Thread which call RunFuncSync() would release gpu memory after
   *  running the function. Currently this function is only used when cudnn
   *  exhaustive searching and callers have to guarantee that the input function
   *  is host blocking */
  template <typename Callback>
  inline void RunFuncSync(Callback&& cudnn_func,
                          size_t required_workspace_bytes) {
    RunFunc(cudnn_func, required_workspace_bytes);
    ResetWorkspace();
  }

410
  void ReallocWorkspace(size_t required_workspace_bytes);
411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426

  inline void ResetWorkspace() { allocation_ = nullptr; }

  inline size_t WorkspaceSize() {
    if (allocation_ == nullptr) {
      return 0;
    }
    return allocation_->size();
  }

  CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
  CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;

 private:
  memory::allocation::AllocationPtr allocation_;
  const CUDADeviceContext& device_context_;
427
  std::mutex* mtx_;
428 429
};

Y
Yang Yu 已提交
430 431
template <>
struct DefaultDeviceContextType<platform::CUDAPlace> {
Y
Yang Yu 已提交
432
  using TYPE = CUDADeviceContext;
Y
Yang Yu 已提交
433 434
};

C
chengduoZH 已提交
435
// Currently, CUDAPinnedDeviceContext is only used to data copying.
C
chengduoZH 已提交
436 437 438 439 440 441
class CUDAPinnedDeviceContext : public DeviceContext {
 public:
  CUDAPinnedDeviceContext();
  explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place);

  Place GetPlace() const override;
C
chengduoZH 已提交
442

C
chengduoZH 已提交
443 444 445 446 447 448 449 450 451 452 453
  Eigen::DefaultDevice* eigen_device() const;

 private:
  CUDAPinnedPlace place_;
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

template <>
struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
  using TYPE = CUDAPinnedDeviceContext;
};
Q
QI JUN 已提交
454
#endif
Q
qijun 已提交
455

T
tensor-tang 已提交
456
#ifdef PADDLE_WITH_MKLDNN
457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496

class MKLDNNDeviceContextThreadLocals {
  // default mkldnn session id

  typedef MKLDNNDeviceContextThreadLocals self;
  struct Body {
    size_t cur_mkldnn_session_id;
    // Current data input shape string.
    // - For fixed-shape, it's a null string in default.
    // - For dynamic-shape, it's user specific.
    std::string cur_input_shape_str;
    // the cache capacity of different input shapes for MKLDNN.
    // Default 1 means fixed input shape, not dynamic shape.
    int cur_input_shape_cache_capacity;
    // Recently registered data_format. This is needed to
    // know for converting MKL-DNN Tensor to non MKL-DNN
    paddle::framework::DataLayout cur_paddle_data_layout;

    Body();
    void set_cur_mkldnn_session_id(size_t sid);
    size_t get_cur_mkldnn_session_id(void);
    void set_cur_input_shape_str(std::string input_shape_str);
    void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
    void set_cur_paddle_data_layout(framework::DataLayout dl);
    framework::DataLayout get_cur_paddle_data_layout(void);
  };
  MKLDNNDeviceContextThreadLocals() = default;
  MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
      delete;

 public:
  // default mkldnn session id
  static constexpr size_t kMKLDNNSessionID_Default = 0;
  // mkldnn session id for cache clearing mode
  static constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
  static Body& fetch() {
    thread_local Body b;
    return b;
  }
};
S
Sylwester Fraczek 已提交
497

T
tensor-tang 已提交
498 499
class MKLDNNDeviceContext : public CPUDeviceContext {
 public:
500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516
  template <class T>
  using BlobPtr_t = std::shared_ptr<T>;
  template <class P1, class P2>
  using umap_value_smart_t = std::unordered_map<P1, BlobPtr_t<P2>>;
  template <class T>
  using umap_key_string_t = umap_value_smart_t<std::string, T>;

  // Following three maps are used to cache MKLDNN primitives.
  // There relations are:
  // - BlobMap = Map<cur_thread_id, ShapeBlob>
  // - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
  // - KeyBlob  = Map<blob_name, blob>

  using KeyBlob = umap_key_string_t<void>;
  using ShapeBlob = umap_key_string_t<KeyBlob>;
  using BlobMap = umap_value_smart_t<int, ShapeBlob>;

T
tensor-tang 已提交
517 518 519
  explicit MKLDNNDeviceContext(CPUPlace place);

  /* \brief  Get the active engine */
520
  const mkldnn::engine& GetEngine() const { return engine_; }
T
tensor-tang 已提交
521

522 523 524
  // Remove all entries from the blob map
  void ResetBlobMap() const;

525 526 527
  // Get the ShapeBlob size in cur_mkldnn_session_id.
  size_t GetShapeBlobSize() const;

528 529
  // Set data to blob (i.e. name/data pair). Create blob if not existing
  void SetBlob(const std::string& name, std::shared_ptr<void> data) const;
T
tensor-tang 已提交
530

531 532
  // Find a saved blob. Return nullptr if not found
  std::shared_ptr<void> GetBlob(const std::string& name) const;
T
tensor-tang 已提交
533

534 535 536 537
  static auto tls() -> decltype(MKLDNNDeviceContextThreadLocals::fetch()) {
    return MKLDNNDeviceContextThreadLocals::fetch();
  }

T
tensor-tang 已提交
538
 private:
539
  mkldnn::engine engine_;
540 541
  std::shared_ptr<BlobMap> p_blobmap_;
  std::shared_ptr<std::mutex> p_mutex_;
T
tensor-tang 已提交
542 543 544
};
#endif

D
dzhwinter 已提交
545 546 547 548 549
/*! \brief device context pool singleton */
class DeviceContextPool {
 public:
  explicit DeviceContextPool(const std::vector<platform::Place>& places);

Y
Yang Yu 已提交
550
  static DeviceContextPool& Instance() {
G
GaoWei8 已提交
551 552 553
    PADDLE_ENFORCE_NOT_NULL(pool,
                            platform::errors::PreconditionNotMet(
                                "Need to Create DeviceContextPool firstly!"));
D
dzhwinter 已提交
554 555 556 557
    return *pool;
  }

  /*! \brief  Create should only called by Init function */
Y
Yang Yu 已提交
558
  static DeviceContextPool& Init(const std::vector<platform::Place>& places) {
D
dzhwinter 已提交
559 560 561 562 563 564
    if (pool == nullptr) {
      pool = new DeviceContextPool(places);
    }
    return *pool;
  }

565 566
  static void SetPool(DeviceContextPool* dev_pool) { pool = dev_pool; }

D
dzhwinter 已提交
567
  /*! \brief  Return handle of single device context. */
Y
Yu Yang 已提交
568
  platform::DeviceContext* Get(const platform::Place& place);
D
dzhwinter 已提交
569

Y
Yang Yu 已提交
570 571 572 573 574 575 576
  template <typename Place>
  const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
      const Place& place) {
    return reinterpret_cast<
        const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place));
  }

577 578
  size_t size() const { return device_contexts_.size(); }

D
dzhwinter 已提交
579 580
 private:
  static DeviceContextPool* pool;
581 582
  std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>
      device_contexts_;
D
dzhwinter 已提交
583 584 585
  DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};

Q
QI JUN 已提交
586 587
}  // namespace platform
}  // namespace paddle