device_context.cc 33.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3
Copyright (c) 2022 NVIDIA Corporation. All rights reserved.

Q
qijun 已提交
4 5 6 7
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
8

Q
qijun 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/platform/device_context.h"
16

W
Wilber 已提交
17
#include <functional>
18
#include <memory>
19
#include <set>
20

21 22 23 24 25
#include "glog/logging.h"
#include "paddle/fluid/framework/expect.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
W
Wilber 已提交
26
#include "paddle/fluid/platform/place.h"
27 28
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
W
Wilber 已提交
29
#include "paddle/fluid/platform/stream/cuda_stream.h"
30 31
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/allocator.h"
32

33
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
34
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
35
#include "paddle/fluid/platform/cuda_device_guard.h"
36
#endif
37

F
fwenguang 已提交
38 39 40 41
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/fluid/platform/device/mlu/device_context_allocator.h"
#endif
42

43 44 45 46 47
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
48
  if (size == 0) {
49 50
    return Alloc(place, size);
  }
51 52

  if (platform::is_gpu_place(place)) {
53
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
54 55 56 57 58
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
59 60
      return paddle::memory::Alloc(desired_dev_ctx.GetPlace(),
                                   size,
61 62
                                   phi::Stream(reinterpret_cast<phi::StreamId>(
                                       desired_dev_ctx.stream())));
63 64 65 66 67 68 69 70 71 72 73 74
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
75 76
    return Alloc(place, size);
#else
77 78 79
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
F
fwenguang 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
#endif
  } else if (platform::is_mlu_place(place)) {
#ifdef PADDLE_WITH_MLU
    auto* default_dev_ctx = static_cast<platform::MLUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::MLUDeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::MLUDeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use MLU device since it's not compiled with MLU,"
        "Please recompile or reinstall Paddle with MLU support."));
97
#endif
98 99 100
  } else {
    return Alloc(place, size);
  }
101 102 103 104 105
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
106 107 108
namespace paddle {
namespace platform {

109
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
110 111 112
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
113 114 115 116

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
117 118
#endif  // PADDLE_WITH_CUDA

119 120 121 122 123 124 125
DeviceType Place2DeviceType(const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
    return platform::DeviceType::CPU;
  } else if (platform::is_gpu_place(place)) {
    return platform::DeviceType::CUDA;
  } else if (platform::is_xpu_place(place)) {
    return platform::DeviceType::XPU;
F
fwenguang 已提交
126 127
  } else if (platform::is_mlu_place(place)) {
    return platform::DeviceType::MLU;
128 129 130 131 132 133
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported place %s to convert into platform::DeviceType.", place));
  }
}

D
dzhwinter 已提交
134
DeviceContextPool* DeviceContextPool::pool = nullptr;
135 136 137
thread_local const std::map<Place,
                            std::shared_future<std::unique_ptr<DeviceContext>>>*
    DeviceContextPool::external_device_contexts_ = nullptr;
D
dzhwinter 已提交
138

Y
Yu Yang 已提交
139
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
140
  VLOG(6) << "DeviceContextPool Get: " << place;
141 142 143 144 145 146 147 148 149 150
  const std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
      ptr;
  if (external_device_contexts_ && external_device_contexts_->count(place)) {
    ptr = external_device_contexts_;
  } else {
    ptr = &device_contexts_;
  }

  auto it = ptr->find(place);
  if (it == ptr->end()) {
G
GaoWei8 已提交
151 152
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
F
fwenguang 已提交
153 154
        "with WITH_GPU, WITH_XPU, WITH_IPU, WITH_MLU or WITH_ASCEND_CL option "
        "or check "
J
jianghaicheng 已提交
155 156
        "that your train process set the correct device id if you use "
        "Executor.",
G
GaoWei8 已提交
157
        place));
D
dzhwinter 已提交
158
  }
159
  return it->second.get().get();
D
dzhwinter 已提交
160 161
}

162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
size_t DeviceContextPool::size() const {
  if (external_device_contexts_) {
    return external_device_contexts_->size();
  }
  return device_contexts_.size();
}

const std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>&
DeviceContextPool::device_contexts() const {
  if (external_device_contexts_) {
    return *external_device_contexts_;
  }
  return device_contexts_;
}

void DeviceContextPool::SetDeviceContexts(
    const std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        dev_ctxs) {
  external_device_contexts_ = dev_ctxs;
}

W
Wilber 已提交
183
template <typename DevCtx>
184 185 186
std::unique_ptr<DeviceContext> CreateDeviceContext(
    const platform::Place& p,
    bool disable_setting_default_stream_for_allocator = false) {
187
  using PtrType = std::unique_ptr<DeviceContext>;
188 189
  auto* dev_ctx = new DevCtx(p);
  if (is_gpu_place(p)) {
190
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
    auto* cuda_ctx = dynamic_cast<CUDADeviceContext*>(dev_ctx);
    PADDLE_ENFORCE_NOT_NULL(
        cuda_ctx,
        platform::errors::InvalidArgument(
            "Failed to dynamic_cast dev_ctx into CUDADeviceContext."));

    auto& instance = memory::allocation::AllocatorFacade::Instance();
    if (!disable_setting_default_stream_for_allocator) {
      instance.SetDefaultStream(CUDAPlace(p.GetDeviceId()), cuda_ctx->stream());
    }
    dev_ctx->SetAllocator(instance.GetAllocator(p).get());
    dev_ctx->SetPinnedAllocator(
        instance.GetAllocator(paddle::platform::CUDAPinnedPlace()).get());

    cuda_ctx->PartialInitWithAllocator();
    dev_ctx->SetGenerator(
        framework::DefaultCUDAGenerator(p.GetDeviceId()).get());
208
#endif
209 210 211 212 213 214 215 216 217 218 219 220 221
  } else {
    dev_ctx->SetAllocator(
        memory::allocation::AllocatorFacade::Instance().GetAllocator(p).get());
    dev_ctx->SetGenerator(framework::DefaultCPUGenerator().get());
  }
  dev_ctx->SetHostGenerator(framework::DefaultCPUGenerator().get());
  dev_ctx->SetHostAllocator(memory::allocation::AllocatorFacade::Instance()
                                .GetAllocator(platform::CPUPlace())
                                .get());
  dev_ctx->SetZeroAllocator(memory::allocation::AllocatorFacade::Instance()
                                .GetZeroAllocator(p)
                                .get());
  return PtrType(dev_ctx);
C
chengduozh 已提交
222 223
}

224 225 226 227
template <typename DevCtx>
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        place_to_device_context,
228 229
    platform::Place place,
    bool disable_setting_default_stream_for_allocator) {
230 231
  // lazy evaluation. i.e., only create device context at first `Get`
  place_to_device_context->emplace(
232 233 234 235 236
      place,
      std::async(std::launch::deferred,
                 CreateDeviceContext<DevCtx>,
                 place,
                 disable_setting_default_stream_for_allocator));
237 238 239 240 241 242 243
}

void EmplaceDeviceContexts(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        place_to_device_context,
    const std::vector<platform::Place>& places,
    bool disable_setting_default_stream_for_allocator) {
G
GaoWei8 已提交
244
  PADDLE_ENFORCE_GT(
245 246
      places.size(),
      0,
G
GaoWei8 已提交
247 248 249
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
250

251
  std::set<Place> set;
Y
Yu Yang 已提交
252 253 254
  for (auto& p : places) {
    set.insert(p);
  }
255

Y
Yu Yang 已提交
256 257
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
258
#ifdef PADDLE_WITH_MKLDNN
259
      EmplaceDeviceContext<MKLDNNDeviceContext>(
260 261
          place_to_device_context,
          p,
262
          disable_setting_default_stream_for_allocator);
263
#else
264
      EmplaceDeviceContext<CPUDeviceContext>(
265 266
          place_to_device_context,
          p,
267
          disable_setting_default_stream_for_allocator);
268
#endif
Y
Yu Yang 已提交
269
    } else if (platform::is_gpu_place(p)) {
270
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
271
      EmplaceDeviceContext<CUDADeviceContext>(
272 273
          place_to_device_context,
          p,
274
          disable_setting_default_stream_for_allocator);
D
dzhwinter 已提交
275
#else
G
GaoWei8 已提交
276 277 278
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
279 280
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
281
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
282
      EmplaceDeviceContext<CUDAPinnedDeviceContext>(
283 284
          place_to_device_context,
          p,
285
          disable_setting_default_stream_for_allocator);
C
chengduoZH 已提交
286
#else
G
GaoWei8 已提交
287
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
288 289
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
290 291 292
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
293
      EmplaceDeviceContext<XPUDeviceContext>(
294 295
          place_to_device_context,
          p,
296
          disable_setting_default_stream_for_allocator);
297 298 299 300
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
F
fwenguang 已提交
301 302 303
#endif
    } else if (platform::is_mlu_place(p)) {
#ifdef PADDLE_WITH_MLU
304
      EmplaceDeviceContext<MLUDeviceContext>(
305 306
          place_to_device_context,
          p,
307
          disable_setting_default_stream_for_allocator);
F
fwenguang 已提交
308 309 310 311
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("MLUPlace is not supported. Please "
                                          "re-compile with WITH_MLU option."));
J
jianghaicheng 已提交
312 313 314
#endif
    } else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU
315
      EmplaceDeviceContext<IPUDeviceContext>(
316 317
          place_to_device_context,
          p,
318
          disable_setting_default_stream_for_allocator);
J
jianghaicheng 已提交
319 320 321 322
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("IPUPlace is not supported. Please "
                                          "re-compile with WITH_IPU option."));
323 324 325
#endif
    } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
326
      EmplaceDeviceContext<NPUDeviceContext>(
327 328
          place_to_device_context,
          p,
329
          disable_setting_default_stream_for_allocator);
330 331 332 333
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPlace is not supported. Please "
          "re-compile with WITH_ASCEND_CL option."));
334 335 336
#endif
    } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
337
      EmplaceDeviceContext<NPUPinnedDeviceContext>(
338 339
          place_to_device_context,
          p,
340
          disable_setting_default_stream_for_allocator);
341 342 343 344 345
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPinnedPlace is not supported. Please re-compile with "
          "WITH_ASCEND_CL "
          "option."));
346 347 348
#endif
    } else if (platform::is_custom_place(p)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
349
      EmplaceDeviceContext<CustomDeviceContext>(
350 351
          place_to_device_context,
          p,
352
          disable_setting_default_stream_for_allocator);
353 354 355 356 357
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "CustomPlace is not supported. Please re-compile with "
          "WITH_CUSTOM_DEVICE "
          "option."));
D
dzhwinter 已提交
358 359 360 361 362
#endif
    }
  }
}

363 364
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
365 366
  EmplaceDeviceContexts(&device_contexts_,
                        places,
367 368 369
                        /*disable_setting_default_stream_for_allocator=*/false);
}

J
jianghaicheng 已提交
370
#ifdef PADDLE_WITH_IPU
A
Allen Guo 已提交
371
IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) {}
J
jianghaicheng 已提交
372

W
Wilber 已提交
373
const Place& IPUDeviceContext::GetPlace() const { return place_; }
A
Allen Guo 已提交
374

J
jianghaicheng 已提交
375 376 377 378 379 380 381
void IPUDeviceContext::Wait() const {
  /*! \brief  Wait for all operations completion in the stream. */
}

IPUDeviceContext::~IPUDeviceContext() {}

#endif
382
#ifdef PADDLE_WITH_XPU
383 384
XPUDeviceContext::XPUDeviceContext() : phi::XPUContext() {
  phi::XPUContext::Init();
W
Wilber 已提交
385
}
386

387
XPUDeviceContext::~XPUDeviceContext() {}
388

389 390
XPUDeviceContext::XPUDeviceContext(XPUPlace place) : phi::XPUContext(place) {
  phi::XPUContext::Init();
391
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: "
W
Wilber 已提交
392
                          << static_cast<int>(place.device);
393 394 395
}
#endif

396 397 398 399 400 401 402
#ifdef PADDLE_WITH_ASCEND_CL
NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) {
  NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device));
  // NOTE(zhiqiu): Usually, no need to create context explicitly,
  // ACL creates a default context which contains 1 default stream
  // and 1 sync strean after aclrtSetDevice.
403
  platform::GetCurrentNPUContext(&context_);
404 405 406 407 408 409 410
  stream_.reset(new stream::NPUStream(place));
}

NPUDeviceContext::~NPUDeviceContext() {
  // NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_));
}
411

412
void NPUDeviceContext::Wait() const {
413 414
  platform::RecordEvent record_event(
      "NPUDeviceContext/wait", platform::TracerEventType::UserDefined, 2);
415 416
  VLOG(4) << "NPU context(" << this << ")  Wait";
  stream_->Wait();
417 418 419 420
}

aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); }

W
Wilber 已提交
421
const Place& NPUDeviceContext::GetPlace() const { return place_; }
422 423

aclrtContext NPUDeviceContext::context() const { return context_; }
424 425 426 427 428 429 430 431 432 433 434 435 436 437

NPUPinnedDeviceContext::NPUPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

NPUPinnedDeviceContext::NPUPinnedDeviceContext(NPUPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* NPUPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

W
Wilber 已提交
438
const Place& NPUPinnedDeviceContext::GetPlace() const { return place_; }
439

440 441 442
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
init  
qijun 已提交
443 444 445 446 447 448 449
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

450
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
451 452 453 454 455
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

456
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
457

458 459 460
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
461
  const cudaDeviceProp& deviceProperties() const override {
462
#endif
Q
init  
qijun 已提交
463 464 465 466
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
467 468 469
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
470 471 472
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
473
    void* retv = buf->ptr();
S
sneaxiy 已提交
474 475 476 477
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
478
    return retv;
Q
init  
qijun 已提交
479 480
  }

S
sneaxiy 已提交
481 482 483 484 485 486
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
487 488 489

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
490
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
491 492 493 494 495 496
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
497
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
498
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
499
#ifdef PADDLE_WITH_HIP
500
      PADDLE_ENFORCE_GPU_SUCCESS(
501 502
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
503
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
init  
qijun 已提交
504
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
505
#endif
Q
init  
qijun 已提交
506 507 508 509 510
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
511
  CUDAPlace place_;
512 513 514 515
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
516
  const cudaDeviceProp* device_prop_;  // not owned;
517
#endif
Q
qijun 已提交
518
  mutable void* scratch_;
Q
init  
qijun 已提交
519
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
520
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
521
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
522 523
};

524 525 526 527 528 529 530 531 532
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

533 534 535 536 537 538 539 540 541 542 543 544
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
545 546
                         const stream::Priority& priority,
                         const stream::StreamFlag& flag) {
547 548
  place_ = place;
  CUDADeviceGuard guard(place_.device);
549
  stream_.reset(new stream::CUDAStream(place, priority, flag));
550 551 552
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
553
#ifndef PADDLE_WITH_HIP
554 555 556
#if CUDA_VERSION >= 11060
  InitCuBlasLtContext();
#endif
Z
zhangkaihuo 已提交
557
  InitCuSparseContext();
G
Guo Sheng 已提交
558
  InitCuSolverContext();
559
#endif
560 561
}

W
Wilber 已提交
562 563 564 565 566 567
void CUDAContext::SetStream(gpuStream_t stream) {
  if (stream_->raw_stream() != stream) {
    CUDADeviceGuard guard(place_.device);
    DestoryCuDNNContext();
    DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
568 569 570
#if CUDA_VERSION >= 11060
    DestoryCuBlasLtContext();
#endif
W
Wilber 已提交
571 572 573 574 575 576 577 578 579
    DestoryCuSolverContext();
#endif

    stream_->SetStream(stream);

    InitEigenContext();
    InitCuBlasContext();
    InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
580 581 582
#if CUDA_VERSION >= 11060
    InitCuBlasLtContext();
#endif
W
Wilber 已提交
583 584 585 586 587
    InitCuSolverContext();
#endif
  }
}

588 589 590 591
CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
592
#ifndef PADDLE_WITH_HIP
593 594 595
#if CUDA_VERSION >= 11060
  InitCuBlasLtContext();
#endif
Z
zhangkaihuo 已提交
596
  DestoryCuSparseContext();
G
Guo Sheng 已提交
597
  DestoryCuSolverContext();
598
#endif
599 600
}

601 602 603
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : phi::GPUContext(place) {
  phi::GPUContext::PartialInitWithoutAllocator();
  cuda_stream_.reset(new stream::CUDAStream(phi::GPUContext::stream(), place));
604 605
}

W
Wilber 已提交
606
CUDADeviceContext::~CUDADeviceContext() = default;
607

608
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
W
Wilber 已提交
609 610 611
  if (thread_ctx_.count(this)) {
    return context()->EigenDevice().get();
  }
612
  return phi::GPUContext::eigen_device();
S
sneaxiy 已提交
613 614
}

W
Wilber 已提交
615
void CUDADeviceContext::Wait() const {
616
  VLOG(4) << "CUDA context(" << this << ")  Wait";
W
Wilber 已提交
617 618 619 620
  if (thread_ctx_.count(this)) {
    context()->Stream()->Wait();
    return;
  }
621
  phi::GPUContext::Wait();
622 623
}

624 625 626
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
627
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
628
#endif
W
Wilber 已提交
629 630 631
  if (thread_ctx_.count(this)) {
    return context()->CudnnHandle();
  }
632
  return phi::GPUContext::cudnn_handle();
633
}
634

635 636
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
W
Wilber 已提交
637 638 639
  if (thread_ctx_.count(this)) {
    return context()->CublasHandle()->GetCublasHandle();
  }
640
  return phi::GPUContext::cublas_handle();
641 642
}
#else
643
cublasHandle_t CUDADeviceContext::cublas_handle() const {
W
Wilber 已提交
644 645 646
  if (thread_ctx_.count(this)) {
    return context()->CublasHandle()->GetCublasHandle();
  }
647
  return phi::GPUContext::cublas_handle();
648
}
649 650 651 652 653 654 655 656
#if CUDA_VERSION >= 11060
cublasLtHandle_t CUDADeviceContext::cublaslt_handle() const {
  if (thread_ctx_.count(this)) {
    return context()->CublasLtHandle()->GetCublasLtHandle();
  }
  return phi::GPUContext::cublaslt_handle();
}
#endif
Z
zhangkaihuo 已提交
657
cusparseHandle_t CUDADeviceContext::cusparse_handle() const {
W
Wilber 已提交
658 659 660
  if (thread_ctx_.count(this)) {
    return context()->CusparseHandle()->GetCusparseHandle();
  }
661
  return phi::GPUContext::cusparse_handle();
W
Wilber 已提交
662 663 664 665 666
}
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  if (thread_ctx_.count(this)) {
    return context()->CusolverDnHandle();
  }
667
  return phi::GPUContext::cusolver_dn_handle();
Z
zhangkaihuo 已提交
668
}
669
#endif
670

W
Wilber 已提交
671 672 673 674 675 676
void CUDADeviceContext::RecordEvent(
    gpuEvent_t ev, const std::function<void()>& callback) const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->RecordEvent(ev, callback);
    return;
  }
677
  phi::GPUContext::RecordEvent(ev, callback);
W
Wilber 已提交
678 679 680 681 682 683 684 685
}

void CUDADeviceContext::AddStreamCallback(
    const std::function<void()>& callback) const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->AddCallback(callback);
    return;
  }
686
  phi::GPUContext::AddStreamCallback(callback);
W
Wilber 已提交
687 688 689 690 691 692 693
}

void CUDADeviceContext::WaitStreamCallback() const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->WaitCallback();
    return;
  }
694
  phi::GPUContext::WaitStreamCallback();
W
Wilber 已提交
695 696
}

697
phi::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
W
Wilber 已提交
698 699
  if (thread_ctx_.count(this)) {
    // return workspace_.get();
700
    return phi::DnnWorkspaceHandle(
W
Wilber 已提交
701
        memory::allocation::AllocatorFacade::Instance()
702
            .GetAllocator(GetPlace())
703 704
            .get(),
        stream());
W
Wilber 已提交
705
  }
706
  return phi::GPUContext::cudnn_workspace_handle();
707
}
708

W
Wilber 已提交
709 710 711 712
gpuStream_t CUDADeviceContext::stream() const {
  if (thread_ctx_.count(this)) {
    return context()->RawStream();
  }
713
  return phi::GPUContext::stream();
G
Guo Sheng 已提交
714 715
}

W
Wilber 已提交
716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734
std::shared_ptr<CUDAContext> CUDADeviceContext::context() const {
  if (!thread_ctx_.count(this)) {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "CUDADeviceContext call context() failed, make sure in the "
        "thread_local semantic."));
  }
  return thread_ctx_.at(this);
}

stream::CUDAStream* CUDADeviceContext::GetCudaStream() const {
  return cuda_stream_.get();
}

stream::CUDAStream* CUDADeviceContext::SetCudaStream(
    stream::CUDAStream* new_stream_ptr) {
  auto* old_stream_ptr = cuda_stream_.release();
  cuda_stream_.reset(new_stream_ptr);
  return old_stream_ptr;
}
Q
qijun 已提交
735

C
chengduoZH 已提交
736 737 738 739 740 741 742 743 744 745 746 747 748
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

W
Wilber 已提交
749
const Place& CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
750
#endif
Q
qijun 已提交
751

T
tensor-tang 已提交
752 753
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
754
    : CPUDeviceContext(place), p_blobmap_() {
755
  p_blobmap_.reset(new BlobMap());
756
  p_exec_items_.reset(new ExecShape());
757
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
758 759
}

760
MKLDNNDeviceContextThreadLocals::Body::Body()
761
    : cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
762 763 764 765 766 767
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

768 769 770 771 772 773 774 775 776 777 778 779
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
780
  dev_ctx->ResetBlobMap(exec_ptr_);
781 782
}

783 784 785 786 787 788 789 790 791 792
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
793 794
  cur_input_shape_str = input_shape_str;
}
795 796
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
797 798
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
799

800 801
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
802 803 804
  cur_paddle_data_layout = dl;
}

805 806
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
807 808 809
  return cur_paddle_data_layout;
}

810 811 812 813 814 815 816 817 818
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

819
const dnnl::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
820 821 822
  return cur_engine;
}

823
dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
824 825 826
  return cur_stream;
}

827
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
L
Leo Chen 已提交
828
  VLOG(4) << tls().get_curr_exec() << " " << ptr;
829
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
830
  if (block_next_cache_clearing_ == 0) {
831
    VLOG(3) << "Clearing DNNL cache.";
832 833 834 835 836 837
    // If no specific executor pointer then clear
    // everything. For executor pointer then clear only
    // objects allocated when using given executor
    if (ptr == nullptr) {
      p_blobmap_->clear();
    } else {
838 839 840 841 842
      // Iterate through all shapes and release
      // for each shape and active executor all entries
      // of this executor
      for (auto& s : *p_exec_items_) {
        for (auto& v : (*s.second)[ptr]) {
843
          (v.first)->erase(v.second);
844 845
        }
        s.second->erase(ptr);
846 847
      }
    }
848 849 850 851
    // Reset paddle layout to NCHW
    VLOG(3) << "Resetting Paddle data layout to NCHW.";
    platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
        paddle::framework::DataLayout::kNCHW);
852
  } else {
853 854 855 856
    --block_next_cache_clearing_;
    VLOG(3) << "Prevented Clearing DNNL cache. Updated "
               "block_next_cache_clearing_ : "
            << block_next_cache_clearing_;
857 858
    PADDLE_ENFORCE_GE(block_next_cache_clearing_,
                      0,
859 860 861 862
                      platform::errors::InvalidArgument(
                          "Cache clearing mark should be non-negative "
                          ". But received %d.",
                          block_next_cache_clearing_));
863 864 865
  }
}

866 867
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
  p_exec_items_->erase(p_exec_items_->begin());
868 869
}

870 871
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
                                                KeyBlob::iterator it) const {
872
  // Take current input shape from TLS
873 874
  // Take current executor addess from TLS
  // and for this executor's items add the one defined with arguments
875 876 877 878 879 880 881 882 883
  auto key_it = p_exec_items_
                    ->insert(std::make_pair(tls().cur_input_shape_str,
                                            std::make_shared<ExecMap>()))
                    .first;
  (*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));

  VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
          << " curr exec size: "
          << (*key_it->second)[tls().get_curr_exec()].size() << "\n";
884 885
}

886 887
void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
888 889 890 891
  ++block_next_cache_clearing_;
  VLOG(3) << "Next DNNL cache clearing has been blocked. Updated "
             "block_next_cache_clearing_ : "
          << block_next_cache_clearing_;
892
}
893

894
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
895
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
896
  BlobMap* pMap = p_blobmap_.get();
897
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
898
  if (map_it == pMap->end()) {
899 900 901
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
902 903 904 905
  }
  return map_it->second->size();
}

906
void MKLDNNDeviceContext::SetBlob(const std::string& name,
907
                                  BlobPtr_t<void> data) const {
908
  BlobMap* pMap = p_blobmap_.get();
909
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
910
  BlobPtr_t<KeyBlob> pBlob = nullptr;
911

912
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
913

914
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
915

916 917
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
918 919 920

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
921
    sBlob = std::make_shared<ShapeBlob>();
922 923
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
924
  } else {
925
    sBlob = map_it->second;
926
  }
T
tensor-tang 已提交
927

928
  // Find KeyBlob for current input shape
929
  auto key_it = sBlob->find(tls().cur_input_shape_str);
930

931
  if (key_it == sBlob->end()) {
932 933
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
934 935
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
936
        sBlob->size() &&
937
        (sBlob->size() >=
938
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
939 940 941 942
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
      RemoveShapeEntriesWithExecutor();
943
    }
944
    pBlob = std::make_shared<KeyBlob>();
945
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
946
  } else {
947
    pBlob = key_it->second;
948 949
  }

950
  // Find Blob via name
951 952 953 954
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    auto el =
        pBlob->insert(std::make_pair(name, data));  //  (*pBlob)[name] = data;
955 956 957
    // Register new element in per executor map
    // to have easily erased when executor terminated
    LinkEntryWithExecutor(pBlob, el.first);
958 959 960
  } else {
    blob_it->second = data;  // set data to existing blob
  }
961
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
962
  // lock will be automatically released when out of scope
963
  return;
T
tensor-tang 已提交
964 965
}

966
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
967 968 969
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
970
      num_entries += (l2.second)->size();
971 972 973 974 975
    }
  }
  return num_entries;
}

976
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
977
    const std::string& name) const {
978
  BlobMap* pMap = p_blobmap_.get();
979
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
980
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
981

982
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
983

984
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
985

986 987
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
988 989 990 991
  // (jczaja): After first iteration of model's execution we
  // should have all elements cached (mostly) so failures are unlikely (less
  // likely for dynamic shapes)
  if (unlikely(map_it == pMap->end())) {
992
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
993 994 995 996 997
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
998
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
999
  if (unlikely(sBlob_it == sBlob->end())) {
1000
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
1001 1002 1003 1004
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
1005 1006

  // Find Blob via name
1007
  auto key_it = pBlob->find(name);
1008

1009
  if (unlikely(key_it == pBlob->end())) {
1010
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
1011 1012
    return nullptr;
  }
1013

1014
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
1015 1016
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
1017 1018
}

1019 1020 1021
#endif

#ifdef PADDLE_WITH_CUSTOM_DEVICE
1022 1023 1024
CustomDeviceContext::CustomDeviceContext(CustomPlace place)
    : phi::CustomContext(place) {
  Init();
1025
  stream_.reset(new phi::stream::Stream(place, stream()));
1026 1027 1028
}

CustomDeviceContext::~CustomDeviceContext() {}
T
tensor-tang 已提交
1029
#endif
Q
qijun 已提交
1030
}  // namespace platform
Q
qijun 已提交
1031
}  // namespace paddle