未验证 提交 9336dd3e 编写于 作者: J JingZhuangzhuang 提交者: GitHub

add trt int8 dynamic support (#44800)

* add trt int8 dynamic support

* just support trt7+

* just for trt7.1.3.a

* Update tensorrt_subgraph_pass.cc

* delete trt_engine when it not use
上级 210fa777
......@@ -40,21 +40,24 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
auto with_dynamic_shape = Get<bool>("with_dynamic_shape");
auto teller = [&](const framework::ir::Node *node) {
if (!node->IsOp() || !node->Op()) return false;
if (find(trt_disabled_ops.begin(), trt_disabled_ops.end(),
if (find(trt_disabled_ops.begin(),
trt_disabled_ops.end(),
node->Op()->Type()) != trt_disabled_ops.end()) {
VLOG(3) << node->Op()->Type().c_str()
<< " is diabled by config in TensorRT";
return false;
}
bool is_ok = tensorrt::OpTeller::Global().Tell(node, no_calib_int8,
with_dynamic_shape);
bool is_ok = tensorrt::OpTeller::Global().Tell(
node, no_calib_int8, with_dynamic_shape);
if (!is_ok)
VLOG(3) << node->Op()->Type().c_str() << " op is not in TensorRT";
return is_ok;
};
framework::ir::SubGraphFuser fuser(
graph, teller, Get<int>("min_subgraph_size") /*min subgraph size*/,
graph,
teller,
Get<int>("min_subgraph_size") /*min subgraph size*/,
"tensorrt_engine");
fuser();
......@@ -102,26 +105,27 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
engine_hash_key += "#";
}
engine_hash_key += predictor_id;
if (!for_calibration) {
engine_hash_key += "#";
engine_hash_key += max_batch_size;
}
engine_hash_key += "#";
engine_hash_key += precision;
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
if (!for_calibration) {
engine_key += max_batch_size;
}
VLOG(2) << "TRT engine hash key: " << engine_hash_key;
VLOG(2) << "TRT engine key: " << engine_key;
return engine_key;
}
void TensorRtSubgraphPass::CreateTensorRTOp(
framework::ir::Node *node, framework::ir::Graph *graph,
framework::ir::Node *node,
framework::ir::Graph *graph,
const std::vector<std::string> &graph_params,
std::vector<std::string> *repetitive_params) const {
auto *op_desc = node->Op();
auto &subgraph = *framework::ir::Agent(node).subgraph();
PADDLE_ENFORCE_EQ(subgraph.empty(), false,
PADDLE_ENFORCE_EQ(subgraph.empty(),
false,
platform::errors::PreconditionNotMet(
"The subgraph should not be empty."));
......@@ -208,7 +212,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
if (trt_tuned_dynamic_shape) {
VLOG(1) << "trt dynamic_shape deserialize from " << shape_range_info_path;
inference::DeserializeShapeRangeInfo(shape_range_info_path,
&min_input_shape, &max_input_shape,
&min_input_shape,
&max_input_shape,
&opt_input_shape);
}
......@@ -224,9 +229,14 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
// input of a OP, but also the output of a Op, there will be problems.
// So we have to rename the variable in the subgraph to make sure
// it is either an OP's input or an OP's output.
RenameAndGetOutputs(subgraph_nodes, &block_desc, input_names_with_id,
&output_names_with_id, &output_names, &output_name_map,
graph_var_map, !enable_int8);
RenameAndGetOutputs(subgraph_nodes,
&block_desc,
input_names_with_id,
&output_names_with_id,
&output_names,
&output_name_map,
graph_var_map,
!enable_int8);
// When tensorrt engine runs at the end of the operation,
// output_mapping help us copy the data from the renamed ITensor
......@@ -234,17 +244,20 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
std::vector<std::string> output_mapping;
std::vector<int> renamed_output_dims;
for (auto name : output_names) {
PADDLE_ENFORCE_NE(output_name_map.count(name), 0,
PADDLE_ENFORCE_NE(output_name_map.count(name),
0,
platform::errors::PreconditionNotMet(
"The output_name_map should have %s", name));
output_mapping.push_back(output_name_map[name]);
renamed_output_dims.push_back(origin_name_output_dims[name]);
}
PADDLE_ENFORCE_EQ(output_mapping.empty(), false,
PADDLE_ENFORCE_EQ(output_mapping.empty(),
false,
platform::errors::PreconditionNotMet(
"The output_mapping should not be empty."));
PADDLE_ENFORCE_EQ(
!block_desc.Proto()->vars().empty(), true,
!block_desc.Proto()->vars().empty(),
true,
platform::errors::PreconditionNotMet("the block has no var-desc"));
// Set attrs
......@@ -287,14 +300,20 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
// when running in the 'use_serialize' mode, there is a bug.
// serialization is affected by max_batch_size, but calibration is not.
// So we use seperate engine keys in serialization and calibration.
auto engine_key = GenerateEngineKey(
input_names_with_id, output_names_with_id, std::to_string(0),
auto engine_key =
GenerateEngineKey(input_names_with_id,
output_names_with_id,
std::to_string(0),
std::to_string(max_batch_size),
std::to_string(static_cast<int>(precision_mode)), false);
std::to_string(static_cast<int>(precision_mode)),
false);
auto calibration_engine_key =
GenerateEngineKey(input_names_with_id, output_names_with_id,
std::to_string(0), std::to_string(max_batch_size),
std::to_string(static_cast<int>(precision_mode)), true);
GenerateEngineKey(input_names_with_id,
output_names_with_id,
std::to_string(0),
std::to_string(max_batch_size),
std::to_string(static_cast<int>(precision_mode)),
true);
auto predictor_id = Get<int>("predictor_id");
// Get "" when there is no cached calibration table data.
......@@ -302,7 +321,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
if (enable_int8 && use_calib_mode) {
calibration_data =
GetTrtCalibTableData(Get<std::string>("model_opt_cache_dir"),
calibration_engine_key, enable_int8);
calibration_engine_key,
enable_int8);
}
op_desc->SetAttr("calibration_data", calibration_data);
op_desc->SetAttr("enable_int8", enable_int8);
......@@ -325,13 +345,12 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
// calibration table data.
bool calibration_mode =
(enable_int8 && calibration_data.size() == 0 && use_calib_mode);
if (calibration_mode) {
// calibraion mode means generate int8 calibration table data process.
return;
}
std::copy(params_not_shared.begin(), params_not_shared.end(),
if (!calibration_mode) {
std::copy(params_not_shared.begin(),
params_not_shared.end(),
std::back_inserter(*repetitive_params));
}
// Check trt version for dynamic shape input.
......@@ -368,10 +387,16 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
bool disable_trt_plugin_fp16 = Get<bool>("disable_trt_plugin_fp16");
tensorrt::TensorRTEngine *trt_engine =
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Create(engine_key + std::to_string(predictor_id), max_batch_size,
Get<int>("workspace_size"), precision_mode, calibrator.get(),
Get<int>("gpu_device_id"), min_input_shape, max_input_shape,
opt_input_shape, disable_trt_plugin_fp16);
.Create(engine_key + std::to_string(predictor_id),
max_batch_size,
Get<int>("workspace_size"),
precision_mode,
calibrator.get(),
Get<int>("gpu_device_id"),
min_input_shape,
max_input_shape,
opt_input_shape,
disable_trt_plugin_fp16);
trt_engine->SetUseOSS(Get<bool>("use_oss"));
trt_engine->SetWithInterleaved(Get<bool>("with_interleaved"));
trt_engine->SetUseDLA(Get<bool>("trt_use_dla"));
......@@ -384,6 +409,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
(graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass)));
if (calibration_mode) {
// calibraion mode means generate int8 calibration table data process.
return;
}
if (use_static_engine) {
trt_engine_serialized_data = GetTrtEngineSerializedData(
Get<std::string>("model_opt_cache_dir"), engine_key);
......@@ -408,9 +438,12 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
std::unordered_set<std::string> param_set(params.begin(), params.end());
inference::Singleton<inference::tensorrt::OpConverter>::Global()
.ConvertBlockToTRTEngine(
&block_desc_temp, *scope,
&block_desc_temp,
*scope,
std::vector<std::string>(input_names.begin(), input_names.end()),
param_set, output_mapping, trt_engine);
param_set,
output_mapping,
trt_engine);
if (use_static_engine) {
nvinfer1::IHostMemory *serialized_engine_data = trt_engine->Serialize();
......
......@@ -254,6 +254,11 @@ void TensorRTEngine::FreezeNetwork() {
nvinfer1::OptProfileSelector::kOPT,
Vec2TRT_Dims(optim_input_shape_[input.first], input.first, true));
}
#if IS_TRT_VERSION_GE(7130)
if (enable_int8) {
infer_builder_config_->setCalibrationProfile(optim_profiles_[i]);
}
#endif
infer_builder_config_->addOptimizationProfile(optim_profiles_[i]);
}
if (WithFp16() && disable_trt_plugin_fp16()) {
......
......@@ -614,6 +614,10 @@ class TensorRTEngine {
void SetUseInspector(bool use_inspector) { use_inspector_ = use_inspector; }
void SetCalibrator(TRTInt8Calibrator* calibrator) {
calibrator_ = calibrator;
}
private:
// Each ICudaEngine object is bound to a specific GPU when it is instantiated,
// ensure that the thread is associated with the correct device by calling
......
......@@ -52,23 +52,28 @@ namespace operators {
using inference::Singleton;
using inference::tensorrt::TensorRTEngine;
using inference::tensorrt::TRTInt8Calibrator;
using inference::tensorrt::TRTCalibratorEngine;
using inference::tensorrt::TRTCalibratorEngineManager;
using inference::tensorrt::TRTInt8Calibrator;
static void RuntimeStaticShapeCheck(std::vector<int64_t> runtime_input_shape,
std::vector<int64_t> model_input_shape) {
auto comma_fold = [](std::string a, int b) {
return std::move(a) + ", " + std::to_string(b);
};
std::string model_input_shape_str = std::accumulate(
std::next(model_input_shape.begin()), model_input_shape.end(),
std::to_string(model_input_shape[0]), comma_fold);
std::string runtime_input_shape_str = std::accumulate(
std::next(runtime_input_shape.begin()), runtime_input_shape.end(),
std::to_string(runtime_input_shape[0]), comma_fold);
std::string model_input_shape_str =
std::accumulate(std::next(model_input_shape.begin()),
model_input_shape.end(),
std::to_string(model_input_shape[0]),
comma_fold);
std::string runtime_input_shape_str =
std::accumulate(std::next(runtime_input_shape.begin()),
runtime_input_shape.end(),
std::to_string(runtime_input_shape[0]),
comma_fold);
PADDLE_ENFORCE_EQ(
model_input_shape == runtime_input_shape, true,
model_input_shape == runtime_input_shape,
true,
platform::errors::InvalidArgument(
"Input shapes are inconsistent with the model. Expect [%s] in "
"model description, but got [%s] in runtime. TRT 5 "
......@@ -76,7 +81,8 @@ static void RuntimeStaticShapeCheck(std::vector<int64_t> runtime_input_shape,
"does not support dynamic input shapes. Please check and "
"modify "
"your input shapes.",
model_input_shape_str, runtime_input_shape_str));
model_input_shape_str,
runtime_input_shape_str));
}
static paddle::experimental::DataType TRT2FluidDataType(
......@@ -102,7 +108,8 @@ static paddle::experimental::DataType TRT2FluidDataType(
}
static void RuntimeDynamicShapeCheck(
const std::string &x, const std::vector<int32_t> &runtime_input_shape,
const std::string &x,
const std::vector<int32_t> &runtime_input_shape,
const std::vector<int32_t> &min_input_shape,
const std::vector<int32_t> &max_input_shape) {
// PADDLE_ENFORCE_EQ(
......@@ -111,8 +118,8 @@ static void RuntimeDynamicShapeCheck(
// "TRT engine runtime input %s dims size(%d) inconsistent "
// "with the dynamic shape size(%d)",
// x, runtime_input_shape.size(), min_input_shape.size()));
auto is_input_shape_valid = [&](
const std::vector<int32_t> &runtime_input_shape,
auto is_input_shape_valid =
[&](const std::vector<int32_t> &runtime_input_shape,
const std::vector<int32_t> &min_input_shape,
const std::vector<int32_t> &max_input_shape) -> bool {
for (size_t i = 0; i < runtime_input_shape.size(); i++) {
......@@ -128,17 +135,23 @@ static void RuntimeDynamicShapeCheck(
auto comma_fold = [](std::string a, int b) {
return std::move(a) + ", " + std::to_string(b);
};
std::string runtime_input_shape_str = std::accumulate(
std::next(runtime_input_shape.begin()), runtime_input_shape.end(),
std::to_string(runtime_input_shape[0]), comma_fold);
std::string runtime_input_shape_str =
std::accumulate(std::next(runtime_input_shape.begin()),
runtime_input_shape.end(),
std::to_string(runtime_input_shape[0]),
comma_fold);
std::string min_input_shape_str =
std::accumulate(std::next(min_input_shape.begin()), min_input_shape.end(),
std::to_string(min_input_shape[0]), comma_fold);
std::accumulate(std::next(min_input_shape.begin()),
min_input_shape.end(),
std::to_string(min_input_shape[0]),
comma_fold);
std::string max_input_shape_str =
std::accumulate(std::next(max_input_shape.begin()), max_input_shape.end(),
std::to_string(max_input_shape[0]), comma_fold);
PADDLE_ENFORCE_EQ(is_input_shape_valid(runtime_input_shape, min_input_shape,
max_input_shape),
std::accumulate(std::next(max_input_shape.begin()),
max_input_shape.end(),
std::to_string(max_input_shape[0]),
comma_fold);
PADDLE_ENFORCE_EQ(is_input_shape_valid(
runtime_input_shape, min_input_shape, max_input_shape),
true,
platform::errors::InvalidArgument(
"TRT runtime input shape of %s is invalid. Expect "
......@@ -146,7 +159,9 @@ static void RuntimeDynamicShapeCheck(
"configured in SetTRTDynamicShapeInfo(),"
"but got runtime input shape = [%s], min input shape = "
"[%s], max input shape = [%s].",
x, runtime_input_shape_str, min_input_shape_str,
x,
runtime_input_shape_str,
min_input_shape_str,
max_input_shape_str));
}
......@@ -158,7 +173,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
mutable TensorRTEngine *trt_engine_{nullptr};
int max_batch_size_;
int workspace_size_;
std::unique_ptr<TRTInt8Calibrator> calibrator_;
mutable std::unique_ptr<TRTInt8Calibrator> calibrator_;
bool enable_int8_;
bool enable_fp16_;
bool use_calib_mode_;
......@@ -260,7 +275,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Has(engine_key_ + std::to_string(predictor_id_));
if (!calibration_mode_ && has_engine) {
if (has_engine) {
trt_engine_ =
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Get(engine_key_ + std::to_string(predictor_id_));
......@@ -287,8 +302,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
Attr<std::vector<std::string>>("output_name_mapping");
inference::Singleton<inference::tensorrt::OpConverter>::Global()
.ConvertBlockToTRTEngine(&block_desc, scope, inputs, param_names_,
outputs, engine);
.ConvertBlockToTRTEngine(
&block_desc, scope, inputs, param_names_, outputs, engine);
}
protected:
......@@ -304,11 +319,14 @@ class TensorRTEngineOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto *trt_engine = GetEngine(scope, dev_place);
if (calibration_mode_ == true) {
RunCalibration(scope, dev_place);
RunCalibration(scope, dev_place, trt_engine);
paddle::inference::Singleton<
inference::tensorrt::TRTEngineManager>::Global()
.DeleteKey(engine_key_ + std::to_string(predictor_id_));
return;
}
auto *trt_engine = GetEngine(scope, dev_place);
if (use_inspector_) {
trt_engine->GetEngineInfo();
}
......@@ -331,15 +349,19 @@ class TensorRTEngineOp : public framework::OperatorBase {
trt_engine->max_input_shape();
for (auto &x : runtime_input_names_) {
PADDLE_ENFORCE_EQ(
min_input_shape.count(x), true,
min_input_shape.count(x),
true,
platform::errors::InvalidArgument(
"Input %s not found in TRT engine min_input_shape.", x));
PADDLE_ENFORCE_EQ(
max_input_shape.count(x), true,
max_input_shape.count(x),
true,
platform::errors::InvalidArgument(
"Input %s not found in TRT engine max_input_shape.", x));
RuntimeDynamicShapeCheck(x, runtime_input_shape[x],
min_input_shape[x], max_input_shape[x]);
RuntimeDynamicShapeCheck(x,
runtime_input_shape[x],
min_input_shape[x],
max_input_shape[x]);
}
} else {
// compare runtime_input_shape and trt_engine dynamic shapes.
......@@ -357,12 +379,17 @@ class TensorRTEngineOp : public framework::OperatorBase {
if (anc == nullptr) {
anc = &scope;
}
if (enable_int8_ && calibration_data_.size()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data_));
trt_engine->SetCalibrator(calibrator_.get());
}
PrepareTRTEngine(*anc, trt_engine);
// update shape_range_info_pbtxt
if (!shape_range_info_path_.empty()) {
inference::UpdateShapeRangeInfo(
shape_range_info_path_, trt_engine->min_input_shape(),
trt_engine->max_input_shape(), trt_engine->optim_input_shape(),
inference::UpdateShapeRangeInfo(shape_range_info_path_,
trt_engine->min_input_shape(),
trt_engine->max_input_shape(),
trt_engine->optim_input_shape(),
shape_changed_name);
}
......@@ -387,7 +414,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
}
void RunCalibration(const framework::Scope &scope,
const platform::Place &dev_place) const {
const platform::Place &dev_place,
TensorRTEngine *trt_engine) const {
// This process will builds a 32-bit trt engine, runs it on the calibration
// set, and records a histogram for each
// tensor of the distribution of activation values.
......@@ -412,9 +440,15 @@ class TensorRTEngineOp : public framework::OperatorBase {
calib_res->calib_.reset(new TRTInt8Calibrator(
calib_buffers, runtime_batch, calibration_engine_key_, dev_place));
calib_res->thr_.reset(new std::thread([&]() {
calib_res->engine_.reset(new TensorRTEngine(
max_batch_size_, workspace_size_, precision_mode_,
calib_res->calib_.get(), dev_place.device));
calib_res->engine_.reset(
new TensorRTEngine(max_batch_size_,
workspace_size_,
precision_mode_,
calib_res->calib_.get(),
dev_place.device,
trt_engine->min_input_shape(),
trt_engine->max_input_shape(),
trt_engine->optim_input_shape()));
VLOG(3) << "start the calib trt engine thread";
PrepareTRTEngine(scope, calib_res->engine_.get());
}));
......@@ -436,7 +470,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
RunNativeImpl(scope, dev_place);
}
void RunTrt(const framework::Scope &scope, const platform::Place &dev_place,
void RunTrt(const framework::Scope &scope,
const platform::Place &dev_place,
TensorRTEngine *engine) const {
int runtime_batch = -1;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......@@ -445,7 +480,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx).stream();
PADDLE_ENFORCE_EQ(
runtime_input_names_.empty(), false,
runtime_input_names_.empty(),
false,
platform::errors::PreconditionNotMet(
"TensorRT engine needs at least one input, but no input is found. "
"Please check if you set the input correctly."));
......@@ -488,13 +524,15 @@ class TensorRTEngineOp : public framework::OperatorBase {
const int bind_index =
engine->engine()->getBindingIndex(x.c_str()) + binding_offset;
PADDLE_ENFORCE_LT(
bind_index, num_bindings,
bind_index,
num_bindings,
platform::errors::InvalidArgument(
"Wrong TRT engine input binding index. Expected The "
"binding index of TRT engine input to be less than "
"the number of inputs and outputs. Received binding "
"index=%d >= total inputs and outputs=%d",
bind_index, num_bindings));
bind_index,
num_bindings));
if (!engine->with_dynamic_shape()) {
// check if the input shapes are consistent with model.
if (HasAttr(x + "_shape")) {
......@@ -507,7 +545,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
RuntimeStaticShapeCheck(runtime_input_shape, model_input_shape);
if (runtime_batch != -1) {
PADDLE_ENFORCE_EQ(
runtime_batch, t_shape[0],
runtime_batch,
t_shape[0],
platform::errors::InvalidArgument(
"Inputs of trt subgraphs has different batchsize. "
"It's not allowed in static shape mode. "
......@@ -589,12 +628,14 @@ class TensorRTEngineOp : public framework::OperatorBase {
auto *fluid_t = fluid_v->GetMutable<framework::LoDTensor>();
fluid_t->Resize(phi::make_ddim(ddim));
PADDLE_ENFORCE_LT(bind_index, num_bindings,
PADDLE_ENFORCE_LT(bind_index,
num_bindings,
platform::errors::InvalidArgument(
"The binding index in TRT engine should be less "
"than the number of bindings, but got binding "
"index = %d, number of bindings = %d.",
bind_index, num_bindings));
bind_index,
num_bindings));
auto trt_type = engine->engine()->getBindingDataType(bind_index);
// get adr and set type
buffers[bind_index] = static_cast<void *>(
......@@ -604,7 +645,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
if (!engine->with_dynamic_shape()) {
PADDLE_ENFORCE_LE(
runtime_batch, max_batch_size_,
runtime_batch,
max_batch_size_,
platform::errors::InvalidArgument(
"The runtime batch size (%d) is greater than the max batch "
"size(%d).\n"
......@@ -623,7 +665,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
"\tThe min_subgraph_size shouble to be greater than the number "
"of "
"nodes in the inconsistent subgraph.\n",
runtime_batch, max_batch_size_));
runtime_batch,
max_batch_size_));
}
// Execute the engine.
engine->Execute(runtime_batch, &buffers, stream);
......@@ -635,9 +678,14 @@ class TensorRTEngineOp : public framework::OperatorBase {
trt_engine_ =
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Create(engine_key_ + std::to_string(predictor_id_),
max_batch_size_, workspace_size_, precision_mode_,
calibrator_.get(), device_id_, min_input_shape_,
max_input_shape_, opt_input_shape_);
max_batch_size_,
workspace_size_,
precision_mode_,
calibrator_.get(),
device_id_,
min_input_shape_,
max_input_shape_,
opt_input_shape_);
PrepareTRTEngine(scope, trt_engine_);
}
return trt_engine_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册