未验证 提交 e50266fe 编写于 作者: W winter-wang 提交者: GitHub

[IR] add the erase api for region&block. (#54844)

上级 ef445ec8
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/ir/core/block.h" #include "paddle/ir/core/block.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/ir/core/region.h" #include "paddle/ir/core/region.h"
...@@ -32,6 +33,12 @@ Block::iterator Block::insert(const_iterator iterator, Operation *op) { ...@@ -32,6 +33,12 @@ Block::iterator Block::insert(const_iterator iterator, Operation *op) {
return iter; return iter;
} }
Block::iterator Block::erase(const_iterator position) {
IR_ENFORCE((*position)->GetParent() == this, "iterator not own this block.");
(*position)->Destroy();
return ops_.erase(position);
}
void Block::clear() { void Block::clear() {
while (!empty()) { while (!empty()) {
ops_.back()->Destroy(); ops_.back()->Destroy();
......
...@@ -50,6 +50,7 @@ class IR_API Block { ...@@ -50,6 +50,7 @@ class IR_API Block {
void push_back(Operation *op); void push_back(Operation *op);
void push_front(Operation *op); void push_front(Operation *op);
iterator insert(const_iterator iterator, Operation *op); iterator insert(const_iterator iterator, Operation *op);
iterator erase(const_iterator position);
void clear(); void clear();
operator Region::iterator() { return position_; } operator Region::iterator() { return position_; }
......
...@@ -33,8 +33,8 @@ Program *ModuleOp::program() { ...@@ -33,8 +33,8 @@ Program *ModuleOp::program() {
Block *ModuleOp::block() { Block *ModuleOp::block() {
assert(operation() != nullptr); assert(operation() != nullptr);
assert(operation()->num_regions() == 1); assert(operation()->num_regions() == 1);
assert(operation()->GetRegion(0).size() == 1); assert(operation()->region(0).size() == 1);
return operation()->GetRegion(0).front(); return operation()->region(0).front();
} }
ModuleOp ModuleOp::Create(IrContext *context, Program *pointer) { ModuleOp ModuleOp::Create(IrContext *context, Program *pointer) {
...@@ -71,6 +71,15 @@ void ModuleOp::Verify(const std::vector<ir::OpResult> &inputs, ...@@ -71,6 +71,15 @@ void ModuleOp::Verify(const std::vector<ir::OpResult> &inputs,
const char *GetParameterOp::attributes_name[attributes_num] = { const char *GetParameterOp::attributes_name[attributes_num] = {
"parameter_name"}; "parameter_name"};
void GetParameterOp::Build(Builder &builder,
OperationArgument &argument,
const std::string &name,
Type type) {
argument.attributes[attributes_name[0]] =
ir::StrAttribute::get(builder.ir_context(), name);
argument.output_types.emplace_back(type);
}
void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs, void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
...@@ -90,6 +99,14 @@ void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs, ...@@ -90,6 +99,14 @@ void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const char *SetParameterOp::attributes_name[attributes_num] = { const char *SetParameterOp::attributes_name[attributes_num] = {
"parameter_name"}; "parameter_name"};
void SetParameterOp::Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
OpResult parameter,
const std::string &name) {
argument.AddOperand(parameter);
argument.AddAttribute(attributes_name[0],
ir::StrAttribute::get(builder.ir_context(), name));
}
void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs, void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
...@@ -106,6 +123,18 @@ void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs, ...@@ -106,6 +123,18 @@ void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0."); IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
} }
void CombineOp::Build(Builder &builder,
OperationArgument &argument,
const std::vector<ir::OpResult> &inputs) {
argument.inputs = inputs;
std::vector<ir::Type> inputs_type(inputs.size());
for (size_t idx = 0; idx < inputs.size(); ++idx) {
inputs_type[idx] = inputs[idx].type();
}
argument.output_types.emplace_back(
ir::VectorType::get(builder.ir_context(), inputs_type));
}
void CombineOp::Verify(const std::vector<ir::OpResult> &inputs, void CombineOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
......
...@@ -54,8 +54,12 @@ class IR_API GetParameterOp : public ir::Op<GetParameterOp> { ...@@ -54,8 +54,12 @@ class IR_API GetParameterOp : public ir::Op<GetParameterOp> {
static const char *name() { return "builtin.get_parameter"; } static const char *name() { return "builtin.get_parameter"; }
static constexpr uint32_t attributes_num = 1; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs, static void Build(Builder &builder, // NOLINT
const std::vector<ir::Type> &outputs, OperationArgument &argument, // NOLINT
const std::string &name,
Type type);
static void Verify(const std::vector<OpResult> &inputs,
const std::vector<Type> &outputs,
const ir::AttributeMap &attributes); const ir::AttributeMap &attributes);
}; };
...@@ -69,6 +73,10 @@ class IR_API SetParameterOp : public ir::Op<SetParameterOp> { ...@@ -69,6 +73,10 @@ class IR_API SetParameterOp : public ir::Op<SetParameterOp> {
static const char *name() { return "builtin.set_parameter"; } static const char *name() { return "builtin.set_parameter"; }
static constexpr uint32_t attributes_num = 1; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
OpResult parameter,
const std::string &name);
static void Verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes); const ir::AttributeMap &attributes);
...@@ -87,6 +95,10 @@ class IR_API CombineOp : public ir::Op<CombineOp> { ...@@ -87,6 +95,10 @@ class IR_API CombineOp : public ir::Op<CombineOp> {
static constexpr const char **attributes_name = nullptr; static constexpr const char **attributes_name = nullptr;
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const std::vector<ir::OpResult> &inputs);
static void Verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes); const ir::AttributeMap &attributes);
......
...@@ -104,7 +104,7 @@ void BasicIrPrinter::PrintAttribute(const Attribute& attr) { ...@@ -104,7 +104,7 @@ void BasicIrPrinter::PrintAttribute(const Attribute& attr) {
void IrPrinter::PrintProgram(Program* program) { void IrPrinter::PrintProgram(Program* program) {
auto top_level_op = program->module_op(); auto top_level_op = program->module_op();
for (size_t i = 0; i < top_level_op->num_regions(); ++i) { for (size_t i = 0; i < top_level_op->num_regions(); ++i) {
auto& region = top_level_op->GetRegion(i); auto& region = top_level_op->region(i);
for (auto it = region.begin(); it != region.end(); ++it) { for (auto it = region.begin(); it != region.end(); ++it) {
auto* block = *it; auto* block = *it;
os << "{\n"; os << "{\n";
...@@ -153,7 +153,7 @@ void IrPrinter::PrintFullOperation(Operation* op) { ...@@ -153,7 +153,7 @@ void IrPrinter::PrintFullOperation(Operation* op) {
os << newline; os << newline;
} }
for (size_t i = 0; i < op->num_regions(); ++i) { for (size_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->GetRegion(i); auto& region = op->region(i);
PrintRegion(region); PrintRegion(region);
} }
} }
......
...@@ -33,7 +33,7 @@ Operation *Operation::Create(OperationArgument &&argument) { ...@@ -33,7 +33,7 @@ Operation *Operation::Create(OperationArgument &&argument) {
argument.regions.size()); argument.regions.size());
for (size_t index = 0; index < argument.regions.size(); ++index) { for (size_t index = 0; index < argument.regions.size(); ++index) {
op->GetRegion(index).TakeBody(std::move(*argument.regions[index])); op->region(index).TakeBody(std::move(*argument.regions[index]));
} }
return op; return op;
} }
...@@ -103,17 +103,35 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs, ...@@ -103,17 +103,35 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
return op; return op;
} }
// Call destructors for OpResults, Operation, and OpOperands in sequence, and // Call destructors for Region , OpResults, Operation, and OpOperands in
// finally free memory. // sequence, and finally free memory.
void Operation::Destroy() { void Operation::Destroy() {
// Deconstruct Regions. // 1. Deconstruct Regions.
if (num_regions_ > 0) { if (num_regions_ > 0) {
for (size_t idx = 0; idx < num_regions_; idx++) { for (size_t idx = 0; idx < num_regions_; idx++) {
regions_[idx].~Region(); regions_[idx].~Region();
} }
} }
// 1. Get aligned_ptr by result_num. // 2. Deconstruct Result.
for (size_t idx = 0; idx < num_results_; ++idx) {
detail::OpResultImpl *impl = result(idx).impl();
IR_ENFORCE(impl->use_empty(), "operation destroyed but still has uses.");
if (detail::OpOutlineResultImpl::classof(*impl)) {
static_cast<detail::OpOutlineResultImpl *>(impl)->~OpOutlineResultImpl();
} else {
static_cast<detail::OpInlineResultImpl *>(impl)->~OpInlineResultImpl();
}
}
// 3. Deconstruct Operation.
this->~Operation();
// 4. Deconstruct OpOperand.
for (size_t idx = 0; idx < num_operands_; idx++) {
operand(idx).impl()->~OpOperandImpl();
}
// 5. Free memory.
uint32_t max_inline_result_num = uint32_t max_inline_result_num =
detail::OpResultImpl::GetMaxInlineResultIndex() + 1; detail::OpResultImpl::GetMaxInlineResultIndex() + 1;
size_t result_mem_size = size_t result_mem_size =
...@@ -122,46 +140,11 @@ void Operation::Destroy() { ...@@ -122,46 +140,11 @@ void Operation::Destroy() {
(num_results_ - max_inline_result_num) + (num_results_ - max_inline_result_num) +
sizeof(detail::OpInlineResultImpl) * max_inline_result_num sizeof(detail::OpInlineResultImpl) * max_inline_result_num
: sizeof(detail::OpInlineResultImpl) * num_results_; : sizeof(detail::OpInlineResultImpl) * num_results_;
char *aligned_ptr = reinterpret_cast<char *>(this) - result_mem_size; void *aligned_ptr = reinterpret_cast<char *>(this) - result_mem_size;
// 2.1. Deconstruct OpResult.
char *base_ptr = aligned_ptr; VLOG(4) << "Destroy an Operation: {ptr = " << aligned_ptr
for (size_t idx = num_results_; idx > 0; idx--) {
// release the uses of this result
detail::OpOperandImpl *first_use =
reinterpret_cast<detail::OpResultImpl *>(base_ptr)->first_use();
while (first_use != nullptr) {
first_use->RemoveFromUdChain();
first_use =
reinterpret_cast<detail::OpResultImpl *>(base_ptr)->first_use();
}
// destory the result
if (idx > max_inline_result_num) {
reinterpret_cast<detail::OpOutlineResultImpl *>(base_ptr)
->~OpOutlineResultImpl();
base_ptr += sizeof(detail::OpOutlineResultImpl);
} else {
reinterpret_cast<detail::OpInlineResultImpl *>(base_ptr)
->~OpInlineResultImpl();
base_ptr += sizeof(detail::OpInlineResultImpl);
}
}
// 2.2. Deconstruct Operation.
if (reinterpret_cast<uintptr_t>(base_ptr) !=
reinterpret_cast<uintptr_t>(this)) {
IR_THROW("Operation address error");
}
reinterpret_cast<Operation *>(base_ptr)->~Operation();
base_ptr += sizeof(Operation);
// 2.3. Deconstruct OpOperand.
for (size_t idx = 0; idx < num_operands_; idx++) {
reinterpret_cast<detail::OpOperandImpl *>(base_ptr)->~OpOperandImpl();
base_ptr += sizeof(detail::OpOperandImpl);
}
// 3. Free memory.
VLOG(4) << "Destroy an Operation: {ptr = "
<< reinterpret_cast<void *>(aligned_ptr)
<< ", size = " << result_mem_size << "}"; << ", size = " << result_mem_size << "}";
aligned_free(reinterpret_cast<void *>(aligned_ptr)); aligned_free(aligned_ptr);
} }
IrContext *Operation::ir_context() const { return info_.ir_context(); } IrContext *Operation::ir_context() const { return info_.ir_context(); }
...@@ -231,7 +214,7 @@ Program *Operation::GetParentProgram() { ...@@ -231,7 +214,7 @@ Program *Operation::GetParentProgram() {
return module_op ? module_op.program() : nullptr; return module_op ? module_op.program() : nullptr;
} }
Region &Operation::GetRegion(unsigned index) { Region &Operation::region(unsigned index) {
assert(index < num_regions_ && "invalid region index"); assert(index < num_regions_ && "invalid region index");
return regions_[index]; return regions_[index];
} }
......
...@@ -54,6 +54,9 @@ class IR_API alignas(8) Operation final { ...@@ -54,6 +54,9 @@ class IR_API alignas(8) Operation final {
OpOperand operand(uint32_t index) const; OpOperand operand(uint32_t index) const;
/// Returns the region held by this operation at position 'index'.
Region &region(unsigned index);
void Print(std::ostream &os); void Print(std::ostream &os);
const AttributeMap &attributes() const { return attributes_; } const AttributeMap &attributes() const { return attributes_; }
...@@ -95,11 +98,10 @@ class IR_API alignas(8) Operation final { ...@@ -95,11 +98,10 @@ class IR_API alignas(8) Operation final {
Program *GetParentProgram(); Program *GetParentProgram();
/// Returns the region held by this operation at position 'index'.
Region &GetRegion(unsigned index);
operator Block::iterator() { return position_; } operator Block::iterator() { return position_; }
operator Block::const_iterator() const { return position_; }
private: private:
Operation(const AttributeMap &attribute, Operation(const AttributeMap &attribute,
ir::OpInfo op_info, ir::OpInfo op_info,
......
...@@ -51,9 +51,15 @@ struct OperationArgument { ...@@ -51,9 +51,15 @@ struct OperationArgument {
info(info), info(info),
regions(std::move(regions)) {} regions(std::move(regions)) {}
/// Add Operand.
void AddOperand(OpResult operand) { inputs.emplace_back(operand); }
template <class InputIt> template <class InputIt>
void AddOperands(InputIt first, InputIt last); void AddOperands(InputIt first, InputIt last);
/// Add Output.
void AddOutput(Type type) { output_types.emplace_back(type); }
template <class InputIt> template <class InputIt>
void AddTypes(InputIt first, InputIt last); void AddTypes(InputIt first, InputIt last);
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/ir/core/region.h" #include "paddle/ir/core/region.h"
#include "paddle/ir/core/block.h" #include "paddle/ir/core/block.h"
#include "paddle/ir/core/enforce.h"
namespace ir { namespace ir {
Region::~Region() { clear(); } Region::~Region() { clear(); }
...@@ -29,6 +30,12 @@ Region::iterator Region::insert(const_iterator position, Block *block) { ...@@ -29,6 +30,12 @@ Region::iterator Region::insert(const_iterator position, Block *block) {
block->SetParent(this, iter); block->SetParent(this, iter);
return iter; return iter;
} }
Region::iterator Region::erase(const_iterator position) {
IR_ENFORCE((*position)->GetParent() == this, "iterator not own this region.");
delete *position;
return blocks_.erase(position);
}
void Region::TakeBody(Region &&other) { void Region::TakeBody(Region &&other) {
clear(); clear();
blocks_.swap(other.blocks_); blocks_.swap(other.blocks_);
......
...@@ -48,6 +48,7 @@ class IR_API Region { ...@@ -48,6 +48,7 @@ class IR_API Region {
void emplace_back(); void emplace_back();
void push_front(Block *block); void push_front(Block *block);
iterator insert(const_iterator position, Block *block); iterator insert(const_iterator position, Block *block);
iterator erase(const_iterator position);
void clear(); void clear();
void TakeBody(Region &&other); void TakeBody(Region &&other);
......
...@@ -34,17 +34,22 @@ OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) { ...@@ -34,17 +34,22 @@ OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) {
} }
OpOperand::operator bool() const { return impl_ && impl_->source(); } OpOperand::operator bool() const { return impl_ && impl_->source(); }
OpOperand OpOperand::next_use() const { return impl_->next_use(); } OpOperand OpOperand::next_use() const { return impl()->next_use(); }
Value OpOperand::source() const { return impl_->source(); } Value OpOperand::source() const { return impl()->source(); }
void OpOperand::set_source(Value value) { Type OpOperand::type() const { return source().type(); }
IR_ENFORCE(impl_, "Can't set source for a null value.");
impl_->set_source(value); void OpOperand::set_source(Value value) { impl()->set_source(value); }
}
Operation *OpOperand::owner() const { return impl()->owner(); }
Operation *OpOperand::owner() const { return impl_->owner(); } void OpOperand::RemoveFromUdChain() { return impl()->RemoveFromUdChain(); }
detail::OpOperandImpl *OpOperand::impl() const {
IR_ENFORCE(impl_, "Can't use impl() interface while operand is null.");
return impl_;
}
// Value // Value
Value::Value(const detail::ValueImpl *impl) Value::Value(const detail::ValueImpl *impl)
: impl_(const_cast<detail::ValueImpl *>(impl)) {} : impl_(const_cast<detail::ValueImpl *>(impl)) {}
...@@ -84,13 +89,18 @@ void Value::ReplaceUsesWithIf( ...@@ -84,13 +89,18 @@ void Value::ReplaceUsesWithIf(
Value new_value, Value new_value,
const std::function<bool(OpOperand)> &should_replace) const { const std::function<bool(OpOperand)> &should_replace) const {
for (auto it = begin(); it != end();) { for (auto it = begin(); it != end();) {
auto cur = it++; if (should_replace(*it)) {
if (should_replace(*cur)) { (it++)->set_source(new_value);
cur->set_source(new_value);
} }
} }
} }
void Value::ReplaceAllUsesWith(Value new_value) const {
for (auto it = begin(); it != end();) {
(it++)->set_source(new_value);
}
}
detail::ValueImpl *Value::impl() const { detail::ValueImpl *Value::impl() const {
IR_ENFORCE(impl_, "Can't use impl() interface while value is null."); IR_ENFORCE(impl_, "Can't use impl() interface while value is null.");
return impl_; return impl_;
...@@ -106,6 +116,7 @@ Operation *OpResult::owner() const { return impl()->owner(); } ...@@ -106,6 +116,7 @@ Operation *OpResult::owner() const { return impl()->owner(); }
uint32_t OpResult::GetResultIndex() const { return impl()->GetResultIndex(); } uint32_t OpResult::GetResultIndex() const { return impl()->GetResultIndex(); }
detail::OpResultImpl *OpResult::impl() const { detail::OpResultImpl *OpResult::impl() const {
IR_ENFORCE(impl_, "Can't use impl() interface while value is null.");
return reinterpret_cast<detail::OpResultImpl *>(impl_); return reinterpret_cast<detail::OpResultImpl *>(impl_);
} }
......
...@@ -55,11 +55,21 @@ class IR_API OpOperand { ...@@ -55,11 +55,21 @@ class IR_API OpOperand {
Value source() const; Value source() const;
Type type() const;
void set_source(Value value); void set_source(Value value);
Operation *owner() const; Operation *owner() const;
void RemoveFromUdChain();
friend Operation;
private: private:
// The interface shoule ensure impl_ isn't nullptr.
// if the user can accept impl_ is nullptr, shoule use impl_ member directly.
detail::OpOperandImpl *impl() const;
detail::OpOperandImpl *impl_{nullptr}; detail::OpOperandImpl *impl_{nullptr};
}; };
...@@ -155,6 +165,7 @@ class IR_API Value { ...@@ -155,6 +165,7 @@ class IR_API Value {
void ReplaceUsesWithIf( void ReplaceUsesWithIf(
Value new_value, Value new_value,
const std::function<bool(OpOperand)> &should_replace) const; const std::function<bool(OpOperand)> &should_replace) const;
void ReplaceAllUsesWith(Value new_value) const;
// The interface shoule ensure impl_ isn't nullptr. // The interface shoule ensure impl_ isn't nullptr.
// if the user can accept impl_ is nullptr, shoule use impl_ member directly. // if the user can accept impl_ is nullptr, shoule use impl_ member directly.
......
...@@ -44,7 +44,7 @@ void detail::PassAdaptor::RunImpl(Operation* op, ...@@ -44,7 +44,7 @@ void detail::PassAdaptor::RunImpl(Operation* op,
auto last_am = analysis_manager(); auto last_am = analysis_manager();
for (size_t i = 0; i < op->num_regions(); ++i) { for (size_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->GetRegion(i); auto& region = op->region(i);
for (auto it = region.begin(); it != region.end(); ++it) { for (auto it = region.begin(); it != region.end(); ++it) {
auto* block = *it; auto* block = *it;
for (auto it = block->begin(); it != block->end(); ++it) { for (auto it = block->begin(); it != block->end(); ++it) {
......
...@@ -2,6 +2,7 @@ cc_test_old(type_test SRCS type_test.cc DEPS ir gtest) ...@@ -2,6 +2,7 @@ cc_test_old(type_test SRCS type_test.cc DEPS ir gtest)
cc_test_old(ir_attribute_test SRCS ir_attribute_test.cc DEPS ir gtest) cc_test_old(ir_attribute_test SRCS ir_attribute_test.cc DEPS ir gtest)
cc_test_old(ir_value_test SRCS ir_value_test.cc DEPS ir gtest) cc_test_old(ir_value_test SRCS ir_value_test.cc DEPS ir gtest)
cc_test_old(ir_op_test SRCS ir_op_test.cc DEPS ir gtest) cc_test_old(ir_op_test SRCS ir_op_test.cc DEPS ir gtest)
cc_test_old(ir_region_test SRCS ir_region_test.cc DEPS ir gtest)
cc_test_old( cc_test_old(
ir_program_test ir_program_test
SRCS SRCS
......
...@@ -48,7 +48,21 @@ class AddOp : public ir::Op<AddOp> { ...@@ -48,7 +48,21 @@ class AddOp : public ir::Op<AddOp> {
throw("The size of outputs must be equal to 1."); throw("The size of outputs must be equal to 1.");
} }
} }
static void Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT
ir::OpResult l_operand,
ir::OpResult r_operand,
ir::Type sum_type);
}; };
void AddOp::Build(ir::Builder &,
ir::OperationArgument &argument,
ir::OpResult l_operand,
ir::OpResult r_operand,
ir::Type sum_type) {
argument.AddOperand(l_operand);
argument.AddOperand(r_operand);
argument.AddOutput(sum_type);
}
IR_DECLARE_EXPLICIT_TYPE_ID(AddOp) IR_DECLARE_EXPLICIT_TYPE_ID(AddOp)
IR_DEFINE_EXPLICIT_TYPE_ID(AddOp) IR_DEFINE_EXPLICIT_TYPE_ID(AddOp)
...@@ -90,22 +104,10 @@ TEST(program_test, program) { ...@@ -90,22 +104,10 @@ TEST(program_test, program) {
EXPECT_EQ(program.parameters_num() == 2, true); EXPECT_EQ(program.parameters_num() == 2, true);
// (4) Def a = GetParameterOp("a"), and create DenseTensor for a. // (4) Def a = GetParameterOp("a"), and create DenseTensor for a.
std::string op1_name = ir::GetParameterOp::name(); ir::Builder builder(ctx, program.block());
ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name); auto op1 = builder.Build<ir::GetParameterOp>("a", dense_tensor_dtype);
std::unordered_map<std::string, ir::Attribute> op1_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "a")}};
ir::Operation *op1 =
ir::Operation::Create({}, op1_attribute, {dense_tensor_dtype}, op1_info);
ir::Block *block = program.block();
block->push_back(op1);
EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParent());
EXPECT_EQ(program.module_op(), block->GetParentOp());
EXPECT_EQ(&program, op1->GetParentProgram()); EXPECT_EQ(&program, op1->GetParentProgram());
EXPECT_EQ(op1->result(0).type().dialect().id(), paddle_dialect->id()); EXPECT_EQ(op1->result(0).type().dialect().id(), paddle_dialect->id());
using Interface = paddle::dialect::ParameterConvertInterface; using Interface = paddle::dialect::ParameterConvertInterface;
Interface *a_interface = Interface *a_interface =
...@@ -124,14 +126,7 @@ TEST(program_test, program) { ...@@ -124,14 +126,7 @@ TEST(program_test, program) {
} }
// (5) Def b = GetParameterOp("b"), and create DenseTensor for b. // (5) Def b = GetParameterOp("b"), and create DenseTensor for b.
std::string op2_name = auto op2 = builder.Build<ir::GetParameterOp>("b", dense_tensor_dtype);
builtin_dialect->name() + "." + std::string(ir::GetParameterOp::name());
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
std::unordered_map<std::string, ir::Attribute> op2_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "b")}};
ir::Operation *op2 =
ir::Operation::Create({}, op2_attribute, {dense_tensor_dtype}, op2_info);
block->push_back(op2);
EXPECT_EQ(op2->result(0).type().dialect().id(), paddle_dialect->id()); EXPECT_EQ(op2->result(0).type().dialect().id(), paddle_dialect->id());
Interface *b_interface = Interface *b_interface =
...@@ -150,16 +145,8 @@ TEST(program_test, program) { ...@@ -150,16 +145,8 @@ TEST(program_test, program) {
} }
// (6) Def c = AddOp(a, b), execute this op. // (6) Def c = AddOp(a, b), execute this op.
std::string op3_name = auto op3 =
builtin_dialect->name() + "." + std::string(AddOp::name()); builder.Build<AddOp>(op1->result(0), op2->result(0), dense_tensor_dtype);
ir::OpInfo op3_info = ctx->GetRegisteredOpInfo(op3_name);
std::unordered_map<std::string, ir::Attribute> op3_attribute;
ir::Operation *op3 = ir::Operation::Create({op1->result(0), op2->result(0)},
op3_attribute,
{dense_tensor_dtype},
op3_info);
block->push_back(op3);
phi::CPUContext *dev_ctx = static_cast<phi::CPUContext *>( phi::CPUContext *dev_ctx = static_cast<phi::CPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get( paddle::platform::DeviceContextPool::Instance().Get(
paddle::platform::CPUPlace())); paddle::platform::CPUPlace()));
...@@ -180,38 +167,17 @@ TEST(program_test, program) { ...@@ -180,38 +167,17 @@ TEST(program_test, program) {
} }
// (7) Def AbsOp(b) // (7) Def AbsOp(b)
ir::OpInfo abs_info = ctx->GetRegisteredOpInfo("pd.abs"); auto abs_op = builder.Build<paddle::dialect::AbsOp>(op1->result(0));
std::vector<ir::OpResult> operands = {op1->result(0)};
std::unordered_map<std::string, ir::Attribute> abs_op_attribute;
std::vector<ir::Type> output_types = {dense_tensor_dtype};
ir::OperationArgument abs_argument(abs_info);
abs_argument.AddOperands(operands.begin(), operands.end());
abs_argument.AddAttributes(abs_op_attribute.begin(), abs_op_attribute.end());
abs_argument.AddTypes(output_types.begin(), output_types.end());
ir::Operation *abs_op = ir::Operation::Create(std::move(abs_argument));
paddle::dialect::OpYamlInfoInterface interface = paddle::dialect::OpYamlInfoInterface interface =
abs_op->dyn_cast<paddle::dialect::OpYamlInfoInterface>(); abs_op->dyn_cast<paddle::dialect::OpYamlInfoInterface>();
EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true); EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true);
// (8) Def SetParameterOp(c, "c") // (8) Def SetParameterOp(c, "c")
std::string op4_name = auto op4 = builder.Build<ir::SetParameterOp>(op3->result(0), "c");
builtin_dialect->name() + "." + std::string(ir::SetParameterOp::name());
ir::OpInfo op4_info = ctx->GetRegisteredOpInfo(op4_name);
std::unordered_map<std::string, ir::Attribute> op4_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "c")}};
ir::OperationArgument op4_argument({op3->result(0)}, {}, {}, op4_info);
op4_argument.AddAttributes(op4_attribute.begin(), op4_attribute.end());
ir::Operation *op4 = ir::Operation::Create(std::move(op4_argument));
block->push_back(op4);
EXPECT_EQ(op4->operand(0).source().type().dialect().id(), EXPECT_EQ(op4->operand(0).type().dialect().id(), paddle_dialect->id());
paddle_dialect->id()); Interface *c_interface =
Interface *c_interface = op4->operand(0) op4->operand(0).type().dialect().GetRegisteredInterface<Interface>();
.source()
.type()
.dialect()
.GetRegisteredInterface<Interface>();
// ir::Parameter *parameter_c = // ir::Parameter *parameter_c =
// c_interface->VariableToParameter(variable_c.get()); // c_interface->VariableToParameter(variable_c.get());
std::unique_ptr<ir::Parameter> parameter_c = std::unique_ptr<ir::Parameter> parameter_c =
...@@ -224,7 +190,7 @@ TEST(program_test, program) { ...@@ -224,7 +190,7 @@ TEST(program_test, program) {
program.SetParameter("c", std::move(parameter_c)); program.SetParameter("c", std::move(parameter_c));
// (8) Traverse Program // (8) Traverse Program
EXPECT_EQ(program.block()->size() == 4, true); EXPECT_EQ(program.block()->size() == 5, true);
EXPECT_EQ(program.parameters_num() == 3, true); EXPECT_EQ(program.parameters_num() == 3, true);
program.Print(std::cout); program.Print(std::cout);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/utils.h"
TEST(region, erase_op_test) {
// (1) Init environment.
ir::IrContext* ctx = ir::IrContext::Instance();
// (2) Create an empty program object
ir::Program program(ctx);
ir::Builder builder = ir::Builder(ctx, program.block());
// (3) Def a = ConstantOp("2.0"); b = ConstantOp("2.0");
ir::FloatAttribute fp_attr = ir::FloatAttribute::get(ctx, 2.0f);
ir::Float32Type fp32_type = ir::Float32Type::get(ctx);
ir::OpResult a = builder.Build<ir::ConstantOp>(fp_attr, fp32_type)->result(0);
ir::OpResult b = builder.Build<ir::ConstantOp>(fp_attr, fp32_type)->result(0);
// (6) Def c = CombineOp(a, b)
builder.Build<ir::CombineOp>(std::vector<ir::OpResult>{a, b});
// Test ir::Block::erase
ir::Block* block = program.block();
EXPECT_EQ(block->size(), 3u);
block->erase(*(block->back()));
EXPECT_EQ(block->size(), 2u);
// Test ir::Region::erase
ir::Region& region = program.module_op()->region(0);
region.push_back(new ir::Block());
EXPECT_EQ(region.size(), 2u);
region.erase(region.begin());
EXPECT_EQ(region.size(), 1u);
}
...@@ -104,10 +104,14 @@ TEST(value_test, value_test) { ...@@ -104,10 +104,14 @@ TEST(value_test, value_test) {
// Test 4: Value Replace Use // Test 4: Value Replace Use
// a = OP1(); b = OP2(); c = OP3(a, b); d, e, f, g, h, i, j = OP4(a, c); // a = OP1(); b = OP2(); c = OP3(a, b); d, e, f, g, h, i, j = OP4(a, c);
// //
c.ReplaceUsesWithIf(a, [](ir::OpOperand) { return true; }); c.ReplaceUsesWithIf(b, [](ir::OpOperand) { return true; });
EXPECT_EQ(op4->operand(1).source(), a); EXPECT_EQ(op4->operand(1).source(), b);
EXPECT_TRUE(c.use_empty()); EXPECT_TRUE(c.use_empty());
b.ReplaceAllUsesWith(a);
EXPECT_EQ(op4->operand(1).source(), a);
EXPECT_TRUE(b.use_empty());
// destroy // destroy
VLOG(0) << op1->result(0).PrintUdChain() << std::endl; VLOG(0) << op1->result(0).PrintUdChain() << std::endl;
op4->Destroy(); op4->Destroy();
......
...@@ -45,7 +45,21 @@ class AddOp : public ir::Op<AddOp> { ...@@ -45,7 +45,21 @@ class AddOp : public ir::Op<AddOp> {
throw("The size of outputs must be equal to 1."); throw("The size of outputs must be equal to 1.");
} }
} }
static void Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT
ir::OpResult l_operand,
ir::OpResult r_operand,
ir::Type sum_type);
}; };
void AddOp::Build(ir::Builder &,
ir::OperationArgument &argument,
ir::OpResult l_operand,
ir::OpResult r_operand,
ir::Type sum_type) {
argument.AddOperand(l_operand);
argument.AddOperand(r_operand);
argument.AddOutput(sum_type);
}
IR_DECLARE_EXPLICIT_TYPE_ID(AddOp) IR_DECLARE_EXPLICIT_TYPE_ID(AddOp)
IR_DEFINE_EXPLICIT_TYPE_ID(AddOp) IR_DEFINE_EXPLICIT_TYPE_ID(AddOp)
...@@ -79,10 +93,9 @@ TEST(pass_manager_test, pass_manager) { ...@@ -79,10 +93,9 @@ TEST(pass_manager_test, pass_manager) {
// (3) Create a float32 DenseTensor Parameter and save into Program // (3) Create a float32 DenseTensor Parameter and save into Program
ir::Type fp32_dtype = ir::Float32Type::get(ctx); ir::Type fp32_dtype = ir::Float32Type::get(ctx);
paddle::dialect::DenseTensorTypeStorage::Dim dims = {2, 2}; phi::DDim dims = {2, 2};
paddle::dialect::DenseTensorTypeStorage::DataLayout data_layout = phi::DataLayout data_layout = phi::DataLayout::NCHW;
paddle::dialect::DenseTensorTypeStorage::DataLayout::NCHW; phi::LoD lod = {{0, 1, 2}};
paddle::dialect::DenseTensorTypeStorage::LoD lod = {{0, 1, 2}};
size_t offset = 0; size_t offset = 0;
ir::Type dense_tensor_dtype = paddle::dialect::DenseTensorType::get( ir::Type dense_tensor_dtype = paddle::dialect::DenseTensorType::get(
ctx, fp32_dtype, dims, data_layout, lod, offset); ctx, fp32_dtype, dims, data_layout, lod, offset);
...@@ -104,22 +117,10 @@ TEST(pass_manager_test, pass_manager) { ...@@ -104,22 +117,10 @@ TEST(pass_manager_test, pass_manager) {
EXPECT_EQ(program.parameters_num() == 2, true); EXPECT_EQ(program.parameters_num() == 2, true);
// (4) Def a = GetParameterOp("a"), and create DenseTensor for a. // (4) Def a = GetParameterOp("a"), and create DenseTensor for a.
std::string op1_name = ir::GetParameterOp::name(); ir::Builder builder(ctx, program.block());
ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name); auto op1 = builder.Build<ir::GetParameterOp>("a", dense_tensor_dtype);
std::unordered_map<std::string, ir::Attribute> op1_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "a")}};
ir::Operation *op1 =
ir::Operation::Create({}, op1_attribute, {dense_tensor_dtype}, op1_info);
ir::Block *block = program.block();
block->push_back(op1);
EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParent());
EXPECT_EQ(program.module_op(), block->GetParentOp());
EXPECT_EQ(&program, op1->GetParentProgram()); EXPECT_EQ(&program, op1->GetParentProgram());
EXPECT_EQ(op1->result(0).type().dialect().id(), paddle_dialect->id()); EXPECT_EQ(op1->result(0).type().dialect().id(), paddle_dialect->id());
using Interface = paddle::dialect::ParameterConvertInterface; using Interface = paddle::dialect::ParameterConvertInterface;
Interface *a_interface = Interface *a_interface =
...@@ -138,15 +139,7 @@ TEST(pass_manager_test, pass_manager) { ...@@ -138,15 +139,7 @@ TEST(pass_manager_test, pass_manager) {
} }
// (5) Def b = GetParameterOp("b"), and create DenseTensor for b. // (5) Def b = GetParameterOp("b"), and create DenseTensor for b.
std::string op2_name = auto op2 = builder.Build<ir::GetParameterOp>("b", dense_tensor_dtype);
builtin_dialect->name() + "." + std::string(ir::GetParameterOp::name());
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
std::unordered_map<std::string, ir::Attribute> op2_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "b")}};
ir::Operation *op2 =
ir::Operation::Create({}, op2_attribute, {dense_tensor_dtype}, op2_info);
block->push_back(op2);
EXPECT_EQ(op2->result(0).type().dialect().id(), paddle_dialect->id()); EXPECT_EQ(op2->result(0).type().dialect().id(), paddle_dialect->id());
Interface *b_interface = Interface *b_interface =
op2->result(0).type().dialect().GetRegisteredInterface<Interface>(); op2->result(0).type().dialect().GetRegisteredInterface<Interface>();
...@@ -164,16 +157,8 @@ TEST(pass_manager_test, pass_manager) { ...@@ -164,16 +157,8 @@ TEST(pass_manager_test, pass_manager) {
} }
// (6) Def c = AddOp(a, b), execute this op. // (6) Def c = AddOp(a, b), execute this op.
std::string op3_name = auto op3 =
builtin_dialect->name() + "." + std::string(AddOp::name()); builder.Build<AddOp>(op1->result(0), op2->result(0), dense_tensor_dtype);
ir::OpInfo op3_info = ctx->GetRegisteredOpInfo(op3_name);
std::unordered_map<std::string, ir::Attribute> op3_attribute;
ir::Operation *op3 = ir::Operation::Create({op1->result(0), op2->result(0)},
op3_attribute,
{dense_tensor_dtype},
op3_info);
block->push_back(op3);
phi::CPUContext *dev_ctx = static_cast<phi::CPUContext *>( phi::CPUContext *dev_ctx = static_cast<phi::CPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get( paddle::platform::DeviceContextPool::Instance().Get(
paddle::platform::CPUPlace())); paddle::platform::CPUPlace()));
...@@ -193,39 +178,12 @@ TEST(pass_manager_test, pass_manager) { ...@@ -193,39 +178,12 @@ TEST(pass_manager_test, pass_manager) {
EXPECT_EQ(*(dst_tensor->data<float>() + i), data_a[i] + data_b[i]); EXPECT_EQ(*(dst_tensor->data<float>() + i), data_a[i] + data_b[i]);
} }
// (7) Def AbsOp(b) // (7) Def SetParameterOp(c, "c")
ir::OpInfo abs_info = ctx->GetRegisteredOpInfo("pd.abs"); auto op4 = builder.Build<ir::SetParameterOp>(op3->result(0), "c");
std::vector<ir::OpResult> operands = {op1->result(0)};
std::unordered_map<std::string, ir::Attribute> abs_op_attribute;
std::vector<ir::Type> output_types = {dense_tensor_dtype};
ir::OperationArgument abs_argument(abs_info);
abs_argument.AddOperands(operands.begin(), operands.end());
abs_argument.AddAttributes(abs_op_attribute.begin(), abs_op_attribute.end());
abs_argument.AddTypes(output_types.begin(), output_types.end());
ir::Operation *abs_op = ir::Operation::Create(std::move(abs_argument));
paddle::dialect::OpYamlInfoInterface interface =
abs_op->dyn_cast<paddle::dialect::OpYamlInfoInterface>();
EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true);
// (8) Def SetParameterOp(c, "c")
std::string op4_name =
builtin_dialect->name() + "." + std::string(ir::SetParameterOp::name());
ir::OpInfo op4_info = ctx->GetRegisteredOpInfo(op4_name);
std::unordered_map<std::string, ir::Attribute> op4_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "c")}};
ir::OperationArgument op4_argument({op3->result(0)}, {}, {}, op4_info);
op4_argument.AddAttributes(op4_attribute.begin(), op4_attribute.end());
ir::Operation *op4 = ir::Operation::Create(std::move(op4_argument));
block->push_back(op4);
EXPECT_EQ(op4->operand(0).source().type().dialect().id(), EXPECT_EQ(op4->operand(0).source().type().dialect().id(),
paddle_dialect->id()); paddle_dialect->id());
Interface *c_interface = op4->operand(0) Interface *c_interface =
.source() op4->operand(0).type().dialect().GetRegisteredInterface<Interface>();
.type()
.dialect()
.GetRegisteredInterface<Interface>();
// ir::Parameter *parameter_c = // ir::Parameter *parameter_c =
// c_interface->VariableToParameter(variable_c.get()); // c_interface->VariableToParameter(variable_c.get());
std::unique_ptr<ir::Parameter> parameter_c = std::unique_ptr<ir::Parameter> parameter_c =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册