device_context.cc 30.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
13
#include <set>
14

15
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
16
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
17
#include "paddle/fluid/platform/cuda_device_guard.h"
18
#endif
F
fwenguang 已提交
19 20 21 22
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/fluid/platform/device/mlu/device_context_allocator.h"
#endif
J
jianghaicheng 已提交
23 24 25
#ifdef PADDLE_WITH_IPU
#include "paddle/fluid/platform/ipu/ipu_backend.h"
#endif
26
#include "glog/logging.h"
27
#include "paddle/fluid/platform/profiler.h"
28

29 30 31 32 33
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
34
  if (size == 0) {
35 36
    return Alloc(place, size);
  }
37 38

  if (platform::is_gpu_place(place)) {
39
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
58 59
    return Alloc(place, size);
#else
60 61 62
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
F
fwenguang 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
#endif
  } else if (platform::is_mlu_place(place)) {
#ifdef PADDLE_WITH_MLU
    auto* default_dev_ctx = static_cast<platform::MLUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::MLUDeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::MLUDeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use MLU device since it's not compiled with MLU,"
        "Please recompile or reinstall Paddle with MLU support."));
80
#endif
81 82 83
  } else {
    return Alloc(place, size);
  }
84 85 86 87 88
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
89 90 91
namespace paddle {
namespace platform {

92
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
93 94 95
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
96 97 98 99

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
100 101
#endif  // PADDLE_WITH_CUDA

102 103 104 105 106 107 108
DeviceType Place2DeviceType(const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
    return platform::DeviceType::CPU;
  } else if (platform::is_gpu_place(place)) {
    return platform::DeviceType::CUDA;
  } else if (platform::is_xpu_place(place)) {
    return platform::DeviceType::XPU;
F
fwenguang 已提交
109 110
  } else if (platform::is_mlu_place(place)) {
    return platform::DeviceType::MLU;
111 112 113 114 115 116
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported place %s to convert into platform::DeviceType.", place));
  }
}

D
dzhwinter 已提交
117 118
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
119
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
120
  VLOG(6) << "DeviceContextPool Get: " << place;
D
dzhwinter 已提交
121 122
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
123 124
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
F
fwenguang 已提交
125 126
        "with WITH_GPU, WITH_XPU, WITH_IPU, WITH_MLU or WITH_ASCEND_CL option "
        "or check "
J
jianghaicheng 已提交
127 128
        "that your train process set the correct device id if you use "
        "Executor.",
G
GaoWei8 已提交
129
        place));
D
dzhwinter 已提交
130
  }
131
  return it->second.get().get();
D
dzhwinter 已提交
132 133
}

134 135 136 137 138 139 140 141 142
template <typename DevCtx, typename PlaceType>
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
  map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
                     // lazy evaluation. i.e., only create device context at
                     // first `Get`
143
                     return PtrType(new DevCtx(BOOST_GET_CONST(PlaceType, p)));
144
                   }));
C
chengduozh 已提交
145 146
}

D
dzhwinter 已提交
147 148
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
149 150 151 152 153
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
154
  std::set<Place> set;
Y
Yu Yang 已提交
155 156 157 158 159
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
160
#ifdef PADDLE_WITH_MKLDNN
161
      EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
162
#else
163
      EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
164
#endif
Y
Yu Yang 已提交
165
    } else if (platform::is_gpu_place(p)) {
166
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
167
      EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
D
dzhwinter 已提交
168
#else
G
GaoWei8 已提交
169 170 171
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
172 173
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
174
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
175 176
      EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
          &device_contexts_, p);
C
chengduoZH 已提交
177
#else
G
GaoWei8 已提交
178
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
179 180
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
181 182 183 184 185 186 187 188
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
      EmplaceDeviceContext<XPUDeviceContext, XPUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
F
fwenguang 已提交
189 190 191 192 193 194 195 196
#endif
    } else if (platform::is_mlu_place(p)) {
#ifdef PADDLE_WITH_MLU
      EmplaceDeviceContext<MLUDeviceContext, MLUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("MLUPlace is not supported. Please "
                                          "re-compile with WITH_MLU option."));
J
jianghaicheng 已提交
197 198 199 200 201 202 203 204
#endif
    } else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU
      EmplaceDeviceContext<IPUDeviceContext, IPUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("IPUPlace is not supported. Please "
                                          "re-compile with WITH_IPU option."));
205 206 207 208 209 210 211 212
#endif
    } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
      EmplaceDeviceContext<NPUDeviceContext, NPUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPlace is not supported. Please "
          "re-compile with WITH_ASCEND_CL option."));
213 214 215 216 217 218 219 220 221 222
#endif
    } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
      EmplaceDeviceContext<NPUPinnedDeviceContext, NPUPinnedPlace>(
          &device_contexts_, p);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPinnedPlace is not supported. Please re-compile with "
          "WITH_ASCEND_CL "
          "option."));
D
dzhwinter 已提交
223 224 225 226 227
#endif
    }
  }
}

228 229 230 231
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
232
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
233 234 235 236 237 238 239
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
240
Place CPUDeviceContext::GetPlace() const { return place_; }
241

J
jianghaicheng 已提交
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
#ifdef PADDLE_WITH_IPU
IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) {
  int id = place.GetDeviceId();
  std::shared_ptr<platform::ipu::IpuBackend> ipu_backend =
      platform::ipu::IpuBackend::GetInstance();
  device_ = ipu_backend->GetDevice(id);
}

Place IPUDeviceContext::GetPlace() const { return place_; }
void IPUDeviceContext::Wait() const {
  /*! \brief  Wait for all operations completion in the stream. */
}

IPUDeviceContext::~IPUDeviceContext() {}

#endif
258
#ifdef PADDLE_WITH_XPU
Q
QingshuChen 已提交
259 260 261 262
XPUDeviceContext::XPUDeviceContext() {
  context_ = xpu::create_context();
  xpu_version_ = get_xpu_version(place_.device);
}
263

264
XPUDeviceContext::~XPUDeviceContext() {}
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279

XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) {
  int dev_id = -1;
  int ret = xpu_current_device(&dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  ret = xpu_set_device(place.device);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
280 281 282

  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: " << place_.device;

283
  context_ = xpu::create_context();
284 285 286
  const int MAX_XPU_NUM = 16;
  static void* l3ptrs[MAX_XPU_NUM] = {nullptr};

287 288 289 290 291
  int l3_size = 13.5 * 1024 * 1024;
  if (std::getenv("XPU_PADDLE_L3_SIZE") != nullptr) {
    l3_size = atoi(std::getenv("XPU_PADDLE_L3_SIZE"));
  }

292 293 294 295 296 297 298 299 300 301 302 303 304
  auto selected_xpus = GetXPUSelectedDevices();
  for (unsigned int i = 0; i < selected_xpus.size(); i++) {
    if (place.device == selected_xpus[i]) {
      if (l3ptrs[place.device] == nullptr) {
        xpu_malloc(static_cast<void**>(&l3ptrs[place.device]), l3_size,
                   XPU_MEM_L3);
      }
      if (l3ptrs[place.device] != nullptr) {
        context_->_l3_mgr.set(l3ptrs[place.device], l3_size);
        VLOG(3) << "xpu place " << place.device << " set l3 size " << l3_size;
      }
      break;
    }
305
  }
306

307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
  ret = xpu_set_device(dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
}

void XPUDeviceContext::Wait() const {
  int ret = xpu_set_device(place_.device);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
322
  xpu_wait(context_->xpu_stream);
323 324 325 326 327 328 329
}

Place XPUDeviceContext::GetPlace() const { return place_; }

xpu::Context* XPUDeviceContext::x_context() const { return context_; }
#endif

330 331 332 333 334 335 336
#ifdef PADDLE_WITH_ASCEND_CL
NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) {
  NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device));
  // NOTE(zhiqiu): Usually, no need to create context explicitly,
  // ACL creates a default context which contains 1 default stream
  // and 1 sync strean after aclrtSetDevice.
337
  platform::GetCurrentNPUContext(&context_);
338 339 340 341 342 343 344
  stream_.reset(new stream::NPUStream(place));
}

NPUDeviceContext::~NPUDeviceContext() {
  // NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_));
}
345

346
void NPUDeviceContext::Wait() const {
347 348 349
  platform::RecordEvent record_event("NPUDeviceContext/wait");
  VLOG(4) << "NPU context(" << this << ")  Wait";
  stream_->Wait();
350 351 352 353 354 355 356
}

aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); }

Place NPUDeviceContext::GetPlace() const { return place_; }

aclrtContext NPUDeviceContext::context() const { return context_; }
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372

NPUPinnedDeviceContext::NPUPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

NPUPinnedDeviceContext::NPUPinnedDeviceContext(NPUPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* NPUPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place NPUPinnedDeviceContext::GetPlace() const { return place_; }

373 374 375
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
init  
qijun 已提交
376 377 378 379 380 381 382
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

383
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
384 385 386 387 388
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

389
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
390

391 392 393
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
394
  const cudaDeviceProp& deviceProperties() const override {
395
#endif
Q
init  
qijun 已提交
396 397 398 399
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
400 401 402
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
403 404 405
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
406
    void* retv = buf->ptr();
S
sneaxiy 已提交
407 408 409 410
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
411
    return retv;
Q
init  
qijun 已提交
412 413
  }

S
sneaxiy 已提交
414 415 416 417 418 419
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
420 421 422

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
423
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
424 425 426 427 428 429
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
430
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
431
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
432
#ifdef PADDLE_WITH_HIP
433
      PADDLE_ENFORCE_GPU_SUCCESS(
434 435
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
436
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
init  
qijun 已提交
437
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
438
#endif
Q
init  
qijun 已提交
439 440 441 442 443
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
444
  CUDAPlace place_;
445 446 447 448
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
449
  const cudaDeviceProp* device_prop_;  // not owned;
450
#endif
Q
qijun 已提交
451
  mutable void* scratch_;
Q
init  
qijun 已提交
452
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
453
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
454
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
455 456
};

457 458 459 460 461 462 463 464 465
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

466 467 468 469 470 471 472 473 474 475 476 477
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
478 479
                         const stream::Priority& priority,
                         const stream::StreamFlag& flag) {
480 481
  place_ = place;
  CUDADeviceGuard guard(place_.device);
482
  stream_.reset(new stream::CUDAStream(place, priority, flag));
483 484 485
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
486
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
487
  InitCuSolverContext();
488
#endif
489 490
}

W
Wilber 已提交
491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510
void CUDAContext::SetStream(gpuStream_t stream) {
  if (stream_->raw_stream() != stream) {
    CUDADeviceGuard guard(place_.device);
    DestoryCuDNNContext();
    DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
    DestoryCuSolverContext();
#endif

    stream_->SetStream(stream);

    InitEigenContext();
    InitCuBlasContext();
    InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
    InitCuSolverContext();
#endif
  }
}

511 512 513 514
CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
515
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
516
  DestoryCuSolverContext();
517
#endif
518 519
}

520
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
Y
Yu Yang 已提交
521
  CUDADeviceGuard guard(place_.device);
522 523 524
  compute_capability_ = GetGPUComputeCapability(place_.device);
  multi_process_ = GetGPUMultiProcessors(place_.device);
  max_threads_per_mp_ = GetGPUMaxThreadsPerMultiProcessor(place_.device);
525
  max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
526
  max_threads_per_block_ = GetGPUMaxThreadsPerBlock(place_.device);
527

528 529
  driver_version_ = GetGPUDriverVersion(place_.device);
  runtime_version_ = GetGPURuntimeVersion(place_.device);
C
chengduo 已提交
530

531
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device
532 533 534
                          << ", GPU Compute Capability: "
                          << compute_capability_ / 10 << "."
                          << compute_capability_ % 10
C
chengduo 已提交
535
                          << ", Driver API Version: " << driver_version_ / 1000
536
                          << "." << (driver_version_ % 100) / 10
C
chengduo 已提交
537 538 539
                          << ", Runtime API Version: "
                          << runtime_version_ / 1000 << "."
                          << (runtime_version_ % 100) / 10;
540 541
#ifdef PADDLE_WITH_HIP
  size_t version_major, version_minor, version_patch;
542
  PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenGetVersion(
543 544 545 546 547
      &version_major, &version_minor, &version_patch));
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", MIOpen Version: " << version_major << "."
                          << version_minor << "." << version_patch;
#else
548 549 550
  size_t cudnn_dso_ver = dynload::cudnnGetVersion();
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
551
                          << (cudnn_dso_ver % 1000) / 100 << ".";
552
#endif
S
sneaxiy 已提交
553 554
  {
    // Check CUDA/CUDNN version compatiblity
555 556
    auto local_cuda_version =
        (driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
557 558 559
#ifdef PADDLE_WITH_HIP
    auto compile_cuda_version = (HIP_VERSION / 100) * 10 + (HIP_VERSION % 10);
#else
560 561
    auto compile_cuda_version =
        (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
562
#endif
S
sneaxiy 已提交
563 564 565 566 567 568 569 570 571 572 573 574
    if (local_cuda_version < compile_cuda_version) {
      LOG_FIRST_N(WARNING, 1)
          << "WARNING: device: " << place_.device
          << ". The installed Paddle is compiled with CUDA "
          << compile_cuda_version / 10 << "." << compile_cuda_version % 10
          << ", but CUDA runtime version in your machine is "
          << local_cuda_version / 10 << "." << local_cuda_version % 10
          << ", which may cause serious incompatible bug. "
          << "Please recompile or reinstall Paddle with compatible CUDA "
             "version.";
    }
  }
575
  default_ctx_.reset(new CUDAContext(place_));
576 577 578 579
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
580
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
581
  if (nccl_comm_) {
582
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
583 584
  }
#endif
585 586
}

L
liaogang 已提交
587
Place CUDADeviceContext::GetPlace() const { return place_; }
588

589
void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); }
590

K
Kexin Zhao 已提交
591
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
592
  return compute_capability_;
K
Kexin Zhao 已提交
593 594
}

595
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
596
  return multi_process_ * max_threads_per_mp_;
597 598
}

599 600 601 602 603 604
int CUDADeviceContext::GetSMCount() const { return multi_process_; }

int CUDADeviceContext::GetMaxThreadsPerBlock() const {
  return max_threads_per_block_;
}

605
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
606
  return context()->EigenDevice().get();
607 608
}

609
bool CUDADeviceContext::tensor_core_available() const {
610
  return context()->CublasTensorCoreHandle() != nullptr;
S
sneaxiy 已提交
611 612
}

613 614 615 616
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
  return max_grid_dim_size_;
}

617 618 619
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
620
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
621
#endif
622 623
  return context()->CudnnHandle();
}
624

625 626 627 628 629
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}
#else
630 631 632
cublasHandle_t CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}
633
#endif
634

S
sneaxiy 已提交
635
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
636
  return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
637
}
638

639
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
640 641 642
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  return context()->CusolverDnHandle();
}
643
#endif
G
Guo Sheng 已提交
644

645
gpuStream_t CUDADeviceContext::stream() const { return context()->RawStream(); }
Q
qijun 已提交
646

C
chengduoZH 已提交
647 648 649 650 651 652 653 654 655 656 657 658 659 660
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
661
#endif
Q
qijun 已提交
662

T
tensor-tang 已提交
663 664
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
665
    : CPUDeviceContext(place), p_blobmap_() {
666
  p_blobmap_.reset(new BlobMap());
667
  p_exec_items_.reset(new ExecShape());
668
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
669 670
}

671
MKLDNNDeviceContextThreadLocals::Body::Body()
672
    : cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
673 674 675 676 677 678
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

679 680 681 682 683 684 685 686 687 688 689 690
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
691
  dev_ctx->ResetBlobMap(exec_ptr_);
692 693
}

694 695 696 697 698 699 700 701 702 703
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
704 705
  cur_input_shape_str = input_shape_str;
}
706 707
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
708 709
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
710

711 712
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
713 714 715
  cur_paddle_data_layout = dl;
}

716 717
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
718 719 720
  return cur_paddle_data_layout;
}

721 722 723 724 725 726 727 728 729
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

730
const dnnl::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
731 732 733
  return cur_engine;
}

734
dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
735 736 737
  return cur_stream;
}

738
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
739 740 741
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
742 743 744 745 746 747
    // If no specific executor pointer then clear
    // everything. For executor pointer then clear only
    // objects allocated when using given executor
    if (ptr == nullptr) {
      p_blobmap_->clear();
    } else {
748 749 750 751 752
      // Iterate through all shapes and release
      // for each shape and active executor all entries
      // of this executor
      for (auto& s : *p_exec_items_) {
        for (auto& v : (*s.second)[ptr]) {
753
          (v.first)->erase(v.second);
754 755
        }
        s.second->erase(ptr);
756 757
      }
    }
758 759 760 761 762 763
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

764 765
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
  p_exec_items_->erase(p_exec_items_->begin());
766 767
}

768 769
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
                                                KeyBlob::iterator it) const {
770
  // Take current input shape from TLS
771 772
  // Take current executor addess from TLS
  // and for this executor's items add the one defined with arguments
773 774 775 776 777 778 779 780 781
  auto key_it = p_exec_items_
                    ->insert(std::make_pair(tls().cur_input_shape_str,
                                            std::make_shared<ExecMap>()))
                    .first;
  (*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));

  VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
          << " curr exec size: "
          << (*key_it->second)[tls().get_curr_exec()].size() << "\n";
782 783
}

784 785 786 787
void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
788
}
789

790
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
791
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
792
  BlobMap* pMap = p_blobmap_.get();
793
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
794
  if (map_it == pMap->end()) {
795 796 797
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
798 799 800 801
  }
  return map_it->second->size();
}

802
void MKLDNNDeviceContext::SetBlob(const std::string& name,
803
                                  BlobPtr_t<void> data) const {
804
  BlobMap* pMap = p_blobmap_.get();
805
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
806
  BlobPtr_t<KeyBlob> pBlob = nullptr;
807

808
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
809

810
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
811

812 813
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
814 815 816

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
817
    sBlob = std::make_shared<ShapeBlob>();
818 819
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
820
  } else {
821
    sBlob = map_it->second;
822
  }
T
tensor-tang 已提交
823

824
  // Find KeyBlob for current input shape
825
  auto key_it = sBlob->find(tls().cur_input_shape_str);
826

827
  if (key_it == sBlob->end()) {
828 829
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
830 831
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
832
        sBlob->size() &&
833
        (sBlob->size() >=
834
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
835 836 837 838
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
      RemoveShapeEntriesWithExecutor();
839
    }
840
    pBlob = std::make_shared<KeyBlob>();
841
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
842
  } else {
843
    pBlob = key_it->second;
844 845
  }

846
  // Find Blob via name
847 848 849 850
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    auto el =
        pBlob->insert(std::make_pair(name, data));  //  (*pBlob)[name] = data;
851 852 853
    // Register new element in per executor map
    // to have easily erased when executor terminated
    LinkEntryWithExecutor(pBlob, el.first);
854 855 856
  } else {
    blob_it->second = data;  // set data to existing blob
  }
857
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
858
  // lock will be automatically released when out of scope
859
  return;
T
tensor-tang 已提交
860 861
}

862
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
863 864 865
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
866
      num_entries += (l2.second)->size();
867 868 869 870 871
    }
  }
  return num_entries;
}

872
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
873
    const std::string& name) const {
874
  BlobMap* pMap = p_blobmap_.get();
875
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
876
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
877

878
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
879

880
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
881

882 883
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
884
  if (map_it == pMap->end()) {
885
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
886 887 888 889 890
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
891
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
892
  if (sBlob_it == sBlob->end()) {
893
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
894 895 896 897
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
898 899

  // Find Blob via name
900
  auto key_it = pBlob->find(name);
901

902
  if (key_it == pBlob->end()) {
903
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
904 905
    return nullptr;
  }
906

907
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
908 909
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
910 911 912
}

#endif
Q
qijun 已提交
913
}  // namespace platform
Q
qijun 已提交
914
}  // namespace paddle