未验证 提交 88e43625 编写于 作者: W winter-wang 提交者: GitHub

[IR] add region data structure. (#54185)

上级 9efa5af4
...@@ -110,8 +110,8 @@ inline ir::Operation* InsertSliceOperationForTarget( ...@@ -110,8 +110,8 @@ inline ir::Operation* InsertSliceOperationForTarget(
defining_info.value.type().dyn_cast<ir::VectorType>(); defining_info.value.type().dyn_cast<ir::VectorType>();
ir::Operation* operation = ir::Operation* operation =
ir::Operation::create({defining_info.value}, ir::Operation::create({defining_info.value},
{src_vec_type[defining_info.idx_in_vector]},
op_attribute_map, op_attribute_map,
{src_vec_type[defining_info.idx_in_vector]},
op_info); op_info);
program->InsertOp(operation); program->InsertOp(operation);
ir::OpResult target_op_result = operation->GetResultByIndex(0); ir::OpResult target_op_result = operation->GetResultByIndex(0);
...@@ -136,7 +136,7 @@ inline ir::Operation* InsertCombineOperationForTarget( ...@@ -136,7 +136,7 @@ inline ir::Operation* InsertCombineOperationForTarget(
} }
ir::Type target_vec_type = ir::VectorType::get(ctx, types_in_vec); ir::Type target_vec_type = ir::VectorType::get(ctx, types_in_vec);
ir::Operation* operation = ir::Operation* operation =
ir::Operation::create(src_values, {target_vec_type}, {}, op_info); ir::Operation::create(src_values, {}, {target_vec_type}, op_info);
program->InsertOp(operation); program->InsertOp(operation);
return operation; return operation;
} }
...@@ -281,7 +281,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, ...@@ -281,7 +281,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx,
std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc);
auto op_info = LoopkUpOpInfo(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc);
ir::Operation* operation = ir::Operation* operation =
ir::Operation::create(op_inputs, op_output_types, {}, op_info); ir::Operation::create(op_inputs, {}, op_output_types, op_info);
program->InsertOp(operation); program->InsertOp(operation);
RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx);
...@@ -299,7 +299,7 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, ...@@ -299,7 +299,7 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx,
std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc);
auto op_info = LoopkUpOpInfo(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc);
ir::Operation* operation = ir::Operation* operation =
ir::Operation::create(op_inputs, op_output_types, {}, op_info); ir::Operation::create(op_inputs, {}, op_output_types, op_info);
program->InsertOp(operation); program->InsertOp(operation);
RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx);
...@@ -315,7 +315,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, ...@@ -315,7 +315,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx,
OpOutputTypeList op_output_types = {}; OpOutputTypeList op_output_types = {};
auto op_info = LoopkUpOpInfo(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc);
ir::Operation* operation = ir::Operation* operation =
ir::Operation::create(op_inputs, op_output_types, {}, op_info); ir::Operation::create(op_inputs, {}, op_output_types, op_info);
program->InsertOp(operation); program->InsertOp(operation);
return operation; return operation;
......
...@@ -79,7 +79,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock( ...@@ -79,7 +79,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock(
}; };
ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);
ir::Operation* operation = ir::Operation::create( ir::Operation* operation = ir::Operation::create(
{}, {translated_var_type}, op_attribute_map, op_info); {}, op_attribute_map, {translated_var_type}, op_info);
program->InsertOp(operation); program->InsertOp(operation);
param_map[var->Name()] = param_map[var->Name()] =
VariableDefiningInfo(operation->GetResultByIndex(0)); VariableDefiningInfo(operation->GetResultByIndex(0));
......
...@@ -16,6 +16,20 @@ ...@@ -16,6 +16,20 @@
namespace ir { namespace ir {
Block::~Block() { clear(); } Block::~Block() { clear(); }
void Block::push_back(Operation *op) {
op->set_parent(this);
ops_.push_back(op);
}
void Block::push_front(Operation *op) {
op->set_parent(this);
ops_.push_front(op);
}
Block::iterator Block::insert(const_iterator iterator, Operation *op) {
op->set_parent(this);
return ops_.insert(iterator, op);
}
void Block::clear() { void Block::clear() {
while (!empty()) { while (!empty()) {
......
...@@ -14,18 +14,23 @@ ...@@ -14,18 +14,23 @@
#pragma once #pragma once
#include <cstddef>
#include <list> #include <list>
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
namespace ir { namespace ir {
class Region;
class Block { class Block {
public: public:
using iterator = std::list<Operation *>::iterator; using iterator = std::list<Operation *>::iterator;
using reverse_iterator = std::list<Operation *>::reverse_iterator; using reverse_iterator = std::list<Operation *>::reverse_iterator;
using const_iterator = std::list<Operation *>::const_iterator;
Block() = default; Block() = default;
~Block(); ~Block();
Region *parent() const { return parent_; }
bool empty() const { return ops_.empty(); } bool empty() const { return ops_.empty(); }
size_t size() const { return ops_.size(); } size_t size() const { return ops_.size(); }
...@@ -34,21 +39,22 @@ class Block { ...@@ -34,21 +39,22 @@ class Block {
reverse_iterator rbegin() { return ops_.rbegin(); } reverse_iterator rbegin() { return ops_.rbegin(); }
reverse_iterator rend() { return ops_.rend(); } reverse_iterator rend() { return ops_.rend(); }
Operation *back() { return ops_.back(); } Operation *back() const { return ops_.back(); }
Operation *front() { return ops_.front(); } Operation *front() const { return ops_.front(); }
void push_back(Operation *op) { ops_.push_back(op); } void push_back(Operation *op);
void push_front(Operation *op) { ops_.push_front(op); } void push_front(Operation *op);
std::list<Operation *>::iterator insert( iterator insert(const_iterator iterator, Operation *op);
std::list<Operation *>::const_iterator iterator, Operation *op) {
return ops_.insert(iterator, op);
}
void clear(); void clear();
private: private:
Block(Block &) = delete; Block(Block &) = delete;
void operator=(Block &) = delete; Block &operator=(const Block &) = delete;
friend class Region;
void set_parent(Region *parent) { parent_ = parent; }
private: private:
Region *parent_; // not owned
std::list<Operation *> ops_; // owned std::list<Operation *> ops_; // owned
}; };
} // namespace ir } // namespace ir
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/ir/core/builder.h" #include "paddle/ir/core/builder.h"
#include "paddle/ir/core/region.h"
namespace ir { namespace ir {
Operation *Builder::insert(Operation *op) { Operation *Builder::insert(Operation *op) {
...@@ -25,17 +26,16 @@ Operation *Builder::insert(Operation *op) { ...@@ -25,17 +26,16 @@ Operation *Builder::insert(Operation *op) {
} }
/// Create an operation given the fields represented as an OperationState. /// Create an operation given the fields represented as an OperationState.
Operation *Builder::create(const OperationArgument &argument) { Operation *Builder::create(OperationArgument &&argument) {
return insert(Operation::create(argument)); return insert(Operation::create(std::move(argument)));
} }
/// Creates an operation with the given fields. /// Creates an operation with the given fields.
Operation *Builder::create(const std::vector<ir::OpResult> &inputs, Operation *Builder::create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types,
const AttributeMap &attribute, const AttributeMap &attribute,
const std::vector<ir::Type> &output_types,
ir::OpInfo op_info) { ir::OpInfo op_info) {
OperationArgument argument(op_info, inputs, output_types, attribute); return create(OperationArgument(inputs, attribute, output_types, op_info));
return create(argument);
} }
} // namespace ir } // namespace ir
...@@ -47,12 +47,12 @@ class Builder { ...@@ -47,12 +47,12 @@ class Builder {
Operation *insert(Operation *op); Operation *insert(Operation *op);
/// Creates an operation given the fields represented as an OperationState. /// Creates an operation given the fields represented as an OperationState.
Operation *create(const OperationArgument &argument); Operation *create(OperationArgument &&argument);
/// Creates an operation with the given fields. /// Creates an operation with the given fields.
Operation *create(const std::vector<ir::OpResult> &inputs, Operation *create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types,
const AttributeMap &attribute, const AttributeMap &attribute,
const std::vector<ir::Type> &output_types,
ir::OpInfo op_info); ir::OpInfo op_info);
/// Create an operation of specific op type at the current insertion point. /// Create an operation of specific op type at the current insertion point.
...@@ -60,7 +60,7 @@ class Builder { ...@@ -60,7 +60,7 @@ class Builder {
OpTy create(Args &&...args) { OpTy create(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name())); OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::build(*this, argument, std::forward<Args>(args)...); OpTy::build(*this, argument, std::forward<Args>(args)...);
Operation *op = create(argument); Operation *op = create(std::move(argument));
return op->dyn_cast<OpTy>(); return op->dyn_cast<OpTy>();
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/ir/core/op_base.h" #include "paddle/ir/core/op_base.h"
namespace ir { namespace ir {
/// ///
/// \brief GetParameterOp: OpResult = GetParameterOp({StrAttribute, /// \brief GetParameterOp: OpResult = GetParameterOp({StrAttribute,
/// StrAttribute}) /// StrAttribute})
......
...@@ -15,23 +15,31 @@ ...@@ -15,23 +15,31 @@
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/ir/core/dialect.h" #include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
#include "paddle/ir/core/region.h"
#include "paddle/ir/core/utils.h" #include "paddle/ir/core/utils.h"
namespace ir { namespace ir {
Operation *Operation::create(const OperationArgument &argument) { Operation *Operation::create(OperationArgument &&argument) {
return create(argument.inputs_, Operation *op = create(argument.inputs,
argument.output_types_, argument.attribute,
argument.attribute_, argument.output_types,
argument.info_); argument.info,
argument.regions.size());
for (size_t index = 0; index < argument.regions.size(); ++index) {
op->GetRegion(index).TakeBody(std::move(*argument.regions[index]));
}
return op;
} }
// Allocate the required memory based on the size and number of inputs, outputs, // Allocate the required memory based on the size and number of inputs, outputs,
// and operators, and construct it in the order of: OpOutlineResult, // and operators, and construct it in the order of: OpOutlineResult,
// OpInlineResult, Operation, Operand. // OpInlineResult, Operation, Operand.
Operation *Operation::create(const std::vector<ir::OpResult> &inputs, Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types,
const AttributeMap &attribute, const AttributeMap &attribute,
ir::OpInfo op_info) { const std::vector<ir::Type> &output_types,
ir::OpInfo op_info,
size_t num_regions) {
// 0. Verify // 0. Verify
if (op_info) { if (op_info) {
op_info.verify(inputs, output_types, attribute); op_info.verify(inputs, output_types, attribute);
...@@ -50,7 +58,9 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs, ...@@ -50,7 +58,9 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
: sizeof(detail::OpInlineResultImpl) * num_results; : sizeof(detail::OpInlineResultImpl) * num_results;
size_t operand_mem_size = sizeof(detail::OpOperandImpl) * num_operands; size_t operand_mem_size = sizeof(detail::OpOperandImpl) * num_operands;
size_t op_mem_size = sizeof(Operation); size_t op_mem_size = sizeof(Operation);
size_t base_size = result_mem_size + op_mem_size + operand_mem_size; size_t region_mem_size = num_regions * sizeof(Region);
size_t base_size =
result_mem_size + op_mem_size + operand_mem_size + region_mem_size;
// 2. Malloc memory. // 2. Malloc memory.
char *base_ptr = reinterpret_cast<char *>(aligned_malloc(base_size, 8)); char *base_ptr = reinterpret_cast<char *>(aligned_malloc(base_size, 8));
// 3.1. Construct OpResults. // 3.1. Construct OpResults.
...@@ -65,8 +75,8 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs, ...@@ -65,8 +75,8 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
} }
} }
// 3.2. Construct Operation. // 3.2. Construct Operation.
Operation *op = Operation *op = new (base_ptr)
new (base_ptr) Operation(num_results, num_operands, attribute, op_info); Operation(attribute, op_info, num_results, num_operands, num_regions);
base_ptr += sizeof(Operation); base_ptr += sizeof(Operation);
// 3.3. Construct OpOperands. // 3.3. Construct OpOperands.
if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) { if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) {
...@@ -76,13 +86,27 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs, ...@@ -76,13 +86,27 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op); new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op);
base_ptr += sizeof(detail::OpOperandImpl); base_ptr += sizeof(detail::OpOperandImpl);
} }
// 3.4. Construct Regions
if (num_regions > 0) {
op->regions_ = reinterpret_cast<Region *>(base_ptr);
for (size_t idx = 0; idx < num_regions; idx++) {
new (base_ptr) Region(op);
base_ptr += sizeof(Region);
}
}
return op; return op;
} }
// Call destructors for OpResults, Operation, and OpOperands in sequence, and // Call destructors for OpResults, Operation, and OpOperands in sequence, and
// finally free memory. // finally free memory.
void Operation::destroy() { void Operation::destroy() {
// Deconstruct Regions.
if (num_regions_ > 0) {
for (size_t idx = 0; idx < num_regions_; idx++) {
regions_[idx].~Region();
}
}
// 1. Get aligned_ptr by result_num. // 1. Get aligned_ptr by result_num.
uint32_t max_inline_result_num = uint32_t max_inline_result_num =
detail::OpResultImpl::GetMaxInlineResultIndex() + 1; detail::OpResultImpl::GetMaxInlineResultIndex() + 1;
...@@ -136,15 +160,16 @@ void Operation::destroy() { ...@@ -136,15 +160,16 @@ void Operation::destroy() {
IrContext *Operation::ir_context() const { return op_info_.ir_context(); } IrContext *Operation::ir_context() const { return op_info_.ir_context(); }
Operation::Operation(uint32_t num_results, Operation::Operation(const AttributeMap &attribute,
ir::OpInfo op_info,
uint32_t num_results,
uint32_t num_operands, uint32_t num_operands,
const AttributeMap &attribute, uint32_t num_regions)
ir::OpInfo op_info) { : attribute_(attribute),
num_results_ = num_results; op_info_(op_info),
num_operands_ = num_operands; num_results_(num_results),
attribute_ = attribute; num_operands_(num_operands),
op_info_ = op_info; num_regions_(num_regions) {}
}
ir::OpResult Operation::GetResultByIndex(uint32_t index) const { ir::OpResult Operation::GetResultByIndex(uint32_t index) const {
if (index >= num_results_) { if (index >= num_results_) {
...@@ -198,4 +223,9 @@ std::string Operation::print() { ...@@ -198,4 +223,9 @@ std::string Operation::print() {
std::string Operation::op_name() const { return op_info_.name(); } std::string Operation::op_name() const { return op_info_.name(); }
Region &Operation::GetRegion(unsigned index) {
assert(index < num_regions_ && "invalid region index");
return regions_[index];
}
} // namespace ir } // namespace ir
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <vector>
#include "paddle/ir/core/op_info.h" #include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation_utils.h" #include "paddle/ir/core/operation_utils.h"
#include "paddle/ir/core/type.h" #include "paddle/ir/core/type.h"
...@@ -24,6 +23,7 @@ ...@@ -24,6 +23,7 @@
namespace ir { namespace ir {
class OpBase; class OpBase;
class Program; class Program;
class Block;
class alignas(8) Operation final { class alignas(8) Operation final {
public: public:
...@@ -34,16 +34,19 @@ class alignas(8) Operation final { ...@@ -34,16 +34,19 @@ class alignas(8) Operation final {
/// used in conjunction. /// used in conjunction.
/// ///
static Operation *create(const std::vector<ir::OpResult> &inputs, static Operation *create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types,
const AttributeMap &attribute, const AttributeMap &attribute,
ir::OpInfo op_info); const std::vector<ir::Type> &output_types,
static Operation *create(const OperationArgument &op_argument); ir::OpInfo op_info,
size_t num_regions = 0);
static Operation *create(OperationArgument &&op_argument);
/// ///
/// \brief Destroy the operation objects and free memory by create(). /// \brief Destroy the operation objects and free memory by create().
/// ///
void destroy(); void destroy();
Block *parent() const { return parent_; }
IrContext *ir_context() const; IrContext *ir_context() const;
ir::OpResult GetResultByIndex(uint32_t index) const; ir::OpResult GetResultByIndex(uint32_t index) const;
...@@ -60,6 +63,8 @@ class alignas(8) Operation final { ...@@ -60,6 +63,8 @@ class alignas(8) Operation final {
uint32_t num_operands() const { return num_operands_; } uint32_t num_operands() const { return num_operands_; }
uint32_t num_regions() const { return num_regions_; }
std::string op_name() const; std::string op_name() const;
template <typename T> template <typename T>
...@@ -83,11 +88,15 @@ class alignas(8) Operation final { ...@@ -83,11 +88,15 @@ class alignas(8) Operation final {
parent_program_ = parent_program; parent_program_ = parent_program;
} }
/// Returns the region held by this operation at position 'index'.
Region &GetRegion(unsigned index);
private: private:
Operation(uint32_t num_results, Operation(const AttributeMap &attribute,
ir::OpInfo op_info,
uint32_t num_results,
uint32_t num_operands, uint32_t num_operands,
const AttributeMap &attribute, uint32_t num_regions);
ir::OpInfo op_info);
template <typename T, typename Enabler = void> template <typename T, typename Enabler = void>
struct CastUtil { struct CastUtil {
...@@ -96,6 +105,9 @@ class alignas(8) Operation final { ...@@ -96,6 +105,9 @@ class alignas(8) Operation final {
} }
}; };
friend class Block;
void set_parent(Block *parent) { parent_ = parent; }
template <typename T> template <typename T>
struct CastUtil< struct CastUtil<
T, T,
...@@ -107,11 +119,13 @@ class alignas(8) Operation final { ...@@ -107,11 +119,13 @@ class alignas(8) Operation final {
OpInfo op_info_; OpInfo op_info_;
uint32_t num_results_ = 0; const uint32_t num_results_ = 0;
const uint32_t num_operands_ = 0;
uint32_t num_operands_ = 0; const uint32_t num_regions_ = 0;
Region *regions_{nullptr};
Program *parent_program_{nullptr}; Program *parent_program_{nullptr};
Block *parent_{nullptr};
}; };
} // namespace ir } // namespace ir
...@@ -13,19 +13,11 @@ ...@@ -13,19 +13,11 @@
// limitations under the License. // limitations under the License.
#include "paddle/ir/core/operation_utils.h" #include "paddle/ir/core/operation_utils.h"
#include "paddle/ir/core/region.h"
namespace ir { namespace ir {
OperationArgument::OperationArgument(IrContext* ir_context, std::string name) { OperationArgument::OperationArgument(IrContext* ir_context,
info_ = ir_context->GetRegisteredOpInfo(name); const std::string& name) {
info = ir_context->GetRegisteredOpInfo(name);
} }
OperationArgument::OperationArgument(OpInfo info,
const std::vector<OpResult>& operands,
const std::vector<Type>& types,
const AttributeMap& named_attr)
: info_(info),
inputs_(operands),
output_types_(types),
attribute_(named_attr) {}
} // namespace ir } // namespace ir
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/op_info.h" #include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/region.h"
#include "paddle/ir/core/type.h" #include "paddle/ir/core/type.h"
#include "paddle/ir/core/value_impl.h" #include "paddle/ir/core/value_impl.h"
...@@ -30,18 +31,25 @@ using AttributeMap = std::unordered_map<std::string, Attribute>; ...@@ -30,18 +31,25 @@ using AttributeMap = std::unordered_map<std::string, Attribute>;
// This represents an operation arguments in an combined form, suitable for use // This represents an operation arguments in an combined form, suitable for use
// with the builder APIs. // with the builder APIs.
struct OperationArgument { struct OperationArgument {
OpInfo info_; std::vector<OpResult> inputs;
std::vector<OpResult> inputs_; AttributeMap attribute;
std::vector<Type> output_types_; std::vector<Type> output_types;
AttributeMap attribute_; OpInfo info;
std::vector<std::unique_ptr<Region>> regions;
public: public:
OperationArgument(IrContext* ir_context, std::string name); OperationArgument(IrContext* ir_context, const std::string& name);
explicit OperationArgument(OpInfo info) : info_(info) {} explicit OperationArgument(OpInfo info) : info(info) {}
OperationArgument(OpInfo info, OperationArgument(const std::vector<OpResult>& operands,
const std::vector<OpResult>& operands, const AttributeMap& named_attr,
const std::vector<Type>& types, const std::vector<Type>& types,
const AttributeMap& named_attr = {}); OpInfo info,
std::vector<std::unique_ptr<Region>>&& regions = {})
: inputs(operands),
attribute(named_attr),
output_types(types),
info(info),
regions(std::move(regions)) {}
template <class InputIt> template <class InputIt>
void addOperands(InputIt first, InputIt last); void addOperands(InputIt first, InputIt last);
...@@ -51,31 +59,31 @@ struct OperationArgument { ...@@ -51,31 +59,31 @@ struct OperationArgument {
/// Add an attribute with the specified name. /// Add an attribute with the specified name.
void addAttribute(const std::string& name, Attribute attr) { void addAttribute(const std::string& name, Attribute attr) {
attribute_[name] = attr; this->attribute[name] = attr;
} }
/// Add an array of named attributes. /// Add an array of named attributes.
template <class InputIt> template <class InputIt>
void addAttributes(InputIt first, InputIt last); void addAttributes(InputIt first, InputIt last);
/// Get the context held by this operation state. /// Get the context held by this operation state.
IrContext* getContext() const { return info_.ir_context(); } IrContext* getContext() const { return info.ir_context(); }
}; };
template <class InputIt> template <class InputIt>
void OperationArgument::addOperands(InputIt first, InputIt last) { void OperationArgument::addOperands(InputIt first, InputIt last) {
while (first != last) { while (first != last) {
inputs_.emplace_back(*first++); inputs.emplace_back(*first++);
} }
} }
template <class InputIt> template <class InputIt>
void OperationArgument::addTypes(InputIt first, InputIt last) { void OperationArgument::addTypes(InputIt first, InputIt last) {
while (first != last) { while (first != last) {
output_types_.emplace_back(*first++); output_types.emplace_back(*first++);
} }
} }
template <class InputIt> template <class InputIt>
void OperationArgument::addAttributes(InputIt first, InputIt last) { void OperationArgument::addAttributes(InputIt first, InputIt last) {
while (first != last) { while (first != last) {
attribute_[first->first] = first->second; attribute[first->first] = first->second;
++first; ++first;
} }
} }
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/core/region.h"
#include "paddle/ir/core/block.h"
namespace ir {
Region::~Region() { clear(); }
void Region::push_back(Block *block) {
block->set_parent(this);
blocks_.push_back(block);
}
void Region::push_front(Block *block) {
block->set_parent(this);
blocks_.push_front(block);
}
Region::iterator Region::insert(const_iterator position, Block *block) {
block->set_parent(this);
return blocks_.insert(position, block);
}
void Region::TakeBody(Region &&other) {
clear();
blocks_.swap(other.blocks_);
for (auto &block : blocks_) {
block->set_parent(this);
}
}
void Region::clear() {
while (!empty()) {
delete blocks_.back();
blocks_.pop_back();
}
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstddef>
#include <list>
namespace ir {
class Block;
class Operation;
class Region {
public:
using iterator = std::list<Block *>::iterator;
using reverse_iterator = std::list<Block *>::reverse_iterator;
using const_iterator = std::list<Block *>::const_iterator;
~Region();
Region() = default;
bool empty() const { return blocks_.empty(); }
size_t size() const { return blocks_.size(); }
iterator begin() { return blocks_.begin(); }
iterator end() { return blocks_.end(); }
reverse_iterator rbegin() { return blocks_.rbegin(); }
reverse_iterator rend() { return blocks_.rend(); }
Block *back() const { return blocks_.back(); }
Block *front() const { return blocks_.front(); }
void push_back(Block *block);
void push_front(Block *block);
iterator insert(const_iterator position, Block *block);
void clear();
void TakeBody(Region &&other);
private:
Region(Region &) = delete;
Region &operator=(const Region &) = delete;
friend class Operation;
explicit Region(Operation *op) : parent_(op) {}
private:
Operation *parent_{nullptr}; // not owned
std::list<Block *> blocks_; // owned
};
} // namespace ir
...@@ -14,12 +14,14 @@ ...@@ -14,12 +14,14 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builder.h" #include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h" #include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_base.h" #include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/region.h"
/// \brief Define built-in Trait, derived from OpTraitBase. /// \brief Define built-in Trait, derived from OpTraitBase.
class ReadOnlyTrait : public ir::OpTraitBase<ReadOnlyTrait> { class ReadOnlyTrait : public ir::OpTraitBase<ReadOnlyTrait> {
...@@ -175,9 +177,9 @@ TEST(op_test, op_test) { ...@@ -175,9 +177,9 @@ TEST(op_test, op_test) {
std::vector<ir::Type> op_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op2 = ir::Operation *op2 =
ir::Operation::create(op_inputs, ir::Operation::create(op_inputs,
op_output_types,
CreateAttributeMap({"op2_attr1", "op2_attr2"}, CreateAttributeMap({"op2_attr1", "op2_attr2"},
{"op2_attr1", "op2_attr2"}), {"op2_attr1", "op2_attr2"}),
op_output_types,
op2_info); op2_info);
ReadOnlyTrait trait = op2->dyn_cast<ReadOnlyTrait>(); ReadOnlyTrait trait = op2->dyn_cast<ReadOnlyTrait>();
...@@ -188,3 +190,44 @@ TEST(op_test, op_test) { ...@@ -188,3 +190,44 @@ TEST(op_test, op_test) {
EXPECT_EQ(Op2.operation(), op2); EXPECT_EQ(Op2.operation(), op2);
op2->destroy(); op2->destroy();
} }
TEST(op_test, region_test) {
// (1) Register Dialect, Operation1, Operation2 into IrContext.
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Dialect *test_dialect = ctx->GetOrRegisterDialect<TestDialect>();
EXPECT_EQ(test_dialect != nullptr, true);
// (2) Get registered operations.
ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(Operation1::name());
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(Operation2::name());
ir::Operation *op1 =
ir::Operation::create({},
CreateAttributeMap({"op1_attr1", "op1_attr2"},
{"op1_attr1", "op1_attr2"}),
{ir::Float32Type::get(ctx)},
op1_info);
ir::Operation *op1_2 =
ir::Operation::create({},
CreateAttributeMap({"op1_attr1", "op1_attr2"},
{"op1_attr1", "op1_attr2"}),
{ir::Float32Type::get(ctx)},
op1_info);
ir::OperationArgument argument(op2_info);
argument.attribute = CreateAttributeMap({"op2_attr1", "op2_attr2"},
{"op2_attr1", "op2_attr2"});
argument.output_types = {ir::Float32Type::get(ctx)};
argument.regions.emplace_back(std::make_unique<ir::Region>());
ir::Region *region = argument.regions.back().get();
EXPECT_EQ(region->empty(), true);
region->push_back(new ir::Block());
region->push_front(new ir::Block());
region->insert(region->begin(), new ir::Block());
ir::Block *block = region->front();
block->push_front(op1);
block->insert(block->begin(), op1_2);
ir::Operation *op2 = ir::Operation::create(std::move(argument));
op2->destroy();
}
...@@ -91,7 +91,7 @@ TEST(program_test, program) { ...@@ -91,7 +91,7 @@ TEST(program_test, program) {
std::unordered_map<std::string, ir::Attribute> op1_attribute{ std::unordered_map<std::string, ir::Attribute> op1_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "a")}}; {"parameter_name", ir::StrAttribute::get(ctx, "a")}};
ir::Operation *op1 = ir::Operation *op1 =
ir::Operation::create({}, {dense_tensor_dtype}, op1_attribute, op1_info); ir::Operation::create({}, op1_attribute, {dense_tensor_dtype}, op1_info);
program.InsertOp(op1); program.InsertOp(op1);
...@@ -123,7 +123,7 @@ TEST(program_test, program) { ...@@ -123,7 +123,7 @@ TEST(program_test, program) {
std::unordered_map<std::string, ir::Attribute> op2_attribute{ std::unordered_map<std::string, ir::Attribute> op2_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "b")}}; {"parameter_name", ir::StrAttribute::get(ctx, "b")}};
ir::Operation *op2 = ir::Operation *op2 =
ir::Operation::create({}, {dense_tensor_dtype}, op2_attribute, op2_info); ir::Operation::create({}, op2_attribute, {dense_tensor_dtype}, op2_info);
program.InsertOp(op2); program.InsertOp(op2);
EXPECT_EQ(op2->GetResultByIndex(0).type().dialect().id(), EXPECT_EQ(op2->GetResultByIndex(0).type().dialect().id(),
...@@ -153,8 +153,8 @@ TEST(program_test, program) { ...@@ -153,8 +153,8 @@ TEST(program_test, program) {
std::unordered_map<std::string, ir::Attribute> op3_attribute; std::unordered_map<std::string, ir::Attribute> op3_attribute;
ir::Operation *op3 = ir::Operation::create( ir::Operation *op3 = ir::Operation::create(
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)}, {op1->GetResultByIndex(0), op2->GetResultByIndex(0)},
{dense_tensor_dtype},
op3_attribute, op3_attribute,
{dense_tensor_dtype},
op3_info); op3_info);
program.InsertOp(op3); program.InsertOp(op3);
...@@ -184,7 +184,7 @@ TEST(program_test, program) { ...@@ -184,7 +184,7 @@ TEST(program_test, program) {
std::unordered_map<std::string, ir::Attribute> op4_attribute{ std::unordered_map<std::string, ir::Attribute> op4_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "c")}}; {"parameter_name", ir::StrAttribute::get(ctx, "c")}};
ir::Operation *op4 = ir::Operation::create( ir::Operation *op4 = ir::Operation::create(
{op3->GetResultByIndex(0)}, {}, op4_attribute, op4_info); {op3->GetResultByIndex(0)}, op4_attribute, {}, op4_info);
program.InsertOp(op4); program.InsertOp(op4);
EXPECT_EQ(op4->GetOperandByIndex(0).impl()->source().type().dialect().id(), EXPECT_EQ(op4->GetOperandByIndex(0).impl()->source().type().dialect().id(),
...@@ -230,7 +230,7 @@ TEST(program_test, slice_combine_test) { ...@@ -230,7 +230,7 @@ TEST(program_test, slice_combine_test) {
std::unordered_map<std::string, ir::Attribute> op1_attribute{ std::unordered_map<std::string, ir::Attribute> op1_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "a")}}; {"parameter_name", ir::StrAttribute::get(ctx, "a")}};
ir::Operation *op1 = ir::Operation *op1 =
ir::Operation::create({}, {fp32_dtype}, op1_attribute, op1_info); ir::Operation::create({}, op1_attribute, {fp32_dtype}, op1_info);
program.InsertOp(op1); program.InsertOp(op1);
// (5) Def b = GetParameterOp("b") // (5) Def b = GetParameterOp("b")
...@@ -239,7 +239,7 @@ TEST(program_test, slice_combine_test) { ...@@ -239,7 +239,7 @@ TEST(program_test, slice_combine_test) {
std::unordered_map<std::string, ir::Attribute> op2_attribute{ std::unordered_map<std::string, ir::Attribute> op2_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "b")}}; {"parameter_name", ir::StrAttribute::get(ctx, "b")}};
ir::Operation *op2 = ir::Operation *op2 =
ir::Operation::create({}, {fp32_dtype}, op2_attribute, op2_info); ir::Operation::create({}, op2_attribute, {fp32_dtype}, op2_info);
program.InsertOp(op2); program.InsertOp(op2);
// (6) Def combine_op = CombineOp("a", "b") // (6) Def combine_op = CombineOp("a", "b")
...@@ -249,8 +249,8 @@ TEST(program_test, slice_combine_test) { ...@@ -249,8 +249,8 @@ TEST(program_test, slice_combine_test) {
ir::VectorType::get(ctx, std::vector<ir::Type>({fp32_dtype, fp32_dtype})); ir::VectorType::get(ctx, std::vector<ir::Type>({fp32_dtype, fp32_dtype}));
ir::Operation *combine_op = ir::Operation::create( ir::Operation *combine_op = ir::Operation::create(
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)}, {op1->GetResultByIndex(0), op2->GetResultByIndex(0)},
{output_type},
{}, {},
{output_type},
combine_op_info); combine_op_info);
program.InsertOp(combine_op); program.InsertOp(combine_op);
...@@ -260,8 +260,8 @@ TEST(program_test, slice_combine_test) { ...@@ -260,8 +260,8 @@ TEST(program_test, slice_combine_test) {
ir::Attribute index_attr = ir::Int32_tAttribute::get(ctx, 0); ir::Attribute index_attr = ir::Int32_tAttribute::get(ctx, 0);
ir::Operation *slice_op = ir::Operation *slice_op =
ir::Operation::create({combine_op->GetResultByIndex(0)}, ir::Operation::create({combine_op->GetResultByIndex(0)},
{fp32_dtype},
{{"index", index_attr}}, {{"index", index_attr}},
{fp32_dtype},
slice_op_info); slice_op_info);
program.InsertOp(slice_op); program.InsertOp(slice_op);
......
...@@ -40,8 +40,8 @@ TEST(value_test, value_test) { ...@@ -40,8 +40,8 @@ TEST(value_test, value_test) {
std::vector<ir::Type> op1_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op1_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op1 = ir::Operation *op1 =
ir::Operation::create(op1_inputs, ir::Operation::create(op1_inputs,
op1_output_types,
CreateAttributeMap("op1_name", "op1_attr"), CreateAttributeMap("op1_name", "op1_attr"),
op1_output_types,
nullptr); nullptr);
VLOG(0) << op1->print(); VLOG(0) << op1->print();
// 2. Construct OP2: b = OP2(); // 2. Construct OP2: b = OP2();
...@@ -49,8 +49,8 @@ TEST(value_test, value_test) { ...@@ -49,8 +49,8 @@ TEST(value_test, value_test) {
std::vector<ir::Type> op2_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op2_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op2 = ir::Operation *op2 =
ir::Operation::create(op2_inputs, ir::Operation::create(op2_inputs,
op2_output_types,
CreateAttributeMap("op2_name", "op2_attr"), CreateAttributeMap("op2_name", "op2_attr"),
op2_output_types,
nullptr); nullptr);
VLOG(0) << op2->print() << std::endl; VLOG(0) << op2->print() << std::endl;
// 3. Construct OP3: c = OP3(a, b); // 3. Construct OP3: c = OP3(a, b);
...@@ -59,8 +59,8 @@ TEST(value_test, value_test) { ...@@ -59,8 +59,8 @@ TEST(value_test, value_test) {
std::vector<ir::Type> op3_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op3_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op3 = ir::Operation *op3 =
ir::Operation::create(op3_inputs, ir::Operation::create(op3_inputs,
op3_output_types,
CreateAttributeMap("op3_name", "op3_attr"), CreateAttributeMap("op3_name", "op3_attr"),
op3_output_types,
nullptr); nullptr);
VLOG(0) << op3->print() << std::endl; VLOG(0) << op3->print() << std::endl;
// 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c); // 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c);
...@@ -72,8 +72,8 @@ TEST(value_test, value_test) { ...@@ -72,8 +72,8 @@ TEST(value_test, value_test) {
} }
ir::Operation *op4 = ir::Operation *op4 =
ir::Operation::create(op4_inputs, ir::Operation::create(op4_inputs,
op4_output_types,
CreateAttributeMap("op4_name", "op4_attr"), CreateAttributeMap("op4_name", "op4_attr"),
op4_output_types,
nullptr); nullptr);
VLOG(0) << op4->print() << std::endl; VLOG(0) << op4->print() << std::endl;
......
...@@ -96,8 +96,8 @@ TEST(pass_manager_test, pass_manager_test) { ...@@ -96,8 +96,8 @@ TEST(pass_manager_test, pass_manager_test) {
std::vector<ir::Type> op_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op = ir::Operation *op =
ir::Operation::create(op_inputs, ir::Operation::create(op_inputs,
op_output_types,
CreateAttributeMap(ctx, "op1_attr1", "op1_attr1"), CreateAttributeMap(ctx, "op1_attr1", "op1_attr1"),
op_output_types,
op_info); op_info);
// (4) Test pass manager for op. // (4) Test pass manager for op.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册