device_context.cc 10.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
11
#include "paddle/fluid/platform/device_context.h"
12

13
#include <set>
14
#include <string>
Y
Yu Yang 已提交
15
#include <unordered_set>
16 17
#include <vector>

Y
Yi Wang 已提交
18
#include "paddle/fluid/memory/memory.h"
19 20 21
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/rw_lock.h"
#endif
22

Q
qijun 已提交
23 24 25
namespace paddle {
namespace platform {

D
dzhwinter 已提交
26 27
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
28
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
D
dzhwinter 已提交
29 30 31 32 33 34
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
    PADDLE_THROW(
        "'Place' is not supported, Please re-compile with WITH_GPU "
        "option");
  }
Y
Yu Yang 已提交
35
  return it->second.get();
D
dzhwinter 已提交
36 37
}

C
chengduozh 已提交
38 39 40 41 42 43 44 45 46 47
const std::vector<const DeviceContext*>
DeviceContextPool::GetAllDeviceContexts() const {
  std::vector<const DeviceContext*> all_device_ctx;
  all_device_ctx.reserve(device_contexts_.size());
  for (auto& dev_ctx : device_contexts_) {
    all_device_ctx.emplace_back(dev_ctx.second.get());
  }
  return all_device_ctx;
}

D
dzhwinter 已提交
48 49 50
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
  PADDLE_ENFORCE_GT(places.size(), 0);
Y
Yu Yang 已提交
51
  using PtrType = std::unique_ptr<DeviceContext>;
52
  std::set<Place> set;
Y
Yu Yang 已提交
53 54 55 56 57 58
  for (auto& p : places) {
    set.insert(p);
  }

  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
59
#ifdef PADDLE_WITH_MKLDNN
Y
Yu Yang 已提交
60 61
      device_contexts_.emplace(
          p, PtrType(new MKLDNNDeviceContext(boost::get<CPUPlace>(p))));
62
#else
Y
Yu Yang 已提交
63 64
      device_contexts_.emplace(
          p, PtrType(new CPUDeviceContext(boost::get<CPUPlace>(p))));
65
#endif
Y
Yu Yang 已提交
66
    } else if (platform::is_gpu_place(p)) {
D
dzhwinter 已提交
67
#ifdef PADDLE_WITH_CUDA
Y
Yu Yang 已提交
68 69
      device_contexts_.emplace(
          p, PtrType(new CUDADeviceContext(boost::get<CUDAPlace>(p))));
D
dzhwinter 已提交
70 71
#else
      PADDLE_THROW(
D
dzhwinter 已提交
72
          "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
D
dzhwinter 已提交
73
          "option");
C
chengduoZH 已提交
74 75 76 77 78 79 80 81 82 83
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
#ifdef PADDLE_WITH_CUDA
      device_contexts_.emplace(
          p,
          PtrType(new CUDAPinnedDeviceContext(boost::get<CUDAPinnedPlace>(p))));
#else
      PADDLE_THROW(
          "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
          "option");
D
dzhwinter 已提交
84 85 86 87 88
#endif
    }
  }
}

89 90 91 92
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
93
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
94 95 96 97 98 99 100
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
101
Place CPUDeviceContext::GetPlace() const { return place_; }
102

103
#ifdef PADDLE_WITH_CUDA
104

Q
init  
qijun 已提交
105 106 107 108 109 110 111
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

D
dzhwinter 已提交
112
  void Reinitialize(const cudaStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
113 114 115 116 117 118 119 120 121 122 123 124
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

  const cudaStream_t& stream() const override { return *stream_; }

  const cudaDeviceProp& deviceProperties() const override {
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
Q
qijun 已提交
125
    return paddle::memory::Alloc(place_, num_bytes);
Q
init  
qijun 已提交
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
  }

  void deallocate(void* buffer) const override {
    paddle::memory::Free(place_, buffer);
  }

  void* scratchpad() const override {
    if (scratch_ == NULL) {
      scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int));
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
      char* scratch =
          static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
      PADDLE_ENFORCE(
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
151
  CUDAPlace place_;
Q
init  
qijun 已提交
152 153
  const cudaStream_t* stream_;         // not owned;
  const cudaDeviceProp* device_prop_;  // not owned;
Q
qijun 已提交
154
  mutable void* scratch_;
Q
init  
qijun 已提交
155 156 157
  mutable unsigned int* semaphore_;
};

158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
class CudnnHolder {
 public:
  CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
      : workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) {
    PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
    PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
  }

  cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }

  void RunFunc(const std::function<void(void*)>& cudnn_func,
               size_t required_workspace_len) {
    std::lock_guard<std::mutex> lock(mtx_);
    if (required_workspace_len > workspace_len_) {
      ReallocateWorkspace(required_workspace_len);
    }
    cudnn_func(workspace_);
  }

  ~CudnnHolder() {
    PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
    if (workspace_ != nullptr) {
      paddle::memory::Free(place_, workspace_);
    }
  }

 private:
  void ReallocateWorkspace(size_t required_workspace_len) {
    if (required_workspace_len <= workspace_len_) {
      return;
    }
    if (workspace_ != nullptr) {
      // Maybe someone is using the current workspace
      PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
      paddle::memory::Free(place_, workspace_);
    }
F
fengjiayi 已提交
194
    workspace_ = paddle::memory::Alloc(place_, required_workspace_len);
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
    workspace_len_ = required_workspace_len;
  }

  cudnnHandle_t cudnn_handle_;
  void* workspace_;
  size_t workspace_len_;

  const cudaStream_t* stream_;  // not owned;
  const CUDAPlace place_;

  std::mutex mtx_;
};

CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
    : place_(place), cudnn_holder_(nullptr) {
210
  SetDeviceId(place_.device);
C
chengduo 已提交
211 212 213
  compute_capability_ = GetCUDAComputeCapability(place_.device);
  multi_process_ = GetCUDAMultiProcessors(place_.device);
  max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
Q
init  
qijun 已提交
214 215 216
  PADDLE_ENFORCE(cudaStreamCreate(&stream_));
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&stream_, place);
217
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
218 219
  PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
  PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
D
dzhwinter 已提交
220
  if (dynload::HasCUDNN()) {
221
    cudnn_holder_.reset(new CudnnHolder(&stream_, place));
D
dzhwinter 已提交
222
  }
S
sneaxiy 已提交
223

C
chengduo 已提交
224 225 226 227 228 229 230 231 232 233
  driver_version_ = GetCUDADriverVersion(place_.device);
  runtime_version_ = GetCUDARuntimeVersion(place_.device);

  LOG(INFO) << "device: " << place_.device
            << ", CUDA Capability: " << compute_capability_
            << ", Driver Version: " << driver_version_ / 1000 << "."
            << (driver_version_ % 100) / 10
            << ", Runtime Version: " << runtime_version_ / 1000 << "."
            << (runtime_version_ % 100) / 10;

S
sneaxiy 已提交
234
  callback_manager_.reset(new StreamCallbackManager(stream_));
235 236 237 238
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
L
liaogang 已提交
239
  Wait();
S
sneaxiy 已提交
240
  WaitStreamCallback();
241
  PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
242 243
  eigen_stream_.reset();
  eigen_device_.reset();
Q
init  
qijun 已提交
244
  PADDLE_ENFORCE(cudaStreamDestroy(stream_));
245 246
}

L
liaogang 已提交
247
Place CUDADeviceContext::GetPlace() const { return place_; }
248

L
liaogang 已提交
249
void CUDADeviceContext::Wait() const {
Q
init  
qijun 已提交
250
  PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
251 252 253
  PADDLE_ENFORCE(cudaGetLastError());
}

K
Kexin Zhao 已提交
254
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
255
  return compute_capability_;
K
Kexin Zhao 已提交
256 257
}

258
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
259
  return multi_process_ * max_threads_per_mp_;
260 261
}

262 263 264 265
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
  return eigen_device_.get();
}

266
cublasHandle_t CUDADeviceContext::cublas_handle() const {
267 268 269
  return cublas_handle_;
}

270 271 272 273 274 275 276 277
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
  return cudnn_holder_->cudnn_handle();
}

void CUDADeviceContext::RunCudnnFuncWithWorkspace(
    const std::function<void(void*)>& cudnn_func, size_t workspace_len) const {
  cudnn_holder_->RunFunc(cudnn_func, workspace_len);
}
278

279
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
Q
qijun 已提交
280

C
chengduoZH 已提交
281 282 283 284 285 286 287 288 289 290 291 292 293 294
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
295
#endif
Q
qijun 已提交
296

T
tensor-tang 已提交
297 298
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
299 300 301
    : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobmap_() {
  p_blobmap_.reset(new BlobMap());
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
302 303
}

S
Sylwester Fraczek 已提交
304 305 306 307 308 309 310 311
namespace {
// Current thread's id.
thread_local int cur_thread_id = 0;
}

void set_cur_thread_id(int tid) { cur_thread_id = tid; }
int get_cur_thread_id(void) { return cur_thread_id; }

312 313
void MKLDNNDeviceContext::SetBlob(const std::string& name,
                                  std::shared_ptr<void> data) const {
314 315 316 317
  BlobMap* pMap = p_blobmap_.get();
  std::shared_ptr<KeyBlob> pBlob = nullptr;

  int tid = platform::get_cur_thread_id();
T
tensor-tang 已提交
318

319
  std::lock_guard<std::mutex> lock(*p_mutex_.get());
T
tensor-tang 已提交
320

321 322 323 324 325 326 327
  // Find KeyBlob for current thread
  auto map_it = pMap->find(tid);

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
    pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
    (*pMap)[tid] = pBlob;
328
  } else {
329
    pBlob = map_it->second;
330
  }
T
tensor-tang 已提交
331

332 333 334 335 336 337 338 339 340 341
  // Find Key in found (or newly created) KeyBlob
  auto key_it = pBlob->find(name);

  if (key_it == pBlob->end()) {
    (*pBlob)[name] = data;  // create new blob
  } else {
    key_it->second = data;  // set data to existing blob
  }

  // lock will be automatically released when out of scope
342
  return;
T
tensor-tang 已提交
343 344
}

345 346
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
    const std::string& name) const {
347 348
  BlobMap* pMap = p_blobmap_.get();
  std::shared_ptr<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
349

350
  int tid = platform::get_cur_thread_id();
T
tensor-tang 已提交
351

352 353 354 355 356 357 358 359 360 361 362
  std::lock_guard<std::mutex> lock(*p_mutex_.get());

  // Find KeyBlob for current thread firstly
  auto map_it = pMap->find(tid);
  if (map_it == pMap->end()) return nullptr;
  pBlob = map_it->second;

  // Find Blob via name
  auto key_it = pBlob->find(name);

  if (key_it == pBlob->end()) return nullptr;
363

364 365
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
366 367 368 369
}

#endif

Q
qijun 已提交
370
}  // namespace platform
Q
qijun 已提交
371
}  // namespace paddle