device_context.cc 16.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
13
#include <set>
14
#include <string>
Y
Yu Yang 已提交
15
#include <unordered_set>
16 17
#include <vector>

Y
Yi Wang 已提交
18
#include "paddle/fluid/memory/memory.h"
19 20
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/rw_lock.h"
21
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
22
#include "paddle/fluid/platform/cuda_device_guard.h"
23
#endif
24

25 26
#include "glog/logging.h"

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
#ifdef PADDLE_WITH_CUDA
  if (size == 0 || !platform::is_gpu_place(place)) {
    return Alloc(place, size);
  }
  auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  auto& desired_dev_ctx =
      static_cast<const platform::CUDADeviceContext&>(dev_ctx);
  if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
    return Alloc(place, size);
  } else {
    return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
        desired_dev_ctx, size);
  }
#else
  return Alloc(place, size);
#endif
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
54 55 56
namespace paddle {
namespace platform {

D
dzhwinter 已提交
57 58
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
59
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
D
dzhwinter 已提交
60 61 62
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
    PADDLE_THROW(
63 64 65 66
        "Place %s is not supported, Please check that your paddle compiles "
        "with WITH_GPU "
        "option or check that your train process hold the correct gpu_id if "
        "you use Executor",
M
minqiyang 已提交
67
        place);
D
dzhwinter 已提交
68
  }
69
  return it->second.get().get();
D
dzhwinter 已提交
70 71
}

72 73 74 75 76 77 78 79 80
template <typename DevCtx, typename PlaceType>
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
  map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
                     // lazy evaluation. i.e., only create device context at
                     // first `Get`
81
                     return PtrType(new DevCtx(BOOST_GET_CONST(PlaceType, p)));
82
                   }));
C
chengduozh 已提交
83 84
}

D
dzhwinter 已提交
85 86 87
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
  PADDLE_ENFORCE_GT(places.size(), 0);
88
  std::set<Place> set;
Y
Yu Yang 已提交
89 90 91 92 93
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
94
#ifdef PADDLE_WITH_MKLDNN
95
      EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
96
#else
97
      EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
98
#endif
Y
Yu Yang 已提交
99
    } else if (platform::is_gpu_place(p)) {
D
dzhwinter 已提交
100
#ifdef PADDLE_WITH_CUDA
101
      EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
D
dzhwinter 已提交
102 103
#else
      PADDLE_THROW(
D
dzhwinter 已提交
104
          "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
D
dzhwinter 已提交
105
          "option");
C
chengduoZH 已提交
106 107 108
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
#ifdef PADDLE_WITH_CUDA
109 110
      EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
          &device_contexts_, p);
C
chengduoZH 已提交
111 112 113 114
#else
      PADDLE_THROW(
          "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
          "option");
D
dzhwinter 已提交
115 116 117 118 119
#endif
    }
  }
}

120 121 122 123
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
124
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
125 126 127 128 129 130 131
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
132
Place CPUDeviceContext::GetPlace() const { return place_; }
133

134
#ifdef PADDLE_WITH_CUDA
135

Q
init  
qijun 已提交
136 137 138 139 140 141 142
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

D
dzhwinter 已提交
143
  void Reinitialize(const cudaStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
144 145 146 147 148 149 150 151 152 153 154 155
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

  const cudaStream_t& stream() const override { return *stream_; }

  const cudaDeviceProp& deviceProperties() const override {
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
156 157 158
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
159 160 161
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
162
    void* retv = buf->ptr();
S
sneaxiy 已提交
163 164 165 166
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
167
    return retv;
Q
init  
qijun 已提交
168 169
  }

S
sneaxiy 已提交
170 171 172 173 174 175
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
176 177 178

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
179 180 181 182
// windows use an old version of eigen that uses kCudaScratchSize,
// once windows updates eigen to a recent version, the following code
// can use kGpuScratchSize uniformly
#ifdef _WIN32
Q
init  
qijun 已提交
183
      scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int));
Z
Zhang Ting 已提交
184 185 186
#else
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
#endif
Q
init  
qijun 已提交
187 188 189 190 191 192
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
193
#ifdef _WIN32
Q
init  
qijun 已提交
194 195
      char* scratch =
          static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
Z
Zhang Ting 已提交
196 197 198
#else
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
#endif
Q
init  
qijun 已提交
199
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
200
      PADDLE_ENFORCE_CUDA_SUCCESS(
Q
init  
qijun 已提交
201 202 203 204 205 206
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
207
  CUDAPlace place_;
Q
init  
qijun 已提交
208 209
  const cudaStream_t* stream_;         // not owned;
  const cudaDeviceProp* device_prop_;  // not owned;
Q
qijun 已提交
210
  mutable void* scratch_;
Q
init  
qijun 已提交
211
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
212
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
213
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
214 215
};

216 217 218 219 220 221 222 223 224
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

225 226 227 228 229 230 231 232 233 234 235 236
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
237
                         const stream::Priority& priority) {
238 239 240 241 242 243
  place_ = place;
  CUDADeviceGuard guard(place_.device);
  stream_.reset(new stream::CUDAStream(place, priority));
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
G
Guo Sheng 已提交
244
  InitCuSolverContext();
245 246 247 248 249 250
}

CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
G
Guo Sheng 已提交
251
  DestoryCuSolverContext();
252 253
}

254
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
Y
Yu Yang 已提交
255
  CUDADeviceGuard guard(place_.device);
C
chengduo 已提交
256 257 258
  compute_capability_ = GetCUDAComputeCapability(place_.device);
  multi_process_ = GetCUDAMultiProcessors(place_.device);
  max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
259
  max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
260
  max_threads_per_block_ = GetCUDAMaxThreadsPerBlock(place_.device);
261

C
chengduo 已提交
262 263 264
  driver_version_ = GetCUDADriverVersion(place_.device);
  runtime_version_ = GetCUDARuntimeVersion(place_.device);

265 266
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device
                          << ", CUDA Capability: " << compute_capability_
C
chengduo 已提交
267
                          << ", Driver API Version: " << driver_version_ / 1000
268
                          << "." << (driver_version_ % 100) / 10
C
chengduo 已提交
269 270 271
                          << ", Runtime API Version: "
                          << runtime_version_ / 1000 << "."
                          << (runtime_version_ % 100) / 10;
272 273 274
  size_t cudnn_dso_ver = dynload::cudnnGetVersion();
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
275
                          << (cudnn_dso_ver % 1000) / 100 << ".";
S
sneaxiy 已提交
276 277 278

  {
    // Check CUDA/CUDNN version compatiblity
279 280 281 282
    auto local_cuda_version =
        (driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
    auto compile_cuda_version =
        (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
S
sneaxiy 已提交
283 284 285 286 287 288 289 290 291 292 293 294
    if (local_cuda_version < compile_cuda_version) {
      LOG_FIRST_N(WARNING, 1)
          << "WARNING: device: " << place_.device
          << ". The installed Paddle is compiled with CUDA "
          << compile_cuda_version / 10 << "." << compile_cuda_version % 10
          << ", but CUDA runtime version in your machine is "
          << local_cuda_version / 10 << "." << local_cuda_version % 10
          << ", which may cause serious incompatible bug. "
          << "Please recompile or reinstall Paddle with compatible CUDA "
             "version.";
    }
  }
295
  default_ctx_.reset(new CUDAContext(place_));
296 297 298 299
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
300 301 302 303 304
#if defined(PADDLE_WITH_NCCL)
  if (nccl_comm_) {
    PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
  }
#endif
305 306
}

L
liaogang 已提交
307
Place CUDADeviceContext::GetPlace() const { return place_; }
308

309
void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); }
310

K
Kexin Zhao 已提交
311
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
312
  return compute_capability_;
K
Kexin Zhao 已提交
313 314
}

315
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
316
  return multi_process_ * max_threads_per_mp_;
317 318
}

319 320 321 322 323 324
int CUDADeviceContext::GetSMCount() const { return multi_process_; }

int CUDADeviceContext::GetMaxThreadsPerBlock() const {
  return max_threads_per_block_;
}

325
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
326
  return context()->EigenDevice().get();
327 328
}

329
bool CUDADeviceContext::tensor_core_available() const {
330
  return context()->CublasTensorCoreHandle() != nullptr;
S
sneaxiy 已提交
331 332
}

333 334 335 336
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
  return max_grid_dim_size_;
}

337 338 339
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
  return context()->CudnnHandle();
}
340

S
sneaxiy 已提交
341
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
342
  return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
343
}
344

G
Guo Sheng 已提交
345 346 347 348
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  return context()->CusolverDnHandle();
}

349 350 351
cudaStream_t CUDADeviceContext::stream() const {
  return context()->RawStream();
}
Q
qijun 已提交
352

C
chengduoZH 已提交
353 354 355 356 357 358 359 360 361 362 363 364 365 366
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
367
#endif
Q
qijun 已提交
368

T
tensor-tang 已提交
369 370
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
A
Adam 已提交
371 372 373
    : CPUDeviceContext(place),
      engine_(mkldnn::engine::kind::cpu, 0),
      p_blobmap_() {
374 375
  p_blobmap_.reset(new BlobMap());
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
376 377
}

378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
MKLDNNDeviceContextThreadLocals::Body::Body() {
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
395 396
  cur_input_shape_str = input_shape_str;
}
397 398
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
399 400
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
401

402 403
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
404 405 406
  cur_paddle_data_layout = dl;
}

407 408
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
409 410 411
  return cur_paddle_data_layout;
}

412 413 414 415
void MKLDNNDeviceContext::ResetBlobMap() const {
  VLOG(3) << "Clearing DNNL cache.";
  p_blobmap_->clear();
}
416

417
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
418
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
419
  BlobMap* pMap = p_blobmap_.get();
420
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
421
  if (map_it == pMap->end()) {
422 423 424
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
425 426 427 428
  }
  return map_it->second->size();
}

429
void MKLDNNDeviceContext::SetBlob(const std::string& name,
430
                                  BlobPtr_t<void> data) const {
431
  BlobMap* pMap = p_blobmap_.get();
432 433
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
  BlobPtr_t<KeyBlob> pBlob = nullptr;
434

435
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
436

437
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
438

439 440
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
441 442 443

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
444
    sBlob = std::make_shared<ShapeBlob>();
445 446
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
447
  } else {
448
    sBlob = map_it->second;
449
  }
T
tensor-tang 已提交
450

451
  // Find KeyBlob for current input shape
452
  auto key_it = sBlob->find(tls().cur_input_shape_str);
453

454
  if (key_it == sBlob->end()) {
455 456
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
457 458
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
459
        sBlob->size() &&
460
        (sBlob->size() >=
461
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
462 463 464 465
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
    }
466 467
    pBlob = std::make_shared<KeyBlob>();
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
468
  } else {
469
    pBlob = key_it->second;
470 471
  }

472 473 474 475 476 477 478
  // Find Blob via name
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    (*pBlob)[name] = data;
  } else {
    blob_it->second = data;  // set data to existing blob
  }
479
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
480
  // lock will be automatically released when out of scope
481
  return;
T
tensor-tang 已提交
482 483
}

484
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
485
    const std::string& name) const {
486
  BlobMap* pMap = p_blobmap_.get();
487 488
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
489

490
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
491

492
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
493

494 495
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
496
  if (map_it == pMap->end()) {
497
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
498 499 500 501 502
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
503
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
504
  if (sBlob_it == sBlob->end()) {
505
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
506 507 508 509
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
510 511 512 513

  // Find Blob via name
  auto key_it = pBlob->find(name);

514
  if (key_it == pBlob->end()) {
515
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
516 517
    return nullptr;
  }
518

519
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
520 521
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
522 523 524 525
}

#endif

Q
qijun 已提交
526
}  // namespace platform
Q
qijun 已提交
527
}  // namespace paddle