From 94a57f1d83035ef8eca9016b2fbfebf655830f93 Mon Sep 17 00:00:00 2001 From: nhzlx Date: Wed, 19 Sep 2018 07:08:22 +0000 Subject: [PATCH] add trt config to arguments --- .../analysis/data_flow_graph_to_fluid_pass.cc | 8 ++++++-- .../fluid/inference/analysis/subgraph_splitter.cc | 3 ++- paddle/fluid/inference/analysis/subgraph_splitter.h | 9 +++++++-- .../inference/analysis/tensorrt_subgraph_pass.cc | 2 +- .../inference/analysis/tensorrt_subgraph_pass.h | 6 +++++- .../inference/api/api_tensorrt_subgraph_engine.cc | 12 ++++++++++-- paddle/fluid/inference/api/paddle_inference_api.h | 9 +++++++++ .../fluid/inference/tests/api/trt_models_tester.cc | 2 +- paddle/fluid/operators/tensorrt_engine_op.cc | 4 ++-- paddle/fluid/operators/tensorrt_engine_op.h | 13 ++++++------- 10 files changed, 49 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc index 5652940ec6d..9913439604f 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc @@ -97,8 +97,9 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node *node) { } } -void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, +void CreateTrtEngineOp(Node *node, Argument *argument, framework::proto::BlockDesc *block) { + const DataFlowGraph &graph = *(argument->main_dfg); static int counter{0}; PADDLE_ENFORCE(node->IsFunctionBlock()); framework::OpDesc desc; @@ -204,7 +205,10 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, PADDLE_ENFORCE(!block->vars().empty(), "the block has no var-desc"); // Set attrs + SetAttr(desc.Proto(), "subgraph", block->SerializeAsString()); + SetAttr(desc.Proto(), "max_batch_size", argument->Get("max_batch_size")); + SetAttr(desc.Proto(), "workspace_size", argument->Get("workspace_size")); SetAttr(desc.Proto(), "engine_uniq_key", "trt-" + std::to_string(counter++)); SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes())); SetAttr(desc.Proto(), "output_name_mapping", output_mapping); @@ -248,7 +252,7 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) { *block_desc.Proto()->mutable_vars() = argument_->origin_program_desc->blocks(0).vars(); PADDLE_ENFORCE(!block_desc.Proto()->vars().empty()); - CreateTrtEngineOp(node, *argument_->main_dfg, block_desc.Proto()); + CreateTrtEngineOp(node, argument_, block_desc.Proto()); auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex); auto *op = main_block->add_ops(); PADDLE_ENFORCE(!node->pb_msg().empty(), "failed to set desc for block"); diff --git a/paddle/fluid/inference/analysis/subgraph_splitter.cc b/paddle/fluid/inference/analysis/subgraph_splitter.cc index efc14439412..e0a7a1969cb 100644 --- a/paddle/fluid/inference/analysis/subgraph_splitter.cc +++ b/paddle/fluid/inference/analysis/subgraph_splitter.cc @@ -309,7 +309,8 @@ void SubGraphFuse::operator()() { ReplaceNodesWithSubGraphs(); } void SubGraphFuse::ReplaceNodesWithSubGraphs() { auto subgraphs = SubGraphSplitter(graph_, node_inside_subgraph_teller_)(); for (auto &subgraph : subgraphs) { - if (subgraph.size() <= 3) continue; + if (subgraph.size() <= argument_->Get("minimun_subgraph_size")) + continue; std::unordered_set subgraph_uniq(subgraph.begin(), subgraph.end()); // replace this sub-graph with the first node. Two steps: 1. Create a Block // Node that contains this subgraph 2. Mark the nodes inside the sub-graph diff --git a/paddle/fluid/inference/analysis/subgraph_splitter.h b/paddle/fluid/inference/analysis/subgraph_splitter.h index a31afbe6933..76e4fda0249 100644 --- a/paddle/fluid/inference/analysis/subgraph_splitter.h +++ b/paddle/fluid/inference/analysis/subgraph_splitter.h @@ -20,6 +20,7 @@ limitations under the License. */ #include +#include "paddle/fluid/inference/analysis/argument.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/node.h" @@ -63,8 +64,11 @@ class SubGraphFuse { public: using NodeInsideSubgraphTeller = SubGraphSplitter::NodeInsideSubgraphTeller; - SubGraphFuse(DataFlowGraph *graph, const NodeInsideSubgraphTeller &teller) - : graph_(graph), node_inside_subgraph_teller_(teller) {} + SubGraphFuse(DataFlowGraph *graph, const NodeInsideSubgraphTeller &teller, + Argument *argument) + : graph_(graph), + node_inside_subgraph_teller_(teller), + argument_(argument) {} // The main method which run all the logic. void operator()(); @@ -76,6 +80,7 @@ class SubGraphFuse { private: DataFlowGraph *graph_; NodeInsideSubgraphTeller node_inside_subgraph_teller_; + Argument *argument_; }; } // namespace analysis diff --git a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.cc index faf876de6d6..cc1746ecb34 100644 --- a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.cc @@ -24,7 +24,7 @@ TensorRTSubGraphPass::TensorRTSubGraphPass( : node_inside_subgraph_teller_(teller) {} void TensorRTSubGraphPass::Run(DataFlowGraph *graph) { - SubGraphFuse(graph, node_inside_subgraph_teller_)(); + SubGraphFuse(graph, node_inside_subgraph_teller_, argument_)(); VLOG(4) << "debug info " << graph->HumanReadableInfo(false /*show_values*/, true /*show_functions*/); diff --git a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h index 219e3f5470f..3545da9109d 100644 --- a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h +++ b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h @@ -33,7 +33,10 @@ class TensorRTSubGraphPass : public DataFlowGraphPass { explicit TensorRTSubGraphPass(const NodeInsideSubgraphTeller& teller); - bool Initialize(Argument* argument) override { return true; } + bool Initialize(Argument* argument) override { + argument_ = argument; + return true; + } // This class get a sub-graph as input and determine whether to transform this // sub-graph into TensorRT. @@ -46,6 +49,7 @@ class TensorRTSubGraphPass : public DataFlowGraphPass { private: NodeInsideSubgraphTeller node_inside_subgraph_teller_; + Argument* argument_; }; } // namespace analysis diff --git a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc index d9d6e139b87..945b85b7f82 100644 --- a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc +++ b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc @@ -34,8 +34,6 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { bool Init(const std::shared_ptr& parent_scope) { FLAGS_IA_enable_tensorrt_subgraph_engine = true; VLOG(3) << "Predictor::init()"; - FLAGS_tensorrt_max_batch_size = config_.max_batch_size; - FLAGS_tensorrt_workspace_size = config_.workspace_size; if (config_.use_gpu) { place_ = paddle::platform::CUDAPlace(config_.device); } else { @@ -91,6 +89,16 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { void OptimizeInferenceProgram() { // Analyze inference_program Argument argument; + + int* minimum_subgraph_size = new int(config_.minimun_subgraph_size); + int* max_batch_size = new int(config_.max_batch_size); + int* workspace_size = new int(config_.workspace_size); + std::string* precision_mode = new std::string(config_.precision_mode); + argument.Set("minimun_subgraph_size", minimum_subgraph_size); + argument.Set("max_batch_size", max_batch_size); + argument.Set("workspace_size", workspace_size); + argument.Set("precision_mode", precision_mode); + if (!config_.model_dir.empty()) { argument.fluid_model_dir.reset(new std::string(config_.model_dir)); } else { diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index 55a07ca705f..084da823e07 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -150,6 +150,15 @@ struct TensorRTConfig : public NativeConfig { // For workspace_size, refer it from here: // https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#troubleshooting int workspace_size{1 << 30}; + // We transform the Ops that can be converted into TRT layer in the model, + // and aggregate these Ops into subgraphs for TRT execution. + // We set this variable to control the minimum number of nodes in the + // subgraph, 3 as + // default value. + int minimun_subgraph_size = 3; + // Reserved configuration + // We just support "FP32" now, "FP16" and "INT8" will be supported. + std::string precision_mode = "FP32"; }; // NOTE WIP, not stable yet. diff --git a/paddle/fluid/inference/tests/api/trt_models_tester.cc b/paddle/fluid/inference/tests/api/trt_models_tester.cc index 79ee9b23a94..966f21c437f 100644 --- a/paddle/fluid/inference/tests/api/trt_models_tester.cc +++ b/paddle/fluid/inference/tests/api/trt_models_tester.cc @@ -99,7 +99,7 @@ TEST(trt_models_test, main) { std::vector infer_models = {"mobilenet", "resnet50", "resnext50"}; for (auto &model_dir : infer_models) { - CompareTensorRTWithFluid(1, FLAGS_dirname + "/" + model_dir); + CompareTensorRTWithFluid(5, FLAGS_dirname + "/" + model_dir); } } } // namespace paddle diff --git a/paddle/fluid/operators/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt_engine_op.cc index 1048d301714..b34fa55210c 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt_engine_op.cc @@ -22,8 +22,6 @@ namespace paddle { DEFINE_int32(tensorrt_engine_batch_size, 1, "the batch_size of TensorRT"); -DEFINE_int32(tensorrt_max_batch_size, 1, "TensorRT maximum batch size"); -DEFINE_int32(tensorrt_workspace_size, 16 << 20, "TensorRT workspace size"); namespace operators { @@ -34,6 +32,8 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Ys", "A list of outputs").AsDuplicable(); AddAttr("subgraph", "the subgraph."); AddAttr("engine_uniq_key", "unique key for the TRT engine."); + AddAttr("max_batch_size", "the maximum batch size."); + AddAttr("workspace_size", "the maximum batch size."); AddComment("TensorRT engine operator."); } }; diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index 79e75ea9a03..d4ba0f9c33c 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -28,8 +28,6 @@ namespace paddle { DECLARE_int32(tensorrt_engine_batch_size); -DECLARE_int32(tensorrt_max_batch_size); -DECLARE_int32(tensorrt_workspace_size); namespace operators { @@ -92,14 +90,14 @@ class TensorRTEngineKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto engine_name = context.Attr("engine_uniq_key"); + int max_batch_size = context.Attr("max_batch_size"); if (!Singleton::Global().HasEngine(engine_name)) { Prepare(context); } auto* engine = Singleton::Global().Get(engine_name); auto input_names = context.op().Inputs("Xs"); PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs"); - PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size, - FLAGS_tensorrt_max_batch_size); + PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size, max_batch_size); std::vector output_maps = context.Attr>("output_name_mapping"); @@ -173,8 +171,9 @@ class TensorRTEngineKernel : public framework::OpKernel { // Get the ProgramDesc and pass to convert. framework::proto::BlockDesc block_desc; block_desc.ParseFromString(context.Attr("subgraph")); - int max_batch = FLAGS_tensorrt_max_batch_size; - auto max_workspace = FLAGS_tensorrt_workspace_size; + int max_batch_size = context.Attr("max_batch_size"); + int workspace_size = context.Attr("workspace_size"); + auto params = context.Attr>("parameters"); std::unordered_set parameters; for (const auto& param : params) { @@ -186,7 +185,7 @@ class TensorRTEngineKernel : public framework::OpKernel { // TODO(Superjomn) replace this with a different stream auto* engine = Singleton::Global().Create( - max_batch, max_workspace, nullptr /*engine hold its own stream*/, + max_batch_size, workspace_size, nullptr /*engine hold its own stream*/, context.Attr("engine_uniq_key"), boost::get(context.GetPlace()).device); -- GitLab