device_context.cc 31.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3
Copyright (c) 2022 NVIDIA Corporation. All rights reserved.

Q
qijun 已提交
4 5 6 7
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
8

Q
qijun 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
14
#include "paddle/fluid/platform/device_context.h"
15

W
Wilber 已提交
16
#include <functional>
17
#include <memory>
18
#include <set>
19

W
Wilber 已提交
20 21
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
22 23
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/allocator.h"
24

25
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
26
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
27
#include "paddle/fluid/platform/cuda_device_guard.h"
28
#endif
F
fwenguang 已提交
29 30 31 32
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/fluid/platform/device/mlu/device_context_allocator.h"
#endif
33
#include "glog/logging.h"
34
#include "paddle/fluid/framework/expect.h"
W
Wilber 已提交
35
#include "paddle/fluid/framework/generator.h"
36
#include "paddle/fluid/memory/allocation/allocator_facade.h"
37
#include "paddle/fluid/platform/device/device_wrapper.h"
38
#include "paddle/fluid/platform/profiler.h"
39
#include "paddle/fluid/platform/profiler/event_tracing.h"
40

41 42 43 44 45
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
46
  if (size == 0) {
47 48
    return Alloc(place, size);
  }
49 50

  if (platform::is_gpu_place(place)) {
51
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
52 53 54 55 56
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
57 58 59
      return paddle::memory::Alloc(desired_dev_ctx.GetPlace(), size,
                                   phi::Stream(reinterpret_cast<phi::StreamId>(
                                       desired_dev_ctx.stream())));
60 61 62 63 64 65 66 67 68 69 70 71
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
72 73
    return Alloc(place, size);
#else
74 75 76
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
F
fwenguang 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
#endif
  } else if (platform::is_mlu_place(place)) {
#ifdef PADDLE_WITH_MLU
    auto* default_dev_ctx = static_cast<platform::MLUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::MLUDeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::MLUDeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use MLU device since it's not compiled with MLU,"
        "Please recompile or reinstall Paddle with MLU support."));
94
#endif
95 96 97
  } else {
    return Alloc(place, size);
  }
98 99 100 101 102
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
103 104 105
namespace paddle {
namespace platform {

106
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
107 108 109
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
110 111 112 113

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
114 115
#endif  // PADDLE_WITH_CUDA

116 117 118 119 120 121 122
DeviceType Place2DeviceType(const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
    return platform::DeviceType::CPU;
  } else if (platform::is_gpu_place(place)) {
    return platform::DeviceType::CUDA;
  } else if (platform::is_xpu_place(place)) {
    return platform::DeviceType::XPU;
F
fwenguang 已提交
123 124
  } else if (platform::is_mlu_place(place)) {
    return platform::DeviceType::MLU;
125 126 127 128 129 130
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported place %s to convert into platform::DeviceType.", place));
  }
}

D
dzhwinter 已提交
131 132
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
133
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
134
  VLOG(6) << "DeviceContextPool Get: " << place;
D
dzhwinter 已提交
135 136
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
137 138
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
F
fwenguang 已提交
139 140
        "with WITH_GPU, WITH_XPU, WITH_IPU, WITH_MLU or WITH_ASCEND_CL option "
        "or check "
J
jianghaicheng 已提交
141 142
        "that your train process set the correct device id if you use "
        "Executor.",
G
GaoWei8 已提交
143
        place));
D
dzhwinter 已提交
144
  }
145
  return it->second.get().get();
D
dzhwinter 已提交
146 147
}

W
Wilber 已提交
148
template <typename DevCtx>
149 150 151 152 153
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
154 155 156 157 158 159 160 161 162 163 164 165
  map_ptr->emplace(
      p, std::async(std::launch::deferred, [=] {
        // lazy evaluation. i.e., only create device context at
        // first `Get`
        auto* dev_ctx = new DevCtx(p);
        if (is_gpu_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
          auto* cuda_ctx = dynamic_cast<CUDADeviceContext*>(dev_ctx);
          PADDLE_ENFORCE_NOT_NULL(
              cuda_ctx,
              platform::errors::InvalidArgument(
                  "Failed to dynamic_cast dev_ctx into CUDADeviceContext."));
W
Wilber 已提交
166
          dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
167
                                    .GetAllocator(p)
W
Wilber 已提交
168
                                    .get());
W
wanghuancoder 已提交
169 170 171 172 173
          dev_ctx->SetPinnedAllocator(
              memory::allocation::AllocatorFacade::Instance()
                  .GetAllocator(paddle::platform::CUDAPinnedPlace())
                  .get());

W
Wilber 已提交
174
          cuda_ctx->PartialInitWithAllocator();
W
Wilber 已提交
175
          dev_ctx->SetGenerator(
176
              framework::DefaultCUDAGenerator(p.GetDeviceId()).get());
177 178
#endif
        } else {
W
Wilber 已提交
179 180 181
          dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
                                    .GetAllocator(p)
                                    .get());
W
Wilber 已提交
182
          dev_ctx->SetGenerator(framework::DefaultCPUGenerator().get());
183
        }
L
Leo Chen 已提交
184
        dev_ctx->SetHostGenerator(framework::DefaultCPUGenerator().get());
185 186 187 188 189 190 191 192 193 194
        dev_ctx->SetHostAllocator(
            memory::allocation::AllocatorFacade::Instance()
                .GetAllocator(platform::CPUPlace())
                .get());
        dev_ctx->SetZeroAllocator(
            memory::allocation::AllocatorFacade::Instance()
                .GetZeroAllocator(p)
                .get());
        return PtrType(dev_ctx);
      }));
C
chengduozh 已提交
195 196
}

D
dzhwinter 已提交
197 198
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
199 200 201 202 203
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
204
  std::set<Place> set;
Y
Yu Yang 已提交
205 206 207 208 209
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
210
#ifdef PADDLE_WITH_MKLDNN
W
Wilber 已提交
211
      EmplaceDeviceContext<MKLDNNDeviceContext>(&device_contexts_, p);
212
#else
W
Wilber 已提交
213
      EmplaceDeviceContext<CPUDeviceContext>(&device_contexts_, p);
214
#endif
Y
Yu Yang 已提交
215
    } else if (platform::is_gpu_place(p)) {
216
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
217
      EmplaceDeviceContext<CUDADeviceContext>(&device_contexts_, p);
D
dzhwinter 已提交
218
#else
G
GaoWei8 已提交
219 220 221
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
222 223
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
224
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
225
      EmplaceDeviceContext<CUDAPinnedDeviceContext>(&device_contexts_, p);
C
chengduoZH 已提交
226
#else
G
GaoWei8 已提交
227
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
228 229
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
230 231 232
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
W
Wilber 已提交
233
      EmplaceDeviceContext<XPUDeviceContext>(&device_contexts_, p);
234 235 236 237
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
F
fwenguang 已提交
238 239 240
#endif
    } else if (platform::is_mlu_place(p)) {
#ifdef PADDLE_WITH_MLU
W
Wilber 已提交
241
      EmplaceDeviceContext<MLUDeviceContext>(&device_contexts_, p);
F
fwenguang 已提交
242 243 244 245
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("MLUPlace is not supported. Please "
                                          "re-compile with WITH_MLU option."));
J
jianghaicheng 已提交
246 247 248
#endif
    } else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU
W
Wilber 已提交
249
      EmplaceDeviceContext<IPUDeviceContext>(&device_contexts_, p);
J
jianghaicheng 已提交
250 251 252 253
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("IPUPlace is not supported. Please "
                                          "re-compile with WITH_IPU option."));
254 255 256
#endif
    } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
257
      EmplaceDeviceContext<NPUDeviceContext>(&device_contexts_, p);
258 259 260 261
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPlace is not supported. Please "
          "re-compile with WITH_ASCEND_CL option."));
262 263 264
#endif
    } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
265
      EmplaceDeviceContext<NPUPinnedDeviceContext>(&device_contexts_, p);
266 267 268 269 270
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPinnedPlace is not supported. Please re-compile with "
          "WITH_ASCEND_CL "
          "option."));
271 272 273 274 275 276 277 278 279
#endif
    } else if (platform::is_custom_place(p)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
      EmplaceDeviceContext<CustomDeviceContext>(&device_contexts_, p);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "CustomPlace is not supported. Please re-compile with "
          "WITH_CUSTOM_DEVICE "
          "option."));
D
dzhwinter 已提交
280 281 282 283 284
#endif
    }
  }
}

285 286
CPUDeviceContext::CPUDeviceContext() : phi::CPUContext() {
  phi::CPUContext::Init();
W
Wilber 已提交
287
}
288

289 290
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : phi::CPUContext(place) {
  phi::CPUContext::Init();
W
Wilber 已提交
291
}
292

J
jianghaicheng 已提交
293
#ifdef PADDLE_WITH_IPU
A
Allen Guo 已提交
294
IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) {}
J
jianghaicheng 已提交
295

W
Wilber 已提交
296
const Place& IPUDeviceContext::GetPlace() const { return place_; }
A
Allen Guo 已提交
297

J
jianghaicheng 已提交
298 299 300 301 302 303 304
void IPUDeviceContext::Wait() const {
  /*! \brief  Wait for all operations completion in the stream. */
}

IPUDeviceContext::~IPUDeviceContext() {}

#endif
305
#ifdef PADDLE_WITH_XPU
306 307
XPUDeviceContext::XPUDeviceContext() : phi::XPUContext() {
  phi::XPUContext::Init();
W
Wilber 已提交
308
}
309

310
XPUDeviceContext::~XPUDeviceContext() {}
311

312 313
XPUDeviceContext::XPUDeviceContext(XPUPlace place) : phi::XPUContext(place) {
  phi::XPUContext::Init();
314
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: "
W
Wilber 已提交
315
                          << static_cast<int>(place.device);
316 317 318
}
#endif

319 320 321 322 323 324 325
#ifdef PADDLE_WITH_ASCEND_CL
NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) {
  NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device));
  // NOTE(zhiqiu): Usually, no need to create context explicitly,
  // ACL creates a default context which contains 1 default stream
  // and 1 sync strean after aclrtSetDevice.
326
  platform::GetCurrentNPUContext(&context_);
327 328 329 330 331 332 333
  stream_.reset(new stream::NPUStream(place));
}

NPUDeviceContext::~NPUDeviceContext() {
  // NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_));
}
334

335
void NPUDeviceContext::Wait() const {
336 337
  platform::RecordEvent record_event("NPUDeviceContext/wait",
                                     platform::TracerEventType::UserDefined, 2);
338 339
  VLOG(4) << "NPU context(" << this << ")  Wait";
  stream_->Wait();
340 341 342 343
}

aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); }

W
Wilber 已提交
344
const Place& NPUDeviceContext::GetPlace() const { return place_; }
345 346

aclrtContext NPUDeviceContext::context() const { return context_; }
347 348 349 350 351 352 353 354 355 356 357 358 359 360

NPUPinnedDeviceContext::NPUPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

NPUPinnedDeviceContext::NPUPinnedDeviceContext(NPUPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* NPUPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

W
Wilber 已提交
361
const Place& NPUPinnedDeviceContext::GetPlace() const { return place_; }
362

363 364 365
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
init  
qijun 已提交
366 367 368 369 370 371 372
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

373
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
374 375 376 377 378
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

379
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
380

381 382 383
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
384
  const cudaDeviceProp& deviceProperties() const override {
385
#endif
Q
init  
qijun 已提交
386 387 388 389
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
390 391 392
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
393 394 395
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
396
    void* retv = buf->ptr();
S
sneaxiy 已提交
397 398 399 400
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
401
    return retv;
Q
init  
qijun 已提交
402 403
  }

S
sneaxiy 已提交
404 405 406 407 408 409
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
410 411 412

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
413
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
414 415 416 417 418 419
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
420
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
421
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
422
#ifdef PADDLE_WITH_HIP
423
      PADDLE_ENFORCE_GPU_SUCCESS(
424 425
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
426
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
init  
qijun 已提交
427
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
428
#endif
Q
init  
qijun 已提交
429 430 431 432 433
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
434
  CUDAPlace place_;
435 436 437 438
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
439
  const cudaDeviceProp* device_prop_;  // not owned;
440
#endif
Q
qijun 已提交
441
  mutable void* scratch_;
Q
init  
qijun 已提交
442
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
443
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
444
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
445 446
};

447 448 449 450 451 452 453 454 455
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

456 457 458 459 460 461 462 463 464 465 466 467
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
468 469
                         const stream::Priority& priority,
                         const stream::StreamFlag& flag) {
470 471
  place_ = place;
  CUDADeviceGuard guard(place_.device);
472
  stream_.reset(new stream::CUDAStream(place, priority, flag));
473 474 475
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
476
#ifndef PADDLE_WITH_HIP
477 478 479
#if CUDA_VERSION >= 11060
  InitCuBlasLtContext();
#endif
Z
zhangkaihuo 已提交
480
  InitCuSparseContext();
G
Guo Sheng 已提交
481
  InitCuSolverContext();
482
#endif
483 484
}

W
Wilber 已提交
485 486 487 488 489 490
void CUDAContext::SetStream(gpuStream_t stream) {
  if (stream_->raw_stream() != stream) {
    CUDADeviceGuard guard(place_.device);
    DestoryCuDNNContext();
    DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
491 492 493
#if CUDA_VERSION >= 11060
    DestoryCuBlasLtContext();
#endif
W
Wilber 已提交
494 495 496 497 498 499 500 501 502
    DestoryCuSolverContext();
#endif

    stream_->SetStream(stream);

    InitEigenContext();
    InitCuBlasContext();
    InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
503 504 505
#if CUDA_VERSION >= 11060
    InitCuBlasLtContext();
#endif
W
Wilber 已提交
506 507 508 509 510
    InitCuSolverContext();
#endif
  }
}

511 512 513 514
CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
515
#ifndef PADDLE_WITH_HIP
516 517 518
#if CUDA_VERSION >= 11060
  InitCuBlasLtContext();
#endif
Z
zhangkaihuo 已提交
519
  DestoryCuSparseContext();
G
Guo Sheng 已提交
520
  DestoryCuSolverContext();
521
#endif
522 523
}

524 525 526
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : phi::GPUContext(place) {
  phi::GPUContext::PartialInitWithoutAllocator();
  cuda_stream_.reset(new stream::CUDAStream(phi::GPUContext::stream(), place));
527 528
  auto& instance = memory::allocation::AllocatorFacade::Instance();
  instance.SetDefaultStream(place, phi::GPUContext::stream());
529 530
  workspace_.reset(new phi::DnnWorkspaceHandle(
      instance.GetAllocator(place).get(), stream()));
531 532
}

W
Wilber 已提交
533
CUDADeviceContext::~CUDADeviceContext() = default;
534

535
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
W
Wilber 已提交
536 537 538
  if (thread_ctx_.count(this)) {
    return context()->EigenDevice().get();
  }
539
  return phi::GPUContext::eigen_device();
S
sneaxiy 已提交
540 541
}

W
Wilber 已提交
542
void CUDADeviceContext::Wait() const {
543
  VLOG(4) << "CUDA context(" << this << ")  Wait";
W
Wilber 已提交
544 545 546 547
  if (thread_ctx_.count(this)) {
    context()->Stream()->Wait();
    return;
  }
548
  phi::GPUContext::Wait();
549 550
}

551 552 553
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
554
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
555
#endif
W
Wilber 已提交
556 557 558
  if (thread_ctx_.count(this)) {
    return context()->CudnnHandle();
  }
559
  return phi::GPUContext::cudnn_handle();
560
}
561

562 563
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
W
Wilber 已提交
564 565 566
  if (thread_ctx_.count(this)) {
    return context()->CublasHandle()->GetCublasHandle();
  }
567
  return phi::GPUContext::cublas_handle();
568 569
}
#else
570
cublasHandle_t CUDADeviceContext::cublas_handle() const {
W
Wilber 已提交
571 572 573
  if (thread_ctx_.count(this)) {
    return context()->CublasHandle()->GetCublasHandle();
  }
574
  return phi::GPUContext::cublas_handle();
575
}
576 577 578 579 580 581 582 583
#if CUDA_VERSION >= 11060
cublasLtHandle_t CUDADeviceContext::cublaslt_handle() const {
  if (thread_ctx_.count(this)) {
    return context()->CublasLtHandle()->GetCublasLtHandle();
  }
  return phi::GPUContext::cublaslt_handle();
}
#endif
Z
zhangkaihuo 已提交
584
cusparseHandle_t CUDADeviceContext::cusparse_handle() const {
W
Wilber 已提交
585 586 587
  if (thread_ctx_.count(this)) {
    return context()->CusparseHandle()->GetCusparseHandle();
  }
588
  return phi::GPUContext::cusparse_handle();
W
Wilber 已提交
589 590 591 592 593
}
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  if (thread_ctx_.count(this)) {
    return context()->CusolverDnHandle();
  }
594
  return phi::GPUContext::cusolver_dn_handle();
Z
zhangkaihuo 已提交
595
}
596
#endif
597

W
Wilber 已提交
598 599 600 601 602 603
void CUDADeviceContext::RecordEvent(
    gpuEvent_t ev, const std::function<void()>& callback) const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->RecordEvent(ev, callback);
    return;
  }
604
  phi::GPUContext::RecordEvent(ev, callback);
W
Wilber 已提交
605 606 607 608 609 610 611 612
}

void CUDADeviceContext::AddStreamCallback(
    const std::function<void()>& callback) const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->AddCallback(callback);
    return;
  }
613
  phi::GPUContext::AddStreamCallback(callback);
W
Wilber 已提交
614 615 616 617 618 619 620
}

void CUDADeviceContext::WaitStreamCallback() const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->WaitCallback();
    return;
  }
621
  phi::GPUContext::WaitStreamCallback();
W
Wilber 已提交
622 623
}

624
phi::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
W
Wilber 已提交
625 626
  if (thread_ctx_.count(this)) {
    // return workspace_.get();
627
    return phi::DnnWorkspaceHandle(
W
Wilber 已提交
628
        memory::allocation::AllocatorFacade::Instance()
629
            .GetAllocator(GetPlace())
630 631
            .get(),
        stream());
W
Wilber 已提交
632
  }
633
  return phi::GPUContext::cudnn_workspace_handle();
634
}
635

W
Wilber 已提交
636 637 638 639
gpuStream_t CUDADeviceContext::stream() const {
  if (thread_ctx_.count(this)) {
    return context()->RawStream();
  }
640
  return phi::GPUContext::stream();
G
Guo Sheng 已提交
641 642
}

W
Wilber 已提交
643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661
std::shared_ptr<CUDAContext> CUDADeviceContext::context() const {
  if (!thread_ctx_.count(this)) {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "CUDADeviceContext call context() failed, make sure in the "
        "thread_local semantic."));
  }
  return thread_ctx_.at(this);
}

stream::CUDAStream* CUDADeviceContext::GetCudaStream() const {
  return cuda_stream_.get();
}

stream::CUDAStream* CUDADeviceContext::SetCudaStream(
    stream::CUDAStream* new_stream_ptr) {
  auto* old_stream_ptr = cuda_stream_.release();
  cuda_stream_.reset(new_stream_ptr);
  return old_stream_ptr;
}
Q
qijun 已提交
662

C
chengduoZH 已提交
663 664 665 666 667 668 669 670 671 672 673 674 675
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

W
Wilber 已提交
676
const Place& CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
677
#endif
Q
qijun 已提交
678

T
tensor-tang 已提交
679 680
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
681
    : CPUDeviceContext(place), p_blobmap_() {
682
  p_blobmap_.reset(new BlobMap());
683
  p_exec_items_.reset(new ExecShape());
684
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
685 686
}

687
MKLDNNDeviceContextThreadLocals::Body::Body()
688
    : cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
689 690 691 692 693 694
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

695 696 697 698 699 700 701 702 703 704 705 706
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
707
  dev_ctx->ResetBlobMap(exec_ptr_);
708 709
}

710 711 712 713 714 715 716 717 718 719
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
720 721
  cur_input_shape_str = input_shape_str;
}
722 723
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
724 725
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
726

727 728
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
729 730 731
  cur_paddle_data_layout = dl;
}

732 733
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
734 735 736
  return cur_paddle_data_layout;
}

737 738 739 740 741 742 743 744 745
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

746
const dnnl::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
747 748 749
  return cur_engine;
}

750
dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
751 752 753
  return cur_stream;
}

754
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
L
Leo Chen 已提交
755
  VLOG(4) << tls().get_curr_exec() << " " << ptr;
756
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
757
  if (block_next_cache_clearing_ == 0) {
758
    VLOG(3) << "Clearing DNNL cache.";
759 760 761 762 763 764
    // If no specific executor pointer then clear
    // everything. For executor pointer then clear only
    // objects allocated when using given executor
    if (ptr == nullptr) {
      p_blobmap_->clear();
    } else {
765 766 767 768 769
      // Iterate through all shapes and release
      // for each shape and active executor all entries
      // of this executor
      for (auto& s : *p_exec_items_) {
        for (auto& v : (*s.second)[ptr]) {
770
          (v.first)->erase(v.second);
771 772
        }
        s.second->erase(ptr);
773 774
      }
    }
775 776 777 778
    // Reset paddle layout to NCHW
    VLOG(3) << "Resetting Paddle data layout to NCHW.";
    platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
        paddle::framework::DataLayout::kNCHW);
779
  } else {
780 781 782 783 784 785 786 787 788
    --block_next_cache_clearing_;
    VLOG(3) << "Prevented Clearing DNNL cache. Updated "
               "block_next_cache_clearing_ : "
            << block_next_cache_clearing_;
    PADDLE_ENFORCE_GE(block_next_cache_clearing_, 0,
                      platform::errors::InvalidArgument(
                          "Cache clearing mark should be non-negative "
                          ". But received %d.",
                          block_next_cache_clearing_));
789 790 791
  }
}

792 793
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
  p_exec_items_->erase(p_exec_items_->begin());
794 795
}

796 797
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
                                                KeyBlob::iterator it) const {
798
  // Take current input shape from TLS
799 800
  // Take current executor addess from TLS
  // and for this executor's items add the one defined with arguments
801 802 803 804 805 806 807 808 809
  auto key_it = p_exec_items_
                    ->insert(std::make_pair(tls().cur_input_shape_str,
                                            std::make_shared<ExecMap>()))
                    .first;
  (*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));

  VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
          << " curr exec size: "
          << (*key_it->second)[tls().get_curr_exec()].size() << "\n";
810 811
}

812 813
void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
814 815 816 817
  ++block_next_cache_clearing_;
  VLOG(3) << "Next DNNL cache clearing has been blocked. Updated "
             "block_next_cache_clearing_ : "
          << block_next_cache_clearing_;
818
}
819

820
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
821
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
822
  BlobMap* pMap = p_blobmap_.get();
823
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
824
  if (map_it == pMap->end()) {
825 826 827
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
828 829 830 831
  }
  return map_it->second->size();
}

832
void MKLDNNDeviceContext::SetBlob(const std::string& name,
833
                                  BlobPtr_t<void> data) const {
834
  BlobMap* pMap = p_blobmap_.get();
835
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
836
  BlobPtr_t<KeyBlob> pBlob = nullptr;
837

838
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
839

840
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
841

842 843
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
844 845 846

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
847
    sBlob = std::make_shared<ShapeBlob>();
848 849
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
850
  } else {
851
    sBlob = map_it->second;
852
  }
T
tensor-tang 已提交
853

854
  // Find KeyBlob for current input shape
855
  auto key_it = sBlob->find(tls().cur_input_shape_str);
856

857
  if (key_it == sBlob->end()) {
858 859
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
860 861
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
862
        sBlob->size() &&
863
        (sBlob->size() >=
864
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
865 866 867 868
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
      RemoveShapeEntriesWithExecutor();
869
    }
870
    pBlob = std::make_shared<KeyBlob>();
871
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
872
  } else {
873
    pBlob = key_it->second;
874 875
  }

876
  // Find Blob via name
877 878 879 880
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    auto el =
        pBlob->insert(std::make_pair(name, data));  //  (*pBlob)[name] = data;
881 882 883
    // Register new element in per executor map
    // to have easily erased when executor terminated
    LinkEntryWithExecutor(pBlob, el.first);
884 885 886
  } else {
    blob_it->second = data;  // set data to existing blob
  }
887
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
888
  // lock will be automatically released when out of scope
889
  return;
T
tensor-tang 已提交
890 891
}

892
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
893 894 895
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
896
      num_entries += (l2.second)->size();
897 898 899 900 901
    }
  }
  return num_entries;
}

902
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
903
    const std::string& name) const {
904
  BlobMap* pMap = p_blobmap_.get();
905
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
906
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
907

908
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
909

910
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
911

912 913
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
914 915 916 917
  // (jczaja): After first iteration of model's execution we
  // should have all elements cached (mostly) so failures are unlikely (less
  // likely for dynamic shapes)
  if (unlikely(map_it == pMap->end())) {
918
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
919 920 921 922 923
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
924
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
925
  if (unlikely(sBlob_it == sBlob->end())) {
926
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
927 928 929 930
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
931 932

  // Find Blob via name
933
  auto key_it = pBlob->find(name);
934

935
  if (unlikely(key_it == pBlob->end())) {
936
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
937 938
    return nullptr;
  }
939

940
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
941 942
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
943 944
}

945 946 947
#endif

#ifdef PADDLE_WITH_CUSTOM_DEVICE
948 949 950
CustomDeviceContext::CustomDeviceContext(CustomPlace place)
    : phi::CustomContext(place) {
  Init();
951
  stream_.reset(new phi::stream::Stream(place, stream()));
952 953 954
}

CustomDeviceContext::~CustomDeviceContext() {}
T
tensor-tang 已提交
955
#endif
Q
qijun 已提交
956
}  // namespace platform
Q
qijun 已提交
957
}  // namespace paddle