diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc index 8eac84343dca4b08e4dbdae9720a2ff41159722c..aaf7ca67011fb7bd4a74f6d8f57317594c528ca4 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc @@ -87,7 +87,7 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node *node) { } void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, - framework::proto::BlockDesc &block) { + framework::proto::BlockDesc *block) { static int counter{0}; PADDLE_ENFORCE(node->IsFunctionBlock()); framework::OpDesc desc; @@ -112,11 +112,23 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, desc.SetType("tensorrt_engine"); std::unordered_map output_name_map; - auto subgraph_nodes = func->subgraph; - for (int index = 0; index < block.ops_size(); index++) { - framework::proto::OpDesc *op = block.mutable_ops(index); - // auto &op = block.mutable_ops(index); + // The following procedure is used to rename all the intermediate + // variables and the output variables of the subgraph. + // Why we do this? + // During the transition from fluid OP to tensorrt OP, we map + // the input and output Tensor(fluid data structure) of fluid OP + // to the correspondin ITensor (trt data structure) through the + // Tensor name. When we set up ITensor for an variable, we must + // ensure that it has not been set before. + // If there is variable in the fluid graph, which is not only the + // input of a OP, but also the output of a Op, there will be problems. + // So we have to rename the variable in the subgraph to make sure + // it is either an OP's input or an OP's output. + + auto subgraph_nodes = func->subgraph; + for (int index = 0; index < block->ops_size(); index++) { + framework::proto::OpDesc *op = block->mutable_ops(index); auto correspond_node = subgraph_nodes[index]; PADDLE_ENFORCE_EQ(correspond_node->name(), op->type()); @@ -124,10 +136,9 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, for (auto *in_var : correspond_node->inlinks) { var2id[in_var->name()] = in_var->id(); } - // TODO(zhaolong): add comments + // rename for the input variables of op inside subgraph for (int i = 0; i < op->inputs_size(); i++) { framework::proto::OpDesc_Var *in_var = op->mutable_inputs(i); - // auto &in_var = op->mutable_inputs(i); std::vector replaced_names; for (int k = 0; k < in_var->arguments_size(); k++) { std::string arg_value = in_var->arguments(k); @@ -148,6 +159,7 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, var2id[out_var->name()] = out_var->id(); } + // rename for the output variables of op inside subgraph for (int i = 0; i < op->outputs_size(); i++) { framework::proto::OpDesc_Var *out_var = op->mutable_outputs(i); std::vector replaced_names; @@ -165,15 +177,18 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, } } } + // When tensorrt engine runs at the end of the operation, + // output_mapping help us copy the data from the renamed ITensor + // to Tensor. std::vector output_mapping; for (auto name : output_names) { PADDLE_ENFORCE(output_name_map.count(name) != 0); output_mapping.push_back(output_name_map[name]); } - PADDLE_ENFORCE(!block.vars().empty(), "the block has no var-desc"); + PADDLE_ENFORCE(!block->vars().empty(), "the block has no var-desc"); // Set attrs - SetAttr(desc.Proto(), "subgraph", block.SerializeAsString()); + SetAttr(desc.Proto(), "subgraph", block->SerializeAsString()); SetAttr(desc.Proto(), "engine_uniq_key", "trt-" + std::to_string(counter++)); SetAttr(desc.Proto(), "max_batch", FLAGS_tensorrt_max_batchsize); SetAttr(desc.Proto(), "max_workspace", FLAGS_tensorrt_workspace_size); @@ -220,7 +235,7 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) { *block_desc.Proto()->mutable_vars() = argument_->origin_program_desc->blocks(0).vars(); PADDLE_ENFORCE(!block_desc.Proto()->vars().empty()); - CreateTrtEngineOp(node, *argument_->main_dfg, *block_desc.Proto()); + CreateTrtEngineOp(node, *argument_->main_dfg, block_desc.Proto()); auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex); auto *op = main_block->add_ops(); PADDLE_ENFORCE(!node->pb_msg().empty(), "failed to set desc for block");