未验证 提交 9dd442ab 编写于 作者: W wenbin 提交者: GitHub

disable int8 if there is no quant info (#36900)

* disable int8

* size_t to int
上级 a522f8c6
...@@ -148,12 +148,21 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -148,12 +148,21 @@ void TensorRTEngine::FreezeNetwork() {
// and outputs have scales, // and outputs have scales,
// this layer's precision and output type are set to float32. // this layer's precision and output type are set to float32.
// This step has no effect if this layer is fused during TRT optimization. // This step has no effect if this layer is fused during TRT optimization.
int layers_no_int8 = 0;
for (int i = 0; i < network()->getNbLayers(); i++) { for (int i = 0; i < network()->getNbLayers(); i++) {
auto layer = network()->getLayer(i); auto layer = network()->getLayer(i);
if (!is_layer_int8(layer)) { if (!is_layer_int8(layer)) {
layer->setPrecision(nvinfer1::DataType::kFLOAT); layer->setPrecision(nvinfer1::DataType::kFLOAT);
++layers_no_int8;
} }
} }
// Disable int8 or build engine failed if all layers aren't int8
if (layers_no_int8 == network()->getNbLayers()) {
nvinfer1::BuilderFlags flags = infer_builder_config_->getFlags();
flags = flags & ~(1U << static_cast<int>(nvinfer1::BuilderFlag::kINT8));
// reset flags
infer_builder_config_->setFlags(flags);
}
#else #else
LOG(WARNING) << "If your TensorRT version is lower than 5.1.2.2, you " LOG(WARNING) << "If your TensorRT version is lower than 5.1.2.2, you "
"must provide quantization scales for all tensors using " "must provide quantization scales for all tensors using "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册