device_context.cc 28.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
W
Wilber 已提交
13
#include <functional>
14
#include <memory>
15
#include <set>
W
Wilber 已提交
16 17 18
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
W
Wilber 已提交
19
#include "paddle/pten/core/allocator.h"
20

21
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
22
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
23
#include "paddle/fluid/platform/cuda_device_guard.h"
24
#endif
F
fwenguang 已提交
25 26 27 28
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/fluid/platform/device/mlu/device_context_allocator.h"
#endif
29
#include "glog/logging.h"
30
#include "paddle/fluid/framework/expect.h"
31
#include "paddle/fluid/memory/allocation/allocator_facade.h"
32
#include "paddle/fluid/platform/profiler.h"
33

34 35 36 37 38
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
39
  if (size == 0) {
40 41
    return Alloc(place, size);
  }
42 43

  if (platform::is_gpu_place(place)) {
44
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
63 64
    return Alloc(place, size);
#else
65 66 67
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
F
fwenguang 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
#endif
  } else if (platform::is_mlu_place(place)) {
#ifdef PADDLE_WITH_MLU
    auto* default_dev_ctx = static_cast<platform::MLUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::MLUDeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::MLUDeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use MLU device since it's not compiled with MLU,"
        "Please recompile or reinstall Paddle with MLU support."));
85
#endif
86 87 88
  } else {
    return Alloc(place, size);
  }
89 90 91 92 93
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
94 95 96
namespace paddle {
namespace platform {

97
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
98 99 100
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
101 102 103 104

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
105 106
#endif  // PADDLE_WITH_CUDA

107 108 109 110 111 112 113
DeviceType Place2DeviceType(const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
    return platform::DeviceType::CPU;
  } else if (platform::is_gpu_place(place)) {
    return platform::DeviceType::CUDA;
  } else if (platform::is_xpu_place(place)) {
    return platform::DeviceType::XPU;
F
fwenguang 已提交
114 115
  } else if (platform::is_mlu_place(place)) {
    return platform::DeviceType::MLU;
116 117 118 119 120 121
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported place %s to convert into platform::DeviceType.", place));
  }
}

D
dzhwinter 已提交
122 123
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
124
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
125
  VLOG(6) << "DeviceContextPool Get: " << place;
D
dzhwinter 已提交
126 127
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
128 129
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
F
fwenguang 已提交
130 131
        "with WITH_GPU, WITH_XPU, WITH_IPU, WITH_MLU or WITH_ASCEND_CL option "
        "or check "
J
jianghaicheng 已提交
132 133
        "that your train process set the correct device id if you use "
        "Executor.",
G
GaoWei8 已提交
134
        place));
D
dzhwinter 已提交
135
  }
136
  return it->second.get().get();
D
dzhwinter 已提交
137 138
}

W
Wilber 已提交
139
template <typename DevCtx>
140 141 142 143 144
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
145 146 147 148 149 150 151 152 153 154 155 156
  map_ptr->emplace(
      p, std::async(std::launch::deferred, [=] {
        // lazy evaluation. i.e., only create device context at
        // first `Get`
        auto* dev_ctx = new DevCtx(p);
        if (is_gpu_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
          auto* cuda_ctx = dynamic_cast<CUDADeviceContext*>(dev_ctx);
          PADDLE_ENFORCE_NOT_NULL(
              cuda_ctx,
              platform::errors::InvalidArgument(
                  "Failed to dynamic_cast dev_ctx into CUDADeviceContext."));
W
Wilber 已提交
157 158 159 160 161 162
          // Note: A trick method to init context, why GetAllocator interface
          // needs a stream parameter?
          dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
                                    .GetAllocator(p, cuda_ctx->stream())
                                    .get());
          cuda_ctx->PartialInitWithAllocator();
163 164
#endif
        } else {
W
Wilber 已提交
165 166 167
          dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
                                    .GetAllocator(p)
                                    .get());
168 169 170 171 172 173 174 175 176 177 178
        }
        dev_ctx->SetHostAllocator(
            memory::allocation::AllocatorFacade::Instance()
                .GetAllocator(platform::CPUPlace())
                .get());
        dev_ctx->SetZeroAllocator(
            memory::allocation::AllocatorFacade::Instance()
                .GetZeroAllocator(p)
                .get());
        return PtrType(dev_ctx);
      }));
C
chengduozh 已提交
179 180
}

D
dzhwinter 已提交
181 182
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
183 184 185 186 187
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
188
  std::set<Place> set;
Y
Yu Yang 已提交
189 190 191 192 193
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
194
#ifdef PADDLE_WITH_MKLDNN
W
Wilber 已提交
195
      EmplaceDeviceContext<MKLDNNDeviceContext>(&device_contexts_, p);
196
#else
W
Wilber 已提交
197
      EmplaceDeviceContext<CPUDeviceContext>(&device_contexts_, p);
198
#endif
Y
Yu Yang 已提交
199
    } else if (platform::is_gpu_place(p)) {
200
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
201
      EmplaceDeviceContext<CUDADeviceContext>(&device_contexts_, p);
D
dzhwinter 已提交
202
#else
G
GaoWei8 已提交
203 204 205
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
206 207
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
208
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
209
      EmplaceDeviceContext<CUDAPinnedDeviceContext>(&device_contexts_, p);
C
chengduoZH 已提交
210
#else
G
GaoWei8 已提交
211
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
212 213
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
214 215 216
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
W
Wilber 已提交
217
      EmplaceDeviceContext<XPUDeviceContext>(&device_contexts_, p);
218 219 220 221
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
F
fwenguang 已提交
222 223 224
#endif
    } else if (platform::is_mlu_place(p)) {
#ifdef PADDLE_WITH_MLU
W
Wilber 已提交
225
      EmplaceDeviceContext<MLUDeviceContext>(&device_contexts_, p);
F
fwenguang 已提交
226 227 228 229
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("MLUPlace is not supported. Please "
                                          "re-compile with WITH_MLU option."));
J
jianghaicheng 已提交
230 231 232
#endif
    } else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU
W
Wilber 已提交
233
      EmplaceDeviceContext<IPUDeviceContext>(&device_contexts_, p);
J
jianghaicheng 已提交
234 235 236 237
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("IPUPlace is not supported. Please "
                                          "re-compile with WITH_IPU option."));
238 239 240
#endif
    } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
241
      EmplaceDeviceContext<NPUDeviceContext>(&device_contexts_, p);
242 243 244 245
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPlace is not supported. Please "
          "re-compile with WITH_ASCEND_CL option."));
246 247 248
#endif
    } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
249
      EmplaceDeviceContext<NPUPinnedDeviceContext>(&device_contexts_, p);
250 251 252 253 254
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPinnedPlace is not supported. Please re-compile with "
          "WITH_ASCEND_CL "
          "option."));
D
dzhwinter 已提交
255 256 257 258 259
#endif
    }
  }
}

W
Wilber 已提交
260 261 262
CPUDeviceContext::CPUDeviceContext() : pten::CPUContext() {
  pten::CPUContext::Init();
}
263

W
Wilber 已提交
264 265 266
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : pten::CPUContext(place) {
  pten::CPUContext::Init();
}
267

J
jianghaicheng 已提交
268
#ifdef PADDLE_WITH_IPU
A
Allen Guo 已提交
269
IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) {}
J
jianghaicheng 已提交
270

W
Wilber 已提交
271
const Place& IPUDeviceContext::GetPlace() const { return place_; }
A
Allen Guo 已提交
272

J
jianghaicheng 已提交
273 274 275 276 277 278 279
void IPUDeviceContext::Wait() const {
  /*! \brief  Wait for all operations completion in the stream. */
}

IPUDeviceContext::~IPUDeviceContext() {}

#endif
280
#ifdef PADDLE_WITH_XPU
W
Wilber 已提交
281 282 283
XPUDeviceContext::XPUDeviceContext() : pten::XPUContext() {
  pten::XPUContext::Init();
}
284

285
XPUDeviceContext::~XPUDeviceContext() {}
286

W
Wilber 已提交
287
XPUDeviceContext::XPUDeviceContext(XPUPlace place) : pten::XPUContext(place) {
W
Wilber 已提交
288
  pten::XPUContext::Init();
289
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: "
W
Wilber 已提交
290
                          << static_cast<int>(place.device);
291 292 293
}
#endif

294 295 296 297 298 299 300
#ifdef PADDLE_WITH_ASCEND_CL
NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) {
  NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device));
  // NOTE(zhiqiu): Usually, no need to create context explicitly,
  // ACL creates a default context which contains 1 default stream
  // and 1 sync strean after aclrtSetDevice.
301
  platform::GetCurrentNPUContext(&context_);
302 303 304 305 306 307 308
  stream_.reset(new stream::NPUStream(place));
}

NPUDeviceContext::~NPUDeviceContext() {
  // NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_));
}
309

310
void NPUDeviceContext::Wait() const {
311 312 313
  platform::RecordEvent record_event("NPUDeviceContext/wait");
  VLOG(4) << "NPU context(" << this << ")  Wait";
  stream_->Wait();
314 315 316 317
}

aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); }

W
Wilber 已提交
318
const Place& NPUDeviceContext::GetPlace() const { return place_; }
319 320

aclrtContext NPUDeviceContext::context() const { return context_; }
321 322 323 324 325 326 327 328 329 330 331 332 333 334

NPUPinnedDeviceContext::NPUPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

NPUPinnedDeviceContext::NPUPinnedDeviceContext(NPUPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* NPUPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

W
Wilber 已提交
335
const Place& NPUPinnedDeviceContext::GetPlace() const { return place_; }
336

337 338 339
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
init  
qijun 已提交
340 341 342 343 344 345 346
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

347
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
348 349 350 351 352
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

353
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
354

355 356 357
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
358
  const cudaDeviceProp& deviceProperties() const override {
359
#endif
Q
init  
qijun 已提交
360 361 362 363
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
364 365 366
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
367 368 369
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
370
    void* retv = buf->ptr();
S
sneaxiy 已提交
371 372 373 374
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
375
    return retv;
Q
init  
qijun 已提交
376 377
  }

S
sneaxiy 已提交
378 379 380 381 382 383
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
384 385 386

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
387
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
388 389 390 391 392 393
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
394
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
395
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
396
#ifdef PADDLE_WITH_HIP
397
      PADDLE_ENFORCE_GPU_SUCCESS(
398 399
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
400
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
init  
qijun 已提交
401
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
402
#endif
Q
init  
qijun 已提交
403 404 405 406 407
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
408
  CUDAPlace place_;
409 410 411 412
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
413
  const cudaDeviceProp* device_prop_;  // not owned;
414
#endif
Q
qijun 已提交
415
  mutable void* scratch_;
Q
init  
qijun 已提交
416
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
417
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
418
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
419 420
};

421 422 423 424 425 426 427 428 429
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

430 431 432 433 434 435 436 437 438 439 440 441
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
442 443
                         const stream::Priority& priority,
                         const stream::StreamFlag& flag) {
444 445
  place_ = place;
  CUDADeviceGuard guard(place_.device);
446
  stream_.reset(new stream::CUDAStream(place, priority, flag));
447 448 449
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
450
#ifndef PADDLE_WITH_HIP
Z
zhangkaihuo 已提交
451
  InitCuSparseContext();
G
Guo Sheng 已提交
452
  InitCuSolverContext();
453
#endif
454 455
}

W
Wilber 已提交
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475
void CUDAContext::SetStream(gpuStream_t stream) {
  if (stream_->raw_stream() != stream) {
    CUDADeviceGuard guard(place_.device);
    DestoryCuDNNContext();
    DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
    DestoryCuSolverContext();
#endif

    stream_->SetStream(stream);

    InitEigenContext();
    InitCuBlasContext();
    InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
    InitCuSolverContext();
#endif
  }
}

476 477 478 479
CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
480
#ifndef PADDLE_WITH_HIP
Z
zhangkaihuo 已提交
481
  DestoryCuSparseContext();
G
Guo Sheng 已提交
482
  DestoryCuSolverContext();
483
#endif
484 485
}

W
Wilber 已提交
486 487 488
CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
    : pten::GPUContext(place) {
  pten::GPUContext::PartialInitWithoutAllocator();
W
Wilber 已提交
489 490 491 492 493
  cuda_stream_.reset(new stream::CUDAStream(pten::GPUContext::stream(), place));
  workspace_.reset(new pten::DnnWorkspaceHandle(
      memory::allocation::AllocatorFacade::Instance()
          .GetAllocator(place, pten::GPUContext::stream())
          .get()));
494 495
}

W
Wilber 已提交
496
CUDADeviceContext::~CUDADeviceContext() = default;
497

498
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
W
Wilber 已提交
499 500 501 502
  if (thread_ctx_.count(this)) {
    return context()->EigenDevice().get();
  }
  return pten::GPUContext::eigen_device();
S
sneaxiy 已提交
503 504
}

W
Wilber 已提交
505 506 507 508 509 510
void CUDADeviceContext::Wait() const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->Wait();
    return;
  }
  pten::GPUContext::Wait();
511 512
}

513 514 515
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
516
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
517
#endif
W
Wilber 已提交
518 519 520 521
  if (thread_ctx_.count(this)) {
    return context()->CudnnHandle();
  }
  return pten::GPUContext::cudnn_handle();
522
}
523

524 525
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
W
Wilber 已提交
526 527 528 529
  if (thread_ctx_.count(this)) {
    return context()->CublasHandle()->GetCublasHandle();
  }
  return pten::GPUContext::cublas_handle();
530 531
}
#else
532
cublasHandle_t CUDADeviceContext::cublas_handle() const {
W
Wilber 已提交
533 534 535 536
  if (thread_ctx_.count(this)) {
    return context()->CublasHandle()->GetCublasHandle();
  }
  return pten::GPUContext::cublas_handle();
537
}
Z
zhangkaihuo 已提交
538
cusparseHandle_t CUDADeviceContext::cusparse_handle() const {
W
Wilber 已提交
539 540 541 542 543 544 545 546 547 548
  if (thread_ctx_.count(this)) {
    return context()->CusparseHandle()->GetCusparseHandle();
  }
  return pten::GPUContext::cusparse_handle();
}
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  if (thread_ctx_.count(this)) {
    return context()->CusolverDnHandle();
  }
  return pten::GPUContext::cusolver_dn_handle();
Z
zhangkaihuo 已提交
549
}
550
#endif
551

W
Wilber 已提交
552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577
void CUDADeviceContext::RecordEvent(
    gpuEvent_t ev, const std::function<void()>& callback) const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->RecordEvent(ev, callback);
    return;
  }
  pten::GPUContext::RecordEvent(ev, callback);
}

void CUDADeviceContext::AddStreamCallback(
    const std::function<void()>& callback) const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->AddCallback(callback);
    return;
  }
  pten::GPUContext::AddStreamCallback(callback);
}

void CUDADeviceContext::WaitStreamCallback() const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->WaitCallback();
    return;
  }
  pten::GPUContext::WaitStreamCallback();
}

W
Wilber 已提交
578 579 580 581 582 583 584 585 586
pten::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
  if (thread_ctx_.count(this)) {
    // return workspace_.get();
    return pten::DnnWorkspaceHandle(
        memory::allocation::AllocatorFacade::Instance()
            .GetAllocator(GetPlace(), pten::GPUContext::stream())
            .get());
  }
  return pten::GPUContext::cudnn_workspace_handle();
587
}
588

W
Wilber 已提交
589 590 591 592 593
gpuStream_t CUDADeviceContext::stream() const {
  if (thread_ctx_.count(this)) {
    return context()->RawStream();
  }
  return pten::GPUContext::stream();
G
Guo Sheng 已提交
594 595
}

W
Wilber 已提交
596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614
std::shared_ptr<CUDAContext> CUDADeviceContext::context() const {
  if (!thread_ctx_.count(this)) {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "CUDADeviceContext call context() failed, make sure in the "
        "thread_local semantic."));
  }
  return thread_ctx_.at(this);
}

stream::CUDAStream* CUDADeviceContext::GetCudaStream() const {
  return cuda_stream_.get();
}

stream::CUDAStream* CUDADeviceContext::SetCudaStream(
    stream::CUDAStream* new_stream_ptr) {
  auto* old_stream_ptr = cuda_stream_.release();
  cuda_stream_.reset(new_stream_ptr);
  return old_stream_ptr;
}
Q
qijun 已提交
615

C
chengduoZH 已提交
616 617 618 619 620 621 622 623 624 625 626 627 628
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

W
Wilber 已提交
629
const Place& CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
630
#endif
Q
qijun 已提交
631

T
tensor-tang 已提交
632 633
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
634
    : CPUDeviceContext(place), p_blobmap_() {
635
  p_blobmap_.reset(new BlobMap());
636
  p_exec_items_.reset(new ExecShape());
637
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
638 639
}

640
MKLDNNDeviceContextThreadLocals::Body::Body()
641
    : cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
642 643 644 645 646 647
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

648 649 650 651 652 653 654 655 656 657 658 659
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
660
  dev_ctx->ResetBlobMap(exec_ptr_);
661 662
}

663 664 665 666 667 668 669 670 671 672
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
673 674
  cur_input_shape_str = input_shape_str;
}
675 676
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
677 678
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
679

680 681
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
682 683 684
  cur_paddle_data_layout = dl;
}

685 686
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
687 688 689
  return cur_paddle_data_layout;
}

690 691 692 693 694 695 696 697 698
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

699
const dnnl::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
700 701 702
  return cur_engine;
}

703
dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
704 705 706
  return cur_stream;
}

707
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
708 709 710
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
711 712 713 714 715 716
    // If no specific executor pointer then clear
    // everything. For executor pointer then clear only
    // objects allocated when using given executor
    if (ptr == nullptr) {
      p_blobmap_->clear();
    } else {
717 718 719 720 721
      // Iterate through all shapes and release
      // for each shape and active executor all entries
      // of this executor
      for (auto& s : *p_exec_items_) {
        for (auto& v : (*s.second)[ptr]) {
722
          (v.first)->erase(v.second);
723 724
        }
        s.second->erase(ptr);
725 726
      }
    }
727 728 729 730 731 732
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

733 734
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
  p_exec_items_->erase(p_exec_items_->begin());
735 736
}

737 738
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
                                                KeyBlob::iterator it) const {
739
  // Take current input shape from TLS
740 741
  // Take current executor addess from TLS
  // and for this executor's items add the one defined with arguments
742 743 744 745 746 747 748 749 750
  auto key_it = p_exec_items_
                    ->insert(std::make_pair(tls().cur_input_shape_str,
                                            std::make_shared<ExecMap>()))
                    .first;
  (*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));

  VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
          << " curr exec size: "
          << (*key_it->second)[tls().get_curr_exec()].size() << "\n";
751 752
}

753 754 755 756
void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
757
}
758

759
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
760
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
761
  BlobMap* pMap = p_blobmap_.get();
762
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
763
  if (map_it == pMap->end()) {
764 765 766
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
767 768 769 770
  }
  return map_it->second->size();
}

771
void MKLDNNDeviceContext::SetBlob(const std::string& name,
772
                                  BlobPtr_t<void> data) const {
773
  BlobMap* pMap = p_blobmap_.get();
774
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
775
  BlobPtr_t<KeyBlob> pBlob = nullptr;
776

777
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
778

779
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
780

781 782
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
783 784 785

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
786
    sBlob = std::make_shared<ShapeBlob>();
787 788
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
789
  } else {
790
    sBlob = map_it->second;
791
  }
T
tensor-tang 已提交
792

793
  // Find KeyBlob for current input shape
794
  auto key_it = sBlob->find(tls().cur_input_shape_str);
795

796
  if (key_it == sBlob->end()) {
797 798
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
799 800
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
801
        sBlob->size() &&
802
        (sBlob->size() >=
803
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
804 805 806 807
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
      RemoveShapeEntriesWithExecutor();
808
    }
809
    pBlob = std::make_shared<KeyBlob>();
810
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
811
  } else {
812
    pBlob = key_it->second;
813 814
  }

815
  // Find Blob via name
816 817 818 819
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    auto el =
        pBlob->insert(std::make_pair(name, data));  //  (*pBlob)[name] = data;
820 821 822
    // Register new element in per executor map
    // to have easily erased when executor terminated
    LinkEntryWithExecutor(pBlob, el.first);
823 824 825
  } else {
    blob_it->second = data;  // set data to existing blob
  }
826
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
827
  // lock will be automatically released when out of scope
828
  return;
T
tensor-tang 已提交
829 830
}

831
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
832 833 834
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
835
      num_entries += (l2.second)->size();
836 837 838 839 840
    }
  }
  return num_entries;
}

841
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
842
    const std::string& name) const {
843
  BlobMap* pMap = p_blobmap_.get();
844
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
845
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
846

847
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
848

849
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
850

851 852
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
853 854 855 856
  // (jczaja): After first iteration of model's execution we
  // should have all elements cached (mostly) so failures are unlikely (less
  // likely for dynamic shapes)
  if (unlikely(map_it == pMap->end())) {
857
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
858 859 860 861 862
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
863
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
864
  if (unlikely(sBlob_it == sBlob->end())) {
865
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
866 867 868 869
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
870 871

  // Find Blob via name
872
  auto key_it = pBlob->find(name);
873

874
  if (unlikely(key_it == pBlob->end())) {
875
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
876 877
    return nullptr;
  }
878

879
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
880 881
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
882 883 884
}

#endif
Q
qijun 已提交
885
}  // namespace platform
Q
qijun 已提交
886
}  // namespace paddle