device_context.cc 19.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
13
#include <set>
14
#include <string>
Y
Yu Yang 已提交
15
#include <unordered_set>
16 17
#include <vector>

Y
Yi Wang 已提交
18
#include "paddle/fluid/memory/memory.h"
19 20
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/rw_lock.h"
21
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
22
#include "paddle/fluid/platform/cuda_device_guard.h"
23
#endif
24

25 26
#include "glog/logging.h"

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
#ifdef PADDLE_WITH_CUDA
  if (size == 0 || !platform::is_gpu_place(place)) {
    return Alloc(place, size);
  }
  auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  auto& desired_dev_ctx =
      static_cast<const platform::CUDADeviceContext&>(dev_ctx);
  if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
    return Alloc(place, size);
  } else {
    return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
        desired_dev_ctx, size);
  }
#else
  return Alloc(place, size);
#endif
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
54 55 56
namespace paddle {
namespace platform {

D
dzhwinter 已提交
57 58
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
59
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
D
dzhwinter 已提交
60 61
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
62 63
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
64 65
        "with WITH_GPU or WITH_XPU option or check that your train process "
        "hold the "
G
GaoWei8 已提交
66 67
        "correct gpu_id if you use Executor.",
        place));
D
dzhwinter 已提交
68
  }
69
  return it->second.get().get();
D
dzhwinter 已提交
70 71
}

72 73 74 75 76 77 78 79 80
template <typename DevCtx, typename PlaceType>
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
  map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
                     // lazy evaluation. i.e., only create device context at
                     // first `Get`
81
                     return PtrType(new DevCtx(BOOST_GET_CONST(PlaceType, p)));
82
                   }));
C
chengduozh 已提交
83 84
}

D
dzhwinter 已提交
85 86
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
87 88 89 90 91
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
92
  std::set<Place> set;
Y
Yu Yang 已提交
93 94 95 96 97
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
98
#ifdef PADDLE_WITH_MKLDNN
99
      EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
100
#else
101
      EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
102
#endif
Y
Yu Yang 已提交
103
    } else if (platform::is_gpu_place(p)) {
D
dzhwinter 已提交
104
#ifdef PADDLE_WITH_CUDA
105
      EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
D
dzhwinter 已提交
106
#else
G
GaoWei8 已提交
107 108 109
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
110 111 112
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
#ifdef PADDLE_WITH_CUDA
113 114
      EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
          &device_contexts_, p);
C
chengduoZH 已提交
115
#else
G
GaoWei8 已提交
116
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
117 118
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
119 120 121 122 123 124 125 126
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
      EmplaceDeviceContext<XPUDeviceContext, XPUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
D
dzhwinter 已提交
127 128 129 130 131
#endif
    }
  }
}

132 133 134 135
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
136
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
137 138 139 140 141 142 143
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
144
Place CPUDeviceContext::GetPlace() const { return place_; }
145

146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
#ifdef PADDLE_WITH_XPU
XPUDeviceContext::XPUDeviceContext() { context_ = xpu::create_context(); }

XPUDeviceContext::~XPUDeviceContext() { xpu::destroy_context(context_); }

XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) {
  int dev_id = -1;
  int ret = xpu_current_device(&dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  ret = xpu_set_device(place.device);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  context_ = xpu::create_context();
  ret = xpu_set_device(dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
}

void XPUDeviceContext::Wait() const {
  int ret = xpu_set_device(place_.device);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  xpu_wait();
}

Place XPUDeviceContext::GetPlace() const { return place_; }

xpu::Context* XPUDeviceContext::x_context() const { return context_; }
#endif

189
#ifdef PADDLE_WITH_CUDA
190

Q
init  
qijun 已提交
191 192 193 194 195 196 197
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

D
dzhwinter 已提交
198
  void Reinitialize(const cudaStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
199 200 201 202 203 204 205 206 207 208 209 210
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

  const cudaStream_t& stream() const override { return *stream_; }

  const cudaDeviceProp& deviceProperties() const override {
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
211 212 213
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
214 215 216
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
217
    void* retv = buf->ptr();
S
sneaxiy 已提交
218 219 220 221
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
222
    return retv;
Q
init  
qijun 已提交
223 224
  }

S
sneaxiy 已提交
225 226 227 228 229 230
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
231 232 233

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
234 235 236 237
// windows use an old version of eigen that uses kCudaScratchSize,
// once windows updates eigen to a recent version, the following code
// can use kGpuScratchSize uniformly
#ifdef _WIN32
Q
init  
qijun 已提交
238
      scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int));
Z
Zhang Ting 已提交
239 240 241
#else
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
#endif
Q
init  
qijun 已提交
242 243 244 245 246 247
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
248
#ifdef _WIN32
Q
init  
qijun 已提交
249 250
      char* scratch =
          static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
Z
Zhang Ting 已提交
251 252 253
#else
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
#endif
Q
init  
qijun 已提交
254
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
255
      PADDLE_ENFORCE_CUDA_SUCCESS(
Q
init  
qijun 已提交
256 257 258 259 260 261
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
262
  CUDAPlace place_;
Q
init  
qijun 已提交
263 264
  const cudaStream_t* stream_;         // not owned;
  const cudaDeviceProp* device_prop_;  // not owned;
Q
qijun 已提交
265
  mutable void* scratch_;
Q
init  
qijun 已提交
266
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
267
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
268
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
269 270
};

271 272 273 274 275 276 277 278 279
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

280 281 282 283 284 285 286 287 288 289 290 291
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
292
                         const stream::Priority& priority) {
293 294 295 296 297 298
  place_ = place;
  CUDADeviceGuard guard(place_.device);
  stream_.reset(new stream::CUDAStream(place, priority));
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
G
Guo Sheng 已提交
299
  InitCuSolverContext();
300 301 302 303 304 305
}

CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
G
Guo Sheng 已提交
306
  DestoryCuSolverContext();
307 308
}

309
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
Y
Yu Yang 已提交
310
  CUDADeviceGuard guard(place_.device);
C
chengduo 已提交
311 312 313
  compute_capability_ = GetCUDAComputeCapability(place_.device);
  multi_process_ = GetCUDAMultiProcessors(place_.device);
  max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
314
  max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
315
  max_threads_per_block_ = GetCUDAMaxThreadsPerBlock(place_.device);
316

C
chengduo 已提交
317 318 319
  driver_version_ = GetCUDADriverVersion(place_.device);
  runtime_version_ = GetCUDARuntimeVersion(place_.device);

320
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device
321 322 323
                          << ", GPU Compute Capability: "
                          << compute_capability_ / 10 << "."
                          << compute_capability_ % 10
C
chengduo 已提交
324
                          << ", Driver API Version: " << driver_version_ / 1000
325
                          << "." << (driver_version_ % 100) / 10
C
chengduo 已提交
326 327 328
                          << ", Runtime API Version: "
                          << runtime_version_ / 1000 << "."
                          << (runtime_version_ % 100) / 10;
329 330 331
  size_t cudnn_dso_ver = dynload::cudnnGetVersion();
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
332
                          << (cudnn_dso_ver % 1000) / 100 << ".";
S
sneaxiy 已提交
333 334 335

  {
    // Check CUDA/CUDNN version compatiblity
336 337 338 339
    auto local_cuda_version =
        (driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
    auto compile_cuda_version =
        (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
S
sneaxiy 已提交
340 341 342 343 344 345 346 347 348 349 350 351
    if (local_cuda_version < compile_cuda_version) {
      LOG_FIRST_N(WARNING, 1)
          << "WARNING: device: " << place_.device
          << ". The installed Paddle is compiled with CUDA "
          << compile_cuda_version / 10 << "." << compile_cuda_version % 10
          << ", but CUDA runtime version in your machine is "
          << local_cuda_version / 10 << "." << local_cuda_version % 10
          << ", which may cause serious incompatible bug. "
          << "Please recompile or reinstall Paddle with compatible CUDA "
             "version.";
    }
  }
352
  default_ctx_.reset(new CUDAContext(place_));
353 354 355 356
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
357 358 359 360 361
#if defined(PADDLE_WITH_NCCL)
  if (nccl_comm_) {
    PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
  }
#endif
362 363
}

L
liaogang 已提交
364
Place CUDADeviceContext::GetPlace() const { return place_; }
365

366
void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); }
367

K
Kexin Zhao 已提交
368
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
369
  return compute_capability_;
K
Kexin Zhao 已提交
370 371
}

372
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
373
  return multi_process_ * max_threads_per_mp_;
374 375
}

376 377 378 379 380 381
int CUDADeviceContext::GetSMCount() const { return multi_process_; }

int CUDADeviceContext::GetMaxThreadsPerBlock() const {
  return max_threads_per_block_;
}

382
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
383
  return context()->EigenDevice().get();
384 385
}

386
bool CUDADeviceContext::tensor_core_available() const {
387
  return context()->CublasTensorCoreHandle() != nullptr;
S
sneaxiy 已提交
388 389
}

390 391 392 393
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
  return max_grid_dim_size_;
}

394 395 396
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
  return context()->CudnnHandle();
}
397

S
sneaxiy 已提交
398
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
399
  return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
400
}
401

G
Guo Sheng 已提交
402 403 404 405
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  return context()->CusolverDnHandle();
}

406 407 408
cudaStream_t CUDADeviceContext::stream() const {
  return context()->RawStream();
}
Q
qijun 已提交
409

C
chengduoZH 已提交
410 411 412 413 414 415 416 417 418 419 420 421 422 423
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
424
#endif
Q
qijun 已提交
425

T
tensor-tang 已提交
426 427
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
A
Adam 已提交
428 429 430
    : CPUDeviceContext(place),
      engine_(mkldnn::engine::kind::cpu, 0),
      p_blobmap_() {
431 432
  p_blobmap_.reset(new BlobMap());
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
433 434
}

435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
MKLDNNDeviceContextThreadLocals::Body::Body() {
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
452 453
  cur_input_shape_str = input_shape_str;
}
454 455
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
456 457
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
458

459 460
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
461 462 463
  cur_paddle_data_layout = dl;
}

464 465
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
466 467 468
  return cur_paddle_data_layout;
}

469 470 471 472 473 474 475 476 477
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

478 479 480 481 482 483 484 485 486 487 488 489 490 491 492
void MKLDNNDeviceContext::ResetBlobMap() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
    p_blobmap_->clear();
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
493
}
494

495
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
496
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
497
  BlobMap* pMap = p_blobmap_.get();
498
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
499
  if (map_it == pMap->end()) {
500 501 502
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
503 504 505 506
  }
  return map_it->second->size();
}

507
void MKLDNNDeviceContext::SetBlob(const std::string& name,
508
                                  BlobPtr_t<void> data) const {
509
  BlobMap* pMap = p_blobmap_.get();
510 511
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
  BlobPtr_t<KeyBlob> pBlob = nullptr;
512

513
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
514

515
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
516

517 518
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
519 520 521

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
522
    sBlob = std::make_shared<ShapeBlob>();
523 524
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
525
  } else {
526
    sBlob = map_it->second;
527
  }
T
tensor-tang 已提交
528

529
  // Find KeyBlob for current input shape
530
  auto key_it = sBlob->find(tls().cur_input_shape_str);
531

532
  if (key_it == sBlob->end()) {
533 534
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
535 536
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
537
        sBlob->size() &&
538
        (sBlob->size() >=
539
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
540 541 542 543
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
    }
544 545
    pBlob = std::make_shared<KeyBlob>();
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
546
  } else {
547
    pBlob = key_it->second;
548 549
  }

550 551 552 553 554 555 556
  // Find Blob via name
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    (*pBlob)[name] = data;
  } else {
    blob_it->second = data;  // set data to existing blob
  }
557
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
558
  // lock will be automatically released when out of scope
559
  return;
T
tensor-tang 已提交
560 561
}

562
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
563
    const std::string& name) const {
564
  BlobMap* pMap = p_blobmap_.get();
565 566
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
567

568
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
569

570
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
571

572 573
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
574
  if (map_it == pMap->end()) {
575
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
576 577 578 579 580
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
581
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
582
  if (sBlob_it == sBlob->end()) {
583
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
584 585 586 587
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
588 589 590 591

  // Find Blob via name
  auto key_it = pBlob->find(name);

592
  if (key_it == pBlob->end()) {
593
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
594 595
    return nullptr;
  }
596

597
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
598 599
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
600 601 602 603
}

#endif

Q
qijun 已提交
604
}  // namespace platform
Q
qijun 已提交
605
}  // namespace paddle