device_context.cc 28.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
W
Wilber 已提交
13
#include <functional>
14
#include <memory>
15
#include <set>
W
Wilber 已提交
16 17 18
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
W
Wilber 已提交
19
#include "paddle/pten/core/allocator.h"
20

21
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
22
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
23
#include "paddle/fluid/platform/cuda_device_guard.h"
24
#endif
F
fwenguang 已提交
25 26 27 28
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/fluid/platform/device/mlu/device_context_allocator.h"
#endif
29
#include "glog/logging.h"
30
#include "paddle/fluid/framework/expect.h"
W
Wilber 已提交
31
#include "paddle/fluid/framework/generator.h"
32
#include "paddle/fluid/memory/allocation/allocator_facade.h"
33
#include "paddle/fluid/platform/profiler.h"
34

35 36 37 38 39
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
40
  if (size == 0) {
41 42
    return Alloc(place, size);
  }
43 44

  if (platform::is_gpu_place(place)) {
45
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
64 65
    return Alloc(place, size);
#else
66 67 68
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
F
fwenguang 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
#endif
  } else if (platform::is_mlu_place(place)) {
#ifdef PADDLE_WITH_MLU
    auto* default_dev_ctx = static_cast<platform::MLUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::MLUDeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::MLUDeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use MLU device since it's not compiled with MLU,"
        "Please recompile or reinstall Paddle with MLU support."));
86
#endif
87 88 89
  } else {
    return Alloc(place, size);
  }
90 91 92 93 94
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
95 96 97
namespace paddle {
namespace platform {

98
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
99 100 101
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
102 103 104 105

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
106 107
#endif  // PADDLE_WITH_CUDA

108 109 110 111 112 113 114
DeviceType Place2DeviceType(const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
    return platform::DeviceType::CPU;
  } else if (platform::is_gpu_place(place)) {
    return platform::DeviceType::CUDA;
  } else if (platform::is_xpu_place(place)) {
    return platform::DeviceType::XPU;
F
fwenguang 已提交
115 116
  } else if (platform::is_mlu_place(place)) {
    return platform::DeviceType::MLU;
117 118 119 120 121 122
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported place %s to convert into platform::DeviceType.", place));
  }
}

D
dzhwinter 已提交
123 124
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
125
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
126
  VLOG(6) << "DeviceContextPool Get: " << place;
D
dzhwinter 已提交
127 128
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
129 130
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
F
fwenguang 已提交
131 132
        "with WITH_GPU, WITH_XPU, WITH_IPU, WITH_MLU or WITH_ASCEND_CL option "
        "or check "
J
jianghaicheng 已提交
133 134
        "that your train process set the correct device id if you use "
        "Executor.",
G
GaoWei8 已提交
135
        place));
D
dzhwinter 已提交
136
  }
137
  return it->second.get().get();
D
dzhwinter 已提交
138 139
}

W
Wilber 已提交
140
template <typename DevCtx>
141 142 143 144 145
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
146 147 148 149 150 151 152 153 154 155 156 157
  map_ptr->emplace(
      p, std::async(std::launch::deferred, [=] {
        // lazy evaluation. i.e., only create device context at
        // first `Get`
        auto* dev_ctx = new DevCtx(p);
        if (is_gpu_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
          auto* cuda_ctx = dynamic_cast<CUDADeviceContext*>(dev_ctx);
          PADDLE_ENFORCE_NOT_NULL(
              cuda_ctx,
              platform::errors::InvalidArgument(
                  "Failed to dynamic_cast dev_ctx into CUDADeviceContext."));
W
Wilber 已提交
158 159 160 161 162 163
          // Note: A trick method to init context, why GetAllocator interface
          // needs a stream parameter?
          dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
                                    .GetAllocator(p, cuda_ctx->stream())
                                    .get());
          cuda_ctx->PartialInitWithAllocator();
W
Wilber 已提交
164 165
          dev_ctx->SetGenerator(
              framework::GetDefaultCUDAGenerator(p.GetDeviceId()).get());
166 167
#endif
        } else {
W
Wilber 已提交
168 169 170
          dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
                                    .GetAllocator(p)
                                    .get());
W
Wilber 已提交
171
          dev_ctx->SetGenerator(framework::DefaultCPUGenerator().get());
172 173 174 175 176 177 178 179 180 181 182
        }
        dev_ctx->SetHostAllocator(
            memory::allocation::AllocatorFacade::Instance()
                .GetAllocator(platform::CPUPlace())
                .get());
        dev_ctx->SetZeroAllocator(
            memory::allocation::AllocatorFacade::Instance()
                .GetZeroAllocator(p)
                .get());
        return PtrType(dev_ctx);
      }));
C
chengduozh 已提交
183 184
}

D
dzhwinter 已提交
185 186
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
187 188 189 190 191
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
192
  std::set<Place> set;
Y
Yu Yang 已提交
193 194 195 196 197
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
198
#ifdef PADDLE_WITH_MKLDNN
W
Wilber 已提交
199
      EmplaceDeviceContext<MKLDNNDeviceContext>(&device_contexts_, p);
200
#else
W
Wilber 已提交
201
      EmplaceDeviceContext<CPUDeviceContext>(&device_contexts_, p);
202
#endif
Y
Yu Yang 已提交
203
    } else if (platform::is_gpu_place(p)) {
204
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
205
      EmplaceDeviceContext<CUDADeviceContext>(&device_contexts_, p);
D
dzhwinter 已提交
206
#else
G
GaoWei8 已提交
207 208 209
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
210 211
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
212
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
213
      EmplaceDeviceContext<CUDAPinnedDeviceContext>(&device_contexts_, p);
C
chengduoZH 已提交
214
#else
G
GaoWei8 已提交
215
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
216 217
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
218 219 220
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
W
Wilber 已提交
221
      EmplaceDeviceContext<XPUDeviceContext>(&device_contexts_, p);
222 223 224 225
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
F
fwenguang 已提交
226 227 228
#endif
    } else if (platform::is_mlu_place(p)) {
#ifdef PADDLE_WITH_MLU
W
Wilber 已提交
229
      EmplaceDeviceContext<MLUDeviceContext>(&device_contexts_, p);
F
fwenguang 已提交
230 231 232 233
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("MLUPlace is not supported. Please "
                                          "re-compile with WITH_MLU option."));
J
jianghaicheng 已提交
234 235 236
#endif
    } else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU
W
Wilber 已提交
237
      EmplaceDeviceContext<IPUDeviceContext>(&device_contexts_, p);
J
jianghaicheng 已提交
238 239 240 241
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("IPUPlace is not supported. Please "
                                          "re-compile with WITH_IPU option."));
242 243 244
#endif
    } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
245
      EmplaceDeviceContext<NPUDeviceContext>(&device_contexts_, p);
246 247 248 249
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPlace is not supported. Please "
          "re-compile with WITH_ASCEND_CL option."));
250 251 252
#endif
    } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
253
      EmplaceDeviceContext<NPUPinnedDeviceContext>(&device_contexts_, p);
254 255 256 257 258
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPinnedPlace is not supported. Please re-compile with "
          "WITH_ASCEND_CL "
          "option."));
D
dzhwinter 已提交
259 260 261 262 263
#endif
    }
  }
}

W
Wilber 已提交
264 265 266
CPUDeviceContext::CPUDeviceContext() : pten::CPUContext() {
  pten::CPUContext::Init();
}
267

W
Wilber 已提交
268 269 270
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : pten::CPUContext(place) {
  pten::CPUContext::Init();
}
271

J
jianghaicheng 已提交
272
#ifdef PADDLE_WITH_IPU
A
Allen Guo 已提交
273
IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) {}
J
jianghaicheng 已提交
274

W
Wilber 已提交
275
const Place& IPUDeviceContext::GetPlace() const { return place_; }
A
Allen Guo 已提交
276

J
jianghaicheng 已提交
277 278 279 280 281 282 283
void IPUDeviceContext::Wait() const {
  /*! \brief  Wait for all operations completion in the stream. */
}

IPUDeviceContext::~IPUDeviceContext() {}

#endif
284
#ifdef PADDLE_WITH_XPU
W
Wilber 已提交
285 286 287
XPUDeviceContext::XPUDeviceContext() : pten::XPUContext() {
  pten::XPUContext::Init();
}
288

289
XPUDeviceContext::~XPUDeviceContext() {}
290

W
Wilber 已提交
291
XPUDeviceContext::XPUDeviceContext(XPUPlace place) : pten::XPUContext(place) {
W
Wilber 已提交
292
  pten::XPUContext::Init();
293
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: "
W
Wilber 已提交
294
                          << static_cast<int>(place.device);
295 296 297
}
#endif

298 299 300 301 302 303 304
#ifdef PADDLE_WITH_ASCEND_CL
NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) {
  NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device));
  // NOTE(zhiqiu): Usually, no need to create context explicitly,
  // ACL creates a default context which contains 1 default stream
  // and 1 sync strean after aclrtSetDevice.
305
  platform::GetCurrentNPUContext(&context_);
306 307 308 309 310 311 312
  stream_.reset(new stream::NPUStream(place));
}

NPUDeviceContext::~NPUDeviceContext() {
  // NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_));
}
313

314
void NPUDeviceContext::Wait() const {
315 316 317
  platform::RecordEvent record_event("NPUDeviceContext/wait");
  VLOG(4) << "NPU context(" << this << ")  Wait";
  stream_->Wait();
318 319 320 321
}

aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); }

W
Wilber 已提交
322
const Place& NPUDeviceContext::GetPlace() const { return place_; }
323 324

aclrtContext NPUDeviceContext::context() const { return context_; }
325 326 327 328 329 330 331 332 333 334 335 336 337 338

NPUPinnedDeviceContext::NPUPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

NPUPinnedDeviceContext::NPUPinnedDeviceContext(NPUPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* NPUPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

W
Wilber 已提交
339
const Place& NPUPinnedDeviceContext::GetPlace() const { return place_; }
340

341 342 343
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
init  
qijun 已提交
344 345 346 347 348 349 350
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

351
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
352 353 354 355 356
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

357
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
358

359 360 361
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
362
  const cudaDeviceProp& deviceProperties() const override {
363
#endif
Q
init  
qijun 已提交
364 365 366 367
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
368 369 370
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
371 372 373
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
374
    void* retv = buf->ptr();
S
sneaxiy 已提交
375 376 377 378
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
379
    return retv;
Q
init  
qijun 已提交
380 381
  }

S
sneaxiy 已提交
382 383 384 385 386 387
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
388 389 390

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
391
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
392 393 394 395 396 397
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
398
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
399
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
400
#ifdef PADDLE_WITH_HIP
401
      PADDLE_ENFORCE_GPU_SUCCESS(
402 403
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
404
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
init  
qijun 已提交
405
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
406
#endif
Q
init  
qijun 已提交
407 408 409 410 411
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
412
  CUDAPlace place_;
413 414 415 416
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
417
  const cudaDeviceProp* device_prop_;  // not owned;
418
#endif
Q
qijun 已提交
419
  mutable void* scratch_;
Q
init  
qijun 已提交
420
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
421
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
422
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
423 424
};

425 426 427 428 429 430 431 432 433
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

434 435 436 437 438 439 440 441 442 443 444 445
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
446 447
                         const stream::Priority& priority,
                         const stream::StreamFlag& flag) {
448 449
  place_ = place;
  CUDADeviceGuard guard(place_.device);
450
  stream_.reset(new stream::CUDAStream(place, priority, flag));
451 452 453
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
454
#ifndef PADDLE_WITH_HIP
Z
zhangkaihuo 已提交
455
  InitCuSparseContext();
G
Guo Sheng 已提交
456
  InitCuSolverContext();
457
#endif
458 459
}

W
Wilber 已提交
460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
void CUDAContext::SetStream(gpuStream_t stream) {
  if (stream_->raw_stream() != stream) {
    CUDADeviceGuard guard(place_.device);
    DestoryCuDNNContext();
    DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
    DestoryCuSolverContext();
#endif

    stream_->SetStream(stream);

    InitEigenContext();
    InitCuBlasContext();
    InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
    InitCuSolverContext();
#endif
  }
}

480 481 482 483
CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
484
#ifndef PADDLE_WITH_HIP
Z
zhangkaihuo 已提交
485
  DestoryCuSparseContext();
G
Guo Sheng 已提交
486
  DestoryCuSolverContext();
487
#endif
488 489
}

W
Wilber 已提交
490 491 492
CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
    : pten::GPUContext(place) {
  pten::GPUContext::PartialInitWithoutAllocator();
W
Wilber 已提交
493 494 495 496 497
  cuda_stream_.reset(new stream::CUDAStream(pten::GPUContext::stream(), place));
  workspace_.reset(new pten::DnnWorkspaceHandle(
      memory::allocation::AllocatorFacade::Instance()
          .GetAllocator(place, pten::GPUContext::stream())
          .get()));
498 499
}

W
Wilber 已提交
500
CUDADeviceContext::~CUDADeviceContext() = default;
501

502
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
W
Wilber 已提交
503 504 505 506
  if (thread_ctx_.count(this)) {
    return context()->EigenDevice().get();
  }
  return pten::GPUContext::eigen_device();
S
sneaxiy 已提交
507 508
}

W
Wilber 已提交
509 510 511 512 513 514
void CUDADeviceContext::Wait() const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->Wait();
    return;
  }
  pten::GPUContext::Wait();
515 516
}

517 518 519
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
520
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
521
#endif
W
Wilber 已提交
522 523 524 525
  if (thread_ctx_.count(this)) {
    return context()->CudnnHandle();
  }
  return pten::GPUContext::cudnn_handle();
526
}
527

528 529
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
W
Wilber 已提交
530 531 532 533
  if (thread_ctx_.count(this)) {
    return context()->CublasHandle()->GetCublasHandle();
  }
  return pten::GPUContext::cublas_handle();
534 535
}
#else
536
cublasHandle_t CUDADeviceContext::cublas_handle() const {
W
Wilber 已提交
537 538 539 540
  if (thread_ctx_.count(this)) {
    return context()->CublasHandle()->GetCublasHandle();
  }
  return pten::GPUContext::cublas_handle();
541
}
Z
zhangkaihuo 已提交
542
cusparseHandle_t CUDADeviceContext::cusparse_handle() const {
W
Wilber 已提交
543 544 545 546 547 548 549 550 551 552
  if (thread_ctx_.count(this)) {
    return context()->CusparseHandle()->GetCusparseHandle();
  }
  return pten::GPUContext::cusparse_handle();
}
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  if (thread_ctx_.count(this)) {
    return context()->CusolverDnHandle();
  }
  return pten::GPUContext::cusolver_dn_handle();
Z
zhangkaihuo 已提交
553
}
554
#endif
555

W
Wilber 已提交
556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581
void CUDADeviceContext::RecordEvent(
    gpuEvent_t ev, const std::function<void()>& callback) const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->RecordEvent(ev, callback);
    return;
  }
  pten::GPUContext::RecordEvent(ev, callback);
}

void CUDADeviceContext::AddStreamCallback(
    const std::function<void()>& callback) const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->AddCallback(callback);
    return;
  }
  pten::GPUContext::AddStreamCallback(callback);
}

void CUDADeviceContext::WaitStreamCallback() const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->WaitCallback();
    return;
  }
  pten::GPUContext::WaitStreamCallback();
}

W
Wilber 已提交
582 583 584 585 586 587 588 589 590
pten::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
  if (thread_ctx_.count(this)) {
    // return workspace_.get();
    return pten::DnnWorkspaceHandle(
        memory::allocation::AllocatorFacade::Instance()
            .GetAllocator(GetPlace(), pten::GPUContext::stream())
            .get());
  }
  return pten::GPUContext::cudnn_workspace_handle();
591
}
592

W
Wilber 已提交
593 594 595 596 597
gpuStream_t CUDADeviceContext::stream() const {
  if (thread_ctx_.count(this)) {
    return context()->RawStream();
  }
  return pten::GPUContext::stream();
G
Guo Sheng 已提交
598 599
}

W
Wilber 已提交
600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618
std::shared_ptr<CUDAContext> CUDADeviceContext::context() const {
  if (!thread_ctx_.count(this)) {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "CUDADeviceContext call context() failed, make sure in the "
        "thread_local semantic."));
  }
  return thread_ctx_.at(this);
}

stream::CUDAStream* CUDADeviceContext::GetCudaStream() const {
  return cuda_stream_.get();
}

stream::CUDAStream* CUDADeviceContext::SetCudaStream(
    stream::CUDAStream* new_stream_ptr) {
  auto* old_stream_ptr = cuda_stream_.release();
  cuda_stream_.reset(new_stream_ptr);
  return old_stream_ptr;
}
Q
qijun 已提交
619

C
chengduoZH 已提交
620 621 622 623 624 625 626 627 628 629 630 631 632
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

W
Wilber 已提交
633
const Place& CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
634
#endif
Q
qijun 已提交
635

T
tensor-tang 已提交
636 637
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
638
    : CPUDeviceContext(place), p_blobmap_() {
639
  p_blobmap_.reset(new BlobMap());
640
  p_exec_items_.reset(new ExecShape());
641
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
642 643
}

644
MKLDNNDeviceContextThreadLocals::Body::Body()
645
    : cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
646 647 648 649 650 651
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

652 653 654 655 656 657 658 659 660 661 662 663
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
664
  dev_ctx->ResetBlobMap(exec_ptr_);
665 666
}

667 668 669 670 671 672 673 674 675 676
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
677 678
  cur_input_shape_str = input_shape_str;
}
679 680
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
681 682
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
683

684 685
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
686 687 688
  cur_paddle_data_layout = dl;
}

689 690
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
691 692 693
  return cur_paddle_data_layout;
}

694 695 696 697 698 699 700 701 702
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

703
const dnnl::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
704 705 706
  return cur_engine;
}

707
dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
708 709 710
  return cur_stream;
}

711
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
712 713 714
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
715 716 717 718 719 720
    // If no specific executor pointer then clear
    // everything. For executor pointer then clear only
    // objects allocated when using given executor
    if (ptr == nullptr) {
      p_blobmap_->clear();
    } else {
721 722 723 724 725
      // Iterate through all shapes and release
      // for each shape and active executor all entries
      // of this executor
      for (auto& s : *p_exec_items_) {
        for (auto& v : (*s.second)[ptr]) {
726
          (v.first)->erase(v.second);
727 728
        }
        s.second->erase(ptr);
729 730
      }
    }
731 732 733 734 735 736
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

737 738
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
  p_exec_items_->erase(p_exec_items_->begin());
739 740
}

741 742
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
                                                KeyBlob::iterator it) const {
743
  // Take current input shape from TLS
744 745
  // Take current executor addess from TLS
  // and for this executor's items add the one defined with arguments
746 747 748 749 750 751 752 753 754
  auto key_it = p_exec_items_
                    ->insert(std::make_pair(tls().cur_input_shape_str,
                                            std::make_shared<ExecMap>()))
                    .first;
  (*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));

  VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
          << " curr exec size: "
          << (*key_it->second)[tls().get_curr_exec()].size() << "\n";
755 756
}

757 758 759 760
void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
761
}
762

763
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
764
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
765
  BlobMap* pMap = p_blobmap_.get();
766
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
767
  if (map_it == pMap->end()) {
768 769 770
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
771 772 773 774
  }
  return map_it->second->size();
}

775
void MKLDNNDeviceContext::SetBlob(const std::string& name,
776
                                  BlobPtr_t<void> data) const {
777
  BlobMap* pMap = p_blobmap_.get();
778
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
779
  BlobPtr_t<KeyBlob> pBlob = nullptr;
780

781
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
782

783
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
784

785 786
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
787 788 789

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
790
    sBlob = std::make_shared<ShapeBlob>();
791 792
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
793
  } else {
794
    sBlob = map_it->second;
795
  }
T
tensor-tang 已提交
796

797
  // Find KeyBlob for current input shape
798
  auto key_it = sBlob->find(tls().cur_input_shape_str);
799

800
  if (key_it == sBlob->end()) {
801 802
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
803 804
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
805
        sBlob->size() &&
806
        (sBlob->size() >=
807
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
808 809 810 811
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
      RemoveShapeEntriesWithExecutor();
812
    }
813
    pBlob = std::make_shared<KeyBlob>();
814
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
815
  } else {
816
    pBlob = key_it->second;
817 818
  }

819
  // Find Blob via name
820 821 822 823
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    auto el =
        pBlob->insert(std::make_pair(name, data));  //  (*pBlob)[name] = data;
824 825 826
    // Register new element in per executor map
    // to have easily erased when executor terminated
    LinkEntryWithExecutor(pBlob, el.first);
827 828 829
  } else {
    blob_it->second = data;  // set data to existing blob
  }
830
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
831
  // lock will be automatically released when out of scope
832
  return;
T
tensor-tang 已提交
833 834
}

835
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
836 837 838
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
839
      num_entries += (l2.second)->size();
840 841 842 843 844
    }
  }
  return num_entries;
}

845
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
846
    const std::string& name) const {
847
  BlobMap* pMap = p_blobmap_.get();
848
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
849
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
850

851
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
852

853
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
854

855 856
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
857 858 859 860
  // (jczaja): After first iteration of model's execution we
  // should have all elements cached (mostly) so failures are unlikely (less
  // likely for dynamic shapes)
  if (unlikely(map_it == pMap->end())) {
861
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
862 863 864 865 866
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
867
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
868
  if (unlikely(sBlob_it == sBlob->end())) {
869
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
870 871 872 873
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
874 875

  // Find Blob via name
876
  auto key_it = pBlob->find(name);
877

878
  if (unlikely(key_it == pBlob->end())) {
879
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
880 881
    return nullptr;
  }
882

883
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
884 885
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
886 887 888
}

#endif
Q
qijun 已提交
889
}  // namespace platform
Q
qijun 已提交
890
}  // namespace paddle