From 9a8e94179c3e6a08c571d1abcc778654570c646d Mon Sep 17 00:00:00 2001 From: winter-wang <78149749+winter-wang@users.noreply.github.com> Date: Fri, 2 Jun 2023 12:48:08 +0800 Subject: [PATCH] [IR] standardize the use of new IR api. (#54289) --- paddle/fluid/dialect/op_gen.py | 6 ++--- paddle/fluid/translator/program_translator.cc | 1 + paddle/fluid/translator/translate.cc | 2 +- paddle/ir/core/attribute_base.h | 16 ------------- paddle/ir/core/builder.cc | 7 +++--- paddle/ir/core/builder.h | 1 - paddle/ir/core/builtin_op.cc | 2 +- paddle/ir/core/operation.cc | 1 + paddle/ir/core/operation.h | 7 +++--- paddle/ir/core/operation_utils.h | 18 +++++++-------- paddle/ir/core/printer.cc | 6 ++--- paddle/ir/core/program.cc | 2 -- paddle/ir/core/program.h | 2 +- paddle/ir/core/type_base.h | 13 ----------- paddle/ir/core/value.cc | 16 +++++-------- paddle/ir/core/value.h | 23 +++++++++++++------ test/cpp/ir/core/ir_op_test.cc | 6 ++--- test/cpp/ir/core/ir_program_test.cc | 15 ++++++------ test/cpp/ir/core/ir_value_test.cc | 17 ++++++-------- 19 files changed, 67 insertions(+), 94 deletions(-) diff --git a/paddle/fluid/dialect/op_gen.py b/paddle/fluid/dialect/op_gen.py index 0d8c4d336f1..3aaee4e42ef 100644 --- a/paddle/fluid/dialect/op_gen.py +++ b/paddle/fluid/dialect/op_gen.py @@ -613,7 +613,7 @@ def GenBuildInputArgsStr( def GenBuildInputs(op_input_name_list): BUILD_INPUT_TEMPLATE = """ std::vector argument_inputs = {{{inputs_args}}}; - argument.addOperands(argument_inputs.begin(), argument_inputs.end()); + argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); """ build_input_str = "" if len(op_input_name_list) > 0: @@ -693,7 +693,7 @@ def GenBuildAttributes(op_attribute_name_list, op_attribute_type_list): op_attribute_type=op_attribute_type_list[idx], attr=op_attribute_name_list[idx], ) - attr_str += """ argument.addAttribute("{attr_name}", attr_{attr_name});\n""".format( + attr_str += """ argument.AddAttribute("{attr_name}", attr_{attr_name});\n""".format( attr_name=op_attribute_name_list[idx] ) @@ -847,7 +847,7 @@ def GenBuildOutputs( name=op_output_name_list[idx] ) - build_output_str += " argument.addTypes(argument_outputs.begin(), argument_outputs.end());\n" + build_output_str += " argument.AddTypes(argument_outputs.begin(), argument_outputs.end());\n" return build_output_str diff --git a/paddle/fluid/translator/program_translator.cc b/paddle/fluid/translator/program_translator.cc index ff8dc88225f..2b98e4e11cf 100644 --- a/paddle/fluid/translator/program_translator.cc +++ b/paddle/fluid/translator/program_translator.cc @@ -22,6 +22,7 @@ #include "paddle/fluid/translator/op_translator.h" #include "paddle/fluid/translator/type_translator.h" #include "paddle/ir/core/attribute.h" +#include "paddle/ir/core/block.h" #include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/operation.h" diff --git a/paddle/fluid/translator/translate.cc b/paddle/fluid/translator/translate.cc index eaf9c35d403..2fdbd39b34e 100644 --- a/paddle/fluid/translator/translate.cc +++ b/paddle/fluid/translator/translate.cc @@ -28,7 +28,7 @@ using Program = ::ir::Program; std::unique_ptr TranslateLegacyProgramToProgram( const LegacyProgramDesc& legacy_program) { - auto program = std::make_unique(); + auto program = std::make_unique(ir::IrContext::Instance()); translator::ProgramTranslator program_translator(&legacy_program, program.get()); diff --git a/paddle/ir/core/attribute_base.h b/paddle/ir/core/attribute_base.h index 8ad3c877416..dd120a02a45 100644 --- a/paddle/ir/core/attribute_base.h +++ b/paddle/ir/core/attribute_base.h @@ -263,20 +263,4 @@ struct AttributeManager { return ir::AttributeManager::template get(ctx, \ args...); \ } - -/// -/// \brief This macro definition is used to register custom Attribute class. -/// -#define REGISTER_ATTRIBUTE_2_IRCONTEXT(concrete_attribute, dialect) \ - ir::AbstractAttribute *abstract_attribute_##concrete_attribute = \ - new ir::AbstractAttribute(std::move( \ - ir::AbstractAttribute::get(*dialect))); \ - \ - dialect->ir_context()->RegisterAbstractAttribute( \ - ir::TypeId::get(), \ - abstract_attribute_##concrete_attribute); \ - \ - ir::AttributeManager::RegisterAttribute( \ - dialect->ir_context()); - } // namespace ir diff --git a/paddle/ir/core/builder.cc b/paddle/ir/core/builder.cc index 13c16db6a6b..4f789e7ce2b 100644 --- a/paddle/ir/core/builder.cc +++ b/paddle/ir/core/builder.cc @@ -14,6 +14,7 @@ #include "paddle/ir/core/builder.h" #include "paddle/ir/core/region.h" +#include "paddle/ir/core/value.h" namespace ir { Operation *Builder::insert(Operation *op) { @@ -31,10 +32,10 @@ Operation *Builder::create(OperationArgument &&argument) { } /// Creates an operation with the given fields. -Operation *Builder::create(const std::vector &inputs, +Operation *Builder::create(const std::vector &inputs, const AttributeMap &attribute, - const std::vector &output_types, - ir::OpInfo op_info) { + const std::vector &output_types, + OpInfo op_info) { return create(OperationArgument(inputs, attribute, output_types, op_info)); } diff --git a/paddle/ir/core/builder.h b/paddle/ir/core/builder.h index 7bd187961c3..3c8a579811f 100644 --- a/paddle/ir/core/builder.h +++ b/paddle/ir/core/builder.h @@ -18,7 +18,6 @@ #include "paddle/ir/core/block.h" #include "paddle/ir/core/operation.h" -#include "paddle/ir/core/program.h" namespace ir { /// diff --git a/paddle/ir/core/builtin_op.cc b/paddle/ir/core/builtin_op.cc index 2da74f17b52..105edd83caa 100644 --- a/paddle/ir/core/builtin_op.cc +++ b/paddle/ir/core/builtin_op.cc @@ -41,7 +41,7 @@ ModuleOp ModuleOp::create(IrContext *context, Program *pointer) { ir::OpInfo info = context->GetRegisteredOpInfo(name()); OperationArgument argument(info); argument.AddRegion()->emplace_back(); - argument.addAttribute("program", PointerAttribute::get(context, pointer)); + argument.AddAttribute("program", PointerAttribute::get(context, pointer)); return ModuleOp(Operation::create(std::move(argument))); } diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index b7aa58a9b0e..1f068000c98 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -18,6 +18,7 @@ #include "paddle/ir/core/program.h" #include "paddle/ir/core/region.h" #include "paddle/ir/core/utils.h" +#include "paddle/ir/core/value_impl.h" namespace ir { Operation *Operation::create(OperationArgument &&argument) { diff --git a/paddle/ir/core/operation.h b/paddle/ir/core/operation.h index 0775f68d63e..759e91f9052 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/ir/core/operation.h @@ -18,12 +18,13 @@ #include "paddle/ir/core/op_info.h" #include "paddle/ir/core/operation_utils.h" #include "paddle/ir/core/type.h" -#include "paddle/ir/core/value_impl.h" namespace ir { class OpBase; class Program; class Block; +class OpOperand; +class OpResult; class alignas(8) Operation final { public: @@ -47,9 +48,9 @@ class alignas(8) Operation final { IrContext *ir_context() const; - ir::OpResult GetResultByIndex(uint32_t index) const; + OpResult GetResultByIndex(uint32_t index) const; - ir::OpOperand GetOperandByIndex(uint32_t index) const; + OpOperand GetOperandByIndex(uint32_t index) const; std::string print(); diff --git a/paddle/ir/core/operation_utils.h b/paddle/ir/core/operation_utils.h index a0f9d9f237a..7c012bcd0d5 100644 --- a/paddle/ir/core/operation_utils.h +++ b/paddle/ir/core/operation_utils.h @@ -14,11 +14,11 @@ #pragma once -#include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/core/attribute.h" #include "paddle/ir/core/op_info.h" #include "paddle/ir/core/region.h" #include "paddle/ir/core/type.h" -#include "paddle/ir/core/value_impl.h" +#include "paddle/ir/core/value.h" namespace ir { @@ -52,18 +52,18 @@ struct OperationArgument { regions(std::move(regions)) {} template - void addOperands(InputIt first, InputIt last); + void AddOperands(InputIt first, InputIt last); template - void addTypes(InputIt first, InputIt last); + void AddTypes(InputIt first, InputIt last); /// Add an attribute with the specified name. - void addAttribute(const std::string& name, Attribute attr) { + void AddAttribute(const std::string& name, Attribute attr) { attributes[name] = attr; } /// Add an array of named attributes. template - void addAttributes(InputIt first, InputIt last); + void AddAttributes(InputIt first, InputIt last); /// Get the context held by this operation state. IrContext* getContext() const { return info.ir_context(); } @@ -74,19 +74,19 @@ struct OperationArgument { }; template -void OperationArgument::addOperands(InputIt first, InputIt last) { +void OperationArgument::AddOperands(InputIt first, InputIt last) { while (first != last) { inputs.emplace_back(*first++); } } template -void OperationArgument::addTypes(InputIt first, InputIt last) { +void OperationArgument::AddTypes(InputIt first, InputIt last) { while (first != last) { output_types.emplace_back(*first++); } } template -void OperationArgument::addAttributes(InputIt first, InputIt last) { +void OperationArgument::AddAttributes(InputIt first, InputIt last) { while (first != last) { attributes[first->first] = first->second; ++first; diff --git a/paddle/ir/core/printer.cc b/paddle/ir/core/printer.cc index 5dc91142fb5..f373f0a70a1 100644 --- a/paddle/ir/core/printer.cc +++ b/paddle/ir/core/printer.cc @@ -17,6 +17,7 @@ #include #include +#include "paddle/ir/core/block.h" #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/dialect.h" @@ -160,7 +161,7 @@ class ProgramPrinter : public Printer { std::vector op_operands; op_operands.reserve(num_op_operands); for (size_t idx = 0; idx < num_op_operands; idx++) { - op_operands.push_back(op->GetOperandByIndex(idx).impl()->source()); + op_operands.push_back(op->GetOperandByIndex(idx).source()); } PrintInterleave( op_operands.begin(), @@ -175,8 +176,7 @@ class ProgramPrinter : public Printer { std::vector op_operand_types; op_operand_types.reserve(num_op_operands); for (size_t idx = 0; idx < num_op_operands; idx++) { - op_operand_types.push_back( - op->GetOperandByIndex(idx).impl()->source().type()); + op_operand_types.push_back(op->GetOperandByIndex(idx).source().type()); } PrintInterleave( op_operand_types.begin(), diff --git a/paddle/ir/core/program.cc b/paddle/ir/core/program.cc index ac97d935305..e104f821bfc 100644 --- a/paddle/ir/core/program.cc +++ b/paddle/ir/core/program.cc @@ -21,8 +21,6 @@ Program::Program(IrContext* context) { module_ = ModuleOp::create(context, this); } -Program::Program() : Program(IrContext::Instance()) {} - Program::~Program() { if (module_) { module_.destroy(); diff --git a/paddle/ir/core/program.h b/paddle/ir/core/program.h index 86402a3f4f7..55b6de2fee8 100644 --- a/paddle/ir/core/program.h +++ b/paddle/ir/core/program.h @@ -17,6 +17,7 @@ #include #include +#include "paddle/ir/core/attribute.h" #include "paddle/ir/core/block.h" #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_op.h" @@ -39,7 +40,6 @@ class Program { using ParameterMap = std::unordered_map>; explicit Program(IrContext* context); - Program(); Program(Program&&) = delete; Program(const Program& program) = delete; Program& operator=(const Program&) = delete; diff --git a/paddle/ir/core/type_base.h b/paddle/ir/core/type_base.h index 3b62ae87fcc..b15b1b2656f 100644 --- a/paddle/ir/core/type_base.h +++ b/paddle/ir/core/type_base.h @@ -259,17 +259,4 @@ struct TypeManager { static concrete_type get(ir::IrContext *ctx, Args... args) { \ return ir::TypeManager::template get(ctx, args...); \ } - -/// -/// \brief This macro definition is used to register custom Type class. -/// -#define REGISTER_TYPE_2_IRCONTEXT(concrete_type, dialect) \ - ir::AbstractType *abstract_type_##concrete_type = new ir::AbstractType( \ - std::move(ir::AbstractType::get(*dialect))); \ - \ - dialect->ir_context()->RegisterAbstractType( \ - ir::TypeId::get(), abstract_type_##concrete_type); \ - \ - ir::TypeManager::RegisterType(dialect->ir_context()); - } // namespace ir diff --git a/paddle/ir/core/value.cc b/paddle/ir/core/value.cc index 631f0ba7adf..72f261238f8 100644 --- a/paddle/ir/core/value.cc +++ b/paddle/ir/core/value.cc @@ -32,19 +32,13 @@ OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) { return *this; } -bool OpOperand::operator==(OpOperand other) const { - return impl_ == other.impl_; -} - -bool OpOperand::operator!=(OpOperand other) const { - return impl_ != other.impl_; -} +OpOperand OpOperand::next_use() const { return impl_->next_use(); } -bool OpOperand::operator!() const { return impl_ == nullptr; } +Value OpOperand::source() const { return impl_->source(); } -OpOperand::operator bool() const { return impl_; } +Operation *OpOperand::owner() const { return impl_->owner(); } -detail::OpOperandImpl *OpOperand::impl() const { return impl_; } +// detail::OpOperandImpl *OpOperand::impl() const { return impl_; } // Value Value::Value(const detail::ValueImpl *impl) @@ -81,6 +75,8 @@ Value::use_iterator Value::begin() const { Value::use_iterator Value::end() const { return Value::use_iterator(); } +OpOperand Value::first_use() const { return impl()->first_use(); } + // OpResult bool OpResult::classof(Value value) { return ir::isa(value.impl()); diff --git a/paddle/ir/core/value.h b/paddle/ir/core/value.h index c0bb5cc4d4d..795aa6f5a80 100644 --- a/paddle/ir/core/value.h +++ b/paddle/ir/core/value.h @@ -19,6 +19,7 @@ namespace ir { class Operation; +class Value; namespace detail { class OpOperandImpl; @@ -42,15 +43,21 @@ class OpOperand { OpOperand &operator=(const detail::OpOperandImpl *impl); - bool operator==(OpOperand other) const; + bool operator==(const OpOperand &other) const { return impl_ == other.impl_; } - bool operator!=(OpOperand other) const; + bool operator!=(const OpOperand &other) const { return !operator==(other); } - bool operator!() const; + bool operator!() const { return impl_ == nullptr; } - explicit operator bool() const; + operator bool() const { return impl_; } - detail::OpOperandImpl *impl() const; + OpOperand next_use() const; + + Value source() const; + + Operation *owner() const; + + // detail::OpOperandImpl *impl() const { return impl_;} private: detail::OpOperandImpl *impl_{nullptr}; @@ -71,14 +78,14 @@ class ValueUseIterator { return !(*this == rhs); } - ir::Operation *owner() const { return current_.impl()->owner(); } + ir::Operation *owner() const { return current_.owner(); } OperandType get() const { return current_; } OperandType operator*() const { return get(); } ValueUseIterator &operator++() { - current_ = current_.impl()->next_use(); + current_ = current_.next_use(); return *this; } @@ -141,6 +148,8 @@ class Value { use_iterator end() const; + OpOperand first_use() const; + friend struct std::hash; protected: diff --git a/test/cpp/ir/core/ir_op_test.cc b/test/cpp/ir/core/ir_op_test.cc index a058b35ef35..3c749d3d1aa 100644 --- a/test/cpp/ir/core/ir_op_test.cc +++ b/test/cpp/ir/core/ir_op_test.cc @@ -103,11 +103,11 @@ class Operation1 : public ir::Op { std::unordered_map attributes = CreateAttributeMap({"op1_attr1", "op1_attr2"}, {"op1_attr1", "op1_attr2"}); - argument.addOperands::iterator>(inputs.begin(), + argument.AddOperands::iterator>(inputs.begin(), inputs.end()); - argument.addTypes::iterator>(output_types.begin(), + argument.AddTypes::iterator>(output_types.begin(), output_types.end()); - argument.addAttributes< + argument.AddAttributes< std::unordered_map::iterator>( attributes.begin(), attributes.end()); } diff --git a/test/cpp/ir/core/ir_program_test.cc b/test/cpp/ir/core/ir_program_test.cc index eab9f872c78..ea18452c2ff 100644 --- a/test/cpp/ir/core/ir_program_test.cc +++ b/test/cpp/ir/core/ir_program_test.cc @@ -188,9 +188,9 @@ TEST(program_test, program) { std::unordered_map abs_op_attribute; std::vector output_types = {dense_tensor_dtype}; ir::OperationArgument abs_argument(abs_info); - abs_argument.addOperands(operands.begin(), operands.end()); - abs_argument.addAttributes(abs_op_attribute.begin(), abs_op_attribute.end()); - abs_argument.addTypes(output_types.begin(), output_types.end()); + abs_argument.AddOperands(operands.begin(), operands.end()); + abs_argument.AddAttributes(abs_op_attribute.begin(), abs_op_attribute.end()); + abs_argument.AddTypes(output_types.begin(), output_types.end()); ir::Operation *abs_op = ir::Operation::create(std::move(abs_argument)); paddle::dialect::GetOpInfoInterface interface = abs_op->dyn_cast(); @@ -205,15 +205,14 @@ TEST(program_test, program) { ir::OperationArgument op4_argument( {op3->GetResultByIndex(0)}, {}, {}, op4_info); - op4_argument.addAttributes(op4_attribute.begin(), op4_attribute.end()); + op4_argument.AddAttributes(op4_attribute.begin(), op4_attribute.end()); ir::Operation *op4 = ir::Operation::create(std::move(op4_argument)); block->push_back(op4); - EXPECT_EQ(op4->GetOperandByIndex(0).impl()->source().type().dialect().id(), + EXPECT_EQ(op4->GetOperandByIndex(0).source().type().dialect().id(), paddle_dialect->id()); Interface *c_interface = op4->GetOperandByIndex(0) - .impl() - ->source() + .source() .type() .dialect() .GetRegisteredInterface(); @@ -239,7 +238,7 @@ TEST(program_test, slice_combine_test) { ctx->GetOrRegisterDialect(); // (2) Create an empty program object - ir::Program program; + ir::Program program(ctx); // ir::Program *program = new ir::Program(); EXPECT_EQ(program.block()->size() == 0, true); diff --git a/test/cpp/ir/core/ir_value_test.cc b/test/cpp/ir/core/ir_value_test.cc index 28e340e52a5..8af7c03e598 100644 --- a/test/cpp/ir/core/ir_value_test.cc +++ b/test/cpp/ir/core/ir_value_test.cc @@ -85,16 +85,13 @@ TEST(value_test, value_test) { // Test 2: op1_first_output -> op4_first_input ir::OpResult op1_first_output = op1->GetResultByIndex(0); - ir::detail::OpOperandImpl *op4_first_input = - reinterpret_cast( - reinterpret_cast(op4) + sizeof(ir::Operation)); - EXPECT_EQ(static_cast(op1_first_output).impl()->first_use(), - op4_first_input); - ir::detail::OpOperandImpl *op3_first_input = - reinterpret_cast( - reinterpret_cast(op3) + sizeof(ir::Operation)); - EXPECT_EQ(op4_first_input->next_use(), op3_first_input); - EXPECT_EQ(op3_first_input->next_use(), nullptr); + ir::OpOperand op4_first_input = op4->GetOperandByIndex(0); + + EXPECT_EQ(op1_first_output.first_use(), op4_first_input); + ir::OpOperand op3_first_input = op3->GetOperandByIndex(0); + + EXPECT_EQ(op4_first_input.next_use(), op3_first_input); + EXPECT_EQ(op3_first_input.next_use(), nullptr); // Test 3: Value iterator ir::Value::use_iterator iter = op1->GetResultByIndex(0).begin(); -- GitLab