device_context.cc 23.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
13
#include <set>
14

15
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
16
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
17
#include "paddle/fluid/platform/cuda_device_guard.h"
18
#endif
19

20 21
#include "glog/logging.h"

22 23 24 25 26
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
27
  if (size == 0) {
28 29
    return Alloc(place, size);
  }
30 31

  if (platform::is_gpu_place(place)) {
32
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
51 52
    return Alloc(place, size);
#else
53 54 55
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
56
#endif
57 58 59
  } else {
    return Alloc(place, size);
  }
60 61 62 63 64
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
65 66 67
namespace paddle {
namespace platform {

68
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
69 70 71
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
A
AshburnLee 已提交
72 73 74 75

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
76 77
#endif  // PADDLE_WITH_CUDA

D
dzhwinter 已提交
78 79
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
80
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
D
dzhwinter 已提交
81 82
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
83 84
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
85 86
        "with WITH_GPU or WITH_XPU option or check that your train process "
        "hold the "
G
GaoWei8 已提交
87 88
        "correct gpu_id if you use Executor.",
        place));
D
dzhwinter 已提交
89
  }
90
  return it->second.get().get();
D
dzhwinter 已提交
91 92
}

93 94 95 96 97 98 99 100 101
template <typename DevCtx, typename PlaceType>
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
  map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
                     // lazy evaluation. i.e., only create device context at
                     // first `Get`
102
                     return PtrType(new DevCtx(BOOST_GET_CONST(PlaceType, p)));
103
                   }));
C
chengduozh 已提交
104 105
}

D
dzhwinter 已提交
106 107
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
108 109 110 111 112
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
113
  std::set<Place> set;
Y
Yu Yang 已提交
114 115 116 117 118
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
119
#ifdef PADDLE_WITH_MKLDNN
120
      EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
121
#else
122
      EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
123
#endif
Y
Yu Yang 已提交
124
    } else if (platform::is_gpu_place(p)) {
125
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
126
      EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
D
dzhwinter 已提交
127
#else
G
GaoWei8 已提交
128 129 130
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
131 132
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
133
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
134 135
      EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
          &device_contexts_, p);
C
chengduoZH 已提交
136
#else
G
GaoWei8 已提交
137
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
138 139
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
140 141 142 143 144 145 146 147
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
      EmplaceDeviceContext<XPUDeviceContext, XPUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
D
dzhwinter 已提交
148 149 150 151 152
#endif
    }
  }
}

153 154 155 156
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
157
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
158 159 160 161 162 163 164
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
165
Place CPUDeviceContext::GetPlace() const { return place_; }
166

167 168 169
#ifdef PADDLE_WITH_XPU
XPUDeviceContext::XPUDeviceContext() { context_ = xpu::create_context(); }

170
XPUDeviceContext::~XPUDeviceContext() {}
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185

XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) {
  int dev_id = -1;
  int ret = xpu_current_device(&dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  ret = xpu_set_device(place.device);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
186 187 188

  LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: " << place_.device;

189
  context_ = xpu::create_context();
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
  const int MAX_XPU_NUM = 16;
  const int l3_size = 13.5 * 1024 * 1024;
  static void* l3ptrs[MAX_XPU_NUM] = {nullptr};

  auto selected_xpus = GetXPUSelectedDevices();
  for (unsigned int i = 0; i < selected_xpus.size(); i++) {
    if (place.device == selected_xpus[i]) {
      if (l3ptrs[place.device] == nullptr) {
        xpu_malloc(static_cast<void**>(&l3ptrs[place.device]), l3_size,
                   XPU_MEM_L3);
      }
      if (l3ptrs[place.device] != nullptr) {
        context_->_l3_mgr.set(l3ptrs[place.device], l3_size);
        VLOG(3) << "xpu place " << place.device << " set l3 size " << l3_size;
      }
      break;
    }
207
  }
208

209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
  ret = xpu_set_device(dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
}

void XPUDeviceContext::Wait() const {
  int ret = xpu_set_device(place_.device);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
224
  xpu_wait(context_->xpu_stream);
225 226 227 228 229 230 231
}

Place XPUDeviceContext::GetPlace() const { return place_; }

xpu::Context* XPUDeviceContext::x_context() const { return context_; }
#endif

232
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
233

Q
init  
qijun 已提交
234 235 236 237 238 239 240
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

241
  void Reinitialize(const gpuStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
242 243 244 245 246
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

247
  const gpuStream_t& stream() const override { return *stream_; }
Q
init  
qijun 已提交
248

249 250 251
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t& deviceProperties() const override {
#else
Q
init  
qijun 已提交
252
  const cudaDeviceProp& deviceProperties() const override {
253
#endif
Q
init  
qijun 已提交
254 255 256 257
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
258 259 260
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
261 262 263
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
264
    void* retv = buf->ptr();
S
sneaxiy 已提交
265 266 267 268
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
269
    return retv;
Q
init  
qijun 已提交
270 271
  }

S
sneaxiy 已提交
272 273 274 275 276 277
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
278 279 280

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
281
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
Q
init  
qijun 已提交
282 283 284 285 286 287
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
288
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
Q
init  
qijun 已提交
289
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
290 291 292 293
#ifdef PADDLE_WITH_HIP
      PADDLE_ENFORCE_CUDA_SUCCESS(
          hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
#else
294
      PADDLE_ENFORCE_CUDA_SUCCESS(
Q
init  
qijun 已提交
295
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
296
#endif
Q
init  
qijun 已提交
297 298 299 300 301
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
302
  CUDAPlace place_;
303 304 305 306
  const gpuStream_t* stream_;  // not owned;
#ifdef PADDLE_WITH_HIP
  const hipDeviceProp_t* device_prop_;
#else
Q
init  
qijun 已提交
307
  const cudaDeviceProp* device_prop_;  // not owned;
308
#endif
Q
qijun 已提交
309
  mutable void* scratch_;
Q
init  
qijun 已提交
310
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
311
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
312
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
313 314
};

315 316 317 318 319 320 321 322 323
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

324 325 326 327 328 329 330 331 332 333 334 335
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
336
                         const stream::Priority& priority) {
337 338 339 340 341 342
  place_ = place;
  CUDADeviceGuard guard(place_.device);
  stream_.reset(new stream::CUDAStream(place, priority));
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
343
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
344
  InitCuSolverContext();
345
#endif
346 347 348 349 350 351
}

CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
352
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
353
  DestoryCuSolverContext();
354
#endif
355 356
}

357
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
Y
Yu Yang 已提交
358
  CUDADeviceGuard guard(place_.device);
C
chengduo 已提交
359 360 361
  compute_capability_ = GetCUDAComputeCapability(place_.device);
  multi_process_ = GetCUDAMultiProcessors(place_.device);
  max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
362
  max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
363
  max_threads_per_block_ = GetCUDAMaxThreadsPerBlock(place_.device);
364

C
chengduo 已提交
365 366 367
  driver_version_ = GetCUDADriverVersion(place_.device);
  runtime_version_ = GetCUDARuntimeVersion(place_.device);

368
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device
369 370 371
                          << ", GPU Compute Capability: "
                          << compute_capability_ / 10 << "."
                          << compute_capability_ % 10
C
chengduo 已提交
372
                          << ", Driver API Version: " << driver_version_ / 1000
373
                          << "." << (driver_version_ % 100) / 10
C
chengduo 已提交
374 375 376
                          << ", Runtime API Version: "
                          << runtime_version_ / 1000 << "."
                          << (runtime_version_ % 100) / 10;
377 378 379 380 381 382 383 384
#ifdef PADDLE_WITH_HIP
  size_t version_major, version_minor, version_patch;
  PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenGetVersion(
      &version_major, &version_minor, &version_patch));
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", MIOpen Version: " << version_major << "."
                          << version_minor << "." << version_patch;
#else
385 386 387
  size_t cudnn_dso_ver = dynload::cudnnGetVersion();
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
388
                          << (cudnn_dso_ver % 1000) / 100 << ".";
389
#endif
S
sneaxiy 已提交
390 391
  {
    // Check CUDA/CUDNN version compatiblity
392 393
    auto local_cuda_version =
        (driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
394 395 396
#ifdef PADDLE_WITH_HIP
    auto compile_cuda_version = (HIP_VERSION / 100) * 10 + (HIP_VERSION % 10);
#else
397 398
    auto compile_cuda_version =
        (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
399
#endif
S
sneaxiy 已提交
400 401 402 403 404 405 406 407 408 409 410 411
    if (local_cuda_version < compile_cuda_version) {
      LOG_FIRST_N(WARNING, 1)
          << "WARNING: device: " << place_.device
          << ". The installed Paddle is compiled with CUDA "
          << compile_cuda_version / 10 << "." << compile_cuda_version % 10
          << ", but CUDA runtime version in your machine is "
          << local_cuda_version / 10 << "." << local_cuda_version % 10
          << ", which may cause serious incompatible bug. "
          << "Please recompile or reinstall Paddle with compatible CUDA "
             "version.";
    }
  }
412
  default_ctx_.reset(new CUDAContext(place_));
413 414 415 416
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
417
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
418 419 420 421
  if (nccl_comm_) {
    PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
  }
#endif
422 423
}

L
liaogang 已提交
424
Place CUDADeviceContext::GetPlace() const { return place_; }
425

426
void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); }
427

K
Kexin Zhao 已提交
428
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
429
  return compute_capability_;
K
Kexin Zhao 已提交
430 431
}

432
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
433
  return multi_process_ * max_threads_per_mp_;
434 435
}

436 437 438 439 440 441
int CUDADeviceContext::GetSMCount() const { return multi_process_; }

int CUDADeviceContext::GetMaxThreadsPerBlock() const {
  return max_threads_per_block_;
}

442
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
443
  return context()->EigenDevice().get();
444 445
}

446
bool CUDADeviceContext::tensor_core_available() const {
447
  return context()->CublasTensorCoreHandle() != nullptr;
S
sneaxiy 已提交
448 449
}

450 451 452 453
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
  return max_grid_dim_size_;
}

454 455 456
#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
457
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
458
#endif
459 460
  return context()->CudnnHandle();
}
461

462 463 464 465
cublasHandle_t CUDADeviceContext::cublas_handle() const {
  return context()->CublasHandle()->GetCublasHandle();
}

S
sneaxiy 已提交
466
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
467
  return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
468
}
469

470
#ifndef PADDLE_WITH_HIP
G
Guo Sheng 已提交
471 472 473
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  return context()->CusolverDnHandle();
}
474
#endif
G
Guo Sheng 已提交
475

476
gpuStream_t CUDADeviceContext::stream() const { return context()->RawStream(); }
Q
qijun 已提交
477

C
chengduoZH 已提交
478 479 480 481 482 483 484 485 486 487 488 489 490 491
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
492
#endif
Q
qijun 已提交
493

T
tensor-tang 已提交
494 495
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
496
    : CPUDeviceContext(place), p_blobmap_() {
497 498
  p_blobmap_.reset(new BlobMap());
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
499 500
}

501 502
MKLDNNDeviceContextThreadLocals::Body::Body()
    : cur_engine(mkldnn::engine::kind::cpu, 0), cur_stream(cur_engine) {
503 504 505 506 507 508
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

509 510 511 512 513 514 515 516 517 518 519 520 521 522 523
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
  auto cpu_place = paddle::platform::CPUPlace();
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  platform::MKLDNNDeviceContext* dev_ctx =
      (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
  dev_ctx->ResetBlobMap();
}

524 525 526 527 528 529 530 531 532 533
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
534 535
  cur_input_shape_str = input_shape_str;
}
536 537
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
538 539
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
540

541 542
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
543 544 545
  cur_paddle_data_layout = dl;
}

546 547
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
548 549 550
  return cur_paddle_data_layout;
}

551 552 553 554 555 556 557 558 559
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

560 561 562 563 564 565 566 567
const mkldnn::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
  return cur_engine;
}

mkldnn::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
  return cur_stream;
}

568 569 570 571 572 573 574 575 576 577 578 579 580 581 582
void MKLDNNDeviceContext::ResetBlobMap() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
    p_blobmap_->clear();
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
583
}
584

585
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
586
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
587
  BlobMap* pMap = p_blobmap_.get();
588
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
589
  if (map_it == pMap->end()) {
590 591 592
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
593 594 595 596
  }
  return map_it->second->size();
}

597
void MKLDNNDeviceContext::SetBlob(const std::string& name,
598
                                  BlobPtr_t<void> data) const {
599
  BlobMap* pMap = p_blobmap_.get();
600 601
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
  BlobPtr_t<KeyBlob> pBlob = nullptr;
602

603
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
604

605
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
606

607 608
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
609 610 611

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
612
    sBlob = std::make_shared<ShapeBlob>();
613 614
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
615
  } else {
616
    sBlob = map_it->second;
617
  }
T
tensor-tang 已提交
618

619
  // Find KeyBlob for current input shape
620
  auto key_it = sBlob->find(tls().cur_input_shape_str);
621

622
  if (key_it == sBlob->end()) {
623 624
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
625 626
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
627
        sBlob->size() &&
628
        (sBlob->size() >=
629
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
630 631 632 633
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
    }
634 635
    pBlob = std::make_shared<KeyBlob>();
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
636
  } else {
637
    pBlob = key_it->second;
638 639
  }

640 641 642 643 644 645 646
  // Find Blob via name
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    (*pBlob)[name] = data;
  } else {
    blob_it->second = data;  // set data to existing blob
  }
647
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
648
  // lock will be automatically released when out of scope
649
  return;
T
tensor-tang 已提交
650 651
}

652 653 654 655 656 657 658 659 660 661
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) {
  unsigned int num_entries = 0;
  for (auto const& l3 : *p_blobmap_) {
    for (auto const& l2 : *(l3.second)) {
      num_entries += (l2.second)->size();
    }
  }
  return num_entries;
}

662
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
663
    const std::string& name) const {
664
  BlobMap* pMap = p_blobmap_.get();
665 666
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
667

668
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
669

670
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
671

672 673
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
674
  if (map_it == pMap->end()) {
675
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
676 677 678 679 680
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
681
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
682
  if (sBlob_it == sBlob->end()) {
683
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
684 685 686 687
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
688 689 690 691

  // Find Blob via name
  auto key_it = pBlob->find(name);

692
  if (key_it == pBlob->end()) {
693
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
694 695
    return nullptr;
  }
696

697
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
698 699
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
700 701 702 703
}

#endif

Q
qijun 已提交
704
}  // namespace platform
Q
qijun 已提交
705
}  // namespace paddle