未验证 提交 e7246bb0 编写于 作者: 周周周 提交者: GitHub

[Paddle Inference] rename vars in subgraph (#56995)

上级 95983a62
...@@ -166,7 +166,7 @@ void RenameAndGetOutputs( ...@@ -166,7 +166,7 @@ void RenameAndGetOutputs(
for (int k = 0; k < in_var->arguments_size(); k++) { // all the arguments for (int k = 0; k < in_var->arguments_size(); k++) { // all the arguments
const std::string arg_value = in_var->arguments(k); const std::string arg_value = in_var->arguments(k);
const std::string arg_value_with_id = const std::string arg_value_with_id =
arg_value + std::to_string(var2id[arg_value]); RenameVarBeUnique(arg_value, std::to_string(var2id[arg_value]));
if (input_names_with_id.count(arg_value_with_id)) { if (input_names_with_id.count(arg_value_with_id)) {
replaced_names.push_back(arg_value); replaced_names.push_back(arg_value);
if (graph_var_map.count(arg_value)) { if (graph_var_map.count(arg_value)) {
...@@ -199,7 +199,8 @@ void RenameAndGetOutputs( ...@@ -199,7 +199,8 @@ void RenameAndGetOutputs(
PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("paddings")); PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("paddings"));
if (same_hierarchy_conv2d_num_map[input_var_name] > 0) { if (same_hierarchy_conv2d_num_map[input_var_name] > 0) {
(*output_names_with_id) (*output_names_with_id)
.insert(out_var_name + std::to_string(var2id[out_var_name])); .insert(RenameVarBeUnique(out_var_name,
std::to_string(var2id[out_var_name])));
(*output_names).insert(out_var_name); (*output_names).insert(out_var_name);
} else if (filter_shape[2] == 1 && filter_shape[3] == 1 && } else if (filter_shape[2] == 1 && filter_shape[3] == 1 &&
strides[0] == 1 && strides[1] == 1 && paddings[0] == 0 && strides[0] == 1 && strides[1] == 1 && paddings[0] == 0 &&
...@@ -214,7 +215,7 @@ void RenameAndGetOutputs( ...@@ -214,7 +215,7 @@ void RenameAndGetOutputs(
for (int k = 0; k < out_var->arguments_size(); k++) { for (int k = 0; k < out_var->arguments_size(); k++) {
const std::string arg_value = out_var->arguments(k); const std::string arg_value = out_var->arguments(k);
const std::string arg_value_with_id = const std::string arg_value_with_id =
arg_value + std::to_string(var2id[arg_value]); RenameVarBeUnique(arg_value, std::to_string(var2id[arg_value]));
if (graph_var_map.count(arg_value)) { if (graph_var_map.count(arg_value)) {
add_block_var(arg_value, arg_value_with_id); add_block_var(arg_value, arg_value_with_id);
} }
...@@ -231,6 +232,11 @@ void RenameAndGetOutputs( ...@@ -231,6 +232,11 @@ void RenameAndGetOutputs(
} }
} }
std::string RenameVarBeUnique(std::string original_var_name,
std::string var_id) {
return original_var_name + "_subgraph_" + var_id;
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -62,6 +62,11 @@ void RenameAndGetOutputs( ...@@ -62,6 +62,11 @@ void RenameAndGetOutputs(
const std::unordered_map<std::string, framework::ir::Node *> &graph_var_map, const std::unordered_map<std::string, framework::ir::Node *> &graph_var_map,
bool trt_and_not_int8 = false); bool trt_and_not_int8 = false);
// When fuse some ops into one subgraph, we need to rename all vars within this
// subgraph (excluding the inputs and outputs of the subgraph) to a unique name.
std::string RenameVarBeUnique(std::string original_var_name,
std::string var_id);
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -327,7 +327,8 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -327,7 +327,8 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
// The node->inputs contains input tensors and parameters. // The node->inputs contains input tensors and parameters.
for (auto *x : node->inputs) { for (auto *x : node->inputs) {
input_names.insert(x->Name()); input_names.insert(x->Name());
input_names_with_id.insert(x->Name() + std::to_string(x->id())); input_names_with_id.insert(
RenameVarBeUnique(x->Name(), std::to_string(x->id())));
if (std::count(graph_params.begin(), graph_params.end(), x->Name()) > 0) { if (std::count(graph_params.begin(), graph_params.end(), x->Name()) > 0) {
parameters.push_back(x->Name()); parameters.push_back(x->Name());
} }
...@@ -357,7 +358,8 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -357,7 +358,8 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
// https://github.com/PaddlePaddle/Paddle/pull/53184 // https://github.com/PaddlePaddle/Paddle/pull/53184
for (auto *n : graph->Nodes()) { for (auto *n : graph->Nodes()) {
if (n->IsVar() && input_names.count(n->Name())) { if (n->IsVar() && input_names.count(n->Name())) {
input_names_with_id.insert(n->Name() + std::to_string(n->id())); input_names_with_id.insert(
RenameVarBeUnique(n->Name(), std::to_string(n->id())));
} }
} }
...@@ -412,7 +414,8 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -412,7 +414,8 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
for (auto *x : node->outputs) { for (auto *x : node->outputs) {
output_names.insert(x->Name()); output_names.insert(x->Name());
output_names_with_id.insert(x->Name() + std::to_string(x->id())); output_names_with_id.insert(
RenameVarBeUnique(x->Name(), std::to_string(x->id())));
origin_name_output_rank[x->Name()] = x->Var()->GetShape().size(); origin_name_output_rank[x->Name()] = x->Var()->GetShape().size();
trt_outputs.insert(x); trt_outputs.insert(x);
map_origin_outputs_dtype[x->Name()] = map_origin_outputs_dtype[x->Name()] =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册