未验证 提交 9336dd3e 编写于 作者: J JingZhuangzhuang 提交者: GitHub

add trt int8 dynamic support (#44800)

* add trt int8 dynamic support

* just support trt7+

* just for trt7.1.3.a

* Update tensorrt_subgraph_pass.cc

* delete trt_engine when it not use
上级 210fa777
...@@ -40,21 +40,24 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( ...@@ -40,21 +40,24 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
auto with_dynamic_shape = Get<bool>("with_dynamic_shape"); auto with_dynamic_shape = Get<bool>("with_dynamic_shape");
auto teller = [&](const framework::ir::Node *node) { auto teller = [&](const framework::ir::Node *node) {
if (!node->IsOp() || !node->Op()) return false; if (!node->IsOp() || !node->Op()) return false;
if (find(trt_disabled_ops.begin(), trt_disabled_ops.end(), if (find(trt_disabled_ops.begin(),
trt_disabled_ops.end(),
node->Op()->Type()) != trt_disabled_ops.end()) { node->Op()->Type()) != trt_disabled_ops.end()) {
VLOG(3) << node->Op()->Type().c_str() VLOG(3) << node->Op()->Type().c_str()
<< " is diabled by config in TensorRT"; << " is diabled by config in TensorRT";
return false; return false;
} }
bool is_ok = tensorrt::OpTeller::Global().Tell(node, no_calib_int8, bool is_ok = tensorrt::OpTeller::Global().Tell(
with_dynamic_shape); node, no_calib_int8, with_dynamic_shape);
if (!is_ok) if (!is_ok)
VLOG(3) << node->Op()->Type().c_str() << " op is not in TensorRT"; VLOG(3) << node->Op()->Type().c_str() << " op is not in TensorRT";
return is_ok; return is_ok;
}; };
framework::ir::SubGraphFuser fuser( framework::ir::SubGraphFuser fuser(
graph, teller, Get<int>("min_subgraph_size") /*min subgraph size*/, graph,
teller,
Get<int>("min_subgraph_size") /*min subgraph size*/,
"tensorrt_engine"); "tensorrt_engine");
fuser(); fuser();
...@@ -102,26 +105,27 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs, ...@@ -102,26 +105,27 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
engine_hash_key += "#"; engine_hash_key += "#";
} }
engine_hash_key += predictor_id; engine_hash_key += predictor_id;
if (!for_calibration) {
engine_hash_key += "#";
engine_hash_key += max_batch_size;
}
engine_hash_key += "#"; engine_hash_key += "#";
engine_hash_key += precision; engine_hash_key += precision;
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key)); auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
if (!for_calibration) {
engine_key += max_batch_size;
}
VLOG(2) << "TRT engine hash key: " << engine_hash_key; VLOG(2) << "TRT engine hash key: " << engine_hash_key;
VLOG(2) << "TRT engine key: " << engine_key; VLOG(2) << "TRT engine key: " << engine_key;
return engine_key; return engine_key;
} }
void TensorRtSubgraphPass::CreateTensorRTOp( void TensorRtSubgraphPass::CreateTensorRTOp(
framework::ir::Node *node, framework::ir::Graph *graph, framework::ir::Node *node,
framework::ir::Graph *graph,
const std::vector<std::string> &graph_params, const std::vector<std::string> &graph_params,
std::vector<std::string> *repetitive_params) const { std::vector<std::string> *repetitive_params) const {
auto *op_desc = node->Op(); auto *op_desc = node->Op();
auto &subgraph = *framework::ir::Agent(node).subgraph(); auto &subgraph = *framework::ir::Agent(node).subgraph();
PADDLE_ENFORCE_EQ(subgraph.empty(), false, PADDLE_ENFORCE_EQ(subgraph.empty(),
false,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The subgraph should not be empty.")); "The subgraph should not be empty."));
...@@ -208,7 +212,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -208,7 +212,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
if (trt_tuned_dynamic_shape) { if (trt_tuned_dynamic_shape) {
VLOG(1) << "trt dynamic_shape deserialize from " << shape_range_info_path; VLOG(1) << "trt dynamic_shape deserialize from " << shape_range_info_path;
inference::DeserializeShapeRangeInfo(shape_range_info_path, inference::DeserializeShapeRangeInfo(shape_range_info_path,
&min_input_shape, &max_input_shape, &min_input_shape,
&max_input_shape,
&opt_input_shape); &opt_input_shape);
} }
...@@ -224,9 +229,14 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -224,9 +229,14 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
// input of a OP, but also the output of a Op, there will be problems. // input of a OP, but also the output of a Op, there will be problems.
// So we have to rename the variable in the subgraph to make sure // So we have to rename the variable in the subgraph to make sure
// it is either an OP's input or an OP's output. // it is either an OP's input or an OP's output.
RenameAndGetOutputs(subgraph_nodes, &block_desc, input_names_with_id, RenameAndGetOutputs(subgraph_nodes,
&output_names_with_id, &output_names, &output_name_map, &block_desc,
graph_var_map, !enable_int8); input_names_with_id,
&output_names_with_id,
&output_names,
&output_name_map,
graph_var_map,
!enable_int8);
// When tensorrt engine runs at the end of the operation, // When tensorrt engine runs at the end of the operation,
// output_mapping help us copy the data from the renamed ITensor // output_mapping help us copy the data from the renamed ITensor
...@@ -234,17 +244,20 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -234,17 +244,20 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
std::vector<std::string> output_mapping; std::vector<std::string> output_mapping;
std::vector<int> renamed_output_dims; std::vector<int> renamed_output_dims;
for (auto name : output_names) { for (auto name : output_names) {
PADDLE_ENFORCE_NE(output_name_map.count(name), 0, PADDLE_ENFORCE_NE(output_name_map.count(name),
0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The output_name_map should have %s", name)); "The output_name_map should have %s", name));
output_mapping.push_back(output_name_map[name]); output_mapping.push_back(output_name_map[name]);
renamed_output_dims.push_back(origin_name_output_dims[name]); renamed_output_dims.push_back(origin_name_output_dims[name]);
} }
PADDLE_ENFORCE_EQ(output_mapping.empty(), false, PADDLE_ENFORCE_EQ(output_mapping.empty(),
false,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The output_mapping should not be empty.")); "The output_mapping should not be empty."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
!block_desc.Proto()->vars().empty(), true, !block_desc.Proto()->vars().empty(),
true,
platform::errors::PreconditionNotMet("the block has no var-desc")); platform::errors::PreconditionNotMet("the block has no var-desc"));
// Set attrs // Set attrs
...@@ -287,14 +300,20 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -287,14 +300,20 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
// when running in the 'use_serialize' mode, there is a bug. // when running in the 'use_serialize' mode, there is a bug.
// serialization is affected by max_batch_size, but calibration is not. // serialization is affected by max_batch_size, but calibration is not.
// So we use seperate engine keys in serialization and calibration. // So we use seperate engine keys in serialization and calibration.
auto engine_key = GenerateEngineKey( auto engine_key =
input_names_with_id, output_names_with_id, std::to_string(0), GenerateEngineKey(input_names_with_id,
std::to_string(max_batch_size), output_names_with_id,
std::to_string(static_cast<int>(precision_mode)), false); std::to_string(0),
std::to_string(max_batch_size),
std::to_string(static_cast<int>(precision_mode)),
false);
auto calibration_engine_key = auto calibration_engine_key =
GenerateEngineKey(input_names_with_id, output_names_with_id, GenerateEngineKey(input_names_with_id,
std::to_string(0), std::to_string(max_batch_size), output_names_with_id,
std::to_string(static_cast<int>(precision_mode)), true); std::to_string(0),
std::to_string(max_batch_size),
std::to_string(static_cast<int>(precision_mode)),
true);
auto predictor_id = Get<int>("predictor_id"); auto predictor_id = Get<int>("predictor_id");
// Get "" when there is no cached calibration table data. // Get "" when there is no cached calibration table data.
...@@ -302,7 +321,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -302,7 +321,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
if (enable_int8 && use_calib_mode) { if (enable_int8 && use_calib_mode) {
calibration_data = calibration_data =
GetTrtCalibTableData(Get<std::string>("model_opt_cache_dir"), GetTrtCalibTableData(Get<std::string>("model_opt_cache_dir"),
calibration_engine_key, enable_int8); calibration_engine_key,
enable_int8);
} }
op_desc->SetAttr("calibration_data", calibration_data); op_desc->SetAttr("calibration_data", calibration_data);
op_desc->SetAttr("enable_int8", enable_int8); op_desc->SetAttr("enable_int8", enable_int8);
...@@ -325,13 +345,12 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -325,13 +345,12 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
// calibration table data. // calibration table data.
bool calibration_mode = bool calibration_mode =
(enable_int8 && calibration_data.size() == 0 && use_calib_mode); (enable_int8 && calibration_data.size() == 0 && use_calib_mode);
if (calibration_mode) {
// calibraion mode means generate int8 calibration table data process.
return;
}
std::copy(params_not_shared.begin(), params_not_shared.end(), if (!calibration_mode) {
std::back_inserter(*repetitive_params)); std::copy(params_not_shared.begin(),
params_not_shared.end(),
std::back_inserter(*repetitive_params));
}
// Check trt version for dynamic shape input. // Check trt version for dynamic shape input.
...@@ -368,10 +387,16 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -368,10 +387,16 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
bool disable_trt_plugin_fp16 = Get<bool>("disable_trt_plugin_fp16"); bool disable_trt_plugin_fp16 = Get<bool>("disable_trt_plugin_fp16");
tensorrt::TensorRTEngine *trt_engine = tensorrt::TensorRTEngine *trt_engine =
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global() inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Create(engine_key + std::to_string(predictor_id), max_batch_size, .Create(engine_key + std::to_string(predictor_id),
Get<int>("workspace_size"), precision_mode, calibrator.get(), max_batch_size,
Get<int>("gpu_device_id"), min_input_shape, max_input_shape, Get<int>("workspace_size"),
opt_input_shape, disable_trt_plugin_fp16); precision_mode,
calibrator.get(),
Get<int>("gpu_device_id"),
min_input_shape,
max_input_shape,
opt_input_shape,
disable_trt_plugin_fp16);
trt_engine->SetUseOSS(Get<bool>("use_oss")); trt_engine->SetUseOSS(Get<bool>("use_oss"));
trt_engine->SetWithInterleaved(Get<bool>("with_interleaved")); trt_engine->SetWithInterleaved(Get<bool>("with_interleaved"));
trt_engine->SetUseDLA(Get<bool>("trt_use_dla")); trt_engine->SetUseDLA(Get<bool>("trt_use_dla"));
...@@ -384,6 +409,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -384,6 +409,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
(graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass) && (graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass))); graph->Has(framework::ir::kMultiheadMatmulPass)));
if (calibration_mode) {
// calibraion mode means generate int8 calibration table data process.
return;
}
if (use_static_engine) { if (use_static_engine) {
trt_engine_serialized_data = GetTrtEngineSerializedData( trt_engine_serialized_data = GetTrtEngineSerializedData(
Get<std::string>("model_opt_cache_dir"), engine_key); Get<std::string>("model_opt_cache_dir"), engine_key);
...@@ -408,9 +438,12 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -408,9 +438,12 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
std::unordered_set<std::string> param_set(params.begin(), params.end()); std::unordered_set<std::string> param_set(params.begin(), params.end());
inference::Singleton<inference::tensorrt::OpConverter>::Global() inference::Singleton<inference::tensorrt::OpConverter>::Global()
.ConvertBlockToTRTEngine( .ConvertBlockToTRTEngine(
&block_desc_temp, *scope, &block_desc_temp,
*scope,
std::vector<std::string>(input_names.begin(), input_names.end()), std::vector<std::string>(input_names.begin(), input_names.end()),
param_set, output_mapping, trt_engine); param_set,
output_mapping,
trt_engine);
if (use_static_engine) { if (use_static_engine) {
nvinfer1::IHostMemory *serialized_engine_data = trt_engine->Serialize(); nvinfer1::IHostMemory *serialized_engine_data = trt_engine->Serialize();
......
...@@ -254,6 +254,11 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -254,6 +254,11 @@ void TensorRTEngine::FreezeNetwork() {
nvinfer1::OptProfileSelector::kOPT, nvinfer1::OptProfileSelector::kOPT,
Vec2TRT_Dims(optim_input_shape_[input.first], input.first, true)); Vec2TRT_Dims(optim_input_shape_[input.first], input.first, true));
} }
#if IS_TRT_VERSION_GE(7130)
if (enable_int8) {
infer_builder_config_->setCalibrationProfile(optim_profiles_[i]);
}
#endif
infer_builder_config_->addOptimizationProfile(optim_profiles_[i]); infer_builder_config_->addOptimizationProfile(optim_profiles_[i]);
} }
if (WithFp16() && disable_trt_plugin_fp16()) { if (WithFp16() && disable_trt_plugin_fp16()) {
......
...@@ -614,6 +614,10 @@ class TensorRTEngine { ...@@ -614,6 +614,10 @@ class TensorRTEngine {
void SetUseInspector(bool use_inspector) { use_inspector_ = use_inspector; } void SetUseInspector(bool use_inspector) { use_inspector_ = use_inspector; }
void SetCalibrator(TRTInt8Calibrator* calibrator) {
calibrator_ = calibrator;
}
private: private:
// Each ICudaEngine object is bound to a specific GPU when it is instantiated, // Each ICudaEngine object is bound to a specific GPU when it is instantiated,
// ensure that the thread is associated with the correct device by calling // ensure that the thread is associated with the correct device by calling
......
...@@ -52,23 +52,28 @@ namespace operators { ...@@ -52,23 +52,28 @@ namespace operators {
using inference::Singleton; using inference::Singleton;
using inference::tensorrt::TensorRTEngine; using inference::tensorrt::TensorRTEngine;
using inference::tensorrt::TRTInt8Calibrator;
using inference::tensorrt::TRTCalibratorEngine; using inference::tensorrt::TRTCalibratorEngine;
using inference::tensorrt::TRTCalibratorEngineManager; using inference::tensorrt::TRTCalibratorEngineManager;
using inference::tensorrt::TRTInt8Calibrator;
static void RuntimeStaticShapeCheck(std::vector<int64_t> runtime_input_shape, static void RuntimeStaticShapeCheck(std::vector<int64_t> runtime_input_shape,
std::vector<int64_t> model_input_shape) { std::vector<int64_t> model_input_shape) {
auto comma_fold = [](std::string a, int b) { auto comma_fold = [](std::string a, int b) {
return std::move(a) + ", " + std::to_string(b); return std::move(a) + ", " + std::to_string(b);
}; };
std::string model_input_shape_str = std::accumulate( std::string model_input_shape_str =
std::next(model_input_shape.begin()), model_input_shape.end(), std::accumulate(std::next(model_input_shape.begin()),
std::to_string(model_input_shape[0]), comma_fold); model_input_shape.end(),
std::string runtime_input_shape_str = std::accumulate( std::to_string(model_input_shape[0]),
std::next(runtime_input_shape.begin()), runtime_input_shape.end(), comma_fold);
std::to_string(runtime_input_shape[0]), comma_fold); std::string runtime_input_shape_str =
std::accumulate(std::next(runtime_input_shape.begin()),
runtime_input_shape.end(),
std::to_string(runtime_input_shape[0]),
comma_fold);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
model_input_shape == runtime_input_shape, true, model_input_shape == runtime_input_shape,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input shapes are inconsistent with the model. Expect [%s] in " "Input shapes are inconsistent with the model. Expect [%s] in "
"model description, but got [%s] in runtime. TRT 5 " "model description, but got [%s] in runtime. TRT 5 "
...@@ -76,7 +81,8 @@ static void RuntimeStaticShapeCheck(std::vector<int64_t> runtime_input_shape, ...@@ -76,7 +81,8 @@ static void RuntimeStaticShapeCheck(std::vector<int64_t> runtime_input_shape,
"does not support dynamic input shapes. Please check and " "does not support dynamic input shapes. Please check and "
"modify " "modify "
"your input shapes.", "your input shapes.",
model_input_shape_str, runtime_input_shape_str)); model_input_shape_str,
runtime_input_shape_str));
} }
static paddle::experimental::DataType TRT2FluidDataType( static paddle::experimental::DataType TRT2FluidDataType(
...@@ -102,7 +108,8 @@ static paddle::experimental::DataType TRT2FluidDataType( ...@@ -102,7 +108,8 @@ static paddle::experimental::DataType TRT2FluidDataType(
} }
static void RuntimeDynamicShapeCheck( static void RuntimeDynamicShapeCheck(
const std::string &x, const std::vector<int32_t> &runtime_input_shape, const std::string &x,
const std::vector<int32_t> &runtime_input_shape,
const std::vector<int32_t> &min_input_shape, const std::vector<int32_t> &min_input_shape,
const std::vector<int32_t> &max_input_shape) { const std::vector<int32_t> &max_input_shape) {
// PADDLE_ENFORCE_EQ( // PADDLE_ENFORCE_EQ(
...@@ -111,10 +118,10 @@ static void RuntimeDynamicShapeCheck( ...@@ -111,10 +118,10 @@ static void RuntimeDynamicShapeCheck(
// "TRT engine runtime input %s dims size(%d) inconsistent " // "TRT engine runtime input %s dims size(%d) inconsistent "
// "with the dynamic shape size(%d)", // "with the dynamic shape size(%d)",
// x, runtime_input_shape.size(), min_input_shape.size())); // x, runtime_input_shape.size(), min_input_shape.size()));
auto is_input_shape_valid = [&]( auto is_input_shape_valid =
const std::vector<int32_t> &runtime_input_shape, [&](const std::vector<int32_t> &runtime_input_shape,
const std::vector<int32_t> &min_input_shape, const std::vector<int32_t> &min_input_shape,
const std::vector<int32_t> &max_input_shape) -> bool { const std::vector<int32_t> &max_input_shape) -> bool {
for (size_t i = 0; i < runtime_input_shape.size(); i++) { for (size_t i = 0; i < runtime_input_shape.size(); i++) {
if (runtime_input_shape[i] <= max_input_shape[i] && if (runtime_input_shape[i] <= max_input_shape[i] &&
runtime_input_shape[i] >= min_input_shape[i]) { runtime_input_shape[i] >= min_input_shape[i]) {
...@@ -128,17 +135,23 @@ static void RuntimeDynamicShapeCheck( ...@@ -128,17 +135,23 @@ static void RuntimeDynamicShapeCheck(
auto comma_fold = [](std::string a, int b) { auto comma_fold = [](std::string a, int b) {
return std::move(a) + ", " + std::to_string(b); return std::move(a) + ", " + std::to_string(b);
}; };
std::string runtime_input_shape_str = std::accumulate( std::string runtime_input_shape_str =
std::next(runtime_input_shape.begin()), runtime_input_shape.end(), std::accumulate(std::next(runtime_input_shape.begin()),
std::to_string(runtime_input_shape[0]), comma_fold); runtime_input_shape.end(),
std::to_string(runtime_input_shape[0]),
comma_fold);
std::string min_input_shape_str = std::string min_input_shape_str =
std::accumulate(std::next(min_input_shape.begin()), min_input_shape.end(), std::accumulate(std::next(min_input_shape.begin()),
std::to_string(min_input_shape[0]), comma_fold); min_input_shape.end(),
std::to_string(min_input_shape[0]),
comma_fold);
std::string max_input_shape_str = std::string max_input_shape_str =
std::accumulate(std::next(max_input_shape.begin()), max_input_shape.end(), std::accumulate(std::next(max_input_shape.begin()),
std::to_string(max_input_shape[0]), comma_fold); max_input_shape.end(),
PADDLE_ENFORCE_EQ(is_input_shape_valid(runtime_input_shape, min_input_shape, std::to_string(max_input_shape[0]),
max_input_shape), comma_fold);
PADDLE_ENFORCE_EQ(is_input_shape_valid(
runtime_input_shape, min_input_shape, max_input_shape),
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"TRT runtime input shape of %s is invalid. Expect " "TRT runtime input shape of %s is invalid. Expect "
...@@ -146,7 +159,9 @@ static void RuntimeDynamicShapeCheck( ...@@ -146,7 +159,9 @@ static void RuntimeDynamicShapeCheck(
"configured in SetTRTDynamicShapeInfo()," "configured in SetTRTDynamicShapeInfo(),"
"but got runtime input shape = [%s], min input shape = " "but got runtime input shape = [%s], min input shape = "
"[%s], max input shape = [%s].", "[%s], max input shape = [%s].",
x, runtime_input_shape_str, min_input_shape_str, x,
runtime_input_shape_str,
min_input_shape_str,
max_input_shape_str)); max_input_shape_str));
} }
...@@ -158,7 +173,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -158,7 +173,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
mutable TensorRTEngine *trt_engine_{nullptr}; mutable TensorRTEngine *trt_engine_{nullptr};
int max_batch_size_; int max_batch_size_;
int workspace_size_; int workspace_size_;
std::unique_ptr<TRTInt8Calibrator> calibrator_; mutable std::unique_ptr<TRTInt8Calibrator> calibrator_;
bool enable_int8_; bool enable_int8_;
bool enable_fp16_; bool enable_fp16_;
bool use_calib_mode_; bool use_calib_mode_;
...@@ -260,7 +275,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -260,7 +275,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global() inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Has(engine_key_ + std::to_string(predictor_id_)); .Has(engine_key_ + std::to_string(predictor_id_));
if (!calibration_mode_ && has_engine) { if (has_engine) {
trt_engine_ = trt_engine_ =
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global() inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Get(engine_key_ + std::to_string(predictor_id_)); .Get(engine_key_ + std::to_string(predictor_id_));
...@@ -287,8 +302,8 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -287,8 +302,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
Attr<std::vector<std::string>>("output_name_mapping"); Attr<std::vector<std::string>>("output_name_mapping");
inference::Singleton<inference::tensorrt::OpConverter>::Global() inference::Singleton<inference::tensorrt::OpConverter>::Global()
.ConvertBlockToTRTEngine(&block_desc, scope, inputs, param_names_, .ConvertBlockToTRTEngine(
outputs, engine); &block_desc, scope, inputs, param_names_, outputs, engine);
} }
protected: protected:
...@@ -304,11 +319,14 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -304,11 +319,14 @@ class TensorRTEngineOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
auto *trt_engine = GetEngine(scope, dev_place);
if (calibration_mode_ == true) { if (calibration_mode_ == true) {
RunCalibration(scope, dev_place); RunCalibration(scope, dev_place, trt_engine);
paddle::inference::Singleton<
inference::tensorrt::TRTEngineManager>::Global()
.DeleteKey(engine_key_ + std::to_string(predictor_id_));
return; return;
} }
auto *trt_engine = GetEngine(scope, dev_place);
if (use_inspector_) { if (use_inspector_) {
trt_engine->GetEngineInfo(); trt_engine->GetEngineInfo();
} }
...@@ -331,15 +349,19 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -331,15 +349,19 @@ class TensorRTEngineOp : public framework::OperatorBase {
trt_engine->max_input_shape(); trt_engine->max_input_shape();
for (auto &x : runtime_input_names_) { for (auto &x : runtime_input_names_) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
min_input_shape.count(x), true, min_input_shape.count(x),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input %s not found in TRT engine min_input_shape.", x)); "Input %s not found in TRT engine min_input_shape.", x));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
max_input_shape.count(x), true, max_input_shape.count(x),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input %s not found in TRT engine max_input_shape.", x)); "Input %s not found in TRT engine max_input_shape.", x));
RuntimeDynamicShapeCheck(x, runtime_input_shape[x], RuntimeDynamicShapeCheck(x,
min_input_shape[x], max_input_shape[x]); runtime_input_shape[x],
min_input_shape[x],
max_input_shape[x]);
} }
} else { } else {
// compare runtime_input_shape and trt_engine dynamic shapes. // compare runtime_input_shape and trt_engine dynamic shapes.
...@@ -357,13 +379,18 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -357,13 +379,18 @@ class TensorRTEngineOp : public framework::OperatorBase {
if (anc == nullptr) { if (anc == nullptr) {
anc = &scope; anc = &scope;
} }
if (enable_int8_ && calibration_data_.size()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data_));
trt_engine->SetCalibrator(calibrator_.get());
}
PrepareTRTEngine(*anc, trt_engine); PrepareTRTEngine(*anc, trt_engine);
// update shape_range_info_pbtxt // update shape_range_info_pbtxt
if (!shape_range_info_path_.empty()) { if (!shape_range_info_path_.empty()) {
inference::UpdateShapeRangeInfo( inference::UpdateShapeRangeInfo(shape_range_info_path_,
shape_range_info_path_, trt_engine->min_input_shape(), trt_engine->min_input_shape(),
trt_engine->max_input_shape(), trt_engine->optim_input_shape(), trt_engine->max_input_shape(),
shape_changed_name); trt_engine->optim_input_shape(),
shape_changed_name);
} }
if (use_static_engine_) { if (use_static_engine_) {
...@@ -387,7 +414,8 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -387,7 +414,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
} }
void RunCalibration(const framework::Scope &scope, void RunCalibration(const framework::Scope &scope,
const platform::Place &dev_place) const { const platform::Place &dev_place,
TensorRTEngine *trt_engine) const {
// This process will builds a 32-bit trt engine, runs it on the calibration // This process will builds a 32-bit trt engine, runs it on the calibration
// set, and records a histogram for each // set, and records a histogram for each
// tensor of the distribution of activation values. // tensor of the distribution of activation values.
...@@ -412,9 +440,15 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -412,9 +440,15 @@ class TensorRTEngineOp : public framework::OperatorBase {
calib_res->calib_.reset(new TRTInt8Calibrator( calib_res->calib_.reset(new TRTInt8Calibrator(
calib_buffers, runtime_batch, calibration_engine_key_, dev_place)); calib_buffers, runtime_batch, calibration_engine_key_, dev_place));
calib_res->thr_.reset(new std::thread([&]() { calib_res->thr_.reset(new std::thread([&]() {
calib_res->engine_.reset(new TensorRTEngine( calib_res->engine_.reset(
max_batch_size_, workspace_size_, precision_mode_, new TensorRTEngine(max_batch_size_,
calib_res->calib_.get(), dev_place.device)); workspace_size_,
precision_mode_,
calib_res->calib_.get(),
dev_place.device,
trt_engine->min_input_shape(),
trt_engine->max_input_shape(),
trt_engine->optim_input_shape()));
VLOG(3) << "start the calib trt engine thread"; VLOG(3) << "start the calib trt engine thread";
PrepareTRTEngine(scope, calib_res->engine_.get()); PrepareTRTEngine(scope, calib_res->engine_.get());
})); }));
...@@ -436,7 +470,8 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -436,7 +470,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
RunNativeImpl(scope, dev_place); RunNativeImpl(scope, dev_place);
} }
void RunTrt(const framework::Scope &scope, const platform::Place &dev_place, void RunTrt(const framework::Scope &scope,
const platform::Place &dev_place,
TensorRTEngine *engine) const { TensorRTEngine *engine) const {
int runtime_batch = -1; int runtime_batch = -1;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
...@@ -445,7 +480,8 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -445,7 +480,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx).stream(); reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx).stream();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
runtime_input_names_.empty(), false, runtime_input_names_.empty(),
false,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"TensorRT engine needs at least one input, but no input is found. " "TensorRT engine needs at least one input, but no input is found. "
"Please check if you set the input correctly.")); "Please check if you set the input correctly."));
...@@ -488,13 +524,15 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -488,13 +524,15 @@ class TensorRTEngineOp : public framework::OperatorBase {
const int bind_index = const int bind_index =
engine->engine()->getBindingIndex(x.c_str()) + binding_offset; engine->engine()->getBindingIndex(x.c_str()) + binding_offset;
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
bind_index, num_bindings, bind_index,
num_bindings,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Wrong TRT engine input binding index. Expected The " "Wrong TRT engine input binding index. Expected The "
"binding index of TRT engine input to be less than " "binding index of TRT engine input to be less than "
"the number of inputs and outputs. Received binding " "the number of inputs and outputs. Received binding "
"index=%d >= total inputs and outputs=%d", "index=%d >= total inputs and outputs=%d",
bind_index, num_bindings)); bind_index,
num_bindings));
if (!engine->with_dynamic_shape()) { if (!engine->with_dynamic_shape()) {
// check if the input shapes are consistent with model. // check if the input shapes are consistent with model.
if (HasAttr(x + "_shape")) { if (HasAttr(x + "_shape")) {
...@@ -507,7 +545,8 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -507,7 +545,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
RuntimeStaticShapeCheck(runtime_input_shape, model_input_shape); RuntimeStaticShapeCheck(runtime_input_shape, model_input_shape);
if (runtime_batch != -1) { if (runtime_batch != -1) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
runtime_batch, t_shape[0], runtime_batch,
t_shape[0],
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Inputs of trt subgraphs has different batchsize. " "Inputs of trt subgraphs has different batchsize. "
"It's not allowed in static shape mode. " "It's not allowed in static shape mode. "
...@@ -589,12 +628,14 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -589,12 +628,14 @@ class TensorRTEngineOp : public framework::OperatorBase {
auto *fluid_t = fluid_v->GetMutable<framework::LoDTensor>(); auto *fluid_t = fluid_v->GetMutable<framework::LoDTensor>();
fluid_t->Resize(phi::make_ddim(ddim)); fluid_t->Resize(phi::make_ddim(ddim));
PADDLE_ENFORCE_LT(bind_index, num_bindings, PADDLE_ENFORCE_LT(bind_index,
num_bindings,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The binding index in TRT engine should be less " "The binding index in TRT engine should be less "
"than the number of bindings, but got binding " "than the number of bindings, but got binding "
"index = %d, number of bindings = %d.", "index = %d, number of bindings = %d.",
bind_index, num_bindings)); bind_index,
num_bindings));
auto trt_type = engine->engine()->getBindingDataType(bind_index); auto trt_type = engine->engine()->getBindingDataType(bind_index);
// get adr and set type // get adr and set type
buffers[bind_index] = static_cast<void *>( buffers[bind_index] = static_cast<void *>(
...@@ -604,7 +645,8 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -604,7 +645,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
if (!engine->with_dynamic_shape()) { if (!engine->with_dynamic_shape()) {
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
runtime_batch, max_batch_size_, runtime_batch,
max_batch_size_,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The runtime batch size (%d) is greater than the max batch " "The runtime batch size (%d) is greater than the max batch "
"size(%d).\n" "size(%d).\n"
...@@ -623,7 +665,8 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -623,7 +665,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
"\tThe min_subgraph_size shouble to be greater than the number " "\tThe min_subgraph_size shouble to be greater than the number "
"of " "of "
"nodes in the inconsistent subgraph.\n", "nodes in the inconsistent subgraph.\n",
runtime_batch, max_batch_size_)); runtime_batch,
max_batch_size_));
} }
// Execute the engine. // Execute the engine.
engine->Execute(runtime_batch, &buffers, stream); engine->Execute(runtime_batch, &buffers, stream);
...@@ -635,9 +678,14 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -635,9 +678,14 @@ class TensorRTEngineOp : public framework::OperatorBase {
trt_engine_ = trt_engine_ =
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global() inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Create(engine_key_ + std::to_string(predictor_id_), .Create(engine_key_ + std::to_string(predictor_id_),
max_batch_size_, workspace_size_, precision_mode_, max_batch_size_,
calibrator_.get(), device_id_, min_input_shape_, workspace_size_,
max_input_shape_, opt_input_shape_); precision_mode_,
calibrator_.get(),
device_id_,
min_input_shape_,
max_input_shape_,
opt_input_shape_);
PrepareTRTEngine(scope, trt_engine_); PrepareTRTEngine(scope, trt_engine_);
} }
return trt_engine_; return trt_engine_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册