place.h 9.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
14
#pragma once
15

S
sneaxiy 已提交
16
#include <functional>
Y
Yi Wang 已提交
17
#include <iostream>
18 19
#include <vector>

Y
Yi Wang 已提交
20 21
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h"
Y
Yi Wang 已提交
22

Y
Yi Wang 已提交
23 24
namespace paddle {
namespace platform {
Y
Yi Wang 已提交
25

26
struct CPUPlace {
Y
Yi Wang 已提交
27 28
  // WORKAROUND: for some reason, omitting this constructor
  // causes errors with boost 1.59 and OSX
29
  CPUPlace() {}
Y
Yi Wang 已提交
30

Y
Yi Wang 已提交
31
  // needed for variant equality comparison
32 33
  inline bool operator==(const CPUPlace &) const { return true; }
  inline bool operator!=(const CPUPlace &) const { return false; }
34
  inline bool operator<(const CPUPlace &) const { return false; }
Y
Yi Wang 已提交
35 36
};

D
dzhwinter 已提交
37 38 39
struct CUDAPlace {
  CUDAPlace() : CUDAPlace(0) {}
  explicit CUDAPlace(int d) : device(d) {}
Y
Yi Wang 已提交
40

Y
Yu Yang 已提交
41
  inline int GetDeviceId() const { return device; }
Y
Yi Wang 已提交
42
  // needed for variant equality comparison
D
dzhwinter 已提交
43 44 45 46
  inline bool operator==(const CUDAPlace &o) const {
    return device == o.device;
  }
  inline bool operator!=(const CUDAPlace &o) const { return !(*this == o); }
47
  inline bool operator<(const CUDAPlace &o) const { return device < o.device; }
Y
Yi Wang 已提交
48 49 50 51

  int device;
};

C
chengduoZH 已提交
52 53 54 55 56 57
struct CUDAPinnedPlace {
  CUDAPinnedPlace() {}

  // needed for variant equality comparison
  inline bool operator==(const CUDAPinnedPlace &) const { return true; }
  inline bool operator!=(const CUDAPinnedPlace &) const { return false; }
58
  inline bool operator<(const CUDAPinnedPlace &) const { return false; }
C
chengduoZH 已提交
59 60
};

61 62 63 64 65 66 67 68 69 70 71 72 73 74
// Place for Baidu Kunlun Accelerator
struct XPUPlace {
  XPUPlace() : XPUPlace(0) {}
  explicit XPUPlace(int d) : device(d) {}

  inline int GetDeviceId() const { return device; }
  // needed for variant equality comparison
  inline bool operator==(const XPUPlace &o) const { return device == o.device; }
  inline bool operator!=(const XPUPlace &o) const { return !(*this == o); }
  inline bool operator<(const XPUPlace &o) const { return device < o.device; }

  int device;
};

75 76 77 78 79 80 81 82 83 84 85 86 87
struct NPUPlace {
  NPUPlace() : NPUPlace(0) {}
  explicit NPUPlace(int d) : device(d) {}

  inline int GetDeviceId() const { return device; }
  // needed for variant equality comparison
  inline bool operator==(const NPUPlace &o) const { return device == o.device; }
  inline bool operator!=(const NPUPlace &o) const { return !(*this == o); }
  inline bool operator<(const NPUPlace &o) const { return device < o.device; }

  int device;
};

88 89 90 91 92 93 94 95
struct NPUPinnedPlace {
  NPUPinnedPlace() {}

  inline bool operator==(const NPUPinnedPlace &) const { return true; }
  inline bool operator!=(const NPUPinnedPlace &) const { return false; }
  inline bool operator<(const NPUPinnedPlace &) const { return false; }
};

D
dzhwinter 已提交
96
struct IsCUDAPlace : public boost::static_visitor<bool> {
97
  bool operator()(const CPUPlace &) const { return false; }
98
  bool operator()(const XPUPlace &) const { return false; }
99
  bool operator()(const NPUPlace &) const { return false; }
100
  bool operator()(const NPUPinnedPlace &) const { return false; }
101
  bool operator()(const CUDAPlace &) const { return true; }
C
chengduoZH 已提交
102
  bool operator()(const CUDAPinnedPlace &) const { return false; }
T
tensor-tang 已提交
103 104
};

C
chengduoZH 已提交
105
struct IsCPUPlace : public boost::static_visitor<bool> {
106
  bool operator()(const CPUPlace &) const { return true; }
107
  bool operator()(const XPUPlace &) const { return false; }
108
  bool operator()(const NPUPlace &) const { return false; }
109
  bool operator()(const NPUPinnedPlace &) const { return false; }
C
chengduoZH 已提交
110 111 112 113 114 115
  bool operator()(const CUDAPlace &) const { return false; }
  bool operator()(const CUDAPinnedPlace &) const { return false; }
};

struct IsCUDAPinnedPlace : public boost::static_visitor<bool> {
  bool operator()(const CPUPlace &) const { return false; }
116
  bool operator()(const XPUPlace &) const { return false; }
117
  bool operator()(const NPUPlace &) const { return false; }
118
  bool operator()(const NPUPinnedPlace &) const { return false; }
C
chengduoZH 已提交
119 120 121 122
  bool operator()(const CUDAPlace &) const { return false; }
  bool operator()(const CUDAPinnedPlace &cuda_pinned) const { return true; }
};

123 124
struct IsXPUPlace : public boost::static_visitor<bool> {
  bool operator()(const CPUPlace &) const { return false; }
125 126
  bool operator()(const XPUPlace &) const { return true; }
  bool operator()(const NPUPlace &) const { return false; }
127
  bool operator()(const NPUPinnedPlace &) const { return false; }
128 129 130 131 132 133 134 135
  bool operator()(const CUDAPlace &) const { return false; }
  bool operator()(const CUDAPinnedPlace &) const { return false; }
};

struct IsNPUPlace : public boost::static_visitor<bool> {
  bool operator()(const CPUPlace &) const { return false; }
  bool operator()(const XPUPlace &) const { return false; }
  bool operator()(const NPUPlace &) const { return true; }
136 137 138 139 140 141 142 143 144 145
  bool operator()(const NPUPinnedPlace &) const { return false; }
  bool operator()(const CUDAPlace &) const { return false; }
  bool operator()(const CUDAPinnedPlace &) const { return false; }
};

struct IsNPUPinnedPlace : public boost::static_visitor<bool> {
  bool operator()(const CPUPlace &) const { return false; }
  bool operator()(const XPUPlace &) const { return false; }
  bool operator()(const NPUPlace &) const { return false; }
  bool operator()(const NPUPinnedPlace &) const { return true; }
146 147 148 149
  bool operator()(const CUDAPlace &) const { return false; }
  bool operator()(const CUDAPinnedPlace &) const { return false; }
};

150
class Place : public boost::variant<CUDAPlace, XPUPlace, NPUPlace, CPUPlace,
151
                                    CUDAPinnedPlace, NPUPinnedPlace> {
152
 private:
153 154
  using PlaceBase = boost::variant<CUDAPlace, XPUPlace, NPUPlace, CPUPlace,
                                   CUDAPinnedPlace, NPUPinnedPlace>;
155 156 157 158

 public:
  Place() = default;
  Place(const CPUPlace &cpu_place) : PlaceBase(cpu_place) {}     // NOLINT
159
  Place(const XPUPlace &xpu_place) : PlaceBase(xpu_place) {}     // NOLINT
160
  Place(const NPUPlace &npu_place) : PlaceBase(npu_place) {}     // NOLINT
161 162 163
  Place(const CUDAPlace &cuda_place) : PlaceBase(cuda_place) {}  // NOLINT
  Place(const CUDAPinnedPlace &cuda_pinned_place)                // NOLINT
      : PlaceBase(cuda_pinned_place) {}
164 165
  Place(const NPUPinnedPlace &npu_pinned_place)  // NOLINT
      : PlaceBase(npu_pinned_place) {}
166 167 168 169 170 171 172 173

  bool operator<(const Place &place) const {
    return PlaceBase::operator<(static_cast<const PlaceBase &>(place));
  }
  bool operator==(const Place &place) const {
    return PlaceBase::operator==(static_cast<const PlaceBase &>(place));
  }
};
Y
Yi Wang 已提交
174

Y
Yang Yu 已提交
175 176
using PlaceList = std::vector<Place>;

Y
Yi Wang 已提交
177
bool is_gpu_place(const Place &);
178
bool is_xpu_place(const Place &);
179
bool is_npu_place(const Place &);
Y
Yi Wang 已提交
180
bool is_cpu_place(const Place &);
C
chengduoZH 已提交
181
bool is_cuda_pinned_place(const Place &);
182
bool is_npu_pinned_place(const Place &);
Y
Yi Wang 已提交
183
bool places_are_same_class(const Place &, const Place &);
184
bool is_same_place(const Place &, const Place &);
Y
Yi Wang 已提交
185

Y
Yi Wang 已提交
186
std::ostream &operator<<(std::ostream &, const Place &);
Y
Yi Wang 已提交
187

Y
Yang Yu 已提交
188 189 190 191 192 193 194 195 196 197
template <typename Visitor>
struct PlaceVisitorWrapper
    : public boost::static_visitor<typename Visitor::result_type> {
  const Visitor &visitor_;
  explicit PlaceVisitorWrapper(const Visitor &visitor) : visitor_(visitor) {}

  typename Visitor::result_type operator()(const CPUPlace &cpu) const {
    return visitor_(cpu);
  }

198 199 200 201 202 203 204
  typename Visitor::result_type operator()(const XPUPlace &xpu) const {
#ifdef PADDLE_WITH_XPU
    return visitor_(xpu);
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "Paddle is not compiled with XPU. Cannot visit xpu device"));
    return typename Visitor::result_type();
205 206 207 208 209 210 211 212 213 214
#endif
  }

  typename Visitor::result_type operator()(const NPUPlace &npu) const {
#ifdef PADDLE_WITH_ASCEND
    return visitor_(npu);
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "Paddle is not compiled with NPU. Cannot visit npu device"));
    return typename Visitor::result_type();
215 216 217 218 219 220 221 222 223 224 225
#endif
  }

  typename Visitor::result_type operator()(
      const NPUPinnedPlace &npu_pinned) const {
#ifdef PADDLE_WITH_ASCEND_CL
    return visitor_(npu_pinned);
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "Paddle is not compiled with NPU. Cannot visit npu_pinned"));
    return typename Visitor::result_type();
226 227 228
#endif
  }

Y
Yang Yu 已提交
229
  typename Visitor::result_type operator()(const CUDAPlace &cuda) const {
230
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Y
Yang Yu 已提交
231 232
    return visitor_(cuda);
#else
233 234
    PADDLE_THROW(platform::errors::Unavailable(
        "Paddle is not compiled with CUDA. Cannot visit cuda device"));
Y
Yang Yu 已提交
235 236 237
    return typename Visitor::result_type();
#endif
  }
C
chengduoZH 已提交
238 239 240

  typename Visitor::result_type operator()(
      const CUDAPinnedPlace &cuda_pinned) const {
241
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
C
chengduoZH 已提交
242
    return visitor_(cuda_pinned);
C
chengduoZH 已提交
243
#else
244 245
    PADDLE_THROW(platform::errors::Unavailable(
        "Paddle is not compiled with CUDA. Cannot visit cuda_pinned"));
C
chengduoZH 已提交
246 247
    return typename Visitor::result_type();
#endif
C
chengduoZH 已提交
248
  }
Y
Yang Yu 已提交
249 250 251 252 253 254 255 256
};

template <typename Visitor>
typename Visitor::result_type VisitPlace(const Place &place,
                                         const Visitor &visitor) {
  return boost::apply_visitor(PlaceVisitorWrapper<Visitor>(visitor), place);
}

Y
Yi Wang 已提交
257 258
}  // namespace platform
}  // namespace paddle