device_context.h 4.4 KB
Newer Older
Q
QI JUN 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dzhwinter 已提交
14 15
#include <memory>
#include <unordered_map>
L
liaogang 已提交
16

17
#ifdef PADDLE_WITH_CUDA
Q
QI JUN 已提交
18 19
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
L
liaogang 已提交
20
#include "paddle/platform/gpu_info.h"
Q
QI JUN 已提交
21 22
#define EIGEN_USE_GPU
#endif
D
dzhwinter 已提交
23 24

#include "paddle/platform/enforce.h"
Q
qijun 已提交
25 26
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
Q
QI JUN 已提交
27

D
dzhwinter 已提交
28 29
#include "glog/logging.h"

Q
QI JUN 已提交
30 31 32 33 34 35
namespace paddle {
namespace platform {

class DeviceContext {
 public:
  virtual ~DeviceContext() {}
L
liaogang 已提交
36
  virtual Place GetPlace() const = 0;
Q
QI JUN 已提交
37

38
  virtual void Wait() const {}
Q
QI JUN 已提交
39 40
};

Q
qijun 已提交
41 42
class CPUDeviceContext : public DeviceContext {
 public:
43
  CPUDeviceContext();
Q
qijun 已提交
44
  explicit CPUDeviceContext(CPUPlace place);
Q
qijun 已提交
45

46
  Eigen::DefaultDevice* eigen_device() const;
Q
qijun 已提交
47

L
liaogang 已提交
48
  Place GetPlace() const override;
Y
Yu Yang 已提交
49

Q
qijun 已提交
50
 private:
D
dzhwinter 已提交
51
  CPUPlace place_;
Q
qijun 已提交
52
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
Q
QI JUN 已提交
53 54
};

Y
Yang Yu 已提交
55 56 57 58 59 60 61 62
template <typename Place>
struct DefaultDeviceContextType;

template <>
struct DefaultDeviceContextType<platform::CPUPlace> {
  using TYPE = CPUDeviceContext;
};

63
#ifdef PADDLE_WITH_CUDA
64

Q
qijun 已提交
65
class EigenCudaStreamDevice;
D
dongzhihong 已提交
66

67
class CUDADeviceContext : public DeviceContext {
Q
QI JUN 已提交
68
 public:
D
dzhwinter 已提交
69
  explicit CUDADeviceContext(CUDAPlace place);
70
  virtual ~CUDADeviceContext();
Q
QI JUN 已提交
71

72
  /*! \brief  Wait for all operations completion in the stream. */
73
  void Wait() const override;
Q
QI JUN 已提交
74

75
  /*! \brief  Return place in the device context. */
L
liaogang 已提交
76
  Place GetPlace() const override;
77 78 79 80 81

  /*! \brief  Return eigen device in the device context. */
  Eigen::GpuDevice* eigen_device() const;

  /*! \brief  Return cublas handle in the device context. */
82
  cublasHandle_t cublas_handle() const;
83 84

  /*! \brief  Return cudnn  handle in the device context. */
85
  cudnnHandle_t cudnn_handle() const;
86

Q
init  
qijun 已提交
87
  /*! \brief  Return cuda stream in the device context. */
88
  cudaStream_t stream() const;
Q
QI JUN 已提交
89 90

 private:
D
dzhwinter 已提交
91
  CUDAPlace place_;
Q
QI JUN 已提交
92

Q
qijun 已提交
93
  std::unique_ptr<Eigen::GpuDevice> eigen_device_;
Q
init  
qijun 已提交
94
  std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
Q
QI JUN 已提交
95

96 97 98
  cudaStream_t stream_;
  cudnnHandle_t cudnn_handle_;
  cublasHandle_t cublas_handle_;
Q
QI JUN 已提交
99
};
Q
qijun 已提交
100

Y
Yang Yu 已提交
101 102
template <>
struct DefaultDeviceContextType<platform::CUDAPlace> {
Y
Yang Yu 已提交
103
  using TYPE = CUDADeviceContext;
Y
Yang Yu 已提交
104 105
};

106
class CUDNNDeviceContext : public CUDADeviceContext {
D
dzhwinter 已提交
107
 public:
Q
QI JUN 已提交
108
  explicit CUDNNDeviceContext(CUDAPlace place);
109
  virtual ~CUDNNDeviceContext();
D
dzhwinter 已提交
110 111 112 113 114 115 116 117

  /*! \brief  Return cudnn  handle in the device context. */
  cudnnHandle_t cudnn_handle() const;

 private:
  cudnnHandle_t cudnn_handle_;
};

Q
QI JUN 已提交
118
#endif
Q
qijun 已提交
119

D
dzhwinter 已提交
120 121 122 123 124
/*! \brief device context pool singleton */
class DeviceContextPool {
 public:
  explicit DeviceContextPool(const std::vector<platform::Place>& places);

Y
Yang Yu 已提交
125
  static DeviceContextPool& Instance() {
D
dzhwinter 已提交
126 127 128 129 130
    PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!");
    return *pool;
  }

  /*! \brief  Create should only called by Init function */
Y
Yang Yu 已提交
131
  static DeviceContextPool& Init(const std::vector<platform::Place>& places) {
D
dzhwinter 已提交
132 133 134 135 136 137 138
    if (pool == nullptr) {
      pool = new DeviceContextPool(places);
    }
    return *pool;
  }

  /*! \brief  Return handle of single device context. */
Y
Yang Yu 已提交
139
  const platform::DeviceContext* Get(const platform::Place& place);
D
dzhwinter 已提交
140

Y
Yang Yu 已提交
141 142 143 144 145 146 147
  template <typename Place>
  const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
      const Place& place) {
    return reinterpret_cast<
        const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place));
  }

D
dzhwinter 已提交
148 149
 private:
  static DeviceContextPool* pool;
D
dzhwinter 已提交
150
  constexpr static int LEFT_SHIFT = 8;
D
dzhwinter 已提交
151 152 153
  struct Hash {
    std::hash<int> hash_;
    size_t operator()(const platform::Place& place) const {
D
dzhwinter 已提交
154
      int pre_hash = place.which() + (1 << LEFT_SHIFT);
D
dzhwinter 已提交
155
      if (platform::is_gpu_place(place)) {
D
dzhwinter 已提交
156
        pre_hash += boost::get<platform::CUDAPlace>(place).GetDeviceId();
D
dzhwinter 已提交
157 158 159 160 161 162 163 164 165 166
      }
      return hash_(pre_hash);
    }
  };
  std::unordered_map<const platform::Place, const platform::DeviceContext*,
                     Hash>
      device_contexts_;
  DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};

Q
QI JUN 已提交
167 168
}  // namespace platform
}  // namespace paddle