From 1681edc8d0a095b7fbe0fc9e1f9d118af4985b23 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Fri, 9 Jun 2023 12:56:46 +0800 Subject: [PATCH] [IR] Adapt startup program (#54452) * adapt_startup_program * refactor program translator * polish --- .../ir_adaptor/translator/op_compat_gen.py | 14 ++- .../translator/program_translator.cc | 112 +++++++++++++++--- .../translator/program_translator.h | 13 +- test/cpp/ir/core/CMakeLists.txt | 11 +- test/cpp/ir/core/program_translator_test.cc | 22 +++- 5 files changed, 144 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py index 5a852754aed..5fbe508ce80 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py +++ b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py @@ -70,8 +70,10 @@ def OpNameNormalizerInitialization( def insert_new_mutable_attributes( op_name: str, mutable_attribute_infos: Dict[str, Dict[str, str]] ): - op_mutable_attribues[op_name] = set() - op_mutable_attribute_infos[op_name] = {} + if op_name not in op_mutable_attribues: + op_mutable_attribues[op_name] = set() + if op_name not in op_mutable_attribute_infos: + op_mutable_attribute_infos[op_name] = {} for ( attribute_name, mutable_attribute_info, @@ -116,6 +118,14 @@ def OpNameNormalizerInitialization( if "scalar" in op_compat_item: insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"]) + if "int_array" in op_compat_item: + insert_new_mutable_attributes( + legacy_name, op_compat_item["int_array"] + ) + + if "scalar" in op_compat_item: + insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"]) + # special op mappings op_name_mappings["fetch_v2"] = "fetch" diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index b637a95e4dd..a91846c033a 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -19,13 +19,16 @@ #include "glog/logging.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/ir_adaptor/translator/op_translator.h" #include "paddle/fluid/ir_adaptor/translator/type_translator.h" #include "paddle/ir/core/attribute.h" #include "paddle/ir/core/block.h" #include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_type.h" +#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/operation.h" +#include "paddle/ir/core/value.h" #include "paddle/phi/core/enforce.h" namespace paddle { @@ -33,6 +36,7 @@ namespace translator { using ProgramDesc = ::paddle::framework::ProgramDesc; using BlockDesc = ::paddle::framework::BlockDesc; +using VarDesc = ::paddle::framework::VarDesc; ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, ir::Program* program) @@ -55,38 +59,77 @@ void ProgramTranslator::Translate() { for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) { const BlockDesc& block = legacy_program->Block(block_idx); - ExtractParameterFromSingleBlock(block); + GetParameterForSingleBlock(block); } for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) { const BlockDesc& block = legacy_program->Block(block_idx); InsertOperationToSingleBlock(block); } + + for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) { + const BlockDesc& block = legacy_program->Block(block_idx); + SetParameterFromSingleBlock(block); + } } -void ProgramTranslator::ExtractParameterFromSingleBlock( - const BlockDesc& block) { +inline ir::Operation* InsertGetParamaterOp(ir::IrContext* ctx, + const VarDesc* var) { auto& type_translator = TypeTranslator::instance(); + std::string get_parameter_op_name(ir::GetParameterOp::name()); + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name); + std::unordered_map op_attribute_map = { + {"parameter_name", ir::StrAttribute::get(ctx, var->Name())}, + }; + + ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); + ir::Operation* operation = ir::Operation::Create( + {}, op_attribute_map, {translated_var_type}, op_info); + return operation; +} + +inline ir::Operation* InsertSetParamaterOp(ir::IrContext* ctx, + ir::OpResult defining_op_result, + const VarDesc* var) { + std::string set_parameter_op_name(ir::SetParameterOp::name()); + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(set_parameter_op_name); + std::unordered_map op_attribute_map = { + {"parameter_name", ir::StrAttribute::get(ctx, var->Name())}, + }; + + ir::Operation* operation = ir::Operation::Create( + {defining_op_result}, op_attribute_map, {}, op_info); + return operation; +} +void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) { for (auto& var : block.AllVars()) { if (!var->Persistable()) continue; if (param_map.count(var->Name()) != 0) continue; if (no_cast_var_names.count(var->Name()) != 0) continue; - std::string get_parameter_op_name(ir::GetParameterOp::name()); - ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name); - std::unordered_map op_attribute_map = { - {"parameter_name", ir::StrAttribute::get(ctx, var->Name())}, - }; - ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); - ir::Operation* operation = ir::Operation::Create( - {}, op_attribute_map, {translated_var_type}, op_info); - program->block()->push_back(operation); - param_map[var->Name()] = - VariableDefiningInfo(operation->GetResultByIndex(0)); - VLOG(10) << "[op translated][get parameter]" << operation; - - program->SetParameter(var->Name(), nullptr); + parameter_name_mappings[var->Name()] = var; + } + + for (auto op_desc : block.AllOps()) { + for (const auto& n : op_desc->Inputs()) { + const auto& input_var_names = n.second; + for (const auto& var_name : input_var_names) { + bool need_get_parameter_op = (parameter_name_mappings.find(var_name) != + parameter_name_mappings.end()); + need_get_parameter_op &= (parameter_visited.count(var_name) == 0); + if (need_get_parameter_op) { + ir::Operation* op = + InsertGetParamaterOp(ctx, parameter_name_mappings[var_name]); + program->block()->push_back(op); + param_map[var_name] = VariableDefiningInfo(op->GetResultByIndex(0)); + VLOG(10) << "[op translated][get parameter]" << op; + + program->SetParameter(var_name, nullptr); + parameter_visited.insert(var_name); + } + } + } } } @@ -99,5 +142,40 @@ void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) { } } +void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { + const auto& ops = block.AllOps(); + for (auto op_desc = ops.rbegin(); op_desc != ops.rend(); op_desc++) { + for (const auto& n : (*op_desc)->Outputs()) { + const auto& output_var_names = n.second; + for (const auto& var_name : output_var_names) { + bool need_set_parameter_op = (parameter_name_mappings.find(var_name) != + parameter_name_mappings.end()); + need_set_parameter_op &= (parameter_visited.count(var_name) == 0); + if (need_set_parameter_op) { + ir::OpResult defining_op_result = param_map[var_name].value; + ir::Operation* op = InsertSetParamaterOp( + ctx, defining_op_result, parameter_name_mappings[var_name]); + + ir::Block* block = program->block(); + ir::Block::iterator insert_pos = std::find( + block->begin(), block->end(), defining_op_result.owner()); + + IR_ENFORCE( + insert_pos != block->end(), + "Parameter %s must have corresponding its defining operation", + var_name); + insert_pos++; + + block->insert(insert_pos, op); + VLOG(10) << "[op translated][set parameter]" << op; + + program->SetParameter(var_name, nullptr); + parameter_visited.insert(var_name); + } + } + } + } +} + } // namespace translator } // namespace paddle diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.h b/paddle/fluid/ir_adaptor/translator/program_translator.h index 3012dfc84c6..cd9b1b4a81f 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.h +++ b/paddle/fluid/ir_adaptor/translator/program_translator.h @@ -50,6 +50,7 @@ using TranslationContext = class ProgramTranslator { using ProgramDesc = ::paddle::framework::ProgramDesc; using BlockDesc = ::paddle::framework::BlockDesc; + using VarDesc = ::paddle::framework::VarDesc; public: explicit ProgramTranslator(const ProgramDesc* legacy_program, @@ -58,10 +59,13 @@ class ProgramTranslator { void Translate(); private: - const ProgramDesc* legacy_program; - ir::Program* program; + const ProgramDesc* legacy_program; // not owned + ir::Program* program; // not owned + ir::IrContext* ctx; // not owned + TranslationContext param_map; - ir::IrContext* ctx; + std::unordered_map parameter_name_mappings; + std::unordered_set parameter_visited; /// In the legacy program desc, there are two special named varibales: /// 1. "feed", the input variable of feed op @@ -71,8 +75,9 @@ class ProgramTranslator { /// `ExtractParameterFromSingleBlock` static const std::unordered_set no_cast_var_names; - void ExtractParameterFromSingleBlock(const BlockDesc& block); + void GetParameterForSingleBlock(const BlockDesc& block); void InsertOperationToSingleBlock(const BlockDesc& block); + void SetParameterFromSingleBlock(const BlockDesc& block); }; } // namespace translator diff --git a/test/cpp/ir/core/CMakeLists.txt b/test/cpp/ir/core/CMakeLists.txt index f2d119742e1..a7817ffc02b 100644 --- a/test/cpp/ir/core/CMakeLists.txt +++ b/test/cpp/ir/core/CMakeLists.txt @@ -52,11 +52,16 @@ cc_test_old( gtest) file( - DOWNLOAD - https://paddle-ci.gz.bcebos.com/ir_translator_test/restnet50_main.prog - ${CMAKE_CURRENT_BINARY_DIR}/restnet50_main.prog + DOWNLOAD https://paddle-ci.gz.bcebos.com/ir_translator_test/resnet50_main.prog + ${CMAKE_CURRENT_BINARY_DIR}/resnet50_main.prog EXPECTED_MD5 b64c0ad3c96d99fc37d12094623ce1ad) +file( + DOWNLOAD + https://paddle-ci.gz.bcebos.com/ir_translator_test/resnet50_startup.prog + ${CMAKE_CURRENT_BINARY_DIR}/resnet50_startup.prog + EXPECTED_MD5 6affc5f40f0f0bb84d956919b95eaf50) + cc_test_old( program_translator_test SRCS diff --git a/test/cpp/ir/core/program_translator_test.cc b/test/cpp/ir/core/program_translator_test.cc index c8824a5f7c8..bac606a5e12 100644 --- a/test/cpp/ir/core/program_translator_test.cc +++ b/test/cpp/ir/core/program_translator_test.cc @@ -46,8 +46,8 @@ ProgramDesc load_from_file(const std::string &file_name) { return ProgramDesc(buffer); } -TEST(PaddleDialectTest, Translator) { - auto p = load_from_file("restnet50_main.prog"); +TEST(PaddleDialectTest, MainProgram) { + auto p = load_from_file("resnet50_main.prog"); EXPECT_EQ(p.Size(), 1u); ir::IrContext *ctx = ir::IrContext::Instance(); @@ -63,3 +63,21 @@ TEST(PaddleDialectTest, Translator) { program->Print(std::cout); } + +TEST(PaddleDialectTest, StartupProgram) { + auto p = load_from_file("resnet50_startup.prog"); + EXPECT_EQ(p.Size(), 1u); + + ir::IrContext *ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + auto program = paddle::TranslateLegacyProgramToProgram(p); + + size_t op_size = program->block()->size(); + // ops.size() = op size in BlockDesc + get_parameter_op + + // consant_op_for_uniform + // + consant_op for guassian + EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 3 + 53); + + program->Print(std::cout); +} -- GitLab