// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "paddle/ir/core/operation.h" #include "paddle/ir/core/utils.h" namespace ir { class InterfaceValue { public: template static InterfaceValue get() { InterfaceValue val; val.type_id_ = TypeId::get(); val.model_ = malloc(sizeof(typename T::template Model)); if (val.model_ == nullptr) { throw("Alloc memory for interface failed."); } static_assert(std::is_trivially_destructible< typename T::template Model>::value, "interface models must be trivially destructible"); new (val.model_) typename T::template Model(); return val; } TypeId type_id() const { return type_id_; } void *model() const { return model_; } InterfaceValue() = default; explicit InterfaceValue(TypeId type_id) : type_id_(type_id) {} InterfaceValue(const InterfaceValue &) = delete; InterfaceValue(InterfaceValue &&); InterfaceValue &operator=(const InterfaceValue &) = delete; InterfaceValue &operator=(InterfaceValue &&); ~InterfaceValue(); void swap(InterfaceValue &&val) { using std::swap; swap(type_id_, val.type_id_); swap(model_, val.model_); } /// /// \brief Comparison operations. /// inline bool operator<(const InterfaceValue &other) const { return type_id_ < other.type_id_; } private: TypeId type_id_; void *model_{nullptr}; }; class OpBase { public: explicit OpBase(Operation *operation) : operation_(operation) {} Operation *operation() const { return operation_; } explicit operator bool() const { return operation() != nullptr; } operator Operation *() const { return operation_; } Operation *operator->() const { return operation_; } private: Operation *operation_; // Not owned }; /// /// \brief OpTrait /// template class OpTraitBase : public OpBase { public: explicit OpTraitBase(Operation *op) : OpBase(op) {} static TypeId GetTraitId() { return TypeId::get(); } static ConcreteTrait dyn_cast(Operation *op) { if (op->HasTrait()) { return ConcreteTrait(op); } return ConcreteTrait(nullptr); } }; /// /// \brief OpInterface /// template class OpInterfaceBase : public OpBase { public: explicit OpInterfaceBase(Operation *op) : OpBase(op) {} static TypeId GetInterfaceId() { return TypeId::get(); } static ConcreteInterface dyn_cast(Operation *op) { if (op->HasInterface()) { return ConcreteInterface( op, op->op_info().GetInterfaceImpl()); } return ConcreteInterface(nullptr, nullptr); } }; template class ConstructInterfacesOrTraits { public: /// Construct method for interfaces. static InterfaceValue *interface(InterfaceValue *p_interface) { (void)std::initializer_list{ 0, (PlacementConstrctInterface(p_interface), 0)...}; return p_interface; } /// Construct method for traits. static TypeId *trait(TypeId *p_trait) { (void)std::initializer_list{ 0, (PlacementConstrctTrait(p_trait), 0)...}; return p_trait; } private: /// Placement new interface. template static void PlacementConstrctInterface( InterfaceValue *&p_interface) { // NOLINT p_interface->swap(InterfaceValue::get()); VLOG(4) << "New a interface: id[" << (p_interface->type_id()).storage() << "]."; ++p_interface; } /// Placement new trait. template static void PlacementConstrctTrait(ir::TypeId *&p_trait) { // NOLINT *p_trait = TypeId::get(); VLOG(4) << "New a trait: id[" << p_trait->storage() << "]."; ++p_trait; } }; /// Specialized for tuple type. template class ConstructInterfacesOrTraits> { public: /// Construct method for interfaces. static InterfaceValue *interface(InterfaceValue *p_interface) { return ConstructInterfacesOrTraits::interface( p_interface); } /// Construct method for traits. static TypeId *trait(TypeId *p_trait) { return ConstructInterfacesOrTraits::trait(p_trait); } }; template class Op : public OpBase { public: using OpBase::OpBase; using TraitList = typename Filter>::Type; using InterfaceList = typename Filter>::Type; static ConcreteOp dyn_cast(Operation *op) { if (op->op_info().id() == TypeId::get()) { return ConcreteOp(op); } return ConcreteOp(nullptr); } static std::vector GetInterfaceMap() { constexpr size_t interfaces_num = std::tuple_size::value; std::vector interfaces_map(interfaces_num); ConstructInterfacesOrTraits::interface( interfaces_map.data()); return interfaces_map; } static std::vector GetTraitSet() { constexpr size_t traits_num = std::tuple_size::value; std::vector trait_set(traits_num); auto p_first_trait = trait_set.data(); ConstructInterfacesOrTraits::trait(p_first_trait); return trait_set; } }; } // namespace ir