// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "paddle/ir/core/enforce.h" #include "paddle/ir/core/operation.h" #include "paddle/ir/core/utils.h" namespace ir { class IR_API InterfaceValue { public: template static InterfaceValue get() { InterfaceValue val; val.type_id_ = TypeId::get(); val.model_ = malloc(sizeof(typename T::template Model)); if (val.model_ == nullptr) { throw("Alloc memory for interface failed."); } static_assert(std::is_trivially_destructible< typename T::template Model>::value, "interface models must be trivially destructible"); new (val.model_) typename T::template Model(); return val; } TypeId type_id() const { return type_id_; } void *model() const { return model_; } InterfaceValue() = default; explicit InterfaceValue(TypeId type_id) : type_id_(type_id) {} InterfaceValue(const InterfaceValue &) = delete; InterfaceValue(InterfaceValue &&) noexcept; InterfaceValue &operator=(const InterfaceValue &) = delete; InterfaceValue &operator=(InterfaceValue &&) noexcept; ~InterfaceValue(); void swap(InterfaceValue &&val) { using std::swap; swap(type_id_, val.type_id_); swap(model_, val.model_); } /// /// \brief Comparison operations. /// inline bool operator<(const InterfaceValue &other) const { return type_id_ < other.type_id_; } private: TypeId type_id_; void *model_{nullptr}; }; class IR_API OpBase { public: explicit OpBase(Operation *operation = nullptr) : operation_(operation) {} Operation *operation() const { IR_ENFORCE(operation_, "Can't use operation() in a null op."); return operation_; } explicit operator bool() const { return operation_ != nullptr; } operator Operation *() const { return operation(); } Operation *operator->() const { return operation(); } IrContext *ir_context() const { return operation()->ir_context(); } uint32_t num_results() const { return operation()->num_results(); } uint32_t num_operands() const { return operation()->num_operands(); } const AttributeMap &attributes() const { return operation()->attributes(); } Value operand_source(uint32_t index) const { return operation()->operand_source(index); } OpResult result(uint32_t index) const { return operation()->result(index); } ir::Attribute attribute(const std::string &name) { return operation()->attribute(name); } template T attribute(const std::string &name) { return operation()->attribute(name); } private: Operation *operation_; // Not owned }; /// /// \brief OpTrait /// template class OpTraitBase : public OpBase { public: explicit OpTraitBase(Operation *op) : OpBase(op) {} static TypeId GetTraitId() { return TypeId::get(); } static ConcreteTrait dyn_cast(Operation *op) { if (op && op->HasTrait()) { return ConcreteTrait(op); } return ConcreteTrait(nullptr); } }; /// /// \brief OpInterface /// template class OpInterfaceBase : public OpBase { public: explicit OpInterfaceBase(Operation *op) : OpBase(op) {} static TypeId GetInterfaceId() { return TypeId::get(); } static ConcreteInterface dyn_cast(Operation *op) { if (op && op->HasInterface()) { return ConcreteInterface( op, op->info().GetInterfaceImpl()); } return ConcreteInterface(nullptr, nullptr); } }; template class ConstructInterfacesOrTraits { public: /// Construct method for interfaces. static InterfaceValue *interface(InterfaceValue *p_interface) { (void)std::initializer_list{ 0, (PlacementConstrctInterface(p_interface), 0)...}; return p_interface; } /// Construct method for traits. static TypeId *trait(TypeId *p_trait) { (void)std::initializer_list{ 0, (PlacementConstrctTrait(p_trait), 0)...}; return p_trait; } private: /// Placement new interface. template static void PlacementConstrctInterface( InterfaceValue *&p_interface) { // NOLINT p_interface->swap(InterfaceValue::get()); VLOG(6) << "New a interface: id[" << (p_interface->type_id()).AsOpaquePointer() << "]."; ++p_interface; } /// Placement new trait. template static void PlacementConstrctTrait(ir::TypeId *&p_trait) { // NOLINT *p_trait = TypeId::get(); VLOG(6) << "New a trait: id[" << p_trait->AsOpaquePointer() << "]."; ++p_trait; } }; /// Specialized for tuple type. template class ConstructInterfacesOrTraits> { public: /// Construct method for interfaces. static InterfaceValue *interface(InterfaceValue *p_interface) { return ConstructInterfacesOrTraits::interface( p_interface); } /// Construct method for traits. static TypeId *trait(TypeId *p_trait) { return ConstructInterfacesOrTraits::trait(p_trait); } }; template class Op : public OpBase { public: using OpBase::OpBase; using TraitList = typename Filter>::Type; using InterfaceList = typename Filter>::Type; static ConcreteOp dyn_cast(Operation *op) { if (op && op->info().id() == TypeId::get()) { return ConcreteOp(op); } return ConcreteOp(nullptr); } static bool classof(const Operation *op) { return op && op->info().id() == TypeId::get(); } static std::vector GetInterfaceMap() { constexpr size_t interfaces_num = std::tuple_size::value; std::vector interfaces_map(interfaces_num); ConstructInterfacesOrTraits::interface( interfaces_map.data()); return interfaces_map; } static std::vector GetTraitSet() { constexpr size_t traits_num = std::tuple_size::value; std::vector trait_set(traits_num); auto p_first_trait = trait_set.data(); ConstructInterfacesOrTraits::trait(p_first_trait); return trait_set; } static constexpr bool HasNoDataMembers() { class EmptyOp : public Op {}; return sizeof(ConcreteOp) == sizeof(EmptyOp); } static void VerifyInvariants(Operation *op) { static_assert(HasNoDataMembers(), "Op class shouldn't define new data members"); op->dyn_cast().Verify(); } }; } // namespace ir