device_context.cc 29.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
13
#include <memory>
14
#include <set>
15

16
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
17
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
18
#include "paddle/fluid/platform/cuda_device_guard.h"
19
#endif
F
fwenguang 已提交
20 21 22 23
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/fluid/platform/device/mlu/device_context_allocator.h"
#endif
24
#include "glog/logging.h"
25
#include "paddle/fluid/framework/expect.h"
26
#include "paddle/fluid/memory/allocation/allocator_facade.h"
27
#include "paddle/fluid/platform/profiler.h"
28

29 30 31 32 33
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
34
  if (size == 0) {
35 36
    return Alloc(place, size);
  }
37 38

  if (platform::is_gpu_place(place)) {
39
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
58 59
    return Alloc(place, size);
#else
60 61 62
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
F
fwenguang 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
#endif
  } else if (platform::is_mlu_place(place)) {
#ifdef PADDLE_WITH_MLU
    auto* default_dev_ctx = static_cast<platform::MLUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::MLUDeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::MLUDeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use MLU device since it's not compiled with MLU,"
        "Please recompile or reinstall Paddle with MLU support."));
80
#endif
81 82 83
  } else {
    return Alloc(place, size);
  }
84 85 86 87 88
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
89 90 91
namespace paddle {
namespace platform {

92
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
93 94 95
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
96 97 98 99

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
100 101
#endif  // PADDLE_WITH_CUDA

102 103 104 105 106 107 108
DeviceType Place2DeviceType(const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
    return platform::DeviceType::CPU;
  } else if (platform::is_gpu_place(place)) {
    return platform::DeviceType::CUDA;
  } else if (platform::is_xpu_place(place)) {
    return platform::DeviceType::XPU;
F
fwenguang 已提交
109 110
  } else if (platform::is_mlu_place(place)) {
    return platform::DeviceType::MLU;
111 112 113 114 115 116
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported place %s to convert into platform::DeviceType.", place));
  }
}

D
dzhwinter 已提交
117 118
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
119
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
120
  VLOG(6) << "DeviceContextPool Get: " << place;
D
dzhwinter 已提交
121 122
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
123 124
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
F
fwenguang 已提交
125 126
        "with WITH_GPU, WITH_XPU, WITH_IPU, WITH_MLU or WITH_ASCEND_CL option "
        "or check "
J
jianghaicheng 已提交
127 128
        "that your train process set the correct device id if you use "
        "Executor.",
G
GaoWei8 已提交
129
        place));
D
dzhwinter 已提交
130
  }
131
  return it->second.get().get();
D
dzhwinter 已提交
132 133
}

W
Wilber 已提交
134
template <typename DevCtx>
135 136 137 138 139
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
  map_ptr->emplace(
      p, std::async(std::launch::deferred, [=] {
        // lazy evaluation. i.e., only create device context at
        // first `Get`
        auto* dev_ctx = new DevCtx(p);
        if (is_gpu_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
          auto* cuda_ctx = dynamic_cast<CUDADeviceContext*>(dev_ctx);
          PADDLE_ENFORCE_NOT_NULL(
              cuda_ctx,
              platform::errors::InvalidArgument(
                  "Failed to dynamic_cast dev_ctx into CUDADeviceContext."));
          dev_ctx->SetDeviceAllocator(
              memory::allocation::AllocatorFacade::Instance()
                  .GetAllocator(p, cuda_ctx->context()->RawStream())
                  .get());
#endif
        } else {
          dev_ctx->SetDeviceAllocator(
              memory::allocation::AllocatorFacade::Instance()
                  .GetAllocator(p)
                  .get());
        }
        dev_ctx->SetHostAllocator(
            memory::allocation::AllocatorFacade::Instance()
                .GetAllocator(platform::CPUPlace())
                .get());
        dev_ctx->SetZeroAllocator(
            memory::allocation::AllocatorFacade::Instance()
                .GetZeroAllocator(p)
                .get());
        return PtrType(dev_ctx);
      }));
C
chengduozh 已提交
173 174
}

D
dzhwinter 已提交
175 176
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
177 178 179 180 181
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
182
  std::set<Place> set;
Y
Yu Yang 已提交
183 184 185 186 187
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
188
#ifdef PADDLE_WITH_MKLDNN
W
Wilber 已提交
189
      EmplaceDeviceContext<MKLDNNDeviceContext>(&device_contexts_, p);
190
#else
W
Wilber 已提交
191
      EmplaceDeviceContext<CPUDeviceContext>(&device_contexts_, p);
192
#endif
Y
Yu Yang 已提交
193
    } else if (platform::is_gpu_place(p)) {
194
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
195
      EmplaceDeviceContext<CUDADeviceContext>(&device_contexts_, p);
D
dzhwinter 已提交
196
#else
G
GaoWei8 已提交
197 198 199
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
200 201
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
202
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
203
      EmplaceDeviceContext<CUDAPinnedDeviceContext>(&device_contexts_, p);
C
chengduoZH 已提交
204
#else
G
GaoWei8 已提交
205
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
206 207
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
208 209 210
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
W
Wilber 已提交
211
      EmplaceDeviceContext<XPUDeviceContext>(&device_contexts_, p);
212 213 214 215
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
F
fwenguang 已提交
216 217 218
#endif
    } else if (platform::is_mlu_place(p)) {
#ifdef PADDLE_WITH_MLU
W
Wilber 已提交
219
      EmplaceDeviceContext<MLUDeviceContext>(&device_contexts_, p);
F
fwenguang 已提交
220 221 222 223
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("MLUPlace is not supported. Please "
                                          "re-compile with WITH_MLU option."));
J
jianghaicheng 已提交
224 225 226
#endif
    } else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU
W
Wilber 已提交
227
      EmplaceDeviceContext<IPUDeviceContext>(&device_contexts_, p);
J
jianghaicheng 已提交
228 229 230 231
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("IPUPlace is not supported. Please "
                                          "re-compile with WITH_IPU option."));
232 233 234
#endif
    } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
235
      EmplaceDeviceContext<NPUDeviceContext>(&device_contexts_, p);
236 237 238 239
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPlace is not supported. Please "
          "re-compile with WITH_ASCEND_CL option."));
240 241 242
#endif
    } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
243
      EmplaceDeviceContext<NPUPinnedDeviceContext>(&device_contexts_, p);
244 245 246 247 248
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPinnedPlace is not supported. Please re-compile with "
          "WITH_ASCEND_CL "
          "option."));
D
dzhwinter 已提交
249 250 251 252 253
#endif
    }
  }
}

W
Wilber 已提交
254
CPUDeviceContext::CPUDeviceContext() : pten::CPUContext() {}
255

W
Wilber 已提交
256
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : pten::CPUContext() {}
257

J
jianghaicheng 已提交
258
#ifdef PADDLE_WITH_IPU
A
Allen Guo 已提交
259
IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) {}
J
jianghaicheng 已提交
260 261

Place IPUDeviceContext::GetPlace() const { return place_; }
A
Allen Guo 已提交
262

J
jianghaicheng 已提交
263 264 265 266 267 268 269
void IPUDeviceContext::Wait() const {
  /*! \brief  Wait for all operations completion in the stream. */
}

IPUDeviceContext::~IPUDeviceContext() {}

#endif
270
#ifdef PADDLE_WITH_XPU
W
Wilber 已提交
271
XPUDeviceContext::XPUDeviceContext() : pten::XPUContext() {}
272

273
XPUDeviceContext::~XPUDeviceContext() {}
274

W
Wilber 已提交
275
XPUDeviceContext::XPUDeviceContext(XPUPlace place) : pten::XPUContext(place) {
276
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: "
W
Wilber 已提交
277
                          << static_cast<int>(place.device);
278 279 280
}
#endif

281 282 283 284 285 286 287
#ifdef PADDLE_WITH_ASCEND_CL
NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) {
  NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device));
  // NOTE(zhiqiu): Usually, no need to create context explicitly,
  // ACL creates a default context which contains 1 default stream
  // and 1 sync strean after aclrtSetDevice.
288
  platform::GetCurrentNPUContext(&context_);
289 290 291 292 293 294 295
  stream_.reset(new stream::NPUStream(place));
}

NPUDeviceContext::~NPUDeviceContext() {
  // NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_));
}
296

297
void NPUDeviceContext::Wait() const {
298 299 300
  platform::RecordEvent record_event("NPUDeviceContext/wait");
  VLOG(4) << "NPU context(" << this << ")  Wait";
  stream_->Wait();
301 302 303 304 305 306 307
}

aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); }

Place NPUDeviceContext::GetPlace() const { return place_; }

aclrtContext NPUDeviceContext::context() const { return context_; }
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323

NPUPinnedDeviceContext::NPUPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

NPUPinnedDeviceContext::NPUPinnedDeviceContext(NPUPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* NPUPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place NPUPinnedDeviceContext::GetPlace() const { return place_; }

324 325 326
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
init  
qijun 已提交
327 328 329 330 331 332 333
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

334
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
335 336 337 338 339
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

340
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
341

342 343 344
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
345
  const cudaDeviceProp& deviceProperties() const override {
346
#endif
Q
init  
qijun 已提交
347 348 349 350
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
351 352 353
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
354 355 356
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
357
    void* retv = buf->ptr();
S
sneaxiy 已提交
358 359 360 361
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
362
    return retv;
Q
init  
qijun 已提交
363 364
  }

S
sneaxiy 已提交
365 366 367 368 369 370
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
371 372 373

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
374
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
375 376 377 378 379 380
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
381
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
382
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
383
#ifdef PADDLE_WITH_HIP
384
      PADDLE_ENFORCE_GPU_SUCCESS(
385 386
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
387
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
init  
qijun 已提交
388
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
389
#endif
Q
init  
qijun 已提交
390 391 392 393 394
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
395
  CUDAPlace place_;
396 397 398 399
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
400
  const cudaDeviceProp* device_prop_;  // not owned;
401
#endif
Q
qijun 已提交
402
  mutable void* scratch_;
Q
init  
qijun 已提交
403
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
404
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
405
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
406 407
};

408 409 410 411 412 413 414 415 416
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

417 418 419 420 421 422 423 424 425 426 427 428
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
429 430
                         const stream::Priority& priority,
                         const stream::StreamFlag& flag) {
431 432
  place_ = place;
  CUDADeviceGuard guard(place_.device);
433
  stream_.reset(new stream::CUDAStream(place, priority, flag));
434 435 436
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
437
#ifndef PADDLE_WITH_HIP
Z
zhangkaihuo 已提交
438
  InitCuSparseContext();
G
Guo Sheng 已提交
439
  InitCuSolverContext();
440
#endif
441 442
}

W
Wilber 已提交
443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462
void CUDAContext::SetStream(gpuStream_t stream) {
  if (stream_->raw_stream() != stream) {
    CUDADeviceGuard guard(place_.device);
    DestoryCuDNNContext();
    DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
    DestoryCuSolverContext();
#endif

    stream_->SetStream(stream);

    InitEigenContext();
    InitCuBlasContext();
    InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
    InitCuSolverContext();
#endif
  }
}

463 464 465 466
CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
467
#ifndef PADDLE_WITH_HIP
Z
zhangkaihuo 已提交
468
  DestoryCuSparseContext();
G
Guo Sheng 已提交
469
  DestoryCuSolverContext();
470
#endif
471 472
}

473
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
Y
Yu Yang 已提交
474
  CUDADeviceGuard guard(place_.device);
475 476 477
  compute_capability_ = GetGPUComputeCapability(place_.device);
  multi_process_ = GetGPUMultiProcessors(place_.device);
  max_threads_per_mp_ = GetGPUMaxThreadsPerMultiProcessor(place_.device);
478
  max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
479
  max_threads_per_block_ = GetGPUMaxThreadsPerBlock(place_.device);
480

481 482
  driver_version_ = GetGPUDriverVersion(place_.device);
  runtime_version_ = GetGPURuntimeVersion(place_.device);
C
chengduo 已提交
483

484 485
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: "
                          << static_cast<int>(place_.device)
486 487 488
                          << ", GPU Compute Capability: "
                          << compute_capability_ / 10 << "."
                          << compute_capability_ % 10
C
chengduo 已提交
489
                          << ", Driver API Version: " << driver_version_ / 1000
490
                          << "." << (driver_version_ % 100) / 10
C
chengduo 已提交
491 492 493
                          << ", Runtime API Version: "
                          << runtime_version_ / 1000 << "."
                          << (runtime_version_ % 100) / 10;
494 495
#ifdef PADDLE_WITH_HIP
  size_t version_major, version_minor, version_patch;
496
  PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenGetVersion(
497
      &version_major, &version_minor, &version_patch));
498
  LOG_FIRST_N(WARNING, 1) << "device: " << static_cast<int>(place_.device)
499 500 501
                          << ", MIOpen Version: " << version_major << "."
                          << version_minor << "." << version_patch;
#else
502
  size_t cudnn_dso_ver = dynload::cudnnGetVersion();
503
  LOG_FIRST_N(WARNING, 1) << "device: " << static_cast<int>(place_.device)
504
                          << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
505
                          << (cudnn_dso_ver % 1000) / 100 << ".";
506
#endif
S
sneaxiy 已提交
507 508
  {
    // Check CUDA/CUDNN version compatiblity
509 510
    auto local_cuda_version =
        (driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
511 512 513
#ifdef PADDLE_WITH_HIP
    auto compile_cuda_version = (HIP_VERSION / 100) * 10 + (HIP_VERSION % 10);
#else
514 515
    auto compile_cuda_version =
        (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
516
#endif
S
sneaxiy 已提交
517 518
    if (local_cuda_version < compile_cuda_version) {
      LOG_FIRST_N(WARNING, 1)
519
          << "WARNING: device: " << static_cast<int>(place_.device)
S
sneaxiy 已提交
520 521 522 523 524 525 526 527 528
          << ". The installed Paddle is compiled with CUDA "
          << compile_cuda_version / 10 << "." << compile_cuda_version % 10
          << ", but CUDA runtime version in your machine is "
          << local_cuda_version / 10 << "." << local_cuda_version % 10
          << ", which may cause serious incompatible bug. "
          << "Please recompile or reinstall Paddle with compatible CUDA "
             "version.";
    }
  }
529
  default_ctx_.reset(new CUDAContext(place_));
530 531 532 533
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
534
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
535
  if (nccl_comm_) {
536
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
537 538
  }
#endif
539 540
}

L
liaogang 已提交
541
Place CUDADeviceContext::GetPlace() const { return place_; }
542

543
void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); }
544

K
Kexin Zhao 已提交
545
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
546
  return compute_capability_;
K
Kexin Zhao 已提交
547 548
}

549
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
550
  return multi_process_ * max_threads_per_mp_;
551 552
}

553 554 555 556 557 558
int CUDADeviceContext::GetSMCount() const { return multi_process_; }

int CUDADeviceContext::GetMaxThreadsPerBlock() const {
  return max_threads_per_block_;
}

559
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
560
  return context()->EigenDevice().get();
561 562
}

563
bool CUDADeviceContext::tensor_core_available() const {
564
  return context()->CublasTensorCoreHandle() != nullptr;
S
sneaxiy 已提交
565 566
}

567 568 569 570
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
  return max_grid_dim_size_;
}

571 572 573
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
574
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
575
#endif
576 577
  return context()->CudnnHandle();
}
578

579 580 581 582 583
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}
#else
584 585 586
cublasHandle_t CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}
Z
zhangkaihuo 已提交
587 588 589
cusparseHandle_t CUDADeviceContext::cusparse_handle() const {
  return context()->CusparseHandle()->GetCusparseHandle();
}
590
#endif
591

S
sneaxiy 已提交
592
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
593
  return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
594
}
595

596
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
597 598 599
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  return context()->CusolverDnHandle();
}
600
#endif
G
Guo Sheng 已提交
601

602
gpuStream_t CUDADeviceContext::stream() const { return context()->RawStream(); }
Q
qijun 已提交
603

C
chengduoZH 已提交
604 605 606 607 608 609 610 611 612 613 614 615 616 617
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
618
#endif
Q
qijun 已提交
619

T
tensor-tang 已提交
620 621
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
622
    : CPUDeviceContext(place), p_blobmap_() {
623
  p_blobmap_.reset(new BlobMap());
624
  p_exec_items_.reset(new ExecShape());
625
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
626 627
}

628
MKLDNNDeviceContextThreadLocals::Body::Body()
629
    : cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
630 631 632 633 634 635
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

636 637 638 639 640 641 642 643 644 645 646 647
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
648
  dev_ctx->ResetBlobMap(exec_ptr_);
649 650
}

651 652 653 654 655 656 657 658 659 660
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
661 662
  cur_input_shape_str = input_shape_str;
}
663 664
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
665 666
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
667

668 669
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
670 671 672
  cur_paddle_data_layout = dl;
}

673 674
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
675 676 677
  return cur_paddle_data_layout;
}

678 679 680 681 682 683 684 685 686
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

687
const dnnl::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
688 689 690
  return cur_engine;
}

691
dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
692 693 694
  return cur_stream;
}

695
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
696 697 698
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
699 700 701 702 703 704
    // If no specific executor pointer then clear
    // everything. For executor pointer then clear only
    // objects allocated when using given executor
    if (ptr == nullptr) {
      p_blobmap_->clear();
    } else {
705 706 707 708 709
      // Iterate through all shapes and release
      // for each shape and active executor all entries
      // of this executor
      for (auto& s : *p_exec_items_) {
        for (auto& v : (*s.second)[ptr]) {
710
          (v.first)->erase(v.second);
711 712
        }
        s.second->erase(ptr);
713 714
      }
    }
715 716 717 718 719 720
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

721 722
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
  p_exec_items_->erase(p_exec_items_->begin());
723 724
}

725 726
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
                                                KeyBlob::iterator it) const {
727
  // Take current input shape from TLS
728 729
  // Take current executor addess from TLS
  // and for this executor's items add the one defined with arguments
730 731 732 733 734 735 736 737 738
  auto key_it = p_exec_items_
                    ->insert(std::make_pair(tls().cur_input_shape_str,
                                            std::make_shared<ExecMap>()))
                    .first;
  (*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));

  VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
          << " curr exec size: "
          << (*key_it->second)[tls().get_curr_exec()].size() << "\n";
739 740
}

741 742 743 744
void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
745
}
746

747
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
748
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
749
  BlobMap* pMap = p_blobmap_.get();
750
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
751
  if (map_it == pMap->end()) {
752 753 754
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
755 756 757 758
  }
  return map_it->second->size();
}

759
void MKLDNNDeviceContext::SetBlob(const std::string& name,
760
                                  BlobPtr_t<void> data) const {
761
  BlobMap* pMap = p_blobmap_.get();
762
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
763
  BlobPtr_t<KeyBlob> pBlob = nullptr;
764

765
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
766

767
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
768

769 770
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
771 772 773

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
774
    sBlob = std::make_shared<ShapeBlob>();
775 776
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
777
  } else {
778
    sBlob = map_it->second;
779
  }
T
tensor-tang 已提交
780

781
  // Find KeyBlob for current input shape
782
  auto key_it = sBlob->find(tls().cur_input_shape_str);
783

784
  if (key_it == sBlob->end()) {
785 786
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
787 788
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
789
        sBlob->size() &&
790
        (sBlob->size() >=
791
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
792 793 794 795
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
      RemoveShapeEntriesWithExecutor();
796
    }
797
    pBlob = std::make_shared<KeyBlob>();
798
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
799
  } else {
800
    pBlob = key_it->second;
801 802
  }

803
  // Find Blob via name
804 805 806 807
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    auto el =
        pBlob->insert(std::make_pair(name, data));  //  (*pBlob)[name] = data;
808 809 810
    // Register new element in per executor map
    // to have easily erased when executor terminated
    LinkEntryWithExecutor(pBlob, el.first);
811 812 813
  } else {
    blob_it->second = data;  // set data to existing blob
  }
814
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
815
  // lock will be automatically released when out of scope
816
  return;
T
tensor-tang 已提交
817 818
}

819
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
820 821 822
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
823
      num_entries += (l2.second)->size();
824 825 826 827 828
    }
  }
  return num_entries;
}

829
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
830
    const std::string& name) const {
831
  BlobMap* pMap = p_blobmap_.get();
832
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
833
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
834

835
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
836

837
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
838

839 840
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
841 842 843 844
  // (jczaja): After first iteration of model's execution we
  // should have all elements cached (mostly) so failures are unlikely (less
  // likely for dynamic shapes)
  if (unlikely(map_it == pMap->end())) {
845
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
846 847 848 849 850
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
851
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
852
  if (unlikely(sBlob_it == sBlob->end())) {
853
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
854 855 856 857
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
858 859

  // Find Blob via name
860
  auto key_it = pBlob->find(name);
861

862
  if (unlikely(key_it == pBlob->end())) {
863
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
864 865
    return nullptr;
  }
866

867
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
868 869
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
870 871 872
}

#endif
Q
qijun 已提交
873
}  // namespace platform
Q
qijun 已提交
874
}  // namespace paddle