device_context.cc 29.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
13
#include <memory>
14
#include <set>
15

16
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
17
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
18
#include "paddle/fluid/platform/cuda_device_guard.h"
19
#endif
F
fwenguang 已提交
20 21 22 23
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/fluid/platform/device/mlu/device_context_allocator.h"
#endif
J
jianghaicheng 已提交
24 25 26
#ifdef PADDLE_WITH_IPU
#include "paddle/fluid/platform/ipu/ipu_backend.h"
#endif
27
#include "glog/logging.h"
28
#include "paddle/fluid/platform/profiler.h"
29

30 31 32 33 34
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
35
  if (size == 0) {
36 37
    return Alloc(place, size);
  }
38 39

  if (platform::is_gpu_place(place)) {
40
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
59 60
    return Alloc(place, size);
#else
61 62 63
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
F
fwenguang 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
#endif
  } else if (platform::is_mlu_place(place)) {
#ifdef PADDLE_WITH_MLU
    auto* default_dev_ctx = static_cast<platform::MLUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::MLUDeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::MLUDeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use MLU device since it's not compiled with MLU,"
        "Please recompile or reinstall Paddle with MLU support."));
81
#endif
82 83 84
  } else {
    return Alloc(place, size);
  }
85 86 87 88 89
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
90 91 92
namespace paddle {
namespace platform {

93
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
94 95 96
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
97 98 99 100

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
101 102
#endif  // PADDLE_WITH_CUDA

103 104 105 106 107 108 109
DeviceType Place2DeviceType(const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
    return platform::DeviceType::CPU;
  } else if (platform::is_gpu_place(place)) {
    return platform::DeviceType::CUDA;
  } else if (platform::is_xpu_place(place)) {
    return platform::DeviceType::XPU;
F
fwenguang 已提交
110 111
  } else if (platform::is_mlu_place(place)) {
    return platform::DeviceType::MLU;
112 113 114 115 116 117
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported place %s to convert into platform::DeviceType.", place));
  }
}

D
dzhwinter 已提交
118 119
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
120
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
121
  VLOG(6) << "DeviceContextPool Get: " << place;
D
dzhwinter 已提交
122 123
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
124 125
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
F
fwenguang 已提交
126 127
        "with WITH_GPU, WITH_XPU, WITH_IPU, WITH_MLU or WITH_ASCEND_CL option "
        "or check "
J
jianghaicheng 已提交
128 129
        "that your train process set the correct device id if you use "
        "Executor.",
G
GaoWei8 已提交
130
        place));
D
dzhwinter 已提交
131
  }
132
  return it->second.get().get();
D
dzhwinter 已提交
133 134
}

W
Wilber 已提交
135
template <typename DevCtx>
136 137 138 139 140 141 142 143
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
  map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
                     // lazy evaluation. i.e., only create device context at
                     // first `Get`
144
                     return PtrType(new DevCtx(p));
145
                   }));
C
chengduozh 已提交
146 147
}

D
dzhwinter 已提交
148 149
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
150 151 152 153 154
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
155
  std::set<Place> set;
Y
Yu Yang 已提交
156 157 158 159 160
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
161
#ifdef PADDLE_WITH_MKLDNN
W
Wilber 已提交
162
      EmplaceDeviceContext<MKLDNNDeviceContext>(&device_contexts_, p);
163
#else
W
Wilber 已提交
164
      EmplaceDeviceContext<CPUDeviceContext>(&device_contexts_, p);
165
#endif
Y
Yu Yang 已提交
166
    } else if (platform::is_gpu_place(p)) {
167
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
168
      EmplaceDeviceContext<CUDADeviceContext>(&device_contexts_, p);
D
dzhwinter 已提交
169
#else
G
GaoWei8 已提交
170 171 172
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
173 174
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
175
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
176
      EmplaceDeviceContext<CUDAPinnedDeviceContext>(&device_contexts_, p);
C
chengduoZH 已提交
177
#else
G
GaoWei8 已提交
178
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
179 180
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
181 182 183
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
W
Wilber 已提交
184
      EmplaceDeviceContext<XPUDeviceContext>(&device_contexts_, p);
185 186 187 188
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
F
fwenguang 已提交
189 190 191
#endif
    } else if (platform::is_mlu_place(p)) {
#ifdef PADDLE_WITH_MLU
W
Wilber 已提交
192
      EmplaceDeviceContext<MLUDeviceContext>(&device_contexts_, p);
F
fwenguang 已提交
193 194 195 196
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("MLUPlace is not supported. Please "
                                          "re-compile with WITH_MLU option."));
J
jianghaicheng 已提交
197 198 199
#endif
    } else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU
W
Wilber 已提交
200
      EmplaceDeviceContext<IPUDeviceContext>(&device_contexts_, p);
J
jianghaicheng 已提交
201 202 203 204
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("IPUPlace is not supported. Please "
                                          "re-compile with WITH_IPU option."));
205 206 207
#endif
    } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
208
      EmplaceDeviceContext<NPUDeviceContext>(&device_contexts_, p);
209 210 211 212
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPlace is not supported. Please "
          "re-compile with WITH_ASCEND_CL option."));
213 214 215
#endif
    } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
216
      EmplaceDeviceContext<NPUPinnedDeviceContext>(&device_contexts_, p);
217 218 219 220 221
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPinnedPlace is not supported. Please re-compile with "
          "WITH_ASCEND_CL "
          "option."));
D
dzhwinter 已提交
222 223 224 225 226
#endif
    }
  }
}

W
Wilber 已提交
227
CPUDeviceContext::CPUDeviceContext() : pten::CPUContext() {}
228

W
Wilber 已提交
229
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : pten::CPUContext() {}
230

J
jianghaicheng 已提交
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
#ifdef PADDLE_WITH_IPU
IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) {
  int id = place.GetDeviceId();
  std::shared_ptr<platform::ipu::IpuBackend> ipu_backend =
      platform::ipu::IpuBackend::GetInstance();
  device_ = ipu_backend->GetDevice(id);
}

Place IPUDeviceContext::GetPlace() const { return place_; }
void IPUDeviceContext::Wait() const {
  /*! \brief  Wait for all operations completion in the stream. */
}

IPUDeviceContext::~IPUDeviceContext() {}

#endif
247
#ifdef PADDLE_WITH_XPU
Q
QingshuChen 已提交
248 249 250 251
XPUDeviceContext::XPUDeviceContext() {
  context_ = xpu::create_context();
  xpu_version_ = get_xpu_version(place_.device);
}
252

253
XPUDeviceContext::~XPUDeviceContext() {}
254 255

XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) {
256
  platform::XPUDeviceGuard guard(place.device);
257

258 259
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: "
                          << static_cast<int>(place_.device);
260

261
  context_ = xpu::create_context();
262 263 264
  const int MAX_XPU_NUM = 16;
  static void* l3ptrs[MAX_XPU_NUM] = {nullptr};

265 266 267 268 269
  int l3_size = 13.5 * 1024 * 1024;
  if (std::getenv("XPU_PADDLE_L3_SIZE") != nullptr) {
    l3_size = atoi(std::getenv("XPU_PADDLE_L3_SIZE"));
  }

270 271 272 273 274 275 276 277 278 279 280 281 282
  auto selected_xpus = GetXPUSelectedDevices();
  for (unsigned int i = 0; i < selected_xpus.size(); i++) {
    if (place.device == selected_xpus[i]) {
      if (l3ptrs[place.device] == nullptr) {
        xpu_malloc(static_cast<void**>(&l3ptrs[place.device]), l3_size,
                   XPU_MEM_L3);
      }
      if (l3ptrs[place.device] != nullptr) {
        context_->_l3_mgr.set(l3ptrs[place.device], l3_size);
        VLOG(3) << "xpu place " << place.device << " set l3 size " << l3_size;
      }
      break;
    }
283
  }
284 285 286
}

void XPUDeviceContext::Wait() const {
287
  platform::SetXPUDeviceId(place_.device);
288
  xpu_wait(context_->xpu_stream);
289 290 291 292 293 294 295
}

Place XPUDeviceContext::GetPlace() const { return place_; }

xpu::Context* XPUDeviceContext::x_context() const { return context_; }
#endif

296 297 298 299 300 301 302
#ifdef PADDLE_WITH_ASCEND_CL
NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) {
  NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device));
  // NOTE(zhiqiu): Usually, no need to create context explicitly,
  // ACL creates a default context which contains 1 default stream
  // and 1 sync strean after aclrtSetDevice.
303
  platform::GetCurrentNPUContext(&context_);
304 305 306 307 308 309 310
  stream_.reset(new stream::NPUStream(place));
}

NPUDeviceContext::~NPUDeviceContext() {
  // NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_));
}
311

312
void NPUDeviceContext::Wait() const {
313 314 315
  platform::RecordEvent record_event("NPUDeviceContext/wait");
  VLOG(4) << "NPU context(" << this << ")  Wait";
  stream_->Wait();
316 317 318 319 320 321 322
}

aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); }

Place NPUDeviceContext::GetPlace() const { return place_; }

aclrtContext NPUDeviceContext::context() const { return context_; }
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338

NPUPinnedDeviceContext::NPUPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

NPUPinnedDeviceContext::NPUPinnedDeviceContext(NPUPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* NPUPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place NPUPinnedDeviceContext::GetPlace() const { return place_; }

339 340 341
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
init  
qijun 已提交
342 343 344 345 346 347 348
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

349
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
350 351 352 353 354
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

355
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
356

357 358 359
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
360
  const cudaDeviceProp& deviceProperties() const override {
361
#endif
Q
init  
qijun 已提交
362 363 364 365
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
366 367 368
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
369 370 371
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
372
    void* retv = buf->ptr();
S
sneaxiy 已提交
373 374 375 376
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
377
    return retv;
Q
init  
qijun 已提交
378 379
  }

S
sneaxiy 已提交
380 381 382 383 384 385
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
386 387 388

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
389
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
390 391 392 393 394 395
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
396
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
397
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
398
#ifdef PADDLE_WITH_HIP
399
      PADDLE_ENFORCE_GPU_SUCCESS(
400 401
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
402
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
init  
qijun 已提交
403
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
404
#endif
Q
init  
qijun 已提交
405 406 407 408 409
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
410
  CUDAPlace place_;
411 412 413 414
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
415
  const cudaDeviceProp* device_prop_;  // not owned;
416
#endif
Q
qijun 已提交
417
  mutable void* scratch_;
Q
init  
qijun 已提交
418
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
419
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
420
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
421 422
};

423 424 425 426 427 428 429 430 431
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

432 433 434 435 436 437 438 439 440 441 442 443
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
444 445
                         const stream::Priority& priority,
                         const stream::StreamFlag& flag) {
446 447
  place_ = place;
  CUDADeviceGuard guard(place_.device);
448
  stream_.reset(new stream::CUDAStream(place, priority, flag));
449 450 451
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
452
#ifndef PADDLE_WITH_HIP
Z
zhangkaihuo 已提交
453
  InitCuSparseContext();
G
Guo Sheng 已提交
454
  InitCuSolverContext();
455
#endif
456 457
}

W
Wilber 已提交
458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477
void CUDAContext::SetStream(gpuStream_t stream) {
  if (stream_->raw_stream() != stream) {
    CUDADeviceGuard guard(place_.device);
    DestoryCuDNNContext();
    DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
    DestoryCuSolverContext();
#endif

    stream_->SetStream(stream);

    InitEigenContext();
    InitCuBlasContext();
    InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
    InitCuSolverContext();
#endif
  }
}

478 479 480 481
CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
482
#ifndef PADDLE_WITH_HIP
Z
zhangkaihuo 已提交
483
  DestoryCuSparseContext();
G
Guo Sheng 已提交
484
  DestoryCuSolverContext();
485
#endif
486 487
}

488
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
Y
Yu Yang 已提交
489
  CUDADeviceGuard guard(place_.device);
490 491 492
  compute_capability_ = GetGPUComputeCapability(place_.device);
  multi_process_ = GetGPUMultiProcessors(place_.device);
  max_threads_per_mp_ = GetGPUMaxThreadsPerMultiProcessor(place_.device);
493
  max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
494
  max_threads_per_block_ = GetGPUMaxThreadsPerBlock(place_.device);
495

496 497
  driver_version_ = GetGPUDriverVersion(place_.device);
  runtime_version_ = GetGPURuntimeVersion(place_.device);
C
chengduo 已提交
498

499 500
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: "
                          << static_cast<int>(place_.device)
501 502 503
                          << ", GPU Compute Capability: "
                          << compute_capability_ / 10 << "."
                          << compute_capability_ % 10
C
chengduo 已提交
504
                          << ", Driver API Version: " << driver_version_ / 1000
505
                          << "." << (driver_version_ % 100) / 10
C
chengduo 已提交
506 507 508
                          << ", Runtime API Version: "
                          << runtime_version_ / 1000 << "."
                          << (runtime_version_ % 100) / 10;
509 510
#ifdef PADDLE_WITH_HIP
  size_t version_major, version_minor, version_patch;
511
  PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenGetVersion(
512
      &version_major, &version_minor, &version_patch));
513
  LOG_FIRST_N(WARNING, 1) << "device: " << static_cast<int>(place_.device)
514 515 516
                          << ", MIOpen Version: " << version_major << "."
                          << version_minor << "." << version_patch;
#else
517
  size_t cudnn_dso_ver = dynload::cudnnGetVersion();
518
  LOG_FIRST_N(WARNING, 1) << "device: " << static_cast<int>(place_.device)
519
                          << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
520
                          << (cudnn_dso_ver % 1000) / 100 << ".";
521
#endif
S
sneaxiy 已提交
522 523
  {
    // Check CUDA/CUDNN version compatiblity
524 525
    auto local_cuda_version =
        (driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
526 527 528
#ifdef PADDLE_WITH_HIP
    auto compile_cuda_version = (HIP_VERSION / 100) * 10 + (HIP_VERSION % 10);
#else
529 530
    auto compile_cuda_version =
        (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
531
#endif
S
sneaxiy 已提交
532 533
    if (local_cuda_version < compile_cuda_version) {
      LOG_FIRST_N(WARNING, 1)
534
          << "WARNING: device: " << static_cast<int>(place_.device)
S
sneaxiy 已提交
535 536 537 538 539 540 541 542 543
          << ". The installed Paddle is compiled with CUDA "
          << compile_cuda_version / 10 << "." << compile_cuda_version % 10
          << ", but CUDA runtime version in your machine is "
          << local_cuda_version / 10 << "." << local_cuda_version % 10
          << ", which may cause serious incompatible bug. "
          << "Please recompile or reinstall Paddle with compatible CUDA "
             "version.";
    }
  }
544
  default_ctx_.reset(new CUDAContext(place_));
545 546 547 548
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
549
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
550
  if (nccl_comm_) {
551
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
552 553
  }
#endif
554 555
}

L
liaogang 已提交
556
Place CUDADeviceContext::GetPlace() const { return place_; }
557

558
void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); }
559

K
Kexin Zhao 已提交
560
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
561
  return compute_capability_;
K
Kexin Zhao 已提交
562 563
}

564
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
565
  return multi_process_ * max_threads_per_mp_;
566 567
}

568 569 570 571 572 573
int CUDADeviceContext::GetSMCount() const { return multi_process_; }

int CUDADeviceContext::GetMaxThreadsPerBlock() const {
  return max_threads_per_block_;
}

574
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
575
  return context()->EigenDevice().get();
576 577
}

578
bool CUDADeviceContext::tensor_core_available() const {
579
  return context()->CublasTensorCoreHandle() != nullptr;
S
sneaxiy 已提交
580 581
}

582 583 584 585
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
  return max_grid_dim_size_;
}

586 587 588
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
589
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
590
#endif
591 592
  return context()->CudnnHandle();
}
593

594 595 596 597 598
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}
#else
599 600 601
cublasHandle_t CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}
Z
zhangkaihuo 已提交
602 603 604
cusparseHandle_t CUDADeviceContext::cusparse_handle() const {
  return context()->CusparseHandle()->GetCusparseHandle();
}
605
#endif
606

S
sneaxiy 已提交
607
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
608
  return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
609
}
610

611
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
612 613 614
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  return context()->CusolverDnHandle();
}
615
#endif
G
Guo Sheng 已提交
616

617
gpuStream_t CUDADeviceContext::stream() const { return context()->RawStream(); }
Q
qijun 已提交
618

C
chengduoZH 已提交
619 620 621 622 623 624 625 626 627 628 629 630 631 632
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
633
#endif
Q
qijun 已提交
634

T
tensor-tang 已提交
635 636
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
637
    : CPUDeviceContext(place), p_blobmap_() {
638
  p_blobmap_.reset(new BlobMap());
639
  p_exec_items_.reset(new ExecShape());
640
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
641 642
}

643
MKLDNNDeviceContextThreadLocals::Body::Body()
644
    : cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
645 646 647 648 649 650
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

651 652 653 654 655 656 657 658 659 660 661 662
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
663
  dev_ctx->ResetBlobMap(exec_ptr_);
664 665
}

666 667 668 669 670 671 672 673 674 675
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
676 677
  cur_input_shape_str = input_shape_str;
}
678 679
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
680 681
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
682

683 684
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
685 686 687
  cur_paddle_data_layout = dl;
}

688 689
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
690 691 692
  return cur_paddle_data_layout;
}

693 694 695 696 697 698 699 700 701
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

702
const dnnl::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
703 704 705
  return cur_engine;
}

706
dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
707 708 709
  return cur_stream;
}

710
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
711 712 713
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
714 715 716 717 718 719
    // If no specific executor pointer then clear
    // everything. For executor pointer then clear only
    // objects allocated when using given executor
    if (ptr == nullptr) {
      p_blobmap_->clear();
    } else {
720 721 722 723 724
      // Iterate through all shapes and release
      // for each shape and active executor all entries
      // of this executor
      for (auto& s : *p_exec_items_) {
        for (auto& v : (*s.second)[ptr]) {
725
          (v.first)->erase(v.second);
726 727
        }
        s.second->erase(ptr);
728 729
      }
    }
730 731 732 733 734 735
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

736 737
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
  p_exec_items_->erase(p_exec_items_->begin());
738 739
}

740 741
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
                                                KeyBlob::iterator it) const {
742
  // Take current input shape from TLS
743 744
  // Take current executor addess from TLS
  // and for this executor's items add the one defined with arguments
745 746 747 748 749 750 751 752 753
  auto key_it = p_exec_items_
                    ->insert(std::make_pair(tls().cur_input_shape_str,
                                            std::make_shared<ExecMap>()))
                    .first;
  (*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));

  VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
          << " curr exec size: "
          << (*key_it->second)[tls().get_curr_exec()].size() << "\n";
754 755
}

756 757 758 759
void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
760
}
761

762
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
763
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
764
  BlobMap* pMap = p_blobmap_.get();
765
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
766
  if (map_it == pMap->end()) {
767 768 769
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
770 771 772 773
  }
  return map_it->second->size();
}

774
void MKLDNNDeviceContext::SetBlob(const std::string& name,
775
                                  BlobPtr_t<void> data) const {
776
  BlobMap* pMap = p_blobmap_.get();
777
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
778
  BlobPtr_t<KeyBlob> pBlob = nullptr;
779

780
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
781

782
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
783

784 785
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
786 787 788

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
789
    sBlob = std::make_shared<ShapeBlob>();
790 791
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
792
  } else {
793
    sBlob = map_it->second;
794
  }
T
tensor-tang 已提交
795

796
  // Find KeyBlob for current input shape
797
  auto key_it = sBlob->find(tls().cur_input_shape_str);
798

799
  if (key_it == sBlob->end()) {
800 801
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
802 803
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
804
        sBlob->size() &&
805
        (sBlob->size() >=
806
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
807 808 809 810
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
      RemoveShapeEntriesWithExecutor();
811
    }
812
    pBlob = std::make_shared<KeyBlob>();
813
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
814
  } else {
815
    pBlob = key_it->second;
816 817
  }

818
  // Find Blob via name
819 820 821 822
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    auto el =
        pBlob->insert(std::make_pair(name, data));  //  (*pBlob)[name] = data;
823 824 825
    // Register new element in per executor map
    // to have easily erased when executor terminated
    LinkEntryWithExecutor(pBlob, el.first);
826 827 828
  } else {
    blob_it->second = data;  // set data to existing blob
  }
829
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
830
  // lock will be automatically released when out of scope
831
  return;
T
tensor-tang 已提交
832 833
}

834
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
835 836 837
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
838
      num_entries += (l2.second)->size();
839 840 841 842 843
    }
  }
  return num_entries;
}

844 845 846 847 848 849 850 851 852
// TODO(jczaja): Replace with C++20 equivalents when applicable
#ifdef _WIN32
#define likely(expr) (expr)
#define unlikely(expr) (expr)
#else
#define likely(expr) (__builtin_expect(!!(expr), 1))
#define unlikely(expr) (__builtin_expect(!!(expr), 0))
#endif

853
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
854
    const std::string& name) const {
855
  BlobMap* pMap = p_blobmap_.get();
856
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
857
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
858

859
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
860

861
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
862

863 864
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
865 866 867 868
  // (jczaja): After first iteration of model's execution we
  // should have all elements cached (mostly) so failures are unlikely (less
  // likely for dynamic shapes)
  if (unlikely(map_it == pMap->end())) {
869
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
870 871 872 873 874
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
875
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
876
  if (unlikely(sBlob_it == sBlob->end())) {
877
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
878 879 880 881
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
882 883

  // Find Blob via name
884
  auto key_it = pBlob->find(name);
885

886
  if (unlikely(key_it == pBlob->end())) {
887
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
888 889
    return nullptr;
  }
890

891
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
892 893
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
894 895 896
}

#endif
Q
qijun 已提交
897
}  // namespace platform
Q
qijun 已提交
898
}  // namespace paddle