device_context.cc 29.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
W
Wilber 已提交
13
#include <functional>
14
#include <memory>
15
#include <set>
W
Wilber 已提交
16 17
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
18 19
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/allocator.h"
20

21
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
22
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
23
#include "paddle/fluid/platform/cuda_device_guard.h"
24
#endif
F
fwenguang 已提交
25 26 27 28
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/fluid/platform/device/mlu/device_context_allocator.h"
#endif
29
#include "glog/logging.h"
30
#include "paddle/fluid/framework/expect.h"
W
Wilber 已提交
31
#include "paddle/fluid/framework/generator.h"
32
#include "paddle/fluid/memory/allocation/allocator_facade.h"
33
#include "paddle/fluid/platform/device/device_wrapper.h"
34
#include "paddle/fluid/platform/profiler.h"
35

36 37 38 39 40
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
41
  if (size == 0) {
42 43
    return Alloc(place, size);
  }
44 45

  if (platform::is_gpu_place(place)) {
46
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
65 66
    return Alloc(place, size);
#else
67 68 69
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
F
fwenguang 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
#endif
  } else if (platform::is_mlu_place(place)) {
#ifdef PADDLE_WITH_MLU
    auto* default_dev_ctx = static_cast<platform::MLUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::MLUDeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::MLUDeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use MLU device since it's not compiled with MLU,"
        "Please recompile or reinstall Paddle with MLU support."));
87
#endif
88 89 90
  } else {
    return Alloc(place, size);
  }
91 92 93 94 95
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
96 97 98
namespace paddle {
namespace platform {

99
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
100 101 102
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
103 104 105 106

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
107 108
#endif  // PADDLE_WITH_CUDA

109 110 111 112 113 114 115
DeviceType Place2DeviceType(const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
    return platform::DeviceType::CPU;
  } else if (platform::is_gpu_place(place)) {
    return platform::DeviceType::CUDA;
  } else if (platform::is_xpu_place(place)) {
    return platform::DeviceType::XPU;
F
fwenguang 已提交
116 117
  } else if (platform::is_mlu_place(place)) {
    return platform::DeviceType::MLU;
118 119 120 121 122 123
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported place %s to convert into platform::DeviceType.", place));
  }
}

D
dzhwinter 已提交
124 125
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
126
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
127
  VLOG(6) << "DeviceContextPool Get: " << place;
D
dzhwinter 已提交
128 129
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
130 131
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
F
fwenguang 已提交
132 133
        "with WITH_GPU, WITH_XPU, WITH_IPU, WITH_MLU or WITH_ASCEND_CL option "
        "or check "
J
jianghaicheng 已提交
134 135
        "that your train process set the correct device id if you use "
        "Executor.",
G
GaoWei8 已提交
136
        place));
D
dzhwinter 已提交
137
  }
138
  return it->second.get().get();
D
dzhwinter 已提交
139 140
}

W
Wilber 已提交
141
template <typename DevCtx>
142 143 144 145 146
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
147 148 149 150 151 152 153 154 155 156 157 158
  map_ptr->emplace(
      p, std::async(std::launch::deferred, [=] {
        // lazy evaluation. i.e., only create device context at
        // first `Get`
        auto* dev_ctx = new DevCtx(p);
        if (is_gpu_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
          auto* cuda_ctx = dynamic_cast<CUDADeviceContext*>(dev_ctx);
          PADDLE_ENFORCE_NOT_NULL(
              cuda_ctx,
              platform::errors::InvalidArgument(
                  "Failed to dynamic_cast dev_ctx into CUDADeviceContext."));
W
Wilber 已提交
159 160 161 162 163 164
          // Note: A trick method to init context, why GetAllocator interface
          // needs a stream parameter?
          dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
                                    .GetAllocator(p, cuda_ctx->stream())
                                    .get());
          cuda_ctx->PartialInitWithAllocator();
W
Wilber 已提交
165 166
          dev_ctx->SetGenerator(
              framework::GetDefaultCUDAGenerator(p.GetDeviceId()).get());
167 168
#endif
        } else {
W
Wilber 已提交
169 170 171
          dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
                                    .GetAllocator(p)
                                    .get());
W
Wilber 已提交
172
          dev_ctx->SetGenerator(framework::DefaultCPUGenerator().get());
173 174 175 176 177 178 179 180 181 182 183
        }
        dev_ctx->SetHostAllocator(
            memory::allocation::AllocatorFacade::Instance()
                .GetAllocator(platform::CPUPlace())
                .get());
        dev_ctx->SetZeroAllocator(
            memory::allocation::AllocatorFacade::Instance()
                .GetZeroAllocator(p)
                .get());
        return PtrType(dev_ctx);
      }));
C
chengduozh 已提交
184 185
}

D
dzhwinter 已提交
186 187
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
188 189 190 191 192
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
193
  std::set<Place> set;
Y
Yu Yang 已提交
194 195 196 197 198
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
199
#ifdef PADDLE_WITH_MKLDNN
W
Wilber 已提交
200
      EmplaceDeviceContext<MKLDNNDeviceContext>(&device_contexts_, p);
201
#else
W
Wilber 已提交
202
      EmplaceDeviceContext<CPUDeviceContext>(&device_contexts_, p);
203
#endif
Y
Yu Yang 已提交
204
    } else if (platform::is_gpu_place(p)) {
205
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
206
      EmplaceDeviceContext<CUDADeviceContext>(&device_contexts_, p);
D
dzhwinter 已提交
207
#else
G
GaoWei8 已提交
208 209 210
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
211 212
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
213
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
214
      EmplaceDeviceContext<CUDAPinnedDeviceContext>(&device_contexts_, p);
C
chengduoZH 已提交
215
#else
G
GaoWei8 已提交
216
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
217 218
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
219 220 221
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
W
Wilber 已提交
222
      EmplaceDeviceContext<XPUDeviceContext>(&device_contexts_, p);
223 224 225 226
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
F
fwenguang 已提交
227 228 229
#endif
    } else if (platform::is_mlu_place(p)) {
#ifdef PADDLE_WITH_MLU
W
Wilber 已提交
230
      EmplaceDeviceContext<MLUDeviceContext>(&device_contexts_, p);
F
fwenguang 已提交
231 232 233 234
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("MLUPlace is not supported. Please "
                                          "re-compile with WITH_MLU option."));
J
jianghaicheng 已提交
235 236 237
#endif
    } else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU
W
Wilber 已提交
238
      EmplaceDeviceContext<IPUDeviceContext>(&device_contexts_, p);
J
jianghaicheng 已提交
239 240 241 242
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("IPUPlace is not supported. Please "
                                          "re-compile with WITH_IPU option."));
243 244 245
#endif
    } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
246
      EmplaceDeviceContext<NPUDeviceContext>(&device_contexts_, p);
247 248 249 250
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPlace is not supported. Please "
          "re-compile with WITH_ASCEND_CL option."));
251 252 253
#endif
    } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
254
      EmplaceDeviceContext<NPUPinnedDeviceContext>(&device_contexts_, p);
255 256 257 258 259
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPinnedPlace is not supported. Please re-compile with "
          "WITH_ASCEND_CL "
          "option."));
260 261 262 263 264 265 266 267 268
#endif
    } else if (platform::is_custom_place(p)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
      EmplaceDeviceContext<CustomDeviceContext>(&device_contexts_, p);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "CustomPlace is not supported. Please re-compile with "
          "WITH_CUSTOM_DEVICE "
          "option."));
D
dzhwinter 已提交
269 270 271 272 273
#endif
    }
  }
}

274 275
CPUDeviceContext::CPUDeviceContext() : phi::CPUContext() {
  phi::CPUContext::Init();
W
Wilber 已提交
276
}
277

278 279
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : phi::CPUContext(place) {
  phi::CPUContext::Init();
W
Wilber 已提交
280
}
281

J
jianghaicheng 已提交
282
#ifdef PADDLE_WITH_IPU
A
Allen Guo 已提交
283
IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) {}
J
jianghaicheng 已提交
284

W
Wilber 已提交
285
const Place& IPUDeviceContext::GetPlace() const { return place_; }
A
Allen Guo 已提交
286

J
jianghaicheng 已提交
287 288 289 290 291 292 293
void IPUDeviceContext::Wait() const {
  /*! \brief  Wait for all operations completion in the stream. */
}

IPUDeviceContext::~IPUDeviceContext() {}

#endif
294
#ifdef PADDLE_WITH_XPU
295 296
XPUDeviceContext::XPUDeviceContext() : phi::XPUContext() {
  phi::XPUContext::Init();
W
Wilber 已提交
297
}
298

299
XPUDeviceContext::~XPUDeviceContext() {}
300

301 302
XPUDeviceContext::XPUDeviceContext(XPUPlace place) : phi::XPUContext(place) {
  phi::XPUContext::Init();
303
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: "
W
Wilber 已提交
304
                          << static_cast<int>(place.device);
305 306 307
}
#endif

308 309 310 311 312 313 314
#ifdef PADDLE_WITH_ASCEND_CL
NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) {
  NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device));
  // NOTE(zhiqiu): Usually, no need to create context explicitly,
  // ACL creates a default context which contains 1 default stream
  // and 1 sync strean after aclrtSetDevice.
315
  platform::GetCurrentNPUContext(&context_);
316 317 318 319 320 321 322
  stream_.reset(new stream::NPUStream(place));
}

NPUDeviceContext::~NPUDeviceContext() {
  // NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_));
}
323

324
void NPUDeviceContext::Wait() const {
325 326 327
  platform::RecordEvent record_event("NPUDeviceContext/wait");
  VLOG(4) << "NPU context(" << this << ")  Wait";
  stream_->Wait();
328 329 330 331
}

aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); }

W
Wilber 已提交
332
const Place& NPUDeviceContext::GetPlace() const { return place_; }
333 334

aclrtContext NPUDeviceContext::context() const { return context_; }
335 336 337 338 339 340 341 342 343 344 345 346 347 348

NPUPinnedDeviceContext::NPUPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

NPUPinnedDeviceContext::NPUPinnedDeviceContext(NPUPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* NPUPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

W
Wilber 已提交
349
const Place& NPUPinnedDeviceContext::GetPlace() const { return place_; }
350

351 352 353
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
init  
qijun 已提交
354 355 356 357 358 359 360
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

361
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
362 363 364 365 366
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

367
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
368

369 370 371
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
372
  const cudaDeviceProp& deviceProperties() const override {
373
#endif
Q
init  
qijun 已提交
374 375 376 377
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
378 379 380
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
381 382 383
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
384
    void* retv = buf->ptr();
S
sneaxiy 已提交
385 386 387 388
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
389
    return retv;
Q
init  
qijun 已提交
390 391
  }

S
sneaxiy 已提交
392 393 394 395 396 397
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
398 399 400

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
401
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
402 403 404 405 406 407
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
408
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
409
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
410
#ifdef PADDLE_WITH_HIP
411
      PADDLE_ENFORCE_GPU_SUCCESS(
412 413
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
414
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
init  
qijun 已提交
415
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
416
#endif
Q
init  
qijun 已提交
417 418 419 420 421
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
422
  CUDAPlace place_;
423 424 425 426
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
427
  const cudaDeviceProp* device_prop_;  // not owned;
428
#endif
Q
qijun 已提交
429
  mutable void* scratch_;
Q
init  
qijun 已提交
430
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
431
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
432
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
433 434
};

435 436 437 438 439 440 441 442 443
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

444 445 446 447 448 449 450 451 452 453 454 455
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
456 457
                         const stream::Priority& priority,
                         const stream::StreamFlag& flag) {
458 459
  place_ = place;
  CUDADeviceGuard guard(place_.device);
460
  stream_.reset(new stream::CUDAStream(place, priority, flag));
461 462 463
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
464
#ifndef PADDLE_WITH_HIP
Z
zhangkaihuo 已提交
465
  InitCuSparseContext();
G
Guo Sheng 已提交
466
  InitCuSolverContext();
467
#endif
468 469
}

W
Wilber 已提交
470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489
void CUDAContext::SetStream(gpuStream_t stream) {
  if (stream_->raw_stream() != stream) {
    CUDADeviceGuard guard(place_.device);
    DestoryCuDNNContext();
    DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
    DestoryCuSolverContext();
#endif

    stream_->SetStream(stream);

    InitEigenContext();
    InitCuBlasContext();
    InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
    InitCuSolverContext();
#endif
  }
}

490 491 492 493
CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
494
#ifndef PADDLE_WITH_HIP
Z
zhangkaihuo 已提交
495
  DestoryCuSparseContext();
G
Guo Sheng 已提交
496
  DestoryCuSolverContext();
497
#endif
498 499
}

500 501 502 503
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : phi::GPUContext(place) {
  phi::GPUContext::PartialInitWithoutAllocator();
  cuda_stream_.reset(new stream::CUDAStream(phi::GPUContext::stream(), place));
  workspace_.reset(new phi::DnnWorkspaceHandle(
W
Wilber 已提交
504
      memory::allocation::AllocatorFacade::Instance()
505
          .GetAllocator(place, phi::GPUContext::stream())
W
Wilber 已提交
506
          .get()));
507 508
}

W
Wilber 已提交
509
CUDADeviceContext::~CUDADeviceContext() = default;
510

511
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
W
Wilber 已提交
512 513 514
  if (thread_ctx_.count(this)) {
    return context()->EigenDevice().get();
  }
515
  return phi::GPUContext::eigen_device();
S
sneaxiy 已提交
516 517
}

W
Wilber 已提交
518 519 520 521 522
void CUDADeviceContext::Wait() const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->Wait();
    return;
  }
523
  phi::GPUContext::Wait();
524 525
}

526 527 528
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
529
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
530
#endif
W
Wilber 已提交
531 532 533
  if (thread_ctx_.count(this)) {
    return context()->CudnnHandle();
  }
534
  return phi::GPUContext::cudnn_handle();
535
}
536

537 538
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
W
Wilber 已提交
539 540 541
  if (thread_ctx_.count(this)) {
    return context()->CublasHandle()->GetCublasHandle();
  }
542
  return phi::GPUContext::cublas_handle();
543 544
}
#else
545
cublasHandle_t CUDADeviceContext::cublas_handle() const {
W
Wilber 已提交
546 547 548
  if (thread_ctx_.count(this)) {
    return context()->CublasHandle()->GetCublasHandle();
  }
549
  return phi::GPUContext::cublas_handle();
550
}
Z
zhangkaihuo 已提交
551
cusparseHandle_t CUDADeviceContext::cusparse_handle() const {
W
Wilber 已提交
552 553 554
  if (thread_ctx_.count(this)) {
    return context()->CusparseHandle()->GetCusparseHandle();
  }
555
  return phi::GPUContext::cusparse_handle();
W
Wilber 已提交
556 557 558 559 560
}
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  if (thread_ctx_.count(this)) {
    return context()->CusolverDnHandle();
  }
561
  return phi::GPUContext::cusolver_dn_handle();
Z
zhangkaihuo 已提交
562
}
563
#endif
564

W
Wilber 已提交
565 566 567 568 569 570
void CUDADeviceContext::RecordEvent(
    gpuEvent_t ev, const std::function<void()>& callback) const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->RecordEvent(ev, callback);
    return;
  }
571
  phi::GPUContext::RecordEvent(ev, callback);
W
Wilber 已提交
572 573 574 575 576 577 578 579
}

void CUDADeviceContext::AddStreamCallback(
    const std::function<void()>& callback) const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->AddCallback(callback);
    return;
  }
580
  phi::GPUContext::AddStreamCallback(callback);
W
Wilber 已提交
581 582 583 584 585 586 587
}

void CUDADeviceContext::WaitStreamCallback() const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->WaitCallback();
    return;
  }
588
  phi::GPUContext::WaitStreamCallback();
W
Wilber 已提交
589 590
}

591
phi::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
W
Wilber 已提交
592 593
  if (thread_ctx_.count(this)) {
    // return workspace_.get();
594
    return phi::DnnWorkspaceHandle(
W
Wilber 已提交
595
        memory::allocation::AllocatorFacade::Instance()
596
            .GetAllocator(GetPlace(), phi::GPUContext::stream())
W
Wilber 已提交
597 598
            .get());
  }
599
  return phi::GPUContext::cudnn_workspace_handle();
600
}
601

W
Wilber 已提交
602 603 604 605
gpuStream_t CUDADeviceContext::stream() const {
  if (thread_ctx_.count(this)) {
    return context()->RawStream();
  }
606
  return phi::GPUContext::stream();
G
Guo Sheng 已提交
607 608
}

W
Wilber 已提交
609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627
std::shared_ptr<CUDAContext> CUDADeviceContext::context() const {
  if (!thread_ctx_.count(this)) {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "CUDADeviceContext call context() failed, make sure in the "
        "thread_local semantic."));
  }
  return thread_ctx_.at(this);
}

stream::CUDAStream* CUDADeviceContext::GetCudaStream() const {
  return cuda_stream_.get();
}

stream::CUDAStream* CUDADeviceContext::SetCudaStream(
    stream::CUDAStream* new_stream_ptr) {
  auto* old_stream_ptr = cuda_stream_.release();
  cuda_stream_.reset(new_stream_ptr);
  return old_stream_ptr;
}
Q
qijun 已提交
628

C
chengduoZH 已提交
629 630 631 632 633 634 635 636 637 638 639 640 641
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

W
Wilber 已提交
642
const Place& CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
643
#endif
Q
qijun 已提交
644

T
tensor-tang 已提交
645 646
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
647
    : CPUDeviceContext(place), p_blobmap_() {
648
  p_blobmap_.reset(new BlobMap());
649
  p_exec_items_.reset(new ExecShape());
650
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
651 652
}

653
MKLDNNDeviceContextThreadLocals::Body::Body()
654
    : cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
655 656 657 658 659 660
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

661 662 663 664 665 666 667 668 669 670 671 672
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
673
  dev_ctx->ResetBlobMap(exec_ptr_);
674 675
}

676 677 678 679 680 681 682 683 684 685
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
686 687
  cur_input_shape_str = input_shape_str;
}
688 689
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
690 691
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
692

693 694
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
695 696 697
  cur_paddle_data_layout = dl;
}

698 699
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
700 701 702
  return cur_paddle_data_layout;
}

703 704 705 706 707 708 709 710 711
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

712
const dnnl::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
713 714 715
  return cur_engine;
}

716
dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
717 718 719
  return cur_stream;
}

720
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
721 722 723
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
724 725 726 727 728 729
    // If no specific executor pointer then clear
    // everything. For executor pointer then clear only
    // objects allocated when using given executor
    if (ptr == nullptr) {
      p_blobmap_->clear();
    } else {
730 731 732 733 734
      // Iterate through all shapes and release
      // for each shape and active executor all entries
      // of this executor
      for (auto& s : *p_exec_items_) {
        for (auto& v : (*s.second)[ptr]) {
735
          (v.first)->erase(v.second);
736 737
        }
        s.second->erase(ptr);
738 739
      }
    }
740 741 742 743 744 745
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

746 747
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
  p_exec_items_->erase(p_exec_items_->begin());
748 749
}

750 751
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
                                                KeyBlob::iterator it) const {
752
  // Take current input shape from TLS
753 754
  // Take current executor addess from TLS
  // and for this executor's items add the one defined with arguments
755 756 757 758 759 760 761 762 763
  auto key_it = p_exec_items_
                    ->insert(std::make_pair(tls().cur_input_shape_str,
                                            std::make_shared<ExecMap>()))
                    .first;
  (*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));

  VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
          << " curr exec size: "
          << (*key_it->second)[tls().get_curr_exec()].size() << "\n";
764 765
}

766 767 768 769
void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
770
}
771

772
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
773
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
774
  BlobMap* pMap = p_blobmap_.get();
775
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
776
  if (map_it == pMap->end()) {
777 778 779
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
780 781 782 783
  }
  return map_it->second->size();
}

784
void MKLDNNDeviceContext::SetBlob(const std::string& name,
785
                                  BlobPtr_t<void> data) const {
786
  BlobMap* pMap = p_blobmap_.get();
787
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
788
  BlobPtr_t<KeyBlob> pBlob = nullptr;
789

790
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
791

792
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
793

794 795
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
796 797 798

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
799
    sBlob = std::make_shared<ShapeBlob>();
800 801
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
802
  } else {
803
    sBlob = map_it->second;
804
  }
T
tensor-tang 已提交
805

806
  // Find KeyBlob for current input shape
807
  auto key_it = sBlob->find(tls().cur_input_shape_str);
808

809
  if (key_it == sBlob->end()) {
810 811
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
812 813
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
814
        sBlob->size() &&
815
        (sBlob->size() >=
816
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
817 818 819 820
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
      RemoveShapeEntriesWithExecutor();
821
    }
822
    pBlob = std::make_shared<KeyBlob>();
823
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
824
  } else {
825
    pBlob = key_it->second;
826 827
  }

828
  // Find Blob via name
829 830 831 832
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    auto el =
        pBlob->insert(std::make_pair(name, data));  //  (*pBlob)[name] = data;
833 834 835
    // Register new element in per executor map
    // to have easily erased when executor terminated
    LinkEntryWithExecutor(pBlob, el.first);
836 837 838
  } else {
    blob_it->second = data;  // set data to existing blob
  }
839
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
840
  // lock will be automatically released when out of scope
841
  return;
T
tensor-tang 已提交
842 843
}

844
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
845 846 847
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
848
      num_entries += (l2.second)->size();
849 850 851 852 853
    }
  }
  return num_entries;
}

854
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
855
    const std::string& name) const {
856
  BlobMap* pMap = p_blobmap_.get();
857
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
858
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
859

860
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
861

862
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
863

864 865
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
866 867 868 869
  // (jczaja): After first iteration of model's execution we
  // should have all elements cached (mostly) so failures are unlikely (less
  // likely for dynamic shapes)
  if (unlikely(map_it == pMap->end())) {
870
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
871 872 873 874 875
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
876
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
877
  if (unlikely(sBlob_it == sBlob->end())) {
878
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
879 880 881 882
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
883 884

  // Find Blob via name
885
  auto key_it = pBlob->find(name);
886

887
  if (unlikely(key_it == pBlob->end())) {
888
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
889 890
    return nullptr;
  }
891

892
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
893 894
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
895 896
}

897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914
#endif

#ifdef PADDLE_WITH_CUSTOM_DEVICE
CustomDeviceContext::CustomDeviceContext(CustomPlace place) : place_(place) {
  DeviceGuard guard(place_);
  stream_.reset(new stream::Stream());
  stream_->Init(place_);
}

CustomDeviceContext::~CustomDeviceContext() {}

const Place& CustomDeviceContext::GetPlace() const { return place_; }

void CustomDeviceContext::Wait() const {
  // platform::RecordEvent record_event("NPUDeviceContext/wait");
  VLOG(4) << "CustomDevice context(" << this << ")  Wait";
  stream_->Wait();
}
T
tensor-tang 已提交
915
#endif
Q
qijun 已提交
916
}  // namespace platform
Q
qijun 已提交
917
}  // namespace paddle