device_context.cc 33.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3
Copyright (c) 2022 NVIDIA Corporation. All rights reserved.

Q
qijun 已提交
4 5 6 7
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
8

Q
qijun 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/platform/device_context.h"
16

W
Wilber 已提交
17
#include <functional>
18
#include <memory>
19
#include <set>
20

21 22 23 24 25
#include "glog/logging.h"
#include "paddle/fluid/framework/expect.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
W
Wilber 已提交
26
#include "paddle/fluid/platform/place.h"
27 28
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
W
Wilber 已提交
29
#include "paddle/fluid/platform/stream/cuda_stream.h"
30 31
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/allocator.h"
32

33
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
34
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
35
#include "paddle/fluid/platform/cuda_device_guard.h"
36
#endif
37

F
fwenguang 已提交
38 39 40 41
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/fluid/platform/device/mlu/device_context_allocator.h"
#endif
42

43 44 45 46 47
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
48
  if (size == 0) {
49 50
    return Alloc(place, size);
  }
51 52

  if (platform::is_gpu_place(place)) {
53
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
54 55 56 57 58
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
59 60 61
      return paddle::memory::Alloc(desired_dev_ctx.GetPlace(), size,
                                   phi::Stream(reinterpret_cast<phi::StreamId>(
                                       desired_dev_ctx.stream())));
62 63 64 65 66 67 68 69 70 71 72 73
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
74 75
    return Alloc(place, size);
#else
76 77 78
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
F
fwenguang 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
#endif
  } else if (platform::is_mlu_place(place)) {
#ifdef PADDLE_WITH_MLU
    auto* default_dev_ctx = static_cast<platform::MLUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::MLUDeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::MLUDeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use MLU device since it's not compiled with MLU,"
        "Please recompile or reinstall Paddle with MLU support."));
96
#endif
97 98 99
  } else {
    return Alloc(place, size);
  }
100 101 102 103 104
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
105 106 107
namespace paddle {
namespace platform {

108
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
109 110 111
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
112 113 114 115

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
116 117
#endif  // PADDLE_WITH_CUDA

118 119 120 121 122 123 124
DeviceType Place2DeviceType(const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
    return platform::DeviceType::CPU;
  } else if (platform::is_gpu_place(place)) {
    return platform::DeviceType::CUDA;
  } else if (platform::is_xpu_place(place)) {
    return platform::DeviceType::XPU;
F
fwenguang 已提交
125 126
  } else if (platform::is_mlu_place(place)) {
    return platform::DeviceType::MLU;
127 128 129 130 131 132
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported place %s to convert into platform::DeviceType.", place));
  }
}

D
dzhwinter 已提交
133
DeviceContextPool* DeviceContextPool::pool = nullptr;
134 135 136
thread_local const std::map<Place,
                            std::shared_future<std::unique_ptr<DeviceContext>>>*
    DeviceContextPool::external_device_contexts_ = nullptr;
D
dzhwinter 已提交
137

Y
Yu Yang 已提交
138
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
139
  VLOG(6) << "DeviceContextPool Get: " << place;
140 141 142 143 144 145 146 147 148 149
  const std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
      ptr;
  if (external_device_contexts_ && external_device_contexts_->count(place)) {
    ptr = external_device_contexts_;
  } else {
    ptr = &device_contexts_;
  }

  auto it = ptr->find(place);
  if (it == ptr->end()) {
G
GaoWei8 已提交
150 151
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
F
fwenguang 已提交
152 153
        "with WITH_GPU, WITH_XPU, WITH_IPU, WITH_MLU or WITH_ASCEND_CL option "
        "or check "
J
jianghaicheng 已提交
154 155
        "that your train process set the correct device id if you use "
        "Executor.",
G
GaoWei8 已提交
156
        place));
D
dzhwinter 已提交
157
  }
158
  return it->second.get().get();
D
dzhwinter 已提交
159 160
}

161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
size_t DeviceContextPool::size() const {
  if (external_device_contexts_) {
    return external_device_contexts_->size();
  }
  return device_contexts_.size();
}

const std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>&
DeviceContextPool::device_contexts() const {
  if (external_device_contexts_) {
    return *external_device_contexts_;
  }
  return device_contexts_;
}

void DeviceContextPool::SetDeviceContexts(
    const std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        dev_ctxs) {
  external_device_contexts_ = dev_ctxs;
}

W
Wilber 已提交
182
template <typename DevCtx>
183 184 185
std::unique_ptr<DeviceContext> CreateDeviceContext(
    const platform::Place& p,
    bool disable_setting_default_stream_for_allocator = false) {
186
  using PtrType = std::unique_ptr<DeviceContext>;
187 188
  auto* dev_ctx = new DevCtx(p);
  if (is_gpu_place(p)) {
189
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
    auto* cuda_ctx = dynamic_cast<CUDADeviceContext*>(dev_ctx);
    PADDLE_ENFORCE_NOT_NULL(
        cuda_ctx,
        platform::errors::InvalidArgument(
            "Failed to dynamic_cast dev_ctx into CUDADeviceContext."));

    auto& instance = memory::allocation::AllocatorFacade::Instance();
    if (!disable_setting_default_stream_for_allocator) {
      instance.SetDefaultStream(CUDAPlace(p.GetDeviceId()), cuda_ctx->stream());
    }
    dev_ctx->SetAllocator(instance.GetAllocator(p).get());
    dev_ctx->SetPinnedAllocator(
        instance.GetAllocator(paddle::platform::CUDAPinnedPlace()).get());

    cuda_ctx->PartialInitWithAllocator();
    dev_ctx->SetGenerator(
        framework::DefaultCUDAGenerator(p.GetDeviceId()).get());
207
#endif
208 209 210 211 212 213 214 215 216 217 218 219 220
  } else {
    dev_ctx->SetAllocator(
        memory::allocation::AllocatorFacade::Instance().GetAllocator(p).get());
    dev_ctx->SetGenerator(framework::DefaultCPUGenerator().get());
  }
  dev_ctx->SetHostGenerator(framework::DefaultCPUGenerator().get());
  dev_ctx->SetHostAllocator(memory::allocation::AllocatorFacade::Instance()
                                .GetAllocator(platform::CPUPlace())
                                .get());
  dev_ctx->SetZeroAllocator(memory::allocation::AllocatorFacade::Instance()
                                .GetZeroAllocator(p)
                                .get());
  return PtrType(dev_ctx);
C
chengduozh 已提交
221 222
}

223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
template <typename DevCtx>
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        place_to_device_context,
    platform::Place place, bool disable_setting_default_stream_for_allocator) {
  // lazy evaluation. i.e., only create device context at first `Get`
  place_to_device_context->emplace(
      place, std::async(std::launch::deferred, CreateDeviceContext<DevCtx>,
                        place, disable_setting_default_stream_for_allocator));
}

void EmplaceDeviceContexts(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        place_to_device_context,
    const std::vector<platform::Place>& places,
    bool disable_setting_default_stream_for_allocator) {
G
GaoWei8 已提交
239 240 241 242 243
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
244

245
  std::set<Place> set;
Y
Yu Yang 已提交
246 247 248
  for (auto& p : places) {
    set.insert(p);
  }
249

Y
Yu Yang 已提交
250 251
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
252
#ifdef PADDLE_WITH_MKLDNN
253 254 255
      EmplaceDeviceContext<MKLDNNDeviceContext>(
          place_to_device_context, p,
          disable_setting_default_stream_for_allocator);
256
#else
257 258 259
      EmplaceDeviceContext<CPUDeviceContext>(
          place_to_device_context, p,
          disable_setting_default_stream_for_allocator);
260
#endif
Y
Yu Yang 已提交
261
    } else if (platform::is_gpu_place(p)) {
262
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
263 264 265
      EmplaceDeviceContext<CUDADeviceContext>(
          place_to_device_context, p,
          disable_setting_default_stream_for_allocator);
D
dzhwinter 已提交
266
#else
G
GaoWei8 已提交
267 268 269
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
270 271
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
272
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
273 274 275
      EmplaceDeviceContext<CUDAPinnedDeviceContext>(
          place_to_device_context, p,
          disable_setting_default_stream_for_allocator);
C
chengduoZH 已提交
276
#else
G
GaoWei8 已提交
277
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
278 279
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
280 281 282
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
283 284 285
      EmplaceDeviceContext<XPUDeviceContext>(
          place_to_device_context, p,
          disable_setting_default_stream_for_allocator);
286 287 288 289
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
F
fwenguang 已提交
290 291 292
#endif
    } else if (platform::is_mlu_place(p)) {
#ifdef PADDLE_WITH_MLU
293 294 295
      EmplaceDeviceContext<MLUDeviceContext>(
          place_to_device_context, p,
          disable_setting_default_stream_for_allocator);
F
fwenguang 已提交
296 297 298 299
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("MLUPlace is not supported. Please "
                                          "re-compile with WITH_MLU option."));
J
jianghaicheng 已提交
300 301 302
#endif
    } else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU
303 304 305
      EmplaceDeviceContext<IPUDeviceContext>(
          place_to_device_context, p,
          disable_setting_default_stream_for_allocator);
J
jianghaicheng 已提交
306 307 308 309
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("IPUPlace is not supported. Please "
                                          "re-compile with WITH_IPU option."));
310 311 312
#endif
    } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
313 314 315
      EmplaceDeviceContext<NPUDeviceContext>(
          place_to_device_context, p,
          disable_setting_default_stream_for_allocator);
316 317 318 319
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPlace is not supported. Please "
          "re-compile with WITH_ASCEND_CL option."));
320 321 322
#endif
    } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
323 324 325
      EmplaceDeviceContext<NPUPinnedDeviceContext>(
          place_to_device_context, p,
          disable_setting_default_stream_for_allocator);
326 327 328 329 330
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPinnedPlace is not supported. Please re-compile with "
          "WITH_ASCEND_CL "
          "option."));
331 332 333
#endif
    } else if (platform::is_custom_place(p)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
334 335 336
      EmplaceDeviceContext<CustomDeviceContext>(
          place_to_device_context, p,
          disable_setting_default_stream_for_allocator);
337 338 339 340 341
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "CustomPlace is not supported. Please re-compile with "
          "WITH_CUSTOM_DEVICE "
          "option."));
D
dzhwinter 已提交
342 343 344 345 346
#endif
    }
  }
}

347 348 349 350 351 352
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
  EmplaceDeviceContexts(&device_contexts_, places,
                        /*disable_setting_default_stream_for_allocator=*/false);
}

353 354
CPUDeviceContext::CPUDeviceContext() : phi::CPUContext() {
  phi::CPUContext::Init();
W
Wilber 已提交
355
}
356

357 358
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : phi::CPUContext(place) {
  phi::CPUContext::Init();
W
Wilber 已提交
359
}
360

J
jianghaicheng 已提交
361
#ifdef PADDLE_WITH_IPU
A
Allen Guo 已提交
362
IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) {}
J
jianghaicheng 已提交
363

W
Wilber 已提交
364
const Place& IPUDeviceContext::GetPlace() const { return place_; }
A
Allen Guo 已提交
365

J
jianghaicheng 已提交
366 367 368 369 370 371 372
void IPUDeviceContext::Wait() const {
  /*! \brief  Wait for all operations completion in the stream. */
}

IPUDeviceContext::~IPUDeviceContext() {}

#endif
373
#ifdef PADDLE_WITH_XPU
374 375
XPUDeviceContext::XPUDeviceContext() : phi::XPUContext() {
  phi::XPUContext::Init();
W
Wilber 已提交
376
}
377

378
XPUDeviceContext::~XPUDeviceContext() {}
379

380 381
XPUDeviceContext::XPUDeviceContext(XPUPlace place) : phi::XPUContext(place) {
  phi::XPUContext::Init();
382
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: "
W
Wilber 已提交
383
                          << static_cast<int>(place.device);
384 385 386
}
#endif

387 388 389 390 391 392 393
#ifdef PADDLE_WITH_ASCEND_CL
NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) {
  NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device));
  // NOTE(zhiqiu): Usually, no need to create context explicitly,
  // ACL creates a default context which contains 1 default stream
  // and 1 sync strean after aclrtSetDevice.
394
  platform::GetCurrentNPUContext(&context_);
395 396 397 398 399 400 401
  stream_.reset(new stream::NPUStream(place));
}

NPUDeviceContext::~NPUDeviceContext() {
  // NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_));
}
402

403
void NPUDeviceContext::Wait() const {
404 405
  platform::RecordEvent record_event("NPUDeviceContext/wait",
                                     platform::TracerEventType::UserDefined, 2);
406 407
  VLOG(4) << "NPU context(" << this << ")  Wait";
  stream_->Wait();
408 409 410 411
}

aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); }

W
Wilber 已提交
412
const Place& NPUDeviceContext::GetPlace() const { return place_; }
413 414

aclrtContext NPUDeviceContext::context() const { return context_; }
415 416 417 418 419 420 421 422 423 424 425 426 427 428

NPUPinnedDeviceContext::NPUPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

NPUPinnedDeviceContext::NPUPinnedDeviceContext(NPUPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* NPUPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

W
Wilber 已提交
429
const Place& NPUPinnedDeviceContext::GetPlace() const { return place_; }
430

431 432 433
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
init  
qijun 已提交
434 435 436 437 438 439 440
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

441
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
442 443 444 445 446
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

447
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
448

449 450 451
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
452
  const cudaDeviceProp& deviceProperties() const override {
453
#endif
Q
init  
qijun 已提交
454 455 456 457
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
458 459 460
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
461 462 463
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
464
    void* retv = buf->ptr();
S
sneaxiy 已提交
465 466 467 468
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
469
    return retv;
Q
init  
qijun 已提交
470 471
  }

S
sneaxiy 已提交
472 473 474 475 476 477
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
478 479 480

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
481
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
482 483 484 485 486 487
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
488
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
489
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
490
#ifdef PADDLE_WITH_HIP
491
      PADDLE_ENFORCE_GPU_SUCCESS(
492 493
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
494
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
init  
qijun 已提交
495
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
496
#endif
Q
init  
qijun 已提交
497 498 499 500 501
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
502
  CUDAPlace place_;
503 504 505 506
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
507
  const cudaDeviceProp* device_prop_;  // not owned;
508
#endif
Q
qijun 已提交
509
  mutable void* scratch_;
Q
init  
qijun 已提交
510
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
511
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
512
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
513 514
};

515 516 517 518 519 520 521 522 523
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

524 525 526 527 528 529 530 531 532 533 534 535
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
536 537
                         const stream::Priority& priority,
                         const stream::StreamFlag& flag) {
538 539
  place_ = place;
  CUDADeviceGuard guard(place_.device);
540
  stream_.reset(new stream::CUDAStream(place, priority, flag));
541 542 543
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
544
#ifndef PADDLE_WITH_HIP
545 546 547
#if CUDA_VERSION >= 11060
  InitCuBlasLtContext();
#endif
Z
zhangkaihuo 已提交
548
  InitCuSparseContext();
G
Guo Sheng 已提交
549
  InitCuSolverContext();
550
#endif
551 552
}

W
Wilber 已提交
553 554 555 556 557 558
void CUDAContext::SetStream(gpuStream_t stream) {
  if (stream_->raw_stream() != stream) {
    CUDADeviceGuard guard(place_.device);
    DestoryCuDNNContext();
    DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
559 560 561
#if CUDA_VERSION >= 11060
    DestoryCuBlasLtContext();
#endif
W
Wilber 已提交
562 563 564 565 566 567 568 569 570
    DestoryCuSolverContext();
#endif

    stream_->SetStream(stream);

    InitEigenContext();
    InitCuBlasContext();
    InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
571 572 573
#if CUDA_VERSION >= 11060
    InitCuBlasLtContext();
#endif
W
Wilber 已提交
574 575 576 577 578
    InitCuSolverContext();
#endif
  }
}

579 580 581 582
CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
583
#ifndef PADDLE_WITH_HIP
584 585 586
#if CUDA_VERSION >= 11060
  InitCuBlasLtContext();
#endif
Z
zhangkaihuo 已提交
587
  DestoryCuSparseContext();
G
Guo Sheng 已提交
588
  DestoryCuSolverContext();
589
#endif
590 591
}

592 593 594
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : phi::GPUContext(place) {
  phi::GPUContext::PartialInitWithoutAllocator();
  cuda_stream_.reset(new stream::CUDAStream(phi::GPUContext::stream(), place));
595 596
}

W
Wilber 已提交
597
CUDADeviceContext::~CUDADeviceContext() = default;
598

599
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
W
Wilber 已提交
600 601 602
  if (thread_ctx_.count(this)) {
    return context()->EigenDevice().get();
  }
603
  return phi::GPUContext::eigen_device();
S
sneaxiy 已提交
604 605
}

W
Wilber 已提交
606
void CUDADeviceContext::Wait() const {
607
  VLOG(4) << "CUDA context(" << this << ")  Wait";
W
Wilber 已提交
608 609 610 611
  if (thread_ctx_.count(this)) {
    context()->Stream()->Wait();
    return;
  }
612
  phi::GPUContext::Wait();
613 614
}

615 616 617
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
618
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
619
#endif
W
Wilber 已提交
620 621 622
  if (thread_ctx_.count(this)) {
    return context()->CudnnHandle();
  }
623
  return phi::GPUContext::cudnn_handle();
624
}
625

626 627
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
W
Wilber 已提交
628 629 630
  if (thread_ctx_.count(this)) {
    return context()->CublasHandle()->GetCublasHandle();
  }
631
  return phi::GPUContext::cublas_handle();
632 633
}
#else
634
cublasHandle_t CUDADeviceContext::cublas_handle() const {
W
Wilber 已提交
635 636 637
  if (thread_ctx_.count(this)) {
    return context()->CublasHandle()->GetCublasHandle();
  }
638
  return phi::GPUContext::cublas_handle();
639
}
640 641 642 643 644 645 646 647
#if CUDA_VERSION >= 11060
cublasLtHandle_t CUDADeviceContext::cublaslt_handle() const {
  if (thread_ctx_.count(this)) {
    return context()->CublasLtHandle()->GetCublasLtHandle();
  }
  return phi::GPUContext::cublaslt_handle();
}
#endif
Z
zhangkaihuo 已提交
648
cusparseHandle_t CUDADeviceContext::cusparse_handle() const {
W
Wilber 已提交
649 650 651
  if (thread_ctx_.count(this)) {
    return context()->CusparseHandle()->GetCusparseHandle();
  }
652
  return phi::GPUContext::cusparse_handle();
W
Wilber 已提交
653 654 655 656 657
}
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  if (thread_ctx_.count(this)) {
    return context()->CusolverDnHandle();
  }
658
  return phi::GPUContext::cusolver_dn_handle();
Z
zhangkaihuo 已提交
659
}
660
#endif
661

W
Wilber 已提交
662 663 664 665 666 667
void CUDADeviceContext::RecordEvent(
    gpuEvent_t ev, const std::function<void()>& callback) const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->RecordEvent(ev, callback);
    return;
  }
668
  phi::GPUContext::RecordEvent(ev, callback);
W
Wilber 已提交
669 670 671 672 673 674 675 676
}

void CUDADeviceContext::AddStreamCallback(
    const std::function<void()>& callback) const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->AddCallback(callback);
    return;
  }
677
  phi::GPUContext::AddStreamCallback(callback);
W
Wilber 已提交
678 679 680 681 682 683 684
}

void CUDADeviceContext::WaitStreamCallback() const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->WaitCallback();
    return;
  }
685
  phi::GPUContext::WaitStreamCallback();
W
Wilber 已提交
686 687
}

688
phi::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
W
Wilber 已提交
689 690
  if (thread_ctx_.count(this)) {
    // return workspace_.get();
691
    return phi::DnnWorkspaceHandle(
W
Wilber 已提交
692
        memory::allocation::AllocatorFacade::Instance()
693
            .GetAllocator(GetPlace())
694 695
            .get(),
        stream());
W
Wilber 已提交
696
  }
697
  return phi::GPUContext::cudnn_workspace_handle();
698
}
699

W
Wilber 已提交
700 701 702 703
gpuStream_t CUDADeviceContext::stream() const {
  if (thread_ctx_.count(this)) {
    return context()->RawStream();
  }
704
  return phi::GPUContext::stream();
G
Guo Sheng 已提交
705 706
}

W
Wilber 已提交
707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725
std::shared_ptr<CUDAContext> CUDADeviceContext::context() const {
  if (!thread_ctx_.count(this)) {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "CUDADeviceContext call context() failed, make sure in the "
        "thread_local semantic."));
  }
  return thread_ctx_.at(this);
}

stream::CUDAStream* CUDADeviceContext::GetCudaStream() const {
  return cuda_stream_.get();
}

stream::CUDAStream* CUDADeviceContext::SetCudaStream(
    stream::CUDAStream* new_stream_ptr) {
  auto* old_stream_ptr = cuda_stream_.release();
  cuda_stream_.reset(new_stream_ptr);
  return old_stream_ptr;
}
Q
qijun 已提交
726

C
chengduoZH 已提交
727 728 729 730 731 732 733 734 735 736 737 738 739
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

W
Wilber 已提交
740
const Place& CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
741
#endif
Q
qijun 已提交
742

T
tensor-tang 已提交
743 744
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
745
    : CPUDeviceContext(place), p_blobmap_() {
746
  p_blobmap_.reset(new BlobMap());
747
  p_exec_items_.reset(new ExecShape());
748
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
749 750
}

751
MKLDNNDeviceContextThreadLocals::Body::Body()
752
    : cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
753 754 755 756 757 758
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

759 760 761 762 763 764 765 766 767 768 769 770
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
771
  dev_ctx->ResetBlobMap(exec_ptr_);
772 773
}

774 775 776 777 778 779 780 781 782 783
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
784 785
  cur_input_shape_str = input_shape_str;
}
786 787
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
788 789
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
790

791 792
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
793 794 795
  cur_paddle_data_layout = dl;
}

796 797
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
798 799 800
  return cur_paddle_data_layout;
}

801 802 803 804 805 806 807 808 809
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

810
const dnnl::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
811 812 813
  return cur_engine;
}

814
dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
815 816 817
  return cur_stream;
}

818
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
L
Leo Chen 已提交
819
  VLOG(4) << tls().get_curr_exec() << " " << ptr;
820
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
821
  if (block_next_cache_clearing_ == 0) {
822
    VLOG(3) << "Clearing DNNL cache.";
823 824 825 826 827 828
    // If no specific executor pointer then clear
    // everything. For executor pointer then clear only
    // objects allocated when using given executor
    if (ptr == nullptr) {
      p_blobmap_->clear();
    } else {
829 830 831 832 833
      // Iterate through all shapes and release
      // for each shape and active executor all entries
      // of this executor
      for (auto& s : *p_exec_items_) {
        for (auto& v : (*s.second)[ptr]) {
834
          (v.first)->erase(v.second);
835 836
        }
        s.second->erase(ptr);
837 838
      }
    }
839 840 841 842
    // Reset paddle layout to NCHW
    VLOG(3) << "Resetting Paddle data layout to NCHW.";
    platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
        paddle::framework::DataLayout::kNCHW);
843
  } else {
844 845 846 847 848 849 850 851 852
    --block_next_cache_clearing_;
    VLOG(3) << "Prevented Clearing DNNL cache. Updated "
               "block_next_cache_clearing_ : "
            << block_next_cache_clearing_;
    PADDLE_ENFORCE_GE(block_next_cache_clearing_, 0,
                      platform::errors::InvalidArgument(
                          "Cache clearing mark should be non-negative "
                          ". But received %d.",
                          block_next_cache_clearing_));
853 854 855
  }
}

856 857
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
  p_exec_items_->erase(p_exec_items_->begin());
858 859
}

860 861
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
                                                KeyBlob::iterator it) const {
862
  // Take current input shape from TLS
863 864
  // Take current executor addess from TLS
  // and for this executor's items add the one defined with arguments
865 866 867 868 869 870 871 872 873
  auto key_it = p_exec_items_
                    ->insert(std::make_pair(tls().cur_input_shape_str,
                                            std::make_shared<ExecMap>()))
                    .first;
  (*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));

  VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
          << " curr exec size: "
          << (*key_it->second)[tls().get_curr_exec()].size() << "\n";
874 875
}

876 877
void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
878 879 880 881
  ++block_next_cache_clearing_;
  VLOG(3) << "Next DNNL cache clearing has been blocked. Updated "
             "block_next_cache_clearing_ : "
          << block_next_cache_clearing_;
882
}
883

884
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
885
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
886
  BlobMap* pMap = p_blobmap_.get();
887
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
888
  if (map_it == pMap->end()) {
889 890 891
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
892 893 894 895
  }
  return map_it->second->size();
}

896
void MKLDNNDeviceContext::SetBlob(const std::string& name,
897
                                  BlobPtr_t<void> data) const {
898
  BlobMap* pMap = p_blobmap_.get();
899
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
900
  BlobPtr_t<KeyBlob> pBlob = nullptr;
901

902
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
903

904
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
905

906 907
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
908 909 910

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
911
    sBlob = std::make_shared<ShapeBlob>();
912 913
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
914
  } else {
915
    sBlob = map_it->second;
916
  }
T
tensor-tang 已提交
917

918
  // Find KeyBlob for current input shape
919
  auto key_it = sBlob->find(tls().cur_input_shape_str);
920

921
  if (key_it == sBlob->end()) {
922 923
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
924 925
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
926
        sBlob->size() &&
927
        (sBlob->size() >=
928
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
929 930 931 932
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
      RemoveShapeEntriesWithExecutor();
933
    }
934
    pBlob = std::make_shared<KeyBlob>();
935
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
936
  } else {
937
    pBlob = key_it->second;
938 939
  }

940
  // Find Blob via name
941 942 943 944
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    auto el =
        pBlob->insert(std::make_pair(name, data));  //  (*pBlob)[name] = data;
945 946 947
    // Register new element in per executor map
    // to have easily erased when executor terminated
    LinkEntryWithExecutor(pBlob, el.first);
948 949 950
  } else {
    blob_it->second = data;  // set data to existing blob
  }
951
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
952
  // lock will be automatically released when out of scope
953
  return;
T
tensor-tang 已提交
954 955
}

956
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
957 958 959
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
960
      num_entries += (l2.second)->size();
961 962 963 964 965
    }
  }
  return num_entries;
}

966
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
967
    const std::string& name) const {
968
  BlobMap* pMap = p_blobmap_.get();
969
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
970
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
971

972
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
973

974
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
975

976 977
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
978 979 980 981
  // (jczaja): After first iteration of model's execution we
  // should have all elements cached (mostly) so failures are unlikely (less
  // likely for dynamic shapes)
  if (unlikely(map_it == pMap->end())) {
982
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
983 984 985 986 987
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
988
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
989
  if (unlikely(sBlob_it == sBlob->end())) {
990
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
991 992 993 994
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
995 996

  // Find Blob via name
997
  auto key_it = pBlob->find(name);
998

999
  if (unlikely(key_it == pBlob->end())) {
1000
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
1001 1002
    return nullptr;
  }
1003

1004
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
1005 1006
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
1007 1008
}

1009 1010 1011
#endif

#ifdef PADDLE_WITH_CUSTOM_DEVICE
1012 1013 1014
CustomDeviceContext::CustomDeviceContext(CustomPlace place)
    : phi::CustomContext(place) {
  Init();
1015
  stream_.reset(new phi::stream::Stream(place, stream()));
1016 1017 1018
}

CustomDeviceContext::~CustomDeviceContext() {}
T
tensor-tang 已提交
1019
#endif
Q
qijun 已提交
1020
}  // namespace platform
Q
qijun 已提交
1021
}  // namespace paddle