From 49bedfd366acb35482ed921f2791781d76d74114 Mon Sep 17 00:00:00 2001 From: winter-wang <78149749+winter-wang@users.noreply.github.com> Date: Fri, 2 Jun 2023 08:01:18 +0800 Subject: [PATCH] [IR] refine the program data structure. (#54220) --- paddle/fluid/translator/op_translator.cc | 10 ++-- paddle/fluid/translator/program_translator.cc | 2 +- paddle/ir/core/block.h | 6 +++ paddle/ir/core/builtin_attribute.cc | 2 + paddle/ir/core/builtin_attribute.h | 11 +++- paddle/ir/core/builtin_attribute_storage.h | 1 + paddle/ir/core/builtin_dialect.cc | 4 +- paddle/ir/core/builtin_op.cc | 54 +++++++++++++++++++ paddle/ir/core/builtin_op.h | 25 +++++++++ paddle/ir/core/op_base.h | 10 ++-- paddle/ir/core/operation.cc | 30 ++++++++--- paddle/ir/core/operation.h | 23 ++++---- paddle/ir/core/operation_utils.h | 15 ++++-- paddle/ir/core/program.cc | 14 +++-- paddle/ir/core/program.h | 36 ++++++++----- paddle/ir/core/region.cc | 3 ++ paddle/ir/core/region.h | 3 ++ test/cpp/ir/core/ir_op_test.cc | 31 +++++++++-- test/cpp/ir/core/ir_program_test.cc | 35 +++++++----- 19 files changed, 249 insertions(+), 66 deletions(-) diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index c8ce1ffdcab..7d917825859 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -113,7 +113,7 @@ inline ir::Operation* InsertSliceOperationForTarget( op_attribute_map, {src_vec_type[defining_info.idx_in_vector]}, op_info); - program->InsertOp(operation); + program->block()->push_back(operation); ir::OpResult target_op_result = operation->GetResultByIndex(0); (*param_map)[arg_name] = VariableDefiningInfo(target_op_result); return operation; @@ -137,7 +137,7 @@ inline ir::Operation* InsertCombineOperationForTarget( ir::Type target_vec_type = ir::VectorType::get(ctx, types_in_vec); ir::Operation* operation = ir::Operation::create(src_values, {}, {target_vec_type}, op_info); - program->InsertOp(operation); + program->block()->push_back(operation); return operation; } @@ -282,7 +282,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, auto op_info = LoopkUpOpInfo(ctx, op_desc); ir::Operation* operation = ir::Operation::create(op_inputs, {}, op_output_types, op_info); - program->InsertOp(operation); + program->block()->push_back(operation); RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); return operation; @@ -300,7 +300,7 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, auto op_info = LoopkUpOpInfo(ctx, op_desc); ir::Operation* operation = ir::Operation::create(op_inputs, {}, op_output_types, op_info); - program->InsertOp(operation); + program->block()->push_back(operation); RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); return operation; @@ -316,7 +316,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, auto op_info = LoopkUpOpInfo(ctx, op_desc); ir::Operation* operation = ir::Operation::create(op_inputs, {}, op_output_types, op_info); - program->InsertOp(operation); + program->block()->push_back(operation); return operation; } diff --git a/paddle/fluid/translator/program_translator.cc b/paddle/fluid/translator/program_translator.cc index 85a09b2da03..ff8dc88225f 100644 --- a/paddle/fluid/translator/program_translator.cc +++ b/paddle/fluid/translator/program_translator.cc @@ -80,7 +80,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock( ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); ir::Operation* operation = ir::Operation::create( {}, op_attribute_map, {translated_var_type}, op_info); - program->InsertOp(operation); + program->block()->push_back(operation); param_map[var->Name()] = VariableDefiningInfo(operation->GetResultByIndex(0)); VLOG(10) << "[op translated][get parameter]" << operation; diff --git a/paddle/ir/core/block.h b/paddle/ir/core/block.h index 09b4b584b66..6534bdd60b7 100644 --- a/paddle/ir/core/block.h +++ b/paddle/ir/core/block.h @@ -46,6 +46,12 @@ class Block { iterator insert(const_iterator iterator, Operation *op); void clear(); + Region *GetParentRegion() const { return parent_; } + + Operation *GetParentOp() const { + return parent_ ? parent_->GetParentOp() : nullptr; + } + private: Block(Block &) = delete; Block &operator=(const Block &) = delete; diff --git a/paddle/ir/core/builtin_attribute.cc b/paddle/ir/core/builtin_attribute.cc index 78bd2903be3..76cc50d3de4 100644 --- a/paddle/ir/core/builtin_attribute.cc +++ b/paddle/ir/core/builtin_attribute.cc @@ -33,4 +33,6 @@ std::vector ArrayAttribute::data() const { return storage()->GetAsKey(); } +void* PointerAttribute::data() const { return storage()->GetAsKey(); } + } // namespace ir diff --git a/paddle/ir/core/builtin_attribute.h b/paddle/ir/core/builtin_attribute.h index edad980136f..cc4741dc3a7 100644 --- a/paddle/ir/core/builtin_attribute.h +++ b/paddle/ir/core/builtin_attribute.h @@ -25,7 +25,7 @@ class StrAttribute : public Attribute { DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(StrAttribute, StrAttributeStorage); - bool operator<(const StrAttribute &right) const { + bool operator<(const StrAttribute& right) const { return storage() < right.storage(); } @@ -94,4 +94,13 @@ class ArrayAttribute : public Attribute { Attribute operator[](size_t index) const { return data()[index]; } }; +class PointerAttribute : public Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(PointerAttribute, PointerAttributeStorage); + + void* data() const; +}; + } // namespace ir diff --git a/paddle/ir/core/builtin_attribute_storage.h b/paddle/ir/core/builtin_attribute_storage.h index 3d2e23bc047..40b833fe9c7 100644 --- a/paddle/ir/core/builtin_attribute_storage.h +++ b/paddle/ir/core/builtin_attribute_storage.h @@ -83,6 +83,7 @@ DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(DoubleAttributeStorage, double); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int32_tAttributeStorage, int32_t); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int64_tAttributeStorage, int64_t); +DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(PointerAttributeStorage, void *); struct ArrayAttributeStorage : public AttributeStorage { using ParamKey = std::vector; diff --git a/paddle/ir/core/builtin_dialect.cc b/paddle/ir/core/builtin_dialect.cc index bdbb2736f12..c88b9d8133d 100644 --- a/paddle/ir/core/builtin_dialect.cc +++ b/paddle/ir/core/builtin_dialect.cc @@ -40,11 +40,13 @@ void BuiltinDialect::initialize() { ir::BoolAttribute, ir::FloatAttribute, ir::DoubleAttribute, + ir::PointerAttribute, ir::Int32_tAttribute, ir::Int64_tAttribute, ir::ArrayAttribute>(); - RegisterOps(); diff --git a/paddle/ir/core/builtin_op.cc b/paddle/ir/core/builtin_op.cc index 023c96d7e13..2da74f17b52 100644 --- a/paddle/ir/core/builtin_op.cc +++ b/paddle/ir/core/builtin_op.cc @@ -19,6 +19,60 @@ #include "paddle/phi/core/enforce.h" namespace ir { + +const char *ModuleOp::attributes_name[attributes_num] = {"program"}; + +Program *ModuleOp::program() { + const AttributeMap &attr = operation()->attributes(); + auto iter = attr.find("program"); + if (iter == attr.end() || !iter->second) return nullptr; + return static_cast( + iter->second.dyn_cast().data()); +} + +Block *ModuleOp::block() { + assert(operation() != nullptr); + assert(operation()->num_regions() == 1); + assert(operation()->GetRegion(0).size() == 1); + return operation()->GetRegion(0).front(); +} + +ModuleOp ModuleOp::create(IrContext *context, Program *pointer) { + ir::OpInfo info = context->GetRegisteredOpInfo(name()); + OperationArgument argument(info); + argument.AddRegion()->emplace_back(); + argument.addAttribute("program", PointerAttribute::get(context, pointer)); + return ModuleOp(Operation::create(std::move(argument))); +} + +void ModuleOp::destroy() { + if (operation()) { + operation()->destroy(); + *this = ModuleOp(nullptr); + } +} + +void ModuleOp::verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes) { + VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp."; + // Verify inputs type: + if (inputs.size() != 0) { + throw("The size of inputs must be equal to 0."); + } + + // Verify if attributes contain attribute name in attributes_name: + auto iter = attributes.find("program"); + if (iter == attributes.end() || !iter->second.isa()) { + throw("Type of attribute: program is not right."); + } + + // Verify outputs type: + if (outputs.size() != 0) { + throw("The size of outputs must be equal to 0."); + } +} + const char *GetParameterOp::attributes_name[attributes_num] = { "parameter_name"}; diff --git a/paddle/ir/core/builtin_op.h b/paddle/ir/core/builtin_op.h index 3a7f77b00aa..c9396e64620 100644 --- a/paddle/ir/core/builtin_op.h +++ b/paddle/ir/core/builtin_op.h @@ -18,6 +18,31 @@ namespace ir { +class Program; +class Block; +/// +/// \brief ModuleOp +/// +class ModuleOp : public ir::Op { + public: + using Op::Op; + static const char *name() { return "builtin.module"; } + static constexpr uint32_t attributes_num = 1; + static const char *attributes_name[attributes_num]; + static void verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes); + + Program *program(); + Block *block(); + + // + // As the top operation, ModuleOp only support create&destroye through + // below interface: "create"&"destroy". + static ModuleOp create(IrContext *context, Program *pointer); + void destroy(); +}; + /// /// \brief GetParameterOp: OpResult = GetParameterOp({StrAttribute, /// StrAttribute}) diff --git a/paddle/ir/core/op_base.h b/paddle/ir/core/op_base.h index 20d7d7036d4..fae2b7cd1f5 100644 --- a/paddle/ir/core/op_base.h +++ b/paddle/ir/core/op_base.h @@ -66,7 +66,7 @@ class InterfaceValue { class OpBase { public: - explicit OpBase(Operation *operation) : operation_(operation) {} + explicit OpBase(Operation *operation = nullptr) : operation_(operation) {} Operation *operation() const { return operation_; } @@ -76,6 +76,8 @@ class OpBase { Operation *operator->() const { return operation_; } + IrContext *ir_context() const { return operation_->ir_context(); } + private: Operation *operation_; // Not owned }; @@ -91,7 +93,7 @@ class OpTraitBase : public OpBase { static TypeId GetTraitId() { return TypeId::get(); } static ConcreteTrait dyn_cast(Operation *op) { - if (op->HasTrait()) { + if (op && op->HasTrait()) { return ConcreteTrait(op); } return ConcreteTrait(nullptr); @@ -109,7 +111,7 @@ class OpInterfaceBase : public OpBase { static TypeId GetInterfaceId() { return TypeId::get(); } static ConcreteInterface dyn_cast(Operation *op) { - if (op->HasInterface()) { + if (op && op->HasInterface()) { return ConcreteInterface( op, op->op_info().GetInterfaceImpl()); } @@ -182,7 +184,7 @@ class Op : public OpBase { typename Filter>::Type; static ConcreteOp dyn_cast(Operation *op) { - if (op->op_info().id() == TypeId::get()) { + if (op && op->op_info().id() == TypeId::get()) { return ConcreteOp(op); } return ConcreteOp(nullptr); diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index 4f9575c03d3..b7aa58a9b0e 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/ir/core/operation.h" +#include "paddle/ir/core/block.h" #include "paddle/ir/core/dialect.h" #include "paddle/ir/core/program.h" #include "paddle/ir/core/region.h" @@ -21,7 +22,7 @@ namespace ir { Operation *Operation::create(OperationArgument &&argument) { Operation *op = create(argument.inputs, - argument.attribute, + argument.attributes, argument.output_types, argument.info, argument.regions.size()); @@ -36,13 +37,13 @@ Operation *Operation::create(OperationArgument &&argument) { // and operators, and construct it in the order of: OpOutlineResult, // OpInlineResult, Operation, Operand. Operation *Operation::create(const std::vector &inputs, - const AttributeMap &attribute, + const AttributeMap &attributes, const std::vector &output_types, ir::OpInfo op_info, size_t num_regions) { // 0. Verify if (op_info) { - op_info.verify(inputs, output_types, attribute); + op_info.verify(inputs, output_types, attributes); } // 1. Calculate the required memory size for OpResults + Operation + // OpOperands. @@ -76,7 +77,7 @@ Operation *Operation::create(const std::vector &inputs, } // 3.2. Construct Operation. Operation *op = new (base_ptr) - Operation(attribute, op_info, num_results, num_operands, num_regions); + Operation(attributes, op_info, num_results, num_operands, num_regions); base_ptr += sizeof(Operation); // 3.3. Construct OpOperands. if ((reinterpret_cast(base_ptr) & 0x7) != 0) { @@ -160,12 +161,12 @@ void Operation::destroy() { IrContext *Operation::ir_context() const { return op_info_.ir_context(); } -Operation::Operation(const AttributeMap &attribute, +Operation::Operation(const AttributeMap &attributes, ir::OpInfo op_info, uint32_t num_results, uint32_t num_operands, uint32_t num_regions) - : attribute_(attribute), + : attributes_(attributes), op_info_(op_info), num_results_(num_results), num_operands_(num_operands), @@ -223,6 +224,23 @@ std::string Operation::print() { std::string Operation::op_name() const { return op_info_.name(); } +Region *Operation::GetParentRegion() const { + return parent_ ? parent_->GetParentRegion() : nullptr; +} + +Operation *Operation::GetParentOp() const { + return parent_ ? parent_->GetParentOp() : nullptr; +} + +Program *Operation::GetParentProgram() { + Operation *op = this; + while (Operation *parent_op = op->GetParentOp()) { + op = parent_op; + } + ModuleOp module_op = op->dyn_cast(); + return module_op ? module_op.program() : nullptr; +} + Region &Operation::GetRegion(unsigned index) { assert(index < num_regions_ && "invalid region index"); return regions_[index]; diff --git a/paddle/ir/core/operation.h b/paddle/ir/core/operation.h index d5804ee9e43..0775f68d63e 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/ir/core/operation.h @@ -34,7 +34,7 @@ class alignas(8) Operation final { /// used in conjunction. /// static Operation *create(const std::vector &inputs, - const AttributeMap &attribute, + const AttributeMap &attributes, const std::vector &output_types, ir::OpInfo op_info, size_t num_regions = 0); @@ -45,8 +45,6 @@ class alignas(8) Operation final { /// void destroy(); - Block *parent() const { return parent_; } - IrContext *ir_context() const; ir::OpResult GetResultByIndex(uint32_t index) const; @@ -55,7 +53,11 @@ class alignas(8) Operation final { std::string print(); - const AttributeMap &attribute() const { return attribute_; } + const AttributeMap &attributes() const { return attributes_; } + + void SetAttribute(const std::string &key, Attribute value) { + attributes_[key] = value; + } ir::OpInfo op_info() const { return op_info_; } @@ -82,11 +84,13 @@ class alignas(8) Operation final { return op_info_.HasInterface(); } - Program *parent_program() const { return parent_program_; } + Block *GetParentBlock() const { return parent_; } - void set_parent_program(Program *parent_program) { - parent_program_ = parent_program; - } + Region *GetParentRegion() const; + + Operation *GetParentOp() const; + + Program *GetParentProgram(); /// Returns the region held by this operation at position 'index'. Region &GetRegion(unsigned index); @@ -115,7 +119,7 @@ class alignas(8) Operation final { static T call(Operation *op) { return T::dyn_cast(op); } }; - AttributeMap attribute_; + AttributeMap attributes_; OpInfo op_info_; @@ -124,7 +128,6 @@ class alignas(8) Operation final { const uint32_t num_regions_ = 0; Region *regions_{nullptr}; - Program *parent_program_{nullptr}; Block *parent_{nullptr}; }; diff --git a/paddle/ir/core/operation_utils.h b/paddle/ir/core/operation_utils.h index fb43e8a1ca0..a0f9d9f237a 100644 --- a/paddle/ir/core/operation_utils.h +++ b/paddle/ir/core/operation_utils.h @@ -32,7 +32,7 @@ using AttributeMap = std::unordered_map; // with the builder APIs. struct OperationArgument { std::vector inputs; - AttributeMap attribute; + AttributeMap attributes; std::vector output_types; OpInfo info; std::vector> regions; @@ -41,12 +41,12 @@ struct OperationArgument { OperationArgument(IrContext* ir_context, const std::string& name); explicit OperationArgument(OpInfo info) : info(info) {} OperationArgument(const std::vector& operands, - const AttributeMap& named_attr, + const AttributeMap& attributes, const std::vector& types, OpInfo info, std::vector>&& regions = {}) : inputs(operands), - attribute(named_attr), + attributes(attributes), output_types(types), info(info), regions(std::move(regions)) {} @@ -59,13 +59,18 @@ struct OperationArgument { /// Add an attribute with the specified name. void addAttribute(const std::string& name, Attribute attr) { - this->attribute[name] = attr; + attributes[name] = attr; } /// Add an array of named attributes. template void addAttributes(InputIt first, InputIt last); /// Get the context held by this operation state. IrContext* getContext() const { return info.ir_context(); } + + Region* AddRegion() { + regions.emplace_back(new Region); + return regions.back().get(); + } }; template @@ -83,7 +88,7 @@ void OperationArgument::addTypes(InputIt first, InputIt last) { template void OperationArgument::addAttributes(InputIt first, InputIt last) { while (first != last) { - attribute[first->first] = first->second; + attributes[first->first] = first->second; ++first; } } diff --git a/paddle/ir/core/program.cc b/paddle/ir/core/program.cc index 01bd9e2dd57..ac97d935305 100644 --- a/paddle/ir/core/program.cc +++ b/paddle/ir/core/program.cc @@ -16,11 +16,17 @@ #include "paddle/ir/core/ir_context.h" namespace ir { -Program::~Program() = default; -void Program::InsertOp(Operation* op) { - block_.push_back(op); - op->set_parent_program(this); +Program::Program(IrContext* context) { + module_ = ModuleOp::create(context, this); +} + +Program::Program() : Program(IrContext::Instance()) {} + +Program::~Program() { + if (module_) { + module_.destroy(); + } } Parameter* Program::GetParameter(std::string name) const { diff --git a/paddle/ir/core/program.h b/paddle/ir/core/program.h index 3169772df7a..86402a3f4f7 100644 --- a/paddle/ir/core/program.h +++ b/paddle/ir/core/program.h @@ -19,10 +19,13 @@ #include "paddle/ir/core/block.h" #include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/operation.h" #include "paddle/ir/core/parameter.h" namespace ir { + +class IrContext; /// /// \brief Program is an abstraction of model structure, divided into /// computational graphs and weights. At the current stage, a computational @@ -33,27 +36,34 @@ namespace ir { /// class Program { public: + using ParameterMap = + std::unordered_map>; + explicit Program(IrContext* context); + Program(); + Program(Program&&) = delete; + Program(const Program& program) = delete; + Program& operator=(const Program&) = delete; + Program& operator=(Program&&); ~Program(); - - Block* block() { return &block_; } - size_t parameters_num() const { return parameters_.size(); } - /// - /// \brief Insert the Operation* constructed by Operation::create(...) into - /// this Program. NOTE: At this time, the memory management permission of - /// Operation* will be owned by this Program. The user does not need to call - /// Operation::destroy() manually - /// - void InsertOp(Operation* op); + ModuleOp module_op() { return module_; } - Parameter* GetParameter(std::string name) const; + Block* block() { return module_.block(); } + Parameter* GetParameter(std::string name) const; void SetParameter(std::string name, std::unique_ptr&& parameter); + ParameterMap& parameters() { return parameters_; } + void set_parameters(ParameterMap&& parameters) { + parameters_ = std::move(parameters); + } + private: - Block block_; - std::unordered_map> parameters_; + // computation graph + ModuleOp module_; + // weight + ParameterMap parameters_; }; std::ostream& operator<<(std::ostream& os, Program& program); diff --git a/paddle/ir/core/region.cc b/paddle/ir/core/region.cc index 905f497c0bc..e434bbec494 100644 --- a/paddle/ir/core/region.cc +++ b/paddle/ir/core/region.cc @@ -22,6 +22,9 @@ void Region::push_back(Block *block) { block->set_parent(this); blocks_.push_back(block); } + +void Region::emplace_back() { push_back(new Block); } + void Region::push_front(Block *block) { block->set_parent(this); blocks_.push_front(block); diff --git a/paddle/ir/core/region.h b/paddle/ir/core/region.h index da84d970f1f..9a5a2f7a9b7 100644 --- a/paddle/ir/core/region.h +++ b/paddle/ir/core/region.h @@ -41,12 +41,15 @@ class Region { Block *back() const { return blocks_.back(); } Block *front() const { return blocks_.front(); } void push_back(Block *block); + void emplace_back(); void push_front(Block *block); iterator insert(const_iterator position, Block *block); void clear(); void TakeBody(Region &&other); + Operation *GetParentOp() const { return parent_; } + private: Region(Region &) = delete; Region &operator=(const Region &) = delete; diff --git a/test/cpp/ir/core/ir_op_test.cc b/test/cpp/ir/core/ir_op_test.cc index 65796d827e0..a058b35ef35 100644 --- a/test/cpp/ir/core/ir_op_test.cc +++ b/test/cpp/ir/core/ir_op_test.cc @@ -17,10 +17,12 @@ #include "paddle/ir/core/block.h" #include "paddle/ir/core/builder.h" #include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/dialect.h" #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/op_base.h" +#include "paddle/ir/core/program.h" #include "paddle/ir/core/region.h" /// \brief Define built-in Trait, derived from OpTraitBase. @@ -133,7 +135,7 @@ class Operation2 throw("Type of attribute: parameter_name is not right."); } } - static void InferShape() { VLOG(0) << "This is op2's InferShape interface."; } + static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; } }; const char *Operation2::attributes_name[attributes_num] = {"op2_attr1", "op2_attr2"}; @@ -212,8 +214,8 @@ TEST(op_test, region_test) { op1_info); ir::OperationArgument argument(op2_info); - argument.attribute = CreateAttributeMap({"op2_attr1", "op2_attr2"}, - {"op2_attr1", "op2_attr2"}); + argument.attributes = CreateAttributeMap({"op2_attr1", "op2_attr2"}, + {"op2_attr1", "op2_attr2"}); argument.output_types = {ir::Float32Type::get(ctx)}; argument.regions.emplace_back(std::make_unique()); ir::Region *region = argument.regions.back().get(); @@ -228,3 +230,26 @@ TEST(op_test, region_test) { ir::Operation *op2 = ir::Operation::create(std::move(argument)); op2->destroy(); } + +TEST(op_test, module_op_death) { + ir::IrContext *ctx = ir::IrContext::Instance(); + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(ir::ModuleOp::name()); + + // (3) Test uses for op. + std::vector inputs{ir::OpResult()}; + ir::AttributeMap attrs{{"program", ir::Int32_tAttribute::get(ctx, 1)}}; + std::vector output_types = {ir::Float32Type::get(ctx)}; + + EXPECT_THROW(ir::Operation::create(inputs, {}, {}, op_info), const char *); + EXPECT_THROW(ir::Operation::create({}, attrs, {}, op_info), const char *); + EXPECT_THROW(ir::Operation::create({}, {}, output_types, op_info), + const char *); + + ir::Program program(ctx); + + EXPECT_EQ(program.module_op().program(), &program); + EXPECT_EQ(program.module_op().ir_context(), ctx); + + program.module_op()->SetAttribute("program", + ir::PointerAttribute::get(ctx, &program)); +} diff --git a/test/cpp/ir/core/ir_program_test.cc b/test/cpp/ir/core/ir_program_test.cc index 3150519a315..eab9f872c78 100644 --- a/test/cpp/ir/core/ir_program_test.cc +++ b/test/cpp/ir/core/ir_program_test.cc @@ -18,6 +18,7 @@ #include "paddle/fluid/dialect/pd_interface.h" #include "paddle/fluid/dialect/pd_type.h" #include "paddle/fluid/dialect/utils.h" +#include "paddle/ir/core/block.h" #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_dialect.h" #include "paddle/ir/core/builtin_op.h" @@ -56,9 +57,7 @@ TEST(program_test, program) { ctx->GetOrRegisterDialect(); // (2) Create an empty program object - ir::Program program; - // ir::Program *program = new ir::Program(); - EXPECT_EQ(program.block()->size() == 0, true); + ir::Program program(ctx); // (3) Create a float32 DenseTensor Parameter and save into Program ir::Type fp32_dtype = ir::Float32Type::get(ctx); @@ -94,7 +93,14 @@ TEST(program_test, program) { ir::Operation *op1 = ir::Operation::create({}, op1_attribute, {dense_tensor_dtype}, op1_info); - program.InsertOp(op1); + ir::Block *block = program.block(); + block->push_back(op1); + + EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParentRegion()); + + EXPECT_EQ(program.module_op(), block->GetParentOp()); + + EXPECT_EQ(&program, op1->GetParentProgram()); EXPECT_EQ(op1->GetResultByIndex(0).type().dialect().id(), paddle_dialect->id()); @@ -124,7 +130,7 @@ TEST(program_test, program) { {"parameter_name", ir::StrAttribute::get(ctx, "b")}}; ir::Operation *op2 = ir::Operation::create({}, op2_attribute, {dense_tensor_dtype}, op2_info); - program.InsertOp(op2); + block->push_back(op2); EXPECT_EQ(op2->GetResultByIndex(0).type().dialect().id(), paddle_dialect->id()); @@ -155,7 +161,7 @@ TEST(program_test, program) { op3_attribute, {dense_tensor_dtype}, op3_info); - program.InsertOp(op3); + block->push_back(op3); phi::CPUContext *dev_ctx = static_cast( paddle::platform::DeviceContextPool::Instance().Get( @@ -196,9 +202,12 @@ TEST(program_test, program) { ir::OpInfo op4_info = ctx->GetRegisteredOpInfo(op4_name); std::unordered_map op4_attribute{ {"parameter_name", ir::StrAttribute::get(ctx, "c")}}; - ir::Operation *op4 = ir::Operation::create( - {op3->GetResultByIndex(0)}, op4_attribute, {}, op4_info); - program.InsertOp(op4); + + ir::OperationArgument op4_argument( + {op3->GetResultByIndex(0)}, {}, {}, op4_info); + op4_argument.addAttributes(op4_attribute.begin(), op4_attribute.end()); + ir::Operation *op4 = ir::Operation::create(std::move(op4_argument)); + block->push_back(op4); EXPECT_EQ(op4->GetOperandByIndex(0).impl()->source().type().dialect().id(), paddle_dialect->id()); @@ -244,7 +253,7 @@ TEST(program_test, slice_combine_test) { {"parameter_name", ir::StrAttribute::get(ctx, "a")}}; ir::Operation *op1 = ir::Operation::create({}, op1_attribute, {fp32_dtype}, op1_info); - program.InsertOp(op1); + program.block()->push_back(op1); // (5) Def b = GetParameterOp("b") std::string op2_name = std::string(ir::GetParameterOp::name()); @@ -253,7 +262,7 @@ TEST(program_test, slice_combine_test) { {"parameter_name", ir::StrAttribute::get(ctx, "b")}}; ir::Operation *op2 = ir::Operation::create({}, op2_attribute, {fp32_dtype}, op2_info); - program.InsertOp(op2); + program.block()->push_back(op2); // (6) Def combine_op = CombineOp("a", "b") std::string combine_op_name = std::string(ir::CombineOp::name()); @@ -265,7 +274,7 @@ TEST(program_test, slice_combine_test) { {}, {output_type}, combine_op_info); - program.InsertOp(combine_op); + program.block()->push_back(combine_op); // (7) Def slice_op = SliceOp(combine_op, 0) std::string slice_op_name = std::string(ir::SliceOp::name()); @@ -276,7 +285,7 @@ TEST(program_test, slice_combine_test) { {{"index", index_attr}}, {fp32_dtype}, slice_op_info); - program.InsertOp(slice_op); + program.block()->push_back(slice_op); // (8) Traverse Program EXPECT_EQ(program.block()->size() == 4, true); -- GitLab