device_context.cc 8.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
11
#include "paddle/fluid/platform/device_context.h"
12

13
#include <set>
14
#include <string>
Y
Yu Yang 已提交
15
#include <unordered_set>
16 17
#include <vector>

Y
Yi Wang 已提交
18
#include "paddle/fluid/memory/memory.h"
19

Q
qijun 已提交
20 21 22
namespace paddle {
namespace platform {

D
dzhwinter 已提交
23 24
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
25
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
D
dzhwinter 已提交
26 27 28 29 30 31
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
    PADDLE_THROW(
        "'Place' is not supported, Please re-compile with WITH_GPU "
        "option");
  }
Y
Yu Yang 已提交
32
  return it->second.get();
D
dzhwinter 已提交
33 34 35 36 37
}

DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
  PADDLE_ENFORCE_GT(places.size(), 0);
Y
Yu Yang 已提交
38
  using PtrType = std::unique_ptr<DeviceContext>;
39
  std::set<Place> set;
Y
Yu Yang 已提交
40 41 42 43 44 45
  for (auto& p : places) {
    set.insert(p);
  }

  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
46
#ifdef PADDLE_WITH_MKLDNN
Y
Yu Yang 已提交
47 48
      device_contexts_.emplace(
          p, PtrType(new MKLDNNDeviceContext(boost::get<CPUPlace>(p))));
49
#else
Y
Yu Yang 已提交
50 51
      device_contexts_.emplace(
          p, PtrType(new CPUDeviceContext(boost::get<CPUPlace>(p))));
52
#endif
Y
Yu Yang 已提交
53
    } else if (platform::is_gpu_place(p)) {
D
dzhwinter 已提交
54
#ifdef PADDLE_WITH_CUDA
Y
Yu Yang 已提交
55 56
      device_contexts_.emplace(
          p, PtrType(new CUDADeviceContext(boost::get<CUDAPlace>(p))));
D
dzhwinter 已提交
57 58
#else
      PADDLE_THROW(
D
dzhwinter 已提交
59
          "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
D
dzhwinter 已提交
60
          "option");
C
chengduoZH 已提交
61 62 63 64 65 66 67 68 69 70
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
#ifdef PADDLE_WITH_CUDA
      device_contexts_.emplace(
          p,
          PtrType(new CUDAPinnedDeviceContext(boost::get<CUDAPinnedPlace>(p))));
#else
      PADDLE_THROW(
          "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
          "option");
D
dzhwinter 已提交
71 72 73 74 75
#endif
    }
  }
}

76 77 78 79
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
80
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
81 82 83 84 85 86 87
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
88
Place CPUDeviceContext::GetPlace() const { return place_; }
89

90
#ifdef PADDLE_WITH_CUDA
91

Q
init  
qijun 已提交
92 93 94 95 96 97 98
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

D
dzhwinter 已提交
99
  void Reinitialize(const cudaStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
100 101 102 103 104 105 106 107 108 109 110 111
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

  const cudaStream_t& stream() const override { return *stream_; }

  const cudaDeviceProp& deviceProperties() const override {
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
Q
qijun 已提交
112
    return paddle::memory::Alloc(place_, num_bytes);
Q
init  
qijun 已提交
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
  }

  void deallocate(void* buffer) const override {
    paddle::memory::Free(place_, buffer);
  }

  void* scratchpad() const override {
    if (scratch_ == NULL) {
      scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int));
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
      char* scratch =
          static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
      PADDLE_ENFORCE(
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
138
  CUDAPlace place_;
Q
init  
qijun 已提交
139 140
  const cudaStream_t* stream_;         // not owned;
  const cudaDeviceProp* device_prop_;  // not owned;
Q
qijun 已提交
141
  mutable void* scratch_;
Q
init  
qijun 已提交
142 143 144
  mutable unsigned int* semaphore_;
};

145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
class CudnnHolder {
 public:
  CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
      : stream_(stream), place_(place), workspace_(nullptr), workspace_len_(0) {
    PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
    PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
  }

  cudnnHandle_t get_cudnn_handle() const { return cudnn_handle_; }

  void* get_workspace(size_t required_len) {
    if (required_len > workspace_len_) {
      void* new_workspace = paddle::memory::Alloc(place_, required_len);
      if (workspace_ != nullptr) {
        // Maybe someone is using the current workspace
        PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
        PADDLE_ENFORCE(cudaGetLastError());
        paddle::memory::Free(place_, workspace_);
      }
      workspace_ = new_workspace;
    }
    return workspace_
  }

  ~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); }

 private:
  cudnnHandle_t cudnn_handle_;
  void* workspace_;
  size_t workspace_len_;

  const cudaStream_t* stream_;  // not owned;
  const CUDAPlace place_;
};

CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
    : place_(place), cudnn_holder_(nullptr) {
182
  SetDeviceId(place_.device);
K
Kexin Zhao 已提交
183
  compute_capability = GetCUDAComputeCapability(place_.device);
184 185
  multi_process = GetCUDAMultiProcessors(place_.device);
  max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
Q
init  
qijun 已提交
186 187 188
  PADDLE_ENFORCE(cudaStreamCreate(&stream_));
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&stream_, place);
189
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
190 191
  PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
  PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
D
dzhwinter 已提交
192
  if (dynload::HasCUDNN()) {
193
    cudnn_holder_.reset(new CudnnHolder(&stream_, place));
D
dzhwinter 已提交
194
  }
195 196 197 198
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
L
liaogang 已提交
199
  Wait();
200
  PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
201 202
  eigen_stream_.reset();
  eigen_device_.reset();
Q
init  
qijun 已提交
203
  PADDLE_ENFORCE(cudaStreamDestroy(stream_));
204 205
}

L
liaogang 已提交
206
Place CUDADeviceContext::GetPlace() const { return place_; }
207

L
liaogang 已提交
208
void CUDADeviceContext::Wait() const {
Q
init  
qijun 已提交
209
  PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
210 211 212
  PADDLE_ENFORCE(cudaGetLastError());
}

K
Kexin Zhao 已提交
213 214 215 216
int CUDADeviceContext::GetComputeCapability() const {
  return compute_capability;
}

217 218 219 220
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
  return multi_process * max_threads_per_mp;
}

221 222 223 224
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
  return eigen_device_.get();
}

225
cublasHandle_t CUDADeviceContext::cublas_handle() const {
226 227 228
  return cublas_handle_;
}

229 230 231 232 233 234 235
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
  return cudnn_holder_->get_cudnn_handle();
}

void* cudnn_workspace(size_t required_len) const {
  return cudnn_holder_->get_workspace(required_len);
}
236

237
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
Q
qijun 已提交
238

C
chengduoZH 已提交
239 240 241 242 243 244 245 246 247 248 249 250 251 252
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
253
#endif
Q
qijun 已提交
254

T
tensor-tang 已提交
255 256
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
257 258
    : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobs_() {
  p_blobs_.reset(new std::unordered_map<std::string, std::shared_ptr<void>>());
T
tensor-tang 已提交
259 260
}

261 262 263 264
void MKLDNNDeviceContext::SetBlob(const std::string& name,
                                  std::shared_ptr<void> data) const {
  std::unordered_map<std::string, std::shared_ptr<void>>* p;
  p = p_blobs_.get();
T
tensor-tang 已提交
265

266
  auto it = p->find(name);
T
tensor-tang 已提交
267

268 269 270 271 272
  if (it == p->end()) {
    (*p)[name] = data;  // create new blob
  } else {
    it->second = data;  // set data to existing blob
  }
T
tensor-tang 已提交
273

274
  return;
T
tensor-tang 已提交
275 276
}

277 278 279 280
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
    const std::string& name) const {
  std::unordered_map<std::string, std::shared_ptr<void>>* p;
  p = p_blobs_.get();
T
tensor-tang 已提交
281

282
  auto it = p->find(name);
T
tensor-tang 已提交
283

284 285
  if (it != p->end()) {
    return it->second;
T
tensor-tang 已提交
286
  }
287 288

  return nullptr;
T
tensor-tang 已提交
289 290 291 292
}

#endif

Q
qijun 已提交
293
}  // namespace platform
Q
qijun 已提交
294
}  // namespace paddle