device_context.cc 28.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
13
#include <memory>
14
#include <set>
15

16
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
17
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
18
#include "paddle/fluid/platform/cuda_device_guard.h"
19
#endif
F
fwenguang 已提交
20 21 22 23
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/fluid/platform/device/mlu/device_context_allocator.h"
#endif
24
#include "glog/logging.h"
25
#include "paddle/fluid/framework/expect.h"
26
#include "paddle/fluid/platform/profiler.h"
27

28 29 30 31 32
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
33
  if (size == 0) {
34 35
    return Alloc(place, size);
  }
36 37

  if (platform::is_gpu_place(place)) {
38
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
57 58
    return Alloc(place, size);
#else
59 60 61
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
F
fwenguang 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
#endif
  } else if (platform::is_mlu_place(place)) {
#ifdef PADDLE_WITH_MLU
    auto* default_dev_ctx = static_cast<platform::MLUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::MLUDeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::MLUDeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use MLU device since it's not compiled with MLU,"
        "Please recompile or reinstall Paddle with MLU support."));
79
#endif
80 81 82
  } else {
    return Alloc(place, size);
  }
83 84 85 86 87
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
88 89 90
namespace paddle {
namespace platform {

91
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
92 93 94
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
95 96 97 98

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
99 100
#endif  // PADDLE_WITH_CUDA

101 102 103 104 105 106 107
DeviceType Place2DeviceType(const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
    return platform::DeviceType::CPU;
  } else if (platform::is_gpu_place(place)) {
    return platform::DeviceType::CUDA;
  } else if (platform::is_xpu_place(place)) {
    return platform::DeviceType::XPU;
F
fwenguang 已提交
108 109
  } else if (platform::is_mlu_place(place)) {
    return platform::DeviceType::MLU;
110 111 112 113 114 115
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported place %s to convert into platform::DeviceType.", place));
  }
}

D
dzhwinter 已提交
116 117
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
118
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
119
  VLOG(6) << "DeviceContextPool Get: " << place;
D
dzhwinter 已提交
120 121
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
122 123
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
F
fwenguang 已提交
124 125
        "with WITH_GPU, WITH_XPU, WITH_IPU, WITH_MLU or WITH_ASCEND_CL option "
        "or check "
J
jianghaicheng 已提交
126 127
        "that your train process set the correct device id if you use "
        "Executor.",
G
GaoWei8 已提交
128
        place));
D
dzhwinter 已提交
129
  }
130
  return it->second.get().get();
D
dzhwinter 已提交
131 132
}

W
Wilber 已提交
133
template <typename DevCtx>
134 135 136 137 138 139 140 141
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
  map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
                     // lazy evaluation. i.e., only create device context at
                     // first `Get`
142
                     return PtrType(new DevCtx(p));
143
                   }));
C
chengduozh 已提交
144 145
}

D
dzhwinter 已提交
146 147
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
148 149 150 151 152
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
153
  std::set<Place> set;
Y
Yu Yang 已提交
154 155 156 157 158
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
159
#ifdef PADDLE_WITH_MKLDNN
W
Wilber 已提交
160
      EmplaceDeviceContext<MKLDNNDeviceContext>(&device_contexts_, p);
161
#else
W
Wilber 已提交
162
      EmplaceDeviceContext<CPUDeviceContext>(&device_contexts_, p);
163
#endif
Y
Yu Yang 已提交
164
    } else if (platform::is_gpu_place(p)) {
165
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
166
      EmplaceDeviceContext<CUDADeviceContext>(&device_contexts_, p);
D
dzhwinter 已提交
167
#else
G
GaoWei8 已提交
168 169 170
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
171 172
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
173
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
174
      EmplaceDeviceContext<CUDAPinnedDeviceContext>(&device_contexts_, p);
C
chengduoZH 已提交
175
#else
G
GaoWei8 已提交
176
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
177 178
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
179 180 181
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
W
Wilber 已提交
182
      EmplaceDeviceContext<XPUDeviceContext>(&device_contexts_, p);
183 184 185 186
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
F
fwenguang 已提交
187 188 189
#endif
    } else if (platform::is_mlu_place(p)) {
#ifdef PADDLE_WITH_MLU
W
Wilber 已提交
190
      EmplaceDeviceContext<MLUDeviceContext>(&device_contexts_, p);
F
fwenguang 已提交
191 192 193 194
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("MLUPlace is not supported. Please "
                                          "re-compile with WITH_MLU option."));
J
jianghaicheng 已提交
195 196 197
#endif
    } else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU
W
Wilber 已提交
198
      EmplaceDeviceContext<IPUDeviceContext>(&device_contexts_, p);
J
jianghaicheng 已提交
199 200 201 202
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("IPUPlace is not supported. Please "
                                          "re-compile with WITH_IPU option."));
203 204 205
#endif
    } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
206
      EmplaceDeviceContext<NPUDeviceContext>(&device_contexts_, p);
207 208 209 210
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPlace is not supported. Please "
          "re-compile with WITH_ASCEND_CL option."));
211 212 213
#endif
    } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
214
      EmplaceDeviceContext<NPUPinnedDeviceContext>(&device_contexts_, p);
215 216 217 218 219
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPinnedPlace is not supported. Please re-compile with "
          "WITH_ASCEND_CL "
          "option."));
D
dzhwinter 已提交
220 221 222 223 224
#endif
    }
  }
}

W
Wilber 已提交
225
CPUDeviceContext::CPUDeviceContext() : pten::CPUContext() {}
226

W
Wilber 已提交
227
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : pten::CPUContext() {}
228

J
jianghaicheng 已提交
229
#ifdef PADDLE_WITH_IPU
A
Allen Guo 已提交
230
IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) {}
J
jianghaicheng 已提交
231 232

Place IPUDeviceContext::GetPlace() const { return place_; }
A
Allen Guo 已提交
233

J
jianghaicheng 已提交
234 235 236 237 238 239 240
void IPUDeviceContext::Wait() const {
  /*! \brief  Wait for all operations completion in the stream. */
}

IPUDeviceContext::~IPUDeviceContext() {}

#endif
241
#ifdef PADDLE_WITH_XPU
W
Wilber 已提交
242
XPUDeviceContext::XPUDeviceContext() : pten::XPUContext() {}
243

244
XPUDeviceContext::~XPUDeviceContext() {}
245

W
Wilber 已提交
246
XPUDeviceContext::XPUDeviceContext(XPUPlace place) : pten::XPUContext(place) {
247
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: "
W
Wilber 已提交
248
                          << static_cast<int>(place.device);
249 250 251
}
#endif

252 253 254 255 256 257 258
#ifdef PADDLE_WITH_ASCEND_CL
NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) {
  NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device));
  // NOTE(zhiqiu): Usually, no need to create context explicitly,
  // ACL creates a default context which contains 1 default stream
  // and 1 sync strean after aclrtSetDevice.
259
  platform::GetCurrentNPUContext(&context_);
260 261 262 263 264 265 266
  stream_.reset(new stream::NPUStream(place));
}

NPUDeviceContext::~NPUDeviceContext() {
  // NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_));
}
267

268
void NPUDeviceContext::Wait() const {
269 270 271
  platform::RecordEvent record_event("NPUDeviceContext/wait");
  VLOG(4) << "NPU context(" << this << ")  Wait";
  stream_->Wait();
272 273 274 275 276 277 278
}

aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); }

Place NPUDeviceContext::GetPlace() const { return place_; }

aclrtContext NPUDeviceContext::context() const { return context_; }
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294

NPUPinnedDeviceContext::NPUPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

NPUPinnedDeviceContext::NPUPinnedDeviceContext(NPUPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* NPUPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place NPUPinnedDeviceContext::GetPlace() const { return place_; }

295 296 297
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
init  
qijun 已提交
298 299 300 301 302 303 304
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

305
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
306 307 308 309 310
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

311
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
312

313 314 315
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
316
  const cudaDeviceProp& deviceProperties() const override {
317
#endif
Q
init  
qijun 已提交
318 319 320 321
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
322 323 324
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
325 326 327
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
328
    void* retv = buf->ptr();
S
sneaxiy 已提交
329 330 331 332
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
333
    return retv;
Q
init  
qijun 已提交
334 335
  }

S
sneaxiy 已提交
336 337 338 339 340 341
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
342 343 344

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
345
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
346 347 348 349 350 351
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
352
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
353
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
354
#ifdef PADDLE_WITH_HIP
355
      PADDLE_ENFORCE_GPU_SUCCESS(
356 357
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
358
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
init  
qijun 已提交
359
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
360
#endif
Q
init  
qijun 已提交
361 362 363 364 365
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
366
  CUDAPlace place_;
367 368 369 370
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
371
  const cudaDeviceProp* device_prop_;  // not owned;
372
#endif
Q
qijun 已提交
373
  mutable void* scratch_;
Q
init  
qijun 已提交
374
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
375
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
376
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
377 378
};

379 380 381 382 383 384 385 386 387
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

388 389 390 391 392 393 394 395 396 397 398 399
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
400 401
                         const stream::Priority& priority,
                         const stream::StreamFlag& flag) {
402 403
  place_ = place;
  CUDADeviceGuard guard(place_.device);
404
  stream_.reset(new stream::CUDAStream(place, priority, flag));
405 406 407
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
408
#ifndef PADDLE_WITH_HIP
Z
zhangkaihuo 已提交
409
  InitCuSparseContext();
G
Guo Sheng 已提交
410
  InitCuSolverContext();
411
#endif
412 413
}

W
Wilber 已提交
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
void CUDAContext::SetStream(gpuStream_t stream) {
  if (stream_->raw_stream() != stream) {
    CUDADeviceGuard guard(place_.device);
    DestoryCuDNNContext();
    DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
    DestoryCuSolverContext();
#endif

    stream_->SetStream(stream);

    InitEigenContext();
    InitCuBlasContext();
    InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
    InitCuSolverContext();
#endif
  }
}

434 435 436 437
CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
438
#ifndef PADDLE_WITH_HIP
Z
zhangkaihuo 已提交
439
  DestoryCuSparseContext();
G
Guo Sheng 已提交
440
  DestoryCuSolverContext();
441
#endif
442 443
}

444
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
Y
Yu Yang 已提交
445
  CUDADeviceGuard guard(place_.device);
446 447 448
  compute_capability_ = GetGPUComputeCapability(place_.device);
  multi_process_ = GetGPUMultiProcessors(place_.device);
  max_threads_per_mp_ = GetGPUMaxThreadsPerMultiProcessor(place_.device);
449
  max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
450
  max_threads_per_block_ = GetGPUMaxThreadsPerBlock(place_.device);
451

452 453
  driver_version_ = GetGPUDriverVersion(place_.device);
  runtime_version_ = GetGPURuntimeVersion(place_.device);
C
chengduo 已提交
454

455 456
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: "
                          << static_cast<int>(place_.device)
457 458 459
                          << ", GPU Compute Capability: "
                          << compute_capability_ / 10 << "."
                          << compute_capability_ % 10
C
chengduo 已提交
460
                          << ", Driver API Version: " << driver_version_ / 1000
461
                          << "." << (driver_version_ % 100) / 10
C
chengduo 已提交
462 463 464
                          << ", Runtime API Version: "
                          << runtime_version_ / 1000 << "."
                          << (runtime_version_ % 100) / 10;
465 466
#ifdef PADDLE_WITH_HIP
  size_t version_major, version_minor, version_patch;
467
  PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenGetVersion(
468
      &version_major, &version_minor, &version_patch));
469
  LOG_FIRST_N(WARNING, 1) << "device: " << static_cast<int>(place_.device)
470 471 472
                          << ", MIOpen Version: " << version_major << "."
                          << version_minor << "." << version_patch;
#else
473
  size_t cudnn_dso_ver = dynload::cudnnGetVersion();
474
  LOG_FIRST_N(WARNING, 1) << "device: " << static_cast<int>(place_.device)
475
                          << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
476
                          << (cudnn_dso_ver % 1000) / 100 << ".";
477
#endif
S
sneaxiy 已提交
478 479
  {
    // Check CUDA/CUDNN version compatiblity
480 481
    auto local_cuda_version =
        (driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
482 483 484
#ifdef PADDLE_WITH_HIP
    auto compile_cuda_version = (HIP_VERSION / 100) * 10 + (HIP_VERSION % 10);
#else
485 486
    auto compile_cuda_version =
        (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
487
#endif
S
sneaxiy 已提交
488 489
    if (local_cuda_version < compile_cuda_version) {
      LOG_FIRST_N(WARNING, 1)
490
          << "WARNING: device: " << static_cast<int>(place_.device)
S
sneaxiy 已提交
491 492 493 494 495 496 497 498 499
          << ". The installed Paddle is compiled with CUDA "
          << compile_cuda_version / 10 << "." << compile_cuda_version % 10
          << ", but CUDA runtime version in your machine is "
          << local_cuda_version / 10 << "." << local_cuda_version % 10
          << ", which may cause serious incompatible bug. "
          << "Please recompile or reinstall Paddle with compatible CUDA "
             "version.";
    }
  }
500
  default_ctx_.reset(new CUDAContext(place_));
501 502 503 504
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
505
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
506
  if (nccl_comm_) {
507
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
508 509
  }
#endif
510 511
}

L
liaogang 已提交
512
Place CUDADeviceContext::GetPlace() const { return place_; }
513

514
void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); }
515

K
Kexin Zhao 已提交
516
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
517
  return compute_capability_;
K
Kexin Zhao 已提交
518 519
}

520
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
521
  return multi_process_ * max_threads_per_mp_;
522 523
}

524 525 526 527 528 529
int CUDADeviceContext::GetSMCount() const { return multi_process_; }

int CUDADeviceContext::GetMaxThreadsPerBlock() const {
  return max_threads_per_block_;
}

530
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
531
  return context()->EigenDevice().get();
532 533
}

534
bool CUDADeviceContext::tensor_core_available() const {
535
  return context()->CublasTensorCoreHandle() != nullptr;
S
sneaxiy 已提交
536 537
}

538 539 540 541
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
  return max_grid_dim_size_;
}

542 543 544
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
545
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
546
#endif
547 548
  return context()->CudnnHandle();
}
549

550 551 552 553 554
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}
#else
555 556 557
cublasHandle_t CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}
Z
zhangkaihuo 已提交
558 559 560
cusparseHandle_t CUDADeviceContext::cusparse_handle() const {
  return context()->CusparseHandle()->GetCusparseHandle();
}
561
#endif
562

S
sneaxiy 已提交
563
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
564
  return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
565
}
566

567
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
568 569 570
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  return context()->CusolverDnHandle();
}
571
#endif
G
Guo Sheng 已提交
572

573
gpuStream_t CUDADeviceContext::stream() const { return context()->RawStream(); }
Q
qijun 已提交
574

C
chengduoZH 已提交
575 576 577 578 579 580 581 582 583 584 585 586 587 588
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
589
#endif
Q
qijun 已提交
590

T
tensor-tang 已提交
591 592
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
593
    : CPUDeviceContext(place), p_blobmap_() {
594
  p_blobmap_.reset(new BlobMap());
595
  p_exec_items_.reset(new ExecShape());
596
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
597 598
}

599
MKLDNNDeviceContextThreadLocals::Body::Body()
600
    : cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
601 602 603 604 605 606
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

607 608 609 610 611 612 613 614 615 616 617 618
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
619
  dev_ctx->ResetBlobMap(exec_ptr_);
620 621
}

622 623 624 625 626 627 628 629 630 631
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
632 633
  cur_input_shape_str = input_shape_str;
}
634 635
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
636 637
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
638

639 640
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
641 642 643
  cur_paddle_data_layout = dl;
}

644 645
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
646 647 648
  return cur_paddle_data_layout;
}

649 650 651 652 653 654 655 656 657
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

658
const dnnl::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
659 660 661
  return cur_engine;
}

662
dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
663 664 665
  return cur_stream;
}

666
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
667 668 669
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
670 671 672 673 674 675
    // If no specific executor pointer then clear
    // everything. For executor pointer then clear only
    // objects allocated when using given executor
    if (ptr == nullptr) {
      p_blobmap_->clear();
    } else {
676 677 678 679 680
      // Iterate through all shapes and release
      // for each shape and active executor all entries
      // of this executor
      for (auto& s : *p_exec_items_) {
        for (auto& v : (*s.second)[ptr]) {
681
          (v.first)->erase(v.second);
682 683
        }
        s.second->erase(ptr);
684 685
      }
    }
686 687 688 689 690 691
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

692 693
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
  p_exec_items_->erase(p_exec_items_->begin());
694 695
}

696 697
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
                                                KeyBlob::iterator it) const {
698
  // Take current input shape from TLS
699 700
  // Take current executor addess from TLS
  // and for this executor's items add the one defined with arguments
701 702 703 704 705 706 707 708 709
  auto key_it = p_exec_items_
                    ->insert(std::make_pair(tls().cur_input_shape_str,
                                            std::make_shared<ExecMap>()))
                    .first;
  (*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));

  VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
          << " curr exec size: "
          << (*key_it->second)[tls().get_curr_exec()].size() << "\n";
710 711
}

712 713 714 715
void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
716
}
717

718
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
719
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
720
  BlobMap* pMap = p_blobmap_.get();
721
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
722
  if (map_it == pMap->end()) {
723 724 725
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
726 727 728 729
  }
  return map_it->second->size();
}

730
void MKLDNNDeviceContext::SetBlob(const std::string& name,
731
                                  BlobPtr_t<void> data) const {
732
  BlobMap* pMap = p_blobmap_.get();
733
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
734
  BlobPtr_t<KeyBlob> pBlob = nullptr;
735

736
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
737

738
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
739

740 741
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
742 743 744

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
745
    sBlob = std::make_shared<ShapeBlob>();
746 747
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
748
  } else {
749
    sBlob = map_it->second;
750
  }
T
tensor-tang 已提交
751

752
  // Find KeyBlob for current input shape
753
  auto key_it = sBlob->find(tls().cur_input_shape_str);
754

755
  if (key_it == sBlob->end()) {
756 757
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
758 759
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
760
        sBlob->size() &&
761
        (sBlob->size() >=
762
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
763 764 765 766
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
      RemoveShapeEntriesWithExecutor();
767
    }
768
    pBlob = std::make_shared<KeyBlob>();
769
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
770
  } else {
771
    pBlob = key_it->second;
772 773
  }

774
  // Find Blob via name
775 776 777 778
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    auto el =
        pBlob->insert(std::make_pair(name, data));  //  (*pBlob)[name] = data;
779 780 781
    // Register new element in per executor map
    // to have easily erased when executor terminated
    LinkEntryWithExecutor(pBlob, el.first);
782 783 784
  } else {
    blob_it->second = data;  // set data to existing blob
  }
785
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
786
  // lock will be automatically released when out of scope
787
  return;
T
tensor-tang 已提交
788 789
}

790
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
791 792 793
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
794
      num_entries += (l2.second)->size();
795 796 797 798 799
    }
  }
  return num_entries;
}

800
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
801
    const std::string& name) const {
802
  BlobMap* pMap = p_blobmap_.get();
803
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
804
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
805

806
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
807

808
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
809

810 811
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
812 813 814 815
  // (jczaja): After first iteration of model's execution we
  // should have all elements cached (mostly) so failures are unlikely (less
  // likely for dynamic shapes)
  if (unlikely(map_it == pMap->end())) {
816
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
817 818 819 820 821
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
822
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
823
  if (unlikely(sBlob_it == sBlob->end())) {
824
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
825 826 827 828
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
829 830

  // Find Blob via name
831
  auto key_it = pBlob->find(name);
832

833
  if (unlikely(key_it == pBlob->end())) {
834
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
835 836
    return nullptr;
  }
837

838
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
839 840
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
841 842 843
}

#endif
Q
qijun 已提交
844
}  // namespace platform
Q
qijun 已提交
845
}  // namespace paddle