diff --git a/paddle/ir/core/block.cc b/paddle/ir/core/block.cc index b74a960f7fc453384c43ad7a450e94c7b745dae1..f99ec340e4c49cf2bd388c41a3c7f7d1d06985ea 100644 --- a/paddle/ir/core/block.cc +++ b/paddle/ir/core/block.cc @@ -18,7 +18,10 @@ #include "paddle/ir/core/region.h" namespace ir { -Block::~Block() { clear(); } +Block::~Block() { + assert(use_empty() && "block destroyed still has uses."); + clear(); +} void Block::push_back(Operation *op) { insert(ops_.end(), op); } void Block::push_front(Operation *op) { insert(ops_.begin(), op); } @@ -51,4 +54,10 @@ void Block::SetParent(Region *parent, Region::iterator position) { position_ = position; } +Block::UseIterator Block::use_begin() const { return first_use_; } + +Block::UseIterator Block::use_end() const { return Block::UseIterator(); } + +bool Block::HasOneUse() const { return first_use_ && !first_use_.next_use(); } + } // namespace ir diff --git a/paddle/ir/core/block.h b/paddle/ir/core/block.h index 68b2257723236a44ec902a951092dfe9b58e2ced..ebe4b6cb8ecf4e6c6147f02e58a08e92b1be592d 100644 --- a/paddle/ir/core/block.h +++ b/paddle/ir/core/block.h @@ -17,8 +17,10 @@ #include #include +#include "paddle/ir/core/block_operand.h" #include "paddle/ir/core/dll_decl.h" #include "paddle/ir/core/region.h" +#include "paddle/ir/core/use_iterator.h" namespace ir { class Operation; @@ -56,6 +58,18 @@ class IR_API Block { void clear(); operator Region::iterator() { return position_; } + /// + /// \brief Provide iterator interface to access Value use chain. + /// + using UseIterator = ValueUseIterator; + UseIterator use_begin() const; + UseIterator use_end() const; + BlockOperand first_use() const { return first_use_; } + void set_first_use(BlockOperand first_use) { first_use_ = first_use; } + bool use_empty() const { return !first_use_; } + bool HasOneUse() const; + BlockOperand *first_use_addr() { return &first_use_; } + private: Block(Block &) = delete; Block &operator=(const Block &) = delete; @@ -68,5 +82,6 @@ class IR_API Block { Region *parent_; // not owned OpListType ops_; // owned Region::iterator position_; + BlockOperand first_use_; }; } // namespace ir diff --git a/paddle/ir/core/block_operand.cc b/paddle/ir/core/block_operand.cc new file mode 100644 index 0000000000000000000000000000000000000000..f64a07fd50dfe964828f8303f1d2c2af7eda0ae4 --- /dev/null +++ b/paddle/ir/core/block_operand.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/core/block_operand.h" +#include "paddle/ir/core/block.h" +#include "paddle/ir/core/block_operand_impl.h" +#include "paddle/ir/core/enforce.h" + +namespace ir { + +#define CHECK_BLOCKOPEREND_NULL_IMPL(func_name) \ + IR_ENFORCE(impl_, \ + "impl_ pointer is null when call func:" #func_name \ + " , in class: BlockOperand.") + +BlockOperand &BlockOperand::operator=(const BlockOperand &rhs) { + if (this == &rhs) return *this; + impl_ = rhs.impl_; + return *this; +} + +BlockOperand::operator bool() const { return impl_ && impl_->source(); } + +BlockOperand BlockOperand::next_use() const { + CHECK_BLOCKOPEREND_NULL_IMPL(next_use); + return impl_->next_use(); +} + +Block *BlockOperand::source() const { + CHECK_BLOCKOPEREND_NULL_IMPL(source); + return impl_->source(); +} + +void BlockOperand::set_source(Block *source) { + CHECK_BLOCKOPEREND_NULL_IMPL(set_source); + impl_->set_source(source); +} + +Operation *BlockOperand::owner() const { + CHECK_BLOCKOPEREND_NULL_IMPL(owner); + return impl_->owner(); +} + +void BlockOperand::RemoveFromUdChain() { + CHECK_BLOCKOPEREND_NULL_IMPL(RemoveFromUdChain); + return impl_->RemoveFromUdChain(); +} + +// details +namespace detail { + +Operation *BlockOperandImpl::owner() const { return owner_; } + +BlockOperand BlockOperandImpl::next_use() const { return next_use_; } + +Block *BlockOperandImpl::source() const { return source_; } + +void BlockOperandImpl::set_source(Block *source) { + RemoveFromUdChain(); + if (!source) { + return; + } + source_ = source; + InsertToUdChain(); +} + +BlockOperandImpl::BlockOperandImpl(Block *source, ir::Operation *owner) + : source_(source), owner_(owner) { + if (!source) { + return; + } + InsertToUdChain(); +} + +void BlockOperandImpl::InsertToUdChain() { + prev_use_addr_ = source_->first_use_addr(); + next_use_ = source_->first_use(); + if (next_use_) { + next_use_.impl()->prev_use_addr_ = &next_use_; + } + source_->set_first_use(this); +} + +void BlockOperandImpl::RemoveFromUdChain() { + if (!source_) return; + if (!prev_use_addr_) return; + if (prev_use_addr_ == source_->first_use_addr()) { + source_->set_first_use(next_use_); + } else { + *prev_use_addr_ = next_use_; + } + if (next_use_) { + next_use_.impl()->prev_use_addr_ = prev_use_addr_; + } + next_use_ = nullptr; + prev_use_addr_ = nullptr; + source_ = nullptr; +} + +BlockOperandImpl::~BlockOperandImpl() { RemoveFromUdChain(); } +} // namespace detail +} // namespace ir diff --git a/paddle/ir/core/block_operand.h b/paddle/ir/core/block_operand.h new file mode 100644 index 0000000000000000000000000000000000000000..ec55a90a1c65d971a669b5155aa37ba09d682f2e --- /dev/null +++ b/paddle/ir/core/block_operand.h @@ -0,0 +1,73 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/core/cast_utils.h" +#include "paddle/ir/core/type.h" + +namespace ir { +class Operation; +class Value; +class Block; + +namespace detail { +class BlockOperandImpl; +} // namespace detail + +/// +/// \brief OpOperand class represents the op_operand of operation. This class +/// only provides interfaces, for specific implementation, see Impl class. +/// +class IR_API BlockOperand { + public: + BlockOperand() = default; + + BlockOperand(const BlockOperand &other) = default; + + BlockOperand(detail::BlockOperandImpl *impl) : impl_(impl) {} // NOLINT + + BlockOperand &operator=(const BlockOperand &rhs); + + bool operator==(const BlockOperand &other) const { + return impl_ == other.impl_; + } + + bool operator!=(const BlockOperand &other) const { + return !operator==(other); + } + + bool operator!() const { return impl_ == nullptr; } + + operator bool() const; + + BlockOperand next_use() const; + + Block *source() const; + + void set_source(Block *source); + + Operation *owner() const; + + void RemoveFromUdChain(); + + friend Operation; + + detail::BlockOperandImpl *impl() const { return impl_; } + + private: + detail::BlockOperandImpl *impl_{nullptr}; +}; + +} // namespace ir diff --git a/paddle/ir/core/block_operand_impl.h b/paddle/ir/core/block_operand_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..53d8257c10032d643a42a69964da00a895a4dd34 --- /dev/null +++ b/paddle/ir/core/block_operand_impl.h @@ -0,0 +1,61 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/core/block_operand.h" + +namespace ir { +class Operation; +class Block; + +namespace detail { +/// +/// \brief OpOperandImpl +/// +class BlockOperandImpl { + public: + Operation* owner() const; + + BlockOperand next_use() const; + + Block* source() const; + + void set_source(Block*); + + /// Remove this op_operand from the current use list. + void RemoveFromUdChain(); + + ~BlockOperandImpl(); + + friend Operation; + + private: + BlockOperandImpl(Block* source, Operation* owner); + + // Insert self to the UD chain holded by source_; + // It is not safe. So set provate. + void InsertToUdChain(); + + BlockOperand next_use_ = nullptr; + + BlockOperand* prev_use_addr_ = nullptr; + + Block* source_; + + Operation* const owner_ = nullptr; +}; + +} // namespace detail +} // namespace ir diff --git a/paddle/ir/core/builtin_op.cc b/paddle/ir/core/builtin_op.cc index 6522bb1fbb9fc52ac0545f8f8f239cb9641fb5a4..1feb4d691d99b75385be064f91cadebf82b44ca7 100644 --- a/paddle/ir/core/builtin_op.cc +++ b/paddle/ir/core/builtin_op.cc @@ -40,9 +40,11 @@ Block *ModuleOp::block() { ModuleOp ModuleOp::Create(IrContext *context, Program *pointer) { ir::OpInfo info = context->GetRegisteredOpInfo(name()); OperationArgument argument(info); - argument.AddRegion()->emplace_back(); + argument.num_regions = 1; argument.AddAttribute("program", PointerAttribute::get(context, pointer)); - return ModuleOp(Operation::Create(std::move(argument))); + Operation *op = Operation::Create(std::move(argument)); + op->region(0).emplace_back(); + return ModuleOp(op); } void ModuleOp::Destroy() { diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index 5cdc154a0a5da2d771374eceae8fc35179fd6267..94722d1ada0b525db90123ad7b7121f36fbf49c4 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -15,6 +15,7 @@ #include #include "paddle/ir/core/block.h" +#include "paddle/ir/core/block_operand_impl.h" #include "paddle/ir/core/dialect.h" #include "paddle/ir/core/enforce.h" #include "paddle/ir/core/op_info.h" @@ -26,16 +27,12 @@ namespace ir { Operation *Operation::Create(OperationArgument &&argument) { - Operation *op = Create(argument.inputs, - argument.attributes, - argument.output_types, - argument.info, - argument.regions.size()); - - for (size_t index = 0; index < argument.regions.size(); ++index) { - op->region(index).TakeBody(std::move(*argument.regions[index])); - } - return op; + return Create(argument.inputs, + argument.attributes, + argument.output_types, + argument.info, + argument.num_regions, + argument.successors); } // Allocate the required memory based on the size and number of inputs, outputs, @@ -43,13 +40,15 @@ Operation *Operation::Create(OperationArgument &&argument) { // OpInlineResult, Operation, operand. Operation *Operation::Create(const std::vector &inputs, const AttributeMap &attributes, - const std::vector &output_types, + const std::vector &output_types, ir::OpInfo op_info, - size_t num_regions) { + size_t num_regions, + const std::vector &successors) { // 1. Calculate the required memory size for OpResults + Operation + // OpOperands. uint32_t num_results = output_types.size(); uint32_t num_operands = inputs.size(); + uint32_t num_successors = successors.size(); uint32_t max_inline_result_num = detail::OpResultImpl::GetMaxInlineResultIndex() + 1; size_t result_mem_size = @@ -58,11 +57,12 @@ Operation *Operation::Create(const std::vector &inputs, (num_results - max_inline_result_num) + sizeof(detail::OpInlineResultImpl) * max_inline_result_num : sizeof(detail::OpInlineResultImpl) * num_results; - size_t operand_mem_size = sizeof(detail::OpOperandImpl) * num_operands; size_t op_mem_size = sizeof(Operation); + size_t operand_mem_size = sizeof(detail::OpOperandImpl) * num_operands; + size_t block_operand_size = num_successors * sizeof(detail::BlockOperandImpl); size_t region_mem_size = num_regions * sizeof(Region); - size_t base_size = - result_mem_size + op_mem_size + operand_mem_size + region_mem_size; + size_t base_size = result_mem_size + op_mem_size + operand_mem_size + + region_mem_size + block_operand_size; // 2. Malloc memory. char *base_ptr = reinterpret_cast(aligned_malloc(base_size, 8)); // 3.1. Construct OpResults. @@ -77,8 +77,12 @@ Operation *Operation::Create(const std::vector &inputs, } } // 3.2. Construct Operation. - Operation *op = new (base_ptr) - Operation(attributes, op_info, num_results, num_operands, num_regions); + Operation *op = new (base_ptr) Operation(attributes, + op_info, + num_results, + num_operands, + num_regions, + num_successors); base_ptr += sizeof(Operation); // 3.3. Construct OpOperands. if ((reinterpret_cast(base_ptr) & 0x7) != 0) { @@ -88,7 +92,17 @@ Operation *Operation::Create(const std::vector &inputs, new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op); base_ptr += sizeof(detail::OpOperandImpl); } - // 3.4. Construct Regions + // 3.4. Construct BlockOperands. + if (num_successors > 0) { + op->block_operands_ = + reinterpret_cast(base_ptr); + for (size_t idx = 0; idx < num_successors; idx++) { + new (base_ptr) detail::BlockOperandImpl(successors[idx], op); + base_ptr += sizeof(detail::BlockOperandImpl); + } + } + + // 3.5. Construct Regions if (num_regions > 0) { op->regions_ = reinterpret_cast(base_ptr); for (size_t idx = 0; idx < num_regions; idx++) { @@ -118,8 +132,6 @@ void Operation::Destroy() { // 2. Deconstruct Result. for (size_t idx = 0; idx < num_results_; ++idx) { detail::OpResultImpl *impl = result(idx).impl(); - IR_ENFORCE(impl->use_empty(), - name() + " operation destroyed but still has uses."); if (detail::OpOutlineResultImpl::classof(*impl)) { static_cast(impl)->~OpOutlineResultImpl(); } else { @@ -132,8 +144,20 @@ void Operation::Destroy() { // 4. Deconstruct OpOperand. for (size_t idx = 0; idx < num_operands_; idx++) { - operand(idx).impl()->~OpOperandImpl(); + detail::OpOperandImpl *op_operand_impl = operand(idx).impl_; + if (op_operand_impl) { + op_operand_impl->~OpOperandImpl(); + } } + + // 5. Deconstruct BlockOperand. + for (size_t idx = 0; idx < num_successors_; idx++) { + detail::BlockOperandImpl *block_operand_impl = block_operands_ + idx; + if (block_operand_impl) { + block_operand_impl->~BlockOperandImpl(); + } + } + // 5. Free memory. uint32_t max_inline_result_num = detail::OpResultImpl::GetMaxInlineResultIndex() + 1; @@ -158,12 +182,14 @@ Operation::Operation(const AttributeMap &attributes, ir::OpInfo op_info, uint32_t num_results, uint32_t num_operands, - uint32_t num_regions) + uint32_t num_regions, + uint32_t num_successors) : attributes_(attributes), info_(op_info), num_results_(num_results), num_operands_(num_operands), - num_regions_(num_regions) {} + num_regions_(num_regions), + num_successors_(num_successors) {} ir::OpResult Operation::result(uint32_t index) const { if (index >= num_results_) { @@ -226,14 +252,26 @@ const Program *Operation::GetParentProgram() const { ModuleOp module_op = op->dyn_cast(); return module_op ? module_op.program() : nullptr; } +BlockOperand Operation::block_operand(uint32_t index) const { + IR_ENFORCE(index < num_successors_, "Invalid block_operand index"); + return block_operands_ + index; +} +Block *Operation::successor(uint32_t index) const { + return block_operand(index).source(); +} + +void Operation::set_successor(Block *block, unsigned index) { + IR_ENFORCE(index < num_operands_, "Invalid block_operand index"); + (block_operands_ + index)->set_source(block); +} Region &Operation::region(unsigned index) { - assert(index < num_regions_ && "invalid region index"); + IR_ENFORCE(index < num_regions_, "invalid region index"); return regions_[index]; } const Region &Operation::region(unsigned index) const { - assert(index < num_regions_ && "invalid region index"); + IR_ENFORCE(index < num_regions_, "invalid region index"); return regions_[index]; } diff --git a/paddle/ir/core/operation.h b/paddle/ir/core/operation.h index a223c57abdd08f1c91b84fe9ccc307927f2f6022..dec0dfa6883ea106ef2ab4e4c69ceb0ccd35709e 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/ir/core/operation.h @@ -29,6 +29,10 @@ class Program; class OpOperand; class OpResult; +namespace detial { +class BlockOperandImpl; +} // namespace detial + class IR_API alignas(8) Operation final { public: /// @@ -41,7 +45,8 @@ class IR_API alignas(8) Operation final { const AttributeMap &attributes, const std::vector &output_types, ir::OpInfo op_info, - size_t num_regions = 0); + size_t num_regions = 0, + const std::vector &successors = {}); static Operation *Create(OperationArgument &&op_argument); /// @@ -59,9 +64,16 @@ class IR_API alignas(8) Operation final { Value operand_source(uint32_t index) const; + uint32_t num_successors() const { return num_successors_; } + BlockOperand block_operand(uint32_t index) const; + Block *successor(uint32_t index) const; + void set_successor(Block *block, unsigned index); + bool HasSuccessors() { return num_successors_ != 0; } + /// Returns the region held by this operation at position 'index'. Region ®ion(unsigned index); const Region ®ion(unsigned index) const; + uint32_t num_regions() const { return num_regions_; } void Print(std::ostream &os) const; @@ -90,8 +102,6 @@ class IR_API alignas(8) Operation final { uint32_t num_operands() const { return num_operands_; } - uint32_t num_regions() const { return num_regions_; } - std::string name() const; template @@ -152,7 +162,8 @@ class IR_API alignas(8) Operation final { ir::OpInfo op_info, uint32_t num_results, uint32_t num_operands, - uint32_t num_regions); + uint32_t num_regions, + uint32_t num_successors); template struct CastUtil { @@ -179,7 +190,9 @@ class IR_API alignas(8) Operation final { const uint32_t num_results_ = 0; const uint32_t num_operands_ = 0; const uint32_t num_regions_ = 0; + const uint32_t num_successors_ = 0; + detail::BlockOperandImpl *block_operands_{nullptr}; Region *regions_{nullptr}; Block *parent_{nullptr}; Block::iterator position_; diff --git a/paddle/ir/core/operation_utils.h b/paddle/ir/core/operation_utils.h index 3ab421f945daed1c6fd2d1cd35085c2e7fba02b0..9e317a6510f5928a9f494fc5e60d1222ac0b5ce2 100644 --- a/paddle/ir/core/operation_utils.h +++ b/paddle/ir/core/operation_utils.h @@ -22,7 +22,7 @@ #include "paddle/ir/core/value.h" namespace ir { - +class Block; using AttributeMap = std::unordered_map; //===----------------------------------------------------------------------===// @@ -36,7 +36,8 @@ struct OperationArgument { AttributeMap attributes; std::vector output_types; OpInfo info; - std::vector> regions; + size_t num_regions{0}; + std::vector successors; public: OperationArgument(IrContext* ir_context, const std::string& name); @@ -45,12 +46,14 @@ struct OperationArgument { const AttributeMap& attributes, const std::vector& types, OpInfo info, - std::vector>&& regions = {}) + size_t num_regions = 0, + const std::vector successors = {}) : inputs(operands), attributes(attributes), output_types(types), info(info), - regions(std::move(regions)) {} + num_regions(num_regions), + successors(successors) {} /// Add Operand. void AddOperand(OpResult operand) { inputs.emplace_back(operand); } @@ -74,10 +77,7 @@ struct OperationArgument { /// Get the context held by this operation state. IrContext* getContext() const { return info.ir_context(); } - Region* AddRegion() { - regions.emplace_back(new Region); - return regions.back().get(); - } + void AddSuccessor(Block* successor) { successors.emplace_back(successor); } }; template diff --git a/paddle/ir/core/region.cc b/paddle/ir/core/region.cc index ac7cd4ccdfa8b8554b1adebfc5f29874af6cdddd..e9fdb91758219fa8a2e0b39722ad62dd522a3408 100644 --- a/paddle/ir/core/region.cc +++ b/paddle/ir/core/region.cc @@ -46,6 +46,11 @@ void Region::TakeBody(Region &&other) { } void Region::clear() { + // In order to ensure the correctness of UD Chain, + // BlockOperend should be decontructed bofore its source. + for (auto iter = blocks_.rbegin(); iter != blocks_.rend(); ++iter) { + (*iter)->clear(); + } while (!empty()) { delete blocks_.back(); blocks_.pop_back(); diff --git a/paddle/ir/core/region.h b/paddle/ir/core/region.h index 5335588790f021c97c1f6dc4fbf02c44ab3ce7e3..cc1c1ab791df5f0e402ca2a549c4f386a4782af5 100644 --- a/paddle/ir/core/region.h +++ b/paddle/ir/core/region.h @@ -31,8 +31,6 @@ class IR_API Region { using reverse_iterator = std::list::reverse_iterator; using const_iterator = std::list::const_iterator; ~Region(); - Region() = default; - bool empty() const { return blocks_.empty(); } size_t size() const { return blocks_.size(); } @@ -59,6 +57,8 @@ class IR_API Region { IrContext *ir_context() const; private: + // region only support construncted by operation. + Region() = delete; Region(Region &) = delete; Region &operator=(const Region &) = delete; friend class Operation; diff --git a/paddle/ir/core/use_iterator.h b/paddle/ir/core/use_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..d7ef2a675649f9db38411b12b63cbd2f9743ec59 --- /dev/null +++ b/paddle/ir/core/use_iterator.h @@ -0,0 +1,55 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +namespace ir { + +class Operation; +/// +/// \brief Value Iterator +/// +template +class ValueUseIterator { + public: + ValueUseIterator(OperandType use = nullptr) : current_(use) {} // NOLINT + + bool operator==(const ValueUseIterator &rhs) const { + return current_ == rhs.current_; + } + bool operator!=(const ValueUseIterator &rhs) const { + return !(*this == rhs); + } + + Operation *owner() const { return current_.owner(); } + + OperandType &operator*() { return current_; } + + OperandType *operator->() { return &operator*(); } + + ValueUseIterator &operator++() { + current_ = current_.next_use(); + return *this; + } + + ValueUseIterator operator++(int) { + ValueUseIterator tmp = *this; + current_ = current_.next_use(); + return tmp; + } + + protected: + OperandType current_; +}; + +} // namespace ir diff --git a/paddle/ir/core/value.cc b/paddle/ir/core/value.cc index 018342aa8154791d7a4a72110c525bd10b979dbc..d38bcdca36314b1ababeae6b03aadd0dc75c8102 100644 --- a/paddle/ir/core/value.cc +++ b/paddle/ir/core/value.cc @@ -14,9 +14,22 @@ #include "paddle/ir/core/value.h" #include "paddle/ir/core/enforce.h" +#include "paddle/ir/core/operation.h" #include "paddle/ir/core/value_impl.h" +#define CHECK_NULL_IMPL(class_name, func_name) \ + IR_ENFORCE(impl_, \ + "impl_ pointer is null when call func:" #func_name \ + " , in class: " #class_name ".") + +#define CHECK_OPOPEREND_NULL_IMPL(func_name) \ + CHECK_NULL_IMPL(OpOpernad, func_name) + +#define CHECK_VALUE_NULL_IMPL(func_name) CHECK_NULL_IMPL(Value, func_name) + +#define CHECK_OPRESULT_NULL_IMPL(func_name) CHECK_NULL_IMPL(OpResult, func_name) namespace ir { + // Operand OpOperand::OpOperand(const detail::OpOperandImpl *impl) : impl_(const_cast(impl)) {} @@ -34,22 +47,33 @@ OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) { } OpOperand::operator bool() const { return impl_ && impl_->source(); } -OpOperand OpOperand::next_use() const { return impl()->next_use(); } +OpOperand OpOperand::next_use() const { + CHECK_OPOPEREND_NULL_IMPL(next_use); + return impl_->next_use(); +} -Value OpOperand::source() const { return impl()->source(); } +Value OpOperand::source() const { + CHECK_OPOPEREND_NULL_IMPL(source); + return impl_->source(); +} Type OpOperand::type() const { return source().type(); } -void OpOperand::set_source(Value value) { impl()->set_source(value); } - -Operation *OpOperand::owner() const { return impl()->owner(); } +void OpOperand::set_source(Value value) { + CHECK_OPOPEREND_NULL_IMPL(set_source); + impl_->set_source(value); +} -void OpOperand::RemoveFromUdChain() { return impl()->RemoveFromUdChain(); } +Operation *OpOperand::owner() const { + CHECK_OPOPEREND_NULL_IMPL(owner); + return impl_->owner(); +} -detail::OpOperandImpl *OpOperand::impl() const { - IR_ENFORCE(impl_, "Can't use impl() interface while op_operand is null."); - return impl_; +void OpOperand::RemoveFromUdChain() { + CHECK_OPOPEREND_NULL_IMPL(RemoveFromUdChain); + return impl_->RemoveFromUdChain(); } + // Value Value::Value(const detail::ValueImpl *impl) : impl_(const_cast(impl)) {} @@ -66,31 +90,48 @@ bool Value::operator!() const { return impl_ == nullptr; } Value::operator bool() const { return impl_; } -ir::Type Value::type() const { return impl()->type(); } +ir::Type Value::type() const { + CHECK_VALUE_NULL_IMPL(type); + return impl_->type(); +} -void Value::set_type(ir::Type type) { impl()->set_type(type); } +void Value::set_type(ir::Type type) { + CHECK_VALUE_NULL_IMPL(set_type); + impl_->set_type(type); +} Operation *Value::GetDefiningOp() const { if (auto result = dyn_cast()) return result.owner(); return nullptr; } -std::string Value::PrintUdChain() { return impl()->PrintUdChain(); } +std::string Value::PrintUdChain() { + CHECK_VALUE_NULL_IMPL(PrintUdChain); + return impl()->PrintUdChain(); +} -Value::use_iterator Value::begin() const { return ir::OpOperand(first_use()); } +Value::UseIterator Value::use_begin() const { + return ir::OpOperand(first_use()); +} -Value::use_iterator Value::end() const { return Value::use_iterator(); } +Value::UseIterator Value::use_end() const { return Value::UseIterator(); } -OpOperand Value::first_use() const { return impl()->first_use(); } +OpOperand Value::first_use() const { + CHECK_VALUE_NULL_IMPL(first_use); + return impl_->first_use(); +} bool Value::use_empty() const { return !first_use(); } -bool Value::HasOneUse() const { return impl()->HasOneUse(); } +bool Value::HasOneUse() const { + CHECK_VALUE_NULL_IMPL(HasOneUse); + return impl_->HasOneUse(); +} void Value::ReplaceUsesWithIf( Value new_value, const std::function &should_replace) const { - for (auto it = begin(); it != end();) { + for (auto it = use_begin(); it != use_end();) { if (should_replace(*it)) { (it++)->set_source(new_value); } @@ -98,27 +139,27 @@ void Value::ReplaceUsesWithIf( } void Value::ReplaceAllUsesWith(Value new_value) const { - for (auto it = begin(); it != end();) { + for (auto it = use_begin(); it != use_end();) { (it++)->set_source(new_value); } } -detail::ValueImpl *Value::impl() const { - IR_ENFORCE(impl_, "Can't use impl() interface while value is null."); - return impl_; -} - // OpResult bool OpResult::classof(Value value) { return value && ir::isa(value.impl()); } -Operation *OpResult::owner() const { return impl()->owner(); } +Operation *OpResult::owner() const { + CHECK_OPRESULT_NULL_IMPL(owner); + return impl()->owner(); +} -uint32_t OpResult::GetResultIndex() const { return impl()->GetResultIndex(); } +uint32_t OpResult::GetResultIndex() const { + CHECK_OPRESULT_NULL_IMPL(GetResultIndex); + return impl()->GetResultIndex(); +} detail::OpResultImpl *OpResult::impl() const { - IR_ENFORCE(impl_, "Can't use impl() interface while value is null."); return reinterpret_cast(impl_); } @@ -168,7 +209,7 @@ void OpOperandImpl::InsertToUdChain() { if (next_use_) { next_use_->prev_use_addr_ = &next_use_; } - source_.impl()->SetFirstUse(this); + source_.impl()->set_first_use(this); } void OpOperandImpl::RemoveFromUdChain() { @@ -176,9 +217,9 @@ void OpOperandImpl::RemoveFromUdChain() { if (!prev_use_addr_) return; if (prev_use_addr_ == source_.impl()->first_use_addr()) { /// NOTE: In ValueImpl, first_use_offseted_by_index_ use lower three bits - /// storage index information, so need to be updated using the SetFirstUse + /// storage index information, so need to be updated using the set_first_use /// method here. - source_.impl()->SetFirstUse(next_use_); + source_.impl()->set_first_use(next_use_); } else { *prev_use_addr_ = next_use_; } @@ -223,6 +264,11 @@ uint32_t OpResultImpl::GetResultIndex() const { return ir::dyn_cast(this)->GetResultIndex(); } +OpResultImpl::~OpResultImpl() { + assert(use_empty() && + owner()->name() + " operation destroyed but still has uses."); +} + ir::Operation *OpResultImpl::owner() const { // For inline result, pointer offset index to obtain the address of op. if (const auto *result = ir::dyn_cast(this)) { diff --git a/paddle/ir/core/value.h b/paddle/ir/core/value.h index 86a5566393a220d57c6292c50aa79ddf7de5f785..6594016e74402227b846210647186b95b5e41c48 100644 --- a/paddle/ir/core/value.h +++ b/paddle/ir/core/value.h @@ -16,6 +16,7 @@ #include "paddle/ir/core/cast_utils.h" #include "paddle/ir/core/type.h" +#include "paddle/ir/core/use_iterator.h" namespace ir { class Operation; @@ -66,49 +67,9 @@ class IR_API OpOperand { friend Operation; private: - // The interface shoule ensure impl_ isn't nullptr. - // if the user can accept impl_ is nullptr, shoule use impl_ member directly. - detail::OpOperandImpl *impl() const; - detail::OpOperandImpl *impl_{nullptr}; }; -/// -/// \brief Value Iterator -/// -template -class ValueUseIterator { - public: - ValueUseIterator(OperandType use = nullptr) : current_(use) {} // NOLINT - - bool operator==(const ValueUseIterator &rhs) const { - return current_ == rhs.current_; - } - bool operator!=(const ValueUseIterator &rhs) const { - return !(*this == rhs); - } - - ir::Operation *owner() const { return current_.owner(); } - - OperandType &operator*() { return current_; } - - OperandType *operator->() { return &operator*(); } - - ValueUseIterator &operator++() { - current_ = current_.next_use(); - return *this; - } - - ValueUseIterator operator++(int) { - ValueUseIterator tmp = *this; - ++*(this); - return tmp; - } - - protected: - OperandType current_; -}; - /// /// \brief Value class represents the SSA value in the IR system. This class /// only provides interfaces, for specific implementation, see Impl class. @@ -150,11 +111,11 @@ class IR_API Value { /// /// \brief Provide iterator interface to access Value use chain. /// - using use_iterator = ValueUseIterator; + using UseIterator = ValueUseIterator; - use_iterator begin() const; + UseIterator use_begin() const; - use_iterator end() const; + UseIterator use_end() const; OpOperand first_use() const; @@ -169,9 +130,8 @@ class IR_API Value { const std::function &should_replace) const; void ReplaceAllUsesWith(Value new_value) const; - // The interface shoule ensure impl_ isn't nullptr. - // if the user can accept impl_ is nullptr, shoule use impl_ member directly. - detail::ValueImpl *impl() const; + detail::ValueImpl *impl() { return impl_; } + const detail::ValueImpl *impl() const { return impl_; } protected: detail::ValueImpl *impl_{nullptr}; @@ -197,11 +157,10 @@ class IR_API OpResult : public Value { friend Operation; detail::ValueImpl *value_impl() const; + detail::OpResultImpl *impl() const; private: static uint32_t GetValidInlineIndex(uint32_t index); - - detail::OpResultImpl *impl() const; }; } // namespace ir diff --git a/paddle/ir/core/value_impl.h b/paddle/ir/core/value_impl.h index 9c3c56cdefd387813c3888875816541b2f74f723..14a7b4d63f5d3d3d047dcd9402e7b9c74155ce56 100644 --- a/paddle/ir/core/value_impl.h +++ b/paddle/ir/core/value_impl.h @@ -46,7 +46,7 @@ class OpOperandImpl { OpOperandImpl(ir::Value source, ir::Operation *owner); // Insert self to the UD chain holded by source_; - // It is not safe. So set provate. + // It is not safe. So set private. void InsertToUdChain(); ir::detail::OpOperandImpl *next_use_ = nullptr; @@ -85,7 +85,7 @@ class alignas(8) ValueImpl { reinterpret_cast(first_use_offseted_by_index_) & (~0x07)); } - void SetFirstUse(OpOperandImpl *first_use) { + void set_first_use(OpOperandImpl *first_use) { uint32_t offset = index(); first_use_offseted_by_index_ = reinterpret_cast( reinterpret_cast(first_use) + offset); @@ -163,6 +163,8 @@ class alignas(8) OpResultImpl : public ValueImpl { static uint32_t GetMaxInlineResultIndex() { return OUTLINE_OP_RESULT_INDEX - 1; } + + ~OpResultImpl(); }; /// diff --git a/test/cpp/ir/CMakeLists.txt b/test/cpp/ir/CMakeLists.txt index d2117ad5c24e2e660464712695288ab872fa098c..4eec7e8ef94c14b5e7e3b0fa00a2e46691920a40 100644 --- a/test/cpp/ir/CMakeLists.txt +++ b/test/cpp/ir/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(tools) add_subdirectory(core) add_subdirectory(pass) add_subdirectory(pattern_rewrite) diff --git a/test/cpp/ir/core/CMakeLists.txt b/test/cpp/ir/core/CMakeLists.txt index b1ff16025714d2eb982f76c4a26fdfc02271b452..80cd506648cd8f897ba9611809212dc9b5648c35 100644 --- a/test/cpp/ir/core/CMakeLists.txt +++ b/test/cpp/ir/core/CMakeLists.txt @@ -95,3 +95,12 @@ cc_test_old( program_translator pd_dialect ir) + +cc_test_old( + block_operand_test + SRCS + block_operand_test.cc + DEPS + test_dialect + gtest + ir) diff --git a/test/cpp/ir/core/block_operand_test.cc b/test/cpp/ir/core/block_operand_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f2b74e9781a3fdba0d6e31806bf5dcd167610ba0 --- /dev/null +++ b/test/cpp/ir/core/block_operand_test.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/ir/core/builder.h" +#include "paddle/ir/core/program.h" + +#include "test/cpp/ir/tools/test_dialect.h" +#include "test/cpp/ir/tools/test_op.h" + +TEST(block_operand_test, type_block) { + ir::IrContext ctx; + ctx.GetOrRegisterDialect(); + + ir::Program program(&ctx); + ir::Block* block = program.block(); + + ir::Builder builder(&ctx, block); + test::RegionOp region_op = builder.Build(); + + auto& region = region_op->region(0); + + ir::Block* block_1 = new ir::Block(); + ir::Block* block_2 = new ir::Block(); + ir::Block* block_3 = new ir::Block(); + region.push_back(block_1); + region.push_back(block_2); + region.push_back(block_3); + + builder.SetInsertionPointToEnd(block_1); + auto op1 = + builder.Build(std::vector{}, block_2); + EXPECT_TRUE(block_2->HasOneUse()); + EXPECT_FALSE(block_2->use_empty()); + + auto iter_begin = block_2->use_begin(); + auto iter_end = block_2->use_end(); + auto block_operand = op1->block_operand(0); + auto iter_curr = iter_begin++; + EXPECT_EQ(iter_begin, iter_end); + EXPECT_EQ(*iter_curr, block_operand); + EXPECT_EQ(block_2->first_use(), block_operand); + EXPECT_EQ(iter_curr->owner(), op1); + + builder.SetInsertionPointToEnd(block_3); + auto op3 = + builder.Build(std::vector{}, block_1); + block_operand = op3->block_operand(0); + block_operand.set_source(block_2); + EXPECT_EQ(block_2, block_operand.source()); +} diff --git a/test/cpp/ir/core/ir_op_test.cc b/test/cpp/ir/core/ir_op_test.cc index 16cfc01730d60dd10446216ec6d2a0e2227b6085..39880c4e5bdaa245b3f6d275647cb4ba2f7a2f61 100644 --- a/test/cpp/ir/core/ir_op_test.cc +++ b/test/cpp/ir/core/ir_op_test.cc @@ -236,24 +236,26 @@ TEST(op_test, region_test) { argument.attributes = CreateAttributeMap({"op2_attr1", "op2_attr2"}, {"op2_attr1", "op2_attr2"}); argument.output_types = {ir::Float32Type::get(ctx)}; - argument.regions.emplace_back(std::make_unique()); - ir::Region *region = argument.regions.back().get(); - EXPECT_EQ(region->empty(), true); + argument.num_regions = 1; + + ir::Operation *op3 = ir::Operation::Create(std::move(argument)); + // argument.regions.emplace_back(std::make_unique()); + + ir::Region ®ion = op3->region(0); + EXPECT_EQ(region.empty(), true); // (3) Test custom operation printer std::stringstream ss; op1->Print(ss); EXPECT_EQ(ss.str(), " (%0) = \"test.operation1\" ()"); - region->push_back(new ir::Block()); - region->push_front(new ir::Block()); - region->insert(region->begin(), new ir::Block()); - ir::Block *block = region->front(); + region.push_back(new ir::Block()); + region.push_front(new ir::Block()); + region.insert(region.begin(), new ir::Block()); + ir::Block *block = region.front(); block->push_front(op1); block->insert(block->begin(), op1_2); - ir::Operation *op2 = ir::Operation::Create(std::move(argument)); - EXPECT_EQ(op2->region(0).ir_context(), ctx); - op2->Destroy(); + op3->Destroy(); } TEST(op_test, module_op_death) { diff --git a/test/cpp/ir/core/ir_value_test.cc b/test/cpp/ir/core/ir_value_test.cc index 3f90e3a4fd6c80899edb28c9732a92577605cc90..fb7fcfd6fdda1ff4ca6d44934ed6249f6e813d97 100644 --- a/test/cpp/ir/core/ir_value_test.cc +++ b/test/cpp/ir/core/ir_value_test.cc @@ -100,8 +100,8 @@ TEST(value_test, value_test) { EXPECT_EQ(op3_first_input.next_use(), nullptr); // Test 3: Value iterator - using my_iterator = ir::Value::use_iterator; - my_iterator iter = op1->result(0).begin(); + using my_iterator = ir::Value::UseIterator; + my_iterator iter = op1->result(0).use_begin(); EXPECT_EQ(iter.owner(), op4); ++iter; EXPECT_EQ(iter.owner(), op3); diff --git a/test/cpp/ir/tools/CMakeLists.txt b/test/cpp/ir/tools/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..58179d87e0d883c3e792cc5924b3eb815503ecf6 --- /dev/null +++ b/test/cpp/ir/tools/CMakeLists.txt @@ -0,0 +1,4 @@ +cc_library( + test_dialect + SRCS test_dialect.cc test_op.cc + DEPS ir) diff --git a/test/cpp/ir/tools/test_dialect.cc b/test/cpp/ir/tools/test_dialect.cc new file mode 100644 index 0000000000000000000000000000000000000000..c16b9be067663c22702b59c39411da7b8e8a7a77 --- /dev/null +++ b/test/cpp/ir/tools/test_dialect.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "test/cpp/ir/tools/test_dialect.h" +#include "test/cpp/ir/tools/test_op.h" +namespace test { +void TestDialect::initialize() { RegisterOps(); } +} // namespace test +IR_DEFINE_EXPLICIT_TYPE_ID(test::TestDialect) diff --git a/test/cpp/ir/tools/test_dialect.h b/test/cpp/ir/tools/test_dialect.h new file mode 100644 index 0000000000000000000000000000000000000000..4403719458e4b688f9967b13654a7e8875289298 --- /dev/null +++ b/test/cpp/ir/tools/test_dialect.h @@ -0,0 +1,33 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/core/dialect.h" + +namespace test { +class TestDialect : public ir::Dialect { + public: + explicit TestDialect(ir::IrContext *context) + : ir::Dialect(name(), context, ir::TypeId::get()) { + initialize(); + } + static const char *name() { return "test"; } + + private: + void initialize(); +}; + +} // namespace test +IR_DECLARE_EXPLICIT_TYPE_ID(test::TestDialect) diff --git a/test/cpp/ir/tools/test_op.cc b/test/cpp/ir/tools/test_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..40dc46c0b8e14bcbb2a3f726fd748f9f9cb21f00 --- /dev/null +++ b/test/cpp/ir/tools/test_op.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "test/cpp/ir/tools/test_op.h" + +namespace test { +void RegionOp::Build(ir::Builder &builder, ir::OperationArgument &argument) { + argument.num_regions = 1; +} +void RegionOp::Verify() const { + auto num_regions = (*this)->num_regions(); + IR_ENFORCE(num_regions == 1u, + "The region's number in Region Op must be 1, but current is %d", + num_regions); +} + +void BranchOp::Build(ir::Builder &builder, // NOLINT + ir::OperationArgument &argument, + const std::vector &target_operands, + ir::Block *target) { + argument.AddOperands(target_operands.begin(), target_operands.end()); + argument.AddSuccessor(target); +} + +void BranchOp::Verify() const { + IR_ENFORCE((*this)->num_successors() == 1u, + "successors number must equal to 1."); + IR_ENFORCE((*this)->successor(0), "successor[0] can't be nullptr"); +} + +} // namespace test + +IR_DEFINE_EXPLICIT_TYPE_ID(test::RegionOp) +IR_DEFINE_EXPLICIT_TYPE_ID(test::BranchOp) diff --git a/test/cpp/ir/tools/test_op.h b/test/cpp/ir/tools/test_op.h new file mode 100644 index 0000000000000000000000000000000000000000..1462a9555cb07dddfb3cff31590dccc595d72c7b --- /dev/null +++ b/test/cpp/ir/tools/test_op.h @@ -0,0 +1,54 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/core/builder.h" +#include "paddle/ir/core/op_base.h" + +namespace test { +/// +/// \brief TestRegionOp +/// +class RegionOp : public ir::Op { + public: + using Op::Op; + static const char *name() { return "test.region"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(ir::Builder &builder, // NOLINT + ir::OperationArgument &argument); // NOLINT + void Verify() const; +}; + +/// +/// \brief TestBranchOp +/// +class BranchOp : public ir::Op { + public: + using Op::Op; + static const char *name() { return "test.branch"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(ir::Builder &builder, // NOLINT + ir::OperationArgument &argument, // NOLINT + const std::vector &target_operands, + ir::Block *target); + void Verify() const; +}; + +} // namespace test + +IR_DECLARE_EXPLICIT_TYPE_ID(test::RegionOp) +IR_DECLARE_EXPLICIT_TYPE_ID(test::BranchOp)