device_context.h 5.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
QI JUN 已提交
2 3 4 5 6 7 8 9 10 11 12 13
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dzhwinter 已提交
14 15
#include <memory>
#include <unordered_map>
L
liaogang 已提交
16

17
#ifdef PADDLE_WITH_CUDA
Y
Yi Wang 已提交
18 19 20
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/gpu_info.h"
Q
QI JUN 已提交
21 22
#define EIGEN_USE_GPU
#endif
D
dzhwinter 已提交
23

T
tensor-tang 已提交
24
#ifdef PADDLE_WITH_MKLDNN
25
#include <mkldnn.hpp>
T
tensor-tang 已提交
26 27
#endif

Y
Yi Wang 已提交
28 29
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
Q
qijun 已提交
30
#include "unsupported/Eigen/CXX11/Tensor"
Q
QI JUN 已提交
31

D
dzhwinter 已提交
32 33
#include "glog/logging.h"

Q
QI JUN 已提交
34 35 36 37 38 39
namespace paddle {
namespace platform {

class DeviceContext {
 public:
  virtual ~DeviceContext() {}
L
liaogang 已提交
40
  virtual Place GetPlace() const = 0;
Q
QI JUN 已提交
41

42
  virtual void Wait() const {}
Q
QI JUN 已提交
43 44
};

Q
qijun 已提交
45 46
class CPUDeviceContext : public DeviceContext {
 public:
47
  CPUDeviceContext();
Q
qijun 已提交
48
  explicit CPUDeviceContext(CPUPlace place);
Q
qijun 已提交
49

50
  Eigen::DefaultDevice* eigen_device() const;
Q
qijun 已提交
51

L
liaogang 已提交
52
  Place GetPlace() const override;
Y
Yu Yang 已提交
53

Q
qijun 已提交
54
 private:
D
dzhwinter 已提交
55
  CPUPlace place_;
Q
qijun 已提交
56
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
Q
QI JUN 已提交
57 58
};

Y
Yang Yu 已提交
59 60 61 62 63 64 65 66
template <typename Place>
struct DefaultDeviceContextType;

template <>
struct DefaultDeviceContextType<platform::CPUPlace> {
  using TYPE = CPUDeviceContext;
};

67
#ifdef PADDLE_WITH_CUDA
68

Q
qijun 已提交
69
class EigenCudaStreamDevice;
D
dongzhihong 已提交
70

71
class CUDADeviceContext : public DeviceContext {
Q
QI JUN 已提交
72
 public:
D
dzhwinter 已提交
73
  explicit CUDADeviceContext(CUDAPlace place);
74
  virtual ~CUDADeviceContext();
Q
QI JUN 已提交
75

76
  /*! \brief  Wait for all operations completion in the stream. */
77
  void Wait() const override;
Q
QI JUN 已提交
78

79
  /*! \brief  Return place in the device context. */
L
liaogang 已提交
80
  Place GetPlace() const override;
81

K
Kexin Zhao 已提交
82
  /*! \brief  Return compute capability in the device context. */
K
Kexin Zhao 已提交
83 84
  int GetComputeCapability() const;

85 86 87
  /*! \brief  Return the max physical thread count in the device context */
  int GetMaxPhysicalThreadCount() const;

88 89 90 91
  /*! \brief  Return eigen device in the device context. */
  Eigen::GpuDevice* eigen_device() const;

  /*! \brief  Return cublas handle in the device context. */
92
  cublasHandle_t cublas_handle() const;
93 94

  /*! \brief  Return cudnn  handle in the device context. */
95
  cudnnHandle_t cudnn_handle() const;
96

Q
init  
qijun 已提交
97
  /*! \brief  Return cuda stream in the device context. */
98
  cudaStream_t stream() const;
Q
QI JUN 已提交
99 100

 private:
D
dzhwinter 已提交
101
  CUDAPlace place_;
Q
QI JUN 已提交
102

Q
qijun 已提交
103
  std::unique_ptr<Eigen::GpuDevice> eigen_device_;
Q
init  
qijun 已提交
104
  std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
Q
QI JUN 已提交
105

106 107 108
  cudaStream_t stream_;
  cudnnHandle_t cudnn_handle_;
  cublasHandle_t cublas_handle_;
109

K
Kexin Zhao 已提交
110
  int compute_capability;
111 112
  int multi_process;
  int max_threads_per_mp;
Q
QI JUN 已提交
113
};
Q
qijun 已提交
114

Y
Yang Yu 已提交
115 116
template <>
struct DefaultDeviceContextType<platform::CUDAPlace> {
Y
Yang Yu 已提交
117
  using TYPE = CUDADeviceContext;
Y
Yang Yu 已提交
118 119
};

Q
QI JUN 已提交
120
#endif
Q
qijun 已提交
121

T
tensor-tang 已提交
122 123 124 125 126 127
#ifdef PADDLE_WITH_MKLDNN
class MKLDNNDeviceContext : public CPUDeviceContext {
 public:
  explicit MKLDNNDeviceContext(CPUPlace place);

  /* \brief  Get the active engine */
128
  const mkldnn::engine& GetEngine() const { return engine_; }
T
tensor-tang 已提交
129

130 131
  // Set data to blob (i.e. name/data pair). Create blob if not existing
  void SetBlob(const std::string& name, std::shared_ptr<void> data) const;
T
tensor-tang 已提交
132

133 134
  // Find a saved blob. Return nullptr if not found
  std::shared_ptr<void> GetBlob(const std::string& name) const;
T
tensor-tang 已提交
135 136

 private:
137 138 139
  mkldnn::engine engine_;
  std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<void>>>
      p_blobs_;
T
tensor-tang 已提交
140 141 142
};
#endif

D
dzhwinter 已提交
143 144 145 146 147
/*! \brief device context pool singleton */
class DeviceContextPool {
 public:
  explicit DeviceContextPool(const std::vector<platform::Place>& places);

Y
Yang Yu 已提交
148
  static DeviceContextPool& Instance() {
D
dzhwinter 已提交
149 150 151 152 153
    PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!");
    return *pool;
  }

  /*! \brief  Create should only called by Init function */
Y
Yang Yu 已提交
154
  static DeviceContextPool& Init(const std::vector<platform::Place>& places) {
D
dzhwinter 已提交
155 156 157 158 159 160 161
    if (pool == nullptr) {
      pool = new DeviceContextPool(places);
    }
    return *pool;
  }

  /*! \brief  Return handle of single device context. */
Y
Yang Yu 已提交
162
  const platform::DeviceContext* Get(const platform::Place& place);
D
dzhwinter 已提交
163

Y
Yang Yu 已提交
164 165 166 167 168 169 170
  template <typename Place>
  const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
      const Place& place) {
    return reinterpret_cast<
        const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place));
  }

171 172
  size_t size() const { return device_contexts_.size(); }

D
dzhwinter 已提交
173 174
 private:
  static DeviceContextPool* pool;
D
dzhwinter 已提交
175
  constexpr static int LEFT_SHIFT = 8;
D
dzhwinter 已提交
176 177 178
  struct Hash {
    std::hash<int> hash_;
    size_t operator()(const platform::Place& place) const {
D
dzhwinter 已提交
179
      int pre_hash = place.which() << LEFT_SHIFT;
D
dzhwinter 已提交
180
      if (platform::is_gpu_place(place)) {
D
dzhwinter 已提交
181
        pre_hash += boost::get<platform::CUDAPlace>(place).GetDeviceId();
D
dzhwinter 已提交
182 183 184 185 186 187 188 189 190 191
      }
      return hash_(pre_hash);
    }
  };
  std::unordered_map<const platform::Place, const platform::DeviceContext*,
                     Hash>
      device_contexts_;
  DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};

Q
QI JUN 已提交
192 193
}  // namespace platform
}  // namespace paddle