device_context.cc 9.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
11
#include "paddle/fluid/platform/device_context.h"
12

13
#include <set>
14
#include <string>
Y
Yu Yang 已提交
15
#include <unordered_set>
16 17
#include <vector>

F
fengjiayi 已提交
18
#include "paddle/fluid/memory/memory.h"
F
fengjiayi 已提交
19
#ifdef PADDLE_WITH_CUDA
F
fengjiayi 已提交
20
#include "paddle/fluid/framework/rw_lock.h"
F
fengjiayi 已提交
21 22
#endif

Q
qijun 已提交
23 24 25
namespace paddle {
namespace platform {

D
dzhwinter 已提交
26 27
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
28
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
D
dzhwinter 已提交
29 30 31 32 33 34
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
    PADDLE_THROW(
        "'Place' is not supported, Please re-compile with WITH_GPU "
        "option");
  }
Y
Yu Yang 已提交
35
  return it->second.get();
D
dzhwinter 已提交
36 37 38 39 40
}

DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
  PADDLE_ENFORCE_GT(places.size(), 0);
Y
Yu Yang 已提交
41
  using PtrType = std::unique_ptr<DeviceContext>;
42
  std::set<Place> set;
Y
Yu Yang 已提交
43 44 45 46 47 48
  for (auto& p : places) {
    set.insert(p);
  }

  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
49
#ifdef PADDLE_WITH_MKLDNN
Y
Yu Yang 已提交
50 51
      device_contexts_.emplace(
          p, PtrType(new MKLDNNDeviceContext(boost::get<CPUPlace>(p))));
52
#else
Y
Yu Yang 已提交
53 54
      device_contexts_.emplace(
          p, PtrType(new CPUDeviceContext(boost::get<CPUPlace>(p))));
55
#endif
Y
Yu Yang 已提交
56
    } else if (platform::is_gpu_place(p)) {
D
dzhwinter 已提交
57
#ifdef PADDLE_WITH_CUDA
Y
Yu Yang 已提交
58 59
      device_contexts_.emplace(
          p, PtrType(new CUDADeviceContext(boost::get<CUDAPlace>(p))));
D
dzhwinter 已提交
60 61
#else
      PADDLE_THROW(
D
dzhwinter 已提交
62
          "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
D
dzhwinter 已提交
63
          "option");
C
chengduoZH 已提交
64 65 66 67 68 69 70 71 72 73
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
#ifdef PADDLE_WITH_CUDA
      device_contexts_.emplace(
          p,
          PtrType(new CUDAPinnedDeviceContext(boost::get<CUDAPinnedPlace>(p))));
#else
      PADDLE_THROW(
          "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
          "option");
D
dzhwinter 已提交
74 75 76 77 78
#endif
    }
  }
}

79 80 81 82
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
83
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
84 85 86 87 88 89 90
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
91
Place CPUDeviceContext::GetPlace() const { return place_; }
92

93
#ifdef PADDLE_WITH_CUDA
94

Q
init  
qijun 已提交
95 96 97 98 99 100 101
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

D
dzhwinter 已提交
102
  void Reinitialize(const cudaStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
103 104 105 106 107 108 109 110 111 112 113 114
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

  const cudaStream_t& stream() const override { return *stream_; }

  const cudaDeviceProp& deviceProperties() const override {
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
Q
qijun 已提交
115
    return paddle::memory::Alloc(place_, num_bytes);
Q
init  
qijun 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
  }

  void deallocate(void* buffer) const override {
    paddle::memory::Free(place_, buffer);
  }

  void* scratchpad() const override {
    if (scratch_ == NULL) {
      scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int));
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
      char* scratch =
          static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
      PADDLE_ENFORCE(
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
141
  CUDAPlace place_;
Q
init  
qijun 已提交
142 143
  const cudaStream_t* stream_;         // not owned;
  const cudaDeviceProp* device_prop_;  // not owned;
Q
qijun 已提交
144
  mutable void* scratch_;
Q
init  
qijun 已提交
145 146 147
  mutable unsigned int* semaphore_;
};

148 149 150
class CudnnHolder {
 public:
  CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
F
fengjiayi 已提交
151
      : workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) {
152
    PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
F
fengjiayi 已提交
153
    PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
154 155
  }

F
fengjiayi 已提交
156 157 158 159
  cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }

  void RunFunc(const std::function<void(void*)>& cudnn_func,
               size_t required_workspace_len) {
F
fengjiayi 已提交
160 161
    framework::RWLockGuard lock_guard(&rw_lock_,
                                      framework::RWLockGuard::Status::kRDLock);
F
fengjiayi 已提交
162
    if (required_workspace_len > workspace_len_) {
F
fengjiayi 已提交
163 164 165 166 167
      lock_guard.UnLock();
      lock_guard.WRLock();
      ReallocateWorkspace(required_workspace_len);
      lock_guard.UnLock();
      lock_guard.RDLock();
168
    }
F
fengjiayi 已提交
169
    cudnn_func(workspace_);
170 171 172 173 174
  }

  ~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); }

 private:
F
fengjiayi 已提交
175
  void ReallocateWorkspace(size_t required_workspace_len) {
F
fengjiayi 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189
    if (required_workspace_len <= workspace_len_) {
      return;
    }
    void* new_workspace = paddle::memory::Alloc(place_, required_len);
    if (workspace_ != nullptr) {
      // Maybe someone is using the current workspace
      PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
      PADDLE_ENFORCE(cudaGetLastError());
      paddle::memory::Free(place_, workspace_);
    }
    workspace_ = new_workspace;
    workspace_len_ = required_len;
  }

190 191 192 193 194 195
  cudnnHandle_t cudnn_handle_;
  void* workspace_;
  size_t workspace_len_;

  const cudaStream_t* stream_;  // not owned;
  const CUDAPlace place_;
F
fengjiayi 已提交
196

F
fengjiayi 已提交
197
  framework::RWLock rw_lock_;
198 199 200 201
};

CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
    : place_(place), cudnn_holder_(nullptr) {
202
  SetDeviceId(place_.device);
K
Kexin Zhao 已提交
203
  compute_capability = GetCUDAComputeCapability(place_.device);
204 205
  multi_process = GetCUDAMultiProcessors(place_.device);
  max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
Q
init  
qijun 已提交
206 207 208
  PADDLE_ENFORCE(cudaStreamCreate(&stream_));
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&stream_, place);
209
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
210 211
  PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
  PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
D
dzhwinter 已提交
212
  if (dynload::HasCUDNN()) {
213
    cudnn_holder_.reset(new CudnnHolder(&stream_, place));
D
dzhwinter 已提交
214
  }
215 216 217 218
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
L
liaogang 已提交
219
  Wait();
220
  PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
221 222
  eigen_stream_.reset();
  eigen_device_.reset();
Q
init  
qijun 已提交
223
  PADDLE_ENFORCE(cudaStreamDestroy(stream_));
224 225
}

L
liaogang 已提交
226
Place CUDADeviceContext::GetPlace() const { return place_; }
227

L
liaogang 已提交
228
void CUDADeviceContext::Wait() const {
Q
init  
qijun 已提交
229
  PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
230 231 232
  PADDLE_ENFORCE(cudaGetLastError());
}

K
Kexin Zhao 已提交
233 234 235 236
int CUDADeviceContext::GetComputeCapability() const {
  return compute_capability;
}

237 238 239 240
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
  return multi_process * max_threads_per_mp;
}

241 242 243 244
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
  return eigen_device_.get();
}

245
cublasHandle_t CUDADeviceContext::cublas_handle() const {
246 247 248
  return cublas_handle_;
}

249
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
F
fengjiayi 已提交
250
  return cudnn_holder_->cudnn_handle();
251 252
}

F
fengjiayi 已提交
253 254 255
void CUDADeviceContext::RunCudnnFuncWithWorkspace(
    const std::function<void(void*)>& cudnn_func, size_t workspace_len) const {
  cudnn_holder_->RunFunc(cudnn_func, workspace_len);
256
}
257

258
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
Q
qijun 已提交
259

C
chengduoZH 已提交
260 261 262 263 264 265 266 267 268 269 270 271 272 273
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
274
#endif
Q
qijun 已提交
275

T
tensor-tang 已提交
276 277
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
278 279
    : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobs_() {
  p_blobs_.reset(new std::unordered_map<std::string, std::shared_ptr<void>>());
T
tensor-tang 已提交
280 281
}

282 283 284 285
void MKLDNNDeviceContext::SetBlob(const std::string& name,
                                  std::shared_ptr<void> data) const {
  std::unordered_map<std::string, std::shared_ptr<void>>* p;
  p = p_blobs_.get();
T
tensor-tang 已提交
286

287
  auto it = p->find(name);
T
tensor-tang 已提交
288

289 290 291 292 293
  if (it == p->end()) {
    (*p)[name] = data;  // create new blob
  } else {
    it->second = data;  // set data to existing blob
  }
T
tensor-tang 已提交
294

295
  return;
T
tensor-tang 已提交
296 297
}

298 299 300 301
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
    const std::string& name) const {
  std::unordered_map<std::string, std::shared_ptr<void>>* p;
  p = p_blobs_.get();
T
tensor-tang 已提交
302

303
  auto it = p->find(name);
T
tensor-tang 已提交
304

305 306
  if (it != p->end()) {
    return it->second;
T
tensor-tang 已提交
307
  }
308 309

  return nullptr;
T
tensor-tang 已提交
310 311 312 313
}

#endif

Q
qijun 已提交
314
}  // namespace platform
Q
qijun 已提交
315
}  // namespace paddle