diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index 3b5be7f3ee33c73a9704bafa9f1b736c8a3cd9ea..f90910ac0d0a897ef01d4ca2bd0bca575baf4c40 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -37,12 +37,16 @@ TEST(Analyzer, analysis_without_tensorrt) { TEST(Analyzer, analysis_with_tensorrt) { FLAGS_IA_enable_tensorrt_subgraph_engine = true; Argument argument; + argument.Set("minimum_subgraph_size", new int(0)); + argument.Set("max_batch_size", new int(3)); + argument.Set("workspace_size", new int(1 << 20)); + argument.Set("precision_mode", new std::string("FP32")); argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir)); Analyzer analyser; analyser.Run(&argument); } -void TestWord2vecPrediction(const std::string &model_path) { +void TestWord2vecPrediction(const std::string& model_path) { NativeConfig config; config.model_dir = model_path; config.use_gpu = false; @@ -73,8 +77,8 @@ void TestWord2vecPrediction(const std::string &model_path) { // The outputs' buffers are in CPU memory. for (size_t i = 0; i < std::min(5UL, num_elements); i++) { LOG(INFO) << "data: " - << static_cast(outputs.front().data.data())[i]; - PADDLE_ENFORCE(static_cast(outputs.front().data.data())[i], + << static_cast(outputs.front().data.data())[i]; + PADDLE_ENFORCE(static_cast(outputs.front().data.data())[i], result[i]); } } diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc index 5652940ec6d4cc7ba9a1d3a3e65f7dca1690d8c4..cb549f4b50cf56154a951d16b58b022dbad3e990 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc @@ -97,8 +97,10 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node *node) { } } -void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, +void CreateTrtEngineOp(Node *node, Argument *argument, framework::proto::BlockDesc *block) { + PADDLE_ENFORCE(argument->main_dfg.get()); + const DataFlowGraph &graph = *(argument->main_dfg); static int counter{0}; PADDLE_ENFORCE(node->IsFunctionBlock()); framework::OpDesc desc; @@ -204,7 +206,10 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, PADDLE_ENFORCE(!block->vars().empty(), "the block has no var-desc"); // Set attrs + SetAttr(desc.Proto(), "subgraph", block->SerializeAsString()); + SetAttr(desc.Proto(), "max_batch_size", argument->Get("max_batch_size")); + SetAttr(desc.Proto(), "workspace_size", argument->Get("workspace_size")); SetAttr(desc.Proto(), "engine_uniq_key", "trt-" + std::to_string(counter++)); SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes())); SetAttr(desc.Proto(), "output_name_mapping", output_mapping); @@ -248,7 +253,7 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) { *block_desc.Proto()->mutable_vars() = argument_->origin_program_desc->blocks(0).vars(); PADDLE_ENFORCE(!block_desc.Proto()->vars().empty()); - CreateTrtEngineOp(node, *argument_->main_dfg, block_desc.Proto()); + CreateTrtEngineOp(node, argument_, block_desc.Proto()); auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex); auto *op = main_block->add_ops(); PADDLE_ENFORCE(!node->pb_msg().empty(), "failed to set desc for block"); diff --git a/paddle/fluid/inference/analysis/subgraph_splitter.cc b/paddle/fluid/inference/analysis/subgraph_splitter.cc index b879067d2f2f6294c50e0adb21f9399a7c36698a..526bbbadfe90c3064d7c620cc22e30f7fef99088 100644 --- a/paddle/fluid/inference/analysis/subgraph_splitter.cc +++ b/paddle/fluid/inference/analysis/subgraph_splitter.cc @@ -309,6 +309,8 @@ void SubGraphFuse::operator()() { ReplaceNodesWithSubGraphs(); } void SubGraphFuse::ReplaceNodesWithSubGraphs() { auto subgraphs = SubGraphSplitter(graph_, node_inside_subgraph_teller_)(); for (auto &subgraph : subgraphs) { + if (subgraph.size() <= argument_->Get("minimum_subgraph_size")) + continue; std::unordered_set subgraph_uniq(subgraph.begin(), subgraph.end()); // replace this sub-graph with the first node. Two steps: 1. Create a Block // Node that contains this subgraph 2. Mark the nodes inside the sub-graph diff --git a/paddle/fluid/inference/analysis/subgraph_splitter.h b/paddle/fluid/inference/analysis/subgraph_splitter.h index a31afbe6933da8d3c7a88142cc12d63b98b55796..76e4fda0249e03c617d1b37c079dcd97f21387c1 100644 --- a/paddle/fluid/inference/analysis/subgraph_splitter.h +++ b/paddle/fluid/inference/analysis/subgraph_splitter.h @@ -20,6 +20,7 @@ limitations under the License. */ #include +#include "paddle/fluid/inference/analysis/argument.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/node.h" @@ -63,8 +64,11 @@ class SubGraphFuse { public: using NodeInsideSubgraphTeller = SubGraphSplitter::NodeInsideSubgraphTeller; - SubGraphFuse(DataFlowGraph *graph, const NodeInsideSubgraphTeller &teller) - : graph_(graph), node_inside_subgraph_teller_(teller) {} + SubGraphFuse(DataFlowGraph *graph, const NodeInsideSubgraphTeller &teller, + Argument *argument) + : graph_(graph), + node_inside_subgraph_teller_(teller), + argument_(argument) {} // The main method which run all the logic. void operator()(); @@ -76,6 +80,7 @@ class SubGraphFuse { private: DataFlowGraph *graph_; NodeInsideSubgraphTeller node_inside_subgraph_teller_; + Argument *argument_; }; } // namespace analysis diff --git a/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc b/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc index 531a170512f727d891aa6644ee08a60c25f16876..e1dc89fab5fb76d456b07c316ab1cabe6de23b26 100644 --- a/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc +++ b/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc @@ -66,10 +66,12 @@ TEST(SubGraphSplitter, Split) { TEST(SubGraphSplitter, Fuse) { auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__"); auto dfg = ProgramDescToDFG(desc); + Argument argument; + argument.Set("minimum_subgraph_size", new int(3)); size_t count0 = dfg.nodes.size(); - SubGraphFuse fuse(&dfg, teller); + SubGraphFuse fuse(&dfg, teller, &argument); fuse(); int count1 = 0; diff --git a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.cc index faf876de6d65d20cf7a084cd97392cfc8d791a42..cc1746ecb34c983d219693bcec17c8789c38fa9f 100644 --- a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.cc @@ -24,7 +24,7 @@ TensorRTSubGraphPass::TensorRTSubGraphPass( : node_inside_subgraph_teller_(teller) {} void TensorRTSubGraphPass::Run(DataFlowGraph *graph) { - SubGraphFuse(graph, node_inside_subgraph_teller_)(); + SubGraphFuse(graph, node_inside_subgraph_teller_, argument_)(); VLOG(4) << "debug info " << graph->HumanReadableInfo(false /*show_values*/, true /*show_functions*/); diff --git a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h index 219e3f5470f627e81005aabf94f9c72c33fd2eed..3545da9109d79964f36c3d7e738620cc2e0f9a6c 100644 --- a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h +++ b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h @@ -33,7 +33,10 @@ class TensorRTSubGraphPass : public DataFlowGraphPass { explicit TensorRTSubGraphPass(const NodeInsideSubgraphTeller& teller); - bool Initialize(Argument* argument) override { return true; } + bool Initialize(Argument* argument) override { + argument_ = argument; + return true; + } // This class get a sub-graph as input and determine whether to transform this // sub-graph into TensorRT. @@ -46,6 +49,7 @@ class TensorRTSubGraphPass : public DataFlowGraphPass { private: NodeInsideSubgraphTeller node_inside_subgraph_teller_; + Argument* argument_; }; } // namespace analysis diff --git a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc index 67a5af83d89b771536ea11be51b35244ff5c09d6..9748e24b06295a4e7c2995429e6588cd0f225fe6 100644 --- a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc +++ b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc @@ -36,6 +36,10 @@ TEST(TensorRTSubGraphPass, main) { }; Argument argument(FLAGS_inference_model_dir); + argument.Set("minimum_subgraph_size", new int(0)); + argument.Set("max_batch_size", new int(3)); + argument.Set("workspace_size", new int(1 << 20)); + argument.Set("precision_mode", new std::string("FP32")); DFG_GraphvizDrawPass::Config config{FLAGS_dot_dir, "origin"}; DFG_GraphvizDrawPass::Config config1{FLAGS_dot_dir, "fusion"}; diff --git a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc index 6c7e63971b2d93f58e219dbd93637c8d389deb7c..5ee6a5a93168f58770067f76ca7f6bb6f67b2965 100644 --- a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc +++ b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc @@ -35,8 +35,6 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { bool Init(const std::shared_ptr& parent_scope) { FLAGS_IA_enable_tensorrt_subgraph_engine = true; VLOG(3) << "Predictor::init()"; - FLAGS_tensorrt_max_batch_size = config_.max_batch_size; - FLAGS_tensorrt_workspace_size = config_.workspace_size; if (config_.use_gpu) { place_ = paddle::platform::CUDAPlace(config_.device); } else { @@ -92,6 +90,14 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { void OptimizeInferenceProgram() { // Analyze inference_program Argument argument; + + argument.Set("minimum_subgraph_size", + new int(config_.minimum_subgraph_size)); + argument.Set("max_batch_size", new int(config_.max_batch_size)); + argument.Set("workspace_size", new int(config_.workspace_size)); + argument.Set("precision_mode", + new std::string(config_.precision_mode)); + if (!config_.model_dir.empty()) { argument.fluid_model_dir.reset(new std::string(config_.model_dir)); } else { diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index 2b4e5ed73704041e18bdbce32338405f3601e082..01ea0d9c3ad37b3bcebe6853de77373810333776 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -194,6 +194,14 @@ struct MixedRTConfig : public NativeConfig { // For workspace_size, refer it from here: // https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#troubleshooting int workspace_size{1 << 30}; + // We transform the Ops that can be converted into TRT layer in the model, + // and aggregate these Ops into subgraphs for TRT execution. + // We set this variable to control the minimum number of nodes in the + // subgraph, 3 as default value. + int minimum_subgraph_size = 3; + // Reserved configuration + // We just support "FP32" now, "FP16" and "INT8" will be supported. + std::string precision_mode = "FP32"; }; // NOTE WIP, not stable yet. diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index d7ab2ac980af2cf3bd9d95bfdbfa1887ef9a64d7..70f9e397c96cf3fe92779778950f3df71b5a67c9 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -90,3 +90,13 @@ if (WITH_ANAKIN AND WITH_MKL) # only needed in CI DEPS inference_anakin_api_shared dynload_cuda SERIAL) endif() endif() + +if(WITH_GPU AND TENSORRT_FOUND) + set(TRT_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/trt") + if (NOT EXISTS ${TRT_MODEL_INSTALL_DIR}) + inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_test_models.tar.gz") + endif() + cc_test(test_trt_models SRCS trt_models_tester.cc + ARGS --dirname=${TRT_MODEL_INSTALL_DIR}/trt_test_models + DEPS paddle_inference_tensorrt_subgraph_engine) +endif() diff --git a/paddle/fluid/inference/tests/api/trt_models_tester.cc b/paddle/fluid/inference/tests/api/trt_models_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf320a0cbc2fff5f973c48768281e26d0fde232b --- /dev/null +++ b/paddle/fluid/inference/tests/api/trt_models_tester.cc @@ -0,0 +1,106 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "paddle/fluid/inference/analysis/analyzer.h" +#include "paddle/fluid/inference/api/paddle_inference_api.h" + +namespace paddle { +using paddle::contrib::MixedRTConfig; + +DEFINE_string(dirname, "", "Directory of the inference model."); + +NativeConfig GetConfigNative() { + NativeConfig config; + config.model_dir = FLAGS_dirname; + // LOG(INFO) << "dirname " << config.model_dir; + config.fraction_of_gpu_memory = 0.45; + config.use_gpu = true; + config.device = 0; + return config; +} + +MixedRTConfig GetConfigTRT() { + MixedRTConfig config; + config.model_dir = FLAGS_dirname; + config.use_gpu = true; + config.fraction_of_gpu_memory = 0.2; + config.device = 0; + config.max_batch_size = 3; + return config; +} + +void CompareTensorRTWithFluid(int batch_size, std::string model_dirname) { + NativeConfig config0 = GetConfigNative(); + config0.model_dir = model_dirname; + + MixedRTConfig config1 = GetConfigTRT(); + config1.model_dir = model_dirname; + config1.max_batch_size = batch_size; + + auto predictor0 = + CreatePaddlePredictor(config0); + auto predictor1 = + CreatePaddlePredictor(config1); + // Prepare inputs + int height = 224; + int width = 224; + float *data = new float[batch_size * 3 * height * width]; + memset(data, 0, sizeof(float) * (batch_size * 3 * height * width)); + data[0] = 1.0f; + + // Prepare inputs + PaddleTensor tensor; + tensor.name = "input_0"; + tensor.shape = std::vector({batch_size, 3, height, width}); + tensor.data = PaddleBuf(static_cast(data), + sizeof(float) * (batch_size * 3 * height * width)); + tensor.dtype = PaddleDType::FLOAT32; + std::vector paddle_tensor_feeds(1, tensor); + + // Prepare outputs + std::vector outputs0; + std::vector outputs1; + CHECK(predictor0->Run(paddle_tensor_feeds, &outputs0)); + + CHECK(predictor1->Run(paddle_tensor_feeds, &outputs1, batch_size)); + + // Get output. + ASSERT_EQ(outputs0.size(), 1UL); + ASSERT_EQ(outputs1.size(), 1UL); + + const size_t num_elements = outputs0.front().data.length() / sizeof(float); + const size_t num_elements1 = outputs1.front().data.length() / sizeof(float); + EXPECT_EQ(num_elements, num_elements1); + + auto *data0 = static_cast(outputs0.front().data.data()); + auto *data1 = static_cast(outputs1.front().data.data()); + + ASSERT_GT(num_elements, 0UL); + for (size_t i = 0; i < std::min(num_elements, num_elements1); i++) { + EXPECT_NEAR(data0[i], data1[i], 1e-3); + } +} + +TEST(trt_models_test, main) { + std::vector infer_models = {"mobilenet", "resnet50", + "resnext50"}; + for (auto &model_dir : infer_models) { + CompareTensorRTWithFluid(1, FLAGS_dirname + "/" + model_dir); + } +} +} // namespace paddle diff --git a/paddle/fluid/operators/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt_engine_op.cc index 1048d3017140c9e31426a1580b2862667116a024..41a5786fe8c3295390144732221280e152d0a15a 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt_engine_op.cc @@ -22,8 +22,6 @@ namespace paddle { DEFINE_int32(tensorrt_engine_batch_size, 1, "the batch_size of TensorRT"); -DEFINE_int32(tensorrt_max_batch_size, 1, "TensorRT maximum batch size"); -DEFINE_int32(tensorrt_workspace_size, 16 << 20, "TensorRT workspace size"); namespace operators { @@ -34,6 +32,8 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Ys", "A list of outputs").AsDuplicable(); AddAttr("subgraph", "the subgraph."); AddAttr("engine_uniq_key", "unique key for the TRT engine."); + AddAttr("max_batch_size", "the maximum batch size."); + AddAttr("workspace_size", "the workspace size."); AddComment("TensorRT engine operator."); } }; diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index 69173ff5178d32634f9ab291b7d709a3f91cb368..3c78c29c1a30d74947be84cd2b52ad308e732a2d 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -28,8 +28,6 @@ namespace paddle { DECLARE_int32(tensorrt_engine_batch_size); -DECLARE_int32(tensorrt_max_batch_size); -DECLARE_int32(tensorrt_workspace_size); namespace operators { @@ -92,14 +90,14 @@ class TensorRTEngineKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto engine_name = context.Attr("engine_uniq_key"); + int max_batch_size = context.Attr("max_batch_size"); if (!Singleton::Global().HasEngine(engine_name)) { Prepare(context); } auto* engine = Singleton::Global().Get(engine_name); auto input_names = context.op().Inputs("Xs"); PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs"); - PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size, - FLAGS_tensorrt_max_batch_size); + PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size, max_batch_size); std::vector output_maps = context.Attr>("output_name_mapping"); @@ -173,8 +171,9 @@ class TensorRTEngineKernel : public framework::OpKernel { // Get the ProgramDesc and pass to convert. framework::proto::BlockDesc block_desc; block_desc.ParseFromString(context.Attr("subgraph")); - int max_batch = FLAGS_tensorrt_max_batch_size; - auto max_workspace = FLAGS_tensorrt_workspace_size; + int max_batch_size = context.Attr("max_batch_size"); + int workspace_size = context.Attr("workspace_size"); + auto params = context.Attr>("parameters"); std::unordered_set parameters; for (const auto& param : params) { @@ -186,7 +185,7 @@ class TensorRTEngineKernel : public framework::OpKernel { // TODO(Superjomn) replace this with a different stream auto* engine = Singleton::Global().Create( - max_batch, max_workspace, nullptr /*engine hold its own stream*/, + max_batch_size, workspace_size, nullptr /*engine hold its own stream*/, context.Attr("engine_uniq_key"), boost::get(context.GetPlace()).device); diff --git a/paddle/fluid/operators/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt_engine_op_test.cc index 27c1d29762b3de5e57f877b271aae52e71eb7cf9..e21101e8d12f210af08284dbcebe5c14c1af6dd3 100644 --- a/paddle/fluid/operators/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt_engine_op_test.cc @@ -58,8 +58,6 @@ void AddTensorToBlockDesc(framework::proto::BlockDesc* block, using inference::analysis::SetAttr; TEST(TensorRTEngineOp, manual) { - FLAGS_tensorrt_engine_batch_size = 2; - FLAGS_tensorrt_max_batch_size = 2; framework::ProgramDesc program; auto* block_ = program.Proto()->add_blocks(); block_->set_idx(0); @@ -101,6 +99,8 @@ TEST(TensorRTEngineOp, manual) { engine_op_desc.SetOutput("Ys", std::vector({"z0"})); SetAttr(engine_op_desc.Proto(), "subgraph", block_->SerializeAsString()); + SetAttr(engine_op_desc.Proto(), "max_batch_size", 2); + SetAttr(engine_op_desc.Proto(), "workspace_size", 2 << 10); SetAttr(engine_op_desc.Proto(), "engine_uniq_key", "a_engine"); SetAttr>(engine_op_desc.Proto(), "parameters", std::vector({})); @@ -129,8 +129,6 @@ TEST(TensorRTEngineOp, manual) { } void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { - FLAGS_tensorrt_engine_batch_size = batch_size; - FLAGS_tensorrt_max_batch_size = batch_size; framework::ProgramDesc program; framework::Scope scope; platform::CUDAPlace place; @@ -195,8 +193,8 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { SetAttr(engine_op_desc.Proto(), "subgraph", block_->SerializeAsString()); - SetAttr(engine_op_desc.Proto(), "max_batch", batch_size); - SetAttr(engine_op_desc.Proto(), "max_workspace", 2 << 10); + SetAttr(engine_op_desc.Proto(), "max_batch_size", batch_size); + SetAttr(engine_op_desc.Proto(), "workspace_size", 2 << 10); SetAttr>( engine_op_desc.Proto(), "parameters", std::vector({"y0", "y1", "y2", "y3"}));