device_context.cc 28.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
13
#include <set>
14

15
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
16
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
17
#include "paddle/fluid/platform/cuda_device_guard.h"
18
#endif
J
jianghaicheng 已提交
19 20 21
#ifdef PADDLE_WITH_IPU
#include "paddle/fluid/platform/ipu/ipu_backend.h"
#endif
22
#include "glog/logging.h"
23
#include "paddle/fluid/platform/profiler.h"
24

25 26 27 28 29
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
30
  if (size == 0) {
31 32
    return Alloc(place, size);
  }
33 34

  if (platform::is_gpu_place(place)) {
35
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
54 55
    return Alloc(place, size);
#else
56 57 58
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
59
#endif
60 61 62
  } else {
    return Alloc(place, size);
  }
63 64 65 66 67
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
68 69 70
namespace paddle {
namespace platform {

71
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
72 73 74
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
75 76 77 78

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
79 80
#endif  // PADDLE_WITH_CUDA

81 82 83 84 85 86 87 88 89 90 91 92 93
DeviceType Place2DeviceType(const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
    return platform::DeviceType::CPU;
  } else if (platform::is_gpu_place(place)) {
    return platform::DeviceType::CUDA;
  } else if (platform::is_xpu_place(place)) {
    return platform::DeviceType::XPU;
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported place %s to convert into platform::DeviceType.", place));
  }
}

D
dzhwinter 已提交
94 95
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
96
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
97
  VLOG(6) << "DeviceContextPool Get: " << place;
D
dzhwinter 已提交
98 99
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
100 101
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
J
jianghaicheng 已提交
102 103 104
        "with WITH_GPU, WITH_XPU, WITH_IPU or WITH_ASCEND_CL option or check "
        "that your train process set the correct device id if you use "
        "Executor.",
G
GaoWei8 已提交
105
        place));
D
dzhwinter 已提交
106
  }
107
  return it->second.get().get();
D
dzhwinter 已提交
108 109
}

110 111 112 113 114 115 116 117 118
template <typename DevCtx, typename PlaceType>
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
  map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
                     // lazy evaluation. i.e., only create device context at
                     // first `Get`
119
                     return PtrType(new DevCtx(BOOST_GET_CONST(PlaceType, p)));
120
                   }));
C
chengduozh 已提交
121 122
}

D
dzhwinter 已提交
123 124
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
125 126 127 128 129
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
130
  std::set<Place> set;
Y
Yu Yang 已提交
131 132 133 134 135
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
136
#ifdef PADDLE_WITH_MKLDNN
137
      EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
138
#else
139
      EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
140
#endif
Y
Yu Yang 已提交
141
    } else if (platform::is_gpu_place(p)) {
142
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
143
      EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
D
dzhwinter 已提交
144
#else
G
GaoWei8 已提交
145 146 147
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
148 149
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
150
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
151 152
      EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
          &device_contexts_, p);
C
chengduoZH 已提交
153
#else
G
GaoWei8 已提交
154
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
155 156
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
157 158 159 160 161 162 163 164
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
      EmplaceDeviceContext<XPUDeviceContext, XPUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
J
jianghaicheng 已提交
165 166 167 168 169 170 171 172
#endif
    } else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU
      EmplaceDeviceContext<IPUDeviceContext, IPUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("IPUPlace is not supported. Please "
                                          "re-compile with WITH_IPU option."));
173 174 175 176 177 178 179 180
#endif
    } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
      EmplaceDeviceContext<NPUDeviceContext, NPUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPlace is not supported. Please "
          "re-compile with WITH_ASCEND_CL option."));
181 182 183 184 185 186 187 188 189 190
#endif
    } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
      EmplaceDeviceContext<NPUPinnedDeviceContext, NPUPinnedPlace>(
          &device_contexts_, p);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPinnedPlace is not supported. Please re-compile with "
          "WITH_ASCEND_CL "
          "option."));
D
dzhwinter 已提交
191 192 193 194 195
#endif
    }
  }
}

196 197 198 199
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
200
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
201 202 203 204 205 206 207
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
208
Place CPUDeviceContext::GetPlace() const { return place_; }
209

J
jianghaicheng 已提交
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
#ifdef PADDLE_WITH_IPU
IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) {
  int id = place.GetDeviceId();
  std::shared_ptr<platform::ipu::IpuBackend> ipu_backend =
      platform::ipu::IpuBackend::GetInstance();
  device_ = ipu_backend->GetDevice(id);
}

Place IPUDeviceContext::GetPlace() const { return place_; }
void IPUDeviceContext::Wait() const {
  /*! \brief  Wait for all operations completion in the stream. */
}

IPUDeviceContext::~IPUDeviceContext() {}

#endif
226
#ifdef PADDLE_WITH_XPU
Q
QingshuChen 已提交
227 228 229 230
XPUDeviceContext::XPUDeviceContext() {
  context_ = xpu::create_context();
  xpu_version_ = get_xpu_version(place_.device);
}
231

232
XPUDeviceContext::~XPUDeviceContext() {}
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247

XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) {
  int dev_id = -1;
  int ret = xpu_current_device(&dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  ret = xpu_set_device(place.device);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
248 249 250

  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: " << place_.device;

251
  context_ = xpu::create_context();
252 253 254
  const int MAX_XPU_NUM = 16;
  static void* l3ptrs[MAX_XPU_NUM] = {nullptr};

255 256 257 258 259
  int l3_size = 13.5 * 1024 * 1024;
  if (std::getenv("XPU_PADDLE_L3_SIZE") != nullptr) {
    l3_size = atoi(std::getenv("XPU_PADDLE_L3_SIZE"));
  }

260 261 262 263 264 265 266 267 268 269 270 271 272
  auto selected_xpus = GetXPUSelectedDevices();
  for (unsigned int i = 0; i < selected_xpus.size(); i++) {
    if (place.device == selected_xpus[i]) {
      if (l3ptrs[place.device] == nullptr) {
        xpu_malloc(static_cast<void**>(&l3ptrs[place.device]), l3_size,
                   XPU_MEM_L3);
      }
      if (l3ptrs[place.device] != nullptr) {
        context_->_l3_mgr.set(l3ptrs[place.device], l3_size);
        VLOG(3) << "xpu place " << place.device << " set l3 size " << l3_size;
      }
      break;
    }
273
  }
274

275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
  ret = xpu_set_device(dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
}

void XPUDeviceContext::Wait() const {
  int ret = xpu_set_device(place_.device);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
290
  xpu_wait(context_->xpu_stream);
291 292 293 294 295 296 297
}

Place XPUDeviceContext::GetPlace() const { return place_; }

xpu::Context* XPUDeviceContext::x_context() const { return context_; }
#endif

298 299 300 301 302 303 304
#ifdef PADDLE_WITH_ASCEND_CL
NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) {
  NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device));
  // NOTE(zhiqiu): Usually, no need to create context explicitly,
  // ACL creates a default context which contains 1 default stream
  // and 1 sync strean after aclrtSetDevice.
305
  platform::GetCurrentNPUContext(&context_);
306 307 308 309 310 311 312
  stream_.reset(new stream::NPUStream(place));
}

NPUDeviceContext::~NPUDeviceContext() {
  // NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_));
}
313

314
void NPUDeviceContext::Wait() const {
315 316 317
  platform::RecordEvent record_event("NPUDeviceContext/wait");
  VLOG(4) << "NPU context(" << this << ")  Wait";
  stream_->Wait();
318 319 320 321 322 323 324
}

aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); }

Place NPUDeviceContext::GetPlace() const { return place_; }

aclrtContext NPUDeviceContext::context() const { return context_; }
325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340

NPUPinnedDeviceContext::NPUPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

NPUPinnedDeviceContext::NPUPinnedDeviceContext(NPUPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* NPUPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place NPUPinnedDeviceContext::GetPlace() const { return place_; }

341 342 343
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
init  
qijun 已提交
344 345 346 347 348 349 350
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

351
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
352 353 354 355 356
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

357
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
358

359 360 361
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
362
  const cudaDeviceProp& deviceProperties() const override {
363
#endif
Q
init  
qijun 已提交
364 365 366 367
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
368 369 370
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
371 372 373
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
374
    void* retv = buf->ptr();
S
sneaxiy 已提交
375 376 377 378
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
379
    return retv;
Q
init  
qijun 已提交
380 381
  }

S
sneaxiy 已提交
382 383 384 385 386 387
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
388 389 390

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
391
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
392 393 394 395 396 397
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
398
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
399
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
400
#ifdef PADDLE_WITH_HIP
401
      PADDLE_ENFORCE_GPU_SUCCESS(
402 403
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
404
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
init  
qijun 已提交
405
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
406
#endif
Q
init  
qijun 已提交
407 408 409 410 411
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
412
  CUDAPlace place_;
413 414 415 416
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
417
  const cudaDeviceProp* device_prop_;  // not owned;
418
#endif
Q
qijun 已提交
419
  mutable void* scratch_;
Q
init  
qijun 已提交
420
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
421
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
422
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
423 424
};

425 426 427 428 429 430 431 432 433
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

434 435 436 437 438 439 440 441 442 443 444 445
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
446 447
                         const stream::Priority& priority,
                         const stream::StreamFlag& flag) {
448 449
  place_ = place;
  CUDADeviceGuard guard(place_.device);
450
  stream_.reset(new stream::CUDAStream(place, priority, flag));
451 452 453
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
454
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
455
  InitCuSolverContext();
456
#endif
457 458 459 460 461 462
}

CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
463
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
464
  DestoryCuSolverContext();
465
#endif
466 467
}

468
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
Y
Yu Yang 已提交
469
  CUDADeviceGuard guard(place_.device);
470 471 472
  compute_capability_ = GetGPUComputeCapability(place_.device);
  multi_process_ = GetGPUMultiProcessors(place_.device);
  max_threads_per_mp_ = GetGPUMaxThreadsPerMultiProcessor(place_.device);
473
  max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
474
  max_threads_per_block_ = GetGPUMaxThreadsPerBlock(place_.device);
475

476 477
  driver_version_ = GetGPUDriverVersion(place_.device);
  runtime_version_ = GetGPURuntimeVersion(place_.device);
C
chengduo 已提交
478

479
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device
480 481 482
                          << ", GPU Compute Capability: "
                          << compute_capability_ / 10 << "."
                          << compute_capability_ % 10
C
chengduo 已提交
483
                          << ", Driver API Version: " << driver_version_ / 1000
484
                          << "." << (driver_version_ % 100) / 10
C
chengduo 已提交
485 486 487
                          << ", Runtime API Version: "
                          << runtime_version_ / 1000 << "."
                          << (runtime_version_ % 100) / 10;
488 489
#ifdef PADDLE_WITH_HIP
  size_t version_major, version_minor, version_patch;
490
  PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenGetVersion(
491 492 493 494 495
      &version_major, &version_minor, &version_patch));
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", MIOpen Version: " << version_major << "."
                          << version_minor << "." << version_patch;
#else
496 497 498
  size_t cudnn_dso_ver = dynload::cudnnGetVersion();
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
499
                          << (cudnn_dso_ver % 1000) / 100 << ".";
500
#endif
S
sneaxiy 已提交
501 502
  {
    // Check CUDA/CUDNN version compatiblity
503 504
    auto local_cuda_version =
        (driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
505 506 507
#ifdef PADDLE_WITH_HIP
    auto compile_cuda_version = (HIP_VERSION / 100) * 10 + (HIP_VERSION % 10);
#else
508 509
    auto compile_cuda_version =
        (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
510
#endif
S
sneaxiy 已提交
511 512 513 514 515 516 517 518 519 520 521 522
    if (local_cuda_version < compile_cuda_version) {
      LOG_FIRST_N(WARNING, 1)
          << "WARNING: device: " << place_.device
          << ". The installed Paddle is compiled with CUDA "
          << compile_cuda_version / 10 << "." << compile_cuda_version % 10
          << ", but CUDA runtime version in your machine is "
          << local_cuda_version / 10 << "." << local_cuda_version % 10
          << ", which may cause serious incompatible bug. "
          << "Please recompile or reinstall Paddle with compatible CUDA "
             "version.";
    }
  }
523
  default_ctx_.reset(new CUDAContext(place_));
524 525 526 527
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
528
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
529
  if (nccl_comm_) {
530
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
531 532
  }
#endif
533 534
}

L
liaogang 已提交
535
Place CUDADeviceContext::GetPlace() const { return place_; }
536

537
void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); }
538

K
Kexin Zhao 已提交
539
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
540
  return compute_capability_;
K
Kexin Zhao 已提交
541 542
}

543
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
544
  return multi_process_ * max_threads_per_mp_;
545 546
}

547 548 549 550 551 552
int CUDADeviceContext::GetSMCount() const { return multi_process_; }

int CUDADeviceContext::GetMaxThreadsPerBlock() const {
  return max_threads_per_block_;
}

553
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
554
  return context()->EigenDevice().get();
555 556
}

557
bool CUDADeviceContext::tensor_core_available() const {
558
  return context()->CublasTensorCoreHandle() != nullptr;
S
sneaxiy 已提交
559 560
}

561 562 563 564
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
  return max_grid_dim_size_;
}

565 566 567
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
568
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
569
#endif
570 571
  return context()->CudnnHandle();
}
572

573 574 575 576 577
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}
#else
578 579 580
cublasHandle_t CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}
581
#endif
582

S
sneaxiy 已提交
583
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
584
  return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
585
}
586

587
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
588 589 590
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  return context()->CusolverDnHandle();
}
591
#endif
G
Guo Sheng 已提交
592

593
gpuStream_t CUDADeviceContext::stream() const { return context()->RawStream(); }
Q
qijun 已提交
594

C
chengduoZH 已提交
595 596 597 598 599 600 601 602 603 604 605 606 607 608
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
609
#endif
Q
qijun 已提交
610

T
tensor-tang 已提交
611 612
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
613
    : CPUDeviceContext(place), p_blobmap_() {
614
  p_blobmap_.reset(new BlobMap());
615
  p_exec_items_.reset(new ExecShape());
616
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
617 618
}

619
MKLDNNDeviceContextThreadLocals::Body::Body()
620
    : cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
621 622 623 624 625 626
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

627 628 629 630 631 632 633 634 635 636 637 638
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
639
  dev_ctx->ResetBlobMap(exec_ptr_);
640 641
}

642 643 644 645 646 647 648 649 650 651
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
652 653
  cur_input_shape_str = input_shape_str;
}
654 655
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
656 657
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
658

659 660
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
661 662 663
  cur_paddle_data_layout = dl;
}

664 665
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
666 667 668
  return cur_paddle_data_layout;
}

669 670 671 672 673 674 675 676 677
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

678
const dnnl::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
679 680 681
  return cur_engine;
}

682
dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
683 684 685
  return cur_stream;
}

686
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
687 688 689
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
690 691 692 693 694 695
    // If no specific executor pointer then clear
    // everything. For executor pointer then clear only
    // objects allocated when using given executor
    if (ptr == nullptr) {
      p_blobmap_->clear();
    } else {
696 697 698 699 700
      // Iterate through all shapes and release
      // for each shape and active executor all entries
      // of this executor
      for (auto& s : *p_exec_items_) {
        for (auto& v : (*s.second)[ptr]) {
701
          (v.first)->erase(v.second);
702 703
        }
        s.second->erase(ptr);
704 705
      }
    }
706 707 708 709 710 711
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

712 713
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
  p_exec_items_->erase(p_exec_items_->begin());
714 715
}

716 717
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
                                                KeyBlob::iterator it) const {
718
  // Take current input shape from TLS
719 720
  // Take current executor addess from TLS
  // and for this executor's items add the one defined with arguments
721 722 723 724 725 726 727 728 729
  auto key_it = p_exec_items_
                    ->insert(std::make_pair(tls().cur_input_shape_str,
                                            std::make_shared<ExecMap>()))
                    .first;
  (*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));

  VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
          << " curr exec size: "
          << (*key_it->second)[tls().get_curr_exec()].size() << "\n";
730 731
}

732 733 734 735
void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
736
}
737

738
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
739
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
740
  BlobMap* pMap = p_blobmap_.get();
741
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
742
  if (map_it == pMap->end()) {
743 744 745
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
746 747 748 749
  }
  return map_it->second->size();
}

750
void MKLDNNDeviceContext::SetBlob(const std::string& name,
751
                                  BlobPtr_t<void> data) const {
752
  BlobMap* pMap = p_blobmap_.get();
753
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
754
  BlobPtr_t<KeyBlob> pBlob = nullptr;
755

756
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
757

758
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
759

760 761
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
762 763 764

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
765
    sBlob = std::make_shared<ShapeBlob>();
766 767
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
768
  } else {
769
    sBlob = map_it->second;
770
  }
T
tensor-tang 已提交
771

772
  // Find KeyBlob for current input shape
773
  auto key_it = sBlob->find(tls().cur_input_shape_str);
774

775
  if (key_it == sBlob->end()) {
776 777
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
778 779
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
780
        sBlob->size() &&
781
        (sBlob->size() >=
782
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
783 784 785 786
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
      RemoveShapeEntriesWithExecutor();
787
    }
788
    pBlob = std::make_shared<KeyBlob>();
789
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
790
  } else {
791
    pBlob = key_it->second;
792 793
  }

794
  // Find Blob via name
795 796 797 798
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    auto el =
        pBlob->insert(std::make_pair(name, data));  //  (*pBlob)[name] = data;
799 800 801
    // Register new element in per executor map
    // to have easily erased when executor terminated
    LinkEntryWithExecutor(pBlob, el.first);
802 803 804
  } else {
    blob_it->second = data;  // set data to existing blob
  }
805
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
806
  // lock will be automatically released when out of scope
807
  return;
T
tensor-tang 已提交
808 809
}

810
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
811 812 813
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
814
      num_entries += (l2.second)->size();
815 816 817 818 819
    }
  }
  return num_entries;
}

820
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
821
    const std::string& name) const {
822
  BlobMap* pMap = p_blobmap_.get();
823
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
824
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
825

826
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
827

828
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
829

830 831
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
832
  if (map_it == pMap->end()) {
833
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
834 835 836 837 838
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
839
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
840
  if (sBlob_it == sBlob->end()) {
841
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
842 843 844 845
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
846 847

  // Find Blob via name
848
  auto key_it = pBlob->find(name);
849

850
  if (key_it == pBlob->end()) {
851
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
852 853
    return nullptr;
  }
854

855
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
856 857
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
858 859 860
}

#endif
Q
qijun 已提交
861
}  // namespace platform
Q
qijun 已提交
862
}  // namespace paddle