device_context.cc 20.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
13
#include <set>
14
#include <string>
Y
Yu Yang 已提交
15
#include <unordered_set>
16 17
#include <vector>

Y
Yi Wang 已提交
18
#include "paddle/fluid/memory/memory.h"
19 20
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/rw_lock.h"
21
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
22
#include "paddle/fluid/platform/cuda_device_guard.h"
23
#endif
24

25 26
#include "glog/logging.h"

27 28 29 30 31
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
32
  if (size == 0) {
33 34
    return Alloc(place, size);
  }
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

  if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
    auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto& desired_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
      return Alloc(place, size);
    } else {
      return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
          desired_dev_ctx, size);
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use CUDA device since it's not compiled with CUDA,"
        "Please recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    // TODO(liuyuhui): Consider xpu stream later
56 57
    return Alloc(place, size);
#else
58 59 60
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't use XPU device since it's not compiled with XPU,"
        "Please recompile or reinstall Paddle with XPU support."));
61
#endif
62 63 64
  } else {
    return Alloc(place, size);
  }
65 66 67 68 69
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
70 71 72
namespace paddle {
namespace platform {

D
dzhwinter 已提交
73 74
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
75
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
D
dzhwinter 已提交
76 77
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
78 79
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
80 81
        "with WITH_GPU or WITH_XPU option or check that your train process "
        "hold the "
G
GaoWei8 已提交
82 83
        "correct gpu_id if you use Executor.",
        place));
D
dzhwinter 已提交
84
  }
85
  return it->second.get().get();
D
dzhwinter 已提交
86 87
}

88 89 90 91 92 93 94 95 96
template <typename DevCtx, typename PlaceType>
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
  map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
                     // lazy evaluation. i.e., only create device context at
                     // first `Get`
97
                     return PtrType(new DevCtx(BOOST_GET_CONST(PlaceType, p)));
98
                   }));
C
chengduozh 已提交
99 100
}

D
dzhwinter 已提交
101 102
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
103 104 105 106 107
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
108
  std::set<Place> set;
Y
Yu Yang 已提交
109 110 111 112 113
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
114
#ifdef PADDLE_WITH_MKLDNN
115
      EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
116
#else
117
      EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
118
#endif
Y
Yu Yang 已提交
119
    } else if (platform::is_gpu_place(p)) {
D
dzhwinter 已提交
120
#ifdef PADDLE_WITH_CUDA
121
      EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
D
dzhwinter 已提交
122
#else
G
GaoWei8 已提交
123 124 125
      PADDLE_THROW(
          platform::errors::Unimplemented("CUDAPlace is not supported. Please "
                                          "re-compile with WITH_GPU option."));
C
chengduoZH 已提交
126 127 128
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
#ifdef PADDLE_WITH_CUDA
129 130
      EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
          &device_contexts_, p);
C
chengduoZH 已提交
131
#else
G
GaoWei8 已提交
132
      PADDLE_THROW(platform::errors::Unimplemented(
G
GaoWei8 已提交
133 134
          "CUDAPlace is not supported. Please re-compile with WITH_GPU "
          "option."));
135 136 137 138 139 140 141 142
#endif
    } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
      EmplaceDeviceContext<XPUDeviceContext, XPUPlace>(&device_contexts_, p);
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("XPUPlace is not supported. Please "
                                          "re-compile with WITH_XPU option."));
D
dzhwinter 已提交
143 144 145 146 147
#endif
    }
  }
}

148 149 150 151
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
152
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
153 154 155 156 157 158 159
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
160
Place CPUDeviceContext::GetPlace() const { return place_; }
161

162 163 164
#ifdef PADDLE_WITH_XPU
XPUDeviceContext::XPUDeviceContext() { context_ = xpu::create_context(); }

165 166 167 168 169 170 171 172 173 174
XPUDeviceContext::~XPUDeviceContext() {
  xpu::destroy_context(context_);
  void* l3ptr = nullptr;
  int l3_size = 13.5 * 1024 * 1024;
  xpu_malloc(static_cast<void**>(&l3ptr), l3_size, XPU_MEM_L3);
  if (l3ptr != nullptr) {
    context_->_l3_mgr.set(l3ptr, l3_size);
    std::cout << "set l3 size " << l3_size << std::endl;
  }
}
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190

XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) {
  int dev_id = -1;
  int ret = xpu_current_device(&dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  ret = xpu_set_device(place.device);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  context_ = xpu::create_context();
191 192 193 194 195 196 197
  void* l3ptr = nullptr;
  int l3_size = 13.5 * 1024 * 1024;
  xpu_malloc(static_cast<void**>(&l3ptr), l3_size, XPU_MEM_L3);
  if (l3ptr != nullptr) {
    context_->_l3_mgr.set(l3ptr, l3_size);
    std::cout << "set l3 size " << l3_size << std::endl;
  }
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
  ret = xpu_set_device(dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
}

void XPUDeviceContext::Wait() const {
  int ret = xpu_set_device(place_.device);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
213
  xpu_wait(context_->xpu_stream);
214 215 216 217 218 219 220
}

Place XPUDeviceContext::GetPlace() const { return place_; }

xpu::Context* XPUDeviceContext::x_context() const { return context_; }
#endif

221
#ifdef PADDLE_WITH_CUDA
222

Q
init  
qijun 已提交
223 224 225 226 227 228 229
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

D
dzhwinter 已提交
230
  void Reinitialize(const cudaStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
231 232 233 234 235 236 237 238 239 240 241 242
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

  const cudaStream_t& stream() const override { return *stream_; }

  const cudaDeviceProp& deviceProperties() const override {
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
243 244 245
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
246 247 248
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
249
    void* retv = buf->ptr();
S
sneaxiy 已提交
250 251 252 253
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
254
    return retv;
Q
init  
qijun 已提交
255 256
  }

S
sneaxiy 已提交
257 258 259 260 261 262
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
263 264 265

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
266 267 268 269
// windows use an old version of eigen that uses kCudaScratchSize,
// once windows updates eigen to a recent version, the following code
// can use kGpuScratchSize uniformly
#ifdef _WIN32
Q
init  
qijun 已提交
270
      scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int));
Z
Zhang Ting 已提交
271 272 273
#else
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
#endif
Q
init  
qijun 已提交
274 275 276 277 278 279
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
280
#ifdef _WIN32
Q
init  
qijun 已提交
281 282
      char* scratch =
          static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
Z
Zhang Ting 已提交
283 284 285
#else
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
#endif
Q
init  
qijun 已提交
286
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
287
      PADDLE_ENFORCE_CUDA_SUCCESS(
Q
init  
qijun 已提交
288 289 290 291 292 293
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
294
  CUDAPlace place_;
Q
init  
qijun 已提交
295 296
  const cudaStream_t* stream_;         // not owned;
  const cudaDeviceProp* device_prop_;  // not owned;
Q
qijun 已提交
297
  mutable void* scratch_;
Q
init  
qijun 已提交
298
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
299
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
300
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
301 302
};

303 304 305 306 307 308 309 310 311
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

312 313 314 315 316 317 318 319 320 321 322 323
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
324
                         const stream::Priority& priority) {
325 326 327 328 329 330
  place_ = place;
  CUDADeviceGuard guard(place_.device);
  stream_.reset(new stream::CUDAStream(place, priority));
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
G
Guo Sheng 已提交
331
  InitCuSolverContext();
332 333 334 335 336 337
}

CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
G
Guo Sheng 已提交
338
  DestoryCuSolverContext();
339 340
}

341
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
Y
Yu Yang 已提交
342
  CUDADeviceGuard guard(place_.device);
C
chengduo 已提交
343 344 345
  compute_capability_ = GetCUDAComputeCapability(place_.device);
  multi_process_ = GetCUDAMultiProcessors(place_.device);
  max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
346
  max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
347
  max_threads_per_block_ = GetCUDAMaxThreadsPerBlock(place_.device);
348

C
chengduo 已提交
349 350 351
  driver_version_ = GetCUDADriverVersion(place_.device);
  runtime_version_ = GetCUDARuntimeVersion(place_.device);

352
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device
353 354 355
                          << ", GPU Compute Capability: "
                          << compute_capability_ / 10 << "."
                          << compute_capability_ % 10
C
chengduo 已提交
356
                          << ", Driver API Version: " << driver_version_ / 1000
357
                          << "." << (driver_version_ % 100) / 10
C
chengduo 已提交
358 359 360
                          << ", Runtime API Version: "
                          << runtime_version_ / 1000 << "."
                          << (runtime_version_ % 100) / 10;
361 362 363
  size_t cudnn_dso_ver = dynload::cudnnGetVersion();
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
364
                          << (cudnn_dso_ver % 1000) / 100 << ".";
S
sneaxiy 已提交
365 366 367

  {
    // Check CUDA/CUDNN version compatiblity
368 369 370 371
    auto local_cuda_version =
        (driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
    auto compile_cuda_version =
        (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
S
sneaxiy 已提交
372 373 374 375 376 377 378 379 380 381 382 383
    if (local_cuda_version < compile_cuda_version) {
      LOG_FIRST_N(WARNING, 1)
          << "WARNING: device: " << place_.device
          << ". The installed Paddle is compiled with CUDA "
          << compile_cuda_version / 10 << "." << compile_cuda_version % 10
          << ", but CUDA runtime version in your machine is "
          << local_cuda_version / 10 << "." << local_cuda_version % 10
          << ", which may cause serious incompatible bug. "
          << "Please recompile or reinstall Paddle with compatible CUDA "
             "version.";
    }
  }
384
  default_ctx_.reset(new CUDAContext(place_));
385 386 387 388
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
389 390 391 392 393
#if defined(PADDLE_WITH_NCCL)
  if (nccl_comm_) {
    PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
  }
#endif
394 395
}

L
liaogang 已提交
396
Place CUDADeviceContext::GetPlace() const { return place_; }
397

398
void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); }
399

K
Kexin Zhao 已提交
400
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
401
  return compute_capability_;
K
Kexin Zhao 已提交
402 403
}

404
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
405
  return multi_process_ * max_threads_per_mp_;
406 407
}

408 409 410 411 412 413
int CUDADeviceContext::GetSMCount() const { return multi_process_; }

int CUDADeviceContext::GetMaxThreadsPerBlock() const {
  return max_threads_per_block_;
}

414
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
415
  return context()->EigenDevice().get();
416 417
}

418
bool CUDADeviceContext::tensor_core_available() const {
419
  return context()->CublasTensorCoreHandle() != nullptr;
S
sneaxiy 已提交
420 421
}

422 423 424 425
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
  return max_grid_dim_size_;
}

426 427 428
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
  return context()->CudnnHandle();
}
429

S
sneaxiy 已提交
430
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
431
  return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
432
}
433

G
Guo Sheng 已提交
434 435 436 437
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  return context()->CusolverDnHandle();
}

438 439 440
cudaStream_t CUDADeviceContext::stream() const {
  return context()->RawStream();
}
Q
qijun 已提交
441

C
chengduoZH 已提交
442 443 444 445 446 447 448 449 450 451 452 453 454 455
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
456
#endif
Q
qijun 已提交
457

T
tensor-tang 已提交
458 459
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
A
Adam 已提交
460 461 462
    : CPUDeviceContext(place),
      engine_(mkldnn::engine::kind::cpu, 0),
      p_blobmap_() {
463 464
  p_blobmap_.reset(new BlobMap());
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
465 466
}

467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483
MKLDNNDeviceContextThreadLocals::Body::Body() {
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
484 485
  cur_input_shape_str = input_shape_str;
}
486 487
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
488 489
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
490

491 492
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
493 494 495
  cur_paddle_data_layout = dl;
}

496 497
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
498 499 500
  return cur_paddle_data_layout;
}

501 502 503 504 505 506 507 508 509
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
  if (!said_once) {
    said_once = true;
    auto dv = dnnl::version();
    LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
              << dv->patch;
  }
}

510 511 512 513 514 515 516 517 518 519 520 521 522 523 524
void MKLDNNDeviceContext::ResetBlobMap() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  if (!block_next_cache_clearing_) {
    VLOG(3) << "Clearing DNNL cache.";
    p_blobmap_->clear();
  } else {
    VLOG(3) << "Prevented Clearing DNNL cache.";
    block_next_cache_clearing_ = false;
  }
}

void MKLDNNDeviceContext::BlockNextCacheClearing() {
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
  VLOG(3) << "Next DNNL cache clearing has been blocked.";
  block_next_cache_clearing_ = true;
525
}
526

527
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
528
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
529
  BlobMap* pMap = p_blobmap_.get();
530
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
531
  if (map_it == pMap->end()) {
532 533 534
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
535 536 537 538
  }
  return map_it->second->size();
}

539
void MKLDNNDeviceContext::SetBlob(const std::string& name,
540
                                  BlobPtr_t<void> data) const {
541
  BlobMap* pMap = p_blobmap_.get();
542 543
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
  BlobPtr_t<KeyBlob> pBlob = nullptr;
544

545
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
546

547
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
548

549 550
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
551 552 553

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
554
    sBlob = std::make_shared<ShapeBlob>();
555 556
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
557
  } else {
558
    sBlob = map_it->second;
559
  }
T
tensor-tang 已提交
560

561
  // Find KeyBlob for current input shape
562
  auto key_it = sBlob->find(tls().cur_input_shape_str);
563

564
  if (key_it == sBlob->end()) {
565 566
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
567 568
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
569
        sBlob->size() &&
570
        (sBlob->size() >=
571
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
572 573 574 575
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
    }
576 577
    pBlob = std::make_shared<KeyBlob>();
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
578
  } else {
579
    pBlob = key_it->second;
580 581
  }

582 583 584 585 586 587 588
  // Find Blob via name
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    (*pBlob)[name] = data;
  } else {
    blob_it->second = data;  // set data to existing blob
  }
589
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
590
  // lock will be automatically released when out of scope
591
  return;
T
tensor-tang 已提交
592 593
}

594
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
595
    const std::string& name) const {
596
  BlobMap* pMap = p_blobmap_.get();
597 598
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
599

600
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
601

602
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
603

604 605
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
606
  if (map_it == pMap->end()) {
607
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
608 609 610 611 612
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
613
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
614
  if (sBlob_it == sBlob->end()) {
615
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
616 617 618 619
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
620 621 622 623

  // Find Blob via name
  auto key_it = pBlob->find(name);

624
  if (key_it == pBlob->end()) {
625
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
626 627
    return nullptr;
  }
628

629
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
630 631
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
632 633 634 635
}

#endif

Q
qijun 已提交
636
}  // namespace platform
Q
qijun 已提交
637
}  // namespace paddle