提交 3f0b97df 编写于 作者: S Shixiaowei02

update tensorrt subgraph_util test=release/1.4

(cherry picked from commit bddb2cd3)
上级 88770542
...@@ -70,11 +70,13 @@ void RenameAndGetOutputs( ...@@ -70,11 +70,13 @@ void RenameAndGetOutputs(
std::unordered_map<std::string /*name*/, int /*ITensor_quote_num*/> std::unordered_map<std::string /*name*/, int /*ITensor_quote_num*/>
same_hierarchy_conv2d_num_map; same_hierarchy_conv2d_num_map;
auto set_var_shape = [&](const std::string &arg_value) { auto add_block_var = [&](const std::string &graph_arg,
auto arg_var_node = graph_var_map.find(arg_value); const std::string &block_arg) {
auto arg_var_node = graph_var_map.find(graph_arg);
PADDLE_ENFORCE(arg_var_node != graph_var_map.end()); PADDLE_ENFORCE(arg_var_node != graph_var_map.end());
auto *var_t = block_desc->Var(arg_value); auto *var_t = block_desc->Var(block_arg);
var_t->SetShape(arg_var_node->second->Var()->GetShape()); var_t->SetShape(arg_var_node->second->Var()->GetShape());
var_t->SetDataType(arg_var_node->second->Var()->GetDataType());
}; };
for (size_t index = 0; index < block_desc->OpSize(); ++index) { for (size_t index = 0; index < block_desc->OpSize(); ++index) {
...@@ -99,15 +101,16 @@ void RenameAndGetOutputs( ...@@ -99,15 +101,16 @@ void RenameAndGetOutputs(
const std::string arg_value_with_id = const std::string arg_value_with_id =
arg_value + std::to_string(var2id[arg_value]); arg_value + std::to_string(var2id[arg_value]);
bool is_var_in_graph = graph_var_map.count(arg_value);
if (input_names_with_id.count(arg_value_with_id)) { if (input_names_with_id.count(arg_value_with_id)) {
replaced_names.push_back(arg_value); replaced_names.push_back(arg_value);
if (graph_var_map.count(arg_value)) {
add_block_var(arg_value, arg_value);
}
} else { } else {
replaced_names.push_back(arg_value_with_id); replaced_names.push_back(arg_value_with_id);
} if (graph_var_map.count(arg_value)) {
if (is_var_in_graph) { add_block_var(arg_value, arg_value_with_id);
set_var_shape(arg_value); }
} }
} }
in_var->clear_arguments(); in_var->clear_arguments();
...@@ -147,11 +150,9 @@ void RenameAndGetOutputs( ...@@ -147,11 +150,9 @@ void RenameAndGetOutputs(
const std::string arg_value_with_id = const std::string arg_value_with_id =
arg_value + std::to_string(var2id[arg_value]); arg_value + std::to_string(var2id[arg_value]);
bool is_var_in_graph = graph_var_map.count(arg_value); if (graph_var_map.count(arg_value)) {
if (is_var_in_graph) { add_block_var(arg_value, arg_value_with_id);
set_var_shape(arg_value);
} }
if (output_names_with_id->count(arg_value_with_id)) { if (output_names_with_id->count(arg_value_with_id)) {
(*output_name_map)[arg_value] = arg_value_with_id; (*output_name_map)[arg_value] = arg_value_with_id;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册