device_context.cc 6.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
Y
Yu Yang 已提交
13
#include <unordered_set>
Y
Yi Wang 已提交
14
#include "paddle/fluid/memory/memory.h"
Q
qijun 已提交
15 16 17
namespace paddle {
namespace platform {

D
dzhwinter 已提交
18 19
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
20
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
D
dzhwinter 已提交
21 22 23 24 25 26
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
    PADDLE_THROW(
        "'Place' is not supported, Please re-compile with WITH_GPU "
        "option");
  }
Y
Yu Yang 已提交
27
  return it->second.get();
D
dzhwinter 已提交
28 29 30 31 32
}

DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
  PADDLE_ENFORCE_GT(places.size(), 0);
Y
Yu Yang 已提交
33 34 35 36 37 38 39 40
  using PtrType = std::unique_ptr<DeviceContext>;
  std::unordered_set<Place, PlaceHash> set;
  for (auto& p : places) {
    set.insert(p);
  }

  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
41
#ifdef PADDLE_WITH_MKLDNN
Y
Yu Yang 已提交
42 43
      device_contexts_.emplace(
          p, PtrType(new MKLDNNDeviceContext(boost::get<CPUPlace>(p))));
44
#else
Y
Yu Yang 已提交
45 46
      device_contexts_.emplace(
          p, PtrType(new CPUDeviceContext(boost::get<CPUPlace>(p))));
47
#endif
Y
Yu Yang 已提交
48
    } else if (platform::is_gpu_place(p)) {
D
dzhwinter 已提交
49
#ifdef PADDLE_WITH_CUDA
Y
Yu Yang 已提交
50 51
      device_contexts_.emplace(
          p, PtrType(new CUDADeviceContext(boost::get<CUDAPlace>(p))));
D
dzhwinter 已提交
52 53
#else
      PADDLE_THROW(
D
dzhwinter 已提交
54
          "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
D
dzhwinter 已提交
55 56 57 58 59 60
          "option");
#endif
    }
  }
}

61 62 63 64
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
65
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
66 67 68 69 70 71 72
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
73
Place CPUDeviceContext::GetPlace() const { return place_; }
74

75
#ifdef PADDLE_WITH_CUDA
76

Q
init  
qijun 已提交
77 78 79 80 81 82 83
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

D
dzhwinter 已提交
84
  void Reinitialize(const cudaStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
85 86 87 88 89 90 91 92 93 94 95 96
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

  const cudaStream_t& stream() const override { return *stream_; }

  const cudaDeviceProp& deviceProperties() const override {
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
Q
qijun 已提交
97
    return paddle::memory::Alloc(place_, num_bytes);
Q
init  
qijun 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
  }

  void deallocate(void* buffer) const override {
    paddle::memory::Free(place_, buffer);
  }

  void* scratchpad() const override {
    if (scratch_ == NULL) {
      scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int));
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
      char* scratch =
          static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
      PADDLE_ENFORCE(
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
123
  CUDAPlace place_;
Q
init  
qijun 已提交
124 125
  const cudaStream_t* stream_;         // not owned;
  const cudaDeviceProp* device_prop_;  // not owned;
Q
qijun 已提交
126
  mutable void* scratch_;
Q
init  
qijun 已提交
127 128 129
  mutable unsigned int* semaphore_;
};

D
dzhwinter 已提交
130
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
131
  SetDeviceId(place_.device);
K
Kexin Zhao 已提交
132
  compute_capability = GetCUDAComputeCapability(place_.device);
133 134
  multi_process = GetCUDAMultiProcessors(place_.device);
  max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
Q
init  
qijun 已提交
135 136 137
  PADDLE_ENFORCE(cudaStreamCreate(&stream_));
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&stream_, place);
138
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
139 140
  PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
  PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
D
dzhwinter 已提交
141 142 143 144 145 146
  if (dynload::HasCUDNN()) {
    PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
    PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
  } else {
    cudnn_handle_ = nullptr;
  }
147 148 149 150
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
L
liaogang 已提交
151
  Wait();
152
  PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
D
dzhwinter 已提交
153 154 155
  if (cudnn_handle_ != nullptr) {
    PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
  }
156 157
  eigen_stream_.reset();
  eigen_device_.reset();
Q
init  
qijun 已提交
158
  PADDLE_ENFORCE(cudaStreamDestroy(stream_));
159 160
}

L
liaogang 已提交
161
Place CUDADeviceContext::GetPlace() const { return place_; }
162

L
liaogang 已提交
163
void CUDADeviceContext::Wait() const {
Y
Yu Yang 已提交
164
  std::lock_guard<std::mutex> guard(mutex_);
Q
init  
qijun 已提交
165
  PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
166 167 168
  PADDLE_ENFORCE(cudaGetLastError());
}

K
Kexin Zhao 已提交
169 170 171 172
int CUDADeviceContext::GetComputeCapability() const {
  return compute_capability;
}

173 174 175 176
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
  return multi_process * max_threads_per_mp;
}

177 178 179 180
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
  return eigen_device_.get();
}

181
cublasHandle_t CUDADeviceContext::cublas_handle() const {
182 183 184
  return cublas_handle_;
}

185
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
186

187
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
Q
qijun 已提交
188

L
Luo Tao 已提交
189
#endif
Q
qijun 已提交
190

T
tensor-tang 已提交
191 192
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
193 194
    : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobs_() {
  p_blobs_.reset(new std::unordered_map<std::string, std::shared_ptr<void>>());
T
tensor-tang 已提交
195 196
}

197 198 199 200
void MKLDNNDeviceContext::SetBlob(const std::string& name,
                                  std::shared_ptr<void> data) const {
  std::unordered_map<std::string, std::shared_ptr<void>>* p;
  p = p_blobs_.get();
T
tensor-tang 已提交
201

202
  auto it = p->find(name);
T
tensor-tang 已提交
203

204 205 206 207 208
  if (it == p->end()) {
    (*p)[name] = data;  // create new blob
  } else {
    it->second = data;  // set data to existing blob
  }
T
tensor-tang 已提交
209

210
  return;
T
tensor-tang 已提交
211 212
}

213 214 215 216
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
    const std::string& name) const {
  std::unordered_map<std::string, std::shared_ptr<void>>* p;
  p = p_blobs_.get();
T
tensor-tang 已提交
217

218
  auto it = p->find(name);
T
tensor-tang 已提交
219

220 221
  if (it != p->end()) {
    return it->second;
T
tensor-tang 已提交
222
  }
223 224

  return nullptr;
T
tensor-tang 已提交
225 226 227 228
}

#endif

Q
qijun 已提交
229
}  // namespace platform
Q
qijun 已提交
230
}  // namespace paddle