op_base.h 6.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16
#include <type_traits>
17

18 19
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/utils.h"
20 21

namespace ir {
22

23
class IR_API InterfaceValue {
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
 public:
  template <typename ConcreteOp, typename T>
  static InterfaceValue get() {
    InterfaceValue val;
    val.type_id_ = TypeId::get<T>();
    val.model_ = malloc(sizeof(typename T::template Model<ConcreteOp>));
    if (val.model_ == nullptr) {
      throw("Alloc memory for interface failed.");
    }
    static_assert(std::is_trivially_destructible<
                      typename T::template Model<ConcreteOp>>::value,
                  "interface models must be trivially destructible");
    new (val.model_) typename T::template Model<ConcreteOp>();
    return val;
  }
  TypeId type_id() const { return type_id_; }
  void *model() const { return model_; }

  InterfaceValue() = default;
  explicit InterfaceValue(TypeId type_id) : type_id_(type_id) {}
  InterfaceValue(const InterfaceValue &) = delete;
  InterfaceValue(InterfaceValue &&);
  InterfaceValue &operator=(const InterfaceValue &) = delete;
  InterfaceValue &operator=(InterfaceValue &&);
  ~InterfaceValue();
  void swap(InterfaceValue &&val) {
    using std::swap;
    swap(type_id_, val.type_id_);
    swap(model_, val.model_);
  }

  ///
  /// \brief Comparison operations.
  ///
  inline bool operator<(const InterfaceValue &other) const {
    return type_id_ < other.type_id_;
  }

 private:
  TypeId type_id_;
  void *model_{nullptr};
};

67
class IR_API OpBase {
68
 public:
69
  explicit OpBase(Operation *operation = nullptr) : operation_(operation) {}
70

Z
zhangbo9674 已提交
71
  Operation *operation() const { return operation_; }
72

73
  explicit operator bool() const { return operation() != nullptr; }
74

Z
zhangbo9674 已提交
75
  operator Operation *() const { return operation_; }
76

Z
zhangbo9674 已提交
77
  Operation *operator->() const { return operation_; }
78

79 80
  IrContext *ir_context() const { return operation_->ir_context(); }

81
 private:
Z
zhangbo9674 已提交
82
  Operation *operation_;  // Not owned
83 84 85 86 87 88 89 90
};

///
/// \brief OpTrait
///
template <class ConcreteTrait>
class OpTraitBase : public OpBase {
 public:
Z
zhangbo9674 已提交
91
  explicit OpTraitBase(Operation *op) : OpBase(op) {}
92 93

  static TypeId GetTraitId() { return TypeId::get<ConcreteTrait>(); }
94

Z
zhangbo9674 已提交
95
  static ConcreteTrait dyn_cast(Operation *op) {
96
    if (op && op->HasTrait<ConcreteTrait>()) {
97 98 99 100
      return ConcreteTrait(op);
    }
    return ConcreteTrait(nullptr);
  }
101 102 103 104 105 106 107 108
};

///
/// \brief OpInterface
///
template <typename ConcreteInterface>
class OpInterfaceBase : public OpBase {
 public:
Z
zhangbo9674 已提交
109
  explicit OpInterfaceBase(Operation *op) : OpBase(op) {}
110 111

  static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); }
112

Z
zhangbo9674 已提交
113
  static ConcreteInterface dyn_cast(Operation *op) {
114
    if (op && op->HasInterface<ConcreteInterface>()) {
115
      return ConcreteInterface(
116
          op, op->info().GetInterfaceImpl<ConcreteInterface>());
117 118 119
    }
    return ConcreteInterface(nullptr, nullptr);
  }
120 121
};

122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
template <typename ConcreteOp, typename... Args>
class ConstructInterfacesOrTraits {
 public:
  /// Construct method for interfaces.
  static InterfaceValue *interface(InterfaceValue *p_interface) {
    (void)std::initializer_list<int>{
        0, (PlacementConstrctInterface<Args>(p_interface), 0)...};
    return p_interface;
  }

  /// Construct method for traits.
  static TypeId *trait(TypeId *p_trait) {
    (void)std::initializer_list<int>{
        0, (PlacementConstrctTrait<Args>(p_trait), 0)...};
    return p_trait;
  }

 private:
  /// Placement new interface.
  template <typename T>
  static void PlacementConstrctInterface(
      InterfaceValue *&p_interface) {  // NOLINT
    p_interface->swap(InterfaceValue::get<ConcreteOp, T>());
145 146
    VLOG(4) << "New a interface: id["
            << (p_interface->type_id()).AsOpaquePointer() << "].";
147 148 149 150 151 152 153
    ++p_interface;
  }

  /// Placement new trait.
  template <typename T>
  static void PlacementConstrctTrait(ir::TypeId *&p_trait) {  // NOLINT
    *p_trait = TypeId::get<T>();
154
    VLOG(4) << "New a trait: id[" << p_trait->AsOpaquePointer() << "].";
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
    ++p_trait;
  }
};

/// Specialized for tuple type.
template <typename ConcreteOp, typename... Args>
class ConstructInterfacesOrTraits<ConcreteOp, std::tuple<Args...>> {
 public:
  /// Construct method for interfaces.
  static InterfaceValue *interface(InterfaceValue *p_interface) {
    return ConstructInterfacesOrTraits<ConcreteOp, Args...>::interface(
        p_interface);
  }

  /// Construct method for traits.
  static TypeId *trait(TypeId *p_trait) {
    return ConstructInterfacesOrTraits<ConcreteOp, Args...>::trait(p_trait);
  }
};

175 176 177 178 179 180 181 182 183 184
template <typename ConcreteOp, class... TraitOrInterface>
class Op : public OpBase {
 public:
  using OpBase::OpBase;

  using TraitList =
      typename Filter<OpTraitBase, std::tuple<TraitOrInterface...>>::Type;

  using InterfaceList =
      typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type;
185

Z
zhangbo9674 已提交
186
  static ConcreteOp dyn_cast(Operation *op) {
187
    if (op && op->info().id() == TypeId::get<ConcreteOp>()) {
188 189 190 191 192
      return ConcreteOp(op);
    }
    return ConcreteOp(nullptr);
  }

193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
  static std::vector<InterfaceValue> GetInterfaceMap() {
    constexpr size_t interfaces_num = std::tuple_size<InterfaceList>::value;
    std::vector<InterfaceValue> interfaces_map(interfaces_num);
    ConstructInterfacesOrTraits<ConcreteOp, InterfaceList>::interface(
        interfaces_map.data());
    return interfaces_map;
  }

  static std::vector<TypeId> GetTraitSet() {
    constexpr size_t traits_num = std::tuple_size<TraitList>::value;
    std::vector<TypeId> trait_set(traits_num);
    auto p_first_trait = trait_set.data();
    ConstructInterfacesOrTraits<ConcreteOp, TraitList>::trait(p_first_trait);
    return trait_set;
  }
};
209

210
}  // namespace ir