device_context.cc 32.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3
Copyright (c) 2022 NVIDIA Corporation. All rights reserved.

Q
qijun 已提交
4 5 6 7
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
8

Q
qijun 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
14
#include "paddle/fluid/platform/device_context.h"
15

W
Wilber 已提交
16
#include <functional>
17
#include <memory>
18
#include <set>
19

W
Wilber 已提交
20 21
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
22 23
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/allocator.h"
24

25
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
26
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
27
#include "paddle/fluid/platform/cuda_device_guard.h"
28
#endif
F
fwenguang 已提交
29 30 31 32
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/fluid/platform/device/mlu/device_context_allocator.h"
#endif
33
#include "glog/logging.h"
34
#include "paddle/fluid/framework/expect.h"
W
Wilber 已提交
35
#include "paddle/fluid/framework/generator.h"
36
#include "paddle/fluid/memory/allocation/allocator_facade.h"
37
#include "paddle/fluid/platform/device/device_wrapper.h"
38
#include "paddle/fluid/platform/profiler.h"
39
#include "paddle/fluid/platform/profiler/event_tracing.h"
40

41 42 43 44 45
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
46
  if (size == 0) {
47 48
    return Alloc(place, size);
  }
49 50

  if (platform::is_gpu_place(place)) {
51
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
52 53 54 55 56
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
57 58 59
      return paddle::memory::Alloc(desired_dev_ctx.GetPlace(), size,
                                   phi::Stream(reinterpret_cast<phi::StreamId>(
                                       desired_dev_ctx.stream())));
60 61 62 63 64 65 66 67 68 69 70 71
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
72 73
    return Alloc(place, size);
#else
74 75 76
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
F
fwenguang 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
#endif
  } else if (platform::is_mlu_place(place)) {
#ifdef PADDLE_WITH_MLU
    auto* default_dev_ctx = static_cast<platform::MLUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::MLUDeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::MLUDeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use MLU device since it's not compiled with MLU,"
        "Please recompile or reinstall Paddle with MLU support."));
94
#endif
95 96 97
  } else {
    return Alloc(place, size);
  }
98 99 100 101 102
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
103 104 105
namespace paddle {
namespace platform {

106
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
107 108 109
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
110 111 112 113

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
114 115
#endif  // PADDLE_WITH_CUDA

116 117 118 119 120 121 122
DeviceType Place2DeviceType(const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
    return platform::DeviceType::CPU;
  } else if (platform::is_gpu_place(place)) {
    return platform::DeviceType::CUDA;
  } else if (platform::is_xpu_place(place)) {
    return platform::DeviceType::XPU;
F
fwenguang 已提交
123 124
  } else if (platform::is_mlu_place(place)) {
    return platform::DeviceType::MLU;
125 126 127 128 129 130
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported place %s to convert into platform::DeviceType.", place));
  }
}

D
dzhwinter 已提交
131
DeviceContextPool* DeviceContextPool::pool = nullptr;
132 133 134
thread_local const std::map<Place,
                            std::shared_future<std::unique_ptr<DeviceContext>>>*
    DeviceContextPool::external_device_contexts_ = nullptr;
D
dzhwinter 已提交
135

Y
Yu Yang 已提交
136
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
137
  VLOG(6) << "DeviceContextPool Get: " << place;
138 139 140 141 142 143 144 145 146 147
  const std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
      ptr;
  if (external_device_contexts_ && external_device_contexts_->count(place)) {
    ptr = external_device_contexts_;
  } else {
    ptr = &device_contexts_;
  }

  auto it = ptr->find(place);
  if (it == ptr->end()) {
G
GaoWei8 已提交
148 149
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
F
fwenguang 已提交
150 151
        "with WITH_GPU, WITH_XPU, WITH_IPU, WITH_MLU or WITH_ASCEND_CL option "
        "or check "
J
jianghaicheng 已提交
152 153
        "that your train process set the correct device id if you use "
        "Executor.",
G
GaoWei8 已提交
154
        place));
D
dzhwinter 已提交
155
  }
156
  return it->second.get().get();
D
dzhwinter 已提交
157 158
}

159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
size_t DeviceContextPool::size() const {
  if (external_device_contexts_) {
    return external_device_contexts_->size();
  }
  return device_contexts_.size();
}

const std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>&
DeviceContextPool::device_contexts() const {
  if (external_device_contexts_) {
    return *external_device_contexts_;
  }
  return device_contexts_;
}

void DeviceContextPool::SetDeviceContexts(
    const std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        dev_ctxs) {
  external_device_contexts_ = dev_ctxs;
}

W
Wilber 已提交
180
template <typename DevCtx>
181 182 183 184 185
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
186 187 188 189 190 191 192 193 194 195 196 197
  map_ptr->emplace(
      p, std::async(std::launch::deferred, [=] {
        // lazy evaluation. i.e., only create device context at
        // first `Get`
        auto* dev_ctx = new DevCtx(p);
        if (is_gpu_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
          auto* cuda_ctx = dynamic_cast<CUDADeviceContext*>(dev_ctx);
          PADDLE_ENFORCE_NOT_NULL(
              cuda_ctx,
              platform::errors::InvalidArgument(
                  "Failed to dynamic_cast dev_ctx into CUDADeviceContext."));
W
Wilber 已提交
198
          dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
199
                                    .GetAllocator(p)
W
Wilber 已提交
200
                                    .get());
W
wanghuancoder 已提交
201 202 203 204 205
          dev_ctx->SetPinnedAllocator(
              memory::allocation::AllocatorFacade::Instance()
                  .GetAllocator(paddle::platform::CUDAPinnedPlace())
                  .get());

W
Wilber 已提交
206
          cuda_ctx->PartialInitWithAllocator();
W
Wilber 已提交
207
          dev_ctx->SetGenerator(
208
              framework::DefaultCUDAGenerator(p.GetDeviceId()).get());
209 210
#endif
        } else {
W
Wilber 已提交
211 212 213
          dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
                                    .GetAllocator(p)
                                    .get());
W
Wilber 已提交
214
          dev_ctx->SetGenerator(framework::DefaultCPUGenerator().get());
215
        }
L
Leo Chen 已提交
216
        dev_ctx->SetHostGenerator(framework::DefaultCPUGenerator().get());
217 218 219 220 221 222 223 224 225 226
        dev_ctx->SetHostAllocator(
            memory::allocation::AllocatorFacade::Instance()
                .GetAllocator(platform::CPUPlace())
                .get());
        dev_ctx->SetZeroAllocator(
            memory::allocation::AllocatorFacade::Instance()
                .GetZeroAllocator(p)
                .get());
        return PtrType(dev_ctx);
      }));
C
chengduozh 已提交
227 228
}

D
dzhwinter 已提交
229 230
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
231 232 233 234 235
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
236
  std::set<Place> set;
Y
Yu Yang 已提交
237 238 239 240 241
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
242
#ifdef PADDLE_WITH_MKLDNN
W
Wilber 已提交
243
      EmplaceDeviceContext<MKLDNNDeviceContext>(&device_contexts_, p);
244
#else
W
Wilber 已提交
245
      EmplaceDeviceContext<CPUDeviceContext>(&device_contexts_, p);
246
#endif
Y
Yu Yang 已提交
247
    } else if (platform::is_gpu_place(p)) {
248
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
249
      EmplaceDeviceContext<CUDADeviceContext>(&device_contexts_, p);
D
dzhwinter 已提交
250
#else
G
GaoWei8 已提交
251 252 253
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
254 255
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
256
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
257
      EmplaceDeviceContext<CUDAPinnedDeviceContext>(&device_contexts_, p);
C
chengduoZH 已提交
258
#else
G
GaoWei8 已提交
259
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
260 261
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
262 263 264
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
W
Wilber 已提交
265
      EmplaceDeviceContext<XPUDeviceContext>(&device_contexts_, p);
266 267 268 269
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
F
fwenguang 已提交
270 271 272
#endif
    } else if (platform::is_mlu_place(p)) {
#ifdef PADDLE_WITH_MLU
W
Wilber 已提交
273
      EmplaceDeviceContext<MLUDeviceContext>(&device_contexts_, p);
F
fwenguang 已提交
274 275 276 277
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("MLUPlace is not supported. Please "
                                          "re-compile with WITH_MLU option."));
J
jianghaicheng 已提交
278 279 280
#endif
    } else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU
W
Wilber 已提交
281
      EmplaceDeviceContext<IPUDeviceContext>(&device_contexts_, p);
J
jianghaicheng 已提交
282 283 284 285
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("IPUPlace is not supported. Please "
                                          "re-compile with WITH_IPU option."));
286 287 288
#endif
    } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
289
      EmplaceDeviceContext<NPUDeviceContext>(&device_contexts_, p);
290 291 292 293
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPlace is not supported. Please "
          "re-compile with WITH_ASCEND_CL option."));
294 295 296
#endif
    } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
297
      EmplaceDeviceContext<NPUPinnedDeviceContext>(&device_contexts_, p);
298 299 300 301 302
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPinnedPlace is not supported. Please re-compile with "
          "WITH_ASCEND_CL "
          "option."));
303 304 305 306 307 308 309 310 311
#endif
    } else if (platform::is_custom_place(p)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
      EmplaceDeviceContext<CustomDeviceContext>(&device_contexts_, p);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "CustomPlace is not supported. Please re-compile with "
          "WITH_CUSTOM_DEVICE "
          "option."));
D
dzhwinter 已提交
312 313 314 315 316
#endif
    }
  }
}

317 318
CPUDeviceContext::CPUDeviceContext() : phi::CPUContext() {
  phi::CPUContext::Init();
W
Wilber 已提交
319
}
320

321 322
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : phi::CPUContext(place) {
  phi::CPUContext::Init();
W
Wilber 已提交
323
}
324

J
jianghaicheng 已提交
325
#ifdef PADDLE_WITH_IPU
A
Allen Guo 已提交
326
IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) {}
J
jianghaicheng 已提交
327

W
Wilber 已提交
328
const Place& IPUDeviceContext::GetPlace() const { return place_; }
A
Allen Guo 已提交
329

J
jianghaicheng 已提交
330 331 332 333 334 335 336
void IPUDeviceContext::Wait() const {
  /*! \brief  Wait for all operations completion in the stream. */
}

IPUDeviceContext::~IPUDeviceContext() {}

#endif
337
#ifdef PADDLE_WITH_XPU
338 339
XPUDeviceContext::XPUDeviceContext() : phi::XPUContext() {
  phi::XPUContext::Init();
W
Wilber 已提交
340
}
341

342
XPUDeviceContext::~XPUDeviceContext() {}
343

344 345
XPUDeviceContext::XPUDeviceContext(XPUPlace place) : phi::XPUContext(place) {
  phi::XPUContext::Init();
346
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: "
W
Wilber 已提交
347
                          << static_cast<int>(place.device);
348 349 350
}
#endif

351 352 353 354 355 356 357
#ifdef PADDLE_WITH_ASCEND_CL
NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) {
  NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device));
  // NOTE(zhiqiu): Usually, no need to create context explicitly,
  // ACL creates a default context which contains 1 default stream
  // and 1 sync strean after aclrtSetDevice.
358
  platform::GetCurrentNPUContext(&context_);
359 360 361 362 363 364 365
  stream_.reset(new stream::NPUStream(place));
}

NPUDeviceContext::~NPUDeviceContext() {
  // NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_));
}
366

367
void NPUDeviceContext::Wait() const {
368 369
  platform::RecordEvent record_event("NPUDeviceContext/wait",
                                     platform::TracerEventType::UserDefined, 2);
370 371
  VLOG(4) << "NPU context(" << this << ")  Wait";
  stream_->Wait();
372 373 374 375
}

aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); }

W
Wilber 已提交
376
const Place& NPUDeviceContext::GetPlace() const { return place_; }
377 378

aclrtContext NPUDeviceContext::context() const { return context_; }
379 380 381 382 383 384 385 386 387 388 389 390 391 392

NPUPinnedDeviceContext::NPUPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

NPUPinnedDeviceContext::NPUPinnedDeviceContext(NPUPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* NPUPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

W
Wilber 已提交
393
const Place& NPUPinnedDeviceContext::GetPlace() const { return place_; }
394

395 396 397
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
init  
qijun 已提交
398 399 400 401 402 403 404
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

405
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
406 407 408 409 410
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

411
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
412

413 414 415
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
416
  const cudaDeviceProp& deviceProperties() const override {
417
#endif
Q
init  
qijun 已提交
418 419 420 421
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
422 423 424
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
425 426 427
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
428
    void* retv = buf->ptr();
S
sneaxiy 已提交
429 430 431 432
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
433
    return retv;
Q
init  
qijun 已提交
434 435
  }

S
sneaxiy 已提交
436 437 438 439 440 441
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
442 443 444

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
445
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
446 447 448 449 450 451
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
452
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
453
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
454
#ifdef PADDLE_WITH_HIP
455
      PADDLE_ENFORCE_GPU_SUCCESS(
456 457
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
458
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
init  
qijun 已提交
459
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
460
#endif
Q
init  
qijun 已提交
461 462 463 464 465
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
466
  CUDAPlace place_;
467 468 469 470
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
471
  const cudaDeviceProp* device_prop_;  // not owned;
472
#endif
Q
qijun 已提交
473
  mutable void* scratch_;
Q
init  
qijun 已提交
474
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
475
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
476
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
477 478
};

479 480 481 482 483 484 485 486 487
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

488 489 490 491 492 493 494 495 496 497 498 499
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
500 501
                         const stream::Priority& priority,
                         const stream::StreamFlag& flag) {
502 503
  place_ = place;
  CUDADeviceGuard guard(place_.device);
504
  stream_.reset(new stream::CUDAStream(place, priority, flag));
505 506 507
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
508
#ifndef PADDLE_WITH_HIP
509 510 511
#if CUDA_VERSION >= 11060
  InitCuBlasLtContext();
#endif
Z
zhangkaihuo 已提交
512
  InitCuSparseContext();
G
Guo Sheng 已提交
513
  InitCuSolverContext();
514
#endif
515 516
}

W
Wilber 已提交
517 518 519 520 521 522
void CUDAContext::SetStream(gpuStream_t stream) {
  if (stream_->raw_stream() != stream) {
    CUDADeviceGuard guard(place_.device);
    DestoryCuDNNContext();
    DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
523 524 525
#if CUDA_VERSION >= 11060
    DestoryCuBlasLtContext();
#endif
W
Wilber 已提交
526 527 528 529 530 531 532 533 534
    DestoryCuSolverContext();
#endif

    stream_->SetStream(stream);

    InitEigenContext();
    InitCuBlasContext();
    InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
535 536 537
#if CUDA_VERSION >= 11060
    InitCuBlasLtContext();
#endif
W
Wilber 已提交
538 539 540 541 542
    InitCuSolverContext();
#endif
  }
}

543 544 545 546
CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
547
#ifndef PADDLE_WITH_HIP
548 549 550
#if CUDA_VERSION >= 11060
  InitCuBlasLtContext();
#endif
Z
zhangkaihuo 已提交
551
  DestoryCuSparseContext();
G
Guo Sheng 已提交
552
  DestoryCuSolverContext();
553
#endif
554 555
}

556 557 558
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : phi::GPUContext(place) {
  phi::GPUContext::PartialInitWithoutAllocator();
  cuda_stream_.reset(new stream::CUDAStream(phi::GPUContext::stream(), place));
559 560
  auto& instance = memory::allocation::AllocatorFacade::Instance();
  instance.SetDefaultStream(place, phi::GPUContext::stream());
561 562
  workspace_.reset(new phi::DnnWorkspaceHandle(
      instance.GetAllocator(place).get(), stream()));
563 564
}

W
Wilber 已提交
565
CUDADeviceContext::~CUDADeviceContext() = default;
566

567
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
W
Wilber 已提交
568 569 570
  if (thread_ctx_.count(this)) {
    return context()->EigenDevice().get();
  }
571
  return phi::GPUContext::eigen_device();
S
sneaxiy 已提交
572 573
}

W
Wilber 已提交
574
void CUDADeviceContext::Wait() const {
575
  VLOG(4) << "CUDA context(" << this << ")  Wait";
W
Wilber 已提交
576 577 578 579
  if (thread_ctx_.count(this)) {
    context()->Stream()->Wait();
    return;
  }
580
  phi::GPUContext::Wait();
581 582
}

583 584 585
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
586
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
587
#endif
W
Wilber 已提交
588 589 590
  if (thread_ctx_.count(this)) {
    return context()->CudnnHandle();
  }
591
  return phi::GPUContext::cudnn_handle();
592
}
593

594 595
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
W
Wilber 已提交
596 597 598
  if (thread_ctx_.count(this)) {
    return context()->CublasHandle()->GetCublasHandle();
  }
599
  return phi::GPUContext::cublas_handle();
600 601
}
#else
602
cublasHandle_t CUDADeviceContext::cublas_handle() const {
W
Wilber 已提交
603 604 605
  if (thread_ctx_.count(this)) {
    return context()->CublasHandle()->GetCublasHandle();
  }
606
  return phi::GPUContext::cublas_handle();
607
}
608 609 610 611 612 613 614 615
#if CUDA_VERSION >= 11060
cublasLtHandle_t CUDADeviceContext::cublaslt_handle() const {
  if (thread_ctx_.count(this)) {
    return context()->CublasLtHandle()->GetCublasLtHandle();
  }
  return phi::GPUContext::cublaslt_handle();
}
#endif
Z
zhangkaihuo 已提交
616
cusparseHandle_t CUDADeviceContext::cusparse_handle() const {
W
Wilber 已提交
617 618 619
  if (thread_ctx_.count(this)) {
    return context()->CusparseHandle()->GetCusparseHandle();
  }
620
  return phi::GPUContext::cusparse_handle();
W
Wilber 已提交
621 622 623 624 625
}
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  if (thread_ctx_.count(this)) {
    return context()->CusolverDnHandle();
  }
626
  return phi::GPUContext::cusolver_dn_handle();
Z
zhangkaihuo 已提交
627
}
628
#endif
629

W
Wilber 已提交
630 631 632 633 634 635
void CUDADeviceContext::RecordEvent(
    gpuEvent_t ev, const std::function<void()>& callback) const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->RecordEvent(ev, callback);
    return;
  }
636
  phi::GPUContext::RecordEvent(ev, callback);
W
Wilber 已提交
637 638 639 640 641 642 643 644
}

void CUDADeviceContext::AddStreamCallback(
    const std::function<void()>& callback) const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->AddCallback(callback);
    return;
  }
645
  phi::GPUContext::AddStreamCallback(callback);
W
Wilber 已提交
646 647 648 649 650 651 652
}

void CUDADeviceContext::WaitStreamCallback() const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->WaitCallback();
    return;
  }
653
  phi::GPUContext::WaitStreamCallback();
W
Wilber 已提交
654 655
}

656
phi::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
W
Wilber 已提交
657 658
  if (thread_ctx_.count(this)) {
    // return workspace_.get();
659
    return phi::DnnWorkspaceHandle(
W
Wilber 已提交
660
        memory::allocation::AllocatorFacade::Instance()
661
            .GetAllocator(GetPlace())
662 663
            .get(),
        stream());
W
Wilber 已提交
664
  }
665
  return phi::GPUContext::cudnn_workspace_handle();
666
}
667

W
Wilber 已提交
668 669 670 671
gpuStream_t CUDADeviceContext::stream() const {
  if (thread_ctx_.count(this)) {
    return context()->RawStream();
  }
672
  return phi::GPUContext::stream();
G
Guo Sheng 已提交
673 674
}

W
Wilber 已提交
675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693
std::shared_ptr<CUDAContext> CUDADeviceContext::context() const {
  if (!thread_ctx_.count(this)) {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "CUDADeviceContext call context() failed, make sure in the "
        "thread_local semantic."));
  }
  return thread_ctx_.at(this);
}

stream::CUDAStream* CUDADeviceContext::GetCudaStream() const {
  return cuda_stream_.get();
}

stream::CUDAStream* CUDADeviceContext::SetCudaStream(
    stream::CUDAStream* new_stream_ptr) {
  auto* old_stream_ptr = cuda_stream_.release();
  cuda_stream_.reset(new_stream_ptr);
  return old_stream_ptr;
}
Q
qijun 已提交
694

C
chengduoZH 已提交
695 696 697 698 699 700 701 702 703 704 705 706 707
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

W
Wilber 已提交
708
const Place& CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
709
#endif
Q
qijun 已提交
710

T
tensor-tang 已提交
711 712
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
713
    : CPUDeviceContext(place), p_blobmap_() {
714
  p_blobmap_.reset(new BlobMap());
715
  p_exec_items_.reset(new ExecShape());
716
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
717 718
}

719
MKLDNNDeviceContextThreadLocals::Body::Body()
720
    : cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
721 722 723 724 725 726
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

727 728 729 730 731 732 733 734 735 736 737 738
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
739
  dev_ctx->ResetBlobMap(exec_ptr_);
740 741
}

742 743 744 745 746 747 748 749 750 751
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
752 753
  cur_input_shape_str = input_shape_str;
}
754 755
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
756 757
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
758

759 760
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
761 762 763
  cur_paddle_data_layout = dl;
}

764 765
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
766 767 768
  return cur_paddle_data_layout;
}

769 770 771 772 773 774 775 776 777
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

778
const dnnl::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
779 780 781
  return cur_engine;
}

782
dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
783 784 785
  return cur_stream;
}

786
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
L
Leo Chen 已提交
787
  VLOG(4) << tls().get_curr_exec() << " " << ptr;
788
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
789
  if (block_next_cache_clearing_ == 0) {
790
    VLOG(3) << "Clearing DNNL cache.";
791 792 793 794 795 796
    // If no specific executor pointer then clear
    // everything. For executor pointer then clear only
    // objects allocated when using given executor
    if (ptr == nullptr) {
      p_blobmap_->clear();
    } else {
797 798 799 800 801
      // Iterate through all shapes and release
      // for each shape and active executor all entries
      // of this executor
      for (auto& s : *p_exec_items_) {
        for (auto& v : (*s.second)[ptr]) {
802
          (v.first)->erase(v.second);
803 804
        }
        s.second->erase(ptr);
805 806
      }
    }
807 808 809 810
    // Reset paddle layout to NCHW
    VLOG(3) << "Resetting Paddle data layout to NCHW.";
    platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
        paddle::framework::DataLayout::kNCHW);
811
  } else {
812 813 814 815 816 817 818 819 820
    --block_next_cache_clearing_;
    VLOG(3) << "Prevented Clearing DNNL cache. Updated "
               "block_next_cache_clearing_ : "
            << block_next_cache_clearing_;
    PADDLE_ENFORCE_GE(block_next_cache_clearing_, 0,
                      platform::errors::InvalidArgument(
                          "Cache clearing mark should be non-negative "
                          ". But received %d.",
                          block_next_cache_clearing_));
821 822 823
  }
}

824 825
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
  p_exec_items_->erase(p_exec_items_->begin());
826 827
}

828 829
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
                                                KeyBlob::iterator it) const {
830
  // Take current input shape from TLS
831 832
  // Take current executor addess from TLS
  // and for this executor's items add the one defined with arguments
833 834 835 836 837 838 839 840 841
  auto key_it = p_exec_items_
                    ->insert(std::make_pair(tls().cur_input_shape_str,
                                            std::make_shared<ExecMap>()))
                    .first;
  (*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));

  VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
          << " curr exec size: "
          << (*key_it->second)[tls().get_curr_exec()].size() << "\n";
842 843
}

844 845
void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
846 847 848 849
  ++block_next_cache_clearing_;
  VLOG(3) << "Next DNNL cache clearing has been blocked. Updated "
             "block_next_cache_clearing_ : "
          << block_next_cache_clearing_;
850
}
851

852
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
853
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
854
  BlobMap* pMap = p_blobmap_.get();
855
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
856
  if (map_it == pMap->end()) {
857 858 859
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
860 861 862 863
  }
  return map_it->second->size();
}

864
void MKLDNNDeviceContext::SetBlob(const std::string& name,
865
                                  BlobPtr_t<void> data) const {
866
  BlobMap* pMap = p_blobmap_.get();
867
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
868
  BlobPtr_t<KeyBlob> pBlob = nullptr;
869

870
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
871

872
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
873

874 875
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
876 877 878

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
879
    sBlob = std::make_shared<ShapeBlob>();
880 881
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
882
  } else {
883
    sBlob = map_it->second;
884
  }
T
tensor-tang 已提交
885

886
  // Find KeyBlob for current input shape
887
  auto key_it = sBlob->find(tls().cur_input_shape_str);
888

889
  if (key_it == sBlob->end()) {
890 891
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
892 893
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
894
        sBlob->size() &&
895
        (sBlob->size() >=
896
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
897 898 899 900
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
      RemoveShapeEntriesWithExecutor();
901
    }
902
    pBlob = std::make_shared<KeyBlob>();
903
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
904
  } else {
905
    pBlob = key_it->second;
906 907
  }

908
  // Find Blob via name
909 910 911 912
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    auto el =
        pBlob->insert(std::make_pair(name, data));  //  (*pBlob)[name] = data;
913 914 915
    // Register new element in per executor map
    // to have easily erased when executor terminated
    LinkEntryWithExecutor(pBlob, el.first);
916 917 918
  } else {
    blob_it->second = data;  // set data to existing blob
  }
919
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
920
  // lock will be automatically released when out of scope
921
  return;
T
tensor-tang 已提交
922 923
}

924
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
925 926 927
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
928
      num_entries += (l2.second)->size();
929 930 931 932 933
    }
  }
  return num_entries;
}

934
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
935
    const std::string& name) const {
936
  BlobMap* pMap = p_blobmap_.get();
937
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
938
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
939

940
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
941

942
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
943

944 945
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
946 947 948 949
  // (jczaja): After first iteration of model's execution we
  // should have all elements cached (mostly) so failures are unlikely (less
  // likely for dynamic shapes)
  if (unlikely(map_it == pMap->end())) {
950
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
951 952 953 954 955
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
956
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
957
  if (unlikely(sBlob_it == sBlob->end())) {
958
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
959 960 961 962
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
963 964

  // Find Blob via name
965
  auto key_it = pBlob->find(name);
966

967
  if (unlikely(key_it == pBlob->end())) {
968
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
969 970
    return nullptr;
  }
971

972
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
973 974
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
975 976
}

977 978 979
#endif

#ifdef PADDLE_WITH_CUSTOM_DEVICE
980 981 982
CustomDeviceContext::CustomDeviceContext(CustomPlace place)
    : phi::CustomContext(place) {
  Init();
983
  stream_.reset(new phi::stream::Stream(place, stream()));
984 985 986
}

CustomDeviceContext::~CustomDeviceContext() {}
T
tensor-tang 已提交
987
#endif
Q
qijun 已提交
988
}  // namespace platform
Q
qijun 已提交
989
}  // namespace paddle