未验证 提交 45ba9cf0 编写于 作者: A Aurelius84 提交者: GitHub

[IR]Polish ProgramTranslator private member code style (#54470)

* [IR]Polish ProgramTranslator private member code style

* update blog
上级 548fb821
...@@ -40,8 +40,8 @@ using VarDesc = ::paddle::framework::VarDesc; ...@@ -40,8 +40,8 @@ using VarDesc = ::paddle::framework::VarDesc;
ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program,
ir::Program* program) ir::Program* program)
: legacy_program(legacy_program), program(program) { : legacy_program_(legacy_program), program_(program) {
ctx = ir::IrContext::Instance(); ctx_ = ir::IrContext::Instance();
} }
const std::unordered_set<std::string> ProgramTranslator::no_cast_var_names = { const std::unordered_set<std::string> ProgramTranslator::no_cast_var_names = {
...@@ -51,24 +51,24 @@ const std::unordered_set<std::string> ProgramTranslator::no_cast_var_names = { ...@@ -51,24 +51,24 @@ const std::unordered_set<std::string> ProgramTranslator::no_cast_var_names = {
void ProgramTranslator::Translate() { void ProgramTranslator::Translate() {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
legacy_program->Size(), legacy_program_->Size(),
1u, 1u,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Not support multi block ProgramDesc translated, now has %d blocks", "Not support multi block ProgramDesc translated, now has %d blocks",
legacy_program->Size())); legacy_program_->Size()));
for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) { for (size_t block_idx = 0; block_idx < legacy_program_->Size(); block_idx++) {
const BlockDesc& block = legacy_program->Block(block_idx); const BlockDesc& block = legacy_program_->Block(block_idx);
GetParameterForSingleBlock(block); GetParameterForSingleBlock(block);
} }
for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) { for (size_t block_idx = 0; block_idx < legacy_program_->Size(); block_idx++) {
const BlockDesc& block = legacy_program->Block(block_idx); const BlockDesc& block = legacy_program_->Block(block_idx);
InsertOperationToSingleBlock(block); InsertOperationToSingleBlock(block);
} }
for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) { for (size_t block_idx = 0; block_idx < legacy_program_->Size(); block_idx++) {
const BlockDesc& block = legacy_program->Block(block_idx); const BlockDesc& block = legacy_program_->Block(block_idx);
SetParameterFromSingleBlock(block); SetParameterFromSingleBlock(block);
} }
} }
...@@ -105,28 +105,28 @@ inline ir::Operation* InsertSetParamaterOp(ir::IrContext* ctx, ...@@ -105,28 +105,28 @@ inline ir::Operation* InsertSetParamaterOp(ir::IrContext* ctx,
void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) { void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) {
for (auto& var : block.AllVars()) { for (auto& var : block.AllVars()) {
if (!var->Persistable()) continue; if (!var->Persistable()) continue;
if (param_map.count(var->Name()) != 0) continue; if (param_map_.count(var->Name()) != 0) continue;
if (no_cast_var_names.count(var->Name()) != 0) continue; if (no_cast_var_names.count(var->Name()) != 0) continue;
parameter_name_mappings[var->Name()] = var; parameter_name_mappings_[var->Name()] = var;
} }
for (auto op_desc : block.AllOps()) { for (auto op_desc : block.AllOps()) {
for (const auto& n : op_desc->Inputs()) { for (const auto& n : op_desc->Inputs()) {
const auto& input_var_names = n.second; const auto& input_var_names = n.second;
for (const auto& var_name : input_var_names) { for (const auto& var_name : input_var_names) {
bool need_get_parameter_op = (parameter_name_mappings.find(var_name) != bool need_get_parameter_op = (parameter_name_mappings_.find(var_name) !=
parameter_name_mappings.end()); parameter_name_mappings_.end());
need_get_parameter_op &= (parameter_visited.count(var_name) == 0); need_get_parameter_op &= (parameter_visited_.count(var_name) == 0);
if (need_get_parameter_op) { if (need_get_parameter_op) {
ir::Operation* op = ir::Operation* op =
InsertGetParamaterOp(ctx, parameter_name_mappings[var_name]); InsertGetParamaterOp(ctx_, parameter_name_mappings_[var_name]);
program->block()->push_back(op); program_->block()->push_back(op);
param_map[var_name] = VariableDefiningInfo(op->result(0)); param_map_[var_name] = VariableDefiningInfo(op->result(0));
VLOG(10) << "[op translated][get parameter]" << op; VLOG(10) << "[op translated][get parameter]" << op;
program->SetParameter(var_name, nullptr); program_->SetParameter(var_name, nullptr);
parameter_visited.insert(var_name); parameter_visited_.insert(var_name);
} }
} }
} }
...@@ -137,7 +137,7 @@ void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) { ...@@ -137,7 +137,7 @@ void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) {
auto& op_translator = OpTranslator::instance(); auto& op_translator = OpTranslator::instance();
for (auto op : block.AllOps()) { for (auto op : block.AllOps()) {
OpTranslateFn& fn = op_translator[op->Type()]; OpTranslateFn& fn = op_translator[op->Type()];
ir::Operation* operation = fn(ctx, &param_map, program, *op); ir::Operation* operation = fn(ctx_, &param_map_, program_, *op);
VLOG(10) << "[op translated][special]" << operation; VLOG(10) << "[op translated][special]" << operation;
} }
} }
...@@ -148,15 +148,15 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { ...@@ -148,15 +148,15 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
for (const auto& n : (*op_desc)->Outputs()) { for (const auto& n : (*op_desc)->Outputs()) {
const auto& output_var_names = n.second; const auto& output_var_names = n.second;
for (const auto& var_name : output_var_names) { for (const auto& var_name : output_var_names) {
bool need_set_parameter_op = (parameter_name_mappings.find(var_name) != bool need_set_parameter_op = (parameter_name_mappings_.find(var_name) !=
parameter_name_mappings.end()); parameter_name_mappings_.end());
need_set_parameter_op &= (parameter_visited.count(var_name) == 0); need_set_parameter_op &= (parameter_visited_.count(var_name) == 0);
if (need_set_parameter_op) { if (need_set_parameter_op) {
ir::OpResult defining_op_result = param_map[var_name].value; ir::OpResult defining_op_result = param_map_[var_name].value;
ir::Operation* op = InsertSetParamaterOp( ir::Operation* op = InsertSetParamaterOp(
ctx, defining_op_result, parameter_name_mappings[var_name]); ctx_, defining_op_result, parameter_name_mappings_[var_name]);
ir::Block* block = program->block(); ir::Block* block = program_->block();
ir::Block::iterator insert_pos = std::find( ir::Block::iterator insert_pos = std::find(
block->begin(), block->end(), defining_op_result.owner()); block->begin(), block->end(), defining_op_result.owner());
...@@ -169,8 +169,8 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { ...@@ -169,8 +169,8 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
block->insert(insert_pos, op); block->insert(insert_pos, op);
VLOG(10) << "[op translated][set parameter]" << op; VLOG(10) << "[op translated][set parameter]" << op;
program->SetParameter(var_name, nullptr); program_->SetParameter(var_name, nullptr);
parameter_visited.insert(var_name); parameter_visited_.insert(var_name);
} }
} }
} }
......
...@@ -59,13 +59,13 @@ class ProgramTranslator { ...@@ -59,13 +59,13 @@ class ProgramTranslator {
void Translate(); void Translate();
private: private:
const ProgramDesc* legacy_program; // not owned const ProgramDesc* legacy_program_; // not owned
ir::Program* program; // not owned ir::Program* program_; // not owned
ir::IrContext* ctx; // not owned ir::IrContext* ctx_; // not owned
TranslationContext param_map; TranslationContext param_map_;
std::unordered_map<std::string, VarDesc*> parameter_name_mappings; std::unordered_map<std::string, VarDesc*> parameter_name_mappings_;
std::unordered_set<std::string> parameter_visited; std::unordered_set<std::string> parameter_visited_;
/// In the legacy program desc, there are two special named varibales: /// In the legacy program desc, there are two special named varibales:
/// 1. "feed", the input variable of feed op /// 1. "feed", the input variable of feed op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册