未验证 提交 31efe00a 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle-Inference] remove int8 fallback (#45762)

* remove int8 fallback
上级 efccf896
......@@ -72,11 +72,15 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
}
auto* shape_tensor = Shape(mask_id_tensor);
std::vector<nvinfer1::ITensor*> start_vec_tensor;
std::vector<nvinfer1::ITensor*> size_vec_tensor;
for (int i = 0; i < mask_dims.nbDims; i++) {
start_vec_tensor.push_back(Add1DConstantLayer(0));
size_vec_tensor.push_back(Add1DConstantLayer(1));
}
size_vec_tensor[1] = GetEleTensorOfShape(shape_tensor, 1);
auto start_tensor = Concat(start_vec_tensor);
auto size_tensor = Concat(size_vec_tensor);
auto slice_layer =
......@@ -86,6 +90,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
slice_start_dims,
slice_start_dims,
slice_stride_dims); // unuseful slice_start_dims
slice_layer->setInput(1, *start_tensor);
slice_layer->setInput(2, *size_tensor);
slice_layer->setName(
("Embeltwise_slice_layer (Output: slice_max_seqlen " +
......
......@@ -89,9 +89,7 @@ void TensorRTEngine::Execute(int batch_size,
if (!with_dynamic_shape()) {
infer_context->enqueue(batch_size, buffers->data(), stream, nullptr);
} else {
#if IS_TRT_VERSION_GE(6000)
infer_context->enqueueV2(buffers->data(), stream, nullptr);
#endif
}
SetRuntimeBatch(batch_size);
}
......@@ -134,7 +132,6 @@ void TensorRTEngine::FreezeNetwork() {
} else {
infer_builder_config_->setInt8Calibrator(nullptr);
#if IS_TRT_VERSION_GE(5000)
for (auto &quant_range : quant_dynamic_range_) {
auto tensor = quant_range.first;
float range = quant_range.second;
......@@ -160,72 +157,6 @@ void TensorRTEngine::FreezeNetwork() {
<< ", this might be ok when trt does not need this range";
}
}
#if IS_TRT_VERSION_GE(5122)
auto layer_int8_fallback = [&](nvinfer1::ILayer *layer) -> bool {
if (layer->getType() == nvinfer1::LayerType::kSHAPE) {
return false;
}
bool all_int = true;
for (int j = 0; j < layer->getNbInputs(); j++) {
auto *temp_in = layer->getInput(j);
if (temp_in->getType() != nvinfer1::DataType::kINT32) {
all_int = false;
}
}
for (int j = 0; j < layer->getNbOutputs(); j++) {
auto *temp_out = layer->getOutput(j);
if (temp_out->getType() != nvinfer1::DataType::kINT32) {
all_int = false;
}
}
if (all_int) return false;
for (int j = 0; j < layer->getNbInputs(); j++) {
auto *temp_in = layer->getInput(j);
if (!temp_in->dynamicRangeIsSet()) {
VLOG(1) << "Layer(Name: " << layer->getName()
<< ") is set to float32 because its input("
<< temp_in->getName() << ") doesn't have dynamic range.";
return true;
}
}
for (int j = 0; j < layer->getNbOutputs(); j++) {
auto *temp_out = layer->getOutput(j);
if (!temp_out->dynamicRangeIsSet()) {
VLOG(1) << "Layer(Name: " << layer->getName()
<< ") is set to float32 because its output("
<< temp_out->getName() << ") doesn't have dynamic range.";
return true;
}
}
return false;
};
// If a layer's output is the network's output, or not all of its inputs
// and outputs have scales,
// this layer's precision and output type are set to float32.
// This step has no effect if this layer is fused during TRT optimization.
int layers_no_int8 = 0;
for (int i = 0; i < network()->getNbLayers(); i++) {
auto layer = network()->getLayer(i);
if (layer_int8_fallback(layer)) {
layer->setPrecision(nvinfer1::DataType::kFLOAT);
++layers_no_int8;
}
}
// Disable int8 or build engine failed if all layers aren't int8
if (layers_no_int8 == network()->getNbLayers()) {
nvinfer1::BuilderFlags flags = infer_builder_config_->getFlags();
flags = flags & ~(1U << static_cast<int>(nvinfer1::BuilderFlag::kINT8));
// reset flags
infer_builder_config_->setFlags(flags);
}
#else
LOG(WARNING) << "If your TensorRT version is lower than 5.1.2.2, you "
"must provide quantization scales for all tensors using "
"TRT to run.";
#endif
#endif
}
}
......@@ -265,7 +196,6 @@ void TensorRTEngine::FreezeNetwork() {
}
if (with_dynamic_shape_) {
#if IS_TRT_VERSION_GE(6000)
LOG(INFO) << "Run Paddle-TRT Dynamic Shape mode.";
for (int i = 0; i < max_profile_num_; i++) {
for (auto &input : min_input_shape_) {
......@@ -310,7 +240,6 @@ void TensorRTEngine::FreezeNetwork() {
"'config.SetDynamicShapeInfo(min_shape, max_shape, "
"opt_shape, false /*disable_trt_plugin_fp16*/)'";
}
#endif
}
#if IS_TRT_VERSION_GE(8200)
if (use_inspector_) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册