From 9d9f0ce5106897a9789ae146e925a13ec8cd3ff8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Fri, 19 May 2023 15:36:38 +0800 Subject: [PATCH] [IR] fine-tune the implementation of ir component. (#53894) --- paddle/ir/builder.h | 49 +++++++ paddle/ir/builtin_op.cc | 6 +- paddle/ir/builtin_op.h | 12 +- paddle/ir/dialect.cc | 18 --- paddle/ir/dialect.h | 57 +++----- paddle/ir/ir_context.cc | 52 +++++-- paddle/ir/ir_context.h | 30 ++-- paddle/ir/op_base.cc | 31 ++++ paddle/ir/op_base.h | 116 ++++++++++++++- paddle/ir/op_info.cc | 153 ++++++++++++++++++++ paddle/ir/op_info.h | 31 +++- paddle/ir/op_info_impl.h | 252 +++++---------------------------- paddle/ir/operation.cc | 14 +- paddle/ir/operation.h | 13 +- paddle/ir/operation_utils.cc | 31 ++++ paddle/ir/operation_utils.h | 83 +++++++++++ paddle/ir/type_id.h | 5 +- test/cpp/ir/ir_op_test.cc | 43 +++--- test/cpp/ir/ir_program_test.cc | 18 ++- 19 files changed, 657 insertions(+), 357 deletions(-) create mode 100644 paddle/ir/builder.h create mode 100644 paddle/ir/op_base.cc create mode 100644 paddle/ir/op_info.cc create mode 100644 paddle/ir/operation_utils.cc create mode 100644 paddle/ir/operation_utils.h diff --git a/paddle/ir/builder.h b/paddle/ir/builder.h new file mode 100644 index 00000000000..a9b67582b7d --- /dev/null +++ b/paddle/ir/builder.h @@ -0,0 +1,49 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/ir/operation.h" + +namespace ir { +/// +/// \brief Unified interface of the Attribute class. Derivation of all Attribute +/// classes only derives interfaces, not members. +/// +class Builder { + public: + explicit Builder(IrContext *context) : context_(context) {} + explicit Builder(Operation *op) : Builder(op->ir_context()) {} + + /// Create an operation of specific op type at the current insertion point. + template + OpTy create(Args &&...args) { + OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name())); + OpTy::build(*this, argument, std::forward(args)...); + Operation *op = Operation::create(argument); + return dyn_cast(op); + } + + private: + IrContext *context_; + // The current op list this builder is inserting into. + // After the design of the block data structure is completed, + // this member will be replaced by the block. + std::list *op_list_ = nullptr; + // The insertion point within the list that this builder is inserting before. + std::list::iterator insertPoint; +}; +} // namespace ir diff --git a/paddle/ir/builtin_op.cc b/paddle/ir/builtin_op.cc index 528631b0b79..d000d086b0f 100644 --- a/paddle/ir/builtin_op.cc +++ b/paddle/ir/builtin_op.cc @@ -15,8 +15,10 @@ #include "paddle/ir/builtin_op.h" namespace ir { -const char *GetParameterOp::attributes_name_[] = {"parameter_name"}; +const char *GetParameterOp::attributes_name[attributes_num] = { + "parameter_name"}; -const char *SetParameterOp::attributes_name_[] = {"parameter_name"}; +const char *SetParameterOp::attributes_name[attributes_num] = { + "parameter_name"}; } // namespace ir diff --git a/paddle/ir/builtin_op.h b/paddle/ir/builtin_op.h index f5c4296394a..ca29867ff4a 100644 --- a/paddle/ir/builtin_op.h +++ b/paddle/ir/builtin_op.h @@ -32,11 +32,11 @@ class GetParameterOp : public ir::Op { public: using Op::Op; - static const char* name() { return "GetParameterOp"; } + static const char* name() { return "builtin.get_parameter"; } - static uint32_t attributes_num() { return 1; } + static constexpr uint32_t attributes_num = 1; - static const char* attributes_name_[]; + static const char* attributes_name[attributes_num]; }; /// @@ -47,11 +47,11 @@ class SetParameterOp : public ir::Op { public: using Op::Op; - static const char* name() { return "SetParameterOp"; } + static const char* name() { return "builtin.set_parameter"; } - static uint32_t attributes_num() { return 1; } + static constexpr uint32_t attributes_num = 1; - static const char* attributes_name_[]; + static const char* attributes_name[attributes_num]; }; } // namespace ir diff --git a/paddle/ir/dialect.cc b/paddle/ir/dialect.cc index 8d52f3bbc9d..8764daf861a 100644 --- a/paddle/ir/dialect.cc +++ b/paddle/ir/dialect.cc @@ -20,24 +20,6 @@ Dialect::Dialect(std::string name, ir::IrContext *context, ir::TypeId id) Dialect::~Dialect() = default; -void Dialect::RegisterType(ir::AbstractType &&abstract_type) { - ir::AbstractType *new_abstract_type = - new ir::AbstractType(std::move(abstract_type)); - this->ir_context()->RegisterAbstractType(new_abstract_type->type_id(), - new_abstract_type); -} - -void Dialect::RegisterAttribute(ir::AbstractAttribute &&abstract_attribute) { - ir::AbstractAttribute *new_abstract_attribute = - new ir::AbstractAttribute(std::move(abstract_attribute)); - this->ir_context()->RegisterAbstractAttribute( - new_abstract_attribute->type_id(), new_abstract_attribute); -} - -void Dialect::RegisterOp(const std::string &name, OpInfoImpl *op_info) { - this->ir_context()->RegisterOpInfo(name, op_info); -} - void Dialect::RegisterInterface(std::unique_ptr interface) { VLOG(4) << "Register interface into dialect" << std::endl; auto it = registered_interfaces_.emplace(interface->interface_id(), diff --git a/paddle/ir/dialect.h b/paddle/ir/dialect.h index 9a6b42dad5a..9fdb931f733 100644 --- a/paddle/ir/dialect.h +++ b/paddle/ir/dialect.h @@ -17,7 +17,7 @@ #include "paddle/ir/attribute_base.h" #include "paddle/ir/dialect_interface.h" #include "paddle/ir/ir_context.h" -#include "paddle/ir/op_info_impl.h" +#include "paddle/ir/op_base.h" #include "paddle/ir/type_base.h" namespace ir { @@ -52,27 +52,11 @@ class Dialect { template void RegisterType() { - VLOG(4) << "Type registered into Dialect. --->"; - if (this->ir_context()->GetRegisteredAbstractType(ir::TypeId::get()) == - nullptr) { - ir::AbstractType *abstract_type = - new ir::AbstractType(std::move(ir::AbstractType::get(*this))); - this->ir_context()->RegisterAbstractType(ir::TypeId::get(), - abstract_type); - ir::TypeManager::RegisterType(this->ir_context()); - } - VLOG(4) << "----------------------------------"; + ir_context()->RegisterAbstractType(TypeId::get(), + AbstractType::get(*this)); + TypeManager::RegisterType(ir_context()); } - /// - /// \brief Register abstract_type into context. - /// NOTE: It's not recommended to use this interface directly. This interface - /// only registers abstract_type. To register TypeStorage into context, you - /// need to call ir::TypeManager::RegisterType() additionally, - /// RegisterType() is recommended to use. - /// - void RegisterType(ir::AbstractType &&abstract_type); - /// /// \brief Register all attributes contained in the template parameter Args. /// To register only one Attribute, you can use the RegisterAttribute template @@ -85,37 +69,28 @@ class Dialect { template void RegisterAttribute() { - VLOG(4) << "Attribute registered into Dialect. --->"; - if (this->ir_context()->GetRegisteredAbstractAttribute( - ir::TypeId::get()) == nullptr) { - ir::AbstractAttribute *abstract_attribute = new ir::AbstractAttribute( - std::move(ir::AbstractAttribute::get(*this))); - this->ir_context()->RegisterAbstractAttribute(ir::TypeId::get(), - abstract_attribute); - ir::AttributeManager::RegisterAttribute(this->ir_context()); - } - VLOG(4) << "----------------------------------"; + ir_context()->RegisterAbstractAttribute(TypeId::get(), + AbstractAttribute::get(*this)); + AttributeManager::RegisterAttribute(ir_context()); } - void RegisterAttribute(ir::AbstractAttribute &&abstract_attribute); - /// - /// \brief Register Operation methods. + /// \brief Register Ops. /// template void RegisterOps() { (void)std::initializer_list{0, (RegisterOp(), 0)...}; } - template + template void RegisterOp() { - std::string name = this->name() + "." + std::string(ConcertOp::name()); - VLOG(4) << "Op " << name << " registered into Dialect. --->"; - if (this->ir_context()->GetRegisteredOpInfo(name) == nullptr) { - ir::OpInfoImpl *op_info = ir::OpInfoImpl::create(this); - this->ir_context()->RegisterOpInfo(name, op_info); - } - VLOG(4) << "----------------------------------"; + ir_context()->RegisterOpInfo(this, + TypeId::get(), + ConcreteOp::name(), + ConcreteOp::GetInterfaceMap(), + ConcreteOp::GetTraitSet(), + ConcreteOp::attributes_num, + ConcreteOp::attributes_name); } void RegisterOp(const std::string &name, OpInfoImpl *op_info); diff --git a/paddle/ir/ir_context.cc b/paddle/ir/ir_context.cc index d172b748cdd..03907599690 100644 --- a/paddle/ir/ir_context.cc +++ b/paddle/ir/ir_context.cc @@ -185,11 +185,6 @@ IrContext::IrContext() : impl_(new IrContextImpl()) { impl_->int64_type = TypeManager::get(this); } -void IrContext::RegisterAbstractType(ir::TypeId type_id, - AbstractType *abstract_type) { - impl().RegisterAbstractType(type_id, abstract_type); -} - StorageManager &IrContext::type_storage_manager() { return impl().registed_type_storage_manager_; } @@ -203,8 +198,14 @@ AbstractType *IrContext::GetRegisteredAbstractType(TypeId id) { } void IrContext::RegisterAbstractAttribute( - ir::TypeId type_id, AbstractAttribute *abstract_attribute) { - impl().RegisterAbstractAttribute(type_id, abstract_attribute); + ir::TypeId type_id, AbstractAttribute &&abstract_attribute) { + if (GetRegisteredAbstractAttribute(type_id) == nullptr) { + impl().RegisterAbstractAttribute( + type_id, new AbstractAttribute(std::move(abstract_attribute))); + VLOG(4) << "<--- Attribute registered into IrContext. --->"; + } else { + LOG(WARNING) << " Attribute already registered."; + } } StorageManager &IrContext::attribute_storage_manager() { @@ -251,17 +252,44 @@ Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) { return nullptr; } -OpInfoImpl *IrContext::GetRegisteredOpInfo(const std::string &name) { - OpInfoImpl *rtn = impl().GetOpInfo(name); - return rtn ? rtn : nullptr; +void IrContext::RegisterAbstractType(ir::TypeId type_id, + AbstractType &&abstract_type) { + if (GetRegisteredAbstractType(type_id) == nullptr) { + impl().RegisterAbstractType(type_id, + new AbstractType(std::move(abstract_type))); + VLOG(4) << "<--- Type registered into IrContext. --->"; + } else { + LOG(WARNING) << " type already registered."; + } } -void IrContext::RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) { - if (impl().GetOpInfo(name) == nullptr) { +void IrContext::RegisterOpInfo(Dialect *dialect, + TypeId op_id, + const char *name, + std::vector &&interface_map, + const std::vector &trait_set, + size_t attributes_num, + const char **attributes_name) { + if (GetRegisteredOpInfo(name) == nullptr) { + OpInfoImpl *opinfo = OpInfoImpl::create(dialect, + op_id, + name, + std::move(interface_map), + trait_set, + attributes_num, + attributes_name); impl().RegisterOpInfo(name, opinfo); + VLOG(4) << "Op " << name << " registered into IrContext. --->"; + } else { + LOG(WARNING) << name << " op already registered."; } } +OpInfo IrContext::GetRegisteredOpInfo(const std::string &name) { + OpInfoImpl *rtn = impl().GetOpInfo(name); + return rtn ? rtn : nullptr; +} + const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) { auto &impl = ctx->impl(); AbstractType *abstract_type = impl.GetAbstractType(type_id); diff --git a/paddle/ir/ir_context.h b/paddle/ir/ir_context.h index efa03dd5695..c3b8c5b34bb 100644 --- a/paddle/ir/ir_context.h +++ b/paddle/ir/ir_context.h @@ -13,11 +13,9 @@ // limitations under the License. #pragma once - -#include #include #include -#include +#include namespace ir { class IrContextImpl; @@ -26,8 +24,8 @@ class AbstractType; class AbstractAttribute; class TypeId; class Dialect; -class OpInfoImpl; - +class OpInfo; +class InterfaceValue; /// /// \brief IrContext is a global parameterless class used to store and manage /// Type, Attribute and other related data structures. @@ -53,7 +51,7 @@ class IrContext { /// \param type_id The type id of the AbstractType. /// \param abstract_type AbstractType* provided by user. /// - void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type); + void RegisterAbstractType(TypeId type_id, AbstractType &&abstract_type); /// /// \brief Returns the storage uniquer used for constructing TypeStorage @@ -73,10 +71,10 @@ class IrContext { /// \brief Register an AbstractAttribute to IrContext /// /// \param type_id The type id of the AbstractAttribute. - /// \param abstract_attribute AbstractAttribute* provided by user. + /// \param abstract_attribute AbstractAttribute provided by user. /// void RegisterAbstractAttribute(ir::TypeId type_id, - AbstractAttribute *abstract_attribute); + AbstractAttribute &&abstract_attribute); /// /// \brief Returns the storage uniquer used for constructing AttributeStorage @@ -93,11 +91,20 @@ class IrContext { AbstractAttribute *GetRegisteredAbstractAttribute(TypeId id); /// - /// \brief Get or register operaiton. + /// \brief Register an op infomation to IrContext /// - void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo); + void RegisterOpInfo(Dialect *dialect, + TypeId op_id, + const char *name, + std::vector &&interface_map, + const std::vector &trait_set, + size_t attributes_num, + const char **attributes_name); - OpInfoImpl *GetRegisteredOpInfo(const std::string &name); + /// + /// \brief Get registered operaiton infomation. + /// + OpInfo GetRegisteredOpInfo(const std::string &name); /// /// \brief Get the dialect of the DialectT class in the context, ff not found, @@ -162,7 +169,6 @@ class IrContext { private: IrContext(); - const std::unique_ptr impl_; }; diff --git a/paddle/ir/op_base.cc b/paddle/ir/op_base.cc new file mode 100644 index 00000000000..30e5a68e933 --- /dev/null +++ b/paddle/ir/op_base.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/op_base.h" +namespace ir { +InterfaceValue::~InterfaceValue() { + if (model_) free(model_); +} + +InterfaceValue::InterfaceValue(InterfaceValue&& val) { + type_id_ = val.type_id_; + model_ = val.model_; + val.model_ = nullptr; +} + +InterfaceValue& InterfaceValue::operator=(InterfaceValue&& val) { + swap(std::move(val)); + return *this; +} +} // namespace ir diff --git a/paddle/ir/op_base.h b/paddle/ir/op_base.h index 23318f5e69c..9a42afa0c67 100644 --- a/paddle/ir/op_base.h +++ b/paddle/ir/op_base.h @@ -13,11 +13,57 @@ // limitations under the License. #pragma once +#include #include "paddle/ir/operation.h" #include "paddle/ir/utils.h" namespace ir { + +class InterfaceValue { + public: + template + static InterfaceValue get() { + InterfaceValue val; + val.type_id_ = TypeId::get(); + val.model_ = malloc(sizeof(typename T::template Model)); + if (val.model_ == nullptr) { + throw("Alloc memory for interface failed."); + } + static_assert(std::is_trivially_destructible< + typename T::template Model>::value, + "interface models must be trivially destructible"); + new (val.model_) typename T::template Model(); + return val; + } + TypeId type_id() const { return type_id_; } + void *model() const { return model_; } + + InterfaceValue() = default; + explicit InterfaceValue(TypeId type_id) : type_id_(type_id) {} + InterfaceValue(const InterfaceValue &) = delete; + InterfaceValue(InterfaceValue &&); + InterfaceValue &operator=(const InterfaceValue &) = delete; + InterfaceValue &operator=(InterfaceValue &&); + ~InterfaceValue(); + void swap(InterfaceValue &&val) { + using std::swap; + swap(type_id_, val.type_id_); + swap(model_, val.model_); + } + + /// + /// \brief Comparison operations. + /// + inline bool operator<(const InterfaceValue &other) const { + return type_id_ < other.type_id_; + } + + private: + TypeId type_id_; + void *model_{nullptr}; +}; + class OpBase { public: explicit OpBase(const Operation *operation) : operation_(operation) {} @@ -58,6 +104,59 @@ class OpInterfaceBase : public OpBase { static TypeId GetInterfaceId() { return TypeId::get(); } }; +template +class ConstructInterfacesOrTraits { + public: + /// Construct method for interfaces. + static InterfaceValue *interface(InterfaceValue *p_interface) { + (void)std::initializer_list{ + 0, (PlacementConstrctInterface(p_interface), 0)...}; + return p_interface; + } + + /// Construct method for traits. + static TypeId *trait(TypeId *p_trait) { + (void)std::initializer_list{ + 0, (PlacementConstrctTrait(p_trait), 0)...}; + return p_trait; + } + + private: + /// Placement new interface. + template + static void PlacementConstrctInterface( + InterfaceValue *&p_interface) { // NOLINT + p_interface->swap(InterfaceValue::get()); + VLOG(4) << "New a interface: id[" << (p_interface->type_id()).storage() + << "]."; + ++p_interface; + } + + /// Placement new trait. + template + static void PlacementConstrctTrait(ir::TypeId *&p_trait) { // NOLINT + *p_trait = TypeId::get(); + VLOG(4) << "New a trait: id[" << p_trait->storage() << "]."; + ++p_trait; + } +}; + +/// Specialized for tuple type. +template +class ConstructInterfacesOrTraits> { + public: + /// Construct method for interfaces. + static InterfaceValue *interface(InterfaceValue *p_interface) { + return ConstructInterfacesOrTraits::interface( + p_interface); + } + + /// Construct method for traits. + static TypeId *trait(TypeId *p_trait) { + return ConstructInterfacesOrTraits::trait(p_trait); + } +}; + template class Op : public OpBase { public: @@ -68,6 +167,21 @@ class Op : public OpBase { using InterfaceList = typename Filter>::Type; -}; + static std::vector GetInterfaceMap() { + constexpr size_t interfaces_num = std::tuple_size::value; + std::vector interfaces_map(interfaces_num); + ConstructInterfacesOrTraits::interface( + interfaces_map.data()); + return interfaces_map; + } + + static std::vector GetTraitSet() { + constexpr size_t traits_num = std::tuple_size::value; + std::vector trait_set(traits_num); + auto p_first_trait = trait_set.data(); + ConstructInterfacesOrTraits::trait(p_first_trait); + return trait_set; + } +}; } // namespace ir diff --git a/paddle/ir/op_info.cc b/paddle/ir/op_info.cc new file mode 100644 index 00000000000..e68839f937d --- /dev/null +++ b/paddle/ir/op_info.cc @@ -0,0 +1,153 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/op_info.h" +#include "paddle/ir/dialect.h" +#include "paddle/ir/ir_context.h" +#include "paddle/ir/op_info_impl.h" + +namespace ir { +bool OpInfo::HasTrait(TypeId trait_id) const { + return impl_ && impl_->HasTrait(trait_id); +} + +bool OpInfo::HasInterface(TypeId interface_id) const { + return impl_ && impl_->HasInterface(interface_id); +} + +IrContext *OpInfo::ir_context() const { + return impl_ ? impl_->ir_context() : nullptr; +} + +const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; } + +void *OpInfo::GetInterfaceImpl(TypeId interface_id) const { + return impl_ ? impl_->interface_impl(interface_id) : nullptr; +} + +ir::IrContext *OpInfoImpl::ir_context() const { + return dialect()->ir_context(); +} + +void *OpInfoImpl::interface_impl(TypeId interface_id) const { + if (num_interfaces_ > 0) { + const InterfaceValue *p_first_interface = + reinterpret_cast( + reinterpret_cast(this) - + sizeof(TypeId) * num_traits_ - + sizeof(InterfaceValue) * num_interfaces_); + size_t left = 0, right = num_interfaces_; + while (left < right) { + size_t mid = (left + right) / 2; + if ((p_first_interface + mid)->type_id() == interface_id) { + return (p_first_interface + mid)->model(); + } else if ((p_first_interface + mid)->type_id() < interface_id) { + left = mid + 1; + } else { + right = mid; + } + } + } + return nullptr; +} +bool OpInfoImpl::HasTrait(TypeId trait_id) const { + if (num_traits_ > 0) { + const TypeId *p_first_trait = + reinterpret_cast(reinterpret_cast(this) - + sizeof(ir::TypeId) * num_traits_); + return std::binary_search( + p_first_trait, p_first_trait + num_traits_, trait_id); + } + return false; +} + +bool OpInfoImpl::HasInterface(TypeId interface_id) const { + if (num_interfaces_ > 0) { + const InterfaceValue *p_first_interface = + reinterpret_cast( + reinterpret_cast(this) - + sizeof(ir::TypeId) * num_traits_ - + sizeof(InterfaceValue) * num_interfaces_); + return std::binary_search(p_first_interface, + p_first_interface + num_interfaces_, + InterfaceValue(interface_id)); + } + return false; +} + +OpInfoImpl *OpInfoImpl::create(Dialect *dialect, + TypeId op_id, + const char *op_name, + std::vector &&interface_map, + const std::vector &trait_set, + size_t attributes_num, + const char *attributes_name[]) { + // (1) Malloc memory for interfaces, traits, opinfo_impl. + size_t interfaces_num = interface_map.size(); + size_t traits_num = trait_set.size(); + VLOG(4) << "Create OpInfoImpl with: " << interfaces_num << " interfaces, " + << traits_num << " traits, " << attributes_num << " attributes."; + size_t base_size = sizeof(InterfaceValue) * interfaces_num + + sizeof(TypeId) * traits_num + sizeof(OpInfoImpl); + char *base_ptr = static_cast(::operator new(base_size)); + VLOG(4) << "Malloc " << base_size << " Bytes at " + << static_cast(base_ptr); + if (interfaces_num > 0) { + std::sort(interface_map.begin(), interface_map.end()); + for (size_t index = 0; index < interfaces_num; ++index) { + new (base_ptr + index * sizeof(InterfaceValue)) + InterfaceValue(std::move(interface_map[index])); + } + base_ptr += interfaces_num * sizeof(InterfaceValue); + } + if (traits_num > 0) { + auto p_first_trait = reinterpret_cast(base_ptr); + memcpy(base_ptr, trait_set.data(), sizeof(TypeId) * traits_num); + std::sort(p_first_trait, p_first_trait + traits_num); + base_ptr += traits_num * sizeof(TypeId); + } + // Construct opinfo_impl. + OpInfoImpl *p_opinfo_impl = reinterpret_cast(base_ptr); + VLOG(4) << "Construct op_info_impl at " << p_opinfo_impl << " ......"; + OpInfoImpl *op_info = new (p_opinfo_impl) OpInfoImpl(dialect, + op_id, + op_name, + interfaces_num, + traits_num, + attributes_num, + attributes_name + + ); + return op_info; +} + +void OpInfoImpl::destroy() { + VLOG(4) << "Destroy op_info impl at " << this; + // (1) free interfaces + char *base_ptr = reinterpret_cast(this) - + sizeof(ir::TypeId) * num_traits_ - + sizeof(InterfaceValue) * num_interfaces_; + if (num_interfaces_ > 0) { + InterfaceValue *p_interface_val = + reinterpret_cast(base_ptr); + for (size_t i = 0; i < num_interfaces_; i++) { + (p_interface_val + i)->~InterfaceValue(); + } + } + // (2) free memeory + VLOG(4) << "Free base_ptr " << base_ptr; + free(base_ptr); +} + +} // namespace ir diff --git a/paddle/ir/op_info.h b/paddle/ir/op_info.h index 2828f59c294..14526c091cd 100644 --- a/paddle/ir/op_info.h +++ b/paddle/ir/op_info.h @@ -13,12 +13,13 @@ // limitations under the License. #pragma once - #include - -#include "paddle/ir/op_info_impl.h" +#include "paddle/ir/type_id.h" namespace ir { +class OpInfoImpl; +class IrContext; + class OpInfo { public: constexpr OpInfo() = default; @@ -37,24 +38,42 @@ class OpInfo { bool operator!() const { return impl_ == nullptr; } - const OpInfoImpl *impl() const { return impl_; } + IrContext *ir_context() const; + + const char *name() const; template bool HasTrait() const { - return impl_->HasTrait(); + return HasTrait(TypeId::get()); } + bool HasTrait(TypeId trait_id) const; + template bool HasInterface() const { - return impl_->HasInterface(); + return HasInterface(TypeId::get()); } + bool HasInterface(TypeId interface_id) const; + + template + typename Interface::Concept *GetInterfaceImpl() const; + friend struct std::hash; + private: + void *GetInterfaceImpl(TypeId interface_id) const; + private: const OpInfoImpl *impl_{nullptr}; // not owned }; +template +typename Interface::Concept *OpInfo::GetInterfaceImpl() const { + void *model = GetInterfaceImpl(TypeId::get()); + return reinterpret_cast(model); +} + } // namespace ir namespace std { diff --git a/paddle/ir/op_info_impl.h b/paddle/ir/op_info_impl.h index 4380866bceb..4c7a1d361f0 100644 --- a/paddle/ir/op_info_impl.h +++ b/paddle/ir/op_info_impl.h @@ -15,77 +15,16 @@ #pragma once #include -#include #include +#include #include #include "paddle/ir/builtin_attribute.h" -// #include "paddle/ir/ir_context.h" +#include "paddle/ir/op_base.h" #include "paddle/ir/type.h" namespace ir { class Dialect; -/// -/// \brief Tool template class for construct interfaces or Traits. -/// -template -class ConstructInterfacesOrTraits { - public: - /// Construct method for interfaces. - static std::pair *interface( - std::pair *p_interface) { - (void)std::initializer_list{ - 0, (PlacementConstrctInterface(p_interface), 0)...}; - return p_interface; - } - - /// Construct method for traits. - static TypeId *trait(TypeId *p_trait) { - (void)std::initializer_list{ - 0, (PlacementConstrctTrait(p_trait), 0)...}; - return p_trait; - } - - private: - /// Placement new interface. - template - static void PlacementConstrctInterface( - std::pair *&p_interface) { // NOLINT - new (&(p_interface->first)) TypeId(ir::TypeId::get()); - p_interface->second = - malloc(sizeof(typename T::template Model)); - new (p_interface->second) typename T::template Model(); - VLOG(4) << "New a interface: id[" << p_interface->first.storage() - << "], interface[" << p_interface->second << "]."; - ++p_interface; - } - - /// Placement new trait. - template - static void PlacementConstrctTrait(ir::TypeId *&p_trait) { // NOLINT - new (p_trait) TypeId(ir::TypeId::get()); - VLOG(4) << "New a trait: id[" << (*p_trait).storage() << "]."; - ++p_trait; - } -}; - -/// Specialized for tuple type. -template -class ConstructInterfacesOrTraits> { - public: - /// Construct method for interfaces. - static std::pair *interface( - std::pair *p_interface) { - return ConstructInterfacesOrTraits::interface( - p_interface); - } - - /// Construct method for traits. - static TypeId *trait(TypeId *p_trait) { - return ConstructInterfacesOrTraits::trait(p_trait); - } -}; - /// /// \brief OpInfoImpl class. /// @@ -95,143 +34,27 @@ class OpInfoImpl { /// \brief Construct and Deconstruct OpInfoImpl. The memory layout of /// OpInfoImpl is: std::pair... | TypeId... | OpInfoImpl /// - template - static OpInfoImpl *create(ir::Dialect *dialect) { - // (1) Malloc memory for interfaces, traits, opinfo_impl. - size_t interfaces_num = - std::tuple_size::value; - size_t traits_num = std::tuple_size::value; - size_t attributes_num = ConcreteOp::attributes_num(); - VLOG(4) << "Create OpInfoImpl with: " << interfaces_num << " interfaces, " - << traits_num << " traits, " << attributes_num << " attributes."; - size_t base_size = sizeof(std::pair) * interfaces_num + - sizeof(ir::TypeId) * traits_num + sizeof(OpInfoImpl); - void *base_ptr = malloc(base_size); - VLOG(4) << "Malloc " << base_size << " Bytes at " << base_ptr; - - // (2) Construct interfaces and sort by TypeId. - std::pair *p_first_interface = nullptr; - if (interfaces_num > 0) { - p_first_interface = - reinterpret_cast *>(base_ptr); - VLOG(4) << "Construct interfaces at " << p_first_interface << " ......"; - ConstructInterfacesOrTraits< - ConcreteOp, - typename ConcreteOp::InterfaceList>::interface(p_first_interface); - std::sort(p_first_interface, p_first_interface + interfaces_num); - base_ptr = reinterpret_cast(p_first_interface + interfaces_num); - } + static OpInfoImpl *create(Dialect *dialect, + TypeId op_id, + const char *op_name, + std::vector &&interface_map, + const std::vector &trait_set, + size_t attributes_num, + const char *attributes_name[]); - // (3) Construct traits and sort by TypeId. - ir::TypeId *p_first_trait = nullptr; - if (traits_num > 0) { - p_first_trait = reinterpret_cast(base_ptr); - VLOG(4) << "Construct traits at " << p_first_trait << " ......"; - ConstructInterfacesOrTraits:: - trait(p_first_trait); - std::sort(p_first_trait, p_first_trait + traits_num); - base_ptr = reinterpret_cast(p_first_trait + traits_num); - } + void destroy(); - // (4) Construct opinfo_impl. - OpInfoImpl *p_opinfo_impl = reinterpret_cast(base_ptr); - VLOG(4) << "Construct op_info_impl at " << p_opinfo_impl << " ......"; - OpInfoImpl *op_info = - new (p_opinfo_impl) OpInfoImpl(interfaces_num, - traits_num, - ConcreteOp::attributes_name_, - attributes_num, - ir::TypeId::get(), - ConcreteOp::name(), - dialect); - return op_info; - } - - void destroy() { - VLOG(4) << "Destroy op_info impl at " << this; - // (1) free interfaces - void *base_ptr = reinterpret_cast( - reinterpret_cast(this) - sizeof(ir::TypeId) * num_traits_ - - sizeof(std::pair) * num_interfaces_); - if (num_interfaces_ > 0) { - std::pair *p_first_interface = - reinterpret_cast *>(base_ptr); - for (size_t i = 0; i < num_interfaces_; i++) { - free((p_first_interface + i)->second); - } - } - // (2) free memeory - VLOG(4) << "Free base_ptr " << base_ptr; - free(base_ptr); - } + ir::IrContext *ir_context() const; - /// /// \brief Search methods for Trait or Interface. - /// - template - bool HasTrait() const { - return HasTrait(TypeId::get()); - } - - bool HasTrait(TypeId trait_id) const { - if (num_traits_ > 0) { - TypeId *p_first_trait = reinterpret_cast( - reinterpret_cast(const_cast(this)) - - sizeof(ir::TypeId) * num_traits_); - return std::binary_search( - p_first_trait, p_first_trait + num_traits_, trait_id); - } - return false; - } - - template - bool HasInterface() const { - return HasInterface(TypeId::get()); - } - - bool HasInterface(TypeId interface_id) const { - if (num_interfaces_ > 0) { - std::pair *p_first_interface = - reinterpret_cast *>( - reinterpret_cast(const_cast(this)) - - sizeof(ir::TypeId) * num_traits_ - - sizeof(std::pair) * num_interfaces_); - return std::binary_search(p_first_interface, - p_first_interface + num_interfaces_, - std::make_pair(interface_id, nullptr), - CompareInterface); - } - return false; - } + bool HasTrait(TypeId trait_id) const; - template - typename Interface::Concept *GetInterfaceImpl() const { - if (num_interfaces_ > 0) { - ir::TypeId interface_id = ir::TypeId::get(); - std::pair *p_first_interface = - reinterpret_cast *>( - reinterpret_cast(const_cast(this)) - - sizeof(ir::TypeId) * num_traits_ - - sizeof(std::pair) * num_interfaces_); - size_t left = 0; - size_t right = num_interfaces_; - while (left < right) { - size_t mid = left + (right - left) / 2; - if ((p_first_interface + mid)->first == interface_id) { - return reinterpret_cast( - (p_first_interface + mid)->second); - } else if ((p_first_interface + mid)->first < interface_id) { - left = mid + 1; - } else { - right = mid; - } - } - } - return nullptr; - } + bool HasInterface(TypeId interface_id) const; ir::TypeId id() const { return op_id_; } + void *interface_impl(TypeId interface_id) const; + const char *name() const { return op_name_; } ir::Dialect *dialect() const { return dialect_; } @@ -243,25 +66,29 @@ class OpInfoImpl { } private: - OpInfoImpl(uint32_t num_interfaces, - uint32_t num_traits, - const char **p_attributes, - uint32_t num_attributes, + OpInfoImpl(ir::Dialect *dialect, TypeId op_id, const char *op_name, - ir::Dialect *dialect) - : num_interfaces_(num_interfaces), - num_traits_(num_traits), - p_attributes_(p_attributes), - num_attributes_(num_attributes), + uint32_t num_interfaces, + uint32_t num_traits, + uint32_t num_attributes, + const char **p_attributes) + : dialect_(dialect), op_id_(op_id), op_name_(op_name), - dialect_(dialect) {} + num_interfaces_(num_interfaces), + num_traits_(num_traits), + num_attributes_(num_attributes), + p_attributes_(p_attributes) {} - static bool CompareInterface(const std::pair &a, - const std::pair &b) { - return a.first < b.first; - } + /// The dialect of this Op belong to. + ir::Dialect *dialect_; + + /// The TypeId of this Op. + TypeId op_id_; + + /// The name of this Op. + const char *op_name_; /// Interface will be recorded by std::pair. uint32_t num_interfaces_ = 0; @@ -269,20 +96,11 @@ class OpInfoImpl { /// Trait will be recorded by TypeId. uint32_t num_traits_ = 0; - /// Attributes array address. - const char **p_attributes_{nullptr}; - /// The number of attributes for this Op. uint32_t num_attributes_ = 0; - /// The TypeId of this Op. - TypeId op_id_; - - /// The name of this Op. - const char *op_name_; - - /// The dialect of this Op belong to. - ir::Dialect *dialect_; + /// Attributes array address. + const char **p_attributes_{nullptr}; }; } // namespace ir diff --git a/paddle/ir/operation.cc b/paddle/ir/operation.cc index 3cfe2b048e4..357747a7dbb 100644 --- a/paddle/ir/operation.cc +++ b/paddle/ir/operation.cc @@ -18,6 +18,13 @@ #include "paddle/ir/utils.h" namespace ir { +Operation *Operation::create(const OperationArgument &argument) { + return create(argument.inputs_, + argument.output_types_, + argument.attribute_, + argument.info_); +} + // Allocate the required memory based on the size and number of inputs, outputs, // and operators, and construct it in the order of: OpOutlineResult, // OpInlineResult, Operation, Operand. @@ -126,6 +133,8 @@ void Operation::destroy() { aligned_free(reinterpret_cast(aligned_ptr)); } +IrContext *Operation::ir_context() const { return op_info_.ir_context(); } + Operation::Operation(uint32_t num_results, uint32_t num_operands, const AttributeMap &attribute, @@ -190,9 +199,6 @@ std::string Operation::print() { return result.str(); } -std::string Operation::op_name() const { - return op_info_.impl()->dialect()->name() + "." + - std::string(op_info_.impl()->name()); -} +std::string Operation::op_name() const { return op_info_.name(); } } // namespace ir diff --git a/paddle/ir/operation.h b/paddle/ir/operation.h index 9730244851f..c2346bf180a 100644 --- a/paddle/ir/operation.h +++ b/paddle/ir/operation.h @@ -16,6 +16,7 @@ #include "paddle/ir/builtin_attribute.h" #include "paddle/ir/op_info.h" +#include "paddle/ir/operation_utils.h" #include "paddle/ir/type.h" #include "paddle/ir/value_impl.h" @@ -26,8 +27,6 @@ template class OpInterfaceBase; class Program; -using AttributeMap = std::unordered_map; - class alignas(8) Operation final { public: /// @@ -40,12 +39,15 @@ class alignas(8) Operation final { const std::vector &output_types, const AttributeMap &attribute, ir::OpInfo op_info); + static Operation *create(const OperationArgument &op_argument); /// /// \brief Destroy the operation objects and free memeory by create(). /// void destroy(); + IrContext *ir_context() const; + ir::OpResult GetResultByIndex(uint32_t index); ir::OpOperand GetOperandByIndex(uint32_t index); @@ -99,14 +101,17 @@ class alignas(8) Operation final { struct CastUtil, T>::value>::type> { - static T call(const Operation *op) { return T(op); } + static T call(const Operation *op) { + return T(op->HasTrait() ? op : nullptr); + } }; template struct CastUtil, T>::value>::type> { static T call(const Operation *op) { - return T(op, op->op_info_.impl()->GetInterfaceImpl()); + typename T::Concept *interface_impl = op->op_info().GetInterfaceImpl(); + return interface_impl ? T(op, interface_impl) : T(nullptr, nullptr); } }; diff --git a/paddle/ir/operation_utils.cc b/paddle/ir/operation_utils.cc new file mode 100644 index 00000000000..07b217e8c05 --- /dev/null +++ b/paddle/ir/operation_utils.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/operation_utils.h" + +namespace ir { +OperationArgument::OperationArgument(IrContext* ir_context, std::string name) { + info_ = ir_context->GetRegisteredOpInfo(name); +} + +OperationArgument::OperationArgument(OpInfo info, + const std::vector& operands, + const std::vector& types, + const AttributeMap& named_attr) + : info_(info), + inputs_(operands), + output_types_(types), + attribute_(named_attr) {} + +} // namespace ir diff --git a/paddle/ir/operation_utils.h b/paddle/ir/operation_utils.h new file mode 100644 index 00000000000..26f831ee403 --- /dev/null +++ b/paddle/ir/operation_utils.h @@ -0,0 +1,83 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/builtin_attribute.h" +#include "paddle/ir/op_info.h" +#include "paddle/ir/type.h" +#include "paddle/ir/value_impl.h" + +namespace ir { + +using AttributeMap = std::unordered_map; + +//===----------------------------------------------------------------------===// +// OperationArgument +//===----------------------------------------------------------------------===// + +// This represents an operation arguments in an combined form, suitable for use +// with the builder APIs. +struct OperationArgument { + OpInfo info_; + std::vector inputs_; + std::vector output_types_; + AttributeMap attribute_; + + public: + OperationArgument(IrContext* ir_context, std::string name); + explicit OperationArgument(OpInfo info) : info_(info) {} + OperationArgument(OpInfo info, + const std::vector& operands, + const std::vector& types, + const AttributeMap& named_attr = {}); + + template + void addOperands(InputIt first, InputIt last); + + template + void addTypes(InputIt first, InputIt last); + + /// Add an attribute with the specified name. + void addAttribute(const std::string& name, Attribute attr) { + attribute_[name] = attr; + } + /// Add an array of named attributes. + template + void addAttributes(InputIt first, InputIt last); + /// Get the context held by this operation state. + IrContext* getContext() const { return info_.ir_context(); } +}; + +template +void OperationArgument::addOperands(InputIt first, InputIt last) { + while (first != last) { + inputs_.emplace_back(*first++); + } +} +template +void OperationArgument::addTypes(InputIt first, InputIt last) { + while (first != last) { + output_types_.emplace_back(*first++); + } +} +template +void OperationArgument::addAttributes(InputIt first, InputIt last) { + while (first != last) { + attribute_[first->first] = first->second; + ++first; + } +} + +} // namespace ir diff --git a/paddle/ir/type_id.h b/paddle/ir/type_id.h index 33cd09b5d4a..736152b4ff6 100644 --- a/paddle/ir/type_id.h +++ b/paddle/ir/type_id.h @@ -45,6 +45,8 @@ class TypeId { return TypeId(&instance); } + TypeId() = default; + TypeId(const TypeId &other) = default; TypeId &operator=(const TypeId &other) = default; @@ -77,9 +79,8 @@ class TypeId { /// explicit TypeId(const Storage *storage) : storage_(storage) {} - const Storage *storage_; + const Storage *storage_{nullptr}; }; - } // namespace ir namespace std { diff --git a/test/cpp/ir/ir_op_test.cc b/test/cpp/ir/ir_op_test.cc index f6fbff5dd3b..fc23166fa41 100644 --- a/test/cpp/ir/ir_op_test.cc +++ b/test/cpp/ir/ir_op_test.cc @@ -47,9 +47,8 @@ class InferShapeInterface : public ir::OpInterfaceBase { } Model() : Concept(InferShape) { - if (sizeof(Model) != sizeof(Concept)) { - throw("sizeof(Model) != sizeof(Concept)"); - } + static_assert(sizeof(Model) == sizeof(Concept), + "sizeof(Model) != sizeof(Concept)"); } }; @@ -66,25 +65,27 @@ class InferShapeInterface : public ir::OpInterfaceBase { class Operation1 : public ir::Op { public: using Op::Op; - static const char *name() { return "Operation1"; } - static const char *attributes_name_[]; - static uint32_t attributes_num() { return 2; } + static const char *name() { return "test.operation1"; } + static constexpr uint32_t attributes_num = 2; + static const char *attributes_name[attributes_num]; }; -const char *Operation1::attributes_name_[] = {"op1_attr1", "op1_attr2"}; +const char *Operation1::attributes_name[attributes_num] = {"op1_attr1", + "op1_attr2"}; // Define op2. class Operation2 : public ir::Op { public: using Op::Op; - static const char *name() { return "Operation2"; } - static const char *attributes_name_[]; - static uint32_t attributes_num() { return 2; } + static const char *name() { return "test.operation2"; } + static constexpr uint32_t attributes_num = 2; + static const char *attributes_name[attributes_num]; static void InferShape() { std::cout << "This is op2's InferShape interface." << std::endl; } }; -const char *Operation2::attributes_name_[] = {"op2_attr1", "op2_attr2"}; +const char *Operation2::attributes_name[attributes_num] = {"op2_attr1", + "op2_attr2"}; // Define a dialect, op1 and op2 will be registered by this dialect. class TestDialect : public ir::Dialect { @@ -93,7 +94,7 @@ class TestDialect : public ir::Dialect { : ir::Dialect(name(), context, ir::TypeId::get()) { initialize(); } - static const char *name() { return "op_test"; } + static const char *name() { return "test"; } private: void initialize() { RegisterOps(); } @@ -116,19 +117,17 @@ TEST(op_test, op_test) { std::cout << test_dialect << std::endl; // (2) Get registered operations. - std::string op1_name = - test_dialect->name() + "." + std::string(Operation1::name()); - ir::OpInfoImpl *op1_info = ctx->GetRegisteredOpInfo(op1_name); + std::string op1_name = Operation1::name(); + ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name); EXPECT_EQ(op1_info != nullptr, true); - std::string op2_name = - test_dialect->name() + "." + std::string(Operation2::name()); - ir::OpInfoImpl *op2_info = ctx->GetRegisteredOpInfo(op2_name); + std::string op2_name = Operation2::name(); + ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name); EXPECT_EQ(op2_info != nullptr, true); - EXPECT_EQ(op1_info->HasTrait(), false); - EXPECT_EQ(op1_info->HasInterface(), false); - EXPECT_EQ(op2_info->HasTrait(), true); - EXPECT_EQ(op2_info->HasInterface(), true); + EXPECT_EQ(op1_info.HasTrait(), false); + EXPECT_EQ(op1_info.HasInterface(), false); + EXPECT_EQ(op2_info.HasTrait(), true); + EXPECT_EQ(op2_info.HasInterface(), true); // (3) Test uses for op. std::vector op_inputs = {}; diff --git a/test/cpp/ir/ir_program_test.cc b/test/cpp/ir/ir_program_test.cc index 91e0fbf9bf4..c37b0f2040b 100644 --- a/test/cpp/ir/ir_program_test.cc +++ b/test/cpp/ir/ir_program_test.cc @@ -31,11 +31,10 @@ class AddOp : public ir::Op { public: using Op::Op; - static const char *name() { return "Add"; } - static const char **attributes_name_; - static uint32_t attributes_num() { return 0; } + static const char *name() { return "test.add"; } + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; }; -const char **AddOp::attributes_name_ = nullptr; TEST(program_test, program) { // (1) Init environment. @@ -78,9 +77,8 @@ TEST(program_test, program) { EXPECT_EQ(program.parameters_num() == 2, true); // (4) Def a = GetParameterOp("a"), and create DenseTensor for a. - std::string op1_name = - builtin_dialect->name() + "." + std::string(ir::GetParameterOp::name()); - ir::OpInfoImpl *op1_info = ctx->GetRegisteredOpInfo(op1_name); + std::string op1_name = ir::GetParameterOp::name(); + ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name); std::unordered_map op1_attribute{ {"parameter_name", ir::StrAttribute::get(ctx, "a")}}; ir::Operation *op1 = @@ -112,7 +110,7 @@ TEST(program_test, program) { // (5) Def b = GetParameterOp("b"), and create DenseTensor for b. std::string op2_name = builtin_dialect->name() + "." + std::string(ir::GetParameterOp::name()); - ir::OpInfoImpl *op2_info = ctx->GetRegisteredOpInfo(op2_name); + ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name); std::unordered_map op2_attribute{ {"parameter_name", ir::StrAttribute::get(ctx, "b")}}; ir::Operation *op2 = @@ -142,7 +140,7 @@ TEST(program_test, program) { // (6) Def c = AddOp(a, b), execute this op. std::string op3_name = builtin_dialect->name() + "." + std::string(AddOp::name()); - ir::OpInfoImpl *op3_info = ctx->GetRegisteredOpInfo(op3_name); + ir::OpInfo op3_info = ctx->GetRegisteredOpInfo(op3_name); std::unordered_map op3_attribute; ir::Operation *op3 = ir::Operation::create( {op1->GetResultByIndex(0), op2->GetResultByIndex(0)}, @@ -173,7 +171,7 @@ TEST(program_test, program) { // (7) Def SetParameterOp(c, "c") std::string op4_name = builtin_dialect->name() + "." + std::string(ir::SetParameterOp::name()); - ir::OpInfoImpl *op4_info = ctx->GetRegisteredOpInfo(op4_name); + ir::OpInfo op4_info = ctx->GetRegisteredOpInfo(op4_name); std::unordered_map op4_attribute{ {"parameter_name", ir::StrAttribute::get(ctx, "c")}}; ir::Operation *op4 = ir::Operation::create( -- GitLab