device_context.cc 9.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
11
#include "paddle/fluid/platform/device_context.h"
12
#include <set>
13
#include <string>
Y
Yu Yang 已提交
14
#include <unordered_set>
15
#include <vector>
Y
Yu Yang 已提交
16
#include "paddle/fluid/platform/cuda_device_guard.h"
17

Y
Yi Wang 已提交
18
#include "paddle/fluid/memory/memory.h"
19 20 21
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/rw_lock.h"
#endif
22

Q
qijun 已提交
23 24 25
namespace paddle {
namespace platform {

D
dzhwinter 已提交
26 27
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
28
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
D
dzhwinter 已提交
29 30 31 32 33 34
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
    PADDLE_THROW(
        "'Place' is not supported, Please re-compile with WITH_GPU "
        "option");
  }
Y
Yu Yang 已提交
35
  return it->second.get();
D
dzhwinter 已提交
36 37 38 39 40
}

DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
  PADDLE_ENFORCE_GT(places.size(), 0);
Y
Yu Yang 已提交
41
  using PtrType = std::unique_ptr<DeviceContext>;
42
  std::set<Place> set;
Y
Yu Yang 已提交
43 44 45 46 47 48
  for (auto& p : places) {
    set.insert(p);
  }

  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
49
#ifdef PADDLE_WITH_MKLDNN
Y
Yu Yang 已提交
50 51
      device_contexts_.emplace(
          p, PtrType(new MKLDNNDeviceContext(boost::get<CPUPlace>(p))));
52
#else
Y
Yu Yang 已提交
53 54
      device_contexts_.emplace(
          p, PtrType(new CPUDeviceContext(boost::get<CPUPlace>(p))));
55
#endif
Y
Yu Yang 已提交
56
    } else if (platform::is_gpu_place(p)) {
D
dzhwinter 已提交
57
#ifdef PADDLE_WITH_CUDA
Y
Yu Yang 已提交
58 59
      device_contexts_.emplace(
          p, PtrType(new CUDADeviceContext(boost::get<CUDAPlace>(p))));
D
dzhwinter 已提交
60 61
#else
      PADDLE_THROW(
D
dzhwinter 已提交
62
          "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
D
dzhwinter 已提交
63
          "option");
C
chengduoZH 已提交
64 65 66 67 68 69 70 71 72 73
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
#ifdef PADDLE_WITH_CUDA
      device_contexts_.emplace(
          p,
          PtrType(new CUDAPinnedDeviceContext(boost::get<CUDAPinnedPlace>(p))));
#else
      PADDLE_THROW(
          "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
          "option");
D
dzhwinter 已提交
74 75 76 77 78
#endif
    }
  }
}

79 80 81 82
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
83
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
84 85 86 87 88 89 90
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
91
Place CPUDeviceContext::GetPlace() const { return place_; }
92

93
#ifdef PADDLE_WITH_CUDA
94

Q
init  
qijun 已提交
95 96 97 98 99 100 101
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

D
dzhwinter 已提交
102
  void Reinitialize(const cudaStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
103 104 105 106 107 108 109 110 111 112 113 114
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

  const cudaStream_t& stream() const override { return *stream_; }

  const cudaDeviceProp& deviceProperties() const override {
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
Y
Yu Yang 已提交
115 116
    auto buf = paddle::memory::Alloc(place_, num_bytes,
                                     memory::Allocator::kScratchpad);
117 118 119
    void* retv = buf->ptr();
    allocations_[buf->ptr()] = std::move(buf);
    return retv;
Q
init  
qijun 已提交
120 121 122
  }

  void deallocate(void* buffer) const override {
123
    allocations_.erase(allocations_.find(buffer));
Q
init  
qijun 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
  }

  void* scratchpad() const override {
    if (scratch_ == NULL) {
      scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int));
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
      char* scratch =
          static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
      PADDLE_ENFORCE(
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
145
  CUDAPlace place_;
Q
init  
qijun 已提交
146 147
  const cudaStream_t* stream_;         // not owned;
  const cudaDeviceProp* device_prop_;  // not owned;
Q
qijun 已提交
148
  mutable void* scratch_;
Q
init  
qijun 已提交
149
  mutable unsigned int* semaphore_;
150 151
  mutable std::unordered_map<void*, std::unique_ptr<memory::Allocation>>
      allocations_;
Q
init  
qijun 已提交
152 153
};

154 155 156
class CudnnHolder {
 public:
  CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
157
      : workspace_(nullptr), stream_(stream), place_(place) {
158 159 160 161 162 163 164 165 166
    PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
    PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
  }

  cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }

  void RunFunc(const std::function<void(void*)>& cudnn_func,
               size_t required_workspace_len) {
    std::lock_guard<std::mutex> lock(mtx_);
167
    if (required_workspace_len > WorkspaceSize()) {
168 169
      ReallocateWorkspace(required_workspace_len);
    }
Y
Yu Yang 已提交
170
    cudnn_func(WorkspacePtr());
171 172
  }

173 174 175 176 177 178 179 180
  ~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); }

 private:
  size_t WorkspaceSize() const {
    if (workspace_ == nullptr) {
      return 0;
    } else {
      return workspace_->size();
181 182 183
    }
  }

Y
Yu Yang 已提交
184 185 186 187 188 189 190 191
  void* WorkspacePtr() const {
    if (workspace_ == nullptr) {
      return nullptr;
    } else {
      return workspace_->ptr();
    }
  }

192
  void ReallocateWorkspace(size_t required_workspace_len) {
193
    if (required_workspace_len <= WorkspaceSize()) {
194 195 196 197 198
      return;
    }
    if (workspace_ != nullptr) {
      // Maybe someone is using the current workspace
      PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
199
      workspace_.reset();
200
    }
201 202
    workspace_ = paddle::memory::Alloc(place_, required_workspace_len,
                                       memory::Allocator::kFluxHuge);
203 204 205
  }

  cudnnHandle_t cudnn_handle_;
206
  std::unique_ptr<memory::Allocation> workspace_;
207 208 209 210 211 212 213 214 215

  const cudaStream_t* stream_;  // not owned;
  const CUDAPlace place_;

  std::mutex mtx_;
};

CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
    : place_(place), cudnn_holder_(nullptr) {
Y
Yu Yang 已提交
216
  CUDADeviceGuard guard(place_.device);
K
Kexin Zhao 已提交
217
  compute_capability = GetCUDAComputeCapability(place_.device);
218 219
  multi_process = GetCUDAMultiProcessors(place_.device);
  max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
Q
init  
qijun 已提交
220 221 222
  PADDLE_ENFORCE(cudaStreamCreate(&stream_));
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&stream_, place);
223
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
224 225
  PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
  PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
D
dzhwinter 已提交
226
  if (dynload::HasCUDNN()) {
227
    cudnn_holder_.reset(new CudnnHolder(&stream_, place));
D
dzhwinter 已提交
228
  }
S
sneaxiy 已提交
229 230

  callback_manager_.reset(new StreamCallbackManager(stream_));
231 232 233 234
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
L
liaogang 已提交
235
  Wait();
S
sneaxiy 已提交
236
  WaitStreamCallback();
237
  PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
238 239
  eigen_stream_.reset();
  eigen_device_.reset();
Q
init  
qijun 已提交
240
  PADDLE_ENFORCE(cudaStreamDestroy(stream_));
241 242
}

L
liaogang 已提交
243
Place CUDADeviceContext::GetPlace() const { return place_; }
244

L
liaogang 已提交
245
void CUDADeviceContext::Wait() const {
Q
init  
qijun 已提交
246
  PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
247 248 249
  PADDLE_ENFORCE(cudaGetLastError());
}

K
Kexin Zhao 已提交
250 251 252 253
int CUDADeviceContext::GetComputeCapability() const {
  return compute_capability;
}

254 255 256 257
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
  return multi_process * max_threads_per_mp;
}

258 259 260 261
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
  return eigen_device_.get();
}

262
cublasHandle_t CUDADeviceContext::cublas_handle() const {
263 264 265
  return cublas_handle_;
}

266 267 268 269 270 271 272 273
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
  return cudnn_holder_->cudnn_handle();
}

void CUDADeviceContext::RunCudnnFuncWithWorkspace(
    const std::function<void(void*)>& cudnn_func, size_t workspace_len) const {
  cudnn_holder_->RunFunc(cudnn_func, workspace_len);
}
274

275
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
Q
qijun 已提交
276

C
chengduoZH 已提交
277 278 279 280 281 282 283 284 285 286 287 288 289 290
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
291
#endif
Q
qijun 已提交
292

T
tensor-tang 已提交
293 294
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
295 296
    : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobs_() {
  p_blobs_.reset(new std::unordered_map<std::string, std::shared_ptr<void>>());
T
tensor-tang 已提交
297 298
}

299 300 301 302
void MKLDNNDeviceContext::SetBlob(const std::string& name,
                                  std::shared_ptr<void> data) const {
  std::unordered_map<std::string, std::shared_ptr<void>>* p;
  p = p_blobs_.get();
T
tensor-tang 已提交
303

304
  auto it = p->find(name);
T
tensor-tang 已提交
305

306 307 308 309 310
  if (it == p->end()) {
    (*p)[name] = data;  // create new blob
  } else {
    it->second = data;  // set data to existing blob
  }
T
tensor-tang 已提交
311

312
  return;
T
tensor-tang 已提交
313 314
}

315 316 317 318
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
    const std::string& name) const {
  std::unordered_map<std::string, std::shared_ptr<void>>* p;
  p = p_blobs_.get();
T
tensor-tang 已提交
319

320
  auto it = p->find(name);
T
tensor-tang 已提交
321

322 323
  if (it != p->end()) {
    return it->second;
T
tensor-tang 已提交
324
  }
325 326

  return nullptr;
T
tensor-tang 已提交
327 328 329 330
}

#endif

Q
qijun 已提交
331
}  // namespace platform
Q
qijun 已提交
332
}  // namespace paddle