device_context.cc 6.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
12 13
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/memory/memory.h"
Q
qijun 已提交
14 15 16 17

namespace paddle {
namespace platform {

D
dzhwinter 已提交
18 19
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
20
const platform::DeviceContext* DeviceContextPool::Get(
D
dzhwinter 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
    const platform::Place& place) {
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
    PADDLE_THROW(
        "'Place' is not supported, Please re-compile with WITH_GPU "
        "option");
  }
  return it->second;
}

DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
  PADDLE_ENFORCE_GT(places.size(), 0);
  for (size_t i = 0; i < places.size(); i++) {
    if (platform::is_cpu_place(places[i])) {
36 37 38 39 40
#ifdef PADDLE_WITH_MKLDNN
      device_contexts_.emplace(places[i],
                               new platform::MKLDNNDeviceContext(
                                   boost::get<platform::CPUPlace>(places[i])));
#else
D
dzhwinter 已提交
41 42 43
      device_contexts_.emplace(places[i],
                               new platform::CPUDeviceContext(
                                   boost::get<platform::CPUPlace>(places[i])));
44
#endif
D
dzhwinter 已提交
45 46 47 48
    } else if (platform::is_gpu_place(places[i])) {
#ifdef PADDLE_WITH_CUDA
      device_contexts_.emplace(places[i],
                               new platform::CUDADeviceContext(
D
dzhwinter 已提交
49
                                   boost::get<platform::CUDAPlace>(places[i])));
D
dzhwinter 已提交
50 51
#else
      PADDLE_THROW(
D
dzhwinter 已提交
52
          "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
D
dzhwinter 已提交
53 54 55 56 57 58
          "option");
#endif
    }
  }
}

59 60 61 62
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
63
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
64 65 66 67 68 69 70
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
71
Place CPUDeviceContext::GetPlace() const { return place_; }
72

73
#ifdef PADDLE_WITH_CUDA
74

Q
init  
qijun 已提交
75 76 77 78 79 80 81
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

D
dzhwinter 已提交
82
  void Reinitialize(const cudaStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
83 84 85 86 87 88 89 90 91 92 93 94
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

  const cudaStream_t& stream() const override { return *stream_; }

  const cudaDeviceProp& deviceProperties() const override {
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
Q
qijun 已提交
95
    return paddle::memory::Alloc(place_, num_bytes);
Q
init  
qijun 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
  }

  void deallocate(void* buffer) const override {
    paddle::memory::Free(place_, buffer);
  }

  void* scratchpad() const override {
    if (scratch_ == NULL) {
      scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int));
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
      char* scratch =
          static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
      PADDLE_ENFORCE(
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
121
  CUDAPlace place_;
Q
init  
qijun 已提交
122 123
  const cudaStream_t* stream_;         // not owned;
  const cudaDeviceProp* device_prop_;  // not owned;
Q
qijun 已提交
124
  mutable void* scratch_;
Q
init  
qijun 已提交
125 126 127
  mutable unsigned int* semaphore_;
};

D
dzhwinter 已提交
128
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
129
  SetDeviceId(place_.device);
130 131
  multi_process = GetCUDAMultiProcessors(place_.device);
  max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
Q
init  
qijun 已提交
132 133 134
  PADDLE_ENFORCE(cudaStreamCreate(&stream_));
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&stream_, place);
135
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
136 137
  PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
  PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
D
dzhwinter 已提交
138 139 140 141 142 143
  if (dynload::HasCUDNN()) {
    PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
    PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
  } else {
    cudnn_handle_ = nullptr;
  }
144 145 146 147
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
L
liaogang 已提交
148
  Wait();
149
  PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
D
dzhwinter 已提交
150 151 152
  if (cudnn_handle_ != nullptr) {
    PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
  }
153 154
  eigen_stream_.reset();
  eigen_device_.reset();
Q
init  
qijun 已提交
155
  PADDLE_ENFORCE(cudaStreamDestroy(stream_));
156 157
}

L
liaogang 已提交
158
Place CUDADeviceContext::GetPlace() const { return place_; }
159

L
liaogang 已提交
160
void CUDADeviceContext::Wait() const {
Q
init  
qijun 已提交
161
  PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
162 163 164
  PADDLE_ENFORCE(cudaGetLastError());
}

165 166 167 168
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
  return multi_process * max_threads_per_mp;
}

169 170 171 172
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
  return eigen_device_.get();
}

173
cublasHandle_t CUDADeviceContext::cublas_handle() const {
174 175 176
  return cublas_handle_;
}

177
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
178

179
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
Q
qijun 已提交
180

L
Luo Tao 已提交
181
#endif
Q
qijun 已提交
182

T
tensor-tang 已提交
183 184
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
185 186
    : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobs_() {
  p_blobs_.reset(new std::unordered_map<std::string, std::shared_ptr<void>>());
T
tensor-tang 已提交
187 188
}

189 190 191 192
void MKLDNNDeviceContext::SetBlob(const std::string& name,
                                  std::shared_ptr<void> data) const {
  std::unordered_map<std::string, std::shared_ptr<void>>* p;
  p = p_blobs_.get();
T
tensor-tang 已提交
193

194
  auto it = p->find(name);
T
tensor-tang 已提交
195

196 197 198 199 200
  if (it == p->end()) {
    (*p)[name] = data;  // create new blob
  } else {
    it->second = data;  // set data to existing blob
  }
T
tensor-tang 已提交
201

202
  return;
T
tensor-tang 已提交
203 204
}

205 206 207 208
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
    const std::string& name) const {
  std::unordered_map<std::string, std::shared_ptr<void>>* p;
  p = p_blobs_.get();
T
tensor-tang 已提交
209

210
  auto it = p->find(name);
T
tensor-tang 已提交
211

212 213
  if (it != p->end()) {
    return it->second;
T
tensor-tang 已提交
214
  }
215 216

  return nullptr;
T
tensor-tang 已提交
217 218 219 220
}

#endif

Q
qijun 已提交
221
}  // namespace platform
Q
qijun 已提交
222
}  // namespace paddle