place.h 6.2 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>

19 20
#include "paddle/phi/api/include/dll_decl.h"

21
namespace phi {
22 23

enum class AllocationType : int8_t {
24
  UNDEFINED = 0,
25 26 27 28 29 30 31 32
  CPU = 1,
  GPU = 2,
  GPUPINNED = 3,
  XPU = 4,
  NPU = 5,
  NPUPINNED = 6,
  IPU = 7,
  MLU = 8,
33
  CUSTOM = 9,
34
};
35

36
const char* AllocationTypeStr(AllocationType type);
37

38 39 40 41
PADDLE_API size_t
GetOrRegisterGlobalDeviceTypeId(const std::string& device_type);

PADDLE_API std::string GetGlobalDeviceType(size_t device_type_id_);
42

43
/// \brief The place is used to specify where the data is stored.
44
class PADDLE_API Place {
45
 public:
46
  Place() : device(0), alloc_type_(AllocationType::UNDEFINED) {}
47

48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
  explicit Place(AllocationType type,
                 int8_t id,
                 const std::string& dev_type = "")
      : device(id),
        alloc_type_(type),
        device_type_id_(GetOrRegisterGlobalDeviceTypeId(dev_type)) {}

  explicit Place(AllocationType type, const std::string& dev_type = "")
      : device(0),
        alloc_type_(type),
        device_type_id_(GetOrRegisterGlobalDeviceTypeId(dev_type)) {}

  void Reset(AllocationType type,
             int8_t device_id = 0,
             const std::string& dev_type = "") noexcept {
63 64
    alloc_type_ = type;
    device = device_id;
65 66 67
    if (!dev_type.empty()) {
      device_type_id_ = GetOrRegisterGlobalDeviceTypeId(dev_type);
    }
68 69
  }

70
  AllocationType GetType() const { return alloc_type_; }
71

72
  int8_t GetDeviceId() const { return device; }
73

74 75 76 77
  std::string GetDeviceType() const {
    return GetGlobalDeviceType(device_type_id_);
  }

78 79
  std::string DebugString() const;

80 81 82 83 84 85 86 87 88
  struct Hash {
    // Note: Now the number of bits we need does not exceed 32 bits, so there is
    // no need to use 64 bits. If needed in the future, it can be expanded,
    // but now we don’t over-design.
    uint32_t operator()(const Place& place) const;
  };

  uint32_t HashValue() const { return Hash()(*this); }

89
  inline bool operator==(const Place& rhs) const {
90 91 92 93
    return HashValue() == rhs.HashValue();
  }
  inline bool operator!=(const Place& rhs) const {
    return HashValue() != rhs.HashValue();
94 95
  }
  inline bool operator<(const Place& rhs) const {
96
    return HashValue() < rhs.HashValue();
97 98
  }

99 100 101
 public:
  // TODO(wilber): Just because of backward compatibility, it needs to be
  // changed to private in the future.
102
  int8_t device{0};
103 104

 private:
105
  AllocationType alloc_type_{AllocationType::UNDEFINED};
106
  size_t device_type_id_;
107 108 109 110
};

class CPUPlace : public Place {
 public:
111 112 113 114
  CPUPlace() : Place(AllocationType::CPU) {}

  CPUPlace(const CPUPlace&) = default;
  CPUPlace(const Place& place) : Place(AllocationType::CPU) {}  // NOLINT
115 116 117 118 119 120
};

class GPUPlace : public Place {
 public:
  GPUPlace() : Place(AllocationType::GPU, 0) {}
  explicit GPUPlace(int device_id) : Place(AllocationType::GPU, device_id) {}
121 122 123 124

  GPUPlace(const GPUPlace&) = default;
  GPUPlace(const Place& place)  // NOLINT
      : Place(AllocationType::GPU, place.GetDeviceId()) {}
125 126 127 128 129
};

class GPUPinnedPlace : public Place {
 public:
  GPUPinnedPlace() : Place(AllocationType::GPUPINNED) {}
130 131 132 133

  GPUPinnedPlace(const GPUPinnedPlace&) = default;
  GPUPinnedPlace(const Place& place)  // NOLINT
      : Place(AllocationType::GPUPINNED) {}
134 135 136 137 138 139
};

class XPUPlace : public Place {
 public:
  XPUPlace() : Place(AllocationType::XPU, 0) {}
  explicit XPUPlace(int device_id) : Place(AllocationType::XPU, device_id) {}
140 141 142 143

  XPUPlace(const XPUPlace&) = default;
  XPUPlace(const Place& place)  // NOLINT
      : Place(AllocationType::XPU, place.GetDeviceId()) {}
144 145 146 147 148
};

class NPUPlace : public Place {
 public:
  NPUPlace() : Place(AllocationType::NPU, 0) {}
149 150 151 152 153
  explicit NPUPlace(int device_id) : Place(AllocationType::NPU, device_id) {}

  NPUPlace(const NPUPlace&) = default;
  NPUPlace(const Place& place)  // NOLINT
      : Place(AllocationType::NPU, place.GetDeviceId()) {}
154 155 156 157 158
};

class NPUPinnedPlace : public Place {
 public:
  NPUPinnedPlace() : Place(AllocationType::NPUPINNED) {}
159 160 161 162

  NPUPinnedPlace(const NPUPinnedPlace&) = default;
  NPUPinnedPlace(const Place& place)  // NOLINT
      : Place(AllocationType::NPUPINNED) {}
163 164 165 166
};

class IPUPlace : public Place {
 public:
167 168 169 170 171 172
  IPUPlace() : Place(AllocationType::IPU, 0) {}
  explicit IPUPlace(int device_id) : Place(AllocationType::IPU, device_id) {}

  IPUPlace(const IPUPlace&) = default;
  IPUPlace(const Place& place)  // NOLINT
      : Place(AllocationType::IPU, place.GetDeviceId()) {}
173 174 175 176 177 178
};

class MLUPlace : public Place {
 public:
  MLUPlace() : Place(AllocationType::MLU, 0) {}
  explicit MLUPlace(int device_id) : Place(AllocationType::MLU, device_id) {}
179 180 181 182

  MLUPlace(const MLUPlace&) = default;
  MLUPlace(const Place& place)  // NOLINT
      : Place(AllocationType::MLU, place.GetDeviceId()) {}
183 184
};

185 186
class CustomPlace : public Place {
 public:
187
  CustomPlace() : Place(AllocationType::CUSTOM, 0, "") {}
188 189 190 191 192 193 194 195 196 197 198 199 200 201
  explicit CustomPlace(const std::string dev_type)
      : Place(AllocationType::CUSTOM, 0, dev_type) {}
  CustomPlace(const std::string dev_type, int device_id)
      : Place(AllocationType::CUSTOM, device_id, dev_type) {}

  CustomPlace(const CustomPlace&) = default;
  CustomPlace(const Place& place) {  // NOLINT
    if (place.GetType() == AllocationType::CUSTOM) {
      this->Reset(
          AllocationType::CUSTOM, place.GetDeviceId(), place.GetDeviceType());
    }
  }
};

202
std::ostream& operator<<(std::ostream&, const Place&);
203

204
}  // namespace phi
205 206 207 208 209

namespace paddle {
namespace experimental {
using AllocationType = phi::AllocationType;
using Place = phi::Place;
210 211 212 213 214
using CPUPlace = phi::CPUPlace;
using GPUPlace = phi::GPUPlace;
using GPUPinnedPlace = phi::GPUPinnedPlace;
using XPUPlace = phi::XPUPlace;
using NPUPlace = phi::NPUPlace;
215 216
}  // namespace experimental
}  // namespace paddle