op_base.h 7.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16
#include <type_traits>
17

18
#include "paddle/ir/core/enforce.h"
19 20
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/utils.h"
21 22

namespace ir {
23

24
class IR_API InterfaceValue {
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
 public:
  template <typename ConcreteOp, typename T>
  static InterfaceValue get() {
    InterfaceValue val;
    val.type_id_ = TypeId::get<T>();
    val.model_ = malloc(sizeof(typename T::template Model<ConcreteOp>));
    if (val.model_ == nullptr) {
      throw("Alloc memory for interface failed.");
    }
    static_assert(std::is_trivially_destructible<
                      typename T::template Model<ConcreteOp>>::value,
                  "interface models must be trivially destructible");
    new (val.model_) typename T::template Model<ConcreteOp>();
    return val;
  }
  TypeId type_id() const { return type_id_; }
  void *model() const { return model_; }

  InterfaceValue() = default;
  explicit InterfaceValue(TypeId type_id) : type_id_(type_id) {}
  InterfaceValue(const InterfaceValue &) = delete;
C
cyberslack_lee 已提交
46
  InterfaceValue(InterfaceValue &&) noexcept;
47
  InterfaceValue &operator=(const InterfaceValue &) = delete;
C
cyberslack_lee 已提交
48
  InterfaceValue &operator=(InterfaceValue &&) noexcept;
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
  ~InterfaceValue();
  void swap(InterfaceValue &&val) {
    using std::swap;
    swap(type_id_, val.type_id_);
    swap(model_, val.model_);
  }

  ///
  /// \brief Comparison operations.
  ///
  inline bool operator<(const InterfaceValue &other) const {
    return type_id_ < other.type_id_;
  }

 private:
  TypeId type_id_;
  void *model_{nullptr};
};

68
class IR_API OpBase {
69
 public:
70
  explicit OpBase(Operation *operation = nullptr) : operation_(operation) {}
71

72 73 74 75 76 77
  Operation *operation() const {
    IR_ENFORCE(operation_, "Can't use operation() in a null op.");
    return operation_;
  }

  explicit operator bool() const { return operation_ != nullptr; }
78

79
  operator Operation *() const { return operation(); }
80

81
  Operation *operator->() const { return operation(); }
82

83
  IrContext *ir_context() const { return operation()->ir_context(); }
84

85
  uint32_t num_results() const { return operation()->num_results(); }
86

87
  uint32_t num_operands() const { return operation()->num_operands(); }
88

89
  const AttributeMap &attributes() const { return operation()->attributes(); }
90

91 92 93
  Value operand_source(uint32_t index) const {
    return operation()->operand_source(index);
  }
94

95
  OpResult result(uint32_t index) const { return operation()->result(index); }
96

97 98 99 100 101 102 103 104
  ir::Attribute attribute(const std::string &name) {
    return operation()->attribute(name);
  }

  template <typename T>
  T attribute(const std::string &name) {
    return operation()->attribute<T>(name);
  }
105

106
 private:
Z
zhangbo9674 已提交
107
  Operation *operation_;  // Not owned
108 109 110 111 112 113 114 115
};

///
/// \brief OpTrait
///
template <class ConcreteTrait>
class OpTraitBase : public OpBase {
 public:
Z
zhangbo9674 已提交
116
  explicit OpTraitBase(Operation *op) : OpBase(op) {}
117 118

  static TypeId GetTraitId() { return TypeId::get<ConcreteTrait>(); }
119

Z
zhangbo9674 已提交
120
  static ConcreteTrait dyn_cast(Operation *op) {
121
    if (op && op->HasTrait<ConcreteTrait>()) {
122 123 124 125
      return ConcreteTrait(op);
    }
    return ConcreteTrait(nullptr);
  }
126 127 128 129 130 131 132 133
};

///
/// \brief OpInterface
///
template <typename ConcreteInterface>
class OpInterfaceBase : public OpBase {
 public:
Z
zhangbo9674 已提交
134
  explicit OpInterfaceBase(Operation *op) : OpBase(op) {}
135 136

  static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); }
137

Z
zhangbo9674 已提交
138
  static ConcreteInterface dyn_cast(Operation *op) {
139
    if (op && op->HasInterface<ConcreteInterface>()) {
140
      return ConcreteInterface(
141
          op, op->info().GetInterfaceImpl<ConcreteInterface>());
142 143 144
    }
    return ConcreteInterface(nullptr, nullptr);
  }
145 146
};

147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
template <typename ConcreteOp, typename... Args>
class ConstructInterfacesOrTraits {
 public:
  /// Construct method for interfaces.
  static InterfaceValue *interface(InterfaceValue *p_interface) {
    (void)std::initializer_list<int>{
        0, (PlacementConstrctInterface<Args>(p_interface), 0)...};
    return p_interface;
  }

  /// Construct method for traits.
  static TypeId *trait(TypeId *p_trait) {
    (void)std::initializer_list<int>{
        0, (PlacementConstrctTrait<Args>(p_trait), 0)...};
    return p_trait;
  }

 private:
  /// Placement new interface.
  template <typename T>
  static void PlacementConstrctInterface(
      InterfaceValue *&p_interface) {  // NOLINT
    p_interface->swap(InterfaceValue::get<ConcreteOp, T>());
170
    VLOG(6) << "New a interface: id["
171
            << (p_interface->type_id()).AsOpaquePointer() << "].";
172 173 174 175 176 177 178
    ++p_interface;
  }

  /// Placement new trait.
  template <typename T>
  static void PlacementConstrctTrait(ir::TypeId *&p_trait) {  // NOLINT
    *p_trait = TypeId::get<T>();
179
    VLOG(6) << "New a trait: id[" << p_trait->AsOpaquePointer() << "].";
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
    ++p_trait;
  }
};

/// Specialized for tuple type.
template <typename ConcreteOp, typename... Args>
class ConstructInterfacesOrTraits<ConcreteOp, std::tuple<Args...>> {
 public:
  /// Construct method for interfaces.
  static InterfaceValue *interface(InterfaceValue *p_interface) {
    return ConstructInterfacesOrTraits<ConcreteOp, Args...>::interface(
        p_interface);
  }

  /// Construct method for traits.
  static TypeId *trait(TypeId *p_trait) {
    return ConstructInterfacesOrTraits<ConcreteOp, Args...>::trait(p_trait);
  }
};

200 201 202 203 204 205 206 207 208 209
template <typename ConcreteOp, class... TraitOrInterface>
class Op : public OpBase {
 public:
  using OpBase::OpBase;

  using TraitList =
      typename Filter<OpTraitBase, std::tuple<TraitOrInterface...>>::Type;

  using InterfaceList =
      typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type;
210

Z
zhangbo9674 已提交
211
  static ConcreteOp dyn_cast(Operation *op) {
212
    if (op && op->info().id() == TypeId::get<ConcreteOp>()) {
213 214 215 216 217
      return ConcreteOp(op);
    }
    return ConcreteOp(nullptr);
  }

218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
  static std::vector<InterfaceValue> GetInterfaceMap() {
    constexpr size_t interfaces_num = std::tuple_size<InterfaceList>::value;
    std::vector<InterfaceValue> interfaces_map(interfaces_num);
    ConstructInterfacesOrTraits<ConcreteOp, InterfaceList>::interface(
        interfaces_map.data());
    return interfaces_map;
  }

  static std::vector<TypeId> GetTraitSet() {
    constexpr size_t traits_num = std::tuple_size<TraitList>::value;
    std::vector<TypeId> trait_set(traits_num);
    auto p_first_trait = trait_set.data();
    ConstructInterfacesOrTraits<ConcreteOp, TraitList>::trait(p_first_trait);
    return trait_set;
  }
233 234 235 236 237 238 239 240 241 242
  static constexpr bool HasNoDataMembers() {
    class EmptyOp : public Op<EmptyOp, TraitOrInterface...> {};
    return sizeof(ConcreteOp) == sizeof(EmptyOp);
  }

  static void VerifyInvariants(Operation *op) {
    static_assert(HasNoDataMembers(),
                  "Op class shouldn't define new data members");
    op->dyn_cast<ConcreteOp>().Verify();
  }
243
};
244

245
}  // namespace ir