diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index c6ff3f94125fb5c6b0e6f4cf480f21e9cf28734b..65102a4eb0b1928dd22d045d9d76dd8cc1fc7db1 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -24,6 +24,8 @@ #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/translator/program_translator.h" #include "paddle/fluid/translator/type_translator.h" +#include "paddle/ir/builtin_op.h" +#include "paddle/ir/builtin_type.h" #include "paddle/ir/ir_context.h" #include "paddle/ir/value.h" #include "paddle/phi/core/enforce.h" @@ -84,23 +86,101 @@ inline ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) { return op_info; } +inline ir::Operation* InsertSliceOperationForTarget( + ir::IrContext* ctx, + TranslationContext* param_map, + ir::Program* program, + const VariableDefiningInfo& defining_info, + const std::string& arg_name) { + std::string slice_op_name(ir::SliceOp::name()); + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(slice_op_name); + std::unordered_map op_attribute_map = { + {"index", ir::Int32_tAttribute::get(ctx, defining_info.idx_in_vector)}, + }; + ir::VectorType src_vec_type = + defining_info.value.type().dyn_cast(); + ir::Operation* operation = + ir::Operation::create({defining_info.value}, + {src_vec_type[defining_info.idx_in_vector]}, + op_attribute_map, + op_info); + program->InsertOp(operation); + ir::OpResult target_op_result = operation->GetResultByIndex(0); + (*param_map)[arg_name] = VariableDefiningInfo(target_op_result); + return operation; +} + +inline ir::Operation* InsertCombineOperationForTarget( + ir::IrContext* ctx, + TranslationContext* param_map, + ir::Program* program, + const std::vector& args) { + std::string combine_op_name(ir::CombineOp::name()); + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(combine_op_name); + + std::vector src_values; + std::vector types_in_vec; + for (auto arg_name : args) { + auto defining_info = param_map->at(arg_name); + src_values.push_back(defining_info.value); + types_in_vec.push_back(defining_info.value.type()); + } + ir::Type target_vec_type = ir::VectorType::get(ctx, types_in_vec); + ir::Operation* operation = + ir::Operation::create(src_values, {target_vec_type}, {}, op_info); + program->InsertOp(operation); + return operation; +} + inline std::vector GenerateOperationInput( - TranslationContext* param_map, const OpDesc& op_desc) { + ir::IrContext* ctx, + TranslationContext* param_map, + ir::Program* program, + const OpDesc& op_desc) { std::vector op_inputs = {}; + + // scan all inputs to see if any of them is generated as a vector + // so need an additional `SliceOp` to take it out. for (const auto& n : op_desc.Inputs()) { auto& name = n.first; - VLOG(10) << "[input retriving]" - << "[" << op_desc.Type() << "]" << name; auto& args = n.second; + for (const auto& arg_name : args) { PADDLE_ENFORCE_NE( param_map->count(arg_name), 0, platform::errors::PreconditionNotMet( - "arg %s as input should be exists before prasing %d", + "arg %s.%s as input should be exists before prasing %d", + name, arg_name, op_desc.Type())); - op_inputs.push_back((*param_map)[arg_name]); + auto defining_info = (*param_map)[arg_name]; + if (defining_info.generated_by_vector) { + InsertSliceOperationForTarget( + ctx, param_map, program, defining_info, arg_name); + } + } + } + + for (const auto& n : op_desc.Inputs()) { + auto& name = n.first; + VLOG(10) << "[input retriving]" + << "[" << op_desc.Type() << "]" << name; + auto& args = n.second; + + // if src type is Tensor or a Vector with size <= 1 + if (args.size() <= 1) { + for (const auto& arg_name : args) { + auto defining_info = (*param_map)[arg_name]; + op_inputs.push_back(defining_info.value); + } + + // if src type is Vector , need an additional `CombineOp` to + // assemble them. + } else { + auto* combine_op = + InsertCombineOperationForTarget(ctx, param_map, program, args); + op_inputs.push_back(combine_op->GetResultByIndex(0)); } } return op_inputs; @@ -119,16 +199,39 @@ inline std::tuple GenerateOperationOutput( VLOG(10) << "[output translating]" << "[" << op_desc.Type() << "]" << name; auto& args = n.second; - for (const auto& arg_name : args) { - VarDesc* var = block->FindVarRecursive(arg_name); - VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << name << " " << arg_name << " " - << var->GetType(); - - ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); - arg_to_idx[arg_name] = op_output_types.size(); - op_output_types.push_back(translated_var_type); + size_t cur_output_idx = op_output_types.size(); + + // if src type is Tensor or a Vector with size <= 1 + if (args.size() <= 1) { + for (const auto& arg_name : args) { + VarDesc* var = block->FindVarRecursive(arg_name); + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << name << " " << arg_name + << " " << var->GetType(); + + ir::Type translated_var_type = + type_translator[var->GetType()](ctx, *var); + + arg_to_idx[arg_name] = cur_output_idx; + op_output_types.push_back(translated_var_type); + } + + // if src type is Vector + } else { + std::vector types; + for (const auto& arg_name : args) { + VarDesc* var = block->FindVarRecursive(arg_name); + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << name << " " << arg_name + << " " << var->GetType(); + ir::Type translated_var_type = + type_translator[var->GetType()](ctx, *var); + types.push_back(translated_var_type); + arg_to_idx[arg_name] = cur_output_idx; + } + ir::Type vec_type = ir::VectorType::get(ctx, types); + op_output_types.push_back(vec_type); } } return {op_output_types, arg_to_idx}; @@ -143,12 +246,17 @@ inline void RecordOpResultMapping(TranslationContext* param_map, VLOG(10) << "[output recording]" << "[" << op_desc.Type() << "]" << name; auto& args = n.second; + size_t idx_in_vector = 0; for (const auto& arg_name : args) { auto idx = arg_to_idx.at(arg_name); VLOG(10) << "[output recording]" << "[" << op_desc.Type() << "]" << arg_name << " " << idx; - (*param_map)[arg_name] = operation->GetResultByIndex(idx); + ir::OpResult value = operation->GetResultByIndex(idx); + bool generated_by_vector = value.type().isa(); + (*param_map)[arg_name] = VariableDefiningInfo( + value, generated_by_vector, generated_by_vector ? idx_in_vector : -1); + idx_in_vector++; } } } @@ -157,7 +265,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - auto op_inputs = GenerateOperationInput(param_map, op_desc); + auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc); OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types = {}; @@ -193,7 +301,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - auto op_inputs = GenerateOperationInput(param_map, op_desc); + auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc); OpOutputTypeList op_output_types = {}; auto op_info = LoopkUpOpInfo(ctx, op_desc); diff --git a/paddle/fluid/translator/op_translator.h b/paddle/fluid/translator/op_translator.h index c767f639d534b9394736d878b0605df9eb20c734..92c03458300257c6b3813e58b4db2e59abc52158 100644 --- a/paddle/fluid/translator/op_translator.h +++ b/paddle/fluid/translator/op_translator.h @@ -20,6 +20,7 @@ #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/var_desc.h" +#include "paddle/fluid/translator/program_translator.h" #include "paddle/ir/ir_context.h" #include "paddle/ir/operation.h" #include "paddle/ir/program.h" @@ -28,8 +29,6 @@ namespace paddle { namespace translator { -using TranslationContext = std::unordered_map; - class OpTranslator { public: using ResultIdx = size_t; diff --git a/paddle/fluid/translator/program_translator.cc b/paddle/fluid/translator/program_translator.cc index 7cdbef589067ba8572c356fca0684e2ede334d01..7618972f1080409797cdd78fafd22caf6a5314b4 100644 --- a/paddle/fluid/translator/program_translator.cc +++ b/paddle/fluid/translator/program_translator.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/translator/op_translator.h" +#include "paddle/fluid/translator/type_translator.h" #include "paddle/ir/attribute.h" #include "paddle/ir/builtin_op.h" #include "paddle/ir/builtin_type.h" @@ -38,6 +39,11 @@ ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, ctx = ir::IrContext::Instance(); } +const std::unordered_set ProgramTranslator::no_cast_var_names = { + "feed", + "fetch", +}; + void ProgramTranslator::Translate() { PADDLE_ENFORCE_EQ( legacy_program->Size(), @@ -59,19 +65,24 @@ void ProgramTranslator::Translate() { void ProgramTranslator::ExtractParameterFromSingleBlock( const BlockDesc& block) { + auto& type_translator = TypeTranslator::instance(); + for (auto& var : block.AllVars()) { if (!var->Persistable()) continue; if (param_map.count(var->Name()) != 0) continue; + if (no_cast_var_names.count(var->Name()) != 0) continue; std::string get_parameter_op_name(ir::GetParameterOp::name()); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name); std::unordered_map op_attribute_map = { {var->Name(), ir::StrAttribute::get(ctx, var->Name())}, }; + ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); ir::Operation* operation = ir::Operation::create( - {}, {ir::Float32Type::get(ctx)}, op_attribute_map, op_info); + {}, {translated_var_type}, op_attribute_map, op_info); program->InsertOp(operation); - param_map[var->Name()] = operation->GetResultByIndex(0); + param_map[var->Name()] = + VariableDefiningInfo(operation->GetResultByIndex(0)); VLOG(10) << "[op translated][get parameter]" << operation; program->SetParameter(var->Name(), nullptr); diff --git a/paddle/fluid/translator/program_translator.h b/paddle/fluid/translator/program_translator.h index 569b93b06aa6d943e4b901eefe91aef8972856df..f7fd4e2890ea64ae4a5a8d434816837fd107c321 100644 --- a/paddle/fluid/translator/program_translator.h +++ b/paddle/fluid/translator/program_translator.h @@ -27,7 +27,25 @@ namespace paddle { namespace translator { -using TranslationContext = std::unordered_map; +struct VariableDefiningInfo { + VariableDefiningInfo(ir::OpResult value, + bool generated_by_vector = false, + int idx_in_vector = -1) + : value(value), + generated_by_vector(generated_by_vector), + idx_in_vector(idx_in_vector) {} + VariableDefiningInfo() {} + + ir::OpResult value; + + bool generated_by_vector = + false; // true if target variabe is generated by Vector + int idx_in_vector = + -1; // positive if target variabe is generated by Vector +}; + +using TranslationContext = + std::unordered_map; class ProgramTranslator { using ProgramDesc = ::paddle::framework::ProgramDesc; @@ -45,6 +63,14 @@ class ProgramTranslator { TranslationContext param_map; ir::IrContext* ctx; + /// In the legacy program desc, there are two special named varibales: + /// 1. "feed", the input variable of feed op + /// 2. "fetch", the output variable of fetch op + /// However, new feed has no input and new fetch has no output + /// So we don't handle these two vairables when + /// `ExtractParameterFromSingleBlock` + static const std::unordered_set no_cast_var_names; + void ExtractParameterFromSingleBlock(const BlockDesc& block); void InsertOperationToSingleBlock(const BlockDesc& block); }; diff --git a/paddle/fluid/translator/translate.h b/paddle/fluid/translator/translate.h index aa2571f74c4cf4b05de6a31e154ef67daf1d6fc6..9cf013ea2bca36d259e7a8d489dbc21badeec930 100644 --- a/paddle/fluid/translator/translate.h +++ b/paddle/fluid/translator/translate.h @@ -22,10 +22,7 @@ namespace paddle { -using LegacyProgramDesc = ::paddle::framework::ProgramDesc; -using Program = ::ir::Program; - -std::unique_ptr TranslateLegacyProgramToProgram( - const LegacyProgramDesc& legacy_program); +std::unique_ptr<::ir::Program> TranslateLegacyProgramToProgram( + const ::paddle::framework::ProgramDesc& legacy_program); } // namespace paddle diff --git a/paddle/fluid/translator/type_translator.cc b/paddle/fluid/translator/type_translator.cc index 9792a4b8537cef3b27d36bda69497ca02ce434f4..e0c31913ffefe2c641eb61edefab3ce081b9babd 100644 --- a/paddle/fluid/translator/type_translator.cc +++ b/paddle/fluid/translator/type_translator.cc @@ -22,6 +22,10 @@ namespace paddle { namespace translator { +using OpDesc = paddle::framework::OpDesc; +using BlockDesc = paddle::framework::BlockDesc; +using VarDesc = paddle::framework::VarDesc; +using VarType = paddle::framework::proto::VarType; using DenseTensorType = paddle::dialect::DenseTensorType; using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage; diff --git a/paddle/fluid/translator/type_translator.h b/paddle/fluid/translator/type_translator.h index b16c1a222a534257f3d5a7d752af88fa48b0924b..707b8913f3385d9a23c5d70c1a79a316527a8bf2 100644 --- a/paddle/fluid/translator/type_translator.h +++ b/paddle/fluid/translator/type_translator.h @@ -27,13 +27,13 @@ namespace paddle { namespace translator { -using OpDesc = paddle::framework::OpDesc; -using BlockDesc = paddle::framework::BlockDesc; -using VarDesc = paddle::framework::VarDesc; -using VarType = paddle::framework::proto::VarType; -using TypeTranslateFn = std::function; +using TypeTranslateFn = + std::function; class TypeTranslator { + public: + using VarType = paddle::framework::proto::VarType; + private: TypeTranslator(); // Disallow instantiation outside of the class. std::unordered_map handlers; diff --git a/paddle/ir/builtin_dialect.cc b/paddle/ir/builtin_dialect.cc index 9c8cacf7bff94a0157b2a37df75262c61414b399..32f8750b7109416647e86433b42363929433bbd1 100644 --- a/paddle/ir/builtin_dialect.cc +++ b/paddle/ir/builtin_dialect.cc @@ -44,7 +44,10 @@ void BuiltinDialect::initialize() { ir::Int64_tAttribute, ir::ArrayAttribute>(); - RegisterOps(); + RegisterOps(); } } // namespace ir diff --git a/paddle/ir/builtin_op.cc b/paddle/ir/builtin_op.cc index 2f5be8a8c2683635665f3f24c8513908da5b6b02..63bfc2196dca35ac8d4a6cab63dfeee138b467c8 100644 --- a/paddle/ir/builtin_op.cc +++ b/paddle/ir/builtin_op.cc @@ -13,7 +13,10 @@ // limitations under the License. #include "paddle/ir/builtin_op.h" + #include "paddle/ir/builtin_attribute.h" +#include "paddle/ir/builtin_type.h" +#include "paddle/phi/core/enforce.h" namespace ir { const char *GetParameterOp::attributes_name[attributes_num] = { @@ -58,4 +61,103 @@ void SetParameterOp::verify(const std::vector &inputs, } } +void CombineOp::verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes) { + // outputs.size() == 1 + PADDLE_ENFORCE_EQ( + outputs.size(), + 1, + phi::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 1.", outputs.size())); + // outputs[0].type == Vector + PADDLE_ENFORCE(outputs[0].isa(), + phi::errors::PreconditionNotMet( + "The type %s of outputs[0] must be equal to VectorType.", + outputs[0])); + ir::VectorType output_type = outputs[0].dyn_cast(); + // inputs.size() == outputs[0].size() + PADDLE_ENFORCE_EQ( + output_type.size(), + inputs.size(), + phi::errors::PreconditionNotMet( + "The size %d of outputs[0] must be equal to size %d of inputs.", + output_type.size(), + inputs.size())); + + // forall i in inputs.size(): inputs[i].type == outputs[0][i].type + for (size_t i = 0; i < inputs.size(); i++) { + PADDLE_ENFORCE_EQ( + output_type[i], + inputs[i].type(), + phi::errors::PreconditionNotMet("The type %s of outputs[0][%d] must be " + "equal to type %s of inputs[%d].", + output_type[i], + i, + inputs[i].type(), + i)); + } +} + +const char *SliceOp::attributes_name[attributes_num] = {"index"}; +void SliceOp::verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes) { + // inputs.size() == 1 + PADDLE_ENFORCE_EQ( + inputs.size(), + 1, + phi::errors::PreconditionNotMet( + "The size %d of inputs must be equal to 1.", inputs.size())); + + // inputs[0].type == Vector + PADDLE_ENFORCE(inputs[0].type().isa(), + phi::errors::PreconditionNotMet( + "The type %s of inputs[0] must be equal to VectorType.", + inputs[0].type())); + ir::VectorType input_type = inputs[0].type().dyn_cast(); + + // outputs.size() == 1 + PADDLE_ENFORCE_EQ( + outputs.size(), + 1, + phi::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 1.", outputs.size())); + + // attributes contains index: Int32 + PADDLE_ENFORCE_NE( + attributes.count("index"), + 0, + phi::errors::PreconditionNotMet("The attributes must contains index.")); + const ir::Attribute &attr = attributes.at("index"); + PADDLE_ENFORCE( + attr.isa(), + phi::errors::PreconditionNotMet("The attribute index must be INT32.")); + auto index = attr.dyn_cast().data(); + + // index >= 0 and < inputs[0].size() + PADDLE_ENFORCE_GE( + index, + 0, + phi::errors::PreconditionNotMet( + "The index %d must be greater or equal than 0.", index)); + PADDLE_ENFORCE_LT( + index, + input_type.size(), + phi::errors::PreconditionNotMet( + "The index %d must be less or equal than size %d of inputs[0].", + index, + input_type.size())); + + // inputs[index].type == outputs[0].type + PADDLE_ENFORCE_EQ( + input_type[index], + outputs[0], + phi::errors::PreconditionNotMet( + "The type %s of inputs[%d] must be equal to type %s of outputs[0].", + input_type[index], + index, + outputs[0])); +} + } // namespace ir diff --git a/paddle/ir/builtin_op.h b/paddle/ir/builtin_op.h index c1953136f8c93c6d7fb4d669fd60244e9c5887ac..d1d2c20b5725dbe04818cd6ff493d51b9c97112a 100644 --- a/paddle/ir/builtin_op.h +++ b/paddle/ir/builtin_op.h @@ -33,7 +33,7 @@ class GetParameterOp : public ir::Op { }; /// -/// \brief GetParameterOp: SetParameterOp(OpOperand, {StrAttribute, +/// \brief SetParameterOp: SetParameterOp(OpOperand, {StrAttribute, /// StrAttribute}) /// class SetParameterOp : public ir::Op { @@ -47,4 +47,38 @@ class SetParameterOp : public ir::Op { const ir::AttributeMap &attributes); }; +/// +/// \brief CombineOp: CombineOp(OpOperand) +/// +class CombineOp : public ir::Op { + public: + using Op::Op; + + static const char *name() { return "builtin.combine"; } + + static constexpr uint32_t attributes_num = 0; + + static constexpr const char **attributes_name = nullptr; + static void verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes); +}; + +/// +/// \brief SliceOp: SliceOp(OpOperand) +/// +class SliceOp : public ir::Op { + public: + using Op::Op; + + static const char *name() { return "builtin.slice"; } + + static constexpr uint32_t attributes_num = 1; + + static const char *attributes_name[attributes_num]; + static void verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes); +}; + } // namespace ir diff --git a/paddle/ir/printer.cc b/paddle/ir/printer.cc index b4d3acb930a34205bf78403c3e728252ca89b4bb..421dedabb3a10088ee52f5508789eff1a1a0155a 100644 --- a/paddle/ir/printer.cc +++ b/paddle/ir/printer.cc @@ -28,6 +28,21 @@ namespace ir { namespace { constexpr char newline[] = "\n"; + +template +void PrintInterleave(ForwardIterator begin, + ForwardIterator end, + UnaryFunctor print_func, + NullFunctor between_func) { + if (begin == end) return; + print_func(*begin); + begin++; + for (; begin != end; begin++) { + between_func(); + print_func(*begin); + } +} + } // namespace class Printer { @@ -47,6 +62,15 @@ class Printer { os << "i32"; } else if (type.isa()) { os << "i64"; + } else if (type.isa()) { + os << "vec<"; + auto inner_types = type.dyn_cast().data(); + PrintInterleave( + inner_types.begin(), + inner_types.end(), + [this](ir::Type v) { this->PrintType(v); }, + [this]() { this->os << ","; }); + os << ">"; } else { auto& dialect = type.dialect(); dialect.PrintType(type, os); @@ -77,22 +101,6 @@ class ProgramPrinter : public Printer { } } - template - void PrintInterleave(ForwardIterator begin, - ForwardIterator end, - UnaryFunctor print_func, - NullFunctor between_func) { - if (begin == end) return; - print_func(*begin); - begin++; - for (; begin != end; begin++) { - between_func(); - print_func(*begin); - } - } - void PrintValue(ir::Value v) { const void* key = static_cast(v.impl()); auto ret = aliases.find(key); diff --git a/paddle/ir/type.cc b/paddle/ir/type.cc index bde3194f8fd4146e5c98714c4493b0a22819503d..e9c24672e5b5e4bed0015078f9c1a7522eecf558 100644 --- a/paddle/ir/type.cc +++ b/paddle/ir/type.cc @@ -16,6 +16,11 @@ #include "paddle/ir/dialect.h" namespace ir { -IrContext *Type::ir_context() const { return dialect().ir_context(); } +IrContext* Type::ir_context() const { return dialect().ir_context(); } + +std::ostream& operator<<(std::ostream& os, Type type) { + type.print(os); + return os; +} } // namespace ir diff --git a/paddle/ir/type.h b/paddle/ir/type.h index fce17db82ebf5a73152da03533521557ab7c20af..89d153c089476e8373427c0c4cf1c5c3ecf6edc2 100644 --- a/paddle/ir/type.h +++ b/paddle/ir/type.h @@ -89,6 +89,8 @@ class Type { const Storage *storage_{nullptr}; }; +std::ostream &operator<<(std::ostream &os, Type type); + } // namespace ir namespace std { diff --git a/test/cpp/ir/ir_program_test.cc b/test/cpp/ir/ir_program_test.cc index c430d9b320b406ee615900a3893aeb893d98a68c..7c6c9acaf52c66cd766255611eeaefa701f2224b 100644 --- a/test/cpp/ir/ir_program_test.cc +++ b/test/cpp/ir/ir_program_test.cc @@ -211,3 +211,62 @@ TEST(program_test, program) { EXPECT_EQ(ops.size() == 4, true); EXPECT_EQ(program.parameters_num() == 3, true); } + +TEST(program_test, slice_combine_test) { + // (1) Init environment. + ir::IrContext *ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + + // (2) Create an empty program object + ir::Program program; + // ir::Program *program = new ir::Program(); + EXPECT_EQ(program.ops().size() == 0, true); + + // (3) Create a float32 DenseTensor Parameter and save into Program + ir::Type fp32_dtype = ir::Float32Type::get(ctx); + + // (4) Def a = GetParameterOp("a") + std::string op1_name = ir::GetParameterOp::name(); + ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name); + std::unordered_map op1_attribute{ + {"parameter_name", ir::StrAttribute::get(ctx, "a")}}; + ir::Operation *op1 = + ir::Operation::create({}, {fp32_dtype}, op1_attribute, op1_info); + program.InsertOp(op1); + + // (5) Def b = GetParameterOp("b") + std::string op2_name = std::string(ir::GetParameterOp::name()); + ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name); + std::unordered_map op2_attribute{ + {"parameter_name", ir::StrAttribute::get(ctx, "b")}}; + ir::Operation *op2 = + ir::Operation::create({}, {fp32_dtype}, op2_attribute, op2_info); + program.InsertOp(op2); + + // (6) Def combine_op = CombineOp("a", "b") + std::string combine_op_name = std::string(ir::CombineOp::name()); + ir::OpInfo combine_op_info = ctx->GetRegisteredOpInfo(combine_op_name); + ir::Type output_type = + ir::VectorType::get(ctx, std::vector({fp32_dtype, fp32_dtype})); + ir::Operation *combine_op = ir::Operation::create( + {op1->GetResultByIndex(0), op2->GetResultByIndex(0)}, + {output_type}, + {}, + combine_op_info); + program.InsertOp(combine_op); + + // (7) Def slice_op = SliceOp(combine_op, 0) + std::string slice_op_name = std::string(ir::SliceOp::name()); + ir::OpInfo slice_op_info = ctx->GetRegisteredOpInfo(slice_op_name); + ir::Attribute index_attr = ir::Int32_tAttribute::get(ctx, 0); + ir::Operation *slice_op = + ir::Operation::create({combine_op->GetResultByIndex(0)}, + {fp32_dtype}, + {{"index", index_attr}}, + slice_op_info); + program.InsertOp(slice_op); + + // (8) Traverse Program + std::list ops = program.ops(); + EXPECT_EQ(ops.size() == 4, true); +} diff --git a/test/cpp/ir/program_translator_test.cc b/test/cpp/ir/program_translator_test.cc index 38cb983ca6dd07f6cf5f5839f541cbb5f7bb4bac..5ae7b0b9b31e7ef8446956145ecfb39c5606a7f0 100644 --- a/test/cpp/ir/program_translator_test.cc +++ b/test/cpp/ir/program_translator_test.cc @@ -59,6 +59,7 @@ TEST(PaddleDialectTest, Translator) { // auto program = paddle::TranslateLegacyProgramToProgram(p); // std::list ops = program->ops(); - // EXPECT_EQ(ops.size(), p.Block(0).OpSize() + program->parameters_num()); - // VLOG(0) << *program << std::endl; + // ops.size() = op size in BlockDesc + get_parameter_op + combine op + // EXPECT_EQ(ops.size(), p.Block(0).OpSize() + program->parameters_num() + + // 20); std::cout << *program << std::endl; }