未验证 提交 1681edc8 编写于 作者: K kangguangli 提交者: GitHub

[IR] Adapt startup program (#54452)

* adapt_startup_program

* refactor program translator

* polish
上级 8a1af374
......@@ -70,8 +70,10 @@ def OpNameNormalizerInitialization(
def insert_new_mutable_attributes(
op_name: str, mutable_attribute_infos: Dict[str, Dict[str, str]]
):
op_mutable_attribues[op_name] = set()
op_mutable_attribute_infos[op_name] = {}
if op_name not in op_mutable_attribues:
op_mutable_attribues[op_name] = set()
if op_name not in op_mutable_attribute_infos:
op_mutable_attribute_infos[op_name] = {}
for (
attribute_name,
mutable_attribute_info,
......@@ -116,6 +118,14 @@ def OpNameNormalizerInitialization(
if "scalar" in op_compat_item:
insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"])
if "int_array" in op_compat_item:
insert_new_mutable_attributes(
legacy_name, op_compat_item["int_array"]
)
if "scalar" in op_compat_item:
insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"])
# special op mappings
op_name_mappings["fetch_v2"] = "fetch"
......
......@@ -19,13 +19,16 @@
#include "glog/logging.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/ir_adaptor/translator/op_translator.h"
#include "paddle/fluid/ir_adaptor/translator/type_translator.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/value.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
......@@ -33,6 +36,7 @@ namespace translator {
using ProgramDesc = ::paddle::framework::ProgramDesc;
using BlockDesc = ::paddle::framework::BlockDesc;
using VarDesc = ::paddle::framework::VarDesc;
ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program,
ir::Program* program)
......@@ -55,38 +59,77 @@ void ProgramTranslator::Translate() {
for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) {
const BlockDesc& block = legacy_program->Block(block_idx);
ExtractParameterFromSingleBlock(block);
GetParameterForSingleBlock(block);
}
for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) {
const BlockDesc& block = legacy_program->Block(block_idx);
InsertOperationToSingleBlock(block);
}
for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) {
const BlockDesc& block = legacy_program->Block(block_idx);
SetParameterFromSingleBlock(block);
}
}
void ProgramTranslator::ExtractParameterFromSingleBlock(
const BlockDesc& block) {
inline ir::Operation* InsertGetParamaterOp(ir::IrContext* ctx,
const VarDesc* var) {
auto& type_translator = TypeTranslator::instance();
std::string get_parameter_op_name(ir::GetParameterOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name);
std::unordered_map<std::string, ir::Attribute> op_attribute_map = {
{"parameter_name", ir::StrAttribute::get(ctx, var->Name())},
};
ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);
ir::Operation* operation = ir::Operation::Create(
{}, op_attribute_map, {translated_var_type}, op_info);
return operation;
}
inline ir::Operation* InsertSetParamaterOp(ir::IrContext* ctx,
ir::OpResult defining_op_result,
const VarDesc* var) {
std::string set_parameter_op_name(ir::SetParameterOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(set_parameter_op_name);
std::unordered_map<std::string, ir::Attribute> op_attribute_map = {
{"parameter_name", ir::StrAttribute::get(ctx, var->Name())},
};
ir::Operation* operation = ir::Operation::Create(
{defining_op_result}, op_attribute_map, {}, op_info);
return operation;
}
void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) {
for (auto& var : block.AllVars()) {
if (!var->Persistable()) continue;
if (param_map.count(var->Name()) != 0) continue;
if (no_cast_var_names.count(var->Name()) != 0) continue;
std::string get_parameter_op_name(ir::GetParameterOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name);
std::unordered_map<std::string, ir::Attribute> op_attribute_map = {
{"parameter_name", ir::StrAttribute::get(ctx, var->Name())},
};
ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);
ir::Operation* operation = ir::Operation::Create(
{}, op_attribute_map, {translated_var_type}, op_info);
program->block()->push_back(operation);
param_map[var->Name()] =
VariableDefiningInfo(operation->GetResultByIndex(0));
VLOG(10) << "[op translated][get parameter]" << operation;
program->SetParameter(var->Name(), nullptr);
parameter_name_mappings[var->Name()] = var;
}
for (auto op_desc : block.AllOps()) {
for (const auto& n : op_desc->Inputs()) {
const auto& input_var_names = n.second;
for (const auto& var_name : input_var_names) {
bool need_get_parameter_op = (parameter_name_mappings.find(var_name) !=
parameter_name_mappings.end());
need_get_parameter_op &= (parameter_visited.count(var_name) == 0);
if (need_get_parameter_op) {
ir::Operation* op =
InsertGetParamaterOp(ctx, parameter_name_mappings[var_name]);
program->block()->push_back(op);
param_map[var_name] = VariableDefiningInfo(op->GetResultByIndex(0));
VLOG(10) << "[op translated][get parameter]" << op;
program->SetParameter(var_name, nullptr);
parameter_visited.insert(var_name);
}
}
}
}
}
......@@ -99,5 +142,40 @@ void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) {
}
}
void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
const auto& ops = block.AllOps();
for (auto op_desc = ops.rbegin(); op_desc != ops.rend(); op_desc++) {
for (const auto& n : (*op_desc)->Outputs()) {
const auto& output_var_names = n.second;
for (const auto& var_name : output_var_names) {
bool need_set_parameter_op = (parameter_name_mappings.find(var_name) !=
parameter_name_mappings.end());
need_set_parameter_op &= (parameter_visited.count(var_name) == 0);
if (need_set_parameter_op) {
ir::OpResult defining_op_result = param_map[var_name].value;
ir::Operation* op = InsertSetParamaterOp(
ctx, defining_op_result, parameter_name_mappings[var_name]);
ir::Block* block = program->block();
ir::Block::iterator insert_pos = std::find(
block->begin(), block->end(), defining_op_result.owner());
IR_ENFORCE(
insert_pos != block->end(),
"Parameter %s must have corresponding its defining operation",
var_name);
insert_pos++;
block->insert(insert_pos, op);
VLOG(10) << "[op translated][set parameter]" << op;
program->SetParameter(var_name, nullptr);
parameter_visited.insert(var_name);
}
}
}
}
}
} // namespace translator
} // namespace paddle
......@@ -50,6 +50,7 @@ using TranslationContext =
class ProgramTranslator {
using ProgramDesc = ::paddle::framework::ProgramDesc;
using BlockDesc = ::paddle::framework::BlockDesc;
using VarDesc = ::paddle::framework::VarDesc;
public:
explicit ProgramTranslator(const ProgramDesc* legacy_program,
......@@ -58,10 +59,13 @@ class ProgramTranslator {
void Translate();
private:
const ProgramDesc* legacy_program;
ir::Program* program;
const ProgramDesc* legacy_program; // not owned
ir::Program* program; // not owned
ir::IrContext* ctx; // not owned
TranslationContext param_map;
ir::IrContext* ctx;
std::unordered_map<std::string, VarDesc*> parameter_name_mappings;
std::unordered_set<std::string> parameter_visited;
/// In the legacy program desc, there are two special named varibales:
/// 1. "feed", the input variable of feed op
......@@ -71,8 +75,9 @@ class ProgramTranslator {
/// `ExtractParameterFromSingleBlock`
static const std::unordered_set<std::string> no_cast_var_names;
void ExtractParameterFromSingleBlock(const BlockDesc& block);
void GetParameterForSingleBlock(const BlockDesc& block);
void InsertOperationToSingleBlock(const BlockDesc& block);
void SetParameterFromSingleBlock(const BlockDesc& block);
};
} // namespace translator
......
......@@ -52,11 +52,16 @@ cc_test_old(
gtest)
file(
DOWNLOAD
https://paddle-ci.gz.bcebos.com/ir_translator_test/restnet50_main.prog
${CMAKE_CURRENT_BINARY_DIR}/restnet50_main.prog
DOWNLOAD https://paddle-ci.gz.bcebos.com/ir_translator_test/resnet50_main.prog
${CMAKE_CURRENT_BINARY_DIR}/resnet50_main.prog
EXPECTED_MD5 b64c0ad3c96d99fc37d12094623ce1ad)
file(
DOWNLOAD
https://paddle-ci.gz.bcebos.com/ir_translator_test/resnet50_startup.prog
${CMAKE_CURRENT_BINARY_DIR}/resnet50_startup.prog
EXPECTED_MD5 6affc5f40f0f0bb84d956919b95eaf50)
cc_test_old(
program_translator_test
SRCS
......
......@@ -46,8 +46,8 @@ ProgramDesc load_from_file(const std::string &file_name) {
return ProgramDesc(buffer);
}
TEST(PaddleDialectTest, Translator) {
auto p = load_from_file("restnet50_main.prog");
TEST(PaddleDialectTest, MainProgram) {
auto p = load_from_file("resnet50_main.prog");
EXPECT_EQ(p.Size(), 1u);
ir::IrContext *ctx = ir::IrContext::Instance();
......@@ -63,3 +63,21 @@ TEST(PaddleDialectTest, Translator) {
program->Print(std::cout);
}
TEST(PaddleDialectTest, StartupProgram) {
auto p = load_from_file("resnet50_startup.prog");
EXPECT_EQ(p.Size(), 1u);
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<PaddleDialect>();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
auto program = paddle::TranslateLegacyProgramToProgram(p);
size_t op_size = program->block()->size();
// ops.size() = op size in BlockDesc + get_parameter_op +
// consant_op_for_uniform
// + consant_op for guassian
EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 3 + 53);
program->Print(std::cout);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册