From 1321c47950c3286e2548a2aaca550ec1bedb5d7b Mon Sep 17 00:00:00 2001 From: Pei Yang Date: Fri, 5 Mar 2021 13:31:52 +0800 Subject: [PATCH] add more info in trt engine serialization (#31434) --- .../ir_passes/tensorrt_subgraph_pass.cc | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 75111701f1f..8a14e168ca4 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -83,16 +83,29 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( std::string GenerateEngineKey(const std::set &engine_inputs, const std::set &engine_outputs, - const std::string &predictor_id) { + const std::string &predictor_id, + const std::string &max_batch_size, + const std::string &precision, + const std::string &use_calib_mode) { std::string engine_hash_key = ""; for (auto name : engine_inputs) { engine_hash_key += name; + engine_hash_key += "#"; } for (auto name : engine_outputs) { engine_hash_key += name; + engine_hash_key += "#"; } engine_hash_key += predictor_id; + engine_hash_key += "#"; + engine_hash_key += max_batch_size; + engine_hash_key += "#"; + engine_hash_key += precision; + engine_hash_key += "#"; + engine_hash_key += use_calib_mode; auto engine_key = std::to_string(std::hash()(engine_hash_key)); + VLOG(2) << "TRT engine hash key: " << engine_hash_key; + VLOG(2) << "TRT engine key: " << engine_key; return engine_key; } @@ -245,8 +258,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp( // TODO(NHZlX) // There are models with the same structure but the different parameters, // when running in the 'use_serialize' mode, there is a bug. - auto engine_key = GenerateEngineKey(input_names_with_id, output_names_with_id, - std::to_string(0)); + auto engine_key = GenerateEngineKey( + input_names_with_id, output_names_with_id, std::to_string(0), + std::to_string(Get("max_batch_size")), + std::to_string(static_cast(precision_mode)), + std::to_string(static_cast(use_calib_mode))); auto predictor_id = Get("predictor_id"); // Get "" when there is no cached calibration table data. @@ -359,6 +375,9 @@ void TensorRtSubgraphPass::CreateTensorRTOp( GetTrtEngineSerializedPath(Get("model_opt_cache_dir"), engine_key), trt_engine_serialized_data); + LOG(INFO) << "Save TRT Optimized Info to " + << GetTrtEngineSerializedPath( + Get("model_opt_cache_dir"), engine_key); } } -- GitLab