place.h 11.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
14
#pragma once
15

S
sneaxiy 已提交
16
#include <functional>
Y
Yi Wang 已提交
17
#include <iostream>
18 19
#include <vector>

Y
Yi Wang 已提交
20 21
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h"
22 23 24
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/device/npu/enforce_npu.h"
#endif
Y
Yi Wang 已提交
25

Y
Yi Wang 已提交
26 27
namespace paddle {
namespace platform {
Y
Yi Wang 已提交
28

29
struct CPUPlace {
Y
Yi Wang 已提交
30 31
  // WORKAROUND: for some reason, omitting this constructor
  // causes errors with boost 1.59 and OSX
32
  CPUPlace() {}
Y
Yi Wang 已提交
33

Y
Yi Wang 已提交
34
  // needed for variant equality comparison
35 36
  inline bool operator==(const CPUPlace &) const { return true; }
  inline bool operator!=(const CPUPlace &) const { return false; }
37
  inline bool operator<(const CPUPlace &) const { return false; }
Y
Yi Wang 已提交
38 39
};

D
dzhwinter 已提交
40 41 42
struct CUDAPlace {
  CUDAPlace() : CUDAPlace(0) {}
  explicit CUDAPlace(int d) : device(d) {}
Y
Yi Wang 已提交
43

Y
Yu Yang 已提交
44
  inline int GetDeviceId() const { return device; }
Y
Yi Wang 已提交
45
  // needed for variant equality comparison
D
dzhwinter 已提交
46 47 48 49
  inline bool operator==(const CUDAPlace &o) const {
    return device == o.device;
  }
  inline bool operator!=(const CUDAPlace &o) const { return !(*this == o); }
50
  inline bool operator<(const CUDAPlace &o) const { return device < o.device; }
Y
Yi Wang 已提交
51 52 53 54

  int device;
};

C
chengduoZH 已提交
55 56 57 58 59 60
struct CUDAPinnedPlace {
  CUDAPinnedPlace() {}

  // needed for variant equality comparison
  inline bool operator==(const CUDAPinnedPlace &) const { return true; }
  inline bool operator!=(const CUDAPinnedPlace &) const { return false; }
61
  inline bool operator<(const CUDAPinnedPlace &) const { return false; }
C
chengduoZH 已提交
62 63
};

64 65 66 67 68 69 70 71 72 73 74 75 76 77
// Place for Baidu Kunlun Accelerator
struct XPUPlace {
  XPUPlace() : XPUPlace(0) {}
  explicit XPUPlace(int d) : device(d) {}

  inline int GetDeviceId() const { return device; }
  // needed for variant equality comparison
  inline bool operator==(const XPUPlace &o) const { return device == o.device; }
  inline bool operator!=(const XPUPlace &o) const { return !(*this == o); }
  inline bool operator<(const XPUPlace &o) const { return device < o.device; }

  int device;
};

78 79 80 81 82 83 84 85 86 87 88 89 90
struct NPUPlace {
  NPUPlace() : NPUPlace(0) {}
  explicit NPUPlace(int d) : device(d) {}

  inline int GetDeviceId() const { return device; }
  // needed for variant equality comparison
  inline bool operator==(const NPUPlace &o) const { return device == o.device; }
  inline bool operator!=(const NPUPlace &o) const { return !(*this == o); }
  inline bool operator<(const NPUPlace &o) const { return device < o.device; }

  int device;
};

91 92 93 94 95 96 97
struct NPUPinnedPlace {
  NPUPinnedPlace() {}

  inline bool operator==(const NPUPinnedPlace &) const { return true; }
  inline bool operator!=(const NPUPinnedPlace &) const { return false; }
  inline bool operator<(const NPUPinnedPlace &) const { return false; }
};
J
jianghaicheng 已提交
98 99 100 101 102 103 104 105 106 107 108 109
struct IPUPlace {
  IPUPlace() : IPUPlace(0) {}
  explicit IPUPlace(int d) : device(d) {}

  inline int GetDeviceId() const { return device; }
  // needed for variant equality comparison
  inline bool operator==(const IPUPlace &o) const { return device == o.device; }
  inline bool operator!=(const IPUPlace &o) const { return !(*this == o); }
  inline bool operator<(const IPUPlace &o) const { return device < o.device; }

  int device;
};
110

D
dzhwinter 已提交
111
struct IsCUDAPlace : public boost::static_visitor<bool> {
112
  bool operator()(const CPUPlace &) const { return false; }
113
  bool operator()(const XPUPlace &) const { return false; }
114
  bool operator()(const NPUPlace &) const { return false; }
115
  bool operator()(const NPUPinnedPlace &) const { return false; }
J
jianghaicheng 已提交
116
  bool operator()(const IPUPlace &) const { return false; }
117
  bool operator()(const CUDAPlace &) const { return true; }
C
chengduoZH 已提交
118
  bool operator()(const CUDAPinnedPlace &) const { return false; }
T
tensor-tang 已提交
119 120
};

C
chengduoZH 已提交
121
struct IsCPUPlace : public boost::static_visitor<bool> {
122
  bool operator()(const CPUPlace &) const { return true; }
123
  bool operator()(const XPUPlace &) const { return false; }
124
  bool operator()(const NPUPlace &) const { return false; }
125
  bool operator()(const NPUPinnedPlace &) const { return false; }
J
jianghaicheng 已提交
126
  bool operator()(const IPUPlace &) const { return false; }
C
chengduoZH 已提交
127 128 129 130 131 132
  bool operator()(const CUDAPlace &) const { return false; }
  bool operator()(const CUDAPinnedPlace &) const { return false; }
};

struct IsCUDAPinnedPlace : public boost::static_visitor<bool> {
  bool operator()(const CPUPlace &) const { return false; }
133
  bool operator()(const XPUPlace &) const { return false; }
134
  bool operator()(const NPUPlace &) const { return false; }
135
  bool operator()(const NPUPinnedPlace &) const { return false; }
J
jianghaicheng 已提交
136
  bool operator()(const IPUPlace &) const { return false; }
C
chengduoZH 已提交
137 138 139 140
  bool operator()(const CUDAPlace &) const { return false; }
  bool operator()(const CUDAPinnedPlace &cuda_pinned) const { return true; }
};

141 142
struct IsXPUPlace : public boost::static_visitor<bool> {
  bool operator()(const CPUPlace &) const { return false; }
143 144
  bool operator()(const XPUPlace &) const { return true; }
  bool operator()(const NPUPlace &) const { return false; }
145
  bool operator()(const NPUPinnedPlace &) const { return false; }
J
jianghaicheng 已提交
146
  bool operator()(const IPUPlace &) const { return false; }
147 148 149 150 151 152 153 154
  bool operator()(const CUDAPlace &) const { return false; }
  bool operator()(const CUDAPinnedPlace &) const { return false; }
};

struct IsNPUPlace : public boost::static_visitor<bool> {
  bool operator()(const CPUPlace &) const { return false; }
  bool operator()(const XPUPlace &) const { return false; }
  bool operator()(const NPUPlace &) const { return true; }
155
  bool operator()(const NPUPinnedPlace &) const { return false; }
J
jianghaicheng 已提交
156
  bool operator()(const IPUPlace &) const { return false; }
157 158 159 160 161 162 163 164
  bool operator()(const CUDAPlace &) const { return false; }
  bool operator()(const CUDAPinnedPlace &) const { return false; }
};

struct IsNPUPinnedPlace : public boost::static_visitor<bool> {
  bool operator()(const CPUPlace &) const { return false; }
  bool operator()(const XPUPlace &) const { return false; }
  bool operator()(const NPUPlace &) const { return false; }
J
jianghaicheng 已提交
165 166 167
  bool operator()(const IPUPlace &) const { return false; }
  bool operator()(const CUDAPlace &) const { return false; }
  bool operator()(const CUDAPinnedPlace &) const { return false; }
168
  bool operator()(const NPUPinnedPlace &) const { return true; }
J
jianghaicheng 已提交
169 170 171 172 173 174
};
struct IsIPUPlace : public boost::static_visitor<bool> {
  bool operator()(const CPUPlace &) const { return false; }
  bool operator()(const XPUPlace &) const { return false; }
  bool operator()(const NPUPlace &) const { return false; }
  bool operator()(const IPUPlace &) const { return true; }
175 176
  bool operator()(const CUDAPlace &) const { return false; }
  bool operator()(const CUDAPinnedPlace &) const { return false; }
J
jianghaicheng 已提交
177
  bool operator()(const NPUPinnedPlace &) const { return false; }
178 179
};

180
class Place : public boost::variant<CUDAPlace, XPUPlace, NPUPlace, CPUPlace,
J
jianghaicheng 已提交
181
                                    CUDAPinnedPlace, NPUPinnedPlace, IPUPlace> {
182
 private:
183
  using PlaceBase = boost::variant<CUDAPlace, XPUPlace, NPUPlace, CPUPlace,
J
jianghaicheng 已提交
184
                                   CUDAPinnedPlace, NPUPinnedPlace, IPUPlace>;
185 186 187 188

 public:
  Place() = default;
  Place(const CPUPlace &cpu_place) : PlaceBase(cpu_place) {}     // NOLINT
189
  Place(const XPUPlace &xpu_place) : PlaceBase(xpu_place) {}     // NOLINT
190
  Place(const NPUPlace &npu_place) : PlaceBase(npu_place) {}     // NOLINT
J
jianghaicheng 已提交
191
  Place(const IPUPlace &ipu_place) : PlaceBase(ipu_place) {}     // NOLINT
192 193 194
  Place(const CUDAPlace &cuda_place) : PlaceBase(cuda_place) {}  // NOLINT
  Place(const CUDAPinnedPlace &cuda_pinned_place)                // NOLINT
      : PlaceBase(cuda_pinned_place) {}
195 196
  Place(const NPUPinnedPlace &npu_pinned_place)  // NOLINT
      : PlaceBase(npu_pinned_place) {}
197 198 199 200 201 202 203 204

  bool operator<(const Place &place) const {
    return PlaceBase::operator<(static_cast<const PlaceBase &>(place));
  }
  bool operator==(const Place &place) const {
    return PlaceBase::operator==(static_cast<const PlaceBase &>(place));
  }
};
Y
Yi Wang 已提交
205

Y
Yang Yu 已提交
206 207
using PlaceList = std::vector<Place>;

Y
Yi Wang 已提交
208
bool is_gpu_place(const Place &);
209
bool is_xpu_place(const Place &);
210
bool is_npu_place(const Place &);
J
jianghaicheng 已提交
211
bool is_ipu_place(const Place &);
Y
Yi Wang 已提交
212
bool is_cpu_place(const Place &);
C
chengduoZH 已提交
213
bool is_cuda_pinned_place(const Place &);
214
bool is_npu_pinned_place(const Place &);
Y
Yi Wang 已提交
215
bool places_are_same_class(const Place &, const Place &);
216
bool is_same_place(const Place &, const Place &);
Y
Yi Wang 已提交
217

Y
Yi Wang 已提交
218
std::ostream &operator<<(std::ostream &, const Place &);
Y
Yi Wang 已提交
219

Y
Yang Yu 已提交
220 221 222 223 224 225 226 227 228 229
template <typename Visitor>
struct PlaceVisitorWrapper
    : public boost::static_visitor<typename Visitor::result_type> {
  const Visitor &visitor_;
  explicit PlaceVisitorWrapper(const Visitor &visitor) : visitor_(visitor) {}

  typename Visitor::result_type operator()(const CPUPlace &cpu) const {
    return visitor_(cpu);
  }

230 231 232 233 234 235 236
  typename Visitor::result_type operator()(const XPUPlace &xpu) const {
#ifdef PADDLE_WITH_XPU
    return visitor_(xpu);
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "Paddle is not compiled with XPU. Cannot visit xpu device"));
    return typename Visitor::result_type();
237 238 239 240 241 242 243 244 245 246
#endif
  }

  typename Visitor::result_type operator()(const NPUPlace &npu) const {
#ifdef PADDLE_WITH_ASCEND
    return visitor_(npu);
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "Paddle is not compiled with NPU. Cannot visit npu device"));
    return typename Visitor::result_type();
247 248 249 250 251 252 253 254 255 256 257
#endif
  }

  typename Visitor::result_type operator()(
      const NPUPinnedPlace &npu_pinned) const {
#ifdef PADDLE_WITH_ASCEND_CL
    return visitor_(npu_pinned);
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "Paddle is not compiled with NPU. Cannot visit npu_pinned"));
    return typename Visitor::result_type();
J
jianghaicheng 已提交
258 259 260 261 262 263 264 265 266
#endif
  }
  typename Visitor::result_type operator()(const IPUPlace &ipu) const {
#ifdef PADDLE_WITH_IPU
    return visitor_(ipu);
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "Paddle is not compiled with IPU. Cannot visit ipu device"));
    return typename Visitor::result_type();
267 268 269
#endif
  }

Y
Yang Yu 已提交
270
  typename Visitor::result_type operator()(const CUDAPlace &cuda) const {
271
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Y
Yang Yu 已提交
272 273
    return visitor_(cuda);
#else
274 275
    PADDLE_THROW(platform::errors::Unavailable(
        "Paddle is not compiled with CUDA. Cannot visit cuda device"));
Y
Yang Yu 已提交
276 277 278
    return typename Visitor::result_type();
#endif
  }
C
chengduoZH 已提交
279 280 281

  typename Visitor::result_type operator()(
      const CUDAPinnedPlace &cuda_pinned) const {
282
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
C
chengduoZH 已提交
283
    return visitor_(cuda_pinned);
C
chengduoZH 已提交
284
#else
285 286
    PADDLE_THROW(platform::errors::Unavailable(
        "Paddle is not compiled with CUDA. Cannot visit cuda_pinned"));
C
chengduoZH 已提交
287 288
    return typename Visitor::result_type();
#endif
C
chengduoZH 已提交
289
  }
Y
Yang Yu 已提交
290 291 292 293 294 295 296 297
};

template <typename Visitor>
typename Visitor::result_type VisitPlace(const Place &place,
                                         const Visitor &visitor) {
  return boost::apply_visitor(PlaceVisitorWrapper<Visitor>(visitor), place);
}

Y
Yi Wang 已提交
298 299
}  // namespace platform
}  // namespace paddle