提交 3f0b97df 编写于 作者: S Shixiaowei02

update tensorrt subgraph_util test=release/1.4

(cherry picked from commit bddb2cd3)
上级 88770542
......@@ -70,11 +70,13 @@ void RenameAndGetOutputs(
std::unordered_map<std::string /*name*/, int /*ITensor_quote_num*/>
same_hierarchy_conv2d_num_map;
auto set_var_shape = [&](const std::string &arg_value) {
auto arg_var_node = graph_var_map.find(arg_value);
auto add_block_var = [&](const std::string &graph_arg,
const std::string &block_arg) {
auto arg_var_node = graph_var_map.find(graph_arg);
PADDLE_ENFORCE(arg_var_node != graph_var_map.end());
auto *var_t = block_desc->Var(arg_value);
auto *var_t = block_desc->Var(block_arg);
var_t->SetShape(arg_var_node->second->Var()->GetShape());
var_t->SetDataType(arg_var_node->second->Var()->GetDataType());
};
for (size_t index = 0; index < block_desc->OpSize(); ++index) {
......@@ -99,15 +101,16 @@ void RenameAndGetOutputs(
const std::string arg_value_with_id =
arg_value + std::to_string(var2id[arg_value]);
bool is_var_in_graph = graph_var_map.count(arg_value);
if (input_names_with_id.count(arg_value_with_id)) {
replaced_names.push_back(arg_value);
if (graph_var_map.count(arg_value)) {
add_block_var(arg_value, arg_value);
}
} else {
replaced_names.push_back(arg_value_with_id);
if (graph_var_map.count(arg_value)) {
add_block_var(arg_value, arg_value_with_id);
}
if (is_var_in_graph) {
set_var_shape(arg_value);
}
}
in_var->clear_arguments();
......@@ -147,11 +150,9 @@ void RenameAndGetOutputs(
const std::string arg_value_with_id =
arg_value + std::to_string(var2id[arg_value]);
bool is_var_in_graph = graph_var_map.count(arg_value);
if (is_var_in_graph) {
set_var_shape(arg_value);
if (graph_var_map.count(arg_value)) {
add_block_var(arg_value, arg_value_with_id);
}
if (output_names_with_id->count(arg_value_with_id)) {
(*output_name_map)[arg_value] = arg_value_with_id;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册