diff --git a/paddle/ir/builtin_attribute_storage.cc b/paddle/ir/builtin_attribute_storage.cc index c7feacae4d64affd9b35ac46463f5fc095f8913f..3f785d20c9b92218aa5b315d6ad4ec94fd1310cb 100644 --- a/paddle/ir/builtin_attribute_storage.cc +++ b/paddle/ir/builtin_attribute_storage.cc @@ -59,19 +59,16 @@ DictionaryAttributeStorage::ParamKey DictionaryAttributeStorage::GetAsKey() } Attribute DictionaryAttributeStorage::GetValue(const StrAttribute &name) const { - if (size_ > 0) { - size_t left = 0; - size_t right = size_ - 1; - size_t mid = 0; - while (left <= right) { - mid = (left + right) / 2; - if (data_[mid].name() == name) { - return data_[mid].value(); - } else if (data_[mid].name() < name) { - left = mid + 1; - } else { - right = mid - 1; - } + size_t left = 0; + size_t right = size_; + while (left < right) { + size_t mid = left + (right - left) / 2; + if (data_[mid].name() == name) { + return data_[mid].value(); + } else if (data_[mid].name() < name) { + left = mid + 1; + } else { + right = mid; } } return nullptr; diff --git a/paddle/ir/dialect.cc b/paddle/ir/dialect.cc index cda00baadc3cb4f3dd2c7703e2d0cfdaefb3e4ba..cdc171557148abf4a8abf05c24895c15c26f4263 100644 --- a/paddle/ir/dialect.cc +++ b/paddle/ir/dialect.cc @@ -31,4 +31,8 @@ void Dialect::RegisterAttribute(ir::AbstractAttribute &&abstract_attribute) { this->ir_context()->RegisterAbstractAttribute( new_abstract_attribute->type_id(), new_abstract_attribute); } + +void Dialect::RegisterOp(const std::string &name, OpInfoImpl *op_info) { + this->ir_context()->RegisterOpInfo(name, op_info); +} } // namespace ir diff --git a/paddle/ir/dialect.h b/paddle/ir/dialect.h index 40a5a823e3dff5aa96e51f2257f7887a31f17d6d..10adc6b3e2e0b61666796d878d94e2a31f6a8453 100644 --- a/paddle/ir/dialect.h +++ b/paddle/ir/dialect.h @@ -16,6 +16,7 @@ #include "paddle/ir/attribute_base.h" #include "paddle/ir/ir_context.h" +#include "paddle/ir/op_info_impl.h" #include "paddle/ir/type_base.h" namespace ir { @@ -45,17 +46,19 @@ class Dialect { (void)std::initializer_list{0, (RegisterType(), 0)...}; } - /// - /// \brief Register type of class T. - /// template void RegisterType() { VLOG(4) << "Type registered into Dialect. --->"; - ir::AbstractType *abstract_type = - new ir::AbstractType(std::move(ir::AbstractType::get(*this))); - this->ir_context()->RegisterAbstractType(ir::TypeId::get(), - abstract_type); - ir::TypeManager::RegisterType(this->ir_context()); + // if (this->ir_context()->registed_abstract_type().count( + // ir::TypeId::get()) == 0) { + if (this->ir_context()->GetRegisteredAbstractType(ir::TypeId::get()) == + nullptr) { + ir::AbstractType *abstract_type = + new ir::AbstractType(std::move(ir::AbstractType::get(*this))); + this->ir_context()->RegisterAbstractType(ir::TypeId::get(), + abstract_type); + ir::TypeManager::RegisterType(this->ir_context()); + } VLOG(4) << "----------------------------------"; } @@ -78,24 +81,42 @@ class Dialect { (void)std::initializer_list{0, (RegisterAttribute(), 0)...}; } - /// - /// \brief Register attribute of class T. - /// template void RegisterAttribute() { VLOG(4) << "Attribute registered into Dialect. --->"; - ir::AbstractAttribute *abstract_attribute = new ir::AbstractAttribute( - std::move(ir::AbstractAttribute::get(*this))); - this->ir_context()->RegisterAbstractAttribute(ir::TypeId::get(), - abstract_attribute); - ir::AttributeManager::RegisterAttribute(this->ir_context()); + if (this->ir_context()->GetRegisteredAbstractAttribute( + ir::TypeId::get()) == nullptr) { + ir::AbstractAttribute *abstract_attribute = new ir::AbstractAttribute( + std::move(ir::AbstractAttribute::get(*this))); + this->ir_context()->RegisterAbstractAttribute(ir::TypeId::get(), + abstract_attribute); + ir::AttributeManager::RegisterAttribute(this->ir_context()); + } VLOG(4) << "----------------------------------"; } + void RegisterAttribute(ir::AbstractAttribute &&abstract_attribute); + /// - /// \brief Register abstract_attribute into context. + /// \brief Register Operation methods. /// - void RegisterAttribute(ir::AbstractAttribute &&abstract_attribute); + template + void RegisterOps() { + (void)std::initializer_list{0, (RegisterOp(), 0)...}; + } + + template + void RegisterOp() { + std::string name = this->name() + "." + std::string(ConcertOp::name()); + VLOG(4) << "Op " << name << " registered into Dialect. --->"; + if (this->ir_context()->GetRegisteredOpInfo(name) == nullptr) { + ir::OpInfoImpl *op_info = ir::OpInfoImpl::create(this); + this->ir_context()->RegisterOpInfo(name, op_info); + } + VLOG(4) << "----------------------------------"; + } + + void RegisterOp(const std::string &name, OpInfoImpl *op_info); private: std::string name_; diff --git a/paddle/ir/ir_context.cc b/paddle/ir/ir_context.cc index 17098ccb5e5560eebf3b6f84975d10846ad58569..d172b748cddae92638959c0452c1ceaa7ad7f3f2 100644 --- a/paddle/ir/ir_context.cc +++ b/paddle/ir/ir_context.cc @@ -20,6 +20,7 @@ #include "paddle/ir/builtin_dialect.h" #include "paddle/ir/builtin_type.h" #include "paddle/ir/dialect.h" +#include "paddle/ir/op_info_impl.h" #include "paddle/ir/spin_lock.h" #include "paddle/ir/type_base.h" @@ -46,6 +47,11 @@ class IrContextImpl { delete dialect_map.second; } registed_dialect_.clear(); + + for (auto &op_map : registed_op_infos_) { + op_map.second->destroy(); + } + registed_op_infos_.clear(); } void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type) { @@ -93,6 +99,25 @@ class IrContextImpl { return nullptr; } + void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) { + std::lock_guard guard(registed_op_infos_lock_); + VLOG(4) << "Register an operation of: [Name=" << name + << ", OpInfoImpl ptr=" << opinfo << "]."; + registed_op_infos_.emplace(name, opinfo); + } + + OpInfoImpl *GetOpInfo(const std::string &name) { + std::lock_guard guard(registed_op_infos_lock_); + auto iter = registed_op_infos_.find(name); + if (iter != registed_op_infos_.end()) { + VLOG(4) << "Fonund a cached operation of: [name=" << name + << ", OpInfoImpl ptr=" << iter->second << "]."; + return iter->second; + } + LOG(WARNING) << "No cache found operation of: [Name=" << name << "]."; + return nullptr; + } + void RegisterDialect(std::string name, Dialect *dialect) { std::lock_guard guard(registed_dialect_lock_); VLOG(4) << "Register a dialect of: [name=" << name @@ -135,6 +160,10 @@ class IrContextImpl { std::unordered_map registed_dialect_; ir::SpinLock registed_dialect_lock_; + // The Op registered in the context. + std::unordered_map registed_op_infos_; + ir::SpinLock registed_op_infos_lock_; + ir::SpinLock destructor_lock_; }; @@ -165,9 +194,12 @@ StorageManager &IrContext::type_storage_manager() { return impl().registed_type_storage_manager_; } -std::unordered_map - &IrContext::registed_abstracted_type() { - return impl().registed_abstract_types_; +AbstractType *IrContext::GetRegisteredAbstractType(TypeId id) { + auto search = impl().registed_abstract_types_.find(id); + if (search != impl().registed_abstract_types_.end()) { + return search->second; + } + return nullptr; } void IrContext::RegisterAbstractAttribute( @@ -179,9 +211,12 @@ StorageManager &IrContext::attribute_storage_manager() { return impl().registed_attribute_storage_manager_; } -std::unordered_map - &IrContext::registed_abstracted_attribute() { - return impl().registed_abstract_attributes_; +AbstractAttribute *IrContext::GetRegisteredAbstractAttribute(TypeId id) { + auto search = impl().registed_abstract_attributes_.find(id); + if (search != impl().registed_abstract_attributes_.end()) { + return search->second; + } + return nullptr; } Dialect *IrContext::GetOrRegisterDialect( @@ -216,6 +251,17 @@ Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) { return nullptr; } +OpInfoImpl *IrContext::GetRegisteredOpInfo(const std::string &name) { + OpInfoImpl *rtn = impl().GetOpInfo(name); + return rtn ? rtn : nullptr; +} + +void IrContext::RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) { + if (impl().GetOpInfo(name) == nullptr) { + impl().RegisterOpInfo(name, opinfo); + } +} + const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) { auto &impl = ctx->impl(); AbstractType *abstract_type = impl.GetAbstractType(type_id); diff --git a/paddle/ir/ir_context.h b/paddle/ir/ir_context.h index f7512344888f63f7fbc6529c4964ebceaf511dc1..efa03dd569536a954189371a57033c56ea854417 100644 --- a/paddle/ir/ir_context.h +++ b/paddle/ir/ir_context.h @@ -26,6 +26,7 @@ class AbstractType; class AbstractAttribute; class TypeId; class Dialect; +class OpInfoImpl; /// /// \brief IrContext is a global parameterless class used to store and manage @@ -47,7 +48,7 @@ class IrContext { IrContextImpl &impl() { return *impl_; } /// - /// \brief Register an AbstractType to IrContext + /// \brief Register an AbstractType to IrContext. /// /// \param type_id The type id of the AbstractType. /// \param abstract_type AbstractType* provided by user. @@ -64,13 +65,9 @@ class IrContext { StorageManager &type_storage_manager(); /// - /// \brief Returns the storage uniquer used for constructing TypeStorage - /// instances. - /// - /// \return The storage uniquer used for constructing TypeStorage - /// instances. + /// \brief Get registered AbstractType from IrContext. /// - std::unordered_map ®isted_abstracted_type(); + AbstractType *GetRegisteredAbstractType(TypeId id); /// /// \brief Register an AbstractAttribute to IrContext @@ -91,14 +88,16 @@ class IrContext { StorageManager &attribute_storage_manager(); /// - /// \brief Returns the storage uniquer used for constructing AttributeStorage - /// instances. + /// \brief Get registered AbstractAttribute from IrContext. /// - /// \return The storage uniquer used for constructing AttributeStorage - /// instances. + AbstractAttribute *GetRegisteredAbstractAttribute(TypeId id); + + /// + /// \brief Get or register operaiton. /// - std::unordered_map - ®isted_abstracted_attribute(); + void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo); + + OpInfoImpl *GetRegisteredOpInfo(const std::string &name); /// /// \brief Get the dialect of the DialectT class in the context, ff not found, diff --git a/paddle/ir/op_base.h b/paddle/ir/op_base.h index 38ff4002c6b2b387f9e849d356958622b8a0c162..23318f5e69c4717fabbd9560c1b504c5a92d9f4b 100644 --- a/paddle/ir/op_base.h +++ b/paddle/ir/op_base.h @@ -15,23 +15,59 @@ #pragma once #include "paddle/ir/operation.h" +#include "paddle/ir/utils.h" namespace ir { class OpBase { public: - Operation *operation() { return operation_; } + explicit OpBase(const Operation *operation) : operation_(operation) {} - explicit operator bool() { return operation() != nullptr; } + const Operation *operation() const { return operation_; } - operator Operation *() const { return operation_; } + explicit operator bool() const { return operation() != nullptr; } - Operation *operator->() const { return operation_; } + operator const Operation *() const { return operation_; } - protected: - explicit OpBase(Operation *operation) : operation_(operation) {} + const Operation *operator->() const { return operation_; } private: - Operation *operation_; + const Operation *operation_; // Not owned +}; + +/// +/// \brief OpTrait +/// +template +class OpTraitBase : public OpBase { + public: + explicit OpTraitBase(const Operation *op) : OpBase(op) {} + + static TypeId GetTraitId() { return TypeId::get(); } +}; + +/// +/// \brief OpInterface +/// +template +class OpInterfaceBase : public OpBase { + public: + // explicit OpInterfaceBase(Operation *op) : OpBase(op) {} + + explicit OpInterfaceBase(const Operation *op) : OpBase(op) {} + + static TypeId GetInterfaceId() { return TypeId::get(); } +}; + +template +class Op : public OpBase { + public: + using OpBase::OpBase; + + using TraitList = + typename Filter>::Type; + + using InterfaceList = + typename Filter>::Type; }; } // namespace ir diff --git a/paddle/ir/op_info.h b/paddle/ir/op_info.h new file mode 100644 index 0000000000000000000000000000000000000000..2828f59c2942ae4d8e916b0bd7f80f3bbab2528a --- /dev/null +++ b/paddle/ir/op_info.h @@ -0,0 +1,67 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/ir/op_info_impl.h" + +namespace ir { +class OpInfo { + public: + constexpr OpInfo() = default; + + OpInfo(const OpInfoImpl *impl) : impl_(impl) {} // NOLINT + + OpInfo(const OpInfo &other) = default; + + OpInfo &operator=(const OpInfo &other) = default; + + bool operator==(OpInfo other) const { return impl_ == other.impl_; } + + bool operator!=(OpInfo other) const { return impl_ != other.impl_; } + + explicit operator bool() const { return impl_; } + + bool operator!() const { return impl_ == nullptr; } + + const OpInfoImpl *impl() const { return impl_; } + + template + bool HasTrait() const { + return impl_->HasTrait(); + } + + template + bool HasInterface() const { + return impl_->HasInterface(); + } + + friend struct std::hash; + + private: + const OpInfoImpl *impl_{nullptr}; // not owned +}; + +} // namespace ir + +namespace std { +template <> +struct hash { + std::size_t operator()(const ir::OpInfo &obj) const { + return std::hash()(obj.impl_); + } +}; +} // namespace std diff --git a/paddle/ir/op_info_impl.h b/paddle/ir/op_info_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..78ca586d9d209887a9c9e307226c9ab049eab968 --- /dev/null +++ b/paddle/ir/op_info_impl.h @@ -0,0 +1,282 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "paddle/ir/builtin_attribute.h" +// #include "paddle/ir/ir_context.h" +#include "paddle/ir/type.h" + +namespace ir { +class Dialect; +/// +/// \brief Tool template class for construct interfaces or Traits. +/// +template +class ConstructInterfacesOrTraits { + public: + /// Construct method for interfaces. + static std::pair *interface( + std::pair *p_interface) { + (void)std::initializer_list{ + 0, (PlacementConstrctInterface(p_interface), 0)...}; + return p_interface; + } + + /// Construct method for traits. + static TypeId *trait(TypeId *p_trait) { + (void)std::initializer_list{ + 0, (PlacementConstrctTrait(p_trait), 0)...}; + return p_trait; + } + + private: + /// Placement new interface. + template + static void PlacementConstrctInterface( + std::pair *&p_interface) { // NOLINT + new (&(p_interface->first)) TypeId(ir::TypeId::get()); + p_interface->second = + malloc(sizeof(typename T::template Model)); + new (p_interface->second) typename T::template Model(); + VLOG(4) << "New a interface: id[" << p_interface->first.storage() + << "], interface[" << p_interface->second << "]."; + ++p_interface; + } + + /// Placement new trait. + template + static void PlacementConstrctTrait(ir::TypeId *&p_trait) { // NOLINT + new (p_trait) TypeId(ir::TypeId::get()); + VLOG(4) << "New a trait: id[" << (*p_trait).storage() << "]."; + ++p_trait; + } +}; + +/// Specialized for tuple type. +template +class ConstructInterfacesOrTraits> { + public: + /// Construct method for interfaces. + static std::pair *interface( + std::pair *p_interface) { + return ConstructInterfacesOrTraits::interface( + p_interface); + } + + /// Construct method for traits. + static TypeId *trait(TypeId *p_trait) { + return ConstructInterfacesOrTraits::trait(p_trait); + } +}; + +/// +/// \brief OpInfoImpl class. +/// +class OpInfoImpl { + public: + /// + /// \brief Construct and Deconstruct OpInfoImpl. The memory layout of + /// OpInfoImpl is: std::pair... | TypeId... | OpInfoImpl + /// + template + static OpInfoImpl *create(ir::Dialect *dialect) { + // (1) Malloc memory for interfaces, traits, opinfo_impl. + size_t interfaces_num = + std::tuple_size::value; + size_t traits_num = std::tuple_size::value; + size_t attributes_num = ConcreteOp::attributes_num(); + VLOG(4) << "Create OpInfoImpl with: " << interfaces_num << " interfaces, " + << traits_num << " traits, " << attributes_num << " attributes."; + size_t base_size = sizeof(std::pair) * interfaces_num + + sizeof(ir::TypeId) * traits_num + sizeof(OpInfoImpl); + void *base_ptr = malloc(base_size); + VLOG(4) << "Malloc " << base_size << " Bytes at " << base_ptr; + + // (2) Construct interfaces and sort by TypeId. + std::pair *p_first_interface = nullptr; + if (interfaces_num > 0) { + p_first_interface = + reinterpret_cast *>(base_ptr); + VLOG(4) << "Construct interfaces at " << p_first_interface << " ......"; + ConstructInterfacesOrTraits< + ConcreteOp, + typename ConcreteOp::InterfaceList>::interface(p_first_interface); + std::sort(p_first_interface, p_first_interface + interfaces_num); + base_ptr = reinterpret_cast(p_first_interface + interfaces_num); + } + + // (3) Construct traits and sort by TypeId. + ir::TypeId *p_first_trait = nullptr; + if (traits_num > 0) { + p_first_trait = reinterpret_cast(base_ptr); + VLOG(4) << "Construct traits at " << p_first_trait << " ......"; + ConstructInterfacesOrTraits:: + trait(p_first_trait); + std::sort(p_first_trait, p_first_trait + traits_num); + base_ptr = reinterpret_cast(p_first_trait + traits_num); + } + + // (4) Construct opinfo_impl. + OpInfoImpl *p_opinfo_impl = reinterpret_cast(base_ptr); + VLOG(4) << "Construct op_info_impl at " << p_opinfo_impl << " ......"; + OpInfoImpl *op_info = + new (p_opinfo_impl) OpInfoImpl(interfaces_num, + traits_num, + ConcreteOp::attributes_name_, + attributes_num, + ir::TypeId::get(), + ConcreteOp::name(), + dialect); + return op_info; + } + + void destroy() { + VLOG(4) << "Destroy op_info impl at " << this; + // (1) free interfaces + void *base_ptr = reinterpret_cast( + reinterpret_cast(this) - sizeof(ir::TypeId) * num_traits_ - + sizeof(std::pair) * num_interfaces_); + if (num_interfaces_ > 0) { + std::pair *p_first_interface = + reinterpret_cast *>(base_ptr); + for (size_t i = 0; i < num_interfaces_; i++) { + free((p_first_interface + i)->second); + } + } + // (2) free memeory + VLOG(4) << "Free base_ptr " << base_ptr; + free(base_ptr); + } + + /// + /// \brief Search methods for Trait or Interface. + /// + template + bool HasTrait() const { + return HasTrait(TypeId::get()); + } + + bool HasTrait(TypeId trait_id) const { + if (num_traits_ > 0) { + TypeId *p_first_trait = reinterpret_cast( + reinterpret_cast(const_cast(this)) - + sizeof(ir::TypeId) * num_traits_); + return std::binary_search( + p_first_trait, p_first_trait + num_traits_, trait_id); + } + return false; + } + + template + bool HasInterface() const { + return HasInterface(TypeId::get()); + } + + bool HasInterface(TypeId interface_id) const { + if (num_interfaces_ > 0) { + std::pair *p_first_interface = + reinterpret_cast *>( + reinterpret_cast(const_cast(this)) - + sizeof(ir::TypeId) * num_traits_ - + sizeof(std::pair) * num_interfaces_); + return std::binary_search(p_first_interface, + p_first_interface + num_interfaces_, + std::make_pair(interface_id, nullptr), + CompareInterface); + } + return false; + } + + template + typename Interface::Concept *GetInterfaceImpl() const { + if (num_interfaces_ > 0) { + ir::TypeId interface_id = ir::TypeId::get(); + std::pair *p_first_interface = + reinterpret_cast *>( + reinterpret_cast(const_cast(this)) - + sizeof(ir::TypeId) * num_traits_ - + sizeof(std::pair) * num_interfaces_); + size_t left = 0; + size_t right = num_interfaces_; + while (left < right) { + size_t mid = left + (right - left) / 2; + if ((p_first_interface + mid)->first == interface_id) { + return reinterpret_cast( + (p_first_interface + mid)->second); + } else if ((p_first_interface + mid)->first < interface_id) { + left = mid + 1; + } else { + right = mid; + } + } + } + return nullptr; + } + + ir::TypeId id() const { return op_id_; } + + const char *name() const { return op_name_; } + + ir::Dialect *dialect() const { return dialect_; } + + private: + OpInfoImpl(uint32_t num_interfaces, + uint32_t num_traits, + const char **p_attributes, + uint32_t num_attributes, + TypeId op_id, + const char *op_name, + ir::Dialect *dialect) + : num_interfaces_(num_interfaces), + num_traits_(num_traits), + p_attributes_(p_attributes), + num_attributes_(num_attributes), + op_id_(op_id), + op_name_(op_name), + dialect_(dialect) {} + + static bool CompareInterface(const std::pair &a, + const std::pair &b) { + return a.first < b.first; + } + + /// Interface will be recorded by std::pair. + uint32_t num_interfaces_ = 0; + + /// Trait will be recorded by TypeId. + uint32_t num_traits_ = 0; + + /// Attributes array address. + const char **p_attributes_{nullptr}; + + /// The number of attributes for this Op. + uint32_t num_attributes_ = 0; + + /// The TypeId of this Op. + TypeId op_id_; + + /// The name of this Op. + const char *op_name_; + + /// The dialect of this Op belong to. + ir::Dialect *dialect_; +}; + +} // namespace ir diff --git a/paddle/ir/operation.cc b/paddle/ir/operation.cc index e9d727f1b5fb34b77804828dc7553a6ef0ec28ea..9222ee7e1afbb4777174c69d2a5f39eafd18f25a 100644 --- a/paddle/ir/operation.cc +++ b/paddle/ir/operation.cc @@ -21,7 +21,8 @@ namespace ir { // OpInlineResult, Operation, Operand. Operation *Operation::create(const std::vector &inputs, const std::vector &output_types, - ir::DictionaryAttribute attribute) { + ir::DictionaryAttribute attribute, + ir::OpInfo op_info) { // 1. Calculate the required memory size for OpResults + Operation + // OpOperands. uint32_t num_results = output_types.size(); @@ -52,7 +53,7 @@ Operation *Operation::create(const std::vector &inputs, } // 3.2. Construct Operation. Operation *op = - new (base_ptr) Operation(num_results, num_operands, attribute); + new (base_ptr) Operation(num_results, num_operands, attribute, op_info); base_ptr += sizeof(Operation); // 3.3. Construct OpOperands. if ((reinterpret_cast(base_ptr) & 0x7) != 0) { @@ -116,13 +117,15 @@ void Operation::destroy() { Operation::Operation(uint32_t num_results, uint32_t num_operands, - ir::DictionaryAttribute attribute) { + ir::DictionaryAttribute attribute, + ir::OpInfo op_info) { if (!attribute) { throw("unexpected null attribute dictionary"); } num_results_ = num_results; num_operands_ = num_operands; attribute_ = attribute; + op_info_ = op_info; } ir::OpResult Operation::GetResultByIndex(uint32_t index) { diff --git a/paddle/ir/operation.h b/paddle/ir/operation.h index 924dcafb73dfc4da68a67ab3753dee1ab099e986..a51043ad5687e565eda1eb12392729c58d87b872 100644 --- a/paddle/ir/operation.h +++ b/paddle/ir/operation.h @@ -15,10 +15,15 @@ #pragma once #include "paddle/ir/builtin_attribute.h" +#include "paddle/ir/op_info.h" #include "paddle/ir/type.h" #include "paddle/ir/value_impl.h" namespace ir { +template +class OpTraitBase; +template +class OpInterfaceBase; class alignas(8) Operation final { public: @@ -28,7 +33,8 @@ class alignas(8) Operation final { /// static Operation *create(const std::vector &inputs, const std::vector &output_types, - ir::DictionaryAttribute attribute); + ir::DictionaryAttribute attribute, + ir::OpInfo op_info); void destroy(); @@ -36,19 +42,60 @@ class alignas(8) Operation final { std::string print(); - ir::DictionaryAttribute attribute() { return attribute_; } + ir::DictionaryAttribute attribute() const { return attribute_; } - uint32_t num_results() { return num_results_; } + ir::OpInfo op_info() const { return op_info_; } - uint32_t num_operands() { return num_operands_; } + uint32_t num_results() const { return num_results_; } + + uint32_t num_operands() const { return num_operands_; } + + template + T dyn_cast() const { + return CastUtil::call(this); + } + + template + bool HasTrait() const { + return op_info_.HasTrait(); + } + + template + bool HasInterface() const { + return op_info_.HasInterface(); + } private: Operation(uint32_t num_results, uint32_t num_operands, - ir::DictionaryAttribute attribute); + ir::DictionaryAttribute attribute, + ir::OpInfo op_info); + + template + struct CastUtil { + static T call(const Operation *op) { + throw("Can't dyn_cast to T, T should be a Trait or Interface"); + } + }; + template + struct CastUtil, T>::value>::type> { + static T call(const Operation *op) { return T(op); } + }; + template + struct CastUtil, T>::value>::type> { + static T call(const Operation *op) { + return T(op, op->op_info_.impl()->GetInterfaceImpl()); + } + }; ir::DictionaryAttribute attribute_; + ir::OpInfo op_info_; + uint32_t num_results_ = 0; uint32_t num_operands_ = 0; diff --git a/paddle/ir/type_id.h b/paddle/ir/type_id.h index b7a2dcd362d012eb200a3f909eb78f60a43e8cea..33cd09b5d4a785bae0f6dc491a89da03fd6799e1 100644 --- a/paddle/ir/type_id.h +++ b/paddle/ir/type_id.h @@ -45,6 +45,12 @@ class TypeId { return TypeId(&instance); } + TypeId(const TypeId &other) = default; + + TypeId &operator=(const TypeId &other) = default; + + const Storage *storage() const { return storage_; } + /// /// \brief Comparison operations. /// @@ -54,6 +60,9 @@ class TypeId { inline bool operator!=(const TypeId &other) const { return !(*this == other); } + inline bool operator<(const TypeId &other) const { + return storage_ < other.storage_; + } /// /// \brief Enable hashing TypeId instances. diff --git a/paddle/ir/utils.h b/paddle/ir/utils.h index b4dd00281e15982b8c9c658ec03d3cb2e33b9877..f4316a7e57e446c3029c657e10f69a65149b3cd8 100644 --- a/paddle/ir/utils.h +++ b/paddle/ir/utils.h @@ -17,12 +17,107 @@ #include #include #include +#include +#include namespace ir { +/// +/// \brief Equivalent to boost::hash_combine. +/// std::size_t hash_combine(std::size_t lhs, std::size_t rhs); +/// +/// \brief Aligned malloc and free functions. +/// void *aligned_malloc(size_t size, size_t alignment); void aligned_free(void *mem_ptr); +/// +/// \brief Some template methods for manipulating std::tuple. +/// +/// (1) Pop front element from Tuple +template +struct PopFrontT; + +template +struct PopFrontT> { + public: + using Type = std::tuple; +}; + +template +using PopFront = typename PopFrontT::Type; + +/// (2) Push front element to Tuple +template +struct PushFrontT; + +template +struct PushFrontT> { + public: + using Type = std::tuple; +}; + +template +struct PushFrontT, std::tuple> { + public: + using Type = std::tuple; +}; + +template +using PushFront = typename PushFrontT::Type; + +/// (3) IsEmpty +template +struct IsEmpty { + static constexpr bool value = false; +}; + +template <> +struct IsEmpty> { + static constexpr bool value = true; +}; + +/// (4) IfThenElseT +template +struct IfThenElseT { + using Type = TrueT; +}; + +template +struct IfThenElseT { + using Type = FalseT; +}; + +template +using IfThenElse = typename IfThenElseT::Type; + +/// (5) Filter out all types inherited from BaseT from the tuple. +template