device_context.cc 28.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
13
#include <set>
14 15 16 17 18 19
#include <utility>
#ifdef _WIN32
#include <intrin.h>
#else
#include <x86intrin.h>
#endif
20

21
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
22
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
23
#include "paddle/fluid/platform/cuda_device_guard.h"
24
#endif
25
#include "glog/logging.h"
26
#include "paddle/fluid/platform/profiler.h"
27

28 29 30 31 32
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
33
  if (size == 0) {
34 35
    return Alloc(place, size);
  }
36 37

  if (platform::is_gpu_place(place)) {
38
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
57 58
    return Alloc(place, size);
#else
59 60 61
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
62
#endif
63 64 65
  } else {
    return Alloc(place, size);
  }
66 67 68 69 70
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
71 72 73
namespace paddle {
namespace platform {

74
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
75 76 77
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
78 79 80 81

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
82 83
#endif  // PADDLE_WITH_CUDA

84 85 86 87 88 89 90 91 92 93 94 95 96
DeviceType Place2DeviceType(const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
    return platform::DeviceType::CPU;
  } else if (platform::is_gpu_place(place)) {
    return platform::DeviceType::CUDA;
  } else if (platform::is_xpu_place(place)) {
    return platform::DeviceType::XPU;
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported place %s to convert into platform::DeviceType.", place));
  }
}

D
dzhwinter 已提交
97 98
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
99
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
100
  VLOG(4) << "DeviceContextPool Get: " << place;
D
dzhwinter 已提交
101 102
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
103 104
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
105 106
        "with WITH_GPU, WITH_XPU or WITH_ASCEND_CL option or check that "
        "your train process set the correct device id if you use Executor.",
G
GaoWei8 已提交
107
        place));
D
dzhwinter 已提交
108
  }
109
  return it->second.get().get();
D
dzhwinter 已提交
110 111
}

112 113 114 115 116 117 118 119 120
template <typename DevCtx, typename PlaceType>
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
  map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
                     // lazy evaluation. i.e., only create device context at
                     // first `Get`
121
                     return PtrType(new DevCtx(BOOST_GET_CONST(PlaceType, p)));
122
                   }));
C
chengduozh 已提交
123 124
}

D
dzhwinter 已提交
125 126
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
127 128 129 130 131
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
132
  std::set<Place> set;
Y
Yu Yang 已提交
133 134 135 136 137
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
138
#ifdef PADDLE_WITH_MKLDNN
139
      EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
140
#else
141
      EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
142
#endif
Y
Yu Yang 已提交
143
    } else if (platform::is_gpu_place(p)) {
144
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
145
      EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
D
dzhwinter 已提交
146
#else
G
GaoWei8 已提交
147 148 149
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
150 151
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
152
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
153 154
      EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
          &device_contexts_, p);
C
chengduoZH 已提交
155
#else
G
GaoWei8 已提交
156
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
157 158
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
159 160 161 162 163 164 165 166
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
      EmplaceDeviceContext<XPUDeviceContext, XPUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
167 168 169 170 171 172 173 174
#endif
    } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
      EmplaceDeviceContext<NPUDeviceContext, NPUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPlace is not supported. Please "
          "re-compile with WITH_ASCEND_CL option."));
175 176 177 178 179 180 181 182 183 184
#endif
    } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
      EmplaceDeviceContext<NPUPinnedDeviceContext, NPUPinnedPlace>(
          &device_contexts_, p);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPinnedPlace is not supported. Please re-compile with "
          "WITH_ASCEND_CL "
          "option."));
D
dzhwinter 已提交
185 186 187 188 189
#endif
    }
  }
}

190 191 192 193
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
194
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
195 196 197 198 199 200 201
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
202
Place CPUDeviceContext::GetPlace() const { return place_; }
203

204
#ifdef PADDLE_WITH_XPU
Q
QingshuChen 已提交
205 206 207 208
XPUDeviceContext::XPUDeviceContext() {
  context_ = xpu::create_context();
  xpu_version_ = get_xpu_version(place_.device);
}
209

210
XPUDeviceContext::~XPUDeviceContext() {}
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225

XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) {
  int dev_id = -1;
  int ret = xpu_current_device(&dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  ret = xpu_set_device(place.device);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
226 227 228

  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: " << place_.device;

229
  context_ = xpu::create_context();
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
  const int MAX_XPU_NUM = 16;
  const int l3_size = 13.5 * 1024 * 1024;
  static void* l3ptrs[MAX_XPU_NUM] = {nullptr};

  auto selected_xpus = GetXPUSelectedDevices();
  for (unsigned int i = 0; i < selected_xpus.size(); i++) {
    if (place.device == selected_xpus[i]) {
      if (l3ptrs[place.device] == nullptr) {
        xpu_malloc(static_cast<void**>(&l3ptrs[place.device]), l3_size,
                   XPU_MEM_L3);
      }
      if (l3ptrs[place.device] != nullptr) {
        context_->_l3_mgr.set(l3ptrs[place.device], l3_size);
        VLOG(3) << "xpu place " << place.device << " set l3 size " << l3_size;
      }
      break;
    }
247
  }
248

249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
  ret = xpu_set_device(dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
}

void XPUDeviceContext::Wait() const {
  int ret = xpu_set_device(place_.device);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
264
  xpu_wait(context_->xpu_stream);
265 266 267 268 269 270 271
}

Place XPUDeviceContext::GetPlace() const { return place_; }

xpu::Context* XPUDeviceContext::x_context() const { return context_; }
#endif

272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
#ifdef PADDLE_WITH_ASCEND_CL
NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) {
  NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device));
  // NOTE(zhiqiu): Usually, no need to create context explicitly,
  // ACL creates a default context which contains 1 default stream
  // and 1 sync strean after aclrtSetDevice.
  PADDLE_ENFORCE_NPU_SUCCESS(aclrtGetCurrentContext(&context_));
  stream_.reset(new stream::NPUStream(place));
}

NPUDeviceContext::~NPUDeviceContext() {
  // NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_));
}
287

288
void NPUDeviceContext::Wait() const {
289 290 291
  platform::RecordEvent record_event("NPUDeviceContext/wait");
  VLOG(4) << "NPU context(" << this << ")  Wait";
  stream_->Wait();
292 293 294 295 296 297 298
}

aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); }

Place NPUDeviceContext::GetPlace() const { return place_; }

aclrtContext NPUDeviceContext::context() const { return context_; }
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314

NPUPinnedDeviceContext::NPUPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

NPUPinnedDeviceContext::NPUPinnedDeviceContext(NPUPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* NPUPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place NPUPinnedDeviceContext::GetPlace() const { return place_; }

315 316 317
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
init  
qijun 已提交
318 319 320 321 322 323 324
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

325
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
326 327 328 329 330
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

331
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
332

333 334 335
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
336
  const cudaDeviceProp& deviceProperties() const override {
337
#endif
Q
init  
qijun 已提交
338 339 340 341
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
342 343 344
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
345 346 347
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
348
    void* retv = buf->ptr();
S
sneaxiy 已提交
349 350 351 352
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
353
    return retv;
Q
init  
qijun 已提交
354 355
  }

S
sneaxiy 已提交
356 357 358 359 360 361
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
362 363 364

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
365
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
366 367 368 369 370 371
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
372
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
373
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
374 375 376 377
#ifdef PADDLE_WITH_HIP
      PADDLE_ENFORCE_CUDA_SUCCESS(
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
378
      PADDLE_ENFORCE_CUDA_SUCCESS(
Q
init  
qijun 已提交
379
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
380
#endif
Q
init  
qijun 已提交
381 382 383 384 385
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
386
  CUDAPlace place_;
387 388 389 390
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
391
  const cudaDeviceProp* device_prop_;  // not owned;
392
#endif
Q
qijun 已提交
393
  mutable void* scratch_;
Q
init  
qijun 已提交
394
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
395
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
396
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
397 398
};

399 400 401 402 403 404 405 406 407
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

408 409 410 411 412 413 414 415 416 417 418 419
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
420 421
                         const stream::Priority& priority,
                         const stream::StreamFlag& flag) {
422 423
  place_ = place;
  CUDADeviceGuard guard(place_.device);
424
  stream_.reset(new stream::CUDAStream(place, priority, flag));
425 426 427
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
428
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
429
  InitCuSolverContext();
430
#endif
431 432 433 434 435 436
}

CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
437
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
438
  DestoryCuSolverContext();
439
#endif
440 441
}

442
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
Y
Yu Yang 已提交
443
  CUDADeviceGuard guard(place_.device);
C
chengduo 已提交
444 445 446
  compute_capability_ = GetCUDAComputeCapability(place_.device);
  multi_process_ = GetCUDAMultiProcessors(place_.device);
  max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
447
  max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
448
  max_threads_per_block_ = GetCUDAMaxThreadsPerBlock(place_.device);
449

C
chengduo 已提交
450 451 452
  driver_version_ = GetCUDADriverVersion(place_.device);
  runtime_version_ = GetCUDARuntimeVersion(place_.device);

453
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device
454 455 456
                          << ", GPU Compute Capability: "
                          << compute_capability_ / 10 << "."
                          << compute_capability_ % 10
C
chengduo 已提交
457
                          << ", Driver API Version: " << driver_version_ / 1000
458
                          << "." << (driver_version_ % 100) / 10
C
chengduo 已提交
459 460 461
                          << ", Runtime API Version: "
                          << runtime_version_ / 1000 << "."
                          << (runtime_version_ % 100) / 10;
462 463 464 465 466 467 468 469
#ifdef PADDLE_WITH_HIP
  size_t version_major, version_minor, version_patch;
  PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenGetVersion(
      &version_major, &version_minor, &version_patch));
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", MIOpen Version: " << version_major << "."
                          << version_minor << "." << version_patch;
#else
470 471 472
  size_t cudnn_dso_ver = dynload::cudnnGetVersion();
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
473
                          << (cudnn_dso_ver % 1000) / 100 << ".";
474
#endif
S
sneaxiy 已提交
475 476
  {
    // Check CUDA/CUDNN version compatiblity
477 478
    auto local_cuda_version =
        (driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
479 480 481
#ifdef PADDLE_WITH_HIP
    auto compile_cuda_version = (HIP_VERSION / 100) * 10 + (HIP_VERSION % 10);
#else
482 483
    auto compile_cuda_version =
        (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
484
#endif
S
sneaxiy 已提交
485 486 487 488 489 490 491 492 493 494 495 496
    if (local_cuda_version < compile_cuda_version) {
      LOG_FIRST_N(WARNING, 1)
          << "WARNING: device: " << place_.device
          << ". The installed Paddle is compiled with CUDA "
          << compile_cuda_version / 10 << "." << compile_cuda_version % 10
          << ", but CUDA runtime version in your machine is "
          << local_cuda_version / 10 << "." << local_cuda_version % 10
          << ", which may cause serious incompatible bug. "
          << "Please recompile or reinstall Paddle with compatible CUDA "
             "version.";
    }
  }
497
  default_ctx_.reset(new CUDAContext(place_));
498 499 500 501
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
502
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
503 504 505 506
  if (nccl_comm_) {
    PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
  }
#endif
507 508
}

L
liaogang 已提交
509
Place CUDADeviceContext::GetPlace() const { return place_; }
510

511
void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); }
512

K
Kexin Zhao 已提交
513
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
514
  return compute_capability_;
K
Kexin Zhao 已提交
515 516
}

517
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
518
  return multi_process_ * max_threads_per_mp_;
519 520
}

521 522 523 524 525 526
int CUDADeviceContext::GetSMCount() const { return multi_process_; }

int CUDADeviceContext::GetMaxThreadsPerBlock() const {
  return max_threads_per_block_;
}

527
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
528
  return context()->EigenDevice().get();
529 530
}

531
bool CUDADeviceContext::tensor_core_available() const {
532
  return context()->CublasTensorCoreHandle() != nullptr;
S
sneaxiy 已提交
533 534
}

535 536 537 538
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
  return max_grid_dim_size_;
}

539 540 541
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
542
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
543
#endif
544 545
  return context()->CudnnHandle();
}
546

547 548 549 550 551
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}
#else
552 553 554
cublasHandle_t CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}
555
#endif
556

S
sneaxiy 已提交
557
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
558
  return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
559
}
560

561
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
562 563 564
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  return context()->CusolverDnHandle();
}
565
#endif
G
Guo Sheng 已提交
566

567
gpuStream_t CUDADeviceContext::stream() const { return context()->RawStream(); }
Q
qijun 已提交
568

C
chengduoZH 已提交
569 570 571 572 573 574 575 576 577 578 579 580 581 582
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
583
#endif
Q
qijun 已提交
584

T
tensor-tang 已提交
585 586
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
587
    : CPUDeviceContext(place), p_blobmap_() {
588
  p_blobmap_.reset(new BlobMap());
589
  p_exec_items_.reset(new ExecShape());
590
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
591 592
}

593 594
MKLDNNDeviceContextThreadLocals::Body::Body()
    : cur_engine(mkldnn::engine::kind::cpu, 0), cur_stream(cur_engine) {
595 596 597 598 599 600
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

601 602 603 604 605 606 607 608 609 610 611 612
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
613
  dev_ctx->ResetBlobMap(exec_ptr_);
614 615
}

616 617 618 619 620 621 622 623 624 625
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
626 627
  cur_input_shape_str = input_shape_str;
}
628 629
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
630 631
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
632

633 634
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
635 636 637
  cur_paddle_data_layout = dl;
}

638 639
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
640 641 642
  return cur_paddle_data_layout;
}

643 644 645 646 647 648 649 650 651
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

652 653 654 655 656 657 658 659
const mkldnn::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
  return cur_engine;
}

mkldnn::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
  return cur_stream;
}

660
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
661 662 663
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
664 665 666 667 668 669
    // If no specific executor pointer then clear
    // everything. For executor pointer then clear only
    // objects allocated when using given executor
    if (ptr == nullptr) {
      p_blobmap_->clear();
    } else {
670 671 672 673 674
      // Iterate through all shapes and release
      // for each shape and active executor all entries
      // of this executor
      for (auto& s : *p_exec_items_) {
        for (auto& v : (*s.second)[ptr]) {
675
          (v.first)->second.erase(v.second);
676 677
        }
        s.second->erase(ptr);
678 679
      }
    }
680 681 682 683 684 685
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701
std::string MKLDNNDeviceContext::PickLeastUsedShape(
    BlobPtr_t<ShapeBlob> sb) const {
  auto ancient_one = sb->begin();
  for (auto v = std::next(sb->begin()); v != sb->end(); ++v) {
    if (v->second->first < ancient_one->second->first) {
      ancient_one = v;
    }
  }
  VLOG(2) << "num_shapes: " << sb->size()
          << ", remove all blobs of shape: " << ancient_one->first;
  return ancient_one->first;
}

void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(
    std::string shape_to_be_removed) const {
  p_exec_items_->erase(shape_to_be_removed);
702 703
}

704 705 706
void MKLDNNDeviceContext::LinkEntryWithExecutor(
    BlobPtr_t<std::pair<unsigned long long, KeyBlob>> pblob,
    KeyBlob::iterator it) const {
707
  // Take current input shape from TLS
708 709
  // Take current executor addess from TLS
  // and for this executor's items add the one defined with arguments
710 711 712 713 714 715 716 717 718
  auto key_it = p_exec_items_
                    ->insert(std::make_pair(tls().cur_input_shape_str,
                                            std::make_shared<ExecMap>()))
                    .first;
  (*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));

  VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
          << " curr exec size: "
          << (*key_it->second)[tls().get_curr_exec()].size() << "\n";
719 720
}

721 722 723 724
void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
725
}
726

727
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
728
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
729
  BlobMap* pMap = p_blobmap_.get();
730
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
731
  if (map_it == pMap->end()) {
732 733 734
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
735 736 737 738
  }
  return map_it->second->size();
}

739
void MKLDNNDeviceContext::SetBlob(const std::string& name,
740
                                  BlobPtr_t<void> data) const {
741
  BlobMap* pMap = p_blobmap_.get();
742
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
743
  BlobPtr_t<std::pair<unsigned long long, KeyBlob>> pBlob = nullptr;
744

745
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
746

747
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
748

749 750
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
751 752 753

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
754
    sBlob = std::make_shared<ShapeBlob>();
755 756
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
757
  } else {
758
    sBlob = map_it->second;
759
  }
T
tensor-tang 已提交
760

761
  // Find KeyBlob for current input shape
762
  auto key_it = sBlob->find(tls().cur_input_shape_str);
763

764
  if (key_it == sBlob->end()) {
765 766
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
767 768
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
769
        sBlob->size() &&
770
        (sBlob->size() >=
771
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
772 773 774
      auto shape_to_be_erased = PickLeastUsedShape(sBlob);
      sBlob->erase(shape_to_be_erased);
      RemoveShapeEntriesWithExecutor(shape_to_be_erased);
775
    }
776 777
    pBlob = std::make_shared<std::pair<unsigned long long, KeyBlob>>();
    pBlob->first = __rdtsc();
778
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
779
  } else {
780
    pBlob = key_it->second;
781 782
    // Update time stamp
    pBlob->first = __rdtsc();
783 784
  }

785
  // Find Blob via name
786 787 788 789
  auto blob_it = pBlob->second.find(name);
  if (blob_it == pBlob->second.end()) {
    auto el = pBlob->second.insert(
        std::make_pair(name, data));  //  (*pBlob)[name] = data;
790 791 792
    // Register new element in per executor map
    // to have easily erased when executor terminated
    LinkEntryWithExecutor(pBlob, el.first);
793 794 795
  } else {
    blob_it->second = data;  // set data to existing blob
  }
796
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
797
  // lock will be automatically released when out of scope
798
  return;
T
tensor-tang 已提交
799 800
}

801
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
802 803 804
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
805
      num_entries += (l2.second->second).size();
806 807 808 809 810
    }
  }
  return num_entries;
}

811
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
812
    const std::string& name) const {
813
  BlobMap* pMap = p_blobmap_.get();
814
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
815
  BlobPtr_t<std::pair<unsigned long long, KeyBlob>> pBlob = nullptr;
T
tensor-tang 已提交
816

817
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
818

819
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
820

821 822
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
823
  if (map_it == pMap->end()) {
824
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
825 826 827 828 829
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
830
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
831
  if (sBlob_it == sBlob->end()) {
832
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
833 834 835 836
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
837 838

  // Find Blob via name
839
  auto key_it = pBlob->second.find(name);
840

841
  if (key_it == pBlob->second.end()) {
842
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
843 844
    return nullptr;
  }
845 846
  // Update timestamp
  sBlob_it->second->first = __rdtsc();  // TODO(windows)
847

848
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
849 850
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
851 852 853
}

#endif
Q
qijun 已提交
854
}  // namespace platform
Q
qijun 已提交
855
}  // namespace paddle