device_context.h 5.9 KB
Newer Older
Q
QI JUN 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dzhwinter 已提交
14 15
#include <memory>
#include <unordered_map>
L
liaogang 已提交
16

17
#ifdef PADDLE_WITH_CUDA
Y
Yi Wang 已提交
18 19 20
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/gpu_info.h"
Q
QI JUN 已提交
21 22
#define EIGEN_USE_GPU
#endif
D
dzhwinter 已提交
23

T
tensor-tang 已提交
24
#ifdef PADDLE_WITH_MKLDNN
Y
Yi Wang 已提交
25
#include "paddle/fluid/platform/mkldnn_helper.h"
T
tensor-tang 已提交
26 27
#endif

Y
Yi Wang 已提交
28 29
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
Q
qijun 已提交
30
#include "unsupported/Eigen/CXX11/Tensor"
Q
QI JUN 已提交
31

D
dzhwinter 已提交
32 33
#include "glog/logging.h"

Q
QI JUN 已提交
34 35 36 37 38 39
namespace paddle {
namespace platform {

class DeviceContext {
 public:
  virtual ~DeviceContext() {}
L
liaogang 已提交
40
  virtual Place GetPlace() const = 0;
Q
QI JUN 已提交
41

42
  virtual void Wait() const {}
Q
QI JUN 已提交
43 44
};

Q
qijun 已提交
45 46
class CPUDeviceContext : public DeviceContext {
 public:
47
  CPUDeviceContext();
Q
qijun 已提交
48
  explicit CPUDeviceContext(CPUPlace place);
Q
qijun 已提交
49

50
  Eigen::DefaultDevice* eigen_device() const;
Q
qijun 已提交
51

L
liaogang 已提交
52
  Place GetPlace() const override;
Y
Yu Yang 已提交
53

Q
qijun 已提交
54
 private:
D
dzhwinter 已提交
55
  CPUPlace place_;
Q
qijun 已提交
56
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
Q
QI JUN 已提交
57 58
};

Y
Yang Yu 已提交
59 60 61 62 63 64 65 66
template <typename Place>
struct DefaultDeviceContextType;

template <>
struct DefaultDeviceContextType<platform::CPUPlace> {
  using TYPE = CPUDeviceContext;
};

67
#ifdef PADDLE_WITH_CUDA
68

Q
qijun 已提交
69
class EigenCudaStreamDevice;
D
dongzhihong 已提交
70

71
class CUDADeviceContext : public DeviceContext {
Q
QI JUN 已提交
72
 public:
D
dzhwinter 已提交
73
  explicit CUDADeviceContext(CUDAPlace place);
74
  virtual ~CUDADeviceContext();
Q
QI JUN 已提交
75

76
  /*! \brief  Wait for all operations completion in the stream. */
77
  void Wait() const override;
Q
QI JUN 已提交
78

79
  /*! \brief  Return place in the device context. */
L
liaogang 已提交
80
  Place GetPlace() const override;
81 82 83 84 85

  /*! \brief  Return eigen device in the device context. */
  Eigen::GpuDevice* eigen_device() const;

  /*! \brief  Return cublas handle in the device context. */
86
  cublasHandle_t cublas_handle() const;
87 88

  /*! \brief  Return cudnn  handle in the device context. */
89
  cudnnHandle_t cudnn_handle() const;
90

Q
init  
qijun 已提交
91
  /*! \brief  Return cuda stream in the device context. */
92
  cudaStream_t stream() const;
Q
QI JUN 已提交
93 94

 private:
D
dzhwinter 已提交
95
  CUDAPlace place_;
Q
QI JUN 已提交
96

Q
qijun 已提交
97
  std::unique_ptr<Eigen::GpuDevice> eigen_device_;
Q
init  
qijun 已提交
98
  std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
Q
QI JUN 已提交
99

100 101 102
  cudaStream_t stream_;
  cudnnHandle_t cudnn_handle_;
  cublasHandle_t cublas_handle_;
Q
QI JUN 已提交
103
};
Q
qijun 已提交
104

Y
Yang Yu 已提交
105 106
template <>
struct DefaultDeviceContextType<platform::CUDAPlace> {
Y
Yang Yu 已提交
107
  using TYPE = CUDADeviceContext;
Y
Yang Yu 已提交
108 109
};

Q
QI JUN 已提交
110
#endif
Q
qijun 已提交
111

T
tensor-tang 已提交
112 113 114 115 116 117 118 119 120 121 122
#ifdef PADDLE_WITH_MKLDNN
class MKLDNNDeviceContext : public CPUDeviceContext {
 public:
  explicit MKLDNNDeviceContext(CPUPlace place);

  /* \brief  Add new element: memory, primitive or primitive desc */
  template <typename T>
  void AddElement(const std::string& op_key, const T& value);

  /* \brief  Get existed element: memory, primitive or primitive desc */
  template <typename T>
123
  const T& GetElement(const std::string& op_key) const;
T
tensor-tang 已提交
124 125 126 127 128 129 130

  /* \brief  Get element pool: memory, primitive or primitive desc pool */
  template <typename T>
  const std::unordered_map<const std::string, const T, std::hash<std::string>>&
  GetElementPool() const;

  /* \brief  Get the active engine */
131
  const MKLDNNEngine& engine() const { return *engine_; }
T
tensor-tang 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153

  /* \brief  Submit primitive to pipeline */
  void Submit(const MKLDNNPrimitivePtr& p) { pipeline_.push_back(*p); }

  /*! \brief  Execute all submitted primitives in pipeline */
  void Execute(bool block = true);

 protected:
  /*! \brief  Reset the stream to prepare next exectue */
  void ResetStream();

 private:
  std::unordered_map<const std::string, const MKLDNNMemoryPtr,
                     std::hash<std::string>>
      memory_pool_;
  std::unordered_map<const std::string, const MKLDNNPrimitivePtr,
                     std::hash<std::string>>
      primitive_pool_;
  std::unordered_map<const std::string, const MKLDNNPrimitiveDescPtr,
                     std::hash<std::string>>
      primitive_desc_pool_;
  std::vector<MKLDNNPrimitive> pipeline_;
154
  MKLDNNStreamPtr stream_;
T
tensor-tang 已提交
155 156 157 158 159
  MKLDNNEnginePtr engine_;
  bool ready_;
};
#endif

D
dzhwinter 已提交
160 161 162 163 164
/*! \brief device context pool singleton */
class DeviceContextPool {
 public:
  explicit DeviceContextPool(const std::vector<platform::Place>& places);

Y
Yang Yu 已提交
165
  static DeviceContextPool& Instance() {
D
dzhwinter 已提交
166 167 168 169 170
    PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!");
    return *pool;
  }

  /*! \brief  Create should only called by Init function */
Y
Yang Yu 已提交
171
  static DeviceContextPool& Init(const std::vector<platform::Place>& places) {
D
dzhwinter 已提交
172 173 174 175 176 177 178
    if (pool == nullptr) {
      pool = new DeviceContextPool(places);
    }
    return *pool;
  }

  /*! \brief  Return handle of single device context. */
Y
Yang Yu 已提交
179
  const platform::DeviceContext* Get(const platform::Place& place);
D
dzhwinter 已提交
180

Y
Yang Yu 已提交
181 182 183 184 185 186 187
  template <typename Place>
  const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
      const Place& place) {
    return reinterpret_cast<
        const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place));
  }

188 189
  size_t size() const { return device_contexts_.size(); }

D
dzhwinter 已提交
190 191
 private:
  static DeviceContextPool* pool;
D
dzhwinter 已提交
192
  constexpr static int LEFT_SHIFT = 8;
D
dzhwinter 已提交
193 194 195
  struct Hash {
    std::hash<int> hash_;
    size_t operator()(const platform::Place& place) const {
D
dzhwinter 已提交
196
      int pre_hash = place.which() << LEFT_SHIFT;
D
dzhwinter 已提交
197
      if (platform::is_gpu_place(place)) {
D
dzhwinter 已提交
198
        pre_hash += boost::get<platform::CUDAPlace>(place).GetDeviceId();
D
dzhwinter 已提交
199 200 201 202 203 204 205 206 207 208
      }
      return hash_(pre_hash);
    }
  };
  std::unordered_map<const platform::Place, const platform::DeviceContext*,
                     Hash>
      device_contexts_;
  DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};

Q
QI JUN 已提交
209 210
}  // namespace platform
}  // namespace paddle