device_context.cc 21.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
13
#include <set>
14
#include <string>
Y
Yu Yang 已提交
15
#include <unordered_set>
16 17
#include <vector>

Y
Yi Wang 已提交
18
#include "paddle/fluid/memory/memory.h"
19 20
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/rw_lock.h"
21
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
22
#include "paddle/fluid/platform/cuda_device_guard.h"
23
#endif
24

25 26
#include "glog/logging.h"

27 28 29 30 31
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
32
  if (size == 0) {
33 34
    return Alloc(place, size);
  }
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

  if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
56 57
    return Alloc(place, size);
#else
58 59 60
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
61
#endif
62 63 64
  } else {
    return Alloc(place, size);
  }
65 66 67 68 69
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
70 71 72
namespace paddle {
namespace platform {

73 74 75 76 77 78 79 80 81 82
#ifdef PADDLE_WITH_CUDA
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }

bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
#endif  // PADDLE_WITH_CUDA

D
dzhwinter 已提交
83 84
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
85
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
D
dzhwinter 已提交
86 87
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
88 89
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
90 91
        "with WITH_GPU or WITH_XPU option or check that your train process "
        "hold the "
G
GaoWei8 已提交
92 93
        "correct gpu_id if you use Executor.",
        place));
D
dzhwinter 已提交
94
  }
95
  return it->second.get().get();
D
dzhwinter 已提交
96 97
}

98 99 100 101 102 103 104 105 106
template <typename DevCtx, typename PlaceType>
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
  map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
                     // lazy evaluation. i.e., only create device context at
                     // first `Get`
107
                     return PtrType(new DevCtx(BOOST_GET_CONST(PlaceType, p)));
108
                   }));
C
chengduozh 已提交
109 110
}

D
dzhwinter 已提交
111 112
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
113 114 115 116 117
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
118
  std::set<Place> set;
Y
Yu Yang 已提交
119 120 121 122 123
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
124
#ifdef PADDLE_WITH_MKLDNN
125
      EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
126
#else
127
      EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
128
#endif
Y
Yu Yang 已提交
129
    } else if (platform::is_gpu_place(p)) {
D
dzhwinter 已提交
130
#ifdef PADDLE_WITH_CUDA
131
      EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
D
dzhwinter 已提交
132
#else
G
GaoWei8 已提交
133 134 135
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
136 137 138
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
#ifdef PADDLE_WITH_CUDA
139 140
      EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
          &device_contexts_, p);
C
chengduoZH 已提交
141
#else
G
GaoWei8 已提交
142
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
143 144
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
145 146 147 148 149 150 151 152
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
      EmplaceDeviceContext<XPUDeviceContext, XPUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
D
dzhwinter 已提交
153 154 155 156 157
#endif
    }
  }
}

158 159 160 161
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
162
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
163 164 165 166 167 168 169
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
170
Place CPUDeviceContext::GetPlace() const { return place_; }
171

172 173 174
#ifdef PADDLE_WITH_XPU
XPUDeviceContext::XPUDeviceContext() { context_ = xpu::create_context(); }

175 176 177 178 179 180 181 182 183 184
XPUDeviceContext::~XPUDeviceContext() {
  xpu::destroy_context(context_);
  void* l3ptr = nullptr;
  int l3_size = 13.5 * 1024 * 1024;
  xpu_malloc(static_cast<void**>(&l3ptr), l3_size, XPU_MEM_L3);
  if (l3ptr != nullptr) {
    context_->_l3_mgr.set(l3ptr, l3_size);
    std::cout << "set l3 size " << l3_size << std::endl;
  }
}
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200

XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) {
  int dev_id = -1;
  int ret = xpu_current_device(&dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  ret = xpu_set_device(place.device);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  context_ = xpu::create_context();
201 202 203 204 205 206 207
  void* l3ptr = nullptr;
  int l3_size = 13.5 * 1024 * 1024;
  xpu_malloc(static_cast<void**>(&l3ptr), l3_size, XPU_MEM_L3);
  if (l3ptr != nullptr) {
    context_->_l3_mgr.set(l3ptr, l3_size);
    std::cout << "set l3 size " << l3_size << std::endl;
  }
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
  ret = xpu_set_device(dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
}

void XPUDeviceContext::Wait() const {
  int ret = xpu_set_device(place_.device);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
223
  xpu_wait(context_->xpu_stream);
224 225 226 227 228 229 230
}

Place XPUDeviceContext::GetPlace() const { return place_; }

xpu::Context* XPUDeviceContext::x_context() const { return context_; }
#endif

231
#ifdef PADDLE_WITH_CUDA
232

Q
init  
qijun 已提交
233 234 235 236 237 238 239
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

D
dzhwinter 已提交
240
  void Reinitialize(const cudaStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
241 242 243 244 245 246 247 248 249 250 251 252
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

  const cudaStream_t& stream() const override { return *stream_; }

  const cudaDeviceProp& deviceProperties() const override {
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
253 254 255
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
256 257 258
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
259
    void* retv = buf->ptr();
S
sneaxiy 已提交
260 261 262 263
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
264
    return retv;
Q
init  
qijun 已提交
265 266
  }

S
sneaxiy 已提交
267 268 269 270 271 272
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
273 274 275

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
276 277 278 279
// windows use an old version of eigen that uses kCudaScratchSize,
// once windows updates eigen to a recent version, the following code
// can use kGpuScratchSize uniformly
#ifdef _WIN32
Q
init  
qijun 已提交
280
      scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int));
Z
Zhang Ting 已提交
281 282 283
#else
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
#endif
Q
init  
qijun 已提交
284 285 286 287 288 289
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
290
#ifdef _WIN32
Q
init  
qijun 已提交
291 292
      char* scratch =
          static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
Z
Zhang Ting 已提交
293 294 295
#else
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
#endif
Q
init  
qijun 已提交
296
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
297
      PADDLE_ENFORCE_CUDA_SUCCESS(
Q
init  
qijun 已提交
298 299 300 301 302 303
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
304
  CUDAPlace place_;
Q
init  
qijun 已提交
305 306
  const cudaStream_t* stream_;         // not owned;
  const cudaDeviceProp* device_prop_;  // not owned;
Q
qijun 已提交
307
  mutable void* scratch_;
Q
init  
qijun 已提交
308
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
309
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
310
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
311 312
};

313 314 315 316 317 318 319 320 321
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

322 323 324 325 326 327 328 329 330 331 332 333
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
334
                         const stream::Priority& priority) {
335 336 337 338 339 340
  place_ = place;
  CUDADeviceGuard guard(place_.device);
  stream_.reset(new stream::CUDAStream(place, priority));
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
G
Guo Sheng 已提交
341
  InitCuSolverContext();
342 343 344 345 346 347
}

CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
G
Guo Sheng 已提交
348
  DestoryCuSolverContext();
349 350
}

351
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
Y
Yu Yang 已提交
352
  CUDADeviceGuard guard(place_.device);
C
chengduo 已提交
353 354 355
  compute_capability_ = GetCUDAComputeCapability(place_.device);
  multi_process_ = GetCUDAMultiProcessors(place_.device);
  max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
356
  max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
357
  max_threads_per_block_ = GetCUDAMaxThreadsPerBlock(place_.device);
358

C
chengduo 已提交
359 360 361
  driver_version_ = GetCUDADriverVersion(place_.device);
  runtime_version_ = GetCUDARuntimeVersion(place_.device);

362
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device
363 364 365
                          << ", GPU Compute Capability: "
                          << compute_capability_ / 10 << "."
                          << compute_capability_ % 10
C
chengduo 已提交
366
                          << ", Driver API Version: " << driver_version_ / 1000
367
                          << "." << (driver_version_ % 100) / 10
C
chengduo 已提交
368 369 370
                          << ", Runtime API Version: "
                          << runtime_version_ / 1000 << "."
                          << (runtime_version_ % 100) / 10;
371 372 373
  size_t cudnn_dso_ver = dynload::cudnnGetVersion();
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
374
                          << (cudnn_dso_ver % 1000) / 100 << ".";
S
sneaxiy 已提交
375 376 377

  {
    // Check CUDA/CUDNN version compatiblity
378 379 380 381
    auto local_cuda_version =
        (driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
    auto compile_cuda_version =
        (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
S
sneaxiy 已提交
382 383 384 385 386 387 388 389 390 391 392 393
    if (local_cuda_version < compile_cuda_version) {
      LOG_FIRST_N(WARNING, 1)
          << "WARNING: device: " << place_.device
          << ". The installed Paddle is compiled with CUDA "
          << compile_cuda_version / 10 << "." << compile_cuda_version % 10
          << ", but CUDA runtime version in your machine is "
          << local_cuda_version / 10 << "." << local_cuda_version % 10
          << ", which may cause serious incompatible bug. "
          << "Please recompile or reinstall Paddle with compatible CUDA "
             "version.";
    }
  }
394
  default_ctx_.reset(new CUDAContext(place_));
395 396 397 398
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
399 400 401 402 403
#if defined(PADDLE_WITH_NCCL)
  if (nccl_comm_) {
    PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
  }
#endif
404 405
}

L
liaogang 已提交
406
Place CUDADeviceContext::GetPlace() const { return place_; }
407

408
void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); }
409

K
Kexin Zhao 已提交
410
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
411
  return compute_capability_;
K
Kexin Zhao 已提交
412 413
}

414
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
415
  return multi_process_ * max_threads_per_mp_;
416 417
}

418 419 420 421 422 423
int CUDADeviceContext::GetSMCount() const { return multi_process_; }

int CUDADeviceContext::GetMaxThreadsPerBlock() const {
  return max_threads_per_block_;
}

424
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
425
  return context()->EigenDevice().get();
426 427
}

428
bool CUDADeviceContext::tensor_core_available() const {
429
  return context()->CublasTensorCoreHandle() != nullptr;
S
sneaxiy 已提交
430 431
}

432 433 434 435
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
  return max_grid_dim_size_;
}

436 437 438
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
  return context()->CudnnHandle();
}
439

S
sneaxiy 已提交
440
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
441
  return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
442
}
443

G
Guo Sheng 已提交
444 445 446 447
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  return context()->CusolverDnHandle();
}

448 449 450
cudaStream_t CUDADeviceContext::stream() const {
  return context()->RawStream();
}
Q
qijun 已提交
451

C
chengduoZH 已提交
452 453 454 455 456 457 458 459 460 461 462 463 464 465
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
466
#endif
Q
qijun 已提交
467

T
tensor-tang 已提交
468 469
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
A
Adam 已提交
470 471 472
    : CPUDeviceContext(place),
      engine_(mkldnn::engine::kind::cpu, 0),
      p_blobmap_() {
473 474
  p_blobmap_.reset(new BlobMap());
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
475 476
}

477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493
MKLDNNDeviceContextThreadLocals::Body::Body() {
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
494 495
  cur_input_shape_str = input_shape_str;
}
496 497
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
498 499
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
500

501 502
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
503 504 505
  cur_paddle_data_layout = dl;
}

506 507
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
508 509 510
  return cur_paddle_data_layout;
}

511 512 513 514 515 516 517 518 519
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

520 521 522 523 524 525 526 527 528 529 530 531 532 533 534
void MKLDNNDeviceContext::ResetBlobMap() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
    p_blobmap_->clear();
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
535
}
536

537
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
538
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
539
  BlobMap* pMap = p_blobmap_.get();
540
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
541
  if (map_it == pMap->end()) {
542 543 544
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
545 546 547 548
  }
  return map_it->second->size();
}

549
void MKLDNNDeviceContext::SetBlob(const std::string& name,
550
                                  BlobPtr_t<void> data) const {
551
  BlobMap* pMap = p_blobmap_.get();
552 553
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
  BlobPtr_t<KeyBlob> pBlob = nullptr;
554

555
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
556

557
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
558

559 560
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
561 562 563

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
564
    sBlob = std::make_shared<ShapeBlob>();
565 566
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
567
  } else {
568
    sBlob = map_it->second;
569
  }
T
tensor-tang 已提交
570

571
  // Find KeyBlob for current input shape
572
  auto key_it = sBlob->find(tls().cur_input_shape_str);
573

574
  if (key_it == sBlob->end()) {
575 576
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
577 578
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
579
        sBlob->size() &&
580
        (sBlob->size() >=
581
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
582 583 584 585
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
    }
586 587
    pBlob = std::make_shared<KeyBlob>();
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
588
  } else {
589
    pBlob = key_it->second;
590 591
  }

592 593 594 595 596 597 598
  // Find Blob via name
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    (*pBlob)[name] = data;
  } else {
    blob_it->second = data;  // set data to existing blob
  }
599
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
600
  // lock will be automatically released when out of scope
601
  return;
T
tensor-tang 已提交
602 603
}

604
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
605
    const std::string& name) const {
606
  BlobMap* pMap = p_blobmap_.get();
607 608
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
609

610
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
611

612
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
613

614 615
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
616
  if (map_it == pMap->end()) {
617
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
618 619 620 621 622
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
623
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
624
  if (sBlob_it == sBlob->end()) {
625
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
626 627 628 629
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
630 631 632 633

  // Find Blob via name
  auto key_it = pBlob->find(name);

634
  if (key_it == pBlob->end()) {
635
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
636 637
    return nullptr;
  }
638

639
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
640 641
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
642 643 644 645
}

#endif

Q
qijun 已提交
646
}  // namespace platform
Q
qijun 已提交
647
}  // namespace paddle