device_context.cc 19.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
13
#include <set>
14
#include <string>
Y
Yu Yang 已提交
15
#include <unordered_set>
16 17
#include <vector>

Y
Yi Wang 已提交
18
#include "paddle/fluid/memory/memory.h"
19 20
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/rw_lock.h"
21
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
22
#include "paddle/fluid/platform/cuda_device_guard.h"
23
#endif
24

25 26
#include "glog/logging.h"

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
#ifdef PADDLE_WITH_CUDA
  if (size == 0 || !platform::is_gpu_place(place)) {
    return Alloc(place, size);
  }
  auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  auto& desired_dev_ctx =
      static_cast<const platform::CUDADeviceContext&>(dev_ctx);
  if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
    return Alloc(place, size);
  } else {
    return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
        desired_dev_ctx, size);
  }
#else
  return Alloc(place, size);
#endif
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
54 55 56
namespace paddle {
namespace platform {

57 58 59 60 61 62
#ifdef PADDLE_WITH_CUDA
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
#endif  // PADDLE_WITH_CUDA

D
dzhwinter 已提交
63 64
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
65
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
D
dzhwinter 已提交
66 67
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
68 69
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
70 71
        "with WITH_GPU or WITH_XPU option or check that your train process "
        "hold the "
G
GaoWei8 已提交
72 73
        "correct gpu_id if you use Executor.",
        place));
D
dzhwinter 已提交
74
  }
75
  return it->second.get().get();
D
dzhwinter 已提交
76 77
}

78 79 80 81 82 83 84 85 86
template <typename DevCtx, typename PlaceType>
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
  map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
                     // lazy evaluation. i.e., only create device context at
                     // first `Get`
87
                     return PtrType(new DevCtx(BOOST_GET_CONST(PlaceType, p)));
88
                   }));
C
chengduozh 已提交
89 90
}

D
dzhwinter 已提交
91 92
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
93 94 95 96 97
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
98
  std::set<Place> set;
Y
Yu Yang 已提交
99 100 101 102 103
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
104
#ifdef PADDLE_WITH_MKLDNN
105
      EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
106
#else
107
      EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
108
#endif
Y
Yu Yang 已提交
109
    } else if (platform::is_gpu_place(p)) {
D
dzhwinter 已提交
110
#ifdef PADDLE_WITH_CUDA
111
      EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
D
dzhwinter 已提交
112
#else
G
GaoWei8 已提交
113 114 115
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
116 117 118
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
#ifdef PADDLE_WITH_CUDA
119 120
      EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
          &device_contexts_, p);
C
chengduoZH 已提交
121
#else
G
GaoWei8 已提交
122
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
123 124
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
125 126 127 128 129 130 131 132
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
      EmplaceDeviceContext<XPUDeviceContext, XPUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
D
dzhwinter 已提交
133 134 135 136 137
#endif
    }
  }
}

138 139 140 141
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
142
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
143 144 145 146 147 148 149
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
150
Place CPUDeviceContext::GetPlace() const { return place_; }
151

152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
#ifdef PADDLE_WITH_XPU
XPUDeviceContext::XPUDeviceContext() { context_ = xpu::create_context(); }

XPUDeviceContext::~XPUDeviceContext() { xpu::destroy_context(context_); }

XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) {
  int dev_id = -1;
  int ret = xpu_current_device(&dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  ret = xpu_set_device(place.device);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  context_ = xpu::create_context();
  ret = xpu_set_device(dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
}

void XPUDeviceContext::Wait() const {
  int ret = xpu_set_device(place_.device);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  xpu_wait();
}

Place XPUDeviceContext::GetPlace() const { return place_; }

xpu::Context* XPUDeviceContext::x_context() const { return context_; }
#endif

195
#ifdef PADDLE_WITH_CUDA
196

Q
init  
qijun 已提交
197 198 199 200 201 202 203
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

D
dzhwinter 已提交
204
  void Reinitialize(const cudaStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
205 206 207 208 209 210 211 212 213 214 215 216
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

  const cudaStream_t& stream() const override { return *stream_; }

  const cudaDeviceProp& deviceProperties() const override {
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
217 218 219
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
220 221 222
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
223
    void* retv = buf->ptr();
S
sneaxiy 已提交
224 225 226 227
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
228
    return retv;
Q
init  
qijun 已提交
229 230
  }

S
sneaxiy 已提交
231 232 233 234 235 236
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
237 238 239

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
240 241 242 243
// windows use an old version of eigen that uses kCudaScratchSize,
// once windows updates eigen to a recent version, the following code
// can use kGpuScratchSize uniformly
#ifdef _WIN32
Q
init  
qijun 已提交
244
      scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int));
Z
Zhang Ting 已提交
245 246 247
#else
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
#endif
Q
init  
qijun 已提交
248 249 250 251 252 253
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
254
#ifdef _WIN32
Q
init  
qijun 已提交
255 256
      char* scratch =
          static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
Z
Zhang Ting 已提交
257 258 259
#else
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
#endif
Q
init  
qijun 已提交
260
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
261
      PADDLE_ENFORCE_CUDA_SUCCESS(
Q
init  
qijun 已提交
262 263 264 265 266 267
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
268
  CUDAPlace place_;
Q
init  
qijun 已提交
269 270
  const cudaStream_t* stream_;         // not owned;
  const cudaDeviceProp* device_prop_;  // not owned;
Q
qijun 已提交
271
  mutable void* scratch_;
Q
init  
qijun 已提交
272
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
273
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
274
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
275 276
};

277 278 279 280 281 282 283 284 285
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

286 287 288 289 290 291 292 293 294 295 296 297
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
298
                         const stream::Priority& priority) {
299 300 301 302 303 304
  place_ = place;
  CUDADeviceGuard guard(place_.device);
  stream_.reset(new stream::CUDAStream(place, priority));
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
G
Guo Sheng 已提交
305
  InitCuSolverContext();
306 307 308 309 310 311
}

CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
G
Guo Sheng 已提交
312
  DestoryCuSolverContext();
313 314
}

315
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
Y
Yu Yang 已提交
316
  CUDADeviceGuard guard(place_.device);
C
chengduo 已提交
317 318 319
  compute_capability_ = GetCUDAComputeCapability(place_.device);
  multi_process_ = GetCUDAMultiProcessors(place_.device);
  max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
320
  max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
321
  max_threads_per_block_ = GetCUDAMaxThreadsPerBlock(place_.device);
322

C
chengduo 已提交
323 324 325
  driver_version_ = GetCUDADriverVersion(place_.device);
  runtime_version_ = GetCUDARuntimeVersion(place_.device);

326
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device
327 328 329
                          << ", GPU Compute Capability: "
                          << compute_capability_ / 10 << "."
                          << compute_capability_ % 10
C
chengduo 已提交
330
                          << ", Driver API Version: " << driver_version_ / 1000
331
                          << "." << (driver_version_ % 100) / 10
C
chengduo 已提交
332 333 334
                          << ", Runtime API Version: "
                          << runtime_version_ / 1000 << "."
                          << (runtime_version_ % 100) / 10;
335 336 337
  size_t cudnn_dso_ver = dynload::cudnnGetVersion();
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
338
                          << (cudnn_dso_ver % 1000) / 100 << ".";
S
sneaxiy 已提交
339 340 341

  {
    // Check CUDA/CUDNN version compatiblity
342 343 344 345
    auto local_cuda_version =
        (driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
    auto compile_cuda_version =
        (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
S
sneaxiy 已提交
346 347 348 349 350 351 352 353 354 355 356 357
    if (local_cuda_version < compile_cuda_version) {
      LOG_FIRST_N(WARNING, 1)
          << "WARNING: device: " << place_.device
          << ". The installed Paddle is compiled with CUDA "
          << compile_cuda_version / 10 << "." << compile_cuda_version % 10
          << ", but CUDA runtime version in your machine is "
          << local_cuda_version / 10 << "." << local_cuda_version % 10
          << ", which may cause serious incompatible bug. "
          << "Please recompile or reinstall Paddle with compatible CUDA "
             "version.";
    }
  }
358
  default_ctx_.reset(new CUDAContext(place_));
359 360 361 362
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
363 364 365 366 367
#if defined(PADDLE_WITH_NCCL)
  if (nccl_comm_) {
    PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
  }
#endif
368 369
}

L
liaogang 已提交
370
Place CUDADeviceContext::GetPlace() const { return place_; }
371

372
void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); }
373

K
Kexin Zhao 已提交
374
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
375
  return compute_capability_;
K
Kexin Zhao 已提交
376 377
}

378
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
379
  return multi_process_ * max_threads_per_mp_;
380 381
}

382 383 384 385 386 387
int CUDADeviceContext::GetSMCount() const { return multi_process_; }

int CUDADeviceContext::GetMaxThreadsPerBlock() const {
  return max_threads_per_block_;
}

388
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
389
  return context()->EigenDevice().get();
390 391
}

392
bool CUDADeviceContext::tensor_core_available() const {
393
  return context()->CublasTensorCoreHandle() != nullptr;
S
sneaxiy 已提交
394 395
}

396 397 398 399
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
  return max_grid_dim_size_;
}

400 401 402
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
  return context()->CudnnHandle();
}
403

S
sneaxiy 已提交
404
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
405
  return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
406
}
407

G
Guo Sheng 已提交
408 409 410 411
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  return context()->CusolverDnHandle();
}

412 413 414
cudaStream_t CUDADeviceContext::stream() const {
  return context()->RawStream();
}
Q
qijun 已提交
415

C
chengduoZH 已提交
416 417 418 419 420 421 422 423 424 425 426 427 428 429
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
430
#endif
Q
qijun 已提交
431

T
tensor-tang 已提交
432 433
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
A
Adam 已提交
434 435 436
    : CPUDeviceContext(place),
      engine_(mkldnn::engine::kind::cpu, 0),
      p_blobmap_() {
437 438
  p_blobmap_.reset(new BlobMap());
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
439 440
}

441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
MKLDNNDeviceContextThreadLocals::Body::Body() {
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
458 459
  cur_input_shape_str = input_shape_str;
}
460 461
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
462 463
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
464

465 466
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
467 468 469
  cur_paddle_data_layout = dl;
}

470 471
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
472 473 474
  return cur_paddle_data_layout;
}

475 476 477 478 479 480 481 482 483
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

484 485 486 487 488 489 490 491 492 493 494 495 496 497 498
void MKLDNNDeviceContext::ResetBlobMap() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
    p_blobmap_->clear();
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
499
}
500

501
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
502
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
503
  BlobMap* pMap = p_blobmap_.get();
504
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
505
  if (map_it == pMap->end()) {
506 507 508
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
509 510 511 512
  }
  return map_it->second->size();
}

513
void MKLDNNDeviceContext::SetBlob(const std::string& name,
514
                                  BlobPtr_t<void> data) const {
515
  BlobMap* pMap = p_blobmap_.get();
516 517
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
  BlobPtr_t<KeyBlob> pBlob = nullptr;
518

519
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
520

521
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
522

523 524
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
525 526 527

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
528
    sBlob = std::make_shared<ShapeBlob>();
529 530
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
531
  } else {
532
    sBlob = map_it->second;
533
  }
T
tensor-tang 已提交
534

535
  // Find KeyBlob for current input shape
536
  auto key_it = sBlob->find(tls().cur_input_shape_str);
537

538
  if (key_it == sBlob->end()) {
539 540
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
541 542
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
543
        sBlob->size() &&
544
        (sBlob->size() >=
545
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
546 547 548 549
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
    }
550 551
    pBlob = std::make_shared<KeyBlob>();
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
552
  } else {
553
    pBlob = key_it->second;
554 555
  }

556 557 558 559 560 561 562
  // Find Blob via name
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    (*pBlob)[name] = data;
  } else {
    blob_it->second = data;  // set data to existing blob
  }
563
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
564
  // lock will be automatically released when out of scope
565
  return;
T
tensor-tang 已提交
566 567
}

568
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
569
    const std::string& name) const {
570
  BlobMap* pMap = p_blobmap_.get();
571 572
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
573

574
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
575

576
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
577

578 579
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
580
  if (map_it == pMap->end()) {
581
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
582 583 584 585 586
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
587
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
588
  if (sBlob_it == sBlob->end()) {
589
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
590 591 592 593
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
594 595 596 597

  // Find Blob via name
  auto key_it = pBlob->find(name);

598
  if (key_it == pBlob->end()) {
599
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
600 601
    return nullptr;
  }
602

603
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
604 605
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
606 607 608 609
}

#endif

Q
qijun 已提交
610
}  // namespace platform
Q
qijun 已提交
611
}  // namespace paddle