device_context.cc 30.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3
Copyright (c) 2022 NVIDIA Corporation. All rights reserved.

Q
qijun 已提交
4 5 6 7
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
8

Q
qijun 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
14
#include "paddle/fluid/platform/device_context.h"
W
Wilber 已提交
15
#include <functional>
16
#include <memory>
17
#include <set>
W
Wilber 已提交
18 19
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
20 21
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/allocator.h"
22

23
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
24
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
25
#include "paddle/fluid/platform/cuda_device_guard.h"
26
#endif
F
fwenguang 已提交
27 28 29 30
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/fluid/platform/device/mlu/device_context_allocator.h"
#endif
31
#include "glog/logging.h"
32
#include "paddle/fluid/framework/expect.h"
W
Wilber 已提交
33
#include "paddle/fluid/framework/generator.h"
34
#include "paddle/fluid/memory/allocation/allocator_facade.h"
35
#include "paddle/fluid/platform/device/device_wrapper.h"
36
#include "paddle/fluid/platform/profiler.h"
37
#include "paddle/fluid/platform/profiler/event_tracing.h"
38

39 40 41 42 43
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
44
  if (size == 0) {
45 46
    return Alloc(place, size);
  }
47 48

  if (platform::is_gpu_place(place)) {
49
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
68 69
    return Alloc(place, size);
#else
70 71 72
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
F
fwenguang 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
#endif
  } else if (platform::is_mlu_place(place)) {
#ifdef PADDLE_WITH_MLU
    auto* default_dev_ctx = static_cast<platform::MLUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::MLUDeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::MLUDeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use MLU device since it's not compiled with MLU,"
        "Please recompile or reinstall Paddle with MLU support."));
90
#endif
91 92 93
  } else {
    return Alloc(place, size);
  }
94 95 96 97 98
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
99 100 101
namespace paddle {
namespace platform {

102
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
103 104 105
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
106 107 108 109

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
110 111
#endif  // PADDLE_WITH_CUDA

112 113 114 115 116 117 118
DeviceType Place2DeviceType(const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
    return platform::DeviceType::CPU;
  } else if (platform::is_gpu_place(place)) {
    return platform::DeviceType::CUDA;
  } else if (platform::is_xpu_place(place)) {
    return platform::DeviceType::XPU;
F
fwenguang 已提交
119 120
  } else if (platform::is_mlu_place(place)) {
    return platform::DeviceType::MLU;
121 122 123 124 125 126
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported place %s to convert into platform::DeviceType.", place));
  }
}

D
dzhwinter 已提交
127 128
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
129
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
130
  VLOG(6) << "DeviceContextPool Get: " << place;
D
dzhwinter 已提交
131 132
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
133 134
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
F
fwenguang 已提交
135 136
        "with WITH_GPU, WITH_XPU, WITH_IPU, WITH_MLU or WITH_ASCEND_CL option "
        "or check "
J
jianghaicheng 已提交
137 138
        "that your train process set the correct device id if you use "
        "Executor.",
G
GaoWei8 已提交
139
        place));
D
dzhwinter 已提交
140
  }
141
  return it->second.get().get();
D
dzhwinter 已提交
142 143
}

W
Wilber 已提交
144
template <typename DevCtx>
145 146 147 148 149
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
150 151 152 153 154 155 156 157 158 159 160 161
  map_ptr->emplace(
      p, std::async(std::launch::deferred, [=] {
        // lazy evaluation. i.e., only create device context at
        // first `Get`
        auto* dev_ctx = new DevCtx(p);
        if (is_gpu_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
          auto* cuda_ctx = dynamic_cast<CUDADeviceContext*>(dev_ctx);
          PADDLE_ENFORCE_NOT_NULL(
              cuda_ctx,
              platform::errors::InvalidArgument(
                  "Failed to dynamic_cast dev_ctx into CUDADeviceContext."));
W
Wilber 已提交
162
          dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
163
                                    .GetAllocator(p)
W
Wilber 已提交
164
                                    .get());
W
wanghuancoder 已提交
165 166 167 168 169
          dev_ctx->SetPinnedAllocator(
              memory::allocation::AllocatorFacade::Instance()
                  .GetAllocator(paddle::platform::CUDAPinnedPlace())
                  .get());

W
Wilber 已提交
170
          cuda_ctx->PartialInitWithAllocator();
W
Wilber 已提交
171
          dev_ctx->SetGenerator(
172
              framework::DefaultCUDAGenerator(p.GetDeviceId()).get());
173 174
#endif
        } else {
W
Wilber 已提交
175 176 177
          dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
                                    .GetAllocator(p)
                                    .get());
W
Wilber 已提交
178
          dev_ctx->SetGenerator(framework::DefaultCPUGenerator().get());
179
        }
L
Leo Chen 已提交
180
        dev_ctx->SetHostGenerator(framework::DefaultCPUGenerator().get());
181 182 183 184 185 186 187 188 189 190
        dev_ctx->SetHostAllocator(
            memory::allocation::AllocatorFacade::Instance()
                .GetAllocator(platform::CPUPlace())
                .get());
        dev_ctx->SetZeroAllocator(
            memory::allocation::AllocatorFacade::Instance()
                .GetZeroAllocator(p)
                .get());
        return PtrType(dev_ctx);
      }));
C
chengduozh 已提交
191 192
}

D
dzhwinter 已提交
193 194
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
195 196 197 198 199
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
200
  std::set<Place> set;
Y
Yu Yang 已提交
201 202 203 204 205
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
206
#ifdef PADDLE_WITH_MKLDNN
W
Wilber 已提交
207
      EmplaceDeviceContext<MKLDNNDeviceContext>(&device_contexts_, p);
208
#else
W
Wilber 已提交
209
      EmplaceDeviceContext<CPUDeviceContext>(&device_contexts_, p);
210
#endif
Y
Yu Yang 已提交
211
    } else if (platform::is_gpu_place(p)) {
212
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
213
      EmplaceDeviceContext<CUDADeviceContext>(&device_contexts_, p);
D
dzhwinter 已提交
214
#else
G
GaoWei8 已提交
215 216 217
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
218 219
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
220
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
221
      EmplaceDeviceContext<CUDAPinnedDeviceContext>(&device_contexts_, p);
C
chengduoZH 已提交
222
#else
G
GaoWei8 已提交
223
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
224 225
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
226 227 228
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
W
Wilber 已提交
229
      EmplaceDeviceContext<XPUDeviceContext>(&device_contexts_, p);
230 231 232 233
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
F
fwenguang 已提交
234 235 236
#endif
    } else if (platform::is_mlu_place(p)) {
#ifdef PADDLE_WITH_MLU
W
Wilber 已提交
237
      EmplaceDeviceContext<MLUDeviceContext>(&device_contexts_, p);
F
fwenguang 已提交
238 239 240 241
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("MLUPlace is not supported. Please "
                                          "re-compile with WITH_MLU option."));
J
jianghaicheng 已提交
242 243 244
#endif
    } else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU
W
Wilber 已提交
245
      EmplaceDeviceContext<IPUDeviceContext>(&device_contexts_, p);
J
jianghaicheng 已提交
246 247 248 249
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("IPUPlace is not supported. Please "
                                          "re-compile with WITH_IPU option."));
250 251 252
#endif
    } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
253
      EmplaceDeviceContext<NPUDeviceContext>(&device_contexts_, p);
254 255 256 257
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPlace is not supported. Please "
          "re-compile with WITH_ASCEND_CL option."));
258 259 260
#endif
    } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
261
      EmplaceDeviceContext<NPUPinnedDeviceContext>(&device_contexts_, p);
262 263 264 265 266
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPinnedPlace is not supported. Please re-compile with "
          "WITH_ASCEND_CL "
          "option."));
267 268 269 270 271 272 273 274 275
#endif
    } else if (platform::is_custom_place(p)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
      EmplaceDeviceContext<CustomDeviceContext>(&device_contexts_, p);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "CustomPlace is not supported. Please re-compile with "
          "WITH_CUSTOM_DEVICE "
          "option."));
D
dzhwinter 已提交
276 277 278 279 280
#endif
    }
  }
}

281 282
CPUDeviceContext::CPUDeviceContext() : phi::CPUContext() {
  phi::CPUContext::Init();
W
Wilber 已提交
283
}
284

285 286
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : phi::CPUContext(place) {
  phi::CPUContext::Init();
W
Wilber 已提交
287
}
288

J
jianghaicheng 已提交
289
#ifdef PADDLE_WITH_IPU
A
Allen Guo 已提交
290
IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) {}
J
jianghaicheng 已提交
291

W
Wilber 已提交
292
const Place& IPUDeviceContext::GetPlace() const { return place_; }
A
Allen Guo 已提交
293

J
jianghaicheng 已提交
294 295 296 297 298 299 300
void IPUDeviceContext::Wait() const {
  /*! \brief  Wait for all operations completion in the stream. */
}

IPUDeviceContext::~IPUDeviceContext() {}

#endif
301
#ifdef PADDLE_WITH_XPU
302 303
XPUDeviceContext::XPUDeviceContext() : phi::XPUContext() {
  phi::XPUContext::Init();
W
Wilber 已提交
304
}
305

306
XPUDeviceContext::~XPUDeviceContext() {}
307

308 309
XPUDeviceContext::XPUDeviceContext(XPUPlace place) : phi::XPUContext(place) {
  phi::XPUContext::Init();
310
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: "
W
Wilber 已提交
311
                          << static_cast<int>(place.device);
312 313 314
}
#endif

315 316 317 318 319 320 321
#ifdef PADDLE_WITH_ASCEND_CL
NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) {
  NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device));
  // NOTE(zhiqiu): Usually, no need to create context explicitly,
  // ACL creates a default context which contains 1 default stream
  // and 1 sync strean after aclrtSetDevice.
322
  platform::GetCurrentNPUContext(&context_);
323 324 325 326 327 328 329
  stream_.reset(new stream::NPUStream(place));
}

NPUDeviceContext::~NPUDeviceContext() {
  // NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_));
}
330

331
void NPUDeviceContext::Wait() const {
332 333
  platform::RecordEvent record_event("NPUDeviceContext/wait",
                                     platform::TracerEventType::UserDefined, 2);
334 335
  VLOG(4) << "NPU context(" << this << ")  Wait";
  stream_->Wait();
336 337 338 339
}

aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); }

W
Wilber 已提交
340
const Place& NPUDeviceContext::GetPlace() const { return place_; }
341 342

aclrtContext NPUDeviceContext::context() const { return context_; }
343 344 345 346 347 348 349 350 351 352 353 354 355 356

NPUPinnedDeviceContext::NPUPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

NPUPinnedDeviceContext::NPUPinnedDeviceContext(NPUPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* NPUPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

W
Wilber 已提交
357
const Place& NPUPinnedDeviceContext::GetPlace() const { return place_; }
358

359 360 361
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
init  
qijun 已提交
362 363 364 365 366 367 368
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

369
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
370 371 372 373 374
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

375
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
376

377 378 379
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
380
  const cudaDeviceProp& deviceProperties() const override {
381
#endif
Q
init  
qijun 已提交
382 383 384 385
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
386 387 388
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
389 390 391
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
392
    void* retv = buf->ptr();
S
sneaxiy 已提交
393 394 395 396
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
397
    return retv;
Q
init  
qijun 已提交
398 399
  }

S
sneaxiy 已提交
400 401 402 403 404 405
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
406 407 408

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
409
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
410 411 412 413 414 415
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
416
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
417
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
418
#ifdef PADDLE_WITH_HIP
419
      PADDLE_ENFORCE_GPU_SUCCESS(
420 421
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
422
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
init  
qijun 已提交
423
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
424
#endif
Q
init  
qijun 已提交
425 426 427 428 429
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
430
  CUDAPlace place_;
431 432 433 434
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
435
  const cudaDeviceProp* device_prop_;  // not owned;
436
#endif
Q
qijun 已提交
437
  mutable void* scratch_;
Q
init  
qijun 已提交
438
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
439
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
440
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
441 442
};

443 444 445 446 447 448 449 450 451
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

452 453 454 455 456 457 458 459 460 461 462 463
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
464 465
                         const stream::Priority& priority,
                         const stream::StreamFlag& flag) {
466 467
  place_ = place;
  CUDADeviceGuard guard(place_.device);
468
  stream_.reset(new stream::CUDAStream(place, priority, flag));
469 470 471
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
472
#ifndef PADDLE_WITH_HIP
473 474 475
#if CUDA_VERSION >= 11060
  InitCuBlasLtContext();
#endif
Z
zhangkaihuo 已提交
476
  InitCuSparseContext();
G
Guo Sheng 已提交
477
  InitCuSolverContext();
478
#endif
479 480
}

W
Wilber 已提交
481 482 483 484 485 486
void CUDAContext::SetStream(gpuStream_t stream) {
  if (stream_->raw_stream() != stream) {
    CUDADeviceGuard guard(place_.device);
    DestoryCuDNNContext();
    DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
487 488 489
#if CUDA_VERSION >= 11060
    DestoryCuBlasLtContext();
#endif
W
Wilber 已提交
490 491 492 493 494 495 496 497 498
    DestoryCuSolverContext();
#endif

    stream_->SetStream(stream);

    InitEigenContext();
    InitCuBlasContext();
    InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
499 500 501
#if CUDA_VERSION >= 11060
    InitCuBlasLtContext();
#endif
W
Wilber 已提交
502 503 504 505 506
    InitCuSolverContext();
#endif
  }
}

507 508 509 510
CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
511
#ifndef PADDLE_WITH_HIP
512 513 514
#if CUDA_VERSION >= 11060
  InitCuBlasLtContext();
#endif
Z
zhangkaihuo 已提交
515
  DestoryCuSparseContext();
G
Guo Sheng 已提交
516
  DestoryCuSolverContext();
517
#endif
518 519
}

520 521 522
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : phi::GPUContext(place) {
  phi::GPUContext::PartialInitWithoutAllocator();
  cuda_stream_.reset(new stream::CUDAStream(phi::GPUContext::stream(), place));
523 524
  auto& instance = memory::allocation::AllocatorFacade::Instance();
  instance.SetDefaultStream(place, phi::GPUContext::stream());
525 526
  workspace_.reset(new phi::DnnWorkspaceHandle(
      instance.GetAllocator(place).get(), stream()));
527 528
}

W
Wilber 已提交
529
CUDADeviceContext::~CUDADeviceContext() = default;
530

531
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
W
Wilber 已提交
532 533 534
  if (thread_ctx_.count(this)) {
    return context()->EigenDevice().get();
  }
535
  return phi::GPUContext::eigen_device();
S
sneaxiy 已提交
536 537
}

W
Wilber 已提交
538
void CUDADeviceContext::Wait() const {
539
  VLOG(4) << "CUDA context(" << this << ")  Wait";
W
Wilber 已提交
540 541 542 543
  if (thread_ctx_.count(this)) {
    context()->Stream()->Wait();
    return;
  }
544
  phi::GPUContext::Wait();
545 546
}

547 548 549
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
550
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
551
#endif
W
Wilber 已提交
552 553 554
  if (thread_ctx_.count(this)) {
    return context()->CudnnHandle();
  }
555
  return phi::GPUContext::cudnn_handle();
556
}
557

558 559
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
W
Wilber 已提交
560 561 562
  if (thread_ctx_.count(this)) {
    return context()->CublasHandle()->GetCublasHandle();
  }
563
  return phi::GPUContext::cublas_handle();
564 565
}
#else
566
cublasHandle_t CUDADeviceContext::cublas_handle() const {
W
Wilber 已提交
567 568 569
  if (thread_ctx_.count(this)) {
    return context()->CublasHandle()->GetCublasHandle();
  }
570
  return phi::GPUContext::cublas_handle();
571
}
572 573 574 575 576 577 578 579
#if CUDA_VERSION >= 11060
cublasLtHandle_t CUDADeviceContext::cublaslt_handle() const {
  if (thread_ctx_.count(this)) {
    return context()->CublasLtHandle()->GetCublasLtHandle();
  }
  return phi::GPUContext::cublaslt_handle();
}
#endif
Z
zhangkaihuo 已提交
580
cusparseHandle_t CUDADeviceContext::cusparse_handle() const {
W
Wilber 已提交
581 582 583
  if (thread_ctx_.count(this)) {
    return context()->CusparseHandle()->GetCusparseHandle();
  }
584
  return phi::GPUContext::cusparse_handle();
W
Wilber 已提交
585 586 587 588 589
}
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  if (thread_ctx_.count(this)) {
    return context()->CusolverDnHandle();
  }
590
  return phi::GPUContext::cusolver_dn_handle();
Z
zhangkaihuo 已提交
591
}
592
#endif
593

W
Wilber 已提交
594 595 596 597 598 599
void CUDADeviceContext::RecordEvent(
    gpuEvent_t ev, const std::function<void()>& callback) const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->RecordEvent(ev, callback);
    return;
  }
600
  phi::GPUContext::RecordEvent(ev, callback);
W
Wilber 已提交
601 602 603 604 605 606 607 608
}

void CUDADeviceContext::AddStreamCallback(
    const std::function<void()>& callback) const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->AddCallback(callback);
    return;
  }
609
  phi::GPUContext::AddStreamCallback(callback);
W
Wilber 已提交
610 611 612 613 614 615 616
}

void CUDADeviceContext::WaitStreamCallback() const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->WaitCallback();
    return;
  }
617
  phi::GPUContext::WaitStreamCallback();
W
Wilber 已提交
618 619
}

620
phi::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
W
Wilber 已提交
621 622
  if (thread_ctx_.count(this)) {
    // return workspace_.get();
623
    return phi::DnnWorkspaceHandle(
W
Wilber 已提交
624
        memory::allocation::AllocatorFacade::Instance()
625
            .GetAllocator(GetPlace())
626 627
            .get(),
        stream());
W
Wilber 已提交
628
  }
629
  return phi::GPUContext::cudnn_workspace_handle();
630
}
631

W
Wilber 已提交
632 633 634 635
gpuStream_t CUDADeviceContext::stream() const {
  if (thread_ctx_.count(this)) {
    return context()->RawStream();
  }
636
  return phi::GPUContext::stream();
G
Guo Sheng 已提交
637 638
}

W
Wilber 已提交
639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657
std::shared_ptr<CUDAContext> CUDADeviceContext::context() const {
  if (!thread_ctx_.count(this)) {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "CUDADeviceContext call context() failed, make sure in the "
        "thread_local semantic."));
  }
  return thread_ctx_.at(this);
}

stream::CUDAStream* CUDADeviceContext::GetCudaStream() const {
  return cuda_stream_.get();
}

stream::CUDAStream* CUDADeviceContext::SetCudaStream(
    stream::CUDAStream* new_stream_ptr) {
  auto* old_stream_ptr = cuda_stream_.release();
  cuda_stream_.reset(new_stream_ptr);
  return old_stream_ptr;
}
Q
qijun 已提交
658

C
chengduoZH 已提交
659 660 661 662 663 664 665 666 667 668 669 670 671
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

W
Wilber 已提交
672
const Place& CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
673
#endif
Q
qijun 已提交
674

T
tensor-tang 已提交
675 676
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
677
    : CPUDeviceContext(place), p_blobmap_() {
678
  p_blobmap_.reset(new BlobMap());
679
  p_exec_items_.reset(new ExecShape());
680
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
681 682
}

683
MKLDNNDeviceContextThreadLocals::Body::Body()
684
    : cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
685 686 687 688 689 690
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

691 692 693 694 695 696 697 698 699 700 701 702
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
703
  dev_ctx->ResetBlobMap(exec_ptr_);
704 705
}

706 707 708 709 710 711 712 713 714 715
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
716 717
  cur_input_shape_str = input_shape_str;
}
718 719
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
720 721
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
722

723 724
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
725 726 727
  cur_paddle_data_layout = dl;
}

728 729
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
730 731 732
  return cur_paddle_data_layout;
}

733 734 735 736 737 738 739 740 741
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

742
const dnnl::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
743 744 745
  return cur_engine;
}

746
dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
747 748 749
  return cur_stream;
}

750
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
L
Leo Chen 已提交
751
  VLOG(4) << tls().get_curr_exec() << " " << ptr;
752 753 754
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
755 756 757 758 759 760
    // If no specific executor pointer then clear
    // everything. For executor pointer then clear only
    // objects allocated when using given executor
    if (ptr == nullptr) {
      p_blobmap_->clear();
    } else {
761 762 763 764 765
      // Iterate through all shapes and release
      // for each shape and active executor all entries
      // of this executor
      for (auto& s : *p_exec_items_) {
        for (auto& v : (*s.second)[ptr]) {
766
          (v.first)->erase(v.second);
767 768
        }
        s.second->erase(ptr);
769 770
      }
    }
771 772 773 774 775 776
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

777 778
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
  p_exec_items_->erase(p_exec_items_->begin());
779 780
}

781 782
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
                                                KeyBlob::iterator it) const {
783
  // Take current input shape from TLS
784 785
  // Take current executor addess from TLS
  // and for this executor's items add the one defined with arguments
786 787 788 789 790 791 792 793 794
  auto key_it = p_exec_items_
                    ->insert(std::make_pair(tls().cur_input_shape_str,
                                            std::make_shared<ExecMap>()))
                    .first;
  (*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));

  VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
          << " curr exec size: "
          << (*key_it->second)[tls().get_curr_exec()].size() << "\n";
795 796
}

797 798 799 800
void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
801
}
802

803
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
804
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
805
  BlobMap* pMap = p_blobmap_.get();
806
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
807
  if (map_it == pMap->end()) {
808 809 810
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
811 812 813 814
  }
  return map_it->second->size();
}

815
void MKLDNNDeviceContext::SetBlob(const std::string& name,
816
                                  BlobPtr_t<void> data) const {
817
  BlobMap* pMap = p_blobmap_.get();
818
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
819
  BlobPtr_t<KeyBlob> pBlob = nullptr;
820

821
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
822

823
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
824

825 826
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
827 828 829

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
830
    sBlob = std::make_shared<ShapeBlob>();
831 832
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
833
  } else {
834
    sBlob = map_it->second;
835
  }
T
tensor-tang 已提交
836

837
  // Find KeyBlob for current input shape
838
  auto key_it = sBlob->find(tls().cur_input_shape_str);
839

840
  if (key_it == sBlob->end()) {
841 842
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
843 844
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
845
        sBlob->size() &&
846
        (sBlob->size() >=
847
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
848 849 850 851
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
      RemoveShapeEntriesWithExecutor();
852
    }
853
    pBlob = std::make_shared<KeyBlob>();
854
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
855
  } else {
856
    pBlob = key_it->second;
857 858
  }

859
  // Find Blob via name
860 861 862 863
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    auto el =
        pBlob->insert(std::make_pair(name, data));  //  (*pBlob)[name] = data;
864 865 866
    // Register new element in per executor map
    // to have easily erased when executor terminated
    LinkEntryWithExecutor(pBlob, el.first);
867 868 869
  } else {
    blob_it->second = data;  // set data to existing blob
  }
870
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
871
  // lock will be automatically released when out of scope
872
  return;
T
tensor-tang 已提交
873 874
}

875
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
876 877 878
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
879
      num_entries += (l2.second)->size();
880 881 882 883 884
    }
  }
  return num_entries;
}

885
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
886
    const std::string& name) const {
887
  BlobMap* pMap = p_blobmap_.get();
888
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
889
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
890

891
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
892

893
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
894

895 896
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
897 898 899 900
  // (jczaja): After first iteration of model's execution we
  // should have all elements cached (mostly) so failures are unlikely (less
  // likely for dynamic shapes)
  if (unlikely(map_it == pMap->end())) {
901
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
902 903 904 905 906
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
907
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
908
  if (unlikely(sBlob_it == sBlob->end())) {
909
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
910 911 912 913
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
914 915

  // Find Blob via name
916
  auto key_it = pBlob->find(name);
917

918
  if (unlikely(key_it == pBlob->end())) {
919
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
920 921
    return nullptr;
  }
922

923
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
924 925
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
926 927
}

928 929 930
#endif

#ifdef PADDLE_WITH_CUSTOM_DEVICE
931 932 933
CustomDeviceContext::CustomDeviceContext(CustomPlace place)
    : phi::CustomContext(place) {
  Init();
934
  stream_.reset(new phi::stream::Stream(place, stream()));
935 936 937
}

CustomDeviceContext::~CustomDeviceContext() {}
T
tensor-tang 已提交
938
#endif
Q
qijun 已提交
939
}  // namespace platform
Q
qijun 已提交
940
}  // namespace paddle