device_context.cc 10.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
11
#include "paddle/fluid/platform/device_context.h"
12
#include <set>
13
#include <string>
Y
Yu Yang 已提交
14
#include <unordered_set>
15 16
#include <vector>

Y
Yi Wang 已提交
17
#include "paddle/fluid/memory/memory.h"
18 19
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/rw_lock.h"
S
sneaxiy 已提交
20
#include "paddle/fluid/platform/cuda_device_guard.h"
21
#endif
22

Q
qijun 已提交
23 24 25
namespace paddle {
namespace platform {

D
dzhwinter 已提交
26 27
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
28
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
D
dzhwinter 已提交
29 30 31 32 33 34
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
    PADDLE_THROW(
        "'Place' is not supported, Please re-compile with WITH_GPU "
        "option");
  }
35
  return it->second.get().get();
D
dzhwinter 已提交
36 37
}

38 39 40 41 42 43 44 45 46 47 48
template <typename DevCtx, typename PlaceType>
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
  map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
                     // lazy evaluation. i.e., only create device context at
                     // first `Get`
                     return PtrType(new DevCtx(boost::get<PlaceType>(p)));
                   }));
C
chengduozh 已提交
49 50
}

D
dzhwinter 已提交
51 52 53
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
  PADDLE_ENFORCE_GT(places.size(), 0);
54
  std::set<Place> set;
Y
Yu Yang 已提交
55 56 57 58 59 60
  for (auto& p : places) {
    set.insert(p);
  }

  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
61
#ifdef PADDLE_WITH_MKLDNN
62
      EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
63
#else
64
      EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
65
#endif
Y
Yu Yang 已提交
66
    } else if (platform::is_gpu_place(p)) {
D
dzhwinter 已提交
67
#ifdef PADDLE_WITH_CUDA
68
      EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
D
dzhwinter 已提交
69 70
#else
      PADDLE_THROW(
D
dzhwinter 已提交
71
          "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
D
dzhwinter 已提交
72
          "option");
C
chengduoZH 已提交
73 74 75
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
#ifdef PADDLE_WITH_CUDA
76 77
      EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
          &device_contexts_, p);
C
chengduoZH 已提交
78 79 80 81
#else
      PADDLE_THROW(
          "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
          "option");
D
dzhwinter 已提交
82 83 84 85 86
#endif
    }
  }
}

87 88 89 90
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
91
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
92 93 94 95 96 97 98
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
99
Place CPUDeviceContext::GetPlace() const { return place_; }
100

101
#ifdef PADDLE_WITH_CUDA
102

Q
init  
qijun 已提交
103 104 105 106 107 108 109
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

D
dzhwinter 已提交
110
  void Reinitialize(const cudaStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
111 112 113 114 115 116 117 118 119 120 121 122
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

  const cudaStream_t& stream() const override { return *stream_; }

  const cudaDeviceProp& deviceProperties() const override {
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
Y
Yu Yang 已提交
123 124
    auto buf = paddle::memory::Alloc(place_, num_bytes,
                                     memory::Allocator::kScratchpad);
125 126 127
    void* retv = buf->ptr();
    allocations_[buf->ptr()] = std::move(buf);
    return retv;
Q
init  
qijun 已提交
128 129 130
  }

  void deallocate(void* buffer) const override {
131
    allocations_.erase(allocations_.find(buffer));
Q
init  
qijun 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
  }

  void* scratchpad() const override {
    if (scratch_ == NULL) {
      scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int));
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
      char* scratch =
          static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
      PADDLE_ENFORCE(
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
153
  CUDAPlace place_;
Q
init  
qijun 已提交
154 155
  const cudaStream_t* stream_;         // not owned;
  const cudaDeviceProp* device_prop_;  // not owned;
Q
qijun 已提交
156
  mutable void* scratch_;
Q
init  
qijun 已提交
157
  mutable unsigned int* semaphore_;
158 159
  mutable std::unordered_map<void*, std::unique_ptr<memory::Allocation>>
      allocations_;
Q
init  
qijun 已提交
160 161
};

S
sneaxiy 已提交
162 163 164 165 166
CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
    : workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) {
  PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
  PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
}
167

S
sneaxiy 已提交
168 169 170 171
CudnnHolder::~CudnnHolder() {
  PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
  if (workspace_ != nullptr) {
    paddle::memory::Free(place_, workspace_);
172
  }
S
sneaxiy 已提交
173
}
174

S
sneaxiy 已提交
175 176 177
void CudnnHolder::ReallocateWorkspace(size_t required_workspace_len) {
  if (required_workspace_len <= workspace_len_) {
    return;
Y
Yu Yang 已提交
178
  }
S
sneaxiy 已提交
179 180 181 182
  if (workspace_ != nullptr) {
    // Maybe someone is using the current workspace
    PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
    paddle::memory::Free(place_, workspace_);
183
  }
S
sneaxiy 已提交
184 185
  workspace_ = paddle::memory::Alloc(place_, required_workspace_len);
  workspace_len_ = required_workspace_len;
S
sneaxiy 已提交
186
}
187 188 189

CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
    : place_(place), cudnn_holder_(nullptr) {
Y
Yu Yang 已提交
190
  CUDADeviceGuard guard(place_.device);
C
chengduo 已提交
191 192 193
  compute_capability_ = GetCUDAComputeCapability(place_.device);
  multi_process_ = GetCUDAMultiProcessors(place_.device);
  max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
Q
init  
qijun 已提交
194 195 196
  PADDLE_ENFORCE(cudaStreamCreate(&stream_));
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&stream_, place);
197
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
198 199
  PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
  PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
D
dzhwinter 已提交
200
  if (dynload::HasCUDNN()) {
201
    cudnn_holder_.reset(new CudnnHolder(&stream_, place));
D
dzhwinter 已提交
202
  }
S
sneaxiy 已提交
203

C
chengduo 已提交
204 205 206
  driver_version_ = GetCUDADriverVersion(place_.device);
  runtime_version_ = GetCUDARuntimeVersion(place_.device);

207 208 209 210 211 212
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device
                          << ", CUDA Capability: " << compute_capability_
                          << ", Driver Version: " << driver_version_ / 1000
                          << "." << (driver_version_ % 100) / 10
                          << ", Runtime Version: " << runtime_version_ / 1000
                          << "." << (runtime_version_ % 100) / 10;
C
chengduo 已提交
213

S
sneaxiy 已提交
214
  callback_manager_.reset(new StreamCallbackManager(stream_));
215 216 217 218
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
L
liaogang 已提交
219
  Wait();
S
sneaxiy 已提交
220
  WaitStreamCallback();
221
  PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
222 223
  eigen_stream_.reset();
  eigen_device_.reset();
Q
init  
qijun 已提交
224
  PADDLE_ENFORCE(cudaStreamDestroy(stream_));
225 226
}

L
liaogang 已提交
227
Place CUDADeviceContext::GetPlace() const { return place_; }
228

L
liaogang 已提交
229
void CUDADeviceContext::Wait() const {
Q
init  
qijun 已提交
230
  PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
231 232 233
  PADDLE_ENFORCE(cudaGetLastError());
}

K
Kexin Zhao 已提交
234
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
235
  return compute_capability_;
K
Kexin Zhao 已提交
236 237
}

238
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
239
  return multi_process_ * max_threads_per_mp_;
240 241
}

242 243 244 245
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
  return eigen_device_.get();
}

246
cublasHandle_t CUDADeviceContext::cublas_handle() const {
247 248 249
  return cublas_handle_;
}

250 251 252 253
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
  return cudnn_holder_->cudnn_handle();
}

S
sneaxiy 已提交
254 255
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
  return CudnnWorkspaceHandle(cudnn_holder_.get());
256
}
257

258
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
Q
qijun 已提交
259

C
chengduoZH 已提交
260 261 262 263 264 265 266 267 268 269 270 271 272 273
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
274
#endif
Q
qijun 已提交
275

T
tensor-tang 已提交
276 277
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
278 279 280
    : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobmap_() {
  p_blobmap_.reset(new BlobMap());
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
281 282
}

S
Sylwester Fraczek 已提交
283 284 285 286 287 288 289 290
namespace {
// Current thread's id.
thread_local int cur_thread_id = 0;
}

void set_cur_thread_id(int tid) { cur_thread_id = tid; }
int get_cur_thread_id(void) { return cur_thread_id; }

291 292
void MKLDNNDeviceContext::SetBlob(const std::string& name,
                                  std::shared_ptr<void> data) const {
293 294 295 296
  BlobMap* pMap = p_blobmap_.get();
  std::shared_ptr<KeyBlob> pBlob = nullptr;

  int tid = platform::get_cur_thread_id();
T
tensor-tang 已提交
297

298
  std::lock_guard<std::mutex> lock(*p_mutex_.get());
T
tensor-tang 已提交
299

300 301 302 303 304 305 306
  // Find KeyBlob for current thread
  auto map_it = pMap->find(tid);

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
    pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
    (*pMap)[tid] = pBlob;
307
  } else {
308
    pBlob = map_it->second;
309
  }
T
tensor-tang 已提交
310

311 312 313 314 315 316 317 318 319 320
  // Find Key in found (or newly created) KeyBlob
  auto key_it = pBlob->find(name);

  if (key_it == pBlob->end()) {
    (*pBlob)[name] = data;  // create new blob
  } else {
    key_it->second = data;  // set data to existing blob
  }

  // lock will be automatically released when out of scope
321
  return;
T
tensor-tang 已提交
322 323
}

324 325
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
    const std::string& name) const {
326 327
  BlobMap* pMap = p_blobmap_.get();
  std::shared_ptr<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
328

329
  int tid = platform::get_cur_thread_id();
T
tensor-tang 已提交
330

331 332 333 334 335 336 337 338 339 340 341
  std::lock_guard<std::mutex> lock(*p_mutex_.get());

  // Find KeyBlob for current thread firstly
  auto map_it = pMap->find(tid);
  if (map_it == pMap->end()) return nullptr;
  pBlob = map_it->second;

  // Find Blob via name
  auto key_it = pBlob->find(name);

  if (key_it == pBlob->end()) return nullptr;
342

343 344
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
345 346 347 348
}

#endif

Q
qijun 已提交
349
}  // namespace platform
Q
qijun 已提交
350
}  // namespace paddle