op_base.h 6.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16
#include <type_traits>
17

18 19
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/utils.h"
20 21

namespace ir {
22

23
class IR_API InterfaceValue {
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
 public:
  template <typename ConcreteOp, typename T>
  static InterfaceValue get() {
    InterfaceValue val;
    val.type_id_ = TypeId::get<T>();
    val.model_ = malloc(sizeof(typename T::template Model<ConcreteOp>));
    if (val.model_ == nullptr) {
      throw("Alloc memory for interface failed.");
    }
    static_assert(std::is_trivially_destructible<
                      typename T::template Model<ConcreteOp>>::value,
                  "interface models must be trivially destructible");
    new (val.model_) typename T::template Model<ConcreteOp>();
    return val;
  }
  TypeId type_id() const { return type_id_; }
  void *model() const { return model_; }

  InterfaceValue() = default;
  explicit InterfaceValue(TypeId type_id) : type_id_(type_id) {}
  InterfaceValue(const InterfaceValue &) = delete;
  InterfaceValue(InterfaceValue &&);
  InterfaceValue &operator=(const InterfaceValue &) = delete;
  InterfaceValue &operator=(InterfaceValue &&);
  ~InterfaceValue();
  void swap(InterfaceValue &&val) {
    using std::swap;
    swap(type_id_, val.type_id_);
    swap(model_, val.model_);
  }

  ///
  /// \brief Comparison operations.
  ///
  inline bool operator<(const InterfaceValue &other) const {
    return type_id_ < other.type_id_;
  }

 private:
  TypeId type_id_;
  void *model_{nullptr};
};

67
class IR_API OpBase {
68
 public:
69
  explicit OpBase(Operation *operation = nullptr) : operation_(operation) {}
70

Z
zhangbo9674 已提交
71
  Operation *operation() const { return operation_; }
72

73
  explicit operator bool() const { return operation() != nullptr; }
74

Z
zhangbo9674 已提交
75
  operator Operation *() const { return operation_; }
76

Z
zhangbo9674 已提交
77
  Operation *operator->() const { return operation_; }
78

79 80
  IrContext *ir_context() const { return operation_->ir_context(); }

81 82 83 84 85 86
  uint32_t num_results() const { return operation_->num_results(); }

  uint32_t num_operands() const { return operation_->num_operands(); }

  const AttributeMap &attributes() const { return operation_->attributes(); }

87
 private:
Z
zhangbo9674 已提交
88
  Operation *operation_;  // Not owned
89 90 91 92 93 94 95 96
};

///
/// \brief OpTrait
///
template <class ConcreteTrait>
class OpTraitBase : public OpBase {
 public:
Z
zhangbo9674 已提交
97
  explicit OpTraitBase(Operation *op) : OpBase(op) {}
98 99

  static TypeId GetTraitId() { return TypeId::get<ConcreteTrait>(); }
100

Z
zhangbo9674 已提交
101
  static ConcreteTrait dyn_cast(Operation *op) {
102
    if (op && op->HasTrait<ConcreteTrait>()) {
103 104 105 106
      return ConcreteTrait(op);
    }
    return ConcreteTrait(nullptr);
  }
107 108 109 110 111 112 113 114
};

///
/// \brief OpInterface
///
template <typename ConcreteInterface>
class OpInterfaceBase : public OpBase {
 public:
Z
zhangbo9674 已提交
115
  explicit OpInterfaceBase(Operation *op) : OpBase(op) {}
116 117

  static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); }
118

Z
zhangbo9674 已提交
119
  static ConcreteInterface dyn_cast(Operation *op) {
120
    if (op && op->HasInterface<ConcreteInterface>()) {
121
      return ConcreteInterface(
122
          op, op->info().GetInterfaceImpl<ConcreteInterface>());
123 124 125
    }
    return ConcreteInterface(nullptr, nullptr);
  }
126 127
};

128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
template <typename ConcreteOp, typename... Args>
class ConstructInterfacesOrTraits {
 public:
  /// Construct method for interfaces.
  static InterfaceValue *interface(InterfaceValue *p_interface) {
    (void)std::initializer_list<int>{
        0, (PlacementConstrctInterface<Args>(p_interface), 0)...};
    return p_interface;
  }

  /// Construct method for traits.
  static TypeId *trait(TypeId *p_trait) {
    (void)std::initializer_list<int>{
        0, (PlacementConstrctTrait<Args>(p_trait), 0)...};
    return p_trait;
  }

 private:
  /// Placement new interface.
  template <typename T>
  static void PlacementConstrctInterface(
      InterfaceValue *&p_interface) {  // NOLINT
    p_interface->swap(InterfaceValue::get<ConcreteOp, T>());
151 152
    VLOG(4) << "New a interface: id["
            << (p_interface->type_id()).AsOpaquePointer() << "].";
153 154 155 156 157 158 159
    ++p_interface;
  }

  /// Placement new trait.
  template <typename T>
  static void PlacementConstrctTrait(ir::TypeId *&p_trait) {  // NOLINT
    *p_trait = TypeId::get<T>();
160
    VLOG(4) << "New a trait: id[" << p_trait->AsOpaquePointer() << "].";
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
    ++p_trait;
  }
};

/// Specialized for tuple type.
template <typename ConcreteOp, typename... Args>
class ConstructInterfacesOrTraits<ConcreteOp, std::tuple<Args...>> {
 public:
  /// Construct method for interfaces.
  static InterfaceValue *interface(InterfaceValue *p_interface) {
    return ConstructInterfacesOrTraits<ConcreteOp, Args...>::interface(
        p_interface);
  }

  /// Construct method for traits.
  static TypeId *trait(TypeId *p_trait) {
    return ConstructInterfacesOrTraits<ConcreteOp, Args...>::trait(p_trait);
  }
};

181 182 183 184 185 186 187 188 189 190
template <typename ConcreteOp, class... TraitOrInterface>
class Op : public OpBase {
 public:
  using OpBase::OpBase;

  using TraitList =
      typename Filter<OpTraitBase, std::tuple<TraitOrInterface...>>::Type;

  using InterfaceList =
      typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type;
191

Z
zhangbo9674 已提交
192
  static ConcreteOp dyn_cast(Operation *op) {
193
    if (op && op->info().id() == TypeId::get<ConcreteOp>()) {
194 195 196 197 198
      return ConcreteOp(op);
    }
    return ConcreteOp(nullptr);
  }

199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
  static std::vector<InterfaceValue> GetInterfaceMap() {
    constexpr size_t interfaces_num = std::tuple_size<InterfaceList>::value;
    std::vector<InterfaceValue> interfaces_map(interfaces_num);
    ConstructInterfacesOrTraits<ConcreteOp, InterfaceList>::interface(
        interfaces_map.data());
    return interfaces_map;
  }

  static std::vector<TypeId> GetTraitSet() {
    constexpr size_t traits_num = std::tuple_size<TraitList>::value;
    std::vector<TypeId> trait_set(traits_num);
    auto p_first_trait = trait_set.data();
    ConstructInterfacesOrTraits<ConcreteOp, TraitList>::trait(p_first_trait);
    return trait_set;
  }
214 215 216 217 218 219 220 221 222 223
  static constexpr bool HasNoDataMembers() {
    class EmptyOp : public Op<EmptyOp, TraitOrInterface...> {};
    return sizeof(ConcreteOp) == sizeof(EmptyOp);
  }

  static void VerifyInvariants(Operation *op) {
    static_assert(HasNoDataMembers(),
                  "Op class shouldn't define new data members");
    op->dyn_cast<ConcreteOp>().Verify();
  }
224
};
225

226
}  // namespace ir