device_context.cc 29.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
13
#include <memory>
14
#include <set>
15

16
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
17
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
18
#include "paddle/fluid/platform/cuda_device_guard.h"
19
#endif
F
fwenguang 已提交
20 21 22 23
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/fluid/platform/device/mlu/device_context_allocator.h"
#endif
J
jianghaicheng 已提交
24 25 26
#ifdef PADDLE_WITH_IPU
#include "paddle/fluid/platform/ipu/ipu_backend.h"
#endif
27
#include "glog/logging.h"
28
#include "paddle/fluid/framework/expect.h"
29
#include "paddle/fluid/platform/profiler.h"
30

31 32 33 34 35
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
36
  if (size == 0) {
37 38
    return Alloc(place, size);
  }
39 40

  if (platform::is_gpu_place(place)) {
41
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
60 61
    return Alloc(place, size);
#else
62 63 64
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
F
fwenguang 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
#endif
  } else if (platform::is_mlu_place(place)) {
#ifdef PADDLE_WITH_MLU
    auto* default_dev_ctx = static_cast<platform::MLUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::MLUDeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::MLUDeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use MLU device since it's not compiled with MLU,"
        "Please recompile or reinstall Paddle with MLU support."));
82
#endif
83 84 85
  } else {
    return Alloc(place, size);
  }
86 87 88 89 90
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
91 92 93
namespace paddle {
namespace platform {

94
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
95 96 97
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
98 99 100 101

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
102 103
#endif  // PADDLE_WITH_CUDA

104 105 106 107 108 109 110
DeviceType Place2DeviceType(const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
    return platform::DeviceType::CPU;
  } else if (platform::is_gpu_place(place)) {
    return platform::DeviceType::CUDA;
  } else if (platform::is_xpu_place(place)) {
    return platform::DeviceType::XPU;
F
fwenguang 已提交
111 112
  } else if (platform::is_mlu_place(place)) {
    return platform::DeviceType::MLU;
113 114 115 116 117 118
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported place %s to convert into platform::DeviceType.", place));
  }
}

D
dzhwinter 已提交
119 120
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
121
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
122
  VLOG(6) << "DeviceContextPool Get: " << place;
D
dzhwinter 已提交
123 124
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
125 126
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
F
fwenguang 已提交
127 128
        "with WITH_GPU, WITH_XPU, WITH_IPU, WITH_MLU or WITH_ASCEND_CL option "
        "or check "
J
jianghaicheng 已提交
129 130
        "that your train process set the correct device id if you use "
        "Executor.",
G
GaoWei8 已提交
131
        place));
D
dzhwinter 已提交
132
  }
133
  return it->second.get().get();
D
dzhwinter 已提交
134 135
}

W
Wilber 已提交
136
template <typename DevCtx>
137 138 139 140 141 142 143 144
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
  map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
                     // lazy evaluation. i.e., only create device context at
                     // first `Get`
145
                     return PtrType(new DevCtx(p));
146
                   }));
C
chengduozh 已提交
147 148
}

D
dzhwinter 已提交
149 150
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
151 152 153 154 155
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
156
  std::set<Place> set;
Y
Yu Yang 已提交
157 158 159 160 161
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
162
#ifdef PADDLE_WITH_MKLDNN
W
Wilber 已提交
163
      EmplaceDeviceContext<MKLDNNDeviceContext>(&device_contexts_, p);
164
#else
W
Wilber 已提交
165
      EmplaceDeviceContext<CPUDeviceContext>(&device_contexts_, p);
166
#endif
Y
Yu Yang 已提交
167
    } else if (platform::is_gpu_place(p)) {
168
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
169
      EmplaceDeviceContext<CUDADeviceContext>(&device_contexts_, p);
D
dzhwinter 已提交
170
#else
G
GaoWei8 已提交
171 172 173
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
174 175
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
176
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
177
      EmplaceDeviceContext<CUDAPinnedDeviceContext>(&device_contexts_, p);
C
chengduoZH 已提交
178
#else
G
GaoWei8 已提交
179
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
180 181
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
182 183 184
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
W
Wilber 已提交
185
      EmplaceDeviceContext<XPUDeviceContext>(&device_contexts_, p);
186 187 188 189
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
F
fwenguang 已提交
190 191 192
#endif
    } else if (platform::is_mlu_place(p)) {
#ifdef PADDLE_WITH_MLU
W
Wilber 已提交
193
      EmplaceDeviceContext<MLUDeviceContext>(&device_contexts_, p);
F
fwenguang 已提交
194 195 196 197
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("MLUPlace is not supported. Please "
                                          "re-compile with WITH_MLU option."));
J
jianghaicheng 已提交
198 199 200
#endif
    } else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU
W
Wilber 已提交
201
      EmplaceDeviceContext<IPUDeviceContext>(&device_contexts_, p);
J
jianghaicheng 已提交
202 203 204 205
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("IPUPlace is not supported. Please "
                                          "re-compile with WITH_IPU option."));
206 207 208
#endif
    } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
209
      EmplaceDeviceContext<NPUDeviceContext>(&device_contexts_, p);
210 211 212 213
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPlace is not supported. Please "
          "re-compile with WITH_ASCEND_CL option."));
214 215 216
#endif
    } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
217
      EmplaceDeviceContext<NPUPinnedDeviceContext>(&device_contexts_, p);
218 219 220 221 222
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPinnedPlace is not supported. Please re-compile with "
          "WITH_ASCEND_CL "
          "option."));
D
dzhwinter 已提交
223 224 225 226 227
#endif
    }
  }
}

W
Wilber 已提交
228
CPUDeviceContext::CPUDeviceContext() : pten::CPUContext() {}
229

W
Wilber 已提交
230
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : pten::CPUContext() {}
231

J
jianghaicheng 已提交
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
#ifdef PADDLE_WITH_IPU
IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) {
  int id = place.GetDeviceId();
  std::shared_ptr<platform::ipu::IpuBackend> ipu_backend =
      platform::ipu::IpuBackend::GetInstance();
  device_ = ipu_backend->GetDevice(id);
}

Place IPUDeviceContext::GetPlace() const { return place_; }
void IPUDeviceContext::Wait() const {
  /*! \brief  Wait for all operations completion in the stream. */
}

IPUDeviceContext::~IPUDeviceContext() {}

#endif
248
#ifdef PADDLE_WITH_XPU
Q
QingshuChen 已提交
249 250 251 252
XPUDeviceContext::XPUDeviceContext() {
  context_ = xpu::create_context();
  xpu_version_ = get_xpu_version(place_.device);
}
253

254
XPUDeviceContext::~XPUDeviceContext() {}
255 256

XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) {
257
  platform::XPUDeviceGuard guard(place.device);
258

259 260
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: "
                          << static_cast<int>(place_.device);
261

262
  context_ = xpu::create_context();
263 264 265
  const int MAX_XPU_NUM = 16;
  static void* l3ptrs[MAX_XPU_NUM] = {nullptr};

266 267 268 269 270
  int l3_size = 13.5 * 1024 * 1024;
  if (std::getenv("XPU_PADDLE_L3_SIZE") != nullptr) {
    l3_size = atoi(std::getenv("XPU_PADDLE_L3_SIZE"));
  }

271 272 273 274 275 276 277 278 279 280 281 282 283
  auto selected_xpus = GetXPUSelectedDevices();
  for (unsigned int i = 0; i < selected_xpus.size(); i++) {
    if (place.device == selected_xpus[i]) {
      if (l3ptrs[place.device] == nullptr) {
        xpu_malloc(static_cast<void**>(&l3ptrs[place.device]), l3_size,
                   XPU_MEM_L3);
      }
      if (l3ptrs[place.device] != nullptr) {
        context_->_l3_mgr.set(l3ptrs[place.device], l3_size);
        VLOG(3) << "xpu place " << place.device << " set l3 size " << l3_size;
      }
      break;
    }
284
  }
285 286 287
}

void XPUDeviceContext::Wait() const {
288
  platform::SetXPUDeviceId(place_.device);
289
  xpu_wait(context_->xpu_stream);
290 291 292 293 294 295 296
}

Place XPUDeviceContext::GetPlace() const { return place_; }

xpu::Context* XPUDeviceContext::x_context() const { return context_; }
#endif

297 298 299 300 301 302 303
#ifdef PADDLE_WITH_ASCEND_CL
NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) {
  NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device));
  // NOTE(zhiqiu): Usually, no need to create context explicitly,
  // ACL creates a default context which contains 1 default stream
  // and 1 sync strean after aclrtSetDevice.
304
  platform::GetCurrentNPUContext(&context_);
305 306 307 308 309 310 311
  stream_.reset(new stream::NPUStream(place));
}

NPUDeviceContext::~NPUDeviceContext() {
  // NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_));
}
312

313
void NPUDeviceContext::Wait() const {
314 315 316
  platform::RecordEvent record_event("NPUDeviceContext/wait");
  VLOG(4) << "NPU context(" << this << ")  Wait";
  stream_->Wait();
317 318 319 320 321 322 323
}

aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); }

Place NPUDeviceContext::GetPlace() const { return place_; }

aclrtContext NPUDeviceContext::context() const { return context_; }
324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339

NPUPinnedDeviceContext::NPUPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

NPUPinnedDeviceContext::NPUPinnedDeviceContext(NPUPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* NPUPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place NPUPinnedDeviceContext::GetPlace() const { return place_; }

340 341 342
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
init  
qijun 已提交
343 344 345 346 347 348 349
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

350
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
351 352 353 354 355
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

356
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
357

358 359 360
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
361
  const cudaDeviceProp& deviceProperties() const override {
362
#endif
Q
init  
qijun 已提交
363 364 365 366
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
367 368 369
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
370 371 372
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
373
    void* retv = buf->ptr();
S
sneaxiy 已提交
374 375 376 377
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
378
    return retv;
Q
init  
qijun 已提交
379 380
  }

S
sneaxiy 已提交
381 382 383 384 385 386
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
387 388 389

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
390
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
391 392 393 394 395 396
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
397
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
398
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
399
#ifdef PADDLE_WITH_HIP
400
      PADDLE_ENFORCE_GPU_SUCCESS(
401 402
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
403
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
init  
qijun 已提交
404
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
405
#endif
Q
init  
qijun 已提交
406 407 408 409 410
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
411
  CUDAPlace place_;
412 413 414 415
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
416
  const cudaDeviceProp* device_prop_;  // not owned;
417
#endif
Q
qijun 已提交
418
  mutable void* scratch_;
Q
init  
qijun 已提交
419
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
420
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
421
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
422 423
};

424 425 426 427 428 429 430 431 432
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

433 434 435 436 437 438 439 440 441 442 443 444
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
445 446
                         const stream::Priority& priority,
                         const stream::StreamFlag& flag) {
447 448
  place_ = place;
  CUDADeviceGuard guard(place_.device);
449
  stream_.reset(new stream::CUDAStream(place, priority, flag));
450 451 452
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
453
#ifndef PADDLE_WITH_HIP
Z
zhangkaihuo 已提交
454
  InitCuSparseContext();
G
Guo Sheng 已提交
455
  InitCuSolverContext();
456
#endif
457 458
}

W
Wilber 已提交
459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478
void CUDAContext::SetStream(gpuStream_t stream) {
  if (stream_->raw_stream() != stream) {
    CUDADeviceGuard guard(place_.device);
    DestoryCuDNNContext();
    DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
    DestoryCuSolverContext();
#endif

    stream_->SetStream(stream);

    InitEigenContext();
    InitCuBlasContext();
    InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
    InitCuSolverContext();
#endif
  }
}

479 480 481 482
CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
483
#ifndef PADDLE_WITH_HIP
Z
zhangkaihuo 已提交
484
  DestoryCuSparseContext();
G
Guo Sheng 已提交
485
  DestoryCuSolverContext();
486
#endif
487 488
}

489
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
Y
Yu Yang 已提交
490
  CUDADeviceGuard guard(place_.device);
491 492 493
  compute_capability_ = GetGPUComputeCapability(place_.device);
  multi_process_ = GetGPUMultiProcessors(place_.device);
  max_threads_per_mp_ = GetGPUMaxThreadsPerMultiProcessor(place_.device);
494
  max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
495
  max_threads_per_block_ = GetGPUMaxThreadsPerBlock(place_.device);
496

497 498
  driver_version_ = GetGPUDriverVersion(place_.device);
  runtime_version_ = GetGPURuntimeVersion(place_.device);
C
chengduo 已提交
499

500 501
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: "
                          << static_cast<int>(place_.device)
502 503 504
                          << ", GPU Compute Capability: "
                          << compute_capability_ / 10 << "."
                          << compute_capability_ % 10
C
chengduo 已提交
505
                          << ", Driver API Version: " << driver_version_ / 1000
506
                          << "." << (driver_version_ % 100) / 10
C
chengduo 已提交
507 508 509
                          << ", Runtime API Version: "
                          << runtime_version_ / 1000 << "."
                          << (runtime_version_ % 100) / 10;
510 511
#ifdef PADDLE_WITH_HIP
  size_t version_major, version_minor, version_patch;
512
  PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenGetVersion(
513
      &version_major, &version_minor, &version_patch));
514
  LOG_FIRST_N(WARNING, 1) << "device: " << static_cast<int>(place_.device)
515 516 517
                          << ", MIOpen Version: " << version_major << "."
                          << version_minor << "." << version_patch;
#else
518
  size_t cudnn_dso_ver = dynload::cudnnGetVersion();
519
  LOG_FIRST_N(WARNING, 1) << "device: " << static_cast<int>(place_.device)
520
                          << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
521
                          << (cudnn_dso_ver % 1000) / 100 << ".";
522
#endif
S
sneaxiy 已提交
523 524
  {
    // Check CUDA/CUDNN version compatiblity
525 526
    auto local_cuda_version =
        (driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
527 528 529
#ifdef PADDLE_WITH_HIP
    auto compile_cuda_version = (HIP_VERSION / 100) * 10 + (HIP_VERSION % 10);
#else
530 531
    auto compile_cuda_version =
        (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
532
#endif
S
sneaxiy 已提交
533 534
    if (local_cuda_version < compile_cuda_version) {
      LOG_FIRST_N(WARNING, 1)
535
          << "WARNING: device: " << static_cast<int>(place_.device)
S
sneaxiy 已提交
536 537 538 539 540 541 542 543 544
          << ". The installed Paddle is compiled with CUDA "
          << compile_cuda_version / 10 << "." << compile_cuda_version % 10
          << ", but CUDA runtime version in your machine is "
          << local_cuda_version / 10 << "." << local_cuda_version % 10
          << ", which may cause serious incompatible bug. "
          << "Please recompile or reinstall Paddle with compatible CUDA "
             "version.";
    }
  }
545
  default_ctx_.reset(new CUDAContext(place_));
546 547 548 549
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
550
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
551
  if (nccl_comm_) {
552
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
553 554
  }
#endif
555 556
}

L
liaogang 已提交
557
Place CUDADeviceContext::GetPlace() const { return place_; }
558

559
void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); }
560

K
Kexin Zhao 已提交
561
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
562
  return compute_capability_;
K
Kexin Zhao 已提交
563 564
}

565
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
566
  return multi_process_ * max_threads_per_mp_;
567 568
}

569 570 571 572 573 574
int CUDADeviceContext::GetSMCount() const { return multi_process_; }

int CUDADeviceContext::GetMaxThreadsPerBlock() const {
  return max_threads_per_block_;
}

575
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
576
  return context()->EigenDevice().get();
577 578
}

579
bool CUDADeviceContext::tensor_core_available() const {
580
  return context()->CublasTensorCoreHandle() != nullptr;
S
sneaxiy 已提交
581 582
}

583 584 585 586
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
  return max_grid_dim_size_;
}

587 588 589
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
590
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
591
#endif
592 593
  return context()->CudnnHandle();
}
594

595 596 597 598 599
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}
#else
600 601 602
cublasHandle_t CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}
Z
zhangkaihuo 已提交
603 604 605
cusparseHandle_t CUDADeviceContext::cusparse_handle() const {
  return context()->CusparseHandle()->GetCusparseHandle();
}
606
#endif
607

S
sneaxiy 已提交
608
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
609
  return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
610
}
611

612
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
613 614 615
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  return context()->CusolverDnHandle();
}
616
#endif
G
Guo Sheng 已提交
617

618
gpuStream_t CUDADeviceContext::stream() const { return context()->RawStream(); }
Q
qijun 已提交
619

C
chengduoZH 已提交
620 621 622 623 624 625 626 627 628 629 630 631 632 633
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
634
#endif
Q
qijun 已提交
635

T
tensor-tang 已提交
636 637
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
638
    : CPUDeviceContext(place), p_blobmap_() {
639
  p_blobmap_.reset(new BlobMap());
640
  p_exec_items_.reset(new ExecShape());
641
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
642 643
}

644
MKLDNNDeviceContextThreadLocals::Body::Body()
645
    : cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
646 647 648 649 650 651
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

652 653 654 655 656 657 658 659 660 661 662 663
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
664
  dev_ctx->ResetBlobMap(exec_ptr_);
665 666
}

667 668 669 670 671 672 673 674 675 676
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
677 678
  cur_input_shape_str = input_shape_str;
}
679 680
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
681 682
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
683

684 685
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
686 687 688
  cur_paddle_data_layout = dl;
}

689 690
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
691 692 693
  return cur_paddle_data_layout;
}

694 695 696 697 698 699 700 701 702
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

703
const dnnl::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
704 705 706
  return cur_engine;
}

707
dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
708 709 710
  return cur_stream;
}

711
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
712 713 714
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
715 716 717 718 719 720
    // If no specific executor pointer then clear
    // everything. For executor pointer then clear only
    // objects allocated when using given executor
    if (ptr == nullptr) {
      p_blobmap_->clear();
    } else {
721 722 723 724 725
      // Iterate through all shapes and release
      // for each shape and active executor all entries
      // of this executor
      for (auto& s : *p_exec_items_) {
        for (auto& v : (*s.second)[ptr]) {
726
          (v.first)->erase(v.second);
727 728
        }
        s.second->erase(ptr);
729 730
      }
    }
731 732 733 734 735 736
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

737 738
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
  p_exec_items_->erase(p_exec_items_->begin());
739 740
}

741 742
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
                                                KeyBlob::iterator it) const {
743
  // Take current input shape from TLS
744 745
  // Take current executor addess from TLS
  // and for this executor's items add the one defined with arguments
746 747 748 749 750 751 752 753 754
  auto key_it = p_exec_items_
                    ->insert(std::make_pair(tls().cur_input_shape_str,
                                            std::make_shared<ExecMap>()))
                    .first;
  (*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));

  VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
          << " curr exec size: "
          << (*key_it->second)[tls().get_curr_exec()].size() << "\n";
755 756
}

757 758 759 760
void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
761
}
762

763
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
764
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
765
  BlobMap* pMap = p_blobmap_.get();
766
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
767
  if (map_it == pMap->end()) {
768 769 770
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
771 772 773 774
  }
  return map_it->second->size();
}

775
void MKLDNNDeviceContext::SetBlob(const std::string& name,
776
                                  BlobPtr_t<void> data) const {
777
  BlobMap* pMap = p_blobmap_.get();
778
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
779
  BlobPtr_t<KeyBlob> pBlob = nullptr;
780

781
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
782

783
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
784

785 786
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
787 788 789

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
790
    sBlob = std::make_shared<ShapeBlob>();
791 792
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
793
  } else {
794
    sBlob = map_it->second;
795
  }
T
tensor-tang 已提交
796

797
  // Find KeyBlob for current input shape
798
  auto key_it = sBlob->find(tls().cur_input_shape_str);
799

800
  if (key_it == sBlob->end()) {
801 802
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
803 804
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
805
        sBlob->size() &&
806
        (sBlob->size() >=
807
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
808 809 810 811
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
      RemoveShapeEntriesWithExecutor();
812
    }
813
    pBlob = std::make_shared<KeyBlob>();
814
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
815
  } else {
816
    pBlob = key_it->second;
817 818
  }

819
  // Find Blob via name
820 821 822 823
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    auto el =
        pBlob->insert(std::make_pair(name, data));  //  (*pBlob)[name] = data;
824 825 826
    // Register new element in per executor map
    // to have easily erased when executor terminated
    LinkEntryWithExecutor(pBlob, el.first);
827 828 829
  } else {
    blob_it->second = data;  // set data to existing blob
  }
830
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
831
  // lock will be automatically released when out of scope
832
  return;
T
tensor-tang 已提交
833 834
}

835
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
836 837 838
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
839
      num_entries += (l2.second)->size();
840 841 842 843 844
    }
  }
  return num_entries;
}

845
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
846
    const std::string& name) const {
847
  BlobMap* pMap = p_blobmap_.get();
848
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
849
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
850

851
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
852

853
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
854

855 856
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
857 858 859 860
  // (jczaja): After first iteration of model's execution we
  // should have all elements cached (mostly) so failures are unlikely (less
  // likely for dynamic shapes)
  if (unlikely(map_it == pMap->end())) {
861
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
862 863 864 865 866
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
867
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
868
  if (unlikely(sBlob_it == sBlob->end())) {
869
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
870 871 872 873
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
874 875

  // Find Blob via name
876
  auto key_it = pBlob->find(name);
877

878
  if (unlikely(key_it == pBlob->end())) {
879
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
880 881
    return nullptr;
  }
882

883
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
884 885
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
886 887 888
}

#endif
Q
qijun 已提交
889
}  // namespace platform
Q
qijun 已提交
890
}  // namespace paddle