未验证 提交 1321c479 编写于 作者: P Pei Yang 提交者: GitHub

add more info in trt engine serialization (#31434)

上级 9ebf05b0
...@@ -83,16 +83,29 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( ...@@ -83,16 +83,29 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
std::string GenerateEngineKey(const std::set<std::string> &engine_inputs, std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
const std::set<std::string> &engine_outputs, const std::set<std::string> &engine_outputs,
const std::string &predictor_id) { const std::string &predictor_id,
const std::string &max_batch_size,
const std::string &precision,
const std::string &use_calib_mode) {
std::string engine_hash_key = ""; std::string engine_hash_key = "";
for (auto name : engine_inputs) { for (auto name : engine_inputs) {
engine_hash_key += name; engine_hash_key += name;
engine_hash_key += "#";
} }
for (auto name : engine_outputs) { for (auto name : engine_outputs) {
engine_hash_key += name; engine_hash_key += name;
engine_hash_key += "#";
} }
engine_hash_key += predictor_id; engine_hash_key += predictor_id;
engine_hash_key += "#";
engine_hash_key += max_batch_size;
engine_hash_key += "#";
engine_hash_key += precision;
engine_hash_key += "#";
engine_hash_key += use_calib_mode;
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key)); auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
VLOG(2) << "TRT engine hash key: " << engine_hash_key;
VLOG(2) << "TRT engine key: " << engine_key;
return engine_key; return engine_key;
} }
...@@ -245,8 +258,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -245,8 +258,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
// TODO(NHZlX) // TODO(NHZlX)
// There are models with the same structure but the different parameters, // There are models with the same structure but the different parameters,
// when running in the 'use_serialize' mode, there is a bug. // when running in the 'use_serialize' mode, there is a bug.
auto engine_key = GenerateEngineKey(input_names_with_id, output_names_with_id, auto engine_key = GenerateEngineKey(
std::to_string(0)); input_names_with_id, output_names_with_id, std::to_string(0),
std::to_string(Get<int>("max_batch_size")),
std::to_string(static_cast<int>(precision_mode)),
std::to_string(static_cast<int>(use_calib_mode)));
auto predictor_id = Get<int>("predictor_id"); auto predictor_id = Get<int>("predictor_id");
// Get "" when there is no cached calibration table data. // Get "" when there is no cached calibration table data.
...@@ -359,6 +375,9 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -359,6 +375,9 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
GetTrtEngineSerializedPath(Get<std::string>("model_opt_cache_dir"), GetTrtEngineSerializedPath(Get<std::string>("model_opt_cache_dir"),
engine_key), engine_key),
trt_engine_serialized_data); trt_engine_serialized_data);
LOG(INFO) << "Save TRT Optimized Info to "
<< GetTrtEngineSerializedPath(
Get<std::string>("model_opt_cache_dir"), engine_key);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册