未验证 提交 ae84c603 编写于 作者: L Leo Chen 提交者: GitHub

Integrate TRT qdq layers (#54803)

* Integrate quantize/dequantize linear and add config for explicit quantization

* Fix the build error

* Add macro for TRT version < 8.0

* Remove qdq UT from windows

* Fix UT failure

* Check TRT version in qdq UT

* Test tensorrt_explicit_enabled API

* Disable QDQ UT if TRT version < 8.5

* Add quantization postfix into public APIs

* Apply code formatter

* Fix the UT failure for explicit quantization

* Apply code formatter on modified files

* Correct the year in copyright
上级 64f18fa1
......@@ -271,6 +271,9 @@ struct Argument {
TensorRtAllowBuildAtRuntime,
bool);
DECL_ARGUMENT_FIELD(tensorrt_use_inspector, TensorRtUseInspector, bool);
DECL_ARGUMENT_FIELD(tensorrt_use_explicit_quantization,
TensorRtUseExplicitQuantization,
bool);
DECL_ARGUMENT_FIELD(use_dlnne, UseDlnne, bool);
DECL_ARGUMENT_FIELD(dlnne_min_subgraph_size, DlnneMinSubgraphSize, int);
......
......@@ -476,6 +476,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(collect_shape_range_info_);
CP_MEMBER(shape_range_info_path_);
CP_MEMBER(trt_use_inspector_);
CP_MEMBER(trt_use_explicit_quantization_);
CP_MEMBER(trt_engine_memory_sharing_);
CP_MEMBER(trt_engine_memory_sharing_identifier_);
// Dlnne related
......@@ -838,6 +839,11 @@ void AnalysisConfig::EnableTensorRtDLA(int dla_core) {
void AnalysisConfig::EnableTensorRtInspector() { trt_use_inspector_ = true; }
void AnalysisConfig::EnableTensorRtExplicitQuantization() {
trt_use_explicit_quantization_ = true;
Update();
}
void AnalysisConfig::Exp_DisableTensorRtOPs(
const std::vector<std::string> &ops) {
trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end());
......@@ -914,6 +920,13 @@ void AnalysisConfig::Update() {
(pass == "conv_bn_fuse_pass")) {
continue;
}
// The following two IR pass will remove QDQ nodes. For explicit
// quantization, they are unnecessary.
if (trt_use_explicit_quantization_ &&
(pass == "trt_delete_weight_dequant_linear_op_pass" ||
pass == "delete_quant_dequant_linear_op_pass")) {
continue;
}
pass_builder()->AppendPass(pass);
}
}
......
......@@ -1405,6 +1405,8 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetTensorRtAllowBuildAtRuntime(
config_.trt_allow_build_at_runtime());
argument_->SetTensorRtUseInspector(config_.trt_use_inspector_);
argument_->SetTensorRtUseExplicitQuantization(
config_.trt_use_explicit_quantization_);
argument_->SetTrtEngineMemorySharing(config_.trt_engine_memory_sharing());
}
......@@ -2950,6 +2952,10 @@ USE_TRT_CONVERTER(temporal_shift)
USE_TRT_CONVERTER(sparse_fc)
USE_TRT_CONVERTER(sparse_multihead_matmul)
#endif
#if IS_TRT_VERSION_GE(8000)
USE_TRT_CONVERTER(quantize_linear)
USE_TRT_CONVERTER(dequantize_linear)
#endif
#endif
namespace paddle_infer {
......
......@@ -843,10 +843,26 @@ struct PD_INFER_DECL AnalysisConfig {
///
bool tensorrt_dla_enabled() { return trt_use_dla_; }
///
/// \brief A boolean state telling whether to show TensorRT inspector
/// information.
///
/// \return bool Whether to show TensorRT inspector information.
///
void EnableTensorRtInspector();
bool tensorrt_inspector_enabled() { return trt_use_inspector_; }
///
/// \brief A boolean state telling whether to use TensorRT explicit
/// quantization.
///
/// \return bool Whether to use TensorRT explicit quantization.
///
void EnableTensorRtExplicitQuantization();
bool tensorrt_explicit_quantization_enabled() {
return trt_use_explicit_quantization_;
}
void EnableDlnne(
int min_subgraph_size = 3,
int max_batch_size = 1,
......@@ -1226,6 +1242,7 @@ struct PD_INFER_DECL AnalysisConfig {
// tune to get dynamic_shape info.
bool trt_tuned_dynamic_shape_{false};
bool trt_use_inspector_{false};
bool trt_use_explicit_quantization_{false};
// In CollectShapeInfo mode, we will collect the shape information of
// all intermediate tensors in the compute graph and calculate the
......
......@@ -109,7 +109,9 @@ list(
einsum_op.cc
unbind_op.cc
assign_op.cc
flip_op.cc)
flip_op.cc
quantize_linear_op.cc
dequantize_linear_op.cc)
if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7)
list(APPEND CONVERT_FILES emb_eltwise_layernorm.cc
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class DequantizeLinearOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_model) override {
#if IS_TRT_VERSION_GE(8000)
VLOG(4) << "convert a dequantize_linear op to tensorrt IDequantizeLayer";
// Declare inputs and attributes
framework::OpDesc op_desc(op, nullptr);
auto* x = engine_->GetITensor(op_desc.Input("X")[0]);
auto* scale_var = scope.FindVar(op_desc.Input("Scale")[0]);
int axis = PADDLE_GET_CONST(int, op_desc.GetAttr("quant_axis"));
// Create constant layer for scale
PADDLE_ENFORCE_NOT_NULL(
scale_var,
platform::errors::NotFound("Can not find %s presistale var in scope.",
op_desc.Input("Scale")[0]));
auto* scale_t = scale_var->GetMutable<phi::DenseTensor>();
int n_scale = scale_t->numel();
std::vector<float> scale_data(n_scale, 0.0f);
for (int i = 0; i < n_scale; ++i) {
scale_data[i] = scale_t->data<float>()[i] / 127.0f;
}
nvinfer1::Dims scale_dim{1, { n_scale }};
auto* scale = AddConstantLayer(scale_data.data(), scale_dim);
// Add dequantize layer
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Dequantize, *x, *scale);
layer->setAxis(axis);
auto output_name = op_desc.Output("Y")[0];
RreplenishLayerAndOutput(
layer, "dequantize_linear", {output_name}, test_model);
#else
PADDLE_THROW(
platform::errors::Fatal("Paddle-TRT explicit quantization does not "
"support Paddle compiled with TRT < 8.0"));
#endif
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(dequantize_linear, DequantizeLinearOpConverter);
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class QuantizeLinearOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_model) override {
#if IS_TRT_VERSION_GE(8000)
VLOG(4) << "convert a quantize_linear op to tensorrt IQuantizeLayer";
// Declare inputs and attributes
framework::OpDesc op_desc(op, nullptr);
auto* x = engine_->GetITensor(op_desc.Input("X")[0]);
auto* scale_var = scope.FindVar(op_desc.Input("Scale")[0]);
int axis = PADDLE_GET_CONST(int, op_desc.GetAttr("quant_axis"));
// Create constant layer for scale
PADDLE_ENFORCE_NOT_NULL(
scale_var,
platform::errors::NotFound("Can not find %s presistale var in scope.",
op_desc.Input("Scale")[0]));
auto* scale_t = scale_var->GetMutable<phi::DenseTensor>();
int n_scale = scale_t->numel();
std::vector<float> scale_data(n_scale, 0.0f);
for (int i = 0; i < n_scale; ++i) {
scale_data[i] = scale_t->data<float>()[i] / 127.0f;
}
nvinfer1::Dims scale_dim{1, { n_scale }};
auto* scale = AddConstantLayer(scale_data.data(), scale_dim);
// Add quantize layer
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Quantize, *x, *scale);
layer->setAxis(axis);
auto output_name = op_desc.Output("Y")[0];
RreplenishLayerAndOutput(
layer, "quantize_linear", {output_name}, test_model);
#else
PADDLE_THROW(
platform::errors::Fatal("Paddle-TRT explicit quantization does not "
"support Paddle compiled with TRT < 8.0"));
#endif
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(quantize_linear, QuantizeLinearOpConverter);
......@@ -2721,6 +2721,15 @@ struct SimpleOpTypeSetTeller : public Teller {
}
#endif
}
if (op_type == "quantize_linear" || op_type == "dequantize_linear") {
#if !IS_TRT_VERSION_GE(8000)
VLOG(3) << "quantize / dequantize linear is not supported when TensorRT "
"< 8.0";
return false;
#else
return true;
#endif
}
if (op_type == "flip") {
if (!with_dynamic_shape) {
......@@ -2905,7 +2914,9 @@ struct SimpleOpTypeSetTeller : public Teller {
"cumsum",
"unbind",
"assign",
"flip"};
"flip",
"quantize_linear",
"dequantize_linear"};
std::unordered_set<std::string> teller_set{
"matrix_multiply",
......@@ -3070,7 +3081,9 @@ struct SimpleOpTypeSetTeller : public Teller {
"cumsum",
"unbind",
"assign",
"flip"};
"flip",
"quantize_linear",
"dequantize_linear"};
};
struct GenericPluginTeller : public Teller {
......
......@@ -918,6 +918,10 @@ void BindAnalysisConfig(py::module *m) {
&AnalysisConfig::EnableTensorRtInspector)
.def("tensorrt_inspector_enabled",
&AnalysisConfig::tensorrt_inspector_enabled)
.def("enable_tensorrt_explicit_quantization",
&AnalysisConfig::EnableTensorRtExplicitQuantization)
.def("tensorrt_explicit_quantization_enabled",
&AnalysisConfig::tensorrt_explicit_quantization_enabled)
.def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled)
.def("enable_dlnne",
&AnalysisConfig::EnableDlnne,
......
......@@ -44,6 +44,8 @@ if(WIN32)
"test_trt_convert_elementwiseadd_transpose")
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_split_layernorm_to_math_ops_pass")
list(REMOVE_ITEM TEST_TRT_CONVERTER
"test_trt_convert_quantize_dequantize_linear")
endif()
# Only for cpu(mkl + openblas)
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import enum
import os
......@@ -716,7 +717,9 @@ class TrtLayerAutoScanTest(AutoScanTest):
dic["use_trt"] = False
return str(dic)
def run_test(self, quant=False, skip_baseline=False, *args, **kwargs):
def run_test(
self, quant=False, explicit=False, skip_baseline=False, *args, **kwargs
):
all_passes = True
def random_to_skip():
......@@ -781,9 +784,17 @@ class TrtLayerAutoScanTest(AutoScanTest):
pred_config.tensorrt_precision_mode()
== paddle_infer.PrecisionType.Int8
)
if (not is_fp8 and quant) or (is_fp8 and not quant):
if (not is_fp8 and quant) or (
is_fp8 and not (quant or explicit)
):
continue
if explicit:
pred_config.enable_tensorrt_explicit_quantization()
self.assertTrue(
pred_config.tensorrt_explicit_quantization_enabled()
)
ignore_flag = False
for teller, reason, note in self.ignore_cases:
if teller(prog_config, pred_config):
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
from typing import List
import numpy as np
from program_config import ProgramConfig, TensorConfig
from trt_layer_auto_scan_test import TrtLayerAutoScanTest
import paddle.inference as paddle_infer
class TrtConvertQuantizeDequantizeTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
ver = paddle_infer.get_trt_compile_version()
# only TRT > 8.0 has quantize / dequantize layers
if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 8517:
return False
return True
def sample_program_configs(self):
self.trt_param.workspace_size = 1073741824
def generate_input1(shape):
return np.random.random(shape).astype(np.float32)
def generate_add(shape):
return np.ones(shape).astype(np.float32)
def generate_scale():
return np.ones([1]).astype(np.float32) + 2.521234002
def generate_zeropoint():
return np.zeros([1]).astype(np.float32)
desc = [{"quant_axis": -1}]
ops_config = [
{
"op_type": "quantize_linear",
"op_inputs": {
"X": ["input_data_1"],
"Scale": ["scale_data_1"],
"ZeroPoint": ["zeropoint_data_1"],
},
"op_outputs": {
"Y": ["y_data_1"],
},
"op_attrs": desc[0],
},
{
"op_type": "dequantize_linear",
"op_inputs": {
"X": ["y_data_1"],
"Scale": ["scale_data_2"],
"ZeroPoint": ["zeropoint_data_2"],
},
"op_outputs": {
"Y": ["y_data_2"],
},
"op_attrs": desc[0],
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["y_data_2"],
"Y": ["add"],
},
"op_outputs": {
"Out": ["y_data_3"],
},
"op_attrs": {"axis": -1},
"outputs_dtype": {"output_data": np.float32},
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"scale_data_1": TensorConfig(data_gen=partial(generate_scale)),
"zeropoint_data_1": TensorConfig(
data_gen=partial(generate_zeropoint)
),
"scale_data_2": TensorConfig(data_gen=partial(generate_scale)),
"zeropoint_data_2": TensorConfig(
data_gen=partial(generate_zeropoint)
),
"add": TensorConfig(
data_gen=partial(generate_add, [1, 8, 32, 32])
),
},
inputs={
"input_data_1": TensorConfig(
data_gen=partial(generate_input1, [1, 8, 32, 32])
)
},
outputs=["y_data_3"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {
"input_data_1": [1, 8, 32, 32],
"add": [1, 8, 32, 32],
}
self.dynamic_shape.max_input_shape = {
"input_data_1": [16, 8, 32, 32],
"add": [16, 8, 32, 32],
}
self.dynamic_shape.opt_input_shape = {
"input_data_1": [16, 8, 32, 32],
"add": [16, 8, 32, 32],
}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 2
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Int8
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), (1e-2, 1e-2)
def test(self):
self.run_test(quant=False, explicit=True)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册