From 8dfcf03ed670787a4a079779be0075422a05f09e Mon Sep 17 00:00:00 2001 From: winter-wang <78149749+winter-wang@users.noreply.github.com> Date: Fri, 9 Jun 2023 23:02:55 +0800 Subject: [PATCH] [IR] add positon member in new ir operation. (#54483) --- paddle/ir/core/block.cc | 19 ++-- paddle/ir/core/block.h | 14 +-- paddle/ir/core/builder.h | 6 +- paddle/ir/core/ir_context.cc | 65 ++++++------ paddle/ir/core/ir_context.h | 7 ++ paddle/ir/core/op_info.cc | 119 +--------------------- paddle/ir/core/op_info.h | 13 ++- paddle/ir/core/op_info_impl.cc | 140 ++++++++++++++++++++++++++ paddle/ir/core/op_info_impl.h | 36 +++---- paddle/ir/core/operation.cc | 2 +- paddle/ir/core/operation.h | 12 ++- paddle/ir/core/region.cc | 8 +- paddle/ir/core/region.h | 2 +- test/cpp/ir/core/CMakeLists.txt | 2 + test/cpp/ir/core/ir_op_test.cc | 4 +- test/cpp/ir/core/ir_program_test.cc | 4 +- test/cpp/ir/core/ir_value_test.cc | 8 +- test/cpp/ir/core/op_info_test.cc | 44 ++++++++ test/cpp/ir/pass/pass_manager_test.cc | 2 +- 19 files changed, 293 insertions(+), 214 deletions(-) create mode 100644 paddle/ir/core/op_info_impl.cc create mode 100644 test/cpp/ir/core/op_info_test.cc diff --git a/paddle/ir/core/block.cc b/paddle/ir/core/block.cc index 42a4362eb77..a0f74cfa6a6 100644 --- a/paddle/ir/core/block.cc +++ b/paddle/ir/core/block.cc @@ -13,22 +13,23 @@ // limitations under the License. #include "paddle/ir/core/block.h" +#include "paddle/ir/core/operation.h" +#include "paddle/ir/core/region.h" namespace ir { Block::~Block() { clear(); } -void Block::push_back(Operation *op) { - op->set_parent(this); - ops_.push_back(op); -} +void Block::push_back(Operation *op) { insert(ops_.end(), op); } + +void Block::push_front(Operation *op) { insert(ops_.begin(), op); } -void Block::push_front(Operation *op) { - op->set_parent(this); - ops_.push_front(op); +Operation *Block::GetParentOp() const { + return parent_ ? parent_->GetParent() : nullptr; } Block::iterator Block::insert(const_iterator iterator, Operation *op) { - op->set_parent(this); - return ops_.insert(iterator, op); + Block::iterator iter = ops_.insert(iterator, op); + op->SetParent(this, iter); + return iter; } void Block::clear() { diff --git a/paddle/ir/core/block.h b/paddle/ir/core/block.h index 6534bdd60b7..5c4bf08019d 100644 --- a/paddle/ir/core/block.h +++ b/paddle/ir/core/block.h @@ -16,10 +16,10 @@ #include #include -#include "paddle/ir/core/operation.h" namespace ir { class Region; +class Operation; class Block { public: @@ -30,7 +30,9 @@ class Block { Block() = default; ~Block(); - Region *parent() const { return parent_; } + Region *GetParent() const { return parent_; } + Operation *GetParentOp() const; + bool empty() const { return ops_.empty(); } size_t size() const { return ops_.size(); } @@ -46,18 +48,12 @@ class Block { iterator insert(const_iterator iterator, Operation *op); void clear(); - Region *GetParentRegion() const { return parent_; } - - Operation *GetParentOp() const { - return parent_ ? parent_->GetParentOp() : nullptr; - } - private: Block(Block &) = delete; Block &operator=(const Block &) = delete; friend class Region; - void set_parent(Region *parent) { parent_ = parent; } + void SetParent(Region *parent) { parent_ = parent; } private: Region *parent_; // not owned diff --git a/paddle/ir/core/builder.h b/paddle/ir/core/builder.h index 80255c3280b..347e725ec47 100644 --- a/paddle/ir/core/builder.h +++ b/paddle/ir/core/builder.h @@ -26,10 +26,10 @@ namespace ir { /// class Builder { public: - explicit Builder(IrContext *context, - Block *block, - Block::iterator insert_point) + Builder(IrContext *context, Block *block, Block::iterator insert_point) : context_(context), block_(block), insert_point_(insert_point) {} + Builder(IrContext *context, Block *block) + : Builder(context, block, block->end()) {} static Builder AtBlockBegin(IrContext *context, Block *block) { return Builder(context, block, block->begin()); diff --git a/paddle/ir/core/ir_context.cc b/paddle/ir/core/ir_context.cc index 1525ff772dc..36621bf099e 100644 --- a/paddle/ir/core/ir_context.cc +++ b/paddle/ir/core/ir_context.cc @@ -49,7 +49,7 @@ class IrContextImpl { registed_dialect_.clear(); for (auto &op_map : registed_op_infos_) { - op_map.second->destroy(); + OpInfoImpl::Destroy(op_map.second); } registed_op_infos_.clear(); } @@ -103,24 +103,25 @@ class IrContextImpl { return registed_op_infos_.find(name) != registed_op_infos_.end(); } - void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) { + void RegisterOpInfo(const std::string &name, OpInfo info) { std::lock_guard guard(registed_op_infos_lock_); VLOG(4) << "Register an operation of: [Name=" << name - << ", OpInfoImpl ptr=" << opinfo << "]."; - registed_op_infos_.emplace(name, opinfo); + << ", OpInfo ptr=" << info.AsOpaquePointer() << "]."; + registed_op_infos_.emplace(name, info); } - OpInfoImpl *GetOpInfo(const std::string &name) { + OpInfo GetOpInfo(const std::string &name) { std::lock_guard guard(registed_op_infos_lock_); auto iter = registed_op_infos_.find(name); if (iter != registed_op_infos_.end()) { - VLOG(4) << "Found a cached operation of: [name=" << name - << ", OpInfoImpl ptr=" << iter->second << "]."; + VLOG(4) << "Found a cached OpInfo of: [name=" << name + << ", OpInfo: ptr=" << iter->second.AsOpaquePointer() << "]."; return iter->second; } LOG(WARNING) << "No cache found operation of: [Name=" << name << "]."; - return nullptr; + return OpInfo(); } + const OpInfoMap ®istered_op_info_map() { return registed_op_infos_; } void RegisterDialect(std::string name, Dialect *dialect) { std::lock_guard guard(registed_dialect_lock_); @@ -170,7 +171,7 @@ class IrContextImpl { ir::SpinLock registed_dialect_lock_; // The Op registered in the context. - std::unordered_map registed_op_infos_; + OpInfoMap registed_op_infos_; ir::SpinLock registed_op_infos_lock_; ir::SpinLock destructor_lock_; @@ -282,43 +283,39 @@ void IrContext::RegisterOpInfo(Dialect *dialect, if (impl().IsOpInfoRegistered(name)) { LOG(WARNING) << name << " op already registered."; } else { - OpInfoImpl *opinfo = OpInfoImpl::create(dialect, - op_id, - name, - std::move(interface_map), - trait_set, - attributes_num, - attributes_name, - verify); - impl().RegisterOpInfo(name, opinfo); + OpInfo info = OpInfoImpl::Create(dialect, + op_id, + name, + std::move(interface_map), + trait_set, + attributes_num, + attributes_name, + verify); + impl().RegisterOpInfo(name, info); VLOG(4) << name << " op registered into IrContext. --->"; } } OpInfo IrContext::GetRegisteredOpInfo(const std::string &name) { - OpInfoImpl *rtn = impl().GetOpInfo(name); - return rtn ? rtn : nullptr; + return impl().GetOpInfo(name); +} + +const OpInfoMap &IrContext::registered_op_info_map() { + return impl().registered_op_info_map(); } const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) { - auto &impl = ctx->impl(); - AbstractType *abstract_type = impl.GetAbstractType(type_id); - if (abstract_type) { - return *abstract_type; - } else { - throw("Abstract type not found in IrContext."); - } + AbstractType *abstract_type = ctx->impl().GetAbstractType(type_id); + IR_ENFORCE(abstract_type, "Abstract type not found in IrContext."); + return *abstract_type; } const AbstractAttribute &AbstractAttribute::lookup(TypeId type_id, IrContext *ctx) { - auto &impl = ctx->impl(); - AbstractAttribute *abstract_attribute = impl.GetAbstractAttribute(type_id); - if (abstract_attribute) { - return *abstract_attribute; - } else { - throw("Abstract attribute not found in IrContext."); - } + AbstractAttribute *abstract_attribute = + ctx->impl().GetAbstractAttribute(type_id); + IR_ENFORCE(abstract_attribute, "Abstract attribute not found in IrContext."); + return *abstract_attribute; } BFloat16Type BFloat16Type::get(IrContext *ctx) { diff --git a/paddle/ir/core/ir_context.h b/paddle/ir/core/ir_context.h index c5fb7fa5550..9628a46aed8 100644 --- a/paddle/ir/core/ir_context.h +++ b/paddle/ir/core/ir_context.h @@ -31,6 +31,8 @@ class Type; class OpResult; class Attribute; +using OpInfoMap = std::unordered_map; + /// /// \brief IrContext is a global parameterless class used to store and manage /// Type, Attribute and other related data structures. @@ -116,6 +118,11 @@ class IrContext { /// OpInfo GetRegisteredOpInfo(const std::string &name); + /// + /// \brief Get registered operaiton infomation map. + /// + const OpInfoMap ®istered_op_info_map(); + /// /// \brief Get the dialect of the DialectT class in the context, ff not found, /// create and register to context. diff --git a/paddle/ir/core/op_info.cc b/paddle/ir/core/op_info.cc index adbe9d06105..b52cdf11387 100644 --- a/paddle/ir/core/op_info.cc +++ b/paddle/ir/core/op_info.cc @@ -41,123 +41,6 @@ void OpInfo::Verify(const std::vector &inputs, } void *OpInfo::GetInterfaceImpl(TypeId interface_id) const { - return impl_ ? impl_->interface_impl(interface_id) : nullptr; + return impl_ ? impl_->GetInterfaceImpl(interface_id) : nullptr; } - -ir::IrContext *OpInfoImpl::ir_context() const { - return dialect()->ir_context(); -} - -void *OpInfoImpl::interface_impl(TypeId interface_id) const { - if (num_interfaces_ > 0) { - const InterfaceValue *p_first_interface = - reinterpret_cast( - reinterpret_cast(this) - - sizeof(TypeId) * num_traits_ - - sizeof(InterfaceValue) * num_interfaces_); - size_t left = 0, right = num_interfaces_; - while (left < right) { - size_t mid = (left + right) / 2; - if ((p_first_interface + mid)->type_id() == interface_id) { - return (p_first_interface + mid)->model(); - } else if ((p_first_interface + mid)->type_id() < interface_id) { - left = mid + 1; - } else { - right = mid; - } - } - } - return nullptr; -} -bool OpInfoImpl::HasTrait(TypeId trait_id) const { - if (num_traits_ > 0) { - const TypeId *p_first_trait = - reinterpret_cast(reinterpret_cast(this) - - sizeof(ir::TypeId) * num_traits_); - return std::binary_search( - p_first_trait, p_first_trait + num_traits_, trait_id); - } - return false; -} - -bool OpInfoImpl::HasInterface(TypeId interface_id) const { - if (num_interfaces_ > 0) { - const InterfaceValue *p_first_interface = - reinterpret_cast( - reinterpret_cast(this) - - sizeof(ir::TypeId) * num_traits_ - - sizeof(InterfaceValue) * num_interfaces_); - return std::binary_search(p_first_interface, - p_first_interface + num_interfaces_, - InterfaceValue(interface_id)); - } - return false; -} - -OpInfoImpl *OpInfoImpl::create(Dialect *dialect, - TypeId op_id, - const char *op_name, - std::vector &&interface_map, - const std::vector &trait_set, - size_t attributes_num, - const char *attributes_name[], - VerifyPtr verify) { - // (1) Malloc memory for interfaces, traits, opinfo_impl. - size_t interfaces_num = interface_map.size(); - size_t traits_num = trait_set.size(); - VLOG(4) << "Create OpInfoImpl with: " << interfaces_num << " interfaces, " - << traits_num << " traits, " << attributes_num << " attributes."; - size_t base_size = sizeof(InterfaceValue) * interfaces_num + - sizeof(TypeId) * traits_num + sizeof(OpInfoImpl); - char *base_ptr = static_cast(::operator new(base_size)); - VLOG(4) << "Malloc " << base_size << " Bytes at " - << static_cast(base_ptr); - if (interfaces_num > 0) { - std::sort(interface_map.begin(), interface_map.end()); - for (size_t index = 0; index < interfaces_num; ++index) { - new (base_ptr + index * sizeof(InterfaceValue)) - InterfaceValue(std::move(interface_map[index])); - } - base_ptr += interfaces_num * sizeof(InterfaceValue); - } - if (traits_num > 0) { - auto p_first_trait = reinterpret_cast(base_ptr); - memcpy(base_ptr, trait_set.data(), sizeof(TypeId) * traits_num); - std::sort(p_first_trait, p_first_trait + traits_num); - base_ptr += traits_num * sizeof(TypeId); - } - // Construct opinfo_impl. - OpInfoImpl *p_opinfo_impl = reinterpret_cast(base_ptr); - VLOG(4) << "Construct op_info_impl at " << p_opinfo_impl << " ......"; - OpInfoImpl *op_info = new (p_opinfo_impl) OpInfoImpl(dialect, - op_id, - op_name, - interfaces_num, - traits_num, - attributes_num, - attributes_name, - verify - - ); - return op_info; -} - -void OpInfoImpl::destroy() { - VLOG(4) << "Destroy op_info impl at " << this; - // (1) free interfaces - char *base_ptr = reinterpret_cast(this) - - sizeof(ir::TypeId) * num_traits_ - - sizeof(InterfaceValue) * num_interfaces_; - if (num_interfaces_ > 0) { - InterfaceValue *p_interface_val = - reinterpret_cast(base_ptr); - for (size_t i = 0; i < num_interfaces_; i++) { - (p_interface_val + i)->~InterfaceValue(); - } - } - // (2) free memeory - VLOG(4) << "Free base_ptr " << base_ptr; - free(base_ptr); -} - } // namespace ir diff --git a/paddle/ir/core/op_info.h b/paddle/ir/core/op_info.h index 7544342e412..1d8cf19f5c9 100644 --- a/paddle/ir/core/op_info.h +++ b/paddle/ir/core/op_info.h @@ -28,8 +28,6 @@ class OpInfo { public: constexpr OpInfo() = default; - OpInfo(const OpInfoImpl *impl) : impl_(impl) {} // NOLINT - OpInfo(const OpInfo &other) = default; OpInfo &operator=(const OpInfo &other) = default; @@ -52,8 +50,6 @@ class OpInfo { const std::vector &outputs, const std::unordered_map &attributes); - const OpInfoImpl *impl() const; - template bool HasTrait() const { return HasTrait(TypeId::get()); @@ -71,13 +67,20 @@ class OpInfo { template typename Interface::Concept *GetInterfaceImpl() const; + void *AsOpaquePointer() const { return impl_; } + static OpInfo RecoverFromOpaquePointer(void *impl) { + return static_cast(impl); + } + + friend class OpInfoImpl; friend struct std::hash; private: + OpInfo(OpInfoImpl *impl) : impl_(impl) {} // NOLINT void *GetInterfaceImpl(TypeId interface_id) const; private: - const OpInfoImpl *impl_{nullptr}; // not owned + OpInfoImpl *impl_{nullptr}; // not owned }; template diff --git a/paddle/ir/core/op_info_impl.cc b/paddle/ir/core/op_info_impl.cc new file mode 100644 index 00000000000..57d15b22c28 --- /dev/null +++ b/paddle/ir/core/op_info_impl.cc @@ -0,0 +1,140 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/core/op_info_impl.h" +#include "paddle/ir/core/dialect.h" + +namespace ir { +OpInfo OpInfoImpl::Create(Dialect *dialect, + TypeId op_id, + const char *op_name, + std::vector &&interface_map, + const std::vector &trait_set, + size_t attributes_num, + const char *attributes_name[], + VerifyPtr verify) { + // (1) Malloc memory for interfaces, traits, opinfo_impl. + size_t interfaces_num = interface_map.size(); + size_t traits_num = trait_set.size(); + VLOG(4) << "Create OpInfoImpl with: " << interfaces_num << " interfaces, " + << traits_num << " traits, " << attributes_num << " attributes."; + size_t base_size = sizeof(InterfaceValue) * interfaces_num + + sizeof(TypeId) * traits_num + sizeof(OpInfoImpl); + char *base_ptr = static_cast(::operator new(base_size)); + VLOG(4) << "Malloc " << base_size << " Bytes at " + << static_cast(base_ptr); + if (interfaces_num > 0) { + std::sort(interface_map.begin(), interface_map.end()); + for (size_t index = 0; index < interfaces_num; ++index) { + new (base_ptr + index * sizeof(InterfaceValue)) + InterfaceValue(std::move(interface_map[index])); + } + base_ptr += interfaces_num * sizeof(InterfaceValue); + } + if (traits_num > 0) { + auto p_first_trait = reinterpret_cast(base_ptr); + memcpy(base_ptr, trait_set.data(), sizeof(TypeId) * traits_num); + std::sort(p_first_trait, p_first_trait + traits_num); + base_ptr += traits_num * sizeof(TypeId); + } + // Construct OpInfoImpl. + VLOG(4) << "Construct OpInfoImpl at " << base_ptr << " ......"; + OpInfo op_info = new (base_ptr) OpInfoImpl(dialect, + op_id, + op_name, + interfaces_num, + traits_num, + attributes_num, + attributes_name, + verify); + return op_info; +} +void OpInfoImpl::Destroy(OpInfo info) { + if (info.impl_) { + info.impl_->Destroy(); + } else { + LOG(WARNING) << "A nullptr OpInfo is destoryed."; + } +} + +ir::IrContext *OpInfoImpl::ir_context() const { + return dialect_ ? dialect_->ir_context() : nullptr; +} + +bool OpInfoImpl::HasTrait(TypeId trait_id) const { + if (num_traits_ > 0) { + const TypeId *p_first_trait = + reinterpret_cast(reinterpret_cast(this) - + sizeof(ir::TypeId) * num_traits_); + return std::binary_search( + p_first_trait, p_first_trait + num_traits_, trait_id); + } + return false; +} + +bool OpInfoImpl::HasInterface(TypeId interface_id) const { + if (num_interfaces_ > 0) { + const InterfaceValue *p_first_interface = + reinterpret_cast( + reinterpret_cast(this) - + sizeof(ir::TypeId) * num_traits_ - + sizeof(InterfaceValue) * num_interfaces_); + return std::binary_search(p_first_interface, + p_first_interface + num_interfaces_, + InterfaceValue(interface_id)); + } + return false; +} + +void *OpInfoImpl::GetInterfaceImpl(TypeId interface_id) const { + if (num_interfaces_ > 0) { + const InterfaceValue *p_first_interface = + reinterpret_cast( + reinterpret_cast(this) - + sizeof(TypeId) * num_traits_ - + sizeof(InterfaceValue) * num_interfaces_); + size_t left = 0, right = num_interfaces_; + while (left < right) { + size_t mid = (left + right) / 2; + if ((p_first_interface + mid)->type_id() == interface_id) { + return (p_first_interface + mid)->model(); + } else if ((p_first_interface + mid)->type_id() < interface_id) { + left = mid + 1; + } else { + right = mid; + } + } + } + return nullptr; +} + +void OpInfoImpl::Destroy() { + VLOG(4) << "Destroy op_info impl at " << this; + // (1) free interfaces + char *base_ptr = reinterpret_cast(this) - + sizeof(ir::TypeId) * num_traits_ - + sizeof(InterfaceValue) * num_interfaces_; + if (num_interfaces_ > 0) { + InterfaceValue *p_interface_val = + reinterpret_cast(base_ptr); + for (size_t i = 0; i < num_interfaces_; i++) { + (p_interface_val + i)->~InterfaceValue(); + } + } + // (2) free memeory + VLOG(4) << "Free base_ptr " << base_ptr; + free(base_ptr); +} + +} // namespace ir diff --git a/paddle/ir/core/op_info_impl.h b/paddle/ir/core/op_info_impl.h index 409417eee53..e5d8fd25aaf 100644 --- a/paddle/ir/core/op_info_impl.h +++ b/paddle/ir/core/op_info_impl.h @@ -38,40 +38,39 @@ class OpInfoImpl { /// \brief Construct and Deconstruct OpInfoImpl. The memory layout of /// OpInfoImpl is: std::pair... | TypeId... | OpInfoImpl /// - static OpInfoImpl *create(Dialect *dialect, - TypeId op_id, - const char *op_name, - std::vector &&interface_map, - const std::vector &trait_set, - size_t attributes_num, - const char *attributes_name[], - VerifyPtr verify); + static OpInfo Create(Dialect *dialect, + TypeId op_id, + const char *op_name, + std::vector &&interface_map, + const std::vector &trait_set, + size_t attributes_num, + const char *attributes_name[], + VerifyPtr verify); + static void Destroy(OpInfo info); - void destroy(); + TypeId id() const { return op_id_; } - ir::IrContext *ir_context() const; + Dialect *dialect() const { return dialect_; } + + VerifyPtr verify() const { return verify_; } + + IrContext *ir_context() const; /// \brief Search methods for Trait or Interface. bool HasTrait(TypeId trait_id) const; bool HasInterface(TypeId interface_id) const; - ir::TypeId id() const { return op_id_; } - - void *interface_impl(TypeId interface_id) const; + void *GetInterfaceImpl(TypeId interface_id) const; const char *name() const { return op_name_; } - ir::Dialect *dialect() const { return dialect_; } - uint32_t AttributeNum() const { return num_attributes_; } const char *GetAttributeByIndex(size_t idx) const { return idx < num_attributes_ ? p_attributes_[idx] : nullptr; } - VerifyPtr verify() const { return verify_; } - private: OpInfoImpl(ir::Dialect *dialect, TypeId op_id, @@ -89,9 +88,10 @@ class OpInfoImpl { num_attributes_(num_attributes), p_attributes_(p_attributes), verify_(verify) {} + void Destroy(); /// The dialect of this Op belong to. - ir::Dialect *dialect_; + Dialect *dialect_; /// The TypeId of this Op. TypeId op_id_; diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index 198259c2416..26dc06e29b5 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -213,7 +213,7 @@ std::string Operation::name() const { } Region *Operation::GetParentRegion() const { - return parent_ ? parent_->GetParentRegion() : nullptr; + return parent_ ? parent_->GetParent() : nullptr; } Operation *Operation::GetParentOp() const { diff --git a/paddle/ir/core/operation.h b/paddle/ir/core/operation.h index f7f7f19965c..87149fa562e 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/ir/core/operation.h @@ -15,6 +15,7 @@ #pragma once #include +#include "paddle/ir/core/block.h" #include "paddle/ir/core/op_info.h" #include "paddle/ir/core/operation_utils.h" #include "paddle/ir/core/type.h" @@ -22,7 +23,6 @@ namespace ir { class OpBase; class Program; -class Block; class OpOperand; class OpResult; @@ -85,7 +85,7 @@ class alignas(8) Operation final { return info_.HasInterface(); } - Block *GetParentBlock() const { return parent_; } + Block *GetParent() const { return parent_; } Region *GetParentRegion() const; @@ -96,6 +96,8 @@ class alignas(8) Operation final { /// Returns the region held by this operation at position 'index'. Region &GetRegion(unsigned index); + operator Block::iterator() { return position_; } + private: Operation(const AttributeMap &attribute, ir::OpInfo op_info, @@ -111,7 +113,10 @@ class alignas(8) Operation final { }; friend class Block; - void set_parent(Block *parent) { parent_ = parent; } + void SetParent(Block *parent, const Block::iterator &position) { + parent_ = parent; + position_ = position; + } template struct CastUtil< @@ -130,6 +135,7 @@ class alignas(8) Operation final { Region *regions_{nullptr}; Block *parent_{nullptr}; + Block::iterator position_; }; } // namespace ir diff --git a/paddle/ir/core/region.cc b/paddle/ir/core/region.cc index e434bbec494..854df9cf9bb 100644 --- a/paddle/ir/core/region.cc +++ b/paddle/ir/core/region.cc @@ -19,26 +19,26 @@ namespace ir { Region::~Region() { clear(); } void Region::push_back(Block *block) { - block->set_parent(this); + block->SetParent(this); blocks_.push_back(block); } void Region::emplace_back() { push_back(new Block); } void Region::push_front(Block *block) { - block->set_parent(this); + block->SetParent(this); blocks_.push_front(block); } Region::iterator Region::insert(const_iterator position, Block *block) { - block->set_parent(this); + block->SetParent(this); return blocks_.insert(position, block); } void Region::TakeBody(Region &&other) { clear(); blocks_.swap(other.blocks_); for (auto &block : blocks_) { - block->set_parent(this); + block->SetParent(this); } } diff --git a/paddle/ir/core/region.h b/paddle/ir/core/region.h index 9a5a2f7a9b7..fa150e4889a 100644 --- a/paddle/ir/core/region.h +++ b/paddle/ir/core/region.h @@ -48,7 +48,7 @@ class Region { void TakeBody(Region &&other); - Operation *GetParentOp() const { return parent_; } + Operation *GetParent() const { return parent_; } private: Region(Region &) = delete; diff --git a/test/cpp/ir/core/CMakeLists.txt b/test/cpp/ir/core/CMakeLists.txt index a7817ffc02b..d033303fd25 100644 --- a/test/cpp/ir/core/CMakeLists.txt +++ b/test/cpp/ir/core/CMakeLists.txt @@ -71,3 +71,5 @@ cc_test_old( gtest new_ir pd_dialect) + +cc_test_old(ir_op_info_test SRCS op_info_test.cc DEPS gtest new_ir) diff --git a/test/cpp/ir/core/ir_op_test.cc b/test/cpp/ir/core/ir_op_test.cc index 37b29b0175e..c071a0d45df 100644 --- a/test/cpp/ir/core/ir_op_test.cc +++ b/test/cpp/ir/core/ir_op_test.cc @@ -163,10 +163,10 @@ TEST(op_test, op_test) { // (2) Get registered operations. std::string op1_name = Operation1::name(); ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name); - EXPECT_EQ(op1_info != nullptr, true); + EXPECT_TRUE(op1_info); std::string op2_name = Operation2::name(); ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name); - EXPECT_EQ(op2_info != nullptr, true); + EXPECT_TRUE(op2_info); EXPECT_EQ(op1_info.HasTrait(), false); EXPECT_EQ(op1_info.HasInterface(), false); EXPECT_EQ(op2_info.HasTrait(), true); diff --git a/test/cpp/ir/core/ir_program_test.cc b/test/cpp/ir/core/ir_program_test.cc index 66985143315..2388752a87b 100644 --- a/test/cpp/ir/core/ir_program_test.cc +++ b/test/cpp/ir/core/ir_program_test.cc @@ -98,7 +98,7 @@ TEST(program_test, program) { ir::Block *block = program.block(); block->push_back(op1); - EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParentRegion()); + EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParent()); EXPECT_EQ(program.module_op(), block->GetParentOp()); @@ -299,7 +299,7 @@ TEST(program_test, builder) { ir::IrContext *ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); ir::Program program(ctx); - ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program.block()); + ir::Builder builder = ir::Builder(ctx, program.block()); paddle::dialect::FullOp full_op = builder.Build( std::vector{2, 2}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); diff --git a/test/cpp/ir/core/ir_value_test.cc b/test/cpp/ir/core/ir_value_test.cc index c99fe31fef1..b3ea7e836e1 100644 --- a/test/cpp/ir/core/ir_value_test.cc +++ b/test/cpp/ir/core/ir_value_test.cc @@ -42,7 +42,7 @@ TEST(value_test, value_test) { ir::Operation::Create(op1_inputs, CreateAttributeMap("op1_name", "op1_attr"), op1_output_types, - nullptr); + ir::OpInfo()); op1->Print(std::cout); // 2. Construct OP2: b = OP2(); std::vector op2_inputs = {}; @@ -51,7 +51,7 @@ TEST(value_test, value_test) { ir::Operation::Create(op2_inputs, CreateAttributeMap("op2_name", "op2_attr"), op2_output_types, - nullptr); + ir::OpInfo()); op2->Print(std::cout); // 3. Construct OP3: c = OP3(a, b); std::vector op3_inputs = {op1->GetResultByIndex(0), @@ -61,7 +61,7 @@ TEST(value_test, value_test) { ir::Operation::Create(op3_inputs, CreateAttributeMap("op3_name", "op3_attr"), op3_output_types, - nullptr); + ir::OpInfo()); op3->Print(std::cout); // 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c); std::vector op4_inputs = {op1->GetResultByIndex(0), @@ -74,7 +74,7 @@ TEST(value_test, value_test) { ir::Operation::Create(op4_inputs, CreateAttributeMap("op4_name", "op4_attr"), op4_output_types, - nullptr); + ir::OpInfo()); op4->Print(std::cout); // Test 1: diff --git a/test/cpp/ir/core/op_info_test.cc b/test/cpp/ir/core/op_info_test.cc new file mode 100644 index 00000000000..da3142dfe5f --- /dev/null +++ b/test/cpp/ir/core/op_info_test.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/ir/core/block.h" +#include "paddle/ir/core/builder.h" +#include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/core/builtin_op.h" +#include "paddle/ir/core/builtin_type.h" +#include "paddle/ir/core/ir_context.h" +#include "paddle/ir/core/program.h" + +TEST(ir_op_info_test, op_op_info_test) { + ir::IrContext* context = ir::IrContext::Instance(); + ir::Program program(context); + + ir::Block* block = program.block(); + ir::Builder builder(context, block); + builder.Build(ir::Int32_tAttribute::get(context, 5), + ir::Int32Type::get(context)); + + ir::Operation* op = block->back(); + + EXPECT_EQ(block->end(), ++ir::Block::iterator(*op)); + + auto& info_map = context->registered_op_info_map(); + EXPECT_FALSE(info_map.empty()); + + void* info_1 = op->info().AsOpaquePointer(); + auto info_2 = ir::OpInfo::RecoverFromOpaquePointer(info_1); + EXPECT_EQ(op->info(), info_2); +} diff --git a/test/cpp/ir/pass/pass_manager_test.cc b/test/cpp/ir/pass/pass_manager_test.cc index 67b9f439bc4..71d4dccac37 100644 --- a/test/cpp/ir/pass/pass_manager_test.cc +++ b/test/cpp/ir/pass/pass_manager_test.cc @@ -112,7 +112,7 @@ TEST(pass_manager_test, pass_manager) { ir::Block *block = program.block(); block->push_back(op1); - EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParentRegion()); + EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParent()); EXPECT_EQ(program.module_op(), block->GetParentOp()); -- GitLab