device_context.cc 29.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
W
Wilber 已提交
13
#include <functional>
14
#include <memory>
15
#include <set>
W
Wilber 已提交
16 17
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
18 19
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/allocator.h"
20

21
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
22
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
23
#include "paddle/fluid/platform/cuda_device_guard.h"
24
#endif
F
fwenguang 已提交
25 26 27 28
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/fluid/platform/device/mlu/device_context_allocator.h"
#endif
29
#include "glog/logging.h"
30
#include "paddle/fluid/framework/expect.h"
W
Wilber 已提交
31
#include "paddle/fluid/framework/generator.h"
32
#include "paddle/fluid/memory/allocation/allocator_facade.h"
33
#include "paddle/fluid/platform/device/device_wrapper.h"
34
#include "paddle/fluid/platform/profiler.h"
35
#include "paddle/fluid/platform/profiler/event_tracing.h"
36

37 38 39 40 41
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
42
  if (size == 0) {
43 44
    return Alloc(place, size);
  }
45 46

  if (platform::is_gpu_place(place)) {
47
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
66 67
    return Alloc(place, size);
#else
68 69 70
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
F
fwenguang 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
#endif
  } else if (platform::is_mlu_place(place)) {
#ifdef PADDLE_WITH_MLU
    auto* default_dev_ctx = static_cast<platform::MLUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::MLUDeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::MLUDeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use MLU device since it's not compiled with MLU,"
        "Please recompile or reinstall Paddle with MLU support."));
88
#endif
89 90 91
  } else {
    return Alloc(place, size);
  }
92 93 94 95 96
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
97 98 99
namespace paddle {
namespace platform {

100
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
101 102 103
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
104 105 106 107

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
108 109
#endif  // PADDLE_WITH_CUDA

110 111 112 113 114 115 116
DeviceType Place2DeviceType(const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
    return platform::DeviceType::CPU;
  } else if (platform::is_gpu_place(place)) {
    return platform::DeviceType::CUDA;
  } else if (platform::is_xpu_place(place)) {
    return platform::DeviceType::XPU;
F
fwenguang 已提交
117 118
  } else if (platform::is_mlu_place(place)) {
    return platform::DeviceType::MLU;
119 120 121 122 123 124
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported place %s to convert into platform::DeviceType.", place));
  }
}

D
dzhwinter 已提交
125 126
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
127
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
128
  VLOG(6) << "DeviceContextPool Get: " << place;
D
dzhwinter 已提交
129 130
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
131 132
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
F
fwenguang 已提交
133 134
        "with WITH_GPU, WITH_XPU, WITH_IPU, WITH_MLU or WITH_ASCEND_CL option "
        "or check "
J
jianghaicheng 已提交
135 136
        "that your train process set the correct device id if you use "
        "Executor.",
G
GaoWei8 已提交
137
        place));
D
dzhwinter 已提交
138
  }
139
  return it->second.get().get();
D
dzhwinter 已提交
140 141
}

W
Wilber 已提交
142
template <typename DevCtx>
143 144 145 146 147
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
148 149 150 151 152 153 154 155 156 157 158 159
  map_ptr->emplace(
      p, std::async(std::launch::deferred, [=] {
        // lazy evaluation. i.e., only create device context at
        // first `Get`
        auto* dev_ctx = new DevCtx(p);
        if (is_gpu_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
          auto* cuda_ctx = dynamic_cast<CUDADeviceContext*>(dev_ctx);
          PADDLE_ENFORCE_NOT_NULL(
              cuda_ctx,
              platform::errors::InvalidArgument(
                  "Failed to dynamic_cast dev_ctx into CUDADeviceContext."));
W
Wilber 已提交
160 161 162 163 164 165
          // Note: A trick method to init context, why GetAllocator interface
          // needs a stream parameter?
          dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
                                    .GetAllocator(p, cuda_ctx->stream())
                                    .get());
          cuda_ctx->PartialInitWithAllocator();
W
Wilber 已提交
166 167
          dev_ctx->SetGenerator(
              framework::GetDefaultCUDAGenerator(p.GetDeviceId()).get());
168 169
#endif
        } else {
W
Wilber 已提交
170 171 172
          dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
                                    .GetAllocator(p)
                                    .get());
W
Wilber 已提交
173
          dev_ctx->SetGenerator(framework::DefaultCPUGenerator().get());
174 175 176 177 178 179 180 181 182 183 184
        }
        dev_ctx->SetHostAllocator(
            memory::allocation::AllocatorFacade::Instance()
                .GetAllocator(platform::CPUPlace())
                .get());
        dev_ctx->SetZeroAllocator(
            memory::allocation::AllocatorFacade::Instance()
                .GetZeroAllocator(p)
                .get());
        return PtrType(dev_ctx);
      }));
C
chengduozh 已提交
185 186
}

D
dzhwinter 已提交
187 188
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
189 190 191 192 193
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
194
  std::set<Place> set;
Y
Yu Yang 已提交
195 196 197 198 199
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
200
#ifdef PADDLE_WITH_MKLDNN
W
Wilber 已提交
201
      EmplaceDeviceContext<MKLDNNDeviceContext>(&device_contexts_, p);
202
#else
W
Wilber 已提交
203
      EmplaceDeviceContext<CPUDeviceContext>(&device_contexts_, p);
204
#endif
Y
Yu Yang 已提交
205
    } else if (platform::is_gpu_place(p)) {
206
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
207
      EmplaceDeviceContext<CUDADeviceContext>(&device_contexts_, p);
D
dzhwinter 已提交
208
#else
G
GaoWei8 已提交
209 210 211
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
212 213
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
214
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
Wilber 已提交
215
      EmplaceDeviceContext<CUDAPinnedDeviceContext>(&device_contexts_, p);
C
chengduoZH 已提交
216
#else
G
GaoWei8 已提交
217
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
218 219
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
220 221 222
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
W
Wilber 已提交
223
      EmplaceDeviceContext<XPUDeviceContext>(&device_contexts_, p);
224 225 226 227
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
F
fwenguang 已提交
228 229 230
#endif
    } else if (platform::is_mlu_place(p)) {
#ifdef PADDLE_WITH_MLU
W
Wilber 已提交
231
      EmplaceDeviceContext<MLUDeviceContext>(&device_contexts_, p);
F
fwenguang 已提交
232 233 234 235
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("MLUPlace is not supported. Please "
                                          "re-compile with WITH_MLU option."));
J
jianghaicheng 已提交
236 237 238
#endif
    } else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU
W
Wilber 已提交
239
      EmplaceDeviceContext<IPUDeviceContext>(&device_contexts_, p);
J
jianghaicheng 已提交
240 241 242 243
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("IPUPlace is not supported. Please "
                                          "re-compile with WITH_IPU option."));
244 245 246
#endif
    } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
247
      EmplaceDeviceContext<NPUDeviceContext>(&device_contexts_, p);
248 249 250 251
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPlace is not supported. Please "
          "re-compile with WITH_ASCEND_CL option."));
252 253 254
#endif
    } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
W
Wilber 已提交
255
      EmplaceDeviceContext<NPUPinnedDeviceContext>(&device_contexts_, p);
256 257 258 259 260
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "NPUPinnedPlace is not supported. Please re-compile with "
          "WITH_ASCEND_CL "
          "option."));
261 262 263 264 265 266 267 268 269
#endif
    } else if (platform::is_custom_place(p)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
      EmplaceDeviceContext<CustomDeviceContext>(&device_contexts_, p);
#else
      PADDLE_THROW(platform::errors::Unimplemented(
          "CustomPlace is not supported. Please re-compile with "
          "WITH_CUSTOM_DEVICE "
          "option."));
D
dzhwinter 已提交
270 271 272 273 274
#endif
    }
  }
}

275 276
CPUDeviceContext::CPUDeviceContext() : phi::CPUContext() {
  phi::CPUContext::Init();
W
Wilber 已提交
277
}
278

279 280
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : phi::CPUContext(place) {
  phi::CPUContext::Init();
W
Wilber 已提交
281
}
282

J
jianghaicheng 已提交
283
#ifdef PADDLE_WITH_IPU
A
Allen Guo 已提交
284
IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) {}
J
jianghaicheng 已提交
285

W
Wilber 已提交
286
const Place& IPUDeviceContext::GetPlace() const { return place_; }
A
Allen Guo 已提交
287

J
jianghaicheng 已提交
288 289 290 291 292 293 294
void IPUDeviceContext::Wait() const {
  /*! \brief  Wait for all operations completion in the stream. */
}

IPUDeviceContext::~IPUDeviceContext() {}

#endif
295
#ifdef PADDLE_WITH_XPU
296 297
XPUDeviceContext::XPUDeviceContext() : phi::XPUContext() {
  phi::XPUContext::Init();
W
Wilber 已提交
298
}
299

300
XPUDeviceContext::~XPUDeviceContext() {}
301

302 303
XPUDeviceContext::XPUDeviceContext(XPUPlace place) : phi::XPUContext(place) {
  phi::XPUContext::Init();
304
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: "
W
Wilber 已提交
305
                          << static_cast<int>(place.device);
306 307 308
}
#endif

309 310 311 312 313 314 315
#ifdef PADDLE_WITH_ASCEND_CL
NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) {
  NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device));
  // NOTE(zhiqiu): Usually, no need to create context explicitly,
  // ACL creates a default context which contains 1 default stream
  // and 1 sync strean after aclrtSetDevice.
316
  platform::GetCurrentNPUContext(&context_);
317 318 319 320 321 322 323
  stream_.reset(new stream::NPUStream(place));
}

NPUDeviceContext::~NPUDeviceContext() {
  // NPUDeviceGuard guard(place_.device);
  // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_));
}
324

325
void NPUDeviceContext::Wait() const {
326 327
  platform::RecordEvent record_event("NPUDeviceContext/wait",
                                     platform::TracerEventType::UserDefined, 2);
328 329
  VLOG(4) << "NPU context(" << this << ")  Wait";
  stream_->Wait();
330 331 332 333
}

aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); }

W
Wilber 已提交
334
const Place& NPUDeviceContext::GetPlace() const { return place_; }
335 336

aclrtContext NPUDeviceContext::context() const { return context_; }
337 338 339 340 341 342 343 344 345 346 347 348 349 350

NPUPinnedDeviceContext::NPUPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

NPUPinnedDeviceContext::NPUPinnedDeviceContext(NPUPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* NPUPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

W
Wilber 已提交
351
const Place& NPUPinnedDeviceContext::GetPlace() const { return place_; }
352

353 354 355
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
init  
qijun 已提交
356 357 358 359 360 361 362
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

363
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
364 365 366 367 368
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

369
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
370

371 372 373
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
374
  const cudaDeviceProp& deviceProperties() const override {
375
#endif
Q
init  
qijun 已提交
376 377 378 379
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
380 381 382
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
383 384 385
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
386
    void* retv = buf->ptr();
S
sneaxiy 已提交
387 388 389 390
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
391
    return retv;
Q
init  
qijun 已提交
392 393
  }

S
sneaxiy 已提交
394 395 396 397 398 399
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
400 401 402

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
403
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
404 405 406 407 408 409
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
410
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
411
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
412
#ifdef PADDLE_WITH_HIP
413
      PADDLE_ENFORCE_GPU_SUCCESS(
414 415
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
416
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
init  
qijun 已提交
417
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
418
#endif
Q
init  
qijun 已提交
419 420 421 422 423
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
424
  CUDAPlace place_;
425 426 427 428
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
429
  const cudaDeviceProp* device_prop_;  // not owned;
430
#endif
Q
qijun 已提交
431
  mutable void* scratch_;
Q
init  
qijun 已提交
432
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
433
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
434
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
435 436
};

437 438 439 440 441 442 443 444 445
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

446 447 448 449 450 451 452 453 454 455 456 457
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
458 459
                         const stream::Priority& priority,
                         const stream::StreamFlag& flag) {
460 461
  place_ = place;
  CUDADeviceGuard guard(place_.device);
462
  stream_.reset(new stream::CUDAStream(place, priority, flag));
463 464 465
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
466
#ifndef PADDLE_WITH_HIP
Z
zhangkaihuo 已提交
467
  InitCuSparseContext();
G
Guo Sheng 已提交
468
  InitCuSolverContext();
469
#endif
470 471
}

W
Wilber 已提交
472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
void CUDAContext::SetStream(gpuStream_t stream) {
  if (stream_->raw_stream() != stream) {
    CUDADeviceGuard guard(place_.device);
    DestoryCuDNNContext();
    DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
    DestoryCuSolverContext();
#endif

    stream_->SetStream(stream);

    InitEigenContext();
    InitCuBlasContext();
    InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
    InitCuSolverContext();
#endif
  }
}

492 493 494 495
CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
496
#ifndef PADDLE_WITH_HIP
Z
zhangkaihuo 已提交
497
  DestoryCuSparseContext();
G
Guo Sheng 已提交
498
  DestoryCuSolverContext();
499
#endif
500 501
}

502 503 504 505
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : phi::GPUContext(place) {
  phi::GPUContext::PartialInitWithoutAllocator();
  cuda_stream_.reset(new stream::CUDAStream(phi::GPUContext::stream(), place));
  workspace_.reset(new phi::DnnWorkspaceHandle(
W
Wilber 已提交
506
      memory::allocation::AllocatorFacade::Instance()
507
          .GetAllocator(place, phi::GPUContext::stream())
W
Wilber 已提交
508
          .get()));
509 510
}

W
Wilber 已提交
511
CUDADeviceContext::~CUDADeviceContext() = default;
512

513
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
W
Wilber 已提交
514 515 516
  if (thread_ctx_.count(this)) {
    return context()->EigenDevice().get();
  }
517
  return phi::GPUContext::eigen_device();
S
sneaxiy 已提交
518 519
}

W
Wilber 已提交
520 521 522 523 524
void CUDADeviceContext::Wait() const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->Wait();
    return;
  }
525
  phi::GPUContext::Wait();
526 527
}

528 529 530
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
531
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
532
#endif
W
Wilber 已提交
533 534 535
  if (thread_ctx_.count(this)) {
    return context()->CudnnHandle();
  }
536
  return phi::GPUContext::cudnn_handle();
537
}
538

539 540
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
W
Wilber 已提交
541 542 543
  if (thread_ctx_.count(this)) {
    return context()->CublasHandle()->GetCublasHandle();
  }
544
  return phi::GPUContext::cublas_handle();
545 546
}
#else
547
cublasHandle_t CUDADeviceContext::cublas_handle() const {
W
Wilber 已提交
548 549 550
  if (thread_ctx_.count(this)) {
    return context()->CublasHandle()->GetCublasHandle();
  }
551
  return phi::GPUContext::cublas_handle();
552
}
Z
zhangkaihuo 已提交
553
cusparseHandle_t CUDADeviceContext::cusparse_handle() const {
W
Wilber 已提交
554 555 556
  if (thread_ctx_.count(this)) {
    return context()->CusparseHandle()->GetCusparseHandle();
  }
557
  return phi::GPUContext::cusparse_handle();
W
Wilber 已提交
558 559 560 561 562
}
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  if (thread_ctx_.count(this)) {
    return context()->CusolverDnHandle();
  }
563
  return phi::GPUContext::cusolver_dn_handle();
Z
zhangkaihuo 已提交
564
}
565
#endif
566

W
Wilber 已提交
567 568 569 570 571 572
void CUDADeviceContext::RecordEvent(
    gpuEvent_t ev, const std::function<void()>& callback) const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->RecordEvent(ev, callback);
    return;
  }
573
  phi::GPUContext::RecordEvent(ev, callback);
W
Wilber 已提交
574 575 576 577 578 579 580 581
}

void CUDADeviceContext::AddStreamCallback(
    const std::function<void()>& callback) const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->AddCallback(callback);
    return;
  }
582
  phi::GPUContext::AddStreamCallback(callback);
W
Wilber 已提交
583 584 585 586 587 588 589
}

void CUDADeviceContext::WaitStreamCallback() const {
  if (thread_ctx_.count(this)) {
    context()->Stream()->WaitCallback();
    return;
  }
590
  phi::GPUContext::WaitStreamCallback();
W
Wilber 已提交
591 592
}

593
phi::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
W
Wilber 已提交
594 595
  if (thread_ctx_.count(this)) {
    // return workspace_.get();
596
    return phi::DnnWorkspaceHandle(
W
Wilber 已提交
597
        memory::allocation::AllocatorFacade::Instance()
598
            .GetAllocator(GetPlace(), phi::GPUContext::stream())
W
Wilber 已提交
599 600
            .get());
  }
601
  return phi::GPUContext::cudnn_workspace_handle();
602
}
603

W
Wilber 已提交
604 605 606 607
gpuStream_t CUDADeviceContext::stream() const {
  if (thread_ctx_.count(this)) {
    return context()->RawStream();
  }
608
  return phi::GPUContext::stream();
G
Guo Sheng 已提交
609 610
}

W
Wilber 已提交
611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629
std::shared_ptr<CUDAContext> CUDADeviceContext::context() const {
  if (!thread_ctx_.count(this)) {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "CUDADeviceContext call context() failed, make sure in the "
        "thread_local semantic."));
  }
  return thread_ctx_.at(this);
}

stream::CUDAStream* CUDADeviceContext::GetCudaStream() const {
  return cuda_stream_.get();
}

stream::CUDAStream* CUDADeviceContext::SetCudaStream(
    stream::CUDAStream* new_stream_ptr) {
  auto* old_stream_ptr = cuda_stream_.release();
  cuda_stream_.reset(new_stream_ptr);
  return old_stream_ptr;
}
Q
qijun 已提交
630

C
chengduoZH 已提交
631 632 633 634 635 636 637 638 639 640 641 642 643
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

W
Wilber 已提交
644
const Place& CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
645
#endif
Q
qijun 已提交
646

T
tensor-tang 已提交
647 648
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
649
    : CPUDeviceContext(place), p_blobmap_() {
650
  p_blobmap_.reset(new BlobMap());
651
  p_exec_items_.reset(new ExecShape());
652
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
653 654
}

655
MKLDNNDeviceContextThreadLocals::Body::Body()
656
    : cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
657 658 659 660 661 662
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

663 664 665 666 667 668 669 670 671 672 673 674
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
675
  dev_ctx->ResetBlobMap(exec_ptr_);
676 677
}

678 679 680 681 682 683 684 685 686 687
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
688 689
  cur_input_shape_str = input_shape_str;
}
690 691
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
692 693
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
694

695 696
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
697 698 699
  cur_paddle_data_layout = dl;
}

700 701
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
702 703 704
  return cur_paddle_data_layout;
}

705 706 707 708 709 710 711 712 713
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

714
const dnnl::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
715 716 717
  return cur_engine;
}

718
dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
719 720 721
  return cur_stream;
}

722
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
723 724 725
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
726 727 728 729 730 731
    // If no specific executor pointer then clear
    // everything. For executor pointer then clear only
    // objects allocated when using given executor
    if (ptr == nullptr) {
      p_blobmap_->clear();
    } else {
732 733 734 735 736
      // Iterate through all shapes and release
      // for each shape and active executor all entries
      // of this executor
      for (auto& s : *p_exec_items_) {
        for (auto& v : (*s.second)[ptr]) {
737
          (v.first)->erase(v.second);
738 739
        }
        s.second->erase(ptr);
740 741
      }
    }
742 743 744 745 746 747
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

748 749
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
  p_exec_items_->erase(p_exec_items_->begin());
750 751
}

752 753
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
                                                KeyBlob::iterator it) const {
754
  // Take current input shape from TLS
755 756
  // Take current executor addess from TLS
  // and for this executor's items add the one defined with arguments
757 758 759 760 761 762 763 764 765
  auto key_it = p_exec_items_
                    ->insert(std::make_pair(tls().cur_input_shape_str,
                                            std::make_shared<ExecMap>()))
                    .first;
  (*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));

  VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
          << " curr exec size: "
          << (*key_it->second)[tls().get_curr_exec()].size() << "\n";
766 767
}

768 769 770 771
void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
772
}
773

774
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
775
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
776
  BlobMap* pMap = p_blobmap_.get();
777
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
778
  if (map_it == pMap->end()) {
779 780 781
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
782 783 784 785
  }
  return map_it->second->size();
}

786
void MKLDNNDeviceContext::SetBlob(const std::string& name,
787
                                  BlobPtr_t<void> data) const {
788
  BlobMap* pMap = p_blobmap_.get();
789
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
790
  BlobPtr_t<KeyBlob> pBlob = nullptr;
791

792
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
793

794
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
795

796 797
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
798 799 800

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
801
    sBlob = std::make_shared<ShapeBlob>();
802 803
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
804
  } else {
805
    sBlob = map_it->second;
806
  }
T
tensor-tang 已提交
807

808
  // Find KeyBlob for current input shape
809
  auto key_it = sBlob->find(tls().cur_input_shape_str);
810

811
  if (key_it == sBlob->end()) {
812 813
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
814 815
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
816
        sBlob->size() &&
817
        (sBlob->size() >=
818
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
819 820 821 822
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
      RemoveShapeEntriesWithExecutor();
823
    }
824
    pBlob = std::make_shared<KeyBlob>();
825
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
826
  } else {
827
    pBlob = key_it->second;
828 829
  }

830
  // Find Blob via name
831 832 833 834
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    auto el =
        pBlob->insert(std::make_pair(name, data));  //  (*pBlob)[name] = data;
835 836 837
    // Register new element in per executor map
    // to have easily erased when executor terminated
    LinkEntryWithExecutor(pBlob, el.first);
838 839 840
  } else {
    blob_it->second = data;  // set data to existing blob
  }
841
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
842
  // lock will be automatically released when out of scope
843
  return;
T
tensor-tang 已提交
844 845
}

846
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
847 848 849
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
850
      num_entries += (l2.second)->size();
851 852 853 854 855
    }
  }
  return num_entries;
}

856
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
857
    const std::string& name) const {
858
  BlobMap* pMap = p_blobmap_.get();
859
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
860
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
861

862
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
863

864
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
865

866 867
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
868 869 870 871
  // (jczaja): After first iteration of model's execution we
  // should have all elements cached (mostly) so failures are unlikely (less
  // likely for dynamic shapes)
  if (unlikely(map_it == pMap->end())) {
872
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
873 874 875 876 877
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
878
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
879
  if (unlikely(sBlob_it == sBlob->end())) {
880
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
881 882 883 884
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
885 886

  // Find Blob via name
887
  auto key_it = pBlob->find(name);
888

889
  if (unlikely(key_it == pBlob->end())) {
890
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
891 892
    return nullptr;
  }
893

894
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
895 896
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
897 898
}

899 900 901
#endif

#ifdef PADDLE_WITH_CUSTOM_DEVICE
902 903 904 905
CustomDeviceContext::CustomDeviceContext(CustomPlace place)
    : phi::CustomContext(place) {
  Init();
  stream_.reset(new platform::stream::Stream(place, stream()));
906 907 908
}

CustomDeviceContext::~CustomDeviceContext() {}
T
tensor-tang 已提交
909
#endif
Q
qijun 已提交
910
}  // namespace platform
Q
qijun 已提交
911
}  // namespace paddle