device_context.cc 27.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
13
#include <set>
14

15
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
16
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
17
#include "paddle/fluid/platform/cuda_device_guard.h"
18
#endif
19
#include "glog/logging.h"
20
#include "paddle/fluid/platform/profiler.h"
21

22 23 24 25 26
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
27
  if (size == 0) {
28 29
    return Alloc(place, size);
  }
30 31

  if (platform::is_gpu_place(place)) {
32
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
51 52
    return Alloc(place, size);
#else
53 54 55
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
56
#endif
57 58 59
  } else {
    return Alloc(place, size);
  }
60 61 62 63 64
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
65 66 67
namespace paddle {
namespace platform {

68
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
69 70 71
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
72 73 74 75

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
76 77
#endif  // PADDLE_WITH_CUDA

78 79 80 81 82 83 84 85 86 87 88 89 90
DeviceType Place2DeviceType(const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
    return platform::DeviceType::CPU;
  } else if (platform::is_gpu_place(place)) {
    return platform::DeviceType::CUDA;
  } else if (platform::is_xpu_place(place)) {
    return platform::DeviceType::XPU;
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported place %s to convert into platform::DeviceType.", place));
  }
}

D
dzhwinter 已提交
91 92
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
93
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
94
  VLOG(4) << "DeviceContextPool Get: " << place;
D
dzhwinter 已提交
95 96
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
97 98
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
99 100
        "with WITH_GPU, WITH_XPU or WITH_ASCEND_CL option or check that "
        "your train process set the correct device id if you use Executor.",
G
GaoWei8 已提交
101
        place));
D
dzhwinter 已提交
102
  }
103
  return it->second.get().get();
D
dzhwinter 已提交
104 105
}

106 107 108 109 110 111 112 113 114
template <typename DevCtx, typename PlaceType>
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
  map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
                     // lazy evaluation. i.e., only create device context at
                     // first `Get`
115
                     return PtrType(new DevCtx(BOOST_GET_CONST(PlaceType, p)));
116
                   }));
C
chengduozh 已提交
117 118
}

D
dzhwinter 已提交
119 120
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
121 122 123 124 125
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
126
  std::set<Place> set;
Y
Yu Yang 已提交
127 128 129 130 131
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
132
#ifdef PADDLE_WITH_MKLDNN
133
      EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
134
#else
135
      EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
136
#endif
Y
Yu Yang 已提交
137
    } else if (platform::is_gpu_place(p)) {
138
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
139
      EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
D
dzhwinter 已提交
140
#else
G
GaoWei8 已提交
141 142 143
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
144 145
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
146
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
147 148
      EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
          &device_contexts_, p);
C
chengduoZH 已提交
149
#else
G
GaoWei8 已提交
150
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
151 152
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
153 154 155 156 157 158 159 160
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
      EmplaceDeviceContext<XPUDeviceContext, XPUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
161 162 163 164 165 166 167 168
#endif
    } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
      EmplaceDeviceContext<NPUDeviceContext, NPUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPlace is not supported. Please "
          "re-compile with WITH_ASCEND_CL option."));
169 170 171 172 173 174 175 176 177 178
#endif
    } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
      EmplaceDeviceContext<NPUPinnedDeviceContext, NPUPinnedPlace>(
          &device_contexts_, p);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPinnedPlace is not supported. Please re-compile with "
          "WITH_ASCEND_CL "
          "option."));
D
dzhwinter 已提交
179 180 181 182 183
#endif
    }
  }
}

184 185 186 187
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
188
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
189 190 191 192 193 194 195
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
196
Place CPUDeviceContext::GetPlace() const { return place_; }
197

198 199 200
#ifdef PADDLE_WITH_XPU
XPUDeviceContext::XPUDeviceContext() { context_ = xpu::create_context(); }

201
XPUDeviceContext::~XPUDeviceContext() {}
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216

XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) {
  int dev_id = -1;
  int ret = xpu_current_device(&dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  ret = xpu_set_device(place.device);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
217 218 219

  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: " << place_.device;

220
  context_ = xpu::create_context();
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
  const int MAX_XPU_NUM = 16;
  const int l3_size = 13.5 * 1024 * 1024;
  static void* l3ptrs[MAX_XPU_NUM] = {nullptr};

  auto selected_xpus = GetXPUSelectedDevices();
  for (unsigned int i = 0; i < selected_xpus.size(); i++) {
    if (place.device == selected_xpus[i]) {
      if (l3ptrs[place.device] == nullptr) {
        xpu_malloc(static_cast<void**>(&l3ptrs[place.device]), l3_size,
                   XPU_MEM_L3);
      }
      if (l3ptrs[place.device] != nullptr) {
        context_->_l3_mgr.set(l3ptrs[place.device], l3_size);
        VLOG(3) << "xpu place " << place.device << " set l3 size " << l3_size;
      }
      break;
    }
238
  }
239

240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
  ret = xpu_set_device(dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
}

void XPUDeviceContext::Wait() const {
  int ret = xpu_set_device(place_.device);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
255
  xpu_wait(context_->xpu_stream);
256 257 258 259 260 261 262
}

Place XPUDeviceContext::GetPlace() const { return place_; }

xpu::Context* XPUDeviceContext::x_context() const { return context_; }
#endif

263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
#ifdef PADDLE_WITH_ASCEND_CL
NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) {
  NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device));
  // NOTE(zhiqiu): Usually, no need to create context explicitly,
  // ACL creates a default context which contains 1 default stream
  // and 1 sync strean after aclrtSetDevice.
  PADDLE_ENFORCE_NPU_SUCCESS(aclrtGetCurrentContext(&context_));
  stream_.reset(new stream::NPUStream(place));
}

NPUDeviceContext::~NPUDeviceContext() {
  // NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_));
}
278

279
void NPUDeviceContext::Wait() const {
280 281 282
  platform::RecordEvent record_event("NPUDeviceContext/wait");
  VLOG(4) << "NPU context(" << this << ")  Wait";
  stream_->Wait();
283 284 285 286 287 288 289
}

aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); }

Place NPUDeviceContext::GetPlace() const { return place_; }

aclrtContext NPUDeviceContext::context() const { return context_; }
290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305

NPUPinnedDeviceContext::NPUPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

NPUPinnedDeviceContext::NPUPinnedDeviceContext(NPUPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* NPUPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place NPUPinnedDeviceContext::GetPlace() const { return place_; }

306 307 308
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
init  
qijun 已提交
309 310 311 312 313 314 315
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

316
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
317 318 319 320 321
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

322
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
323

324 325 326
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
327
  const cudaDeviceProp& deviceProperties() const override {
328
#endif
Q
init  
qijun 已提交
329 330 331 332
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
333 334 335
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
336 337 338
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
339
    void* retv = buf->ptr();
S
sneaxiy 已提交
340 341 342 343
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
344
    return retv;
Q
init  
qijun 已提交
345 346
  }

S
sneaxiy 已提交
347 348 349 350 351 352
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
353 354 355

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
356
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
357 358 359 360 361 362
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
363
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
364
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
365 366 367 368
#ifdef PADDLE_WITH_HIP
      PADDLE_ENFORCE_CUDA_SUCCESS(
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
369
      PADDLE_ENFORCE_CUDA_SUCCESS(
Q
init  
qijun 已提交
370
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
371
#endif
Q
init  
qijun 已提交
372 373 374 375 376
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
377
  CUDAPlace place_;
378 379 380 381
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
382
  const cudaDeviceProp* device_prop_;  // not owned;
383
#endif
Q
qijun 已提交
384
  mutable void* scratch_;
Q
init  
qijun 已提交
385
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
386
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
387
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
388 389
};

390 391 392 393 394 395 396 397 398
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

399 400 401 402 403 404 405 406 407 408 409 410
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
411
                         const stream::Priority& priority) {
412 413 414 415 416 417
  place_ = place;
  CUDADeviceGuard guard(place_.device);
  stream_.reset(new stream::CUDAStream(place, priority));
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
418
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
419
  InitCuSolverContext();
420
#endif
421 422 423 424 425 426
}

CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
427
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
428
  DestoryCuSolverContext();
429
#endif
430 431
}

432
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
Y
Yu Yang 已提交
433
  CUDADeviceGuard guard(place_.device);
C
chengduo 已提交
434 435 436
  compute_capability_ = GetCUDAComputeCapability(place_.device);
  multi_process_ = GetCUDAMultiProcessors(place_.device);
  max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
437
  max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
438
  max_threads_per_block_ = GetCUDAMaxThreadsPerBlock(place_.device);
439

C
chengduo 已提交
440 441 442
  driver_version_ = GetCUDADriverVersion(place_.device);
  runtime_version_ = GetCUDARuntimeVersion(place_.device);

443
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device
444 445 446
                          << ", GPU Compute Capability: "
                          << compute_capability_ / 10 << "."
                          << compute_capability_ % 10
C
chengduo 已提交
447
                          << ", Driver API Version: " << driver_version_ / 1000
448
                          << "." << (driver_version_ % 100) / 10
C
chengduo 已提交
449 450 451
                          << ", Runtime API Version: "
                          << runtime_version_ / 1000 << "."
                          << (runtime_version_ % 100) / 10;
452 453 454 455 456 457 458 459
#ifdef PADDLE_WITH_HIP
  size_t version_major, version_minor, version_patch;
  PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenGetVersion(
      &version_major, &version_minor, &version_patch));
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", MIOpen Version: " << version_major << "."
                          << version_minor << "." << version_patch;
#else
460 461 462
  size_t cudnn_dso_ver = dynload::cudnnGetVersion();
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
463
                          << (cudnn_dso_ver % 1000) / 100 << ".";
464
#endif
S
sneaxiy 已提交
465 466
  {
    // Check CUDA/CUDNN version compatiblity
467 468
    auto local_cuda_version =
        (driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
469 470 471
#ifdef PADDLE_WITH_HIP
    auto compile_cuda_version = (HIP_VERSION / 100) * 10 + (HIP_VERSION % 10);
#else
472 473
    auto compile_cuda_version =
        (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
474
#endif
S
sneaxiy 已提交
475 476 477 478 479 480 481 482 483 484 485 486
    if (local_cuda_version < compile_cuda_version) {
      LOG_FIRST_N(WARNING, 1)
          << "WARNING: device: " << place_.device
          << ". The installed Paddle is compiled with CUDA "
          << compile_cuda_version / 10 << "." << compile_cuda_version % 10
          << ", but CUDA runtime version in your machine is "
          << local_cuda_version / 10 << "." << local_cuda_version % 10
          << ", which may cause serious incompatible bug. "
          << "Please recompile or reinstall Paddle with compatible CUDA "
             "version.";
    }
  }
487
  default_ctx_.reset(new CUDAContext(place_));
488 489 490 491
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
492
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
493 494 495 496
  if (nccl_comm_) {
    PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
  }
#endif
497 498
}

L
liaogang 已提交
499
Place CUDADeviceContext::GetPlace() const { return place_; }
500

501
void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); }
502

K
Kexin Zhao 已提交
503
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
504
  return compute_capability_;
K
Kexin Zhao 已提交
505 506
}

507
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
508
  return multi_process_ * max_threads_per_mp_;
509 510
}

511 512 513 514 515 516
int CUDADeviceContext::GetSMCount() const { return multi_process_; }

int CUDADeviceContext::GetMaxThreadsPerBlock() const {
  return max_threads_per_block_;
}

517
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
518
  return context()->EigenDevice().get();
519 520
}

521
bool CUDADeviceContext::tensor_core_available() const {
522
  return context()->CublasTensorCoreHandle() != nullptr;
S
sneaxiy 已提交
523 524
}

525 526 527 528
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
  return max_grid_dim_size_;
}

529 530 531
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
532
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
533
#endif
534 535
  return context()->CudnnHandle();
}
536

537 538 539 540 541
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}
#else
542 543 544
cublasHandle_t CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}
545
#endif
546

S
sneaxiy 已提交
547
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
548
  return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
549
}
550

551
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
552 553 554
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  return context()->CusolverDnHandle();
}
555
#endif
G
Guo Sheng 已提交
556

557
gpuStream_t CUDADeviceContext::stream() const { return context()->RawStream(); }
Q
qijun 已提交
558

C
chengduoZH 已提交
559 560 561 562 563 564 565 566 567 568 569 570 571 572
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
573
#endif
Q
qijun 已提交
574

T
tensor-tang 已提交
575 576
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
577
    : CPUDeviceContext(place), p_blobmap_() {
578
  p_blobmap_.reset(new BlobMap());
579
  p_exec_items_.reset(new ExecShape());
580
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
581 582
}

583 584
MKLDNNDeviceContextThreadLocals::Body::Body()
    : cur_engine(mkldnn::engine::kind::cpu, 0), cur_stream(cur_engine) {
585 586 587 588 589 590
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

591 592 593 594 595 596 597 598 599 600 601 602
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
603
  dev_ctx->ResetBlobMap(exec_ptr_);
604 605
}

606 607 608 609 610 611 612 613 614 615
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
616 617
  cur_input_shape_str = input_shape_str;
}
618 619
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
620 621
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
622

623 624
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
625 626 627
  cur_paddle_data_layout = dl;
}

628 629
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
630 631 632
  return cur_paddle_data_layout;
}

633 634 635 636 637 638 639 640 641
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

642 643 644 645 646 647 648 649
const mkldnn::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
  return cur_engine;
}

mkldnn::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
  return cur_stream;
}

650
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
651 652 653
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
654 655 656 657 658 659
    // If no specific executor pointer then clear
    // everything. For executor pointer then clear only
    // objects allocated when using given executor
    if (ptr == nullptr) {
      p_blobmap_->clear();
    } else {
660 661 662 663 664 665 666 667
      // Iterate through all shapes and release
      // for each shape and active executor all entries
      // of this executor
      for (auto& s : *p_exec_items_) {
        for (auto& v : (*s.second)[ptr]) {
          (v.first)->erase(v.second);
        }
        s.second->erase(ptr);
668 669
      }
    }
670 671 672 673 674 675
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

676 677 678 679
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
  p_exec_items_->erase(p_exec_items_->begin());
}

680 681
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
                                                KeyBlob::iterator it) const {
682
  // Take current input shape from TLS
683 684
  // Take current executor addess from TLS
  // and for this executor's items add the one defined with arguments
685 686 687 688 689 690 691 692 693
  auto key_it = p_exec_items_
                    ->insert(std::make_pair(tls().cur_input_shape_str,
                                            std::make_shared<ExecMap>()))
                    .first;
  (*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));

  VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
          << " curr exec size: "
          << (*key_it->second)[tls().get_curr_exec()].size() << "\n";
694 695
}

696 697 698 699
void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
700
}
701

702
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
703
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
704
  BlobMap* pMap = p_blobmap_.get();
705
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
706
  if (map_it == pMap->end()) {
707 708 709
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
710 711 712 713
  }
  return map_it->second->size();
}

714
void MKLDNNDeviceContext::SetBlob(const std::string& name,
715
                                  BlobPtr_t<void> data) const {
716
  BlobMap* pMap = p_blobmap_.get();
717 718
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
  BlobPtr_t<KeyBlob> pBlob = nullptr;
719

720
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
721

722
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
723

724 725
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
726 727 728

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
729
    sBlob = std::make_shared<ShapeBlob>();
730 731
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
732
  } else {
733
    sBlob = map_it->second;
734
  }
T
tensor-tang 已提交
735

736
  // Find KeyBlob for current input shape
737
  auto key_it = sBlob->find(tls().cur_input_shape_str);
738

739
  if (key_it == sBlob->end()) {
740 741
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
742 743
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
744
        sBlob->size() &&
745
        (sBlob->size() >=
746
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
747 748 749
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
750
      RemoveShapeEntriesWithExecutor();
751
    }
752 753
    pBlob = std::make_shared<KeyBlob>();
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
754
  } else {
755
    pBlob = key_it->second;
756 757
  }

758 759 760
  // Find Blob via name
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
761 762 763 764 765
    auto el =
        pBlob->insert(std::make_pair(name, data));  //  (*pBlob)[name] = data;
    // Register new element in per executor map
    // to have easily erased when executor terminated
    LinkEntryWithExecutor(pBlob, el.first);
766 767 768
  } else {
    blob_it->second = data;  // set data to existing blob
  }
769
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
770
  // lock will be automatically released when out of scope
771
  return;
T
tensor-tang 已提交
772 773
}

774
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
775 776 777 778 779 780 781 782 783
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
      num_entries += (l2.second)->size();
    }
  }
  return num_entries;
}

784
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
785
    const std::string& name) const {
786
  BlobMap* pMap = p_blobmap_.get();
787 788
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
789

790
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
791

792
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
793

794 795
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
796
  if (map_it == pMap->end()) {
797
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
798 799 800 801 802
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
803
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
804
  if (sBlob_it == sBlob->end()) {
805
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
806 807 808 809
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
810 811 812 813

  // Find Blob via name
  auto key_it = pBlob->find(name);

814
  if (key_it == pBlob->end()) {
815
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
816 817
    return nullptr;
  }
818

819
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
820 821
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
822 823 824
}

#endif
Q
qijun 已提交
825
}  // namespace platform
Q
qijun 已提交
826
}  // namespace paddle