place.h 6.0 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>

19
namespace phi {
20 21

enum class AllocationType : int8_t {
22
  UNDEFINED = 0,
23 24 25 26 27 28 29 30
  CPU = 1,
  GPU = 2,
  GPUPINNED = 3,
  XPU = 4,
  NPU = 5,
  NPUPINNED = 6,
  IPU = 7,
  MLU = 8,
31
  CUSTOM = 9,
32
};
33

34
const char* AllocationTypeStr(AllocationType type);
35

36 37 38
size_t GetOrRegisterGlobalDeviceTypeId(const std::string& device_type);
std::string GetGlobalDeviceType(size_t device_type_id_);

39
/// \brief The place is used to specify where the data is stored.
40
class Place {
41
 public:
42
  Place() : device(0), alloc_type_(AllocationType::UNDEFINED) {}
43

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
  explicit Place(AllocationType type,
                 int8_t id,
                 const std::string& dev_type = "")
      : device(id),
        alloc_type_(type),
        device_type_id_(GetOrRegisterGlobalDeviceTypeId(dev_type)) {}

  explicit Place(AllocationType type, const std::string& dev_type = "")
      : device(0),
        alloc_type_(type),
        device_type_id_(GetOrRegisterGlobalDeviceTypeId(dev_type)) {}

  void Reset(AllocationType type,
             int8_t device_id = 0,
             const std::string& dev_type = "") noexcept {
59 60
    alloc_type_ = type;
    device = device_id;
61 62 63
    if (!dev_type.empty()) {
      device_type_id_ = GetOrRegisterGlobalDeviceTypeId(dev_type);
    }
64 65
  }

66
  AllocationType GetType() const { return alloc_type_; }
67

68
  int8_t GetDeviceId() const { return device; }
69

70 71 72 73
  std::string GetDeviceType() const {
    return GetGlobalDeviceType(device_type_id_);
  }

74 75
  std::string DebugString() const;

76 77 78 79 80 81 82 83 84
  inline bool operator==(const Place& rhs) const {
    if (alloc_type_ != rhs.GetType()) {
      return false;
    }
    if (alloc_type_ == AllocationType::CPU ||
        alloc_type_ == AllocationType::GPUPINNED ||
        alloc_type_ == AllocationType::NPUPINNED) {
      return true;
    }
85 86 87 88
    if (alloc_type_ == AllocationType::CUSTOM) {
      return device_type_id_ == rhs.device_type_id_ &&
             device == rhs.GetDeviceId();
    }
89 90 91 92 93 94 95
    return device == rhs.GetDeviceId();
  }
  inline bool operator!=(const Place& rhs) const { return !(*this == rhs); }
  inline bool operator<(const Place& rhs) const {
    if (alloc_type_ != rhs.GetType()) {
      return static_cast<int>(alloc_type_) < static_cast<int>(rhs.GetType());
    }
96 97 98 99
    if (alloc_type_ == AllocationType::CUSTOM &&
        device_type_id_ != rhs.device_type_id_) {
      return device_type_id_ < rhs.device_type_id_;
    }
100 101 102
    return device < rhs.GetDeviceId();
  }

103 104 105
 public:
  // TODO(wilber): Just because of backward compatibility, it needs to be
  // changed to private in the future.
106
  int8_t device{0};
107 108

 private:
109
  AllocationType alloc_type_{AllocationType::UNDEFINED};
110
  size_t device_type_id_;
111 112 113 114
};

class CPUPlace : public Place {
 public:
115 116 117 118
  CPUPlace() : Place(AllocationType::CPU) {}

  CPUPlace(const CPUPlace&) = default;
  CPUPlace(const Place& place) : Place(AllocationType::CPU) {}  // NOLINT
119 120 121 122 123 124
};

class GPUPlace : public Place {
 public:
  GPUPlace() : Place(AllocationType::GPU, 0) {}
  explicit GPUPlace(int device_id) : Place(AllocationType::GPU, device_id) {}
125 126 127 128

  GPUPlace(const GPUPlace&) = default;
  GPUPlace(const Place& place)  // NOLINT
      : Place(AllocationType::GPU, place.GetDeviceId()) {}
129 130 131 132 133
};

class GPUPinnedPlace : public Place {
 public:
  GPUPinnedPlace() : Place(AllocationType::GPUPINNED) {}
134 135 136 137

  GPUPinnedPlace(const GPUPinnedPlace&) = default;
  GPUPinnedPlace(const Place& place)  // NOLINT
      : Place(AllocationType::GPUPINNED) {}
138 139 140 141 142 143
};

class XPUPlace : public Place {
 public:
  XPUPlace() : Place(AllocationType::XPU, 0) {}
  explicit XPUPlace(int device_id) : Place(AllocationType::XPU, device_id) {}
144 145 146 147

  XPUPlace(const XPUPlace&) = default;
  XPUPlace(const Place& place)  // NOLINT
      : Place(AllocationType::XPU, place.GetDeviceId()) {}
148 149 150 151 152
};

class NPUPlace : public Place {
 public:
  NPUPlace() : Place(AllocationType::NPU, 0) {}
153 154 155 156 157
  explicit NPUPlace(int device_id) : Place(AllocationType::NPU, device_id) {}

  NPUPlace(const NPUPlace&) = default;
  NPUPlace(const Place& place)  // NOLINT
      : Place(AllocationType::NPU, place.GetDeviceId()) {}
158 159 160 161 162
};

class NPUPinnedPlace : public Place {
 public:
  NPUPinnedPlace() : Place(AllocationType::NPUPINNED) {}
163 164 165 166

  NPUPinnedPlace(const NPUPinnedPlace&) = default;
  NPUPinnedPlace(const Place& place)  // NOLINT
      : Place(AllocationType::NPUPINNED) {}
167 168 169 170
};

class IPUPlace : public Place {
 public:
171 172 173 174 175 176
  IPUPlace() : Place(AllocationType::IPU, 0) {}
  explicit IPUPlace(int device_id) : Place(AllocationType::IPU, device_id) {}

  IPUPlace(const IPUPlace&) = default;
  IPUPlace(const Place& place)  // NOLINT
      : Place(AllocationType::IPU, place.GetDeviceId()) {}
177 178 179 180 181 182
};

class MLUPlace : public Place {
 public:
  MLUPlace() : Place(AllocationType::MLU, 0) {}
  explicit MLUPlace(int device_id) : Place(AllocationType::MLU, device_id) {}
183 184 185 186

  MLUPlace(const MLUPlace&) = default;
  MLUPlace(const Place& place)  // NOLINT
      : Place(AllocationType::MLU, place.GetDeviceId()) {}
187 188
};

189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
class CustomPlace : public Place {
 public:
  explicit CustomPlace(const std::string dev_type)
      : Place(AllocationType::CUSTOM, 0, dev_type) {}
  CustomPlace(const std::string dev_type, int device_id)
      : Place(AllocationType::CUSTOM, device_id, dev_type) {}

  CustomPlace(const CustomPlace&) = default;
  CustomPlace(const Place& place) {  // NOLINT
    if (place.GetType() == AllocationType::CUSTOM) {
      this->Reset(
          AllocationType::CUSTOM, place.GetDeviceId(), place.GetDeviceType());
    }
  }
};

205
std::ostream& operator<<(std::ostream&, const Place&);
206

207
}  // namespace phi