From 7e1dd338eee911a9e2a0151170a1560544fce6e2 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 24 May 2023 16:56:51 +0800 Subject: [PATCH] [IR] Add vector type support for program translator (#54035) * add vector type support for program translator * polish * resolve conflicts * add verify for combine/slice and unittests * polish --- paddle/fluid/translator/op_translator.cc | 142 +++++++++++++++--- paddle/fluid/translator/op_translator.h | 3 +- paddle/fluid/translator/program_translator.cc | 15 +- paddle/fluid/translator/program_translator.h | 28 +++- paddle/fluid/translator/translate.h | 7 +- paddle/fluid/translator/type_translator.cc | 4 + paddle/fluid/translator/type_translator.h | 10 +- paddle/ir/builtin_dialect.cc | 5 +- paddle/ir/builtin_op.cc | 102 +++++++++++++ paddle/ir/builtin_op.h | 36 ++++- paddle/ir/printer.cc | 40 +++-- paddle/ir/type.cc | 7 +- paddle/ir/type.h | 2 + test/cpp/ir/ir_program_test.cc | 59 ++++++++ test/cpp/ir/program_translator_test.cc | 5 +- 15 files changed, 412 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index c6ff3f94125..65102a4eb0b 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -24,6 +24,8 @@ #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/translator/program_translator.h" #include "paddle/fluid/translator/type_translator.h" +#include "paddle/ir/builtin_op.h" +#include "paddle/ir/builtin_type.h" #include "paddle/ir/ir_context.h" #include "paddle/ir/value.h" #include "paddle/phi/core/enforce.h" @@ -84,23 +86,101 @@ inline ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) { return op_info; } +inline ir::Operation* InsertSliceOperationForTarget( + ir::IrContext* ctx, + TranslationContext* param_map, + ir::Program* program, + const VariableDefiningInfo& defining_info, + const std::string& arg_name) { + std::string slice_op_name(ir::SliceOp::name()); + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(slice_op_name); + std::unordered_map op_attribute_map = { + {"index", ir::Int32_tAttribute::get(ctx, defining_info.idx_in_vector)}, + }; + ir::VectorType src_vec_type = + defining_info.value.type().dyn_cast(); + ir::Operation* operation = + ir::Operation::create({defining_info.value}, + {src_vec_type[defining_info.idx_in_vector]}, + op_attribute_map, + op_info); + program->InsertOp(operation); + ir::OpResult target_op_result = operation->GetResultByIndex(0); + (*param_map)[arg_name] = VariableDefiningInfo(target_op_result); + return operation; +} + +inline ir::Operation* InsertCombineOperationForTarget( + ir::IrContext* ctx, + TranslationContext* param_map, + ir::Program* program, + const std::vector& args) { + std::string combine_op_name(ir::CombineOp::name()); + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(combine_op_name); + + std::vector src_values; + std::vector types_in_vec; + for (auto arg_name : args) { + auto defining_info = param_map->at(arg_name); + src_values.push_back(defining_info.value); + types_in_vec.push_back(defining_info.value.type()); + } + ir::Type target_vec_type = ir::VectorType::get(ctx, types_in_vec); + ir::Operation* operation = + ir::Operation::create(src_values, {target_vec_type}, {}, op_info); + program->InsertOp(operation); + return operation; +} + inline std::vector GenerateOperationInput( - TranslationContext* param_map, const OpDesc& op_desc) { + ir::IrContext* ctx, + TranslationContext* param_map, + ir::Program* program, + const OpDesc& op_desc) { std::vector op_inputs = {}; + + // scan all inputs to see if any of them is generated as a vector + // so need an additional `SliceOp` to take it out. for (const auto& n : op_desc.Inputs()) { auto& name = n.first; - VLOG(10) << "[input retriving]" - << "[" << op_desc.Type() << "]" << name; auto& args = n.second; + for (const auto& arg_name : args) { PADDLE_ENFORCE_NE( param_map->count(arg_name), 0, platform::errors::PreconditionNotMet( - "arg %s as input should be exists before prasing %d", + "arg %s.%s as input should be exists before prasing %d", + name, arg_name, op_desc.Type())); - op_inputs.push_back((*param_map)[arg_name]); + auto defining_info = (*param_map)[arg_name]; + if (defining_info.generated_by_vector) { + InsertSliceOperationForTarget( + ctx, param_map, program, defining_info, arg_name); + } + } + } + + for (const auto& n : op_desc.Inputs()) { + auto& name = n.first; + VLOG(10) << "[input retriving]" + << "[" << op_desc.Type() << "]" << name; + auto& args = n.second; + + // if src type is Tensor or a Vector with size <= 1 + if (args.size() <= 1) { + for (const auto& arg_name : args) { + auto defining_info = (*param_map)[arg_name]; + op_inputs.push_back(defining_info.value); + } + + // if src type is Vector , need an additional `CombineOp` to + // assemble them. + } else { + auto* combine_op = + InsertCombineOperationForTarget(ctx, param_map, program, args); + op_inputs.push_back(combine_op->GetResultByIndex(0)); } } return op_inputs; @@ -119,16 +199,39 @@ inline std::tuple GenerateOperationOutput( VLOG(10) << "[output translating]" << "[" << op_desc.Type() << "]" << name; auto& args = n.second; - for (const auto& arg_name : args) { - VarDesc* var = block->FindVarRecursive(arg_name); - VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << name << " " << arg_name << " " - << var->GetType(); - - ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); - arg_to_idx[arg_name] = op_output_types.size(); - op_output_types.push_back(translated_var_type); + size_t cur_output_idx = op_output_types.size(); + + // if src type is Tensor or a Vector with size <= 1 + if (args.size() <= 1) { + for (const auto& arg_name : args) { + VarDesc* var = block->FindVarRecursive(arg_name); + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << name << " " << arg_name + << " " << var->GetType(); + + ir::Type translated_var_type = + type_translator[var->GetType()](ctx, *var); + + arg_to_idx[arg_name] = cur_output_idx; + op_output_types.push_back(translated_var_type); + } + + // if src type is Vector + } else { + std::vector types; + for (const auto& arg_name : args) { + VarDesc* var = block->FindVarRecursive(arg_name); + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << name << " " << arg_name + << " " << var->GetType(); + ir::Type translated_var_type = + type_translator[var->GetType()](ctx, *var); + types.push_back(translated_var_type); + arg_to_idx[arg_name] = cur_output_idx; + } + ir::Type vec_type = ir::VectorType::get(ctx, types); + op_output_types.push_back(vec_type); } } return {op_output_types, arg_to_idx}; @@ -143,12 +246,17 @@ inline void RecordOpResultMapping(TranslationContext* param_map, VLOG(10) << "[output recording]" << "[" << op_desc.Type() << "]" << name; auto& args = n.second; + size_t idx_in_vector = 0; for (const auto& arg_name : args) { auto idx = arg_to_idx.at(arg_name); VLOG(10) << "[output recording]" << "[" << op_desc.Type() << "]" << arg_name << " " << idx; - (*param_map)[arg_name] = operation->GetResultByIndex(idx); + ir::OpResult value = operation->GetResultByIndex(idx); + bool generated_by_vector = value.type().isa(); + (*param_map)[arg_name] = VariableDefiningInfo( + value, generated_by_vector, generated_by_vector ? idx_in_vector : -1); + idx_in_vector++; } } } @@ -157,7 +265,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - auto op_inputs = GenerateOperationInput(param_map, op_desc); + auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc); OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types = {}; @@ -193,7 +301,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - auto op_inputs = GenerateOperationInput(param_map, op_desc); + auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc); OpOutputTypeList op_output_types = {}; auto op_info = LoopkUpOpInfo(ctx, op_desc); diff --git a/paddle/fluid/translator/op_translator.h b/paddle/fluid/translator/op_translator.h index c767f639d53..92c03458300 100644 --- a/paddle/fluid/translator/op_translator.h +++ b/paddle/fluid/translator/op_translator.h @@ -20,6 +20,7 @@ #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/var_desc.h" +#include "paddle/fluid/translator/program_translator.h" #include "paddle/ir/ir_context.h" #include "paddle/ir/operation.h" #include "paddle/ir/program.h" @@ -28,8 +29,6 @@ namespace paddle { namespace translator { -using TranslationContext = std::unordered_map; - class OpTranslator { public: using ResultIdx = size_t; diff --git a/paddle/fluid/translator/program_translator.cc b/paddle/fluid/translator/program_translator.cc index 7cdbef58906..7618972f108 100644 --- a/paddle/fluid/translator/program_translator.cc +++ b/paddle/fluid/translator/program_translator.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/translator/op_translator.h" +#include "paddle/fluid/translator/type_translator.h" #include "paddle/ir/attribute.h" #include "paddle/ir/builtin_op.h" #include "paddle/ir/builtin_type.h" @@ -38,6 +39,11 @@ ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, ctx = ir::IrContext::Instance(); } +const std::unordered_set ProgramTranslator::no_cast_var_names = { + "feed", + "fetch", +}; + void ProgramTranslator::Translate() { PADDLE_ENFORCE_EQ( legacy_program->Size(), @@ -59,19 +65,24 @@ void ProgramTranslator::Translate() { void ProgramTranslator::ExtractParameterFromSingleBlock( const BlockDesc& block) { + auto& type_translator = TypeTranslator::instance(); + for (auto& var : block.AllVars()) { if (!var->Persistable()) continue; if (param_map.count(var->Name()) != 0) continue; + if (no_cast_var_names.count(var->Name()) != 0) continue; std::string get_parameter_op_name(ir::GetParameterOp::name()); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name); std::unordered_map op_attribute_map = { {var->Name(), ir::StrAttribute::get(ctx, var->Name())}, }; + ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); ir::Operation* operation = ir::Operation::create( - {}, {ir::Float32Type::get(ctx)}, op_attribute_map, op_info); + {}, {translated_var_type}, op_attribute_map, op_info); program->InsertOp(operation); - param_map[var->Name()] = operation->GetResultByIndex(0); + param_map[var->Name()] = + VariableDefiningInfo(operation->GetResultByIndex(0)); VLOG(10) << "[op translated][get parameter]" << operation; program->SetParameter(var->Name(), nullptr); diff --git a/paddle/fluid/translator/program_translator.h b/paddle/fluid/translator/program_translator.h index 569b93b06aa..f7fd4e2890e 100644 --- a/paddle/fluid/translator/program_translator.h +++ b/paddle/fluid/translator/program_translator.h @@ -27,7 +27,25 @@ namespace paddle { namespace translator { -using TranslationContext = std::unordered_map; +struct VariableDefiningInfo { + VariableDefiningInfo(ir::OpResult value, + bool generated_by_vector = false, + int idx_in_vector = -1) + : value(value), + generated_by_vector(generated_by_vector), + idx_in_vector(idx_in_vector) {} + VariableDefiningInfo() {} + + ir::OpResult value; + + bool generated_by_vector = + false; // true if target variabe is generated by Vector + int idx_in_vector = + -1; // positive if target variabe is generated by Vector +}; + +using TranslationContext = + std::unordered_map; class ProgramTranslator { using ProgramDesc = ::paddle::framework::ProgramDesc; @@ -45,6 +63,14 @@ class ProgramTranslator { TranslationContext param_map; ir::IrContext* ctx; + /// In the legacy program desc, there are two special named varibales: + /// 1. "feed", the input variable of feed op + /// 2. "fetch", the output variable of fetch op + /// However, new feed has no input and new fetch has no output + /// So we don't handle these two vairables when + /// `ExtractParameterFromSingleBlock` + static const std::unordered_set no_cast_var_names; + void ExtractParameterFromSingleBlock(const BlockDesc& block); void InsertOperationToSingleBlock(const BlockDesc& block); }; diff --git a/paddle/fluid/translator/translate.h b/paddle/fluid/translator/translate.h index aa2571f74c4..9cf013ea2bc 100644 --- a/paddle/fluid/translator/translate.h +++ b/paddle/fluid/translator/translate.h @@ -22,10 +22,7 @@ namespace paddle { -using LegacyProgramDesc = ::paddle::framework::ProgramDesc; -using Program = ::ir::Program; - -std::unique_ptr TranslateLegacyProgramToProgram( - const LegacyProgramDesc& legacy_program); +std::unique_ptr<::ir::Program> TranslateLegacyProgramToProgram( + const ::paddle::framework::ProgramDesc& legacy_program); } // namespace paddle diff --git a/paddle/fluid/translator/type_translator.cc b/paddle/fluid/translator/type_translator.cc index 9792a4b8537..e0c31913ffe 100644 --- a/paddle/fluid/translator/type_translator.cc +++ b/paddle/fluid/translator/type_translator.cc @@ -22,6 +22,10 @@ namespace paddle { namespace translator { +using OpDesc = paddle::framework::OpDesc; +using BlockDesc = paddle::framework::BlockDesc; +using VarDesc = paddle::framework::VarDesc; +using VarType = paddle::framework::proto::VarType; using DenseTensorType = paddle::dialect::DenseTensorType; using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage; diff --git a/paddle/fluid/translator/type_translator.h b/paddle/fluid/translator/type_translator.h index b16c1a222a5..707b8913f33 100644 --- a/paddle/fluid/translator/type_translator.h +++ b/paddle/fluid/translator/type_translator.h @@ -27,13 +27,13 @@ namespace paddle { namespace translator { -using OpDesc = paddle::framework::OpDesc; -using BlockDesc = paddle::framework::BlockDesc; -using VarDesc = paddle::framework::VarDesc; -using VarType = paddle::framework::proto::VarType; -using TypeTranslateFn = std::function; +using TypeTranslateFn = + std::function; class TypeTranslator { + public: + using VarType = paddle::framework::proto::VarType; + private: TypeTranslator(); // Disallow instantiation outside of the class. std::unordered_map handlers; diff --git a/paddle/ir/builtin_dialect.cc b/paddle/ir/builtin_dialect.cc index 9c8cacf7bff..32f8750b710 100644 --- a/paddle/ir/builtin_dialect.cc +++ b/paddle/ir/builtin_dialect.cc @@ -44,7 +44,10 @@ void BuiltinDialect::initialize() { ir::Int64_tAttribute, ir::ArrayAttribute>(); - RegisterOps(); + RegisterOps(); } } // namespace ir diff --git a/paddle/ir/builtin_op.cc b/paddle/ir/builtin_op.cc index 2f5be8a8c26..63bfc2196dc 100644 --- a/paddle/ir/builtin_op.cc +++ b/paddle/ir/builtin_op.cc @@ -13,7 +13,10 @@ // limitations under the License. #include "paddle/ir/builtin_op.h" + #include "paddle/ir/builtin_attribute.h" +#include "paddle/ir/builtin_type.h" +#include "paddle/phi/core/enforce.h" namespace ir { const char *GetParameterOp::attributes_name[attributes_num] = { @@ -58,4 +61,103 @@ void SetParameterOp::verify(const std::vector &inputs, } } +void CombineOp::verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes) { + // outputs.size() == 1 + PADDLE_ENFORCE_EQ( + outputs.size(), + 1, + phi::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 1.", outputs.size())); + // outputs[0].type == Vector + PADDLE_ENFORCE(outputs[0].isa(), + phi::errors::PreconditionNotMet( + "The type %s of outputs[0] must be equal to VectorType.", + outputs[0])); + ir::VectorType output_type = outputs[0].dyn_cast(); + // inputs.size() == outputs[0].size() + PADDLE_ENFORCE_EQ( + output_type.size(), + inputs.size(), + phi::errors::PreconditionNotMet( + "The size %d of outputs[0] must be equal to size %d of inputs.", + output_type.size(), + inputs.size())); + + // forall i in inputs.size(): inputs[i].type == outputs[0][i].type + for (size_t i = 0; i < inputs.size(); i++) { + PADDLE_ENFORCE_EQ( + output_type[i], + inputs[i].type(), + phi::errors::PreconditionNotMet("The type %s of outputs[0][%d] must be " + "equal to type %s of inputs[%d].", + output_type[i], + i, + inputs[i].type(), + i)); + } +} + +const char *SliceOp::attributes_name[attributes_num] = {"index"}; +void SliceOp::verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes) { + // inputs.size() == 1 + PADDLE_ENFORCE_EQ( + inputs.size(), + 1, + phi::errors::PreconditionNotMet( + "The size %d of inputs must be equal to 1.", inputs.size())); + + // inputs[0].type == Vector + PADDLE_ENFORCE(inputs[0].type().isa(), + phi::errors::PreconditionNotMet( + "The type %s of inputs[0] must be equal to VectorType.", + inputs[0].type())); + ir::VectorType input_type = inputs[0].type().dyn_cast(); + + // outputs.size() == 1 + PADDLE_ENFORCE_EQ( + outputs.size(), + 1, + phi::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 1.", outputs.size())); + + // attributes contains index: Int32 + PADDLE_ENFORCE_NE( + attributes.count("index"), + 0, + phi::errors::PreconditionNotMet("The attributes must contains index.")); + const ir::Attribute &attr = attributes.at("index"); + PADDLE_ENFORCE( + attr.isa(), + phi::errors::PreconditionNotMet("The attribute index must be INT32.")); + auto index = attr.dyn_cast().data(); + + // index >= 0 and < inputs[0].size() + PADDLE_ENFORCE_GE( + index, + 0, + phi::errors::PreconditionNotMet( + "The index %d must be greater or equal than 0.", index)); + PADDLE_ENFORCE_LT( + index, + input_type.size(), + phi::errors::PreconditionNotMet( + "The index %d must be less or equal than size %d of inputs[0].", + index, + input_type.size())); + + // inputs[index].type == outputs[0].type + PADDLE_ENFORCE_EQ( + input_type[index], + outputs[0], + phi::errors::PreconditionNotMet( + "The type %s of inputs[%d] must be equal to type %s of outputs[0].", + input_type[index], + index, + outputs[0])); +} + } // namespace ir diff --git a/paddle/ir/builtin_op.h b/paddle/ir/builtin_op.h index c1953136f8c..d1d2c20b572 100644 --- a/paddle/ir/builtin_op.h +++ b/paddle/ir/builtin_op.h @@ -33,7 +33,7 @@ class GetParameterOp : public ir::Op { }; /// -/// \brief GetParameterOp: SetParameterOp(OpOperand, {StrAttribute, +/// \brief SetParameterOp: SetParameterOp(OpOperand, {StrAttribute, /// StrAttribute}) /// class SetParameterOp : public ir::Op { @@ -47,4 +47,38 @@ class SetParameterOp : public ir::Op { const ir::AttributeMap &attributes); }; +/// +/// \brief CombineOp: CombineOp(OpOperand) +/// +class CombineOp : public ir::Op { + public: + using Op::Op; + + static const char *name() { return "builtin.combine"; } + + static constexpr uint32_t attributes_num = 0; + + static constexpr const char **attributes_name = nullptr; + static void verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes); +}; + +/// +/// \brief SliceOp: SliceOp(OpOperand) +/// +class SliceOp : public ir::Op { + public: + using Op::Op; + + static const char *name() { return "builtin.slice"; } + + static constexpr uint32_t attributes_num = 1; + + static const char *attributes_name[attributes_num]; + static void verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes); +}; + } // namespace ir diff --git a/paddle/ir/printer.cc b/paddle/ir/printer.cc index b4d3acb930a..421dedabb3a 100644 --- a/paddle/ir/printer.cc +++ b/paddle/ir/printer.cc @@ -28,6 +28,21 @@ namespace ir { namespace { constexpr char newline[] = "\n"; + +template +void PrintInterleave(ForwardIterator begin, + ForwardIterator end, + UnaryFunctor print_func, + NullFunctor between_func) { + if (begin == end) return; + print_func(*begin); + begin++; + for (; begin != end; begin++) { + between_func(); + print_func(*begin); + } +} + } // namespace class Printer { @@ -47,6 +62,15 @@ class Printer { os << "i32"; } else if (type.isa()) { os << "i64"; + } else if (type.isa()) { + os << "vec<"; + auto inner_types = type.dyn_cast().data(); + PrintInterleave( + inner_types.begin(), + inner_types.end(), + [this](ir::Type v) { this->PrintType(v); }, + [this]() { this->os << ","; }); + os << ">"; } else { auto& dialect = type.dialect(); dialect.PrintType(type, os); @@ -77,22 +101,6 @@ class ProgramPrinter : public Printer { } } - template - void PrintInterleave(ForwardIterator begin, - ForwardIterator end, - UnaryFunctor print_func, - NullFunctor between_func) { - if (begin == end) return; - print_func(*begin); - begin++; - for (; begin != end; begin++) { - between_func(); - print_func(*begin); - } - } - void PrintValue(ir::Value v) { const void* key = static_cast(v.impl()); auto ret = aliases.find(key); diff --git a/paddle/ir/type.cc b/paddle/ir/type.cc index bde3194f8fd..e9c24672e5b 100644 --- a/paddle/ir/type.cc +++ b/paddle/ir/type.cc @@ -16,6 +16,11 @@ #include "paddle/ir/dialect.h" namespace ir { -IrContext *Type::ir_context() const { return dialect().ir_context(); } +IrContext* Type::ir_context() const { return dialect().ir_context(); } + +std::ostream& operator<<(std::ostream& os, Type type) { + type.print(os); + return os; +} } // namespace ir diff --git a/paddle/ir/type.h b/paddle/ir/type.h index fce17db82eb..89d153c0894 100644 --- a/paddle/ir/type.h +++ b/paddle/ir/type.h @@ -89,6 +89,8 @@ class Type { const Storage *storage_{nullptr}; }; +std::ostream &operator<<(std::ostream &os, Type type); + } // namespace ir namespace std { diff --git a/test/cpp/ir/ir_program_test.cc b/test/cpp/ir/ir_program_test.cc index c430d9b320b..7c6c9acaf52 100644 --- a/test/cpp/ir/ir_program_test.cc +++ b/test/cpp/ir/ir_program_test.cc @@ -211,3 +211,62 @@ TEST(program_test, program) { EXPECT_EQ(ops.size() == 4, true); EXPECT_EQ(program.parameters_num() == 3, true); } + +TEST(program_test, slice_combine_test) { + // (1) Init environment. + ir::IrContext *ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + + // (2) Create an empty program object + ir::Program program; + // ir::Program *program = new ir::Program(); + EXPECT_EQ(program.ops().size() == 0, true); + + // (3) Create a float32 DenseTensor Parameter and save into Program + ir::Type fp32_dtype = ir::Float32Type::get(ctx); + + // (4) Def a = GetParameterOp("a") + std::string op1_name = ir::GetParameterOp::name(); + ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name); + std::unordered_map op1_attribute{ + {"parameter_name", ir::StrAttribute::get(ctx, "a")}}; + ir::Operation *op1 = + ir::Operation::create({}, {fp32_dtype}, op1_attribute, op1_info); + program.InsertOp(op1); + + // (5) Def b = GetParameterOp("b") + std::string op2_name = std::string(ir::GetParameterOp::name()); + ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name); + std::unordered_map op2_attribute{ + {"parameter_name", ir::StrAttribute::get(ctx, "b")}}; + ir::Operation *op2 = + ir::Operation::create({}, {fp32_dtype}, op2_attribute, op2_info); + program.InsertOp(op2); + + // (6) Def combine_op = CombineOp("a", "b") + std::string combine_op_name = std::string(ir::CombineOp::name()); + ir::OpInfo combine_op_info = ctx->GetRegisteredOpInfo(combine_op_name); + ir::Type output_type = + ir::VectorType::get(ctx, std::vector({fp32_dtype, fp32_dtype})); + ir::Operation *combine_op = ir::Operation::create( + {op1->GetResultByIndex(0), op2->GetResultByIndex(0)}, + {output_type}, + {}, + combine_op_info); + program.InsertOp(combine_op); + + // (7) Def slice_op = SliceOp(combine_op, 0) + std::string slice_op_name = std::string(ir::SliceOp::name()); + ir::OpInfo slice_op_info = ctx->GetRegisteredOpInfo(slice_op_name); + ir::Attribute index_attr = ir::Int32_tAttribute::get(ctx, 0); + ir::Operation *slice_op = + ir::Operation::create({combine_op->GetResultByIndex(0)}, + {fp32_dtype}, + {{"index", index_attr}}, + slice_op_info); + program.InsertOp(slice_op); + + // (8) Traverse Program + std::list ops = program.ops(); + EXPECT_EQ(ops.size() == 4, true); +} diff --git a/test/cpp/ir/program_translator_test.cc b/test/cpp/ir/program_translator_test.cc index 38cb983ca6d..5ae7b0b9b31 100644 --- a/test/cpp/ir/program_translator_test.cc +++ b/test/cpp/ir/program_translator_test.cc @@ -59,6 +59,7 @@ TEST(PaddleDialectTest, Translator) { // auto program = paddle::TranslateLegacyProgramToProgram(p); // std::list ops = program->ops(); - // EXPECT_EQ(ops.size(), p.Block(0).OpSize() + program->parameters_num()); - // VLOG(0) << *program << std::endl; + // ops.size() = op size in BlockDesc + get_parameter_op + combine op + // EXPECT_EQ(ops.size(), p.Block(0).OpSize() + program->parameters_num() + + // 20); std::cout << *program << std::endl; } -- GitLab