device_context.cc 17.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
13
#include <set>
14
#include <string>
Y
Yu Yang 已提交
15
#include <unordered_set>
16 17
#include <vector>

Y
Yi Wang 已提交
18
#include "paddle/fluid/memory/memory.h"
19 20
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/rw_lock.h"
21
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
22
#include "paddle/fluid/platform/cuda_device_guard.h"
23
#endif
24

25 26
#include "glog/logging.h"

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
#ifdef PADDLE_WITH_CUDA
  if (size == 0 || !platform::is_gpu_place(place)) {
    return Alloc(place, size);
  }
  auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  auto& desired_dev_ctx =
      static_cast<const platform::CUDADeviceContext&>(dev_ctx);
  if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
    return Alloc(place, size);
  } else {
    return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
        desired_dev_ctx, size);
  }
#else
  return Alloc(place, size);
#endif
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
54 55 56
namespace paddle {
namespace platform {

D
dzhwinter 已提交
57 58
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
59
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
D
dzhwinter 已提交
60 61 62
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
    PADDLE_THROW(
63 64 65 66
        "Place %s is not supported, Please check that your paddle compiles "
        "with WITH_GPU "
        "option or check that your train process hold the correct gpu_id if "
        "you use Executor",
M
minqiyang 已提交
67
        place);
D
dzhwinter 已提交
68
  }
69
  return it->second.get().get();
D
dzhwinter 已提交
70 71
}

72 73 74 75 76 77 78 79 80 81 82
template <typename DevCtx, typename PlaceType>
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
  map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
                     // lazy evaluation. i.e., only create device context at
                     // first `Get`
                     return PtrType(new DevCtx(boost::get<PlaceType>(p)));
                   }));
C
chengduozh 已提交
83 84
}

D
dzhwinter 已提交
85 86 87
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
  PADDLE_ENFORCE_GT(places.size(), 0);
88
  std::set<Place> set;
Y
Yu Yang 已提交
89 90 91 92 93
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
94
#ifdef PADDLE_WITH_MKLDNN
95
      EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
96
#else
97
      EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
98
#endif
Y
Yu Yang 已提交
99
    } else if (platform::is_gpu_place(p)) {
D
dzhwinter 已提交
100
#ifdef PADDLE_WITH_CUDA
101
      EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
D
dzhwinter 已提交
102 103
#else
      PADDLE_THROW(
D
dzhwinter 已提交
104
          "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
D
dzhwinter 已提交
105
          "option");
C
chengduoZH 已提交
106 107 108
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
#ifdef PADDLE_WITH_CUDA
109 110
      EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
          &device_contexts_, p);
C
chengduoZH 已提交
111 112 113 114
#else
      PADDLE_THROW(
          "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
          "option");
D
dzhwinter 已提交
115 116 117 118 119
#endif
    }
  }
}

120 121 122 123
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
124
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
125 126 127 128 129 130 131
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
132
Place CPUDeviceContext::GetPlace() const { return place_; }
133

134
#ifdef PADDLE_WITH_CUDA
135

Q
init  
qijun 已提交
136 137 138 139 140 141 142
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

D
dzhwinter 已提交
143
  void Reinitialize(const cudaStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
144 145 146 147 148 149 150 151 152 153 154 155
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

  const cudaStream_t& stream() const override { return *stream_; }

  const cudaDeviceProp& deviceProperties() const override {
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
156 157 158
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
159 160 161
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
162
    void* retv = buf->ptr();
S
sneaxiy 已提交
163 164 165 166
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
167
    return retv;
Q
init  
qijun 已提交
168 169
  }

S
sneaxiy 已提交
170 171 172 173 174 175
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188

  void* scratchpad() const override {
    if (scratch_ == NULL) {
      scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int));
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
      char* scratch =
          static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
189
      PADDLE_ENFORCE_CUDA_SUCCESS(
Q
init  
qijun 已提交
190 191 192 193 194 195
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
196
  CUDAPlace place_;
Q
init  
qijun 已提交
197 198
  const cudaStream_t* stream_;         // not owned;
  const cudaDeviceProp* device_prop_;  // not owned;
Q
qijun 已提交
199
  mutable void* scratch_;
Q
init  
qijun 已提交
200
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
201
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
202
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
203 204
};

205 206 207 208 209 210 211 212 213
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

214
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
Y
Yu Yang 已提交
215
  CUDADeviceGuard guard(place_.device);
C
chengduo 已提交
216 217 218
  compute_capability_ = GetCUDAComputeCapability(place_.device);
  multi_process_ = GetCUDAMultiProcessors(place_.device);
  max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
219
  max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
220
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream_));
Q
init  
qijun 已提交
221 222
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&stream_, place);
223
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
224 225 226 227 228 229 230 231 232
  cublas_handle_.reset(new CublasHandleHolder(stream_, CUBLAS_DEFAULT_MATH));

  if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
    cublas_tensor_core_handle_.reset(
        new CublasHandleHolder(stream_, CUBLAS_TENSOR_OP_MATH));
#endif
  }

C
chengduo 已提交
233 234 235
  driver_version_ = GetCUDADriverVersion(place_.device);
  runtime_version_ = GetCUDARuntimeVersion(place_.device);

236 237
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device
                          << ", CUDA Capability: " << compute_capability_
C
chengduo 已提交
238
                          << ", Driver API Version: " << driver_version_ / 1000
239
                          << "." << (driver_version_ % 100) / 10
C
chengduo 已提交
240 241 242
                          << ", Runtime API Version: "
                          << runtime_version_ / 1000 << "."
                          << (runtime_version_ % 100) / 10;
243 244 245
  size_t cudnn_dso_ver = dynload::cudnnGetVersion();
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
246
                          << (cudnn_dso_ver % 1000) / 100 << ".";
S
sneaxiy 已提交
247 248 249

  {
    // Check CUDA/CUDNN version compatiblity
250 251 252 253
    auto local_cuda_version =
        (driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
    auto compile_cuda_version =
        (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
S
sneaxiy 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
    if (local_cuda_version < compile_cuda_version) {
      LOG_FIRST_N(WARNING, 1)
          << "WARNING: device: " << place_.device
          << ". The installed Paddle is compiled with CUDA "
          << compile_cuda_version / 10 << "." << compile_cuda_version % 10
          << ", but CUDA runtime version in your machine is "
          << local_cuda_version / 10 << "." << local_cuda_version % 10
          << ", which may cause serious incompatible bug. "
          << "Please recompile or reinstall Paddle with compatible CUDA "
             "version.";
    }

    if (dynload::HasCUDNN()) {
      auto local_cudnn_version = cudnn_dso_ver / 100;
      auto compile_cudnn_version = CUDNN_VERSION / 100;
S
sneaxiy 已提交
269
      if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
S
sneaxiy 已提交
270 271 272 273 274 275 276 277 278 279
        LOG_FIRST_N(WARNING, 1)
            << "WARNING: device: " << place_.device
            << ". The installed Paddle is compiled with CUDNN "
            << compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
            << ", but CUDNN version in your machine is "
            << local_cudnn_version / 10 << "." << local_cudnn_version % 10
            << ", which may cause serious incompatible bug. "
            << "Please recompile or reinstall Paddle with compatible CUDNN "
               "version.";
      }
280 281 282 283 284 285 286 287
      PADDLE_ENFORCE_CUDA_SUCCESS(
          dynload::cudnnCreate(&cudnn_handle_),
          "Failed to create Cudnn handle in DeviceContext");
      PADDLE_ENFORCE_CUDA_SUCCESS(
          dynload::cudnnSetStream(cudnn_handle_, stream_),
          "Failed to set stream for Cudnn handle in DeviceContext");
    } else {
      cudnn_handle_ = nullptr;
S
sneaxiy 已提交
288 289 290
    }
  }

S
sneaxiy 已提交
291
  callback_manager_.reset(new StreamCallbackManager(stream_));
292 293 294 295
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
L
liaogang 已提交
296
  Wait();
S
sneaxiy 已提交
297
  WaitStreamCallback();
298 299
  cublas_handle_.reset();
  cublas_tensor_core_handle_.reset();
300 301
  eigen_stream_.reset();
  eigen_device_.reset();
302 303 304 305 306
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(stream_));
  if (cudnn_handle_) {
    PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroy(cudnn_handle_),
                                "Failed to destory Cudnn handle");
  }
Q
qingqing01 已提交
307
#if !defined(_WIN32)
308
  if (nccl_comm_) {
309
    PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
310
  }
Q
qingqing01 已提交
311
#endif
312 313
}

L
liaogang 已提交
314
Place CUDADeviceContext::GetPlace() const { return place_; }
315

L
liaogang 已提交
316
void CUDADeviceContext::Wait() const {
317 318 319 320 321 322 323 324 325 326 327
  cudaError_t e_sync = cudaSuccess;
#if !defined(_WIN32)
  e_sync = cudaStreamSynchronize(stream_);
#else
  while (e_sync = cudaStreamQuery(stream_)) {
    if (e_sync == cudaErrorNotReady) continue;
    break;
  }
#endif

  if (cudaSuccess != e_sync) {
328 329 330
    LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync)
               << " errno: " << e_sync;
  }
331

332
  cudaError_t e_get = cudaGetLastError();
333
  if (cudaSuccess != e_get) {
334 335 336
    LOG(FATAL) << "cudaGetLastError  " << cudaGetErrorString(e_get)
               << " errno: " << e_get;
  }
337 338
}

K
Kexin Zhao 已提交
339
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
340
  return compute_capability_;
K
Kexin Zhao 已提交
341 342
}

343
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
344
  return multi_process_ * max_threads_per_mp_;
345 346
}

347 348 349 350
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
  return eigen_device_.get();
}

351 352
bool CUDADeviceContext::tensor_core_available() const {
  return cublas_tensor_core_handle_ != nullptr;
S
sneaxiy 已提交
353 354
}

355 356 357 358
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
  return max_grid_dim_size_;
}

359
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
360

S
sneaxiy 已提交
361
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
362
  return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
363
}
364

365
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
Q
qijun 已提交
366

C
chengduoZH 已提交
367 368 369 370 371 372 373 374 375 376 377 378 379 380
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
381
#endif
Q
qijun 已提交
382

T
tensor-tang 已提交
383 384
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
385 386 387
    : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobmap_() {
  p_blobmap_.reset(new BlobMap());
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
388 389
}

S
Sylwester Fraczek 已提交
390
namespace {
391 392
// Current mkldnn session id.
thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default;
393 394 395 396
// Current data input shape string.
// - For fixed-shape, it's a null string in default.
// - For dynamic-shape, it's user specific.
thread_local std::string cur_input_shape_str = "";
397 398 399
// the cache capacity of different input shapes for MKLDNN.
// Default 1 means fixed input shape, not dynamic shape.
thread_local int cur_input_shape_cache_capacity = 1;
400 401 402 403
// Recently registered data_format. This is needed to
// know for converting MKL-DNN Tensor to non MKL-DNN
thread_local paddle::framework::DataLayout cur_paddle_data_layout =
    paddle::framework::DataLayout::kNCHW;
404
}  // namespace
S
Sylwester Fraczek 已提交
405

406 407
void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; }
size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; }
408 409 410
void set_cur_input_shape_str(std::string input_shape_str) {
  cur_input_shape_str = input_shape_str;
}
411 412 413
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity) {
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
414

415 416 417 418 419 420 421 422
void set_cur_paddle_data_layout(framework::DataLayout dl) {
  cur_paddle_data_layout = dl;
}

framework::DataLayout get_cur_paddle_data_layout(void) {
  return cur_paddle_data_layout;
}

423 424
void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); }

425 426 427 428 429 430 431 432 433 434 435
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
  std::lock_guard<std::mutex> lock(*p_mutex_);
  BlobMap* pMap = p_blobmap_.get();
  auto map_it = pMap->find(cur_mkldnn_session_id);
  if (map_it == pMap->end()) {
    LOG(FATAL) << "MKLDNNDeviceContext don't find cur_mkldnn_session_id : "
               << cur_mkldnn_session_id;
  }
  return map_it->second->size();
}

436 437
void MKLDNNDeviceContext::SetBlob(const std::string& name,
                                  std::shared_ptr<void> data) const {
438
  BlobMap* pMap = p_blobmap_.get();
439
  std::shared_ptr<ShapeBlob> sBlob = nullptr;
440 441
  std::shared_ptr<KeyBlob> pBlob = nullptr;

442
  int sid = platform::get_cur_mkldnn_session_id();
T
tensor-tang 已提交
443

444
  std::lock_guard<std::mutex> lock(*p_mutex_);
T
tensor-tang 已提交
445

446 447
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
448 449 450

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
451
    sBlob = std::shared_ptr<ShapeBlob>(new ShapeBlob());
452 453
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
454
  } else {
455
    sBlob = map_it->second;
456
  }
T
tensor-tang 已提交
457

458 459
  // Find KeyBlob for current input shape
  auto key_it = sBlob->find(cur_input_shape_str);
460

461
  if (key_it == sBlob->end()) {
462 463
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
464 465
    if ((static_cast<size_t>(sid) == kMKLDNNSessionID_CacheClearing) &&
        sBlob->size() &&
466 467 468 469 470 471
        (sBlob->size() >=
         static_cast<size_t>(cur_input_shape_cache_capacity))) {
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
    }
472 473
    pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
    (*sBlob)[cur_input_shape_str] = pBlob;
474
  } else {
475
    pBlob = key_it->second;
476 477
  }

478 479 480 481 482 483 484
  // Find Blob via name
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    (*pBlob)[name] = data;
  } else {
    blob_it->second = data;  // set data to existing blob
  }
485
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
486
  // lock will be automatically released when out of scope
487
  return;
T
tensor-tang 已提交
488 489
}

490 491
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
    const std::string& name) const {
492
  BlobMap* pMap = p_blobmap_.get();
493
  std::shared_ptr<ShapeBlob> sBlob = nullptr;
494
  std::shared_ptr<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
495

496
  int sid = platform::get_cur_mkldnn_session_id();
T
tensor-tang 已提交
497

498
  std::lock_guard<std::mutex> lock(*p_mutex_);
499

500 501
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
502
  if (map_it == pMap->end()) {
503
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
504 505 506 507 508 509 510
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
  auto sBlob_it = sBlob->find(cur_input_shape_str);
  if (sBlob_it == sBlob->end()) {
511
    VLOG(2) << "GetBlob: sid=" << cur_input_shape_str
512 513 514 515
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
516 517 518 519

  // Find Blob via name
  auto key_it = pBlob->find(name);

520
  if (key_it == pBlob->end()) {
521
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
522 523
    return nullptr;
  }
524

525
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
526 527
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
528 529 530 531
}

#endif

Q
qijun 已提交
532
}  // namespace platform
Q
qijun 已提交
533
}  // namespace paddle