place.h 5.9 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>

19
namespace phi {
20 21

enum class AllocationType : int8_t {
22
  UNDEFINED = 0,
23 24 25 26 27 28 29 30
  CPU = 1,
  GPU = 2,
  GPUPINNED = 3,
  XPU = 4,
  NPU = 5,
  NPUPINNED = 6,
  IPU = 7,
  MLU = 8,
31
  CUSTOM = 9,
32
};
33

34
const char* AllocationTypeStr(AllocationType type);
35

36 37 38
size_t GetOrRegisterGlobalDeviceTypeId(const std::string& device_type);
std::string GetGlobalDeviceType(size_t device_type_id_);

39
/// \brief The place is used to specify where the data is stored.
40
class Place {
41
 public:
42
  Place() : device(0), alloc_type_(AllocationType::UNDEFINED) {}
43

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
  explicit Place(AllocationType type,
                 int8_t id,
                 const std::string& dev_type = "")
      : device(id),
        alloc_type_(type),
        device_type_id_(GetOrRegisterGlobalDeviceTypeId(dev_type)) {}

  explicit Place(AllocationType type, const std::string& dev_type = "")
      : device(0),
        alloc_type_(type),
        device_type_id_(GetOrRegisterGlobalDeviceTypeId(dev_type)) {}

  void Reset(AllocationType type,
             int8_t device_id = 0,
             const std::string& dev_type = "") noexcept {
59 60
    alloc_type_ = type;
    device = device_id;
61 62 63
    if (!dev_type.empty()) {
      device_type_id_ = GetOrRegisterGlobalDeviceTypeId(dev_type);
    }
64 65
  }

66
  AllocationType GetType() const { return alloc_type_; }
67

68
  int8_t GetDeviceId() const { return device; }
69

70 71 72 73
  std::string GetDeviceType() const {
    return GetGlobalDeviceType(device_type_id_);
  }

74 75
  std::string DebugString() const;

76 77 78 79 80 81 82 83 84
  struct Hash {
    // Note: Now the number of bits we need does not exceed 32 bits, so there is
    // no need to use 64 bits. If needed in the future, it can be expanded,
    // but now we don’t over-design.
    uint32_t operator()(const Place& place) const;
  };

  uint32_t HashValue() const { return Hash()(*this); }

85
  inline bool operator==(const Place& rhs) const {
86 87 88 89
    return HashValue() == rhs.HashValue();
  }
  inline bool operator!=(const Place& rhs) const {
    return HashValue() != rhs.HashValue();
90 91
  }
  inline bool operator<(const Place& rhs) const {
92
    return HashValue() < rhs.HashValue();
93 94
  }

95 96 97
 public:
  // TODO(wilber): Just because of backward compatibility, it needs to be
  // changed to private in the future.
98
  int8_t device{0};
99 100

 private:
101
  AllocationType alloc_type_{AllocationType::UNDEFINED};
102
  size_t device_type_id_;
103 104 105 106
};

class CPUPlace : public Place {
 public:
107 108 109 110
  CPUPlace() : Place(AllocationType::CPU) {}

  CPUPlace(const CPUPlace&) = default;
  CPUPlace(const Place& place) : Place(AllocationType::CPU) {}  // NOLINT
111 112 113 114 115 116
};

class GPUPlace : public Place {
 public:
  GPUPlace() : Place(AllocationType::GPU, 0) {}
  explicit GPUPlace(int device_id) : Place(AllocationType::GPU, device_id) {}
117 118 119 120

  GPUPlace(const GPUPlace&) = default;
  GPUPlace(const Place& place)  // NOLINT
      : Place(AllocationType::GPU, place.GetDeviceId()) {}
121 122 123 124 125
};

class GPUPinnedPlace : public Place {
 public:
  GPUPinnedPlace() : Place(AllocationType::GPUPINNED) {}
126 127 128 129

  GPUPinnedPlace(const GPUPinnedPlace&) = default;
  GPUPinnedPlace(const Place& place)  // NOLINT
      : Place(AllocationType::GPUPINNED) {}
130 131 132 133 134 135
};

class XPUPlace : public Place {
 public:
  XPUPlace() : Place(AllocationType::XPU, 0) {}
  explicit XPUPlace(int device_id) : Place(AllocationType::XPU, device_id) {}
136 137 138 139

  XPUPlace(const XPUPlace&) = default;
  XPUPlace(const Place& place)  // NOLINT
      : Place(AllocationType::XPU, place.GetDeviceId()) {}
140 141 142 143 144
};

class NPUPlace : public Place {
 public:
  NPUPlace() : Place(AllocationType::NPU, 0) {}
145 146 147 148 149
  explicit NPUPlace(int device_id) : Place(AllocationType::NPU, device_id) {}

  NPUPlace(const NPUPlace&) = default;
  NPUPlace(const Place& place)  // NOLINT
      : Place(AllocationType::NPU, place.GetDeviceId()) {}
150 151 152 153 154
};

class NPUPinnedPlace : public Place {
 public:
  NPUPinnedPlace() : Place(AllocationType::NPUPINNED) {}
155 156 157 158

  NPUPinnedPlace(const NPUPinnedPlace&) = default;
  NPUPinnedPlace(const Place& place)  // NOLINT
      : Place(AllocationType::NPUPINNED) {}
159 160 161 162
};

class IPUPlace : public Place {
 public:
163 164 165 166 167 168
  IPUPlace() : Place(AllocationType::IPU, 0) {}
  explicit IPUPlace(int device_id) : Place(AllocationType::IPU, device_id) {}

  IPUPlace(const IPUPlace&) = default;
  IPUPlace(const Place& place)  // NOLINT
      : Place(AllocationType::IPU, place.GetDeviceId()) {}
169 170 171 172 173 174
};

class MLUPlace : public Place {
 public:
  MLUPlace() : Place(AllocationType::MLU, 0) {}
  explicit MLUPlace(int device_id) : Place(AllocationType::MLU, device_id) {}
175 176 177 178

  MLUPlace(const MLUPlace&) = default;
  MLUPlace(const Place& place)  // NOLINT
      : Place(AllocationType::MLU, place.GetDeviceId()) {}
179 180
};

181 182
class CustomPlace : public Place {
 public:
183
  CustomPlace() : Place(AllocationType::CUSTOM, 0, "") {}
184 185 186 187 188 189 190 191 192 193 194 195 196 197
  explicit CustomPlace(const std::string dev_type)
      : Place(AllocationType::CUSTOM, 0, dev_type) {}
  CustomPlace(const std::string dev_type, int device_id)
      : Place(AllocationType::CUSTOM, device_id, dev_type) {}

  CustomPlace(const CustomPlace&) = default;
  CustomPlace(const Place& place) {  // NOLINT
    if (place.GetType() == AllocationType::CUSTOM) {
      this->Reset(
          AllocationType::CUSTOM, place.GetDeviceId(), place.GetDeviceType());
    }
  }
};

198
std::ostream& operator<<(std::ostream&, const Place&);
199

200
}  // namespace phi
201 202 203 204 205 206 207

namespace paddle {
namespace experimental {
using AllocationType = phi::AllocationType;
using Place = phi::Place;
}  // namespace experimental
}  // namespace paddle