未验证 提交 662df299 编写于 作者: W winter-wang 提交者: GitHub

[IR] add block operand support for ir. (#56354)

上级 81659f7e
......@@ -18,7 +18,10 @@
#include "paddle/ir/core/region.h"
namespace ir {
Block::~Block() { clear(); }
Block::~Block() {
assert(use_empty() && "block destroyed still has uses.");
clear();
}
void Block::push_back(Operation *op) { insert(ops_.end(), op); }
void Block::push_front(Operation *op) { insert(ops_.begin(), op); }
......@@ -51,4 +54,10 @@ void Block::SetParent(Region *parent, Region::iterator position) {
position_ = position;
}
Block::UseIterator Block::use_begin() const { return first_use_; }
Block::UseIterator Block::use_end() const { return Block::UseIterator(); }
bool Block::HasOneUse() const { return first_use_ && !first_use_.next_use(); }
} // namespace ir
......@@ -17,8 +17,10 @@
#include <cstddef>
#include <list>
#include "paddle/ir/core/block_operand.h"
#include "paddle/ir/core/dll_decl.h"
#include "paddle/ir/core/region.h"
#include "paddle/ir/core/use_iterator.h"
namespace ir {
class Operation;
......@@ -56,6 +58,18 @@ class IR_API Block {
void clear();
operator Region::iterator() { return position_; }
///
/// \brief Provide iterator interface to access Value use chain.
///
using UseIterator = ValueUseIterator<BlockOperand>;
UseIterator use_begin() const;
UseIterator use_end() const;
BlockOperand first_use() const { return first_use_; }
void set_first_use(BlockOperand first_use) { first_use_ = first_use; }
bool use_empty() const { return !first_use_; }
bool HasOneUse() const;
BlockOperand *first_use_addr() { return &first_use_; }
private:
Block(Block &) = delete;
Block &operator=(const Block &) = delete;
......@@ -68,5 +82,6 @@ class IR_API Block {
Region *parent_; // not owned
OpListType ops_; // owned
Region::iterator position_;
BlockOperand first_use_;
};
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/core/block_operand.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/block_operand_impl.h"
#include "paddle/ir/core/enforce.h"
namespace ir {
#define CHECK_BLOCKOPEREND_NULL_IMPL(func_name) \
IR_ENFORCE(impl_, \
"impl_ pointer is null when call func:" #func_name \
" , in class: BlockOperand.")
BlockOperand &BlockOperand::operator=(const BlockOperand &rhs) {
if (this == &rhs) return *this;
impl_ = rhs.impl_;
return *this;
}
BlockOperand::operator bool() const { return impl_ && impl_->source(); }
BlockOperand BlockOperand::next_use() const {
CHECK_BLOCKOPEREND_NULL_IMPL(next_use);
return impl_->next_use();
}
Block *BlockOperand::source() const {
CHECK_BLOCKOPEREND_NULL_IMPL(source);
return impl_->source();
}
void BlockOperand::set_source(Block *source) {
CHECK_BLOCKOPEREND_NULL_IMPL(set_source);
impl_->set_source(source);
}
Operation *BlockOperand::owner() const {
CHECK_BLOCKOPEREND_NULL_IMPL(owner);
return impl_->owner();
}
void BlockOperand::RemoveFromUdChain() {
CHECK_BLOCKOPEREND_NULL_IMPL(RemoveFromUdChain);
return impl_->RemoveFromUdChain();
}
// details
namespace detail {
Operation *BlockOperandImpl::owner() const { return owner_; }
BlockOperand BlockOperandImpl::next_use() const { return next_use_; }
Block *BlockOperandImpl::source() const { return source_; }
void BlockOperandImpl::set_source(Block *source) {
RemoveFromUdChain();
if (!source) {
return;
}
source_ = source;
InsertToUdChain();
}
BlockOperandImpl::BlockOperandImpl(Block *source, ir::Operation *owner)
: source_(source), owner_(owner) {
if (!source) {
return;
}
InsertToUdChain();
}
void BlockOperandImpl::InsertToUdChain() {
prev_use_addr_ = source_->first_use_addr();
next_use_ = source_->first_use();
if (next_use_) {
next_use_.impl()->prev_use_addr_ = &next_use_;
}
source_->set_first_use(this);
}
void BlockOperandImpl::RemoveFromUdChain() {
if (!source_) return;
if (!prev_use_addr_) return;
if (prev_use_addr_ == source_->first_use_addr()) {
source_->set_first_use(next_use_);
} else {
*prev_use_addr_ = next_use_;
}
if (next_use_) {
next_use_.impl()->prev_use_addr_ = prev_use_addr_;
}
next_use_ = nullptr;
prev_use_addr_ = nullptr;
source_ = nullptr;
}
BlockOperandImpl::~BlockOperandImpl() { RemoveFromUdChain(); }
} // namespace detail
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/cast_utils.h"
#include "paddle/ir/core/type.h"
namespace ir {
class Operation;
class Value;
class Block;
namespace detail {
class BlockOperandImpl;
} // namespace detail
///
/// \brief OpOperand class represents the op_operand of operation. This class
/// only provides interfaces, for specific implementation, see Impl class.
///
class IR_API BlockOperand {
public:
BlockOperand() = default;
BlockOperand(const BlockOperand &other) = default;
BlockOperand(detail::BlockOperandImpl *impl) : impl_(impl) {} // NOLINT
BlockOperand &operator=(const BlockOperand &rhs);
bool operator==(const BlockOperand &other) const {
return impl_ == other.impl_;
}
bool operator!=(const BlockOperand &other) const {
return !operator==(other);
}
bool operator!() const { return impl_ == nullptr; }
operator bool() const;
BlockOperand next_use() const;
Block *source() const;
void set_source(Block *source);
Operation *owner() const;
void RemoveFromUdChain();
friend Operation;
detail::BlockOperandImpl *impl() const { return impl_; }
private:
detail::BlockOperandImpl *impl_{nullptr};
};
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/block_operand.h"
namespace ir {
class Operation;
class Block;
namespace detail {
///
/// \brief OpOperandImpl
///
class BlockOperandImpl {
public:
Operation* owner() const;
BlockOperand next_use() const;
Block* source() const;
void set_source(Block*);
/// Remove this op_operand from the current use list.
void RemoveFromUdChain();
~BlockOperandImpl();
friend Operation;
private:
BlockOperandImpl(Block* source, Operation* owner);
// Insert self to the UD chain holded by source_;
// It is not safe. So set provate.
void InsertToUdChain();
BlockOperand next_use_ = nullptr;
BlockOperand* prev_use_addr_ = nullptr;
Block* source_;
Operation* const owner_ = nullptr;
};
} // namespace detail
} // namespace ir
......@@ -40,9 +40,11 @@ Block *ModuleOp::block() {
ModuleOp ModuleOp::Create(IrContext *context, Program *pointer) {
ir::OpInfo info = context->GetRegisteredOpInfo(name());
OperationArgument argument(info);
argument.AddRegion()->emplace_back();
argument.num_regions = 1;
argument.AddAttribute("program", PointerAttribute::get(context, pointer));
return ModuleOp(Operation::Create(std::move(argument)));
Operation *op = Operation::Create(std::move(argument));
op->region(0).emplace_back();
return ModuleOp(op);
}
void ModuleOp::Destroy() {
......
......@@ -15,6 +15,7 @@
#include <ostream>
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/block_operand_impl.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/op_info.h"
......@@ -26,16 +27,12 @@
namespace ir {
Operation *Operation::Create(OperationArgument &&argument) {
Operation *op = Create(argument.inputs,
argument.attributes,
argument.output_types,
argument.info,
argument.regions.size());
for (size_t index = 0; index < argument.regions.size(); ++index) {
op->region(index).TakeBody(std::move(*argument.regions[index]));
}
return op;
return Create(argument.inputs,
argument.attributes,
argument.output_types,
argument.info,
argument.num_regions,
argument.successors);
}
// Allocate the required memory based on the size and number of inputs, outputs,
......@@ -43,13 +40,15 @@ Operation *Operation::Create(OperationArgument &&argument) {
// OpInlineResult, Operation, operand.
Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
const AttributeMap &attributes,
const std::vector<ir::Type> &output_types,
const std::vector<Type> &output_types,
ir::OpInfo op_info,
size_t num_regions) {
size_t num_regions,
const std::vector<Block *> &successors) {
// 1. Calculate the required memory size for OpResults + Operation +
// OpOperands.
uint32_t num_results = output_types.size();
uint32_t num_operands = inputs.size();
uint32_t num_successors = successors.size();
uint32_t max_inline_result_num =
detail::OpResultImpl::GetMaxInlineResultIndex() + 1;
size_t result_mem_size =
......@@ -58,11 +57,12 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
(num_results - max_inline_result_num) +
sizeof(detail::OpInlineResultImpl) * max_inline_result_num
: sizeof(detail::OpInlineResultImpl) * num_results;
size_t operand_mem_size = sizeof(detail::OpOperandImpl) * num_operands;
size_t op_mem_size = sizeof(Operation);
size_t operand_mem_size = sizeof(detail::OpOperandImpl) * num_operands;
size_t block_operand_size = num_successors * sizeof(detail::BlockOperandImpl);
size_t region_mem_size = num_regions * sizeof(Region);
size_t base_size =
result_mem_size + op_mem_size + operand_mem_size + region_mem_size;
size_t base_size = result_mem_size + op_mem_size + operand_mem_size +
region_mem_size + block_operand_size;
// 2. Malloc memory.
char *base_ptr = reinterpret_cast<char *>(aligned_malloc(base_size, 8));
// 3.1. Construct OpResults.
......@@ -77,8 +77,12 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
}
}
// 3.2. Construct Operation.
Operation *op = new (base_ptr)
Operation(attributes, op_info, num_results, num_operands, num_regions);
Operation *op = new (base_ptr) Operation(attributes,
op_info,
num_results,
num_operands,
num_regions,
num_successors);
base_ptr += sizeof(Operation);
// 3.3. Construct OpOperands.
if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) {
......@@ -88,7 +92,17 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op);
base_ptr += sizeof(detail::OpOperandImpl);
}
// 3.4. Construct Regions
// 3.4. Construct BlockOperands.
if (num_successors > 0) {
op->block_operands_ =
reinterpret_cast<detail::BlockOperandImpl *>(base_ptr);
for (size_t idx = 0; idx < num_successors; idx++) {
new (base_ptr) detail::BlockOperandImpl(successors[idx], op);
base_ptr += sizeof(detail::BlockOperandImpl);
}
}
// 3.5. Construct Regions
if (num_regions > 0) {
op->regions_ = reinterpret_cast<Region *>(base_ptr);
for (size_t idx = 0; idx < num_regions; idx++) {
......@@ -118,8 +132,6 @@ void Operation::Destroy() {
// 2. Deconstruct Result.
for (size_t idx = 0; idx < num_results_; ++idx) {
detail::OpResultImpl *impl = result(idx).impl();
IR_ENFORCE(impl->use_empty(),
name() + " operation destroyed but still has uses.");
if (detail::OpOutlineResultImpl::classof(*impl)) {
static_cast<detail::OpOutlineResultImpl *>(impl)->~OpOutlineResultImpl();
} else {
......@@ -132,8 +144,20 @@ void Operation::Destroy() {
// 4. Deconstruct OpOperand.
for (size_t idx = 0; idx < num_operands_; idx++) {
operand(idx).impl()->~OpOperandImpl();
detail::OpOperandImpl *op_operand_impl = operand(idx).impl_;
if (op_operand_impl) {
op_operand_impl->~OpOperandImpl();
}
}
// 5. Deconstruct BlockOperand.
for (size_t idx = 0; idx < num_successors_; idx++) {
detail::BlockOperandImpl *block_operand_impl = block_operands_ + idx;
if (block_operand_impl) {
block_operand_impl->~BlockOperandImpl();
}
}
// 5. Free memory.
uint32_t max_inline_result_num =
detail::OpResultImpl::GetMaxInlineResultIndex() + 1;
......@@ -158,12 +182,14 @@ Operation::Operation(const AttributeMap &attributes,
ir::OpInfo op_info,
uint32_t num_results,
uint32_t num_operands,
uint32_t num_regions)
uint32_t num_regions,
uint32_t num_successors)
: attributes_(attributes),
info_(op_info),
num_results_(num_results),
num_operands_(num_operands),
num_regions_(num_regions) {}
num_regions_(num_regions),
num_successors_(num_successors) {}
ir::OpResult Operation::result(uint32_t index) const {
if (index >= num_results_) {
......@@ -226,14 +252,26 @@ const Program *Operation::GetParentProgram() const {
ModuleOp module_op = op->dyn_cast<ModuleOp>();
return module_op ? module_op.program() : nullptr;
}
BlockOperand Operation::block_operand(uint32_t index) const {
IR_ENFORCE(index < num_successors_, "Invalid block_operand index");
return block_operands_ + index;
}
Block *Operation::successor(uint32_t index) const {
return block_operand(index).source();
}
void Operation::set_successor(Block *block, unsigned index) {
IR_ENFORCE(index < num_operands_, "Invalid block_operand index");
(block_operands_ + index)->set_source(block);
}
Region &Operation::region(unsigned index) {
assert(index < num_regions_ && "invalid region index");
IR_ENFORCE(index < num_regions_, "invalid region index");
return regions_[index];
}
const Region &Operation::region(unsigned index) const {
assert(index < num_regions_ && "invalid region index");
IR_ENFORCE(index < num_regions_, "invalid region index");
return regions_[index];
}
......
......@@ -29,6 +29,10 @@ class Program;
class OpOperand;
class OpResult;
namespace detial {
class BlockOperandImpl;
} // namespace detial
class IR_API alignas(8) Operation final {
public:
///
......@@ -41,7 +45,8 @@ class IR_API alignas(8) Operation final {
const AttributeMap &attributes,
const std::vector<ir::Type> &output_types,
ir::OpInfo op_info,
size_t num_regions = 0);
size_t num_regions = 0,
const std::vector<Block *> &successors = {});
static Operation *Create(OperationArgument &&op_argument);
///
......@@ -59,9 +64,16 @@ class IR_API alignas(8) Operation final {
Value operand_source(uint32_t index) const;
uint32_t num_successors() const { return num_successors_; }
BlockOperand block_operand(uint32_t index) const;
Block *successor(uint32_t index) const;
void set_successor(Block *block, unsigned index);
bool HasSuccessors() { return num_successors_ != 0; }
/// Returns the region held by this operation at position 'index'.
Region &region(unsigned index);
const Region &region(unsigned index) const;
uint32_t num_regions() const { return num_regions_; }
void Print(std::ostream &os) const;
......@@ -90,8 +102,6 @@ class IR_API alignas(8) Operation final {
uint32_t num_operands() const { return num_operands_; }
uint32_t num_regions() const { return num_regions_; }
std::string name() const;
template <typename T>
......@@ -152,7 +162,8 @@ class IR_API alignas(8) Operation final {
ir::OpInfo op_info,
uint32_t num_results,
uint32_t num_operands,
uint32_t num_regions);
uint32_t num_regions,
uint32_t num_successors);
template <typename T, typename Enabler = void>
struct CastUtil {
......@@ -179,7 +190,9 @@ class IR_API alignas(8) Operation final {
const uint32_t num_results_ = 0;
const uint32_t num_operands_ = 0;
const uint32_t num_regions_ = 0;
const uint32_t num_successors_ = 0;
detail::BlockOperandImpl *block_operands_{nullptr};
Region *regions_{nullptr};
Block *parent_{nullptr};
Block::iterator position_;
......
......@@ -22,7 +22,7 @@
#include "paddle/ir/core/value.h"
namespace ir {
class Block;
using AttributeMap = std::unordered_map<std::string, Attribute>;
//===----------------------------------------------------------------------===//
......@@ -36,7 +36,8 @@ struct OperationArgument {
AttributeMap attributes;
std::vector<Type> output_types;
OpInfo info;
std::vector<std::unique_ptr<Region>> regions;
size_t num_regions{0};
std::vector<Block*> successors;
public:
OperationArgument(IrContext* ir_context, const std::string& name);
......@@ -45,12 +46,14 @@ struct OperationArgument {
const AttributeMap& attributes,
const std::vector<Type>& types,
OpInfo info,
std::vector<std::unique_ptr<Region>>&& regions = {})
size_t num_regions = 0,
const std::vector<Block*> successors = {})
: inputs(operands),
attributes(attributes),
output_types(types),
info(info),
regions(std::move(regions)) {}
num_regions(num_regions),
successors(successors) {}
/// Add Operand.
void AddOperand(OpResult operand) { inputs.emplace_back(operand); }
......@@ -74,10 +77,7 @@ struct OperationArgument {
/// Get the context held by this operation state.
IrContext* getContext() const { return info.ir_context(); }
Region* AddRegion() {
regions.emplace_back(new Region);
return regions.back().get();
}
void AddSuccessor(Block* successor) { successors.emplace_back(successor); }
};
template <class InputIt>
......
......@@ -46,6 +46,11 @@ void Region::TakeBody(Region &&other) {
}
void Region::clear() {
// In order to ensure the correctness of UD Chain,
// BlockOperend should be decontructed bofore its source.
for (auto iter = blocks_.rbegin(); iter != blocks_.rend(); ++iter) {
(*iter)->clear();
}
while (!empty()) {
delete blocks_.back();
blocks_.pop_back();
......
......@@ -31,8 +31,6 @@ class IR_API Region {
using reverse_iterator = std::list<Block *>::reverse_iterator;
using const_iterator = std::list<Block *>::const_iterator;
~Region();
Region() = default;
bool empty() const { return blocks_.empty(); }
size_t size() const { return blocks_.size(); }
......@@ -59,6 +57,8 @@ class IR_API Region {
IrContext *ir_context() const;
private:
// region only support construncted by operation.
Region() = delete;
Region(Region &) = delete;
Region &operator=(const Region &) = delete;
friend class Operation;
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace ir {
class Operation;
///
/// \brief Value Iterator
///
template <typename OperandType>
class ValueUseIterator {
public:
ValueUseIterator(OperandType use = nullptr) : current_(use) {} // NOLINT
bool operator==(const ValueUseIterator<OperandType> &rhs) const {
return current_ == rhs.current_;
}
bool operator!=(const ValueUseIterator<OperandType> &rhs) const {
return !(*this == rhs);
}
Operation *owner() const { return current_.owner(); }
OperandType &operator*() { return current_; }
OperandType *operator->() { return &operator*(); }
ValueUseIterator<OperandType> &operator++() {
current_ = current_.next_use();
return *this;
}
ValueUseIterator<OperandType> operator++(int) {
ValueUseIterator<OperandType> tmp = *this;
current_ = current_.next_use();
return tmp;
}
protected:
OperandType current_;
};
} // namespace ir
......@@ -14,9 +14,22 @@
#include "paddle/ir/core/value.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/value_impl.h"
#define CHECK_NULL_IMPL(class_name, func_name) \
IR_ENFORCE(impl_, \
"impl_ pointer is null when call func:" #func_name \
" , in class: " #class_name ".")
#define CHECK_OPOPEREND_NULL_IMPL(func_name) \
CHECK_NULL_IMPL(OpOpernad, func_name)
#define CHECK_VALUE_NULL_IMPL(func_name) CHECK_NULL_IMPL(Value, func_name)
#define CHECK_OPRESULT_NULL_IMPL(func_name) CHECK_NULL_IMPL(OpResult, func_name)
namespace ir {
// Operand
OpOperand::OpOperand(const detail::OpOperandImpl *impl)
: impl_(const_cast<detail::OpOperandImpl *>(impl)) {}
......@@ -34,22 +47,33 @@ OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) {
}
OpOperand::operator bool() const { return impl_ && impl_->source(); }
OpOperand OpOperand::next_use() const { return impl()->next_use(); }
OpOperand OpOperand::next_use() const {
CHECK_OPOPEREND_NULL_IMPL(next_use);
return impl_->next_use();
}
Value OpOperand::source() const { return impl()->source(); }
Value OpOperand::source() const {
CHECK_OPOPEREND_NULL_IMPL(source);
return impl_->source();
}
Type OpOperand::type() const { return source().type(); }
void OpOperand::set_source(Value value) { impl()->set_source(value); }
Operation *OpOperand::owner() const { return impl()->owner(); }
void OpOperand::set_source(Value value) {
CHECK_OPOPEREND_NULL_IMPL(set_source);
impl_->set_source(value);
}
void OpOperand::RemoveFromUdChain() { return impl()->RemoveFromUdChain(); }
Operation *OpOperand::owner() const {
CHECK_OPOPEREND_NULL_IMPL(owner);
return impl_->owner();
}
detail::OpOperandImpl *OpOperand::impl() const {
IR_ENFORCE(impl_, "Can't use impl() interface while op_operand is null.");
return impl_;
void OpOperand::RemoveFromUdChain() {
CHECK_OPOPEREND_NULL_IMPL(RemoveFromUdChain);
return impl_->RemoveFromUdChain();
}
// Value
Value::Value(const detail::ValueImpl *impl)
: impl_(const_cast<detail::ValueImpl *>(impl)) {}
......@@ -66,31 +90,48 @@ bool Value::operator!() const { return impl_ == nullptr; }
Value::operator bool() const { return impl_; }
ir::Type Value::type() const { return impl()->type(); }
ir::Type Value::type() const {
CHECK_VALUE_NULL_IMPL(type);
return impl_->type();
}
void Value::set_type(ir::Type type) { impl()->set_type(type); }
void Value::set_type(ir::Type type) {
CHECK_VALUE_NULL_IMPL(set_type);
impl_->set_type(type);
}
Operation *Value::GetDefiningOp() const {
if (auto result = dyn_cast<OpResult>()) return result.owner();
return nullptr;
}
std::string Value::PrintUdChain() { return impl()->PrintUdChain(); }
std::string Value::PrintUdChain() {
CHECK_VALUE_NULL_IMPL(PrintUdChain);
return impl()->PrintUdChain();
}
Value::use_iterator Value::begin() const { return ir::OpOperand(first_use()); }
Value::UseIterator Value::use_begin() const {
return ir::OpOperand(first_use());
}
Value::use_iterator Value::end() const { return Value::use_iterator(); }
Value::UseIterator Value::use_end() const { return Value::UseIterator(); }
OpOperand Value::first_use() const { return impl()->first_use(); }
OpOperand Value::first_use() const {
CHECK_VALUE_NULL_IMPL(first_use);
return impl_->first_use();
}
bool Value::use_empty() const { return !first_use(); }
bool Value::HasOneUse() const { return impl()->HasOneUse(); }
bool Value::HasOneUse() const {
CHECK_VALUE_NULL_IMPL(HasOneUse);
return impl_->HasOneUse();
}
void Value::ReplaceUsesWithIf(
Value new_value,
const std::function<bool(OpOperand)> &should_replace) const {
for (auto it = begin(); it != end();) {
for (auto it = use_begin(); it != use_end();) {
if (should_replace(*it)) {
(it++)->set_source(new_value);
}
......@@ -98,27 +139,27 @@ void Value::ReplaceUsesWithIf(
}
void Value::ReplaceAllUsesWith(Value new_value) const {
for (auto it = begin(); it != end();) {
for (auto it = use_begin(); it != use_end();) {
(it++)->set_source(new_value);
}
}
detail::ValueImpl *Value::impl() const {
IR_ENFORCE(impl_, "Can't use impl() interface while value is null.");
return impl_;
}
// OpResult
bool OpResult::classof(Value value) {
return value && ir::isa<detail::OpResultImpl>(value.impl());
}
Operation *OpResult::owner() const { return impl()->owner(); }
Operation *OpResult::owner() const {
CHECK_OPRESULT_NULL_IMPL(owner);
return impl()->owner();
}
uint32_t OpResult::GetResultIndex() const { return impl()->GetResultIndex(); }
uint32_t OpResult::GetResultIndex() const {
CHECK_OPRESULT_NULL_IMPL(GetResultIndex);
return impl()->GetResultIndex();
}
detail::OpResultImpl *OpResult::impl() const {
IR_ENFORCE(impl_, "Can't use impl() interface while value is null.");
return reinterpret_cast<detail::OpResultImpl *>(impl_);
}
......@@ -168,7 +209,7 @@ void OpOperandImpl::InsertToUdChain() {
if (next_use_) {
next_use_->prev_use_addr_ = &next_use_;
}
source_.impl()->SetFirstUse(this);
source_.impl()->set_first_use(this);
}
void OpOperandImpl::RemoveFromUdChain() {
......@@ -176,9 +217,9 @@ void OpOperandImpl::RemoveFromUdChain() {
if (!prev_use_addr_) return;
if (prev_use_addr_ == source_.impl()->first_use_addr()) {
/// NOTE: In ValueImpl, first_use_offseted_by_index_ use lower three bits
/// storage index information, so need to be updated using the SetFirstUse
/// storage index information, so need to be updated using the set_first_use
/// method here.
source_.impl()->SetFirstUse(next_use_);
source_.impl()->set_first_use(next_use_);
} else {
*prev_use_addr_ = next_use_;
}
......@@ -223,6 +264,11 @@ uint32_t OpResultImpl::GetResultIndex() const {
return ir::dyn_cast<OpInlineResultImpl>(this)->GetResultIndex();
}
OpResultImpl::~OpResultImpl() {
assert(use_empty() &&
owner()->name() + " operation destroyed but still has uses.");
}
ir::Operation *OpResultImpl::owner() const {
// For inline result, pointer offset index to obtain the address of op.
if (const auto *result = ir::dyn_cast<OpInlineResultImpl>(this)) {
......
......@@ -16,6 +16,7 @@
#include "paddle/ir/core/cast_utils.h"
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/use_iterator.h"
namespace ir {
class Operation;
......@@ -66,49 +67,9 @@ class IR_API OpOperand {
friend Operation;
private:
// The interface shoule ensure impl_ isn't nullptr.
// if the user can accept impl_ is nullptr, shoule use impl_ member directly.
detail::OpOperandImpl *impl() const;
detail::OpOperandImpl *impl_{nullptr};
};
///
/// \brief Value Iterator
///
template <typename OperandType>
class ValueUseIterator {
public:
ValueUseIterator(OperandType use = nullptr) : current_(use) {} // NOLINT
bool operator==(const ValueUseIterator<OperandType> &rhs) const {
return current_ == rhs.current_;
}
bool operator!=(const ValueUseIterator<OperandType> &rhs) const {
return !(*this == rhs);
}
ir::Operation *owner() const { return current_.owner(); }
OperandType &operator*() { return current_; }
OperandType *operator->() { return &operator*(); }
ValueUseIterator<OperandType> &operator++() {
current_ = current_.next_use();
return *this;
}
ValueUseIterator<OperandType> operator++(int) {
ValueUseIterator<OperandType> tmp = *this;
++*(this);
return tmp;
}
protected:
OperandType current_;
};
///
/// \brief Value class represents the SSA value in the IR system. This class
/// only provides interfaces, for specific implementation, see Impl class.
......@@ -150,11 +111,11 @@ class IR_API Value {
///
/// \brief Provide iterator interface to access Value use chain.
///
using use_iterator = ValueUseIterator<OpOperand>;
using UseIterator = ValueUseIterator<OpOperand>;
use_iterator begin() const;
UseIterator use_begin() const;
use_iterator end() const;
UseIterator use_end() const;
OpOperand first_use() const;
......@@ -169,9 +130,8 @@ class IR_API Value {
const std::function<bool(OpOperand)> &should_replace) const;
void ReplaceAllUsesWith(Value new_value) const;
// The interface shoule ensure impl_ isn't nullptr.
// if the user can accept impl_ is nullptr, shoule use impl_ member directly.
detail::ValueImpl *impl() const;
detail::ValueImpl *impl() { return impl_; }
const detail::ValueImpl *impl() const { return impl_; }
protected:
detail::ValueImpl *impl_{nullptr};
......@@ -197,11 +157,10 @@ class IR_API OpResult : public Value {
friend Operation;
detail::ValueImpl *value_impl() const;
detail::OpResultImpl *impl() const;
private:
static uint32_t GetValidInlineIndex(uint32_t index);
detail::OpResultImpl *impl() const;
};
} // namespace ir
......
......@@ -46,7 +46,7 @@ class OpOperandImpl {
OpOperandImpl(ir::Value source, ir::Operation *owner);
// Insert self to the UD chain holded by source_;
// It is not safe. So set provate.
// It is not safe. So set private.
void InsertToUdChain();
ir::detail::OpOperandImpl *next_use_ = nullptr;
......@@ -85,7 +85,7 @@ class alignas(8) ValueImpl {
reinterpret_cast<uintptr_t>(first_use_offseted_by_index_) & (~0x07));
}
void SetFirstUse(OpOperandImpl *first_use) {
void set_first_use(OpOperandImpl *first_use) {
uint32_t offset = index();
first_use_offseted_by_index_ = reinterpret_cast<OpOperandImpl *>(
reinterpret_cast<uintptr_t>(first_use) + offset);
......@@ -163,6 +163,8 @@ class alignas(8) OpResultImpl : public ValueImpl {
static uint32_t GetMaxInlineResultIndex() {
return OUTLINE_OP_RESULT_INDEX - 1;
}
~OpResultImpl();
};
///
......
add_subdirectory(tools)
add_subdirectory(core)
add_subdirectory(pass)
add_subdirectory(pattern_rewrite)
......
......@@ -95,3 +95,12 @@ cc_test_old(
program_translator
pd_dialect
ir)
cc_test_old(
block_operand_test
SRCS
block_operand_test.cc
DEPS
test_dialect
gtest
ir)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/program.h"
#include "test/cpp/ir/tools/test_dialect.h"
#include "test/cpp/ir/tools/test_op.h"
TEST(block_operand_test, type_block) {
ir::IrContext ctx;
ctx.GetOrRegisterDialect<test::TestDialect>();
ir::Program program(&ctx);
ir::Block* block = program.block();
ir::Builder builder(&ctx, block);
test::RegionOp region_op = builder.Build<test::RegionOp>();
auto& region = region_op->region(0);
ir::Block* block_1 = new ir::Block();
ir::Block* block_2 = new ir::Block();
ir::Block* block_3 = new ir::Block();
region.push_back(block_1);
region.push_back(block_2);
region.push_back(block_3);
builder.SetInsertionPointToEnd(block_1);
auto op1 =
builder.Build<test::BranchOp>(std::vector<ir::OpResult>{}, block_2);
EXPECT_TRUE(block_2->HasOneUse());
EXPECT_FALSE(block_2->use_empty());
auto iter_begin = block_2->use_begin();
auto iter_end = block_2->use_end();
auto block_operand = op1->block_operand(0);
auto iter_curr = iter_begin++;
EXPECT_EQ(iter_begin, iter_end);
EXPECT_EQ(*iter_curr, block_operand);
EXPECT_EQ(block_2->first_use(), block_operand);
EXPECT_EQ(iter_curr->owner(), op1);
builder.SetInsertionPointToEnd(block_3);
auto op3 =
builder.Build<test::BranchOp>(std::vector<ir::OpResult>{}, block_1);
block_operand = op3->block_operand(0);
block_operand.set_source(block_2);
EXPECT_EQ(block_2, block_operand.source());
}
......@@ -236,24 +236,26 @@ TEST(op_test, region_test) {
argument.attributes = CreateAttributeMap({"op2_attr1", "op2_attr2"},
{"op2_attr1", "op2_attr2"});
argument.output_types = {ir::Float32Type::get(ctx)};
argument.regions.emplace_back(std::make_unique<ir::Region>());
ir::Region *region = argument.regions.back().get();
EXPECT_EQ(region->empty(), true);
argument.num_regions = 1;
ir::Operation *op3 = ir::Operation::Create(std::move(argument));
// argument.regions.emplace_back(std::make_unique<ir::Region>());
ir::Region &region = op3->region(0);
EXPECT_EQ(region.empty(), true);
// (3) Test custom operation printer
std::stringstream ss;
op1->Print(ss);
EXPECT_EQ(ss.str(), " (%0) = \"test.operation1\" ()");
region->push_back(new ir::Block());
region->push_front(new ir::Block());
region->insert(region->begin(), new ir::Block());
ir::Block *block = region->front();
region.push_back(new ir::Block());
region.push_front(new ir::Block());
region.insert(region.begin(), new ir::Block());
ir::Block *block = region.front();
block->push_front(op1);
block->insert(block->begin(), op1_2);
ir::Operation *op2 = ir::Operation::Create(std::move(argument));
EXPECT_EQ(op2->region(0).ir_context(), ctx);
op2->Destroy();
op3->Destroy();
}
TEST(op_test, module_op_death) {
......
......@@ -100,8 +100,8 @@ TEST(value_test, value_test) {
EXPECT_EQ(op3_first_input.next_use(), nullptr);
// Test 3: Value iterator
using my_iterator = ir::Value::use_iterator;
my_iterator iter = op1->result(0).begin();
using my_iterator = ir::Value::UseIterator;
my_iterator iter = op1->result(0).use_begin();
EXPECT_EQ(iter.owner(), op4);
++iter;
EXPECT_EQ(iter.owner(), op3);
......
cc_library(
test_dialect
SRCS test_dialect.cc test_op.cc
DEPS ir)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "test/cpp/ir/tools/test_dialect.h"
#include "test/cpp/ir/tools/test_op.h"
namespace test {
void TestDialect::initialize() { RegisterOps<RegionOp, BranchOp>(); }
} // namespace test
IR_DEFINE_EXPLICIT_TYPE_ID(test::TestDialect)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/dialect.h"
namespace test {
class TestDialect : public ir::Dialect {
public:
explicit TestDialect(ir::IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<TestDialect>()) {
initialize();
}
static const char *name() { return "test"; }
private:
void initialize();
};
} // namespace test
IR_DECLARE_EXPLICIT_TYPE_ID(test::TestDialect)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "test/cpp/ir/tools/test_op.h"
namespace test {
void RegionOp::Build(ir::Builder &builder, ir::OperationArgument &argument) {
argument.num_regions = 1;
}
void RegionOp::Verify() const {
auto num_regions = (*this)->num_regions();
IR_ENFORCE(num_regions == 1u,
"The region's number in Region Op must be 1, but current is %d",
num_regions);
}
void BranchOp::Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument,
const std::vector<ir::OpResult> &target_operands,
ir::Block *target) {
argument.AddOperands(target_operands.begin(), target_operands.end());
argument.AddSuccessor(target);
}
void BranchOp::Verify() const {
IR_ENFORCE((*this)->num_successors() == 1u,
"successors number must equal to 1.");
IR_ENFORCE((*this)->successor(0), "successor[0] can't be nullptr");
}
} // namespace test
IR_DEFINE_EXPLICIT_TYPE_ID(test::RegionOp)
IR_DEFINE_EXPLICIT_TYPE_ID(test::BranchOp)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/op_base.h"
namespace test {
///
/// \brief TestRegionOp
///
class RegionOp : public ir::Op<RegionOp> {
public:
using Op::Op;
static const char *name() { return "test.region"; }
static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr;
static void Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument); // NOLINT
void Verify() const;
};
///
/// \brief TestBranchOp
///
class BranchOp : public ir::Op<BranchOp> {
public:
using Op::Op;
static const char *name() { return "test.branch"; }
static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr;
static void Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT
const std::vector<ir::OpResult> &target_operands,
ir::Block *target);
void Verify() const;
};
} // namespace test
IR_DECLARE_EXPLICIT_TYPE_ID(test::RegionOp)
IR_DECLARE_EXPLICIT_TYPE_ID(test::BranchOp)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册