未验证 提交 cac9635a 编写于 作者: P Pei Yang 提交者: GitHub

[Paddle-TRT] Fix engine key in trt int8 calibration (#31513)

* fix engine key in trt int8 calibration

* fix unit test
上级 50ac7dbf
......@@ -86,7 +86,7 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
const std::string &predictor_id,
const std::string &max_batch_size,
const std::string &precision,
const std::string &use_calib_mode) {
const bool for_calibration) {
std::string engine_hash_key = "";
for (auto name : engine_inputs) {
engine_hash_key += name;
......@@ -97,12 +97,13 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
engine_hash_key += "#";
}
engine_hash_key += predictor_id;
engine_hash_key += "#";
engine_hash_key += max_batch_size;
if (!for_calibration) {
engine_hash_key += "#";
engine_hash_key += max_batch_size;
}
engine_hash_key += "#";
engine_hash_key += precision;
engine_hash_key += "#";
engine_hash_key += use_calib_mode;
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
VLOG(2) << "TRT engine hash key: " << engine_hash_key;
VLOG(2) << "TRT engine key: " << engine_key;
......@@ -258,24 +259,31 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
// TODO(NHZlX)
// There are models with the same structure but the different parameters,
// when running in the 'use_serialize' mode, there is a bug.
// serialization is affected by max_batch_size, but calibration is not.
// So we use seperate engine keys in serialization and calibration.
auto engine_key = GenerateEngineKey(
input_names_with_id, output_names_with_id, std::to_string(0),
std::to_string(Get<int>("max_batch_size")),
std::to_string(static_cast<int>(precision_mode)),
std::to_string(static_cast<int>(use_calib_mode)));
std::to_string(static_cast<int>(precision_mode)), false);
auto calibration_engine_key = GenerateEngineKey(
input_names_with_id, output_names_with_id, std::to_string(0),
std::to_string(Get<int>("max_batch_size")),
std::to_string(static_cast<int>(precision_mode)), true);
auto predictor_id = Get<int>("predictor_id");
// Get "" when there is no cached calibration table data.
std::string calibration_data = "";
if (enable_int8 && use_calib_mode) {
calibration_data = GetTrtCalibTableData(
Get<std::string>("model_opt_cache_dir"), engine_key, enable_int8);
calibration_data =
GetTrtCalibTableData(Get<std::string>("model_opt_cache_dir"),
calibration_engine_key, enable_int8);
}
op_desc->SetAttr("calibration_data", calibration_data);
op_desc->SetAttr("enable_int8", enable_int8);
op_desc->SetAttr("enable_fp16", enable_fp16);
op_desc->SetAttr("use_calib_mode", use_calib_mode);
op_desc->SetAttr("engine_key", engine_key);
op_desc->SetAttr("calibration_engine_key", calibration_engine_key);
op_desc->SetAttr("predictor_id", predictor_id);
std::string trt_engine_serialized_data = "";
......
......@@ -1017,8 +1017,8 @@ bool AnalysisPredictor::SaveTrtCalibToDisk() {
auto &block = inference_program_->Block(0);
for (auto &op_desc : block.AllOps()) {
if (op_desc->Type() == "tensorrt_engine") {
std::string engine_name =
BOOST_GET_CONST(std::string, op_desc->GetAttr("engine_key"));
std::string engine_name = BOOST_GET_CONST(
std::string, op_desc->GetAttr("calibration_engine_key"));
if (!Singleton<TRTCalibratorEngineManager>::Global().Has(engine_name)) {
LOG(ERROR) << "You should run the predictor(with trt) on the real data "
"to generate calibration info";
......
......@@ -89,6 +89,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
bool use_calib_mode_;
std::string calibration_data_;
std::string engine_key_;
std::string calibration_engine_key_;
bool calibration_mode_;
int predictor_id_;
int device_id_;
......@@ -109,6 +110,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
use_calib_mode_ = Attr<bool>("use_calib_mode");
calibration_data_ = Attr<std::string>("calibration_data");
engine_key_ = Attr<std::string>("engine_key");
calibration_engine_key_ = Attr<std::string>("calibration_engine_key");
predictor_id_ = Attr<int>("predictor_id");
auto params = Attr<std::vector<std::string>>("parameters");
......@@ -172,9 +174,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
"Paddle TRT int8...";
int runtime_batch = 1;
if (!Singleton<TRTCalibratorEngineManager>::Global().Has(engine_key_)) {
if (!Singleton<TRTCalibratorEngineManager>::Global().Has(
calibration_engine_key_)) {
TRTCalibratorEngine *calib_res =
Singleton<TRTCalibratorEngineManager>::Global().Create(engine_key_);
Singleton<TRTCalibratorEngineManager>::Global().Create(
calibration_engine_key_);
std::unordered_map<std::string, size_t> calib_buffers;
for (auto &x : input_names_) {
if (param_names_.count(x)) continue;
......@@ -185,7 +189,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
runtime_batch = t_shape[0];
}
calib_res->calib_.reset(new TRTInt8Calibrator(
calib_buffers, runtime_batch, engine_key_, dev_place));
calib_buffers, runtime_batch, calibration_engine_key_, dev_place));
calib_res->thr_.reset(new std::thread([&]() {
calib_res->engine_.reset(new TensorRTEngine(
max_batch_size_, workspace_size_, precision_mode_,
......@@ -198,7 +202,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
TRTInt8Calibrator *temp_calibrator =
Singleton<TRTCalibratorEngineManager>::Global()
.Get(engine_key_)
.Get(calibration_engine_key_)
->calib_.get();
std::unordered_map<std::string, void *> calib_data;
......
......@@ -102,6 +102,8 @@ TEST(TensorRTEngineOp, manual) {
engine_op_desc.SetAttr("workspace_size", static_cast<int>(1 << 20));
engine_op_desc.SetAttr("parameters", std::vector<std::string>({}));
engine_op_desc.SetAttr("engine_key", std::string("a_engine"));
engine_op_desc.SetAttr("calibration_engine_key",
std::string("a_calib_engine"));
engine_op_desc.SetAttr("predictor_id", 1);
engine_op_desc.SetAttr("calibration_data", std::string(""));
engine_op_desc.SetAttr("enable_int8", static_cast<bool>(false));
......@@ -204,6 +206,8 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
engine_op_desc.SetAttr("parameters",
std::vector<std::string>({"y0", "y1", "y2", "y3"}));
engine_op_desc.SetAttr("engine_key", std::string("b_engine"));
engine_op_desc.SetAttr("calibration_engine_key",
std::string("b_calib_engine"));
engine_op_desc.SetAttr("predictor_id", 1);
engine_op_desc.SetAttr("calibration_data", std::string(""));
engine_op_desc.SetAttr("enable_int8", static_cast<bool>(false));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册