未验证 提交 cac9635a 编写于 作者: P Pei Yang 提交者: GitHub

[Paddle-TRT] Fix engine key in trt int8 calibration (#31513)

* fix engine key in trt int8 calibration

* fix unit test
上级 50ac7dbf
...@@ -86,7 +86,7 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs, ...@@ -86,7 +86,7 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
const std::string &predictor_id, const std::string &predictor_id,
const std::string &max_batch_size, const std::string &max_batch_size,
const std::string &precision, const std::string &precision,
const std::string &use_calib_mode) { const bool for_calibration) {
std::string engine_hash_key = ""; std::string engine_hash_key = "";
for (auto name : engine_inputs) { for (auto name : engine_inputs) {
engine_hash_key += name; engine_hash_key += name;
...@@ -97,12 +97,13 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs, ...@@ -97,12 +97,13 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
engine_hash_key += "#"; engine_hash_key += "#";
} }
engine_hash_key += predictor_id; engine_hash_key += predictor_id;
engine_hash_key += "#"; if (!for_calibration) {
engine_hash_key += max_batch_size; engine_hash_key += "#";
engine_hash_key += max_batch_size;
}
engine_hash_key += "#"; engine_hash_key += "#";
engine_hash_key += precision; engine_hash_key += precision;
engine_hash_key += "#";
engine_hash_key += use_calib_mode;
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key)); auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
VLOG(2) << "TRT engine hash key: " << engine_hash_key; VLOG(2) << "TRT engine hash key: " << engine_hash_key;
VLOG(2) << "TRT engine key: " << engine_key; VLOG(2) << "TRT engine key: " << engine_key;
...@@ -258,24 +259,31 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -258,24 +259,31 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
// TODO(NHZlX) // TODO(NHZlX)
// There are models with the same structure but the different parameters, // There are models with the same structure but the different parameters,
// when running in the 'use_serialize' mode, there is a bug. // when running in the 'use_serialize' mode, there is a bug.
// serialization is affected by max_batch_size, but calibration is not.
// So we use seperate engine keys in serialization and calibration.
auto engine_key = GenerateEngineKey( auto engine_key = GenerateEngineKey(
input_names_with_id, output_names_with_id, std::to_string(0), input_names_with_id, output_names_with_id, std::to_string(0),
std::to_string(Get<int>("max_batch_size")), std::to_string(Get<int>("max_batch_size")),
std::to_string(static_cast<int>(precision_mode)), std::to_string(static_cast<int>(precision_mode)), false);
std::to_string(static_cast<int>(use_calib_mode))); auto calibration_engine_key = GenerateEngineKey(
input_names_with_id, output_names_with_id, std::to_string(0),
std::to_string(Get<int>("max_batch_size")),
std::to_string(static_cast<int>(precision_mode)), true);
auto predictor_id = Get<int>("predictor_id"); auto predictor_id = Get<int>("predictor_id");
// Get "" when there is no cached calibration table data. // Get "" when there is no cached calibration table data.
std::string calibration_data = ""; std::string calibration_data = "";
if (enable_int8 && use_calib_mode) { if (enable_int8 && use_calib_mode) {
calibration_data = GetTrtCalibTableData( calibration_data =
Get<std::string>("model_opt_cache_dir"), engine_key, enable_int8); GetTrtCalibTableData(Get<std::string>("model_opt_cache_dir"),
calibration_engine_key, enable_int8);
} }
op_desc->SetAttr("calibration_data", calibration_data); op_desc->SetAttr("calibration_data", calibration_data);
op_desc->SetAttr("enable_int8", enable_int8); op_desc->SetAttr("enable_int8", enable_int8);
op_desc->SetAttr("enable_fp16", enable_fp16); op_desc->SetAttr("enable_fp16", enable_fp16);
op_desc->SetAttr("use_calib_mode", use_calib_mode); op_desc->SetAttr("use_calib_mode", use_calib_mode);
op_desc->SetAttr("engine_key", engine_key); op_desc->SetAttr("engine_key", engine_key);
op_desc->SetAttr("calibration_engine_key", calibration_engine_key);
op_desc->SetAttr("predictor_id", predictor_id); op_desc->SetAttr("predictor_id", predictor_id);
std::string trt_engine_serialized_data = ""; std::string trt_engine_serialized_data = "";
......
...@@ -1017,8 +1017,8 @@ bool AnalysisPredictor::SaveTrtCalibToDisk() { ...@@ -1017,8 +1017,8 @@ bool AnalysisPredictor::SaveTrtCalibToDisk() {
auto &block = inference_program_->Block(0); auto &block = inference_program_->Block(0);
for (auto &op_desc : block.AllOps()) { for (auto &op_desc : block.AllOps()) {
if (op_desc->Type() == "tensorrt_engine") { if (op_desc->Type() == "tensorrt_engine") {
std::string engine_name = std::string engine_name = BOOST_GET_CONST(
BOOST_GET_CONST(std::string, op_desc->GetAttr("engine_key")); std::string, op_desc->GetAttr("calibration_engine_key"));
if (!Singleton<TRTCalibratorEngineManager>::Global().Has(engine_name)) { if (!Singleton<TRTCalibratorEngineManager>::Global().Has(engine_name)) {
LOG(ERROR) << "You should run the predictor(with trt) on the real data " LOG(ERROR) << "You should run the predictor(with trt) on the real data "
"to generate calibration info"; "to generate calibration info";
......
...@@ -89,6 +89,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -89,6 +89,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
bool use_calib_mode_; bool use_calib_mode_;
std::string calibration_data_; std::string calibration_data_;
std::string engine_key_; std::string engine_key_;
std::string calibration_engine_key_;
bool calibration_mode_; bool calibration_mode_;
int predictor_id_; int predictor_id_;
int device_id_; int device_id_;
...@@ -109,6 +110,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -109,6 +110,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
use_calib_mode_ = Attr<bool>("use_calib_mode"); use_calib_mode_ = Attr<bool>("use_calib_mode");
calibration_data_ = Attr<std::string>("calibration_data"); calibration_data_ = Attr<std::string>("calibration_data");
engine_key_ = Attr<std::string>("engine_key"); engine_key_ = Attr<std::string>("engine_key");
calibration_engine_key_ = Attr<std::string>("calibration_engine_key");
predictor_id_ = Attr<int>("predictor_id"); predictor_id_ = Attr<int>("predictor_id");
auto params = Attr<std::vector<std::string>>("parameters"); auto params = Attr<std::vector<std::string>>("parameters");
...@@ -172,9 +174,11 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -172,9 +174,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
"Paddle TRT int8..."; "Paddle TRT int8...";
int runtime_batch = 1; int runtime_batch = 1;
if (!Singleton<TRTCalibratorEngineManager>::Global().Has(engine_key_)) { if (!Singleton<TRTCalibratorEngineManager>::Global().Has(
calibration_engine_key_)) {
TRTCalibratorEngine *calib_res = TRTCalibratorEngine *calib_res =
Singleton<TRTCalibratorEngineManager>::Global().Create(engine_key_); Singleton<TRTCalibratorEngineManager>::Global().Create(
calibration_engine_key_);
std::unordered_map<std::string, size_t> calib_buffers; std::unordered_map<std::string, size_t> calib_buffers;
for (auto &x : input_names_) { for (auto &x : input_names_) {
if (param_names_.count(x)) continue; if (param_names_.count(x)) continue;
...@@ -185,7 +189,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -185,7 +189,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
runtime_batch = t_shape[0]; runtime_batch = t_shape[0];
} }
calib_res->calib_.reset(new TRTInt8Calibrator( calib_res->calib_.reset(new TRTInt8Calibrator(
calib_buffers, runtime_batch, engine_key_, dev_place)); calib_buffers, runtime_batch, calibration_engine_key_, dev_place));
calib_res->thr_.reset(new std::thread([&]() { calib_res->thr_.reset(new std::thread([&]() {
calib_res->engine_.reset(new TensorRTEngine( calib_res->engine_.reset(new TensorRTEngine(
max_batch_size_, workspace_size_, precision_mode_, max_batch_size_, workspace_size_, precision_mode_,
...@@ -198,7 +202,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -198,7 +202,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
TRTInt8Calibrator *temp_calibrator = TRTInt8Calibrator *temp_calibrator =
Singleton<TRTCalibratorEngineManager>::Global() Singleton<TRTCalibratorEngineManager>::Global()
.Get(engine_key_) .Get(calibration_engine_key_)
->calib_.get(); ->calib_.get();
std::unordered_map<std::string, void *> calib_data; std::unordered_map<std::string, void *> calib_data;
......
...@@ -102,6 +102,8 @@ TEST(TensorRTEngineOp, manual) { ...@@ -102,6 +102,8 @@ TEST(TensorRTEngineOp, manual) {
engine_op_desc.SetAttr("workspace_size", static_cast<int>(1 << 20)); engine_op_desc.SetAttr("workspace_size", static_cast<int>(1 << 20));
engine_op_desc.SetAttr("parameters", std::vector<std::string>({})); engine_op_desc.SetAttr("parameters", std::vector<std::string>({}));
engine_op_desc.SetAttr("engine_key", std::string("a_engine")); engine_op_desc.SetAttr("engine_key", std::string("a_engine"));
engine_op_desc.SetAttr("calibration_engine_key",
std::string("a_calib_engine"));
engine_op_desc.SetAttr("predictor_id", 1); engine_op_desc.SetAttr("predictor_id", 1);
engine_op_desc.SetAttr("calibration_data", std::string("")); engine_op_desc.SetAttr("calibration_data", std::string(""));
engine_op_desc.SetAttr("enable_int8", static_cast<bool>(false)); engine_op_desc.SetAttr("enable_int8", static_cast<bool>(false));
...@@ -204,6 +206,8 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { ...@@ -204,6 +206,8 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
engine_op_desc.SetAttr("parameters", engine_op_desc.SetAttr("parameters",
std::vector<std::string>({"y0", "y1", "y2", "y3"})); std::vector<std::string>({"y0", "y1", "y2", "y3"}));
engine_op_desc.SetAttr("engine_key", std::string("b_engine")); engine_op_desc.SetAttr("engine_key", std::string("b_engine"));
engine_op_desc.SetAttr("calibration_engine_key",
std::string("b_calib_engine"));
engine_op_desc.SetAttr("predictor_id", 1); engine_op_desc.SetAttr("predictor_id", 1);
engine_op_desc.SetAttr("calibration_data", std::string("")); engine_op_desc.SetAttr("calibration_data", std::string(""));
engine_op_desc.SetAttr("enable_int8", static_cast<bool>(false)); engine_op_desc.SetAttr("enable_int8", static_cast<bool>(false));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册