diff --git a/paddle/fluid/inference/analysis/ir_passes/subgraph_util.cc b/paddle/fluid/inference/analysis/ir_passes/subgraph_util.cc index 33b6d0980b712db3b0b94ecb654a4320fa35e9ad..7c4aab06a1d2b3fadc76b46c7e95cea7818c56e2 100644 --- a/paddle/fluid/inference/analysis/ir_passes/subgraph_util.cc +++ b/paddle/fluid/inference/analysis/ir_passes/subgraph_util.cc @@ -70,11 +70,13 @@ void RenameAndGetOutputs( std::unordered_map same_hierarchy_conv2d_num_map; - auto set_var_shape = [&](const std::string &arg_value) { - auto arg_var_node = graph_var_map.find(arg_value); + auto add_block_var = [&](const std::string &graph_arg, + const std::string &block_arg) { + auto arg_var_node = graph_var_map.find(graph_arg); PADDLE_ENFORCE(arg_var_node != graph_var_map.end()); - auto *var_t = block_desc->Var(arg_value); + auto *var_t = block_desc->Var(block_arg); var_t->SetShape(arg_var_node->second->Var()->GetShape()); + var_t->SetDataType(arg_var_node->second->Var()->GetDataType()); }; for (size_t index = 0; index < block_desc->OpSize(); ++index) { @@ -99,15 +101,16 @@ void RenameAndGetOutputs( const std::string arg_value_with_id = arg_value + std::to_string(var2id[arg_value]); - bool is_var_in_graph = graph_var_map.count(arg_value); - if (input_names_with_id.count(arg_value_with_id)) { replaced_names.push_back(arg_value); + if (graph_var_map.count(arg_value)) { + add_block_var(arg_value, arg_value); + } } else { replaced_names.push_back(arg_value_with_id); - } - if (is_var_in_graph) { - set_var_shape(arg_value); + if (graph_var_map.count(arg_value)) { + add_block_var(arg_value, arg_value_with_id); + } } } in_var->clear_arguments(); @@ -147,11 +150,9 @@ void RenameAndGetOutputs( const std::string arg_value_with_id = arg_value + std::to_string(var2id[arg_value]); - bool is_var_in_graph = graph_var_map.count(arg_value); - if (is_var_in_graph) { - set_var_shape(arg_value); + if (graph_var_map.count(arg_value)) { + add_block_var(arg_value, arg_value_with_id); } - if (output_names_with_id->count(arg_value_with_id)) { (*output_name_map)[arg_value] = arg_value_with_id; }