未验证 提交 49bedfd3 编写于 作者: W winter-wang 提交者: GitHub

[IR] refine the program data structure. (#54220)

上级 4bd5b695
...@@ -113,7 +113,7 @@ inline ir::Operation* InsertSliceOperationForTarget( ...@@ -113,7 +113,7 @@ inline ir::Operation* InsertSliceOperationForTarget(
op_attribute_map, op_attribute_map,
{src_vec_type[defining_info.idx_in_vector]}, {src_vec_type[defining_info.idx_in_vector]},
op_info); op_info);
program->InsertOp(operation); program->block()->push_back(operation);
ir::OpResult target_op_result = operation->GetResultByIndex(0); ir::OpResult target_op_result = operation->GetResultByIndex(0);
(*param_map)[arg_name] = VariableDefiningInfo(target_op_result); (*param_map)[arg_name] = VariableDefiningInfo(target_op_result);
return operation; return operation;
...@@ -137,7 +137,7 @@ inline ir::Operation* InsertCombineOperationForTarget( ...@@ -137,7 +137,7 @@ inline ir::Operation* InsertCombineOperationForTarget(
ir::Type target_vec_type = ir::VectorType::get(ctx, types_in_vec); ir::Type target_vec_type = ir::VectorType::get(ctx, types_in_vec);
ir::Operation* operation = ir::Operation* operation =
ir::Operation::create(src_values, {}, {target_vec_type}, op_info); ir::Operation::create(src_values, {}, {target_vec_type}, op_info);
program->InsertOp(operation); program->block()->push_back(operation);
return operation; return operation;
} }
...@@ -282,7 +282,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, ...@@ -282,7 +282,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx,
auto op_info = LoopkUpOpInfo(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc);
ir::Operation* operation = ir::Operation* operation =
ir::Operation::create(op_inputs, {}, op_output_types, op_info); ir::Operation::create(op_inputs, {}, op_output_types, op_info);
program->InsertOp(operation); program->block()->push_back(operation);
RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx);
return operation; return operation;
...@@ -300,7 +300,7 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, ...@@ -300,7 +300,7 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx,
auto op_info = LoopkUpOpInfo(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc);
ir::Operation* operation = ir::Operation* operation =
ir::Operation::create(op_inputs, {}, op_output_types, op_info); ir::Operation::create(op_inputs, {}, op_output_types, op_info);
program->InsertOp(operation); program->block()->push_back(operation);
RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx);
return operation; return operation;
...@@ -316,7 +316,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, ...@@ -316,7 +316,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx,
auto op_info = LoopkUpOpInfo(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc);
ir::Operation* operation = ir::Operation* operation =
ir::Operation::create(op_inputs, {}, op_output_types, op_info); ir::Operation::create(op_inputs, {}, op_output_types, op_info);
program->InsertOp(operation); program->block()->push_back(operation);
return operation; return operation;
} }
......
...@@ -80,7 +80,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock( ...@@ -80,7 +80,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock(
ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);
ir::Operation* operation = ir::Operation::create( ir::Operation* operation = ir::Operation::create(
{}, op_attribute_map, {translated_var_type}, op_info); {}, op_attribute_map, {translated_var_type}, op_info);
program->InsertOp(operation); program->block()->push_back(operation);
param_map[var->Name()] = param_map[var->Name()] =
VariableDefiningInfo(operation->GetResultByIndex(0)); VariableDefiningInfo(operation->GetResultByIndex(0));
VLOG(10) << "[op translated][get parameter]" << operation; VLOG(10) << "[op translated][get parameter]" << operation;
......
...@@ -46,6 +46,12 @@ class Block { ...@@ -46,6 +46,12 @@ class Block {
iterator insert(const_iterator iterator, Operation *op); iterator insert(const_iterator iterator, Operation *op);
void clear(); void clear();
Region *GetParentRegion() const { return parent_; }
Operation *GetParentOp() const {
return parent_ ? parent_->GetParentOp() : nullptr;
}
private: private:
Block(Block &) = delete; Block(Block &) = delete;
Block &operator=(const Block &) = delete; Block &operator=(const Block &) = delete;
......
...@@ -33,4 +33,6 @@ std::vector<Attribute> ArrayAttribute::data() const { ...@@ -33,4 +33,6 @@ std::vector<Attribute> ArrayAttribute::data() const {
return storage()->GetAsKey(); return storage()->GetAsKey();
} }
void* PointerAttribute::data() const { return storage()->GetAsKey(); }
} // namespace ir } // namespace ir
...@@ -25,7 +25,7 @@ class StrAttribute : public Attribute { ...@@ -25,7 +25,7 @@ class StrAttribute : public Attribute {
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(StrAttribute, StrAttributeStorage); DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(StrAttribute, StrAttributeStorage);
bool operator<(const StrAttribute &right) const { bool operator<(const StrAttribute& right) const {
return storage() < right.storage(); return storage() < right.storage();
} }
...@@ -94,4 +94,13 @@ class ArrayAttribute : public Attribute { ...@@ -94,4 +94,13 @@ class ArrayAttribute : public Attribute {
Attribute operator[](size_t index) const { return data()[index]; } Attribute operator[](size_t index) const { return data()[index]; }
}; };
class PointerAttribute : public Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(PointerAttribute, PointerAttributeStorage);
void* data() const;
};
} // namespace ir } // namespace ir
...@@ -83,6 +83,7 @@ DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float); ...@@ -83,6 +83,7 @@ DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(DoubleAttributeStorage, double); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(DoubleAttributeStorage, double);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int32_tAttributeStorage, int32_t); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int32_tAttributeStorage, int32_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int64_tAttributeStorage, int64_t); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int64_tAttributeStorage, int64_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(PointerAttributeStorage, void *);
struct ArrayAttributeStorage : public AttributeStorage { struct ArrayAttributeStorage : public AttributeStorage {
using ParamKey = std::vector<Attribute>; using ParamKey = std::vector<Attribute>;
......
...@@ -40,11 +40,13 @@ void BuiltinDialect::initialize() { ...@@ -40,11 +40,13 @@ void BuiltinDialect::initialize() {
ir::BoolAttribute, ir::BoolAttribute,
ir::FloatAttribute, ir::FloatAttribute,
ir::DoubleAttribute, ir::DoubleAttribute,
ir::PointerAttribute,
ir::Int32_tAttribute, ir::Int32_tAttribute,
ir::Int64_tAttribute, ir::Int64_tAttribute,
ir::ArrayAttribute>(); ir::ArrayAttribute>();
RegisterOps<ir::GetParameterOp, RegisterOps<ir::ModuleOp,
ir::GetParameterOp,
ir::SetParameterOp, ir::SetParameterOp,
ir::CombineOp, ir::CombineOp,
ir::SliceOp>(); ir::SliceOp>();
......
...@@ -19,6 +19,60 @@ ...@@ -19,6 +19,60 @@
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
namespace ir { namespace ir {
const char *ModuleOp::attributes_name[attributes_num] = {"program"};
Program *ModuleOp::program() {
const AttributeMap &attr = operation()->attributes();
auto iter = attr.find("program");
if (iter == attr.end() || !iter->second) return nullptr;
return static_cast<Program *>(
iter->second.dyn_cast<PointerAttribute>().data());
}
Block *ModuleOp::block() {
assert(operation() != nullptr);
assert(operation()->num_regions() == 1);
assert(operation()->GetRegion(0).size() == 1);
return operation()->GetRegion(0).front();
}
ModuleOp ModuleOp::create(IrContext *context, Program *pointer) {
ir::OpInfo info = context->GetRegisteredOpInfo(name());
OperationArgument argument(info);
argument.AddRegion()->emplace_back();
argument.addAttribute("program", PointerAttribute::get(context, pointer));
return ModuleOp(Operation::create(std::move(argument)));
}
void ModuleOp::destroy() {
if (operation()) {
operation()->destroy();
*this = ModuleOp(nullptr);
}
}
void ModuleOp::verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
// Verify inputs type:
if (inputs.size() != 0) {
throw("The size of inputs must be equal to 0.");
}
// Verify if attributes contain attribute name in attributes_name:
auto iter = attributes.find("program");
if (iter == attributes.end() || !iter->second.isa<PointerAttribute>()) {
throw("Type of attribute: program is not right.");
}
// Verify outputs type:
if (outputs.size() != 0) {
throw("The size of outputs must be equal to 0.");
}
}
const char *GetParameterOp::attributes_name[attributes_num] = { const char *GetParameterOp::attributes_name[attributes_num] = {
"parameter_name"}; "parameter_name"};
......
...@@ -18,6 +18,31 @@ ...@@ -18,6 +18,31 @@
namespace ir { namespace ir {
class Program;
class Block;
///
/// \brief ModuleOp
///
class ModuleOp : public ir::Op<ModuleOp> {
public:
using Op::Op;
static const char *name() { return "builtin.module"; }
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
Program *program();
Block *block();
//
// As the top operation, ModuleOp only support create&destroye through
// below interface: "create"&"destroy".
static ModuleOp create(IrContext *context, Program *pointer);
void destroy();
};
/// ///
/// \brief GetParameterOp: OpResult = GetParameterOp({StrAttribute, /// \brief GetParameterOp: OpResult = GetParameterOp({StrAttribute,
/// StrAttribute}) /// StrAttribute})
......
...@@ -66,7 +66,7 @@ class InterfaceValue { ...@@ -66,7 +66,7 @@ class InterfaceValue {
class OpBase { class OpBase {
public: public:
explicit OpBase(Operation *operation) : operation_(operation) {} explicit OpBase(Operation *operation = nullptr) : operation_(operation) {}
Operation *operation() const { return operation_; } Operation *operation() const { return operation_; }
...@@ -76,6 +76,8 @@ class OpBase { ...@@ -76,6 +76,8 @@ class OpBase {
Operation *operator->() const { return operation_; } Operation *operator->() const { return operation_; }
IrContext *ir_context() const { return operation_->ir_context(); }
private: private:
Operation *operation_; // Not owned Operation *operation_; // Not owned
}; };
...@@ -91,7 +93,7 @@ class OpTraitBase : public OpBase { ...@@ -91,7 +93,7 @@ class OpTraitBase : public OpBase {
static TypeId GetTraitId() { return TypeId::get<ConcreteTrait>(); } static TypeId GetTraitId() { return TypeId::get<ConcreteTrait>(); }
static ConcreteTrait dyn_cast(Operation *op) { static ConcreteTrait dyn_cast(Operation *op) {
if (op->HasTrait<ConcreteTrait>()) { if (op && op->HasTrait<ConcreteTrait>()) {
return ConcreteTrait(op); return ConcreteTrait(op);
} }
return ConcreteTrait(nullptr); return ConcreteTrait(nullptr);
...@@ -109,7 +111,7 @@ class OpInterfaceBase : public OpBase { ...@@ -109,7 +111,7 @@ class OpInterfaceBase : public OpBase {
static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); } static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); }
static ConcreteInterface dyn_cast(Operation *op) { static ConcreteInterface dyn_cast(Operation *op) {
if (op->HasInterface<ConcreteInterface>()) { if (op && op->HasInterface<ConcreteInterface>()) {
return ConcreteInterface( return ConcreteInterface(
op, op->op_info().GetInterfaceImpl<ConcreteInterface>()); op, op->op_info().GetInterfaceImpl<ConcreteInterface>());
} }
...@@ -182,7 +184,7 @@ class Op : public OpBase { ...@@ -182,7 +184,7 @@ class Op : public OpBase {
typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type; typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type;
static ConcreteOp dyn_cast(Operation *op) { static ConcreteOp dyn_cast(Operation *op) {
if (op->op_info().id() == TypeId::get<ConcreteOp>()) { if (op && op->op_info().id() == TypeId::get<ConcreteOp>()) {
return ConcreteOp(op); return ConcreteOp(op);
} }
return ConcreteOp(nullptr); return ConcreteOp(nullptr);
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/dialect.h" #include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
#include "paddle/ir/core/region.h" #include "paddle/ir/core/region.h"
...@@ -21,7 +22,7 @@ ...@@ -21,7 +22,7 @@
namespace ir { namespace ir {
Operation *Operation::create(OperationArgument &&argument) { Operation *Operation::create(OperationArgument &&argument) {
Operation *op = create(argument.inputs, Operation *op = create(argument.inputs,
argument.attribute, argument.attributes,
argument.output_types, argument.output_types,
argument.info, argument.info,
argument.regions.size()); argument.regions.size());
...@@ -36,13 +37,13 @@ Operation *Operation::create(OperationArgument &&argument) { ...@@ -36,13 +37,13 @@ Operation *Operation::create(OperationArgument &&argument) {
// and operators, and construct it in the order of: OpOutlineResult, // and operators, and construct it in the order of: OpOutlineResult,
// OpInlineResult, Operation, Operand. // OpInlineResult, Operation, Operand.
Operation *Operation::create(const std::vector<ir::OpResult> &inputs, Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
const AttributeMap &attribute, const AttributeMap &attributes,
const std::vector<ir::Type> &output_types, const std::vector<ir::Type> &output_types,
ir::OpInfo op_info, ir::OpInfo op_info,
size_t num_regions) { size_t num_regions) {
// 0. Verify // 0. Verify
if (op_info) { if (op_info) {
op_info.verify(inputs, output_types, attribute); op_info.verify(inputs, output_types, attributes);
} }
// 1. Calculate the required memory size for OpResults + Operation + // 1. Calculate the required memory size for OpResults + Operation +
// OpOperands. // OpOperands.
...@@ -76,7 +77,7 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs, ...@@ -76,7 +77,7 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
} }
// 3.2. Construct Operation. // 3.2. Construct Operation.
Operation *op = new (base_ptr) Operation *op = new (base_ptr)
Operation(attribute, op_info, num_results, num_operands, num_regions); Operation(attributes, op_info, num_results, num_operands, num_regions);
base_ptr += sizeof(Operation); base_ptr += sizeof(Operation);
// 3.3. Construct OpOperands. // 3.3. Construct OpOperands.
if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) { if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) {
...@@ -160,12 +161,12 @@ void Operation::destroy() { ...@@ -160,12 +161,12 @@ void Operation::destroy() {
IrContext *Operation::ir_context() const { return op_info_.ir_context(); } IrContext *Operation::ir_context() const { return op_info_.ir_context(); }
Operation::Operation(const AttributeMap &attribute, Operation::Operation(const AttributeMap &attributes,
ir::OpInfo op_info, ir::OpInfo op_info,
uint32_t num_results, uint32_t num_results,
uint32_t num_operands, uint32_t num_operands,
uint32_t num_regions) uint32_t num_regions)
: attribute_(attribute), : attributes_(attributes),
op_info_(op_info), op_info_(op_info),
num_results_(num_results), num_results_(num_results),
num_operands_(num_operands), num_operands_(num_operands),
...@@ -223,6 +224,23 @@ std::string Operation::print() { ...@@ -223,6 +224,23 @@ std::string Operation::print() {
std::string Operation::op_name() const { return op_info_.name(); } std::string Operation::op_name() const { return op_info_.name(); }
Region *Operation::GetParentRegion() const {
return parent_ ? parent_->GetParentRegion() : nullptr;
}
Operation *Operation::GetParentOp() const {
return parent_ ? parent_->GetParentOp() : nullptr;
}
Program *Operation::GetParentProgram() {
Operation *op = this;
while (Operation *parent_op = op->GetParentOp()) {
op = parent_op;
}
ModuleOp module_op = op->dyn_cast<ModuleOp>();
return module_op ? module_op.program() : nullptr;
}
Region &Operation::GetRegion(unsigned index) { Region &Operation::GetRegion(unsigned index) {
assert(index < num_regions_ && "invalid region index"); assert(index < num_regions_ && "invalid region index");
return regions_[index]; return regions_[index];
......
...@@ -34,7 +34,7 @@ class alignas(8) Operation final { ...@@ -34,7 +34,7 @@ class alignas(8) Operation final {
/// used in conjunction. /// used in conjunction.
/// ///
static Operation *create(const std::vector<ir::OpResult> &inputs, static Operation *create(const std::vector<ir::OpResult> &inputs,
const AttributeMap &attribute, const AttributeMap &attributes,
const std::vector<ir::Type> &output_types, const std::vector<ir::Type> &output_types,
ir::OpInfo op_info, ir::OpInfo op_info,
size_t num_regions = 0); size_t num_regions = 0);
...@@ -45,8 +45,6 @@ class alignas(8) Operation final { ...@@ -45,8 +45,6 @@ class alignas(8) Operation final {
/// ///
void destroy(); void destroy();
Block *parent() const { return parent_; }
IrContext *ir_context() const; IrContext *ir_context() const;
ir::OpResult GetResultByIndex(uint32_t index) const; ir::OpResult GetResultByIndex(uint32_t index) const;
...@@ -55,7 +53,11 @@ class alignas(8) Operation final { ...@@ -55,7 +53,11 @@ class alignas(8) Operation final {
std::string print(); std::string print();
const AttributeMap &attribute() const { return attribute_; } const AttributeMap &attributes() const { return attributes_; }
void SetAttribute(const std::string &key, Attribute value) {
attributes_[key] = value;
}
ir::OpInfo op_info() const { return op_info_; } ir::OpInfo op_info() const { return op_info_; }
...@@ -82,11 +84,13 @@ class alignas(8) Operation final { ...@@ -82,11 +84,13 @@ class alignas(8) Operation final {
return op_info_.HasInterface<Interface>(); return op_info_.HasInterface<Interface>();
} }
Program *parent_program() const { return parent_program_; } Block *GetParentBlock() const { return parent_; }
void set_parent_program(Program *parent_program) { Region *GetParentRegion() const;
parent_program_ = parent_program;
} Operation *GetParentOp() const;
Program *GetParentProgram();
/// Returns the region held by this operation at position 'index'. /// Returns the region held by this operation at position 'index'.
Region &GetRegion(unsigned index); Region &GetRegion(unsigned index);
...@@ -115,7 +119,7 @@ class alignas(8) Operation final { ...@@ -115,7 +119,7 @@ class alignas(8) Operation final {
static T call(Operation *op) { return T::dyn_cast(op); } static T call(Operation *op) { return T::dyn_cast(op); }
}; };
AttributeMap attribute_; AttributeMap attributes_;
OpInfo op_info_; OpInfo op_info_;
...@@ -124,7 +128,6 @@ class alignas(8) Operation final { ...@@ -124,7 +128,6 @@ class alignas(8) Operation final {
const uint32_t num_regions_ = 0; const uint32_t num_regions_ = 0;
Region *regions_{nullptr}; Region *regions_{nullptr};
Program *parent_program_{nullptr};
Block *parent_{nullptr}; Block *parent_{nullptr};
}; };
......
...@@ -32,7 +32,7 @@ using AttributeMap = std::unordered_map<std::string, Attribute>; ...@@ -32,7 +32,7 @@ using AttributeMap = std::unordered_map<std::string, Attribute>;
// with the builder APIs. // with the builder APIs.
struct OperationArgument { struct OperationArgument {
std::vector<OpResult> inputs; std::vector<OpResult> inputs;
AttributeMap attribute; AttributeMap attributes;
std::vector<Type> output_types; std::vector<Type> output_types;
OpInfo info; OpInfo info;
std::vector<std::unique_ptr<Region>> regions; std::vector<std::unique_ptr<Region>> regions;
...@@ -41,12 +41,12 @@ struct OperationArgument { ...@@ -41,12 +41,12 @@ struct OperationArgument {
OperationArgument(IrContext* ir_context, const std::string& name); OperationArgument(IrContext* ir_context, const std::string& name);
explicit OperationArgument(OpInfo info) : info(info) {} explicit OperationArgument(OpInfo info) : info(info) {}
OperationArgument(const std::vector<OpResult>& operands, OperationArgument(const std::vector<OpResult>& operands,
const AttributeMap& named_attr, const AttributeMap& attributes,
const std::vector<Type>& types, const std::vector<Type>& types,
OpInfo info, OpInfo info,
std::vector<std::unique_ptr<Region>>&& regions = {}) std::vector<std::unique_ptr<Region>>&& regions = {})
: inputs(operands), : inputs(operands),
attribute(named_attr), attributes(attributes),
output_types(types), output_types(types),
info(info), info(info),
regions(std::move(regions)) {} regions(std::move(regions)) {}
...@@ -59,13 +59,18 @@ struct OperationArgument { ...@@ -59,13 +59,18 @@ struct OperationArgument {
/// Add an attribute with the specified name. /// Add an attribute with the specified name.
void addAttribute(const std::string& name, Attribute attr) { void addAttribute(const std::string& name, Attribute attr) {
this->attribute[name] = attr; attributes[name] = attr;
} }
/// Add an array of named attributes. /// Add an array of named attributes.
template <class InputIt> template <class InputIt>
void addAttributes(InputIt first, InputIt last); void addAttributes(InputIt first, InputIt last);
/// Get the context held by this operation state. /// Get the context held by this operation state.
IrContext* getContext() const { return info.ir_context(); } IrContext* getContext() const { return info.ir_context(); }
Region* AddRegion() {
regions.emplace_back(new Region);
return regions.back().get();
}
}; };
template <class InputIt> template <class InputIt>
...@@ -83,7 +88,7 @@ void OperationArgument::addTypes(InputIt first, InputIt last) { ...@@ -83,7 +88,7 @@ void OperationArgument::addTypes(InputIt first, InputIt last) {
template <class InputIt> template <class InputIt>
void OperationArgument::addAttributes(InputIt first, InputIt last) { void OperationArgument::addAttributes(InputIt first, InputIt last) {
while (first != last) { while (first != last) {
attribute[first->first] = first->second; attributes[first->first] = first->second;
++first; ++first;
} }
} }
......
...@@ -16,11 +16,17 @@ ...@@ -16,11 +16,17 @@
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
namespace ir { namespace ir {
Program::~Program() = default;
void Program::InsertOp(Operation* op) { Program::Program(IrContext* context) {
block_.push_back(op); module_ = ModuleOp::create(context, this);
op->set_parent_program(this); }
Program::Program() : Program(IrContext::Instance()) {}
Program::~Program() {
if (module_) {
module_.destroy();
}
} }
Parameter* Program::GetParameter(std::string name) const { Parameter* Program::GetParameter(std::string name) const {
......
...@@ -19,10 +19,13 @@ ...@@ -19,10 +19,13 @@
#include "paddle/ir/core/block.h" #include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/ir/core/parameter.h" #include "paddle/ir/core/parameter.h"
namespace ir { namespace ir {
class IrContext;
/// ///
/// \brief Program is an abstraction of model structure, divided into /// \brief Program is an abstraction of model structure, divided into
/// computational graphs and weights. At the current stage, a computational /// computational graphs and weights. At the current stage, a computational
...@@ -33,27 +36,34 @@ namespace ir { ...@@ -33,27 +36,34 @@ namespace ir {
/// ///
class Program { class Program {
public: public:
using ParameterMap =
std::unordered_map<std::string, std::unique_ptr<Parameter>>;
explicit Program(IrContext* context);
Program();
Program(Program&&) = delete;
Program(const Program& program) = delete;
Program& operator=(const Program&) = delete;
Program& operator=(Program&&);
~Program(); ~Program();
Block* block() { return &block_; }
size_t parameters_num() const { return parameters_.size(); } size_t parameters_num() const { return parameters_.size(); }
/// ModuleOp module_op() { return module_; }
/// \brief Insert the Operation* constructed by Operation::create(...) into
/// this Program. NOTE: At this time, the memory management permission of
/// Operation* will be owned by this Program. The user does not need to call
/// Operation::destroy() manually
///
void InsertOp(Operation* op);
Parameter* GetParameter(std::string name) const; Block* block() { return module_.block(); }
Parameter* GetParameter(std::string name) const;
void SetParameter(std::string name, std::unique_ptr<Parameter>&& parameter); void SetParameter(std::string name, std::unique_ptr<Parameter>&& parameter);
ParameterMap& parameters() { return parameters_; }
void set_parameters(ParameterMap&& parameters) {
parameters_ = std::move(parameters);
}
private: private:
Block block_; // computation graph
std::unordered_map<std::string, std::unique_ptr<Parameter>> parameters_; ModuleOp module_;
// weight
ParameterMap parameters_;
}; };
std::ostream& operator<<(std::ostream& os, Program& program); std::ostream& operator<<(std::ostream& os, Program& program);
......
...@@ -22,6 +22,9 @@ void Region::push_back(Block *block) { ...@@ -22,6 +22,9 @@ void Region::push_back(Block *block) {
block->set_parent(this); block->set_parent(this);
blocks_.push_back(block); blocks_.push_back(block);
} }
void Region::emplace_back() { push_back(new Block); }
void Region::push_front(Block *block) { void Region::push_front(Block *block) {
block->set_parent(this); block->set_parent(this);
blocks_.push_front(block); blocks_.push_front(block);
......
...@@ -41,12 +41,15 @@ class Region { ...@@ -41,12 +41,15 @@ class Region {
Block *back() const { return blocks_.back(); } Block *back() const { return blocks_.back(); }
Block *front() const { return blocks_.front(); } Block *front() const { return blocks_.front(); }
void push_back(Block *block); void push_back(Block *block);
void emplace_back();
void push_front(Block *block); void push_front(Block *block);
iterator insert(const_iterator position, Block *block); iterator insert(const_iterator position, Block *block);
void clear(); void clear();
void TakeBody(Region &&other); void TakeBody(Region &&other);
Operation *GetParentOp() const { return parent_; }
private: private:
Region(Region &) = delete; Region(Region &) = delete;
Region &operator=(const Region &) = delete; Region &operator=(const Region &) = delete;
......
...@@ -17,10 +17,12 @@ ...@@ -17,10 +17,12 @@
#include "paddle/ir/core/block.h" #include "paddle/ir/core/block.h"
#include "paddle/ir/core/builder.h" #include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h" #include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_base.h" #include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/region.h" #include "paddle/ir/core/region.h"
/// \brief Define built-in Trait, derived from OpTraitBase. /// \brief Define built-in Trait, derived from OpTraitBase.
...@@ -133,7 +135,7 @@ class Operation2 ...@@ -133,7 +135,7 @@ class Operation2
throw("Type of attribute: parameter_name is not right."); throw("Type of attribute: parameter_name is not right.");
} }
} }
static void InferShape() { VLOG(0) << "This is op2's InferShape interface."; } static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; }
}; };
const char *Operation2::attributes_name[attributes_num] = {"op2_attr1", const char *Operation2::attributes_name[attributes_num] = {"op2_attr1",
"op2_attr2"}; "op2_attr2"};
...@@ -212,7 +214,7 @@ TEST(op_test, region_test) { ...@@ -212,7 +214,7 @@ TEST(op_test, region_test) {
op1_info); op1_info);
ir::OperationArgument argument(op2_info); ir::OperationArgument argument(op2_info);
argument.attribute = CreateAttributeMap({"op2_attr1", "op2_attr2"}, argument.attributes = CreateAttributeMap({"op2_attr1", "op2_attr2"},
{"op2_attr1", "op2_attr2"}); {"op2_attr1", "op2_attr2"});
argument.output_types = {ir::Float32Type::get(ctx)}; argument.output_types = {ir::Float32Type::get(ctx)};
argument.regions.emplace_back(std::make_unique<ir::Region>()); argument.regions.emplace_back(std::make_unique<ir::Region>());
...@@ -228,3 +230,26 @@ TEST(op_test, region_test) { ...@@ -228,3 +230,26 @@ TEST(op_test, region_test) {
ir::Operation *op2 = ir::Operation::create(std::move(argument)); ir::Operation *op2 = ir::Operation::create(std::move(argument));
op2->destroy(); op2->destroy();
} }
TEST(op_test, module_op_death) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(ir::ModuleOp::name());
// (3) Test uses for op.
std::vector<ir::OpResult> inputs{ir::OpResult()};
ir::AttributeMap attrs{{"program", ir::Int32_tAttribute::get(ctx, 1)}};
std::vector<ir::Type> output_types = {ir::Float32Type::get(ctx)};
EXPECT_THROW(ir::Operation::create(inputs, {}, {}, op_info), const char *);
EXPECT_THROW(ir::Operation::create({}, attrs, {}, op_info), const char *);
EXPECT_THROW(ir::Operation::create({}, {}, output_types, op_info),
const char *);
ir::Program program(ctx);
EXPECT_EQ(program.module_op().program(), &program);
EXPECT_EQ(program.module_op().ir_context(), ctx);
program.module_op()->SetAttribute("program",
ir::PointerAttribute::get(ctx, &program));
}
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/dialect/pd_interface.h" #include "paddle/fluid/dialect/pd_interface.h"
#include "paddle/fluid/dialect/pd_type.h" #include "paddle/fluid/dialect/pd_type.h"
#include "paddle/fluid/dialect/utils.h" #include "paddle/fluid/dialect/utils.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h" #include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
...@@ -56,9 +57,7 @@ TEST(program_test, program) { ...@@ -56,9 +57,7 @@ TEST(program_test, program) {
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>(); ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
// (2) Create an empty program object // (2) Create an empty program object
ir::Program program; ir::Program program(ctx);
// ir::Program *program = new ir::Program();
EXPECT_EQ(program.block()->size() == 0, true);
// (3) Create a float32 DenseTensor Parameter and save into Program // (3) Create a float32 DenseTensor Parameter and save into Program
ir::Type fp32_dtype = ir::Float32Type::get(ctx); ir::Type fp32_dtype = ir::Float32Type::get(ctx);
...@@ -94,7 +93,14 @@ TEST(program_test, program) { ...@@ -94,7 +93,14 @@ TEST(program_test, program) {
ir::Operation *op1 = ir::Operation *op1 =
ir::Operation::create({}, op1_attribute, {dense_tensor_dtype}, op1_info); ir::Operation::create({}, op1_attribute, {dense_tensor_dtype}, op1_info);
program.InsertOp(op1); ir::Block *block = program.block();
block->push_back(op1);
EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParentRegion());
EXPECT_EQ(program.module_op(), block->GetParentOp());
EXPECT_EQ(&program, op1->GetParentProgram());
EXPECT_EQ(op1->GetResultByIndex(0).type().dialect().id(), EXPECT_EQ(op1->GetResultByIndex(0).type().dialect().id(),
paddle_dialect->id()); paddle_dialect->id());
...@@ -124,7 +130,7 @@ TEST(program_test, program) { ...@@ -124,7 +130,7 @@ TEST(program_test, program) {
{"parameter_name", ir::StrAttribute::get(ctx, "b")}}; {"parameter_name", ir::StrAttribute::get(ctx, "b")}};
ir::Operation *op2 = ir::Operation *op2 =
ir::Operation::create({}, op2_attribute, {dense_tensor_dtype}, op2_info); ir::Operation::create({}, op2_attribute, {dense_tensor_dtype}, op2_info);
program.InsertOp(op2); block->push_back(op2);
EXPECT_EQ(op2->GetResultByIndex(0).type().dialect().id(), EXPECT_EQ(op2->GetResultByIndex(0).type().dialect().id(),
paddle_dialect->id()); paddle_dialect->id());
...@@ -155,7 +161,7 @@ TEST(program_test, program) { ...@@ -155,7 +161,7 @@ TEST(program_test, program) {
op3_attribute, op3_attribute,
{dense_tensor_dtype}, {dense_tensor_dtype},
op3_info); op3_info);
program.InsertOp(op3); block->push_back(op3);
phi::CPUContext *dev_ctx = static_cast<phi::CPUContext *>( phi::CPUContext *dev_ctx = static_cast<phi::CPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get( paddle::platform::DeviceContextPool::Instance().Get(
...@@ -196,9 +202,12 @@ TEST(program_test, program) { ...@@ -196,9 +202,12 @@ TEST(program_test, program) {
ir::OpInfo op4_info = ctx->GetRegisteredOpInfo(op4_name); ir::OpInfo op4_info = ctx->GetRegisteredOpInfo(op4_name);
std::unordered_map<std::string, ir::Attribute> op4_attribute{ std::unordered_map<std::string, ir::Attribute> op4_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "c")}}; {"parameter_name", ir::StrAttribute::get(ctx, "c")}};
ir::Operation *op4 = ir::Operation::create(
{op3->GetResultByIndex(0)}, op4_attribute, {}, op4_info); ir::OperationArgument op4_argument(
program.InsertOp(op4); {op3->GetResultByIndex(0)}, {}, {}, op4_info);
op4_argument.addAttributes(op4_attribute.begin(), op4_attribute.end());
ir::Operation *op4 = ir::Operation::create(std::move(op4_argument));
block->push_back(op4);
EXPECT_EQ(op4->GetOperandByIndex(0).impl()->source().type().dialect().id(), EXPECT_EQ(op4->GetOperandByIndex(0).impl()->source().type().dialect().id(),
paddle_dialect->id()); paddle_dialect->id());
...@@ -244,7 +253,7 @@ TEST(program_test, slice_combine_test) { ...@@ -244,7 +253,7 @@ TEST(program_test, slice_combine_test) {
{"parameter_name", ir::StrAttribute::get(ctx, "a")}}; {"parameter_name", ir::StrAttribute::get(ctx, "a")}};
ir::Operation *op1 = ir::Operation *op1 =
ir::Operation::create({}, op1_attribute, {fp32_dtype}, op1_info); ir::Operation::create({}, op1_attribute, {fp32_dtype}, op1_info);
program.InsertOp(op1); program.block()->push_back(op1);
// (5) Def b = GetParameterOp("b") // (5) Def b = GetParameterOp("b")
std::string op2_name = std::string(ir::GetParameterOp::name()); std::string op2_name = std::string(ir::GetParameterOp::name());
...@@ -253,7 +262,7 @@ TEST(program_test, slice_combine_test) { ...@@ -253,7 +262,7 @@ TEST(program_test, slice_combine_test) {
{"parameter_name", ir::StrAttribute::get(ctx, "b")}}; {"parameter_name", ir::StrAttribute::get(ctx, "b")}};
ir::Operation *op2 = ir::Operation *op2 =
ir::Operation::create({}, op2_attribute, {fp32_dtype}, op2_info); ir::Operation::create({}, op2_attribute, {fp32_dtype}, op2_info);
program.InsertOp(op2); program.block()->push_back(op2);
// (6) Def combine_op = CombineOp("a", "b") // (6) Def combine_op = CombineOp("a", "b")
std::string combine_op_name = std::string(ir::CombineOp::name()); std::string combine_op_name = std::string(ir::CombineOp::name());
...@@ -265,7 +274,7 @@ TEST(program_test, slice_combine_test) { ...@@ -265,7 +274,7 @@ TEST(program_test, slice_combine_test) {
{}, {},
{output_type}, {output_type},
combine_op_info); combine_op_info);
program.InsertOp(combine_op); program.block()->push_back(combine_op);
// (7) Def slice_op = SliceOp(combine_op, 0) // (7) Def slice_op = SliceOp(combine_op, 0)
std::string slice_op_name = std::string(ir::SliceOp::name()); std::string slice_op_name = std::string(ir::SliceOp::name());
...@@ -276,7 +285,7 @@ TEST(program_test, slice_combine_test) { ...@@ -276,7 +285,7 @@ TEST(program_test, slice_combine_test) {
{{"index", index_attr}}, {{"index", index_attr}},
{fp32_dtype}, {fp32_dtype},
slice_op_info); slice_op_info);
program.InsertOp(slice_op); program.block()->push_back(slice_op);
// (8) Traverse Program // (8) Traverse Program
EXPECT_EQ(program.block()->size() == 4, true); EXPECT_EQ(program.block()->size() == 4, true);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册