未验证 提交 1681edc8 编写于 作者: K kangguangli 提交者: GitHub

[IR] Adapt startup program (#54452)

* adapt_startup_program

* refactor program translator

* polish
上级 8a1af374
...@@ -70,8 +70,10 @@ def OpNameNormalizerInitialization( ...@@ -70,8 +70,10 @@ def OpNameNormalizerInitialization(
def insert_new_mutable_attributes( def insert_new_mutable_attributes(
op_name: str, mutable_attribute_infos: Dict[str, Dict[str, str]] op_name: str, mutable_attribute_infos: Dict[str, Dict[str, str]]
): ):
op_mutable_attribues[op_name] = set() if op_name not in op_mutable_attribues:
op_mutable_attribute_infos[op_name] = {} op_mutable_attribues[op_name] = set()
if op_name not in op_mutable_attribute_infos:
op_mutable_attribute_infos[op_name] = {}
for ( for (
attribute_name, attribute_name,
mutable_attribute_info, mutable_attribute_info,
...@@ -116,6 +118,14 @@ def OpNameNormalizerInitialization( ...@@ -116,6 +118,14 @@ def OpNameNormalizerInitialization(
if "scalar" in op_compat_item: if "scalar" in op_compat_item:
insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"]) insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"])
if "int_array" in op_compat_item:
insert_new_mutable_attributes(
legacy_name, op_compat_item["int_array"]
)
if "scalar" in op_compat_item:
insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"])
# special op mappings # special op mappings
op_name_mappings["fetch_v2"] = "fetch" op_name_mappings["fetch_v2"] = "fetch"
......
...@@ -19,13 +19,16 @@ ...@@ -19,13 +19,16 @@
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/ir_adaptor/translator/op_translator.h" #include "paddle/fluid/ir_adaptor/translator/op_translator.h"
#include "paddle/fluid/ir_adaptor/translator/type_translator.h" #include "paddle/fluid/ir_adaptor/translator/type_translator.h"
#include "paddle/ir/core/attribute.h" #include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/block.h" #include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/ir/core/value.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
namespace paddle { namespace paddle {
...@@ -33,6 +36,7 @@ namespace translator { ...@@ -33,6 +36,7 @@ namespace translator {
using ProgramDesc = ::paddle::framework::ProgramDesc; using ProgramDesc = ::paddle::framework::ProgramDesc;
using BlockDesc = ::paddle::framework::BlockDesc; using BlockDesc = ::paddle::framework::BlockDesc;
using VarDesc = ::paddle::framework::VarDesc;
ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program,
ir::Program* program) ir::Program* program)
...@@ -55,38 +59,77 @@ void ProgramTranslator::Translate() { ...@@ -55,38 +59,77 @@ void ProgramTranslator::Translate() {
for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) { for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) {
const BlockDesc& block = legacy_program->Block(block_idx); const BlockDesc& block = legacy_program->Block(block_idx);
ExtractParameterFromSingleBlock(block); GetParameterForSingleBlock(block);
} }
for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) { for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) {
const BlockDesc& block = legacy_program->Block(block_idx); const BlockDesc& block = legacy_program->Block(block_idx);
InsertOperationToSingleBlock(block); InsertOperationToSingleBlock(block);
} }
for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) {
const BlockDesc& block = legacy_program->Block(block_idx);
SetParameterFromSingleBlock(block);
}
} }
void ProgramTranslator::ExtractParameterFromSingleBlock( inline ir::Operation* InsertGetParamaterOp(ir::IrContext* ctx,
const BlockDesc& block) { const VarDesc* var) {
auto& type_translator = TypeTranslator::instance(); auto& type_translator = TypeTranslator::instance();
std::string get_parameter_op_name(ir::GetParameterOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name);
std::unordered_map<std::string, ir::Attribute> op_attribute_map = {
{"parameter_name", ir::StrAttribute::get(ctx, var->Name())},
};
ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);
ir::Operation* operation = ir::Operation::Create(
{}, op_attribute_map, {translated_var_type}, op_info);
return operation;
}
inline ir::Operation* InsertSetParamaterOp(ir::IrContext* ctx,
ir::OpResult defining_op_result,
const VarDesc* var) {
std::string set_parameter_op_name(ir::SetParameterOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(set_parameter_op_name);
std::unordered_map<std::string, ir::Attribute> op_attribute_map = {
{"parameter_name", ir::StrAttribute::get(ctx, var->Name())},
};
ir::Operation* operation = ir::Operation::Create(
{defining_op_result}, op_attribute_map, {}, op_info);
return operation;
}
void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) {
for (auto& var : block.AllVars()) { for (auto& var : block.AllVars()) {
if (!var->Persistable()) continue; if (!var->Persistable()) continue;
if (param_map.count(var->Name()) != 0) continue; if (param_map.count(var->Name()) != 0) continue;
if (no_cast_var_names.count(var->Name()) != 0) continue; if (no_cast_var_names.count(var->Name()) != 0) continue;
std::string get_parameter_op_name(ir::GetParameterOp::name()); parameter_name_mappings[var->Name()] = var;
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name); }
std::unordered_map<std::string, ir::Attribute> op_attribute_map = {
{"parameter_name", ir::StrAttribute::get(ctx, var->Name())}, for (auto op_desc : block.AllOps()) {
}; for (const auto& n : op_desc->Inputs()) {
ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); const auto& input_var_names = n.second;
ir::Operation* operation = ir::Operation::Create( for (const auto& var_name : input_var_names) {
{}, op_attribute_map, {translated_var_type}, op_info); bool need_get_parameter_op = (parameter_name_mappings.find(var_name) !=
program->block()->push_back(operation); parameter_name_mappings.end());
param_map[var->Name()] = need_get_parameter_op &= (parameter_visited.count(var_name) == 0);
VariableDefiningInfo(operation->GetResultByIndex(0)); if (need_get_parameter_op) {
VLOG(10) << "[op translated][get parameter]" << operation; ir::Operation* op =
InsertGetParamaterOp(ctx, parameter_name_mappings[var_name]);
program->SetParameter(var->Name(), nullptr); program->block()->push_back(op);
param_map[var_name] = VariableDefiningInfo(op->GetResultByIndex(0));
VLOG(10) << "[op translated][get parameter]" << op;
program->SetParameter(var_name, nullptr);
parameter_visited.insert(var_name);
}
}
}
} }
} }
...@@ -99,5 +142,40 @@ void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) { ...@@ -99,5 +142,40 @@ void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) {
} }
} }
void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
const auto& ops = block.AllOps();
for (auto op_desc = ops.rbegin(); op_desc != ops.rend(); op_desc++) {
for (const auto& n : (*op_desc)->Outputs()) {
const auto& output_var_names = n.second;
for (const auto& var_name : output_var_names) {
bool need_set_parameter_op = (parameter_name_mappings.find(var_name) !=
parameter_name_mappings.end());
need_set_parameter_op &= (parameter_visited.count(var_name) == 0);
if (need_set_parameter_op) {
ir::OpResult defining_op_result = param_map[var_name].value;
ir::Operation* op = InsertSetParamaterOp(
ctx, defining_op_result, parameter_name_mappings[var_name]);
ir::Block* block = program->block();
ir::Block::iterator insert_pos = std::find(
block->begin(), block->end(), defining_op_result.owner());
IR_ENFORCE(
insert_pos != block->end(),
"Parameter %s must have corresponding its defining operation",
var_name);
insert_pos++;
block->insert(insert_pos, op);
VLOG(10) << "[op translated][set parameter]" << op;
program->SetParameter(var_name, nullptr);
parameter_visited.insert(var_name);
}
}
}
}
}
} // namespace translator } // namespace translator
} // namespace paddle } // namespace paddle
...@@ -50,6 +50,7 @@ using TranslationContext = ...@@ -50,6 +50,7 @@ using TranslationContext =
class ProgramTranslator { class ProgramTranslator {
using ProgramDesc = ::paddle::framework::ProgramDesc; using ProgramDesc = ::paddle::framework::ProgramDesc;
using BlockDesc = ::paddle::framework::BlockDesc; using BlockDesc = ::paddle::framework::BlockDesc;
using VarDesc = ::paddle::framework::VarDesc;
public: public:
explicit ProgramTranslator(const ProgramDesc* legacy_program, explicit ProgramTranslator(const ProgramDesc* legacy_program,
...@@ -58,10 +59,13 @@ class ProgramTranslator { ...@@ -58,10 +59,13 @@ class ProgramTranslator {
void Translate(); void Translate();
private: private:
const ProgramDesc* legacy_program; const ProgramDesc* legacy_program; // not owned
ir::Program* program; ir::Program* program; // not owned
ir::IrContext* ctx; // not owned
TranslationContext param_map; TranslationContext param_map;
ir::IrContext* ctx; std::unordered_map<std::string, VarDesc*> parameter_name_mappings;
std::unordered_set<std::string> parameter_visited;
/// In the legacy program desc, there are two special named varibales: /// In the legacy program desc, there are two special named varibales:
/// 1. "feed", the input variable of feed op /// 1. "feed", the input variable of feed op
...@@ -71,8 +75,9 @@ class ProgramTranslator { ...@@ -71,8 +75,9 @@ class ProgramTranslator {
/// `ExtractParameterFromSingleBlock` /// `ExtractParameterFromSingleBlock`
static const std::unordered_set<std::string> no_cast_var_names; static const std::unordered_set<std::string> no_cast_var_names;
void ExtractParameterFromSingleBlock(const BlockDesc& block); void GetParameterForSingleBlock(const BlockDesc& block);
void InsertOperationToSingleBlock(const BlockDesc& block); void InsertOperationToSingleBlock(const BlockDesc& block);
void SetParameterFromSingleBlock(const BlockDesc& block);
}; };
} // namespace translator } // namespace translator
......
...@@ -52,11 +52,16 @@ cc_test_old( ...@@ -52,11 +52,16 @@ cc_test_old(
gtest) gtest)
file( file(
DOWNLOAD DOWNLOAD https://paddle-ci.gz.bcebos.com/ir_translator_test/resnet50_main.prog
https://paddle-ci.gz.bcebos.com/ir_translator_test/restnet50_main.prog ${CMAKE_CURRENT_BINARY_DIR}/resnet50_main.prog
${CMAKE_CURRENT_BINARY_DIR}/restnet50_main.prog
EXPECTED_MD5 b64c0ad3c96d99fc37d12094623ce1ad) EXPECTED_MD5 b64c0ad3c96d99fc37d12094623ce1ad)
file(
DOWNLOAD
https://paddle-ci.gz.bcebos.com/ir_translator_test/resnet50_startup.prog
${CMAKE_CURRENT_BINARY_DIR}/resnet50_startup.prog
EXPECTED_MD5 6affc5f40f0f0bb84d956919b95eaf50)
cc_test_old( cc_test_old(
program_translator_test program_translator_test
SRCS SRCS
......
...@@ -46,8 +46,8 @@ ProgramDesc load_from_file(const std::string &file_name) { ...@@ -46,8 +46,8 @@ ProgramDesc load_from_file(const std::string &file_name) {
return ProgramDesc(buffer); return ProgramDesc(buffer);
} }
TEST(PaddleDialectTest, Translator) { TEST(PaddleDialectTest, MainProgram) {
auto p = load_from_file("restnet50_main.prog"); auto p = load_from_file("resnet50_main.prog");
EXPECT_EQ(p.Size(), 1u); EXPECT_EQ(p.Size(), 1u);
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
...@@ -63,3 +63,21 @@ TEST(PaddleDialectTest, Translator) { ...@@ -63,3 +63,21 @@ TEST(PaddleDialectTest, Translator) {
program->Print(std::cout); program->Print(std::cout);
} }
TEST(PaddleDialectTest, StartupProgram) {
auto p = load_from_file("resnet50_startup.prog");
EXPECT_EQ(p.Size(), 1u);
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<PaddleDialect>();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
auto program = paddle::TranslateLegacyProgramToProgram(p);
size_t op_size = program->block()->size();
// ops.size() = op size in BlockDesc + get_parameter_op +
// consant_op_for_uniform
// + consant_op for guassian
EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 3 + 53);
program->Print(std::cout);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册