device_context.h 4.1 KB
Newer Older
Q
QI JUN 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dzhwinter 已提交
14 15
#include <memory>
#include <unordered_map>
L
liaogang 已提交
16

17
#ifdef PADDLE_WITH_CUDA
Q
QI JUN 已提交
18 19
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
L
liaogang 已提交
20
#include "paddle/platform/gpu_info.h"
Q
QI JUN 已提交
21 22
#define EIGEN_USE_GPU
#endif
D
dzhwinter 已提交
23 24

#include "paddle/platform/enforce.h"
Q
qijun 已提交
25 26
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
Q
QI JUN 已提交
27

D
dzhwinter 已提交
28 29
#include "glog/logging.h"

Q
QI JUN 已提交
30 31 32 33 34 35
namespace paddle {
namespace platform {

class DeviceContext {
 public:
  virtual ~DeviceContext() {}
L
liaogang 已提交
36
  virtual Place GetPlace() const = 0;
Q
QI JUN 已提交
37

38
  virtual void Wait() const {}
Q
QI JUN 已提交
39 40
};

Q
qijun 已提交
41 42
class CPUDeviceContext : public DeviceContext {
 public:
43
  CPUDeviceContext();
Q
qijun 已提交
44
  explicit CPUDeviceContext(CPUPlace place);
Q
qijun 已提交
45

46
  Eigen::DefaultDevice* eigen_device() const;
Q
qijun 已提交
47

L
liaogang 已提交
48
  Place GetPlace() const override;
Y
Yu Yang 已提交
49

Q
qijun 已提交
50
 private:
D
dzhwinter 已提交
51
  CPUPlace place_;
Q
qijun 已提交
52
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
Q
QI JUN 已提交
53 54
};

55
#ifdef PADDLE_WITH_CUDA
56

Q
qijun 已提交
57
class EigenCudaStreamDevice;
D
dongzhihong 已提交
58

59
class CUDADeviceContext : public DeviceContext {
Q
QI JUN 已提交
60
 public:
D
dzhwinter 已提交
61
  explicit CUDADeviceContext(CUDAPlace place);
62
  virtual ~CUDADeviceContext();
Q
QI JUN 已提交
63

64
  /*! \brief  Wait for all operations completion in the stream. */
65
  void Wait() const override;
Q
QI JUN 已提交
66

67
  /*! \brief  Return place in the device context. */
L
liaogang 已提交
68
  Place GetPlace() const override;
69 70 71 72 73

  /*! \brief  Return eigen device in the device context. */
  Eigen::GpuDevice* eigen_device() const;

  /*! \brief  Return cublas handle in the device context. */
74
  cublasHandle_t cublas_handle() const;
75 76

  /*! \brief  Return cudnn  handle in the device context. */
77
  cudnnHandle_t cudnn_handle() const;
78

Q
init  
qijun 已提交
79
  /*! \brief  Return cuda stream in the device context. */
80
  cudaStream_t stream() const;
Q
QI JUN 已提交
81 82

 private:
D
dzhwinter 已提交
83
  CUDAPlace place_;
Q
QI JUN 已提交
84

Q
qijun 已提交
85
  std::unique_ptr<Eigen::GpuDevice> eigen_device_;
Q
init  
qijun 已提交
86
  std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
Q
QI JUN 已提交
87

88 89 90
  cudaStream_t stream_;
  cudnnHandle_t cudnn_handle_;
  cublasHandle_t cublas_handle_;
Q
QI JUN 已提交
91
};
Q
qijun 已提交
92

93
class CUDNNDeviceContext : public CUDADeviceContext {
D
dzhwinter 已提交
94
 public:
Q
QI JUN 已提交
95
  explicit CUDNNDeviceContext(CUDAPlace place);
96
  virtual ~CUDNNDeviceContext();
D
dzhwinter 已提交
97 98 99 100 101 102 103 104

  /*! \brief  Return cudnn  handle in the device context. */
  cudnnHandle_t cudnn_handle() const;

 private:
  cudnnHandle_t cudnn_handle_;
};

Q
QI JUN 已提交
105
#endif
Q
qijun 已提交
106

D
dzhwinter 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
/*! \brief device context pool singleton */
class DeviceContextPool {
 public:
  explicit DeviceContextPool(const std::vector<platform::Place>& places);

  static DeviceContextPool& Get() {
    PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!");
    return *pool;
  }

  /*! \brief  Create should only called by Init function */
  static DeviceContextPool& Create(const std::vector<platform::Place>& places) {
    if (pool == nullptr) {
      pool = new DeviceContextPool(places);
    }
    return *pool;
  }

  /*! \brief  Return handle of single device context. */
  const platform::DeviceContext* Borrow(const platform::Place& place);

  /*! \brief  Return handle of multi-device context. */
  std::vector<const platform::DeviceContext*> Borrow(
      const std::vector<platform::Place>& places);

  ~DeviceContextPool() {}

 private:
  static DeviceContextPool* pool;
D
dzhwinter 已提交
136
  constexpr static int LEFT_SHIFT = 8;
D
dzhwinter 已提交
137 138 139
  struct Hash {
    std::hash<int> hash_;
    size_t operator()(const platform::Place& place) const {
D
dzhwinter 已提交
140
      int pre_hash = place.which() + (1 << LEFT_SHIFT);
D
dzhwinter 已提交
141
      if (platform::is_gpu_place(place)) {
D
dzhwinter 已提交
142
        pre_hash += boost::get<platform::CUDAPlace>(place).GetDeviceId();
D
dzhwinter 已提交
143 144 145 146 147 148 149 150 151 152
      }
      return hash_(pre_hash);
    }
  };
  std::unordered_map<const platform::Place, const platform::DeviceContext*,
                     Hash>
      device_contexts_;
  DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};

Q
QI JUN 已提交
153 154
}  // namespace platform
}  // namespace paddle