device_context.cc 29.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
13
#include <set>
14

15
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
16
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
17
#include "paddle/fluid/platform/cuda_device_guard.h"
18
#endif
F
fwenguang 已提交
19 20 21 22
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/fluid/platform/device/mlu/device_context_allocator.h"
#endif
J
jianghaicheng 已提交
23 24 25
#ifdef PADDLE_WITH_IPU
#include "paddle/fluid/platform/ipu/ipu_backend.h"
#endif
26
#include "glog/logging.h"
27
#include "paddle/fluid/platform/profiler.h"
28

29 30 31 32 33
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
34
  if (size == 0) {
35 36
    return Alloc(place, size);
  }
37 38

  if (platform::is_gpu_place(place)) {
39
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
58 59
    return Alloc(place, size);
#else
60 61 62
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
F
fwenguang 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
#endif
  } else if (platform::is_mlu_place(place)) {
#ifdef PADDLE_WITH_MLU
    auto* default_dev_ctx = static_cast<platform::MLUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::MLUDeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::MLUDeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use MLU device since it's not compiled with MLU,"
        "Please recompile or reinstall Paddle with MLU support."));
80
#endif
81 82 83
  } else {
    return Alloc(place, size);
  }
84 85 86 87 88
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
89 90 91
namespace paddle {
namespace platform {

92
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
93 94 95
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
96 97 98 99

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
100 101
#endif  // PADDLE_WITH_CUDA

102 103 104 105 106 107 108
DeviceType Place2DeviceType(const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
    return platform::DeviceType::CPU;
  } else if (platform::is_gpu_place(place)) {
    return platform::DeviceType::CUDA;
  } else if (platform::is_xpu_place(place)) {
    return platform::DeviceType::XPU;
F
fwenguang 已提交
109 110
  } else if (platform::is_mlu_place(place)) {
    return platform::DeviceType::MLU;
111 112 113 114 115 116
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported place %s to convert into platform::DeviceType.", place));
  }
}

D
dzhwinter 已提交
117 118
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
119
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
120
  VLOG(6) << "DeviceContextPool Get: " << place;
D
dzhwinter 已提交
121 122
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
123 124
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
F
fwenguang 已提交
125 126
        "with WITH_GPU, WITH_XPU, WITH_IPU, WITH_MLU or WITH_ASCEND_CL option "
        "or check "
J
jianghaicheng 已提交
127 128
        "that your train process set the correct device id if you use "
        "Executor.",
G
GaoWei8 已提交
129
        place));
D
dzhwinter 已提交
130
  }
131
  return it->second.get().get();
D
dzhwinter 已提交
132 133
}

134 135 136 137 138 139 140 141 142
template <typename DevCtx, typename PlaceType>
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
  map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
                     // lazy evaluation. i.e., only create device context at
                     // first `Get`
143
                     return PtrType(new DevCtx(BOOST_GET_CONST(PlaceType, p)));
144
                   }));
C
chengduozh 已提交
145 146
}

D
dzhwinter 已提交
147 148
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
149 150 151 152 153
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
154
  std::set<Place> set;
Y
Yu Yang 已提交
155 156 157 158 159
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
160
#ifdef PADDLE_WITH_MKLDNN
161
      EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
162
#else
163
      EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
164
#endif
Y
Yu Yang 已提交
165
    } else if (platform::is_gpu_place(p)) {
166
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
167
      EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
D
dzhwinter 已提交
168
#else
G
GaoWei8 已提交
169 170 171
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
172 173
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
174
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
175 176
      EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
          &device_contexts_, p);
C
chengduoZH 已提交
177
#else
G
GaoWei8 已提交
178
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
179 180
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
181 182 183 184 185 186 187 188
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
      EmplaceDeviceContext<XPUDeviceContext, XPUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
F
fwenguang 已提交
189 190 191 192 193 194 195 196
#endif
    } else if (platform::is_mlu_place(p)) {
#ifdef PADDLE_WITH_MLU
      EmplaceDeviceContext<MLUDeviceContext, MLUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("MLUPlace is not supported. Please "
                                          "re-compile with WITH_MLU option."));
J
jianghaicheng 已提交
197 198 199 200 201 202 203 204
#endif
    } else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU
      EmplaceDeviceContext<IPUDeviceContext, IPUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("IPUPlace is not supported. Please "
                                          "re-compile with WITH_IPU option."));
205 206 207 208 209 210 211 212
#endif
    } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
      EmplaceDeviceContext<NPUDeviceContext, NPUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPlace is not supported. Please "
          "re-compile with WITH_ASCEND_CL option."));
213 214 215 216 217 218 219 220 221 222
#endif
    } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
      EmplaceDeviceContext<NPUPinnedDeviceContext, NPUPinnedPlace>(
          &device_contexts_, p);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPinnedPlace is not supported. Please re-compile with "
          "WITH_ASCEND_CL "
          "option."));
D
dzhwinter 已提交
223 224 225 226 227
#endif
    }
  }
}

228 229 230 231
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
232
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
233 234 235 236 237 238 239
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
240
Place CPUDeviceContext::GetPlace() const { return place_; }
241

J
jianghaicheng 已提交
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
#ifdef PADDLE_WITH_IPU
IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) {
  int id = place.GetDeviceId();
  std::shared_ptr<platform::ipu::IpuBackend> ipu_backend =
      platform::ipu::IpuBackend::GetInstance();
  device_ = ipu_backend->GetDevice(id);
}

Place IPUDeviceContext::GetPlace() const { return place_; }
void IPUDeviceContext::Wait() const {
  /*! \brief  Wait for all operations completion in the stream. */
}

IPUDeviceContext::~IPUDeviceContext() {}

#endif
258
#ifdef PADDLE_WITH_XPU
Q
QingshuChen 已提交
259 260 261 262
XPUDeviceContext::XPUDeviceContext() {
  context_ = xpu::create_context();
  xpu_version_ = get_xpu_version(place_.device);
}
263

264
XPUDeviceContext::~XPUDeviceContext() {}
265 266

XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) {
267
  platform::XPUDeviceGuard guard(place.device);
268 269 270

  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: " << place_.device;

271
  context_ = xpu::create_context();
272 273 274
  const int MAX_XPU_NUM = 16;
  static void* l3ptrs[MAX_XPU_NUM] = {nullptr};

275 276 277 278 279
  int l3_size = 13.5 * 1024 * 1024;
  if (std::getenv("XPU_PADDLE_L3_SIZE") != nullptr) {
    l3_size = atoi(std::getenv("XPU_PADDLE_L3_SIZE"));
  }

280 281 282 283 284 285 286 287 288 289 290 291 292
  auto selected_xpus = GetXPUSelectedDevices();
  for (unsigned int i = 0; i < selected_xpus.size(); i++) {
    if (place.device == selected_xpus[i]) {
      if (l3ptrs[place.device] == nullptr) {
        xpu_malloc(static_cast<void**>(&l3ptrs[place.device]), l3_size,
                   XPU_MEM_L3);
      }
      if (l3ptrs[place.device] != nullptr) {
        context_->_l3_mgr.set(l3ptrs[place.device], l3_size);
        VLOG(3) << "xpu place " << place.device << " set l3 size " << l3_size;
      }
      break;
    }
293
  }
294 295 296
}

void XPUDeviceContext::Wait() const {
297
  platform::SetXPUDeviceId(place_.device);
298
  xpu_wait(context_->xpu_stream);
299 300 301 302 303 304 305
}

Place XPUDeviceContext::GetPlace() const { return place_; }

xpu::Context* XPUDeviceContext::x_context() const { return context_; }
#endif

306 307 308 309 310 311 312
#ifdef PADDLE_WITH_ASCEND_CL
NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) {
  NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device));
  // NOTE(zhiqiu): Usually, no need to create context explicitly,
  // ACL creates a default context which contains 1 default stream
  // and 1 sync strean after aclrtSetDevice.
313
  platform::GetCurrentNPUContext(&context_);
314 315 316 317 318 319 320
  stream_.reset(new stream::NPUStream(place));
}

NPUDeviceContext::~NPUDeviceContext() {
  // NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_));
}
321

322
void NPUDeviceContext::Wait() const {
323 324 325
  platform::RecordEvent record_event("NPUDeviceContext/wait");
  VLOG(4) << "NPU context(" << this << ")  Wait";
  stream_->Wait();
326 327 328 329 330 331 332
}

aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); }

Place NPUDeviceContext::GetPlace() const { return place_; }

aclrtContext NPUDeviceContext::context() const { return context_; }
333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348

NPUPinnedDeviceContext::NPUPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

NPUPinnedDeviceContext::NPUPinnedDeviceContext(NPUPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* NPUPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place NPUPinnedDeviceContext::GetPlace() const { return place_; }

349 350 351
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
init  
qijun 已提交
352 353 354 355 356 357 358
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

359
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
360 361 362 363 364
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

365
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
366

367 368 369
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
370
  const cudaDeviceProp& deviceProperties() const override {
371
#endif
Q
init  
qijun 已提交
372 373 374 375
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
376 377 378
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
379 380 381
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
382
    void* retv = buf->ptr();
S
sneaxiy 已提交
383 384 385 386
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
387
    return retv;
Q
init  
qijun 已提交
388 389
  }

S
sneaxiy 已提交
390 391 392 393 394 395
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
396 397 398

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
399
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
400 401 402 403 404 405
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
406
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
407
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
408
#ifdef PADDLE_WITH_HIP
409
      PADDLE_ENFORCE_GPU_SUCCESS(
410 411
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
412
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
init  
qijun 已提交
413
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
414
#endif
Q
init  
qijun 已提交
415 416 417 418 419
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
420
  CUDAPlace place_;
421 422 423 424
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
425
  const cudaDeviceProp* device_prop_;  // not owned;
426
#endif
Q
qijun 已提交
427
  mutable void* scratch_;
Q
init  
qijun 已提交
428
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
429
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
430
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
431 432
};

433 434 435 436 437 438 439 440 441
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

442 443 444 445 446 447 448 449 450 451 452 453
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
454 455
                         const stream::Priority& priority,
                         const stream::StreamFlag& flag) {
456 457
  place_ = place;
  CUDADeviceGuard guard(place_.device);
458
  stream_.reset(new stream::CUDAStream(place, priority, flag));
459 460 461
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
462
#ifndef PADDLE_WITH_HIP
Z
zhangkaihuo 已提交
463
  InitCuSparseContext();
G
Guo Sheng 已提交
464
  InitCuSolverContext();
465
#endif
466 467
}

W
Wilber 已提交
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487
void CUDAContext::SetStream(gpuStream_t stream) {
  if (stream_->raw_stream() != stream) {
    CUDADeviceGuard guard(place_.device);
    DestoryCuDNNContext();
    DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
    DestoryCuSolverContext();
#endif

    stream_->SetStream(stream);

    InitEigenContext();
    InitCuBlasContext();
    InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
    InitCuSolverContext();
#endif
  }
}

488 489 490 491
CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
492
#ifndef PADDLE_WITH_HIP
Z
zhangkaihuo 已提交
493
  DestoryCuSparseContext();
G
Guo Sheng 已提交
494
  DestoryCuSolverContext();
495
#endif
496 497
}

498
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
Y
Yu Yang 已提交
499
  CUDADeviceGuard guard(place_.device);
500 501 502
  compute_capability_ = GetGPUComputeCapability(place_.device);
  multi_process_ = GetGPUMultiProcessors(place_.device);
  max_threads_per_mp_ = GetGPUMaxThreadsPerMultiProcessor(place_.device);
503
  max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
504
  max_threads_per_block_ = GetGPUMaxThreadsPerBlock(place_.device);
505

506 507
  driver_version_ = GetGPUDriverVersion(place_.device);
  runtime_version_ = GetGPURuntimeVersion(place_.device);
C
chengduo 已提交
508

509
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device
510 511 512
                          << ", GPU Compute Capability: "
                          << compute_capability_ / 10 << "."
                          << compute_capability_ % 10
C
chengduo 已提交
513
                          << ", Driver API Version: " << driver_version_ / 1000
514
                          << "." << (driver_version_ % 100) / 10
C
chengduo 已提交
515 516 517
                          << ", Runtime API Version: "
                          << runtime_version_ / 1000 << "."
                          << (runtime_version_ % 100) / 10;
518 519
#ifdef PADDLE_WITH_HIP
  size_t version_major, version_minor, version_patch;
520
  PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenGetVersion(
521 522 523 524 525
      &version_major, &version_minor, &version_patch));
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", MIOpen Version: " << version_major << "."
                          << version_minor << "." << version_patch;
#else
526 527 528
  size_t cudnn_dso_ver = dynload::cudnnGetVersion();
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
529
                          << (cudnn_dso_ver % 1000) / 100 << ".";
530
#endif
S
sneaxiy 已提交
531 532
  {
    // Check CUDA/CUDNN version compatiblity
533 534
    auto local_cuda_version =
        (driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
535 536 537
#ifdef PADDLE_WITH_HIP
    auto compile_cuda_version = (HIP_VERSION / 100) * 10 + (HIP_VERSION % 10);
#else
538 539
    auto compile_cuda_version =
        (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
540
#endif
S
sneaxiy 已提交
541 542 543 544 545 546 547 548 549 550 551 552
    if (local_cuda_version < compile_cuda_version) {
      LOG_FIRST_N(WARNING, 1)
          << "WARNING: device: " << place_.device
          << ". The installed Paddle is compiled with CUDA "
          << compile_cuda_version / 10 << "." << compile_cuda_version % 10
          << ", but CUDA runtime version in your machine is "
          << local_cuda_version / 10 << "." << local_cuda_version % 10
          << ", which may cause serious incompatible bug. "
          << "Please recompile or reinstall Paddle with compatible CUDA "
             "version.";
    }
  }
553
  default_ctx_.reset(new CUDAContext(place_));
554 555 556 557
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
558
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
559
  if (nccl_comm_) {
560
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
561 562
  }
#endif
563 564
}

L
liaogang 已提交
565
Place CUDADeviceContext::GetPlace() const { return place_; }
566

567
void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); }
568

K
Kexin Zhao 已提交
569
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
570
  return compute_capability_;
K
Kexin Zhao 已提交
571 572
}

573
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
574
  return multi_process_ * max_threads_per_mp_;
575 576
}

577 578 579 580 581 582
int CUDADeviceContext::GetSMCount() const { return multi_process_; }

int CUDADeviceContext::GetMaxThreadsPerBlock() const {
  return max_threads_per_block_;
}

583
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
584
  return context()->EigenDevice().get();
585 586
}

587
bool CUDADeviceContext::tensor_core_available() const {
588
  return context()->CublasTensorCoreHandle() != nullptr;
S
sneaxiy 已提交
589 590
}

591 592 593 594
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
  return max_grid_dim_size_;
}

595 596 597
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
598
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
599
#endif
600 601
  return context()->CudnnHandle();
}
602

603 604 605 606 607
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}
#else
608 609 610
cublasHandle_t CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}
Z
zhangkaihuo 已提交
611 612 613
cusparseHandle_t CUDADeviceContext::cusparse_handle() const {
  return context()->CusparseHandle()->GetCusparseHandle();
}
614
#endif
615

S
sneaxiy 已提交
616
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
617
  return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
618
}
619

620
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
621 622 623
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  return context()->CusolverDnHandle();
}
624
#endif
G
Guo Sheng 已提交
625

626
gpuStream_t CUDADeviceContext::stream() const { return context()->RawStream(); }
Q
qijun 已提交
627

C
chengduoZH 已提交
628 629 630 631 632 633 634 635 636 637 638 639 640 641
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
642
#endif
Q
qijun 已提交
643

T
tensor-tang 已提交
644 645
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
646
    : CPUDeviceContext(place), p_blobmap_() {
647
  p_blobmap_.reset(new BlobMap());
648
  p_exec_items_.reset(new ExecShape());
649
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
650 651
}

652
MKLDNNDeviceContextThreadLocals::Body::Body()
653
    : cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
654 655 656 657 658 659
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

660 661 662 663 664 665 666 667 668 669 670 671
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
672
  dev_ctx->ResetBlobMap(exec_ptr_);
673 674
}

675 676 677 678 679 680 681 682 683 684
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
685 686
  cur_input_shape_str = input_shape_str;
}
687 688
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
689 690
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
691

692 693
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
694 695 696
  cur_paddle_data_layout = dl;
}

697 698
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
699 700 701
  return cur_paddle_data_layout;
}

702 703 704 705 706 707 708 709 710
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

711
const dnnl::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
712 713 714
  return cur_engine;
}

715
dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
716 717 718
  return cur_stream;
}

719
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
720 721 722
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
723 724 725 726 727 728
    // If no specific executor pointer then clear
    // everything. For executor pointer then clear only
    // objects allocated when using given executor
    if (ptr == nullptr) {
      p_blobmap_->clear();
    } else {
729 730 731 732 733
      // Iterate through all shapes and release
      // for each shape and active executor all entries
      // of this executor
      for (auto& s : *p_exec_items_) {
        for (auto& v : (*s.second)[ptr]) {
734
          (v.first)->erase(v.second);
735 736
        }
        s.second->erase(ptr);
737 738
      }
    }
739 740 741 742 743 744
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

745 746
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
  p_exec_items_->erase(p_exec_items_->begin());
747 748
}

749 750
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
                                                KeyBlob::iterator it) const {
751
  // Take current input shape from TLS
752 753
  // Take current executor addess from TLS
  // and for this executor's items add the one defined with arguments
754 755 756 757 758 759 760 761 762
  auto key_it = p_exec_items_
                    ->insert(std::make_pair(tls().cur_input_shape_str,
                                            std::make_shared<ExecMap>()))
                    .first;
  (*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));

  VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
          << " curr exec size: "
          << (*key_it->second)[tls().get_curr_exec()].size() << "\n";
763 764
}

765 766 767 768
void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
769
}
770

771
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
772
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
773
  BlobMap* pMap = p_blobmap_.get();
774
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
775
  if (map_it == pMap->end()) {
776 777 778
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
779 780 781 782
  }
  return map_it->second->size();
}

783
void MKLDNNDeviceContext::SetBlob(const std::string& name,
784
                                  BlobPtr_t<void> data) const {
785
  BlobMap* pMap = p_blobmap_.get();
786
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
787
  BlobPtr_t<KeyBlob> pBlob = nullptr;
788

789
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
790

791
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
792

793 794
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
795 796 797

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
798
    sBlob = std::make_shared<ShapeBlob>();
799 800
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
801
  } else {
802
    sBlob = map_it->second;
803
  }
T
tensor-tang 已提交
804

805
  // Find KeyBlob for current input shape
806
  auto key_it = sBlob->find(tls().cur_input_shape_str);
807

808
  if (key_it == sBlob->end()) {
809 810
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
811 812
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
813
        sBlob->size() &&
814
        (sBlob->size() >=
815
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
816 817 818 819
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
      RemoveShapeEntriesWithExecutor();
820
    }
821
    pBlob = std::make_shared<KeyBlob>();
822
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
823
  } else {
824
    pBlob = key_it->second;
825 826
  }

827
  // Find Blob via name
828 829 830 831
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    auto el =
        pBlob->insert(std::make_pair(name, data));  //  (*pBlob)[name] = data;
832 833 834
    // Register new element in per executor map
    // to have easily erased when executor terminated
    LinkEntryWithExecutor(pBlob, el.first);
835 836 837
  } else {
    blob_it->second = data;  // set data to existing blob
  }
838
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
839
  // lock will be automatically released when out of scope
840
  return;
T
tensor-tang 已提交
841 842
}

843
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
844 845 846
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
847
      num_entries += (l2.second)->size();
848 849 850 851 852
    }
  }
  return num_entries;
}

853 854 855 856 857 858 859 860 861
// TODO(jczaja): Replace with C++20 equivalents when applicable
#ifdef _WIN32
#define likely(expr) (expr)
#define unlikely(expr) (expr)
#else
#define likely(expr) (__builtin_expect(!!(expr), 1))
#define unlikely(expr) (__builtin_expect(!!(expr), 0))
#endif

862
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
863
    const std::string& name) const {
864
  BlobMap* pMap = p_blobmap_.get();
865
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
866
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
867

868
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
869

870
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
871

872 873
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
874 875 876 877
  // (jczaja): After first iteration of model's execution we
  // should have all elements cached (mostly) so failures are unlikely (less
  // likely for dynamic shapes)
  if (unlikely(map_it == pMap->end())) {
878
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
879 880 881 882 883
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
884
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
885
  if (unlikely(sBlob_it == sBlob->end())) {
886
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
887 888 889 890
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
891 892

  // Find Blob via name
893
  auto key_it = pBlob->find(name);
894

895
  if (unlikely(key_it == pBlob->end())) {
896
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
897 898
    return nullptr;
  }
899

900
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
901 902
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
903 904 905
}

#endif
Q
qijun 已提交
906
}  // namespace platform
Q
qijun 已提交
907
}  // namespace paddle