diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 88e6749223008f688c2d06020de532b060ec29b3..8a43229af7971f877d897bb7470fb6f4184b1082 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -231,6 +231,7 @@ struct Argument { TensorRtUseStaticEngine, bool); DECL_ARGUMENT_FIELD(tensorrt_use_calib_mode, TensorRtUseCalibMode, bool); + DECL_ARGUMENT_FIELD(tensorrt_use_cuda_graph, TensorRtUseCudaGraph, bool); DECL_ARGUMENT_FIELD(tensorrt_use_varseqlen, TensorRtUseOSS, bool); DECL_ARGUMENT_FIELD(tensorrt_with_interleaved, TensorRtWithInterleaved, bool); DECL_ARGUMENT_FIELD(tensorrt_transformer_posid, diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 1d87edcd3404c2aa166da434f5a1ea8a72cc2184..4051511906b1be2b735cb2985faed7f9eb910e1b 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -165,6 +165,8 @@ void IRPassManager::CreatePasses(Argument *argument, new AnalysisConfig::Precision(precision_mode)); pass->Set("context_memory_sharing", new bool(argument->trt_engine_memory_sharing())); + pass->Set("use_cuda_graph", + new bool(argument->tensorrt_use_cuda_graph())); bool use_static_engine = argument->tensorrt_use_static_engine(); bool model_from_memory = argument->model_from_memory(); std::string optim_cache_dir = argument->optim_cache_dir(); diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index b64d2bd16ce1aa30e8de3ccdae9593fe267f4a9e..33e5622436a0ff27d859789f972065898236434d 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -101,6 +101,22 @@ void OutputProcess(framework::ir::Graph *graph, } } +// Determine whether the whole graph offload to tensorrt. If so we can try to +// enable optimization such as cudaGraph. +bool AllNodesLowerToTrtPostProcess(framework::ir::Graph *graph) { + std::unordered_set trt_nodes_set{ + "feed", "fetch", "tensorrt_engine"}; + bool all_nodes_offload_to_trt = true; + for (auto *node : graph->Nodes()) { + if (node->IsOp()) { + if (!trt_nodes_set.count(node->Op()->Type())) { + all_nodes_offload_to_trt = false; + break; + } + } + } + return all_nodes_offload_to_trt; +} } // namespace using framework::ir::Node; @@ -124,6 +140,7 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( auto enable_int8 = Get("enable_int8"); auto use_calib_mode = Get("use_calib_mode"); + bool use_cuda_graph = Get("use_cuda_graph"); bool no_calib_int8 = enable_int8 && !(use_calib_mode); auto trt_disabled_ops = Get>("trt_disabled_ops"); auto with_dynamic_shape = Get("with_dynamic_shape"); @@ -165,13 +182,11 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( // those parameter already exist in trt, and should not have another copy in // fluid. std::vector repetitive_params; + std::vector engine_names; for (auto *node : graph->Nodes()) { if (node->IsOp() && !framework::ir::Agent(node).subgraph()->empty()) { - CreateTensorRTOp(node, graph, graph_param_names, &repetitive_params); - std::unordered_set nodes2remove( - framework::ir::Agent(node).subgraph()->begin(), - framework::ir::Agent(node).subgraph()->end()); - framework::ir::GraphSafeRemoveNodes(graph, nodes2remove); + engine_names.push_back(CreateTensorRTOp( + node, graph, graph_param_names, &repetitive_params, use_cuda_graph)); } } @@ -184,6 +199,32 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( framework::ir::GraphSafeRemoveNodes(graph, nodes2remove); graph->Set(framework::ir::kRepetitiveParamAttr, new std::vector(repetitive_params)); + + bool all_nodes_offload_to_trt = AllNodesLowerToTrtPostProcess(graph); + if (all_nodes_offload_to_trt) { + LOG(INFO) << "The entire graph is offloaded to TensorRT."; + } + if (use_cuda_graph && !all_nodes_offload_to_trt) { + LOG_FIRST_N(WARNING, 1) + << "You have enabled CudaGraph, but not the entire graph offload to " + "trt, now return to normal mode."; + use_cuda_graph = false; + } + if (use_cuda_graph && all_nodes_offload_to_trt) { + for (auto &name : engine_names) { + PADDLE_ENFORCE_EQ( + paddle::inference::Singleton< + inference::tensorrt::TRTEngineManager>::Global() + .Has(name), + true, + platform::errors::PreconditionNotMet( + "TRTEnegineManager shoud has engine %s, but not found.", name)); + paddle::inference::Singleton< + inference::tensorrt::TRTEngineManager>::Global() + .Get(name) + ->SetAllNodesLowerToTrt(use_cuda_graph); + } + } } std::string GenerateEngineKey(const std::set &engine_inputs, @@ -191,6 +232,7 @@ std::string GenerateEngineKey(const std::set &engine_inputs, const std::string &predictor_id, const std::string &max_batch_size, const std::string &precision, + bool use_cuda_graph, const bool for_calibration) { std::string engine_hash_key = ""; for (auto name : engine_inputs) { @@ -209,17 +251,21 @@ std::string GenerateEngineKey(const std::set &engine_inputs, engine_hash_key += "#"; engine_hash_key += precision; + engine_hash_key += "#"; + engine_hash_key += use_cuda_graph; + auto engine_key = std::to_string(std::hash()(engine_hash_key)); VLOG(2) << "TRT engine hash key: " << engine_hash_key; VLOG(2) << "TRT engine key: " << engine_key; return engine_key; } -void TensorRtSubgraphPass::CreateTensorRTOp( +std::string TensorRtSubgraphPass::CreateTensorRTOp( framework::ir::Node *node, framework::ir::Graph *graph, const std::vector &graph_params, - std::vector *repetitive_params) const { + std::vector *repetitive_params, + bool use_cuda_graph) const { auto *op_desc = node->Op(); auto &subgraph = *framework::ir::Agent(node).subgraph(); PADDLE_ENFORCE_EQ(subgraph.empty(), @@ -506,6 +552,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( std::to_string(0), std::to_string(max_batch_size), std::to_string(static_cast(precision_mode)), + use_cuda_graph, false); auto calibration_engine_key = GenerateEngineKey(input_names_with_id, @@ -513,6 +560,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( std::to_string(0), std::to_string(max_batch_size), std::to_string(static_cast(precision_mode)), + use_cuda_graph, true); auto predictor_id = Get("predictor_id"); @@ -547,7 +595,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( (enable_int8 && calibration_data.size() == 0 && use_calib_mode); if (calibration_mode) { // calibraion mode means generate int8 calibration table data process. - return; + return calibration_engine_key; } std::copy(params_not_shared.begin(), @@ -582,6 +630,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp( "recommend using the same TRT version at runtime."; } + std::unordered_set nodes2remove( + framework::ir::Agent(node).subgraph()->begin(), + framework::ir::Agent(node).subgraph()->end()); + framework::ir::GraphSafeRemoveNodes(graph, nodes2remove); + // Setting the disable_trt_plugin_fp16 to true means that TRT plugin will not // run fp16. // When running fp16, the output accuracy of the model will be affected, @@ -628,7 +681,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( LOG(INFO) << "Load TRT Optimized Info from " << GetTrtEngineSerializedPath( Get("model_opt_cache_dir"), engine_key); - return; + return engine_key + std::to_string(predictor_id); } catch (const std::exception &exp) { LOG(WARNING) << "Fail to load TRT Optimized Info from " @@ -643,7 +696,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( // If with_dynamic_shape is configured,but min_input_shape is empty, // create trt engine in runtime instead of in pass. if (with_dynamic_shape && min_input_shape.empty()) { - return; + return engine_key + std::to_string(predictor_id); } // the following code will NOT run in following situation: @@ -676,6 +729,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp( << GetTrtEngineSerializedPath( Get("model_opt_cache_dir"), engine_key); } + + return engine_key + std::to_string(predictor_id); } } // namespace analysis diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h index 1bc86d554b50f4b08e51de2fcda3d4e3d1b94c12..a79c41f6a075f19e50d7410170136bc0903a6cde 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h @@ -42,10 +42,11 @@ class TensorRtSubgraphPass : public framework::ir::FusePassBase { void ApplyImpl(framework::ir::Graph *graph) const override; private: - void CreateTensorRTOp(framework::ir::Node *x, - framework::ir::Graph *graph, - const std::vector &graph_params, - std::vector *repetitive_params) const; + std::string CreateTensorRTOp(framework::ir::Node *x, + framework::ir::Graph *graph, + const std::vector &graph_params, + std::vector *repetitive_params, + bool use_cuda_graph) const; void CleanIntermediateOutputs(framework::ir::Node *node); }; diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 3fa947ce27daf58ed14fbac499b047d0eb63ea8a..4f73fb23c6a5a81294b455c6afc27d09c6cac592 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -16,6 +16,7 @@ #include #include +#include "glog/logging.h" #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/paddle_analysis_config.h" #include "paddle/fluid/inference/api/paddle_pass_builder.h" @@ -442,6 +443,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(trt_dla_core_); CP_MEMBER(trt_use_static_engine_); CP_MEMBER(trt_use_calib_mode_); + CP_MEMBER(trt_use_cuda_graph_); CP_MEMBER(trt_use_varseqlen_); CP_MEMBER(trt_with_interleaved_); CP_MEMBER(tensorrt_transformer_posid_); @@ -719,7 +721,8 @@ void AnalysisConfig::EnableTensorRtEngine( int min_subgraph_size, AnalysisConfig::Precision precision_mode, bool use_static, - bool use_calib_mode) { + bool use_calib_mode, + bool use_cuda_graph) { #ifdef PADDLE_WITH_TENSORRT if (!use_gpu()) { LOG(ERROR) << "To use TensorRT engine, please call EnableUseGpu() first"; @@ -733,6 +736,11 @@ void AnalysisConfig::EnableTensorRtEngine( tensorrt_precision_mode_ = precision_mode; trt_use_static_engine_ = use_static; trt_use_calib_mode_ = use_calib_mode; + trt_use_cuda_graph_ = use_cuda_graph; + if (use_cuda_graph) { + LOG_FIRST_N(INFO, 1) << "You have enabled Trt Cuda Graph, you must ensure " + "that the input Shape remains unchanged."; + } Update(); #else @@ -1313,6 +1321,8 @@ std::string AnalysisConfig::Summary() { trt_use_static_engine_ ? "true" : "false"}); os.InsertRow( {"tensorrt_use_calib_mode", trt_use_calib_mode_ ? "true" : "false"}); + os.InsertRow( + {"tensorrt_use_cuda_graph", trt_use_cuda_graph_ ? "true" : "false"}); // dynamic_shape os.InsertRow({"tensorrt_enable_dynamic_shape", diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 466eee7d9414f633fe0fbecb01c86bcb7af5f7bf..5495f929e8895b79e3c82b93dce42029d6823e98 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1352,6 +1352,7 @@ void AnalysisPredictor::PrepareArgument() { argument_->SetTensorRtDLACore(config_.trt_dla_core_); argument_->SetTensorRtUseStaticEngine(config_.trt_use_static_engine_); argument_->SetTensorRtUseCalibMode(config_.trt_use_calib_mode_); + argument_->SetTensorRtUseCudaGraph(config_.trt_use_cuda_graph_); argument_->SetCloseTrtPluginFp16(config_.disable_trt_plugin_fp16_); argument_->SetTensorRtShapeRangeInfoPath(config_.shape_range_info_path()); argument_->SetTensorRtAllowBuildAtRuntime( diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 3300fe8a9b8dcc1142d79de7b9e0a1fade68212c..d810442810af754b09a63135a7b336e39960bc67 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -586,6 +586,9 @@ struct PD_INFER_DECL AnalysisConfig { /// \param use_static Serialize optimization information to disk for reusing. /// \param use_calib_mode Use TRT int8 calibration(post training /// quantization). + /// \param use_cuda_graph Use CudaGraph to reduce the time consumption of + /// enqueue. Note that this option can only be enabled when your input is + /// constant (including the batch dimension). /// /// void EnableTensorRtEngine(int64_t workspace_size = 1 << 30, @@ -593,7 +596,8 @@ struct PD_INFER_DECL AnalysisConfig { int min_subgraph_size = 3, Precision precision = Precision::kFloat32, bool use_static = false, - bool use_calib_mode = true); + bool use_calib_mode = true, + bool use_cuda_graph = false); /// /// \brief A boolean state telling whether the TensorRT engine is used. /// @@ -1114,6 +1118,7 @@ struct PD_INFER_DECL AnalysisConfig { Precision tensorrt_precision_mode_{Precision::kFloat32}; bool trt_use_static_engine_{false}; bool trt_use_calib_mode_{true}; + bool trt_use_cuda_graph_{false}; bool trt_use_varseqlen_{false}; bool trt_with_interleaved_{false}; std::string tensorrt_transformer_posid_{""}; diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 4710966030d67618c2d95f149c7b5f2ec32afb19..b2ee3f47965958b0571151fdee5a61a89514a4e7 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/enforce.h" namespace paddle { namespace inference { @@ -129,12 +130,60 @@ void TensorRTEngine::Execute(int batch_size, phi::Stream(reinterpret_cast(stream))); infer_context->setDeviceMemory(context_memory); } + + // TODO(wilber): Is cudaGraph has conflict with memory sharing? + if (startup_with_cudagraph_ && !cudagraph_inited_) { + // Avoid capturing initialization calls by executing the enqueue function at + // least once before starting CUDA graph capture. + const auto ret = Enqueue(infer_context, buffers, batch_size, stream); + PADDLE_ENFORCE_EQ( + ret, + true, + phi::errors::PreconditionNotMet("Trt CudaGraph test run failed.")); + cudaStreamSynchronize(stream); + + cuda_graph_.BeginCapture(stream); + // The built TRT engine may contain operations that are not permitted under + // CUDA graph capture mode. When the stream is capturing, the call may + // return false if the current CUDA graph capture fails. + if (Enqueue(infer_context, buffers, batch_size, stream)) { + cuda_graph_.EndCapture(stream); + cudagraph_inited_ = true; + } else { + cuda_graph_.EndCaptureOnError(stream); + // Ensure any CUDA error has been cleaned up. + PADDLE_ENFORCE_GPU_SUCCESS(cudaGetLastError()); + LOG(WARNING) << "The built TensorRT engine contains operations that are " + "not permitted under " + "CUDA graph capture mode. The specified UseCudaGraph " + "flag has been ignored. The inference will be " + "launched without using CUDA graph launch."; + cudagraph_inited_ = false; + } + startup_with_cudagraph_ = false; + } + + Enqueue(infer_context, buffers, batch_size, stream); +} + +bool TensorRTEngine::Enqueue(nvinfer1::IExecutionContext *context, + std::vector *buffers, + int batch_size, + cudaStream_t stream) { + if (cudagraph_inited_) { + VLOG(1) << "cuda_graph init success, so we will use cuda graph launch the " + "entire graph."; + return cuda_graph_.Launch(stream); + } + + bool ret; if (!with_dynamic_shape()) { - infer_context->enqueue(batch_size, buffers->data(), stream, nullptr); + ret = context->enqueue(batch_size, buffers->data(), stream, nullptr); } else { - infer_context->enqueueV2(buffers->data(), stream, nullptr); + ret = context->enqueueV2(buffers->data(), stream, nullptr); } SetRuntimeBatch(batch_size); + return ret; } void TensorRTEngine::FreezeNetwork() { diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 8e1531352137f1bdae86712bdd90caa2171435a1..1906b8be5171e1f6be1522cffc1a4582f07e9eba 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -49,6 +49,64 @@ namespace paddle { namespace inference { namespace tensorrt { +// The code is mainly from TensorRT, thanks to the project. +class TrtCudaGraph { + public: + TrtCudaGraph() = default; + ~TrtCudaGraph() { + if (cuda_graph_exec_) { + cudaGraphExecDestroy(cuda_graph_exec_); + } + } + + void BeginCapture(cudaStream_t stream) { + PADDLE_ENFORCE_GPU_SUCCESS( + cudaStreamBeginCapture(stream, cudaStreamCaptureModeThreadLocal)); + } + + bool Launch(cudaStream_t stream) { + return cudaGraphLaunch(cuda_graph_exec_, stream); + } + + void EndCapture(cudaStream_t stream) { + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamEndCapture(stream, &cuda_graph_)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphInstantiate( + &cuda_graph_exec_, cuda_graph_, nullptr, nullptr, 0)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphDestroy(cuda_graph_)); + } + + void EndCaptureOnError(cudaStream_t stream) { + // There are two possibilities why stream capture would fail: + // (1) stream is in cudaErrorStreamCaptureInvalidated state. + // (2) TRT reports a failure. + // In case (1), the returning cuda_graph_ should be nullptr. + // In case (2), the returning cuda_graph_ is not nullptr, but it should not + // be used. + const auto ret = cudaStreamEndCapture(stream, &cuda_graph_); + if (ret == cudaErrorStreamCaptureInvalidated) { + PADDLE_ENFORCE_EQ(cuda_graph_ == nullptr, + true, + platform::errors::PreconditionNotMet( + "CudaGraph capture stream failed.")); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(ret); + PADDLE_ENFORCE_NOT_NULL( + cuda_graph_, + phi::errors::PreconditionNotMet("CudaGraph capture stream failed.")); + PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphDestroy(cuda_graph_)); + cuda_graph_ = nullptr; + } + // Clean up any cuda error. + cudaGetLastError(); + LOG(WARNING) << "The TRT CUDA graph capture on the stream has failed."; + } + + private: + DISABLE_COPY_AND_ASSIGN(TrtCudaGraph); + cudaGraph_t cuda_graph_{}; + cudaGraphExec_t cuda_graph_exec_{}; +}; + namespace plugin { class PluginTensorRT; } // namespace plugin @@ -445,6 +503,11 @@ class TensorRTEngine { std::vector* buffers, cudaStream_t stream = nullptr); + bool Enqueue(nvinfer1::IExecutionContext* context, + std::vector* buffers, + int batch, + cudaStream_t stream); + nvinfer1::INetworkDefinition* network() { return infer_network_.get(); } ShapeMapType min_input_shape() { return min_input_shape_; } @@ -682,6 +745,11 @@ class TensorRTEngine { context_memory_sharing_ = context_memory_sharing; } + void SetAllNodesLowerToTrt(bool all_nodes_offload_to_trt) { + // all nodes are in trt, so we can use cudaGraph to optimize runtime. + startup_with_cudagraph_ = all_nodes_offload_to_trt; + } + private: // Each ICudaEngine object is bound to a specific GPU when it is instantiated, // ensure that the thread is associated with the correct device by calling @@ -744,6 +812,11 @@ class TensorRTEngine { infer_ptr ihost_memory_; std::unordered_map quant_dynamic_range_; + // cudagraph related + TrtCudaGraph cuda_graph_; + bool cudagraph_inited_{false}; + bool startup_with_cudagraph_{false}; + std::unordered_map attrs_; std::unordered_map> attr_dels_; #if IS_TRT_VERSION_GE(6000) diff --git a/paddle/fluid/inference/tensorrt/test_engine.cc b/paddle/fluid/inference/tensorrt/test_engine.cc index 9a06b2e65ef100a4903d732d3e02058197f2a972..23a0df7316724a9ab6eece77a73a42e1199839d4 100644 --- a/paddle/fluid/inference/tensorrt/test_engine.cc +++ b/paddle/fluid/inference/tensorrt/test_engine.cc @@ -274,6 +274,7 @@ TEST_F(TensorRTEngineTest, test_pool2d) { buffers[0] = reinterpret_cast(x_v_gpu_data); buffers[1] = reinterpret_cast(y_gpu_data); + engine_->SetAllNodesLowerToTrt(true); engine_->Execute(2, &buffers, ctx_->stream()); LOG(INFO) << "to get output"; diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index a8bfd5a9917e7d55b93648a49c804f181d44582f..e861c5b5bbe90b9d1799864a31489c8273c31f22 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -840,7 +840,8 @@ void BindAnalysisConfig(py::module *m) { py::arg("min_subgraph_size") = 3, py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32, py::arg("use_static") = false, - py::arg("use_calib_mode") = true) + py::arg("use_calib_mode") = true, + py::arg("use_cuda_graph") = false) .def("enable_tensorrt_memory_optim", &AnalysisConfig::EnableTensorRTMemoryOptim, py::arg("engine_memory_sharing") = true, diff --git a/test/cpp/inference/api/CMakeLists.txt b/test/cpp/inference/api/CMakeLists.txt index 01010715c8104b8d192abbbe51f1584420e3de95..c4045e9ca603abb27b148d20eaeb91ad50e6d028 100644 --- a/test/cpp/inference/api/CMakeLists.txt +++ b/test/cpp/inference/api/CMakeLists.txt @@ -1377,7 +1377,7 @@ if(WITH_TESTING AND WITH_INFERENCE_API_TEST) set_tests_properties(test_analyzer_ernie PROPERTIES TIMEOUT 120) endif() if(WITH_GPU AND TENSORRT_FOUND) - set_tests_properties(trt_mobilenet_test PROPERTIES TIMEOUT 120) + set_tests_properties(trt_mobilenet_test PROPERTIES TIMEOUT 240) if(WITH_MKLDNN) set_tests_properties(test_analyzer_bfloat16_resnet50 PROPERTIES TIMEOUT 120) diff --git a/test/cpp/inference/api/trt_mobilenet_test.cc b/test/cpp/inference/api/trt_mobilenet_test.cc index 5c0519c067f56b9d4934496dee169e1ea2fdb865..7cae99e0d3479a552876b36b8bae728bc67b7c4b 100644 --- a/test/cpp/inference/api/trt_mobilenet_test.cc +++ b/test/cpp/inference/api/trt_mobilenet_test.cc @@ -99,4 +99,30 @@ TEST(PredictorPool, use_gpu) { predictor->Run(); } +TEST(PredictorPool, use_trt_cuda_graph) { + std::string model_dir = FLAGS_infer_model + "/" + "mobilenet"; + Config config; + config.EnableUseGpu(100, 0); + config.SetModel(model_dir); + config.EnableTensorRtEngine( + 1 << 20, 1, 3, PrecisionType::kFloat32, false, false, true); + config.Exp_DisableTensorRtOPs({"fc"}); + config.EnableTensorRtDLA(0); + services::PredictorPool pred_pool(config, 1); + + auto predictor = pred_pool.Retrive(0); + auto input_names = predictor->GetInputNames(); + auto input_t = predictor->GetInputHandle(input_names[0]); + std::vector in_shape = {1, 3, 224, 224}; + int in_num = + std::accumulate(in_shape.begin(), in_shape.end(), 1, [](int &a, int &b) { + return a * b; + }); + + std::vector input(in_num, 0); + input_t->Reshape(in_shape); + input_t->CopyFromCpu(input.data()); + predictor->Run(); +} + } // namespace paddle_infer