未验证 提交 8bbae468 编写于 作者: G Guanghua Yu 提交者: GitHub

Add observer attribute in qdq node & Add quant config for different backends. (#46887)

上级 07db4a9f
...@@ -200,6 +200,12 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -200,6 +200,12 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.") "for training. Some layers may run faster when this is true.")
.SetDefault(true); .SetDefault(true);
AddAttr<bool>(
"only_observer",
"(bool, default false) Whether to only observer or not. If "
"only_observer=false, it will calculate fake quant or dequant output. "
"If only_observer=true, it will only calibrate scale information.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
The scale of QuantizeLinear operator is a vector. The scale of QuantizeLinear operator is a vector.
In detail, each channel of the input X has a scale value. In detail, each channel of the input X has a scale value.
......
...@@ -61,6 +61,7 @@ class QuantizeLinearKernel : public framework::OpKernel<T> { ...@@ -61,6 +61,7 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
int bin_cnt = std::pow(2, bit_length - 1) - 1; int bin_cnt = std::pow(2, bit_length - 1) - 1;
int quant_axis = context.Attr<int>("quant_axis"); int quant_axis = context.Attr<int>("quant_axis");
bool is_test = context.Attr<bool>("is_test"); bool is_test = context.Attr<bool>("is_test");
bool only_observer = context.Attr<bool>("only_observer");
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
if (quant_axis < 0) { if (quant_axis < 0) {
...@@ -91,11 +92,19 @@ class QuantizeLinearKernel : public framework::OpKernel<T> { ...@@ -91,11 +92,19 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
out_state, out_state,
out_accum, out_accum,
out_scale); out_scale);
ClipAndFakeQuantFunctor<DeviceContext, T>()( if (only_observer) {
dev_ctx, *in, *out_scale, bin_cnt, round_type, out); framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
} else {
ClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
}
} else { } else {
ClipAndFakeQuantFunctor<DeviceContext, T>()( if (only_observer) {
dev_ctx, *in, *in_scale, bin_cnt, round_type, out); framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
} else {
ClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *in_scale, bin_cnt, round_type, out);
}
} }
} else { } else {
if (!is_test) { if (!is_test) {
...@@ -103,11 +112,19 @@ class QuantizeLinearKernel : public framework::OpKernel<T> { ...@@ -103,11 +112,19 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace()); T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
FindChannelAbsMaxFunctor<DeviceContext, T>()( FindChannelAbsMaxFunctor<DeviceContext, T>()(
dev_ctx, *in, quant_axis, out_scale_data); dev_ctx, *in, quant_axis, out_scale_data);
ChannelClipAndFakeQuantFunctor<DeviceContext, T>()( if (only_observer) {
dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out); framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
} else {
ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out);
}
} else { } else {
ChannelClipAndFakeQuantFunctor<DeviceContext, T>()( if (only_observer) {
dev_ctx, *in, *in_scale, bin_cnt, round_type, quant_axis, out); framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
} else {
ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *in_scale, bin_cnt, round_type, quant_axis, out);
}
} }
} }
} }
...@@ -132,6 +149,12 @@ class DeQuantizeLinearKernel : public framework::OpKernel<T> { ...@@ -132,6 +149,12 @@ class DeQuantizeLinearKernel : public framework::OpKernel<T> {
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
auto quant_axis = context.Attr<int>("quant_axis"); auto quant_axis = context.Attr<int>("quant_axis");
dev_ctx.template Alloc<D>(out, out->numel() * sizeof(D)); dev_ctx.template Alloc<D>(out, out->numel() * sizeof(D));
bool only_observer = context.Attr<bool>("only_observer");
if (only_observer) {
framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
return;
}
if (quant_axis < 0) { if (quant_axis < 0) {
float max_range = (std::pow(2, bit_length - 1) - 1); float max_range = (std::pow(2, bit_length - 1) - 1);
......
...@@ -24,15 +24,19 @@ from paddle.static.quantization import ( ...@@ -24,15 +24,19 @@ from paddle.static.quantization import (
AddQuantDequantPassV2, AddQuantDequantPassV2,
OutScaleForTrainingPass, OutScaleForTrainingPass,
QuantizationTransformPassV2, QuantizationTransformPassV2,
utils, quant_config,
) )
from ..auto_parallel.converter import Converter from ..auto_parallel.converter import Converter
from ..auto_parallel.dist_attribute import OperatorDistAttr, TensorDistAttr from ..auto_parallel.dist_attribute import OperatorDistAttr, TensorDistAttr
from .pass_base import PassBase, register_pass from .pass_base import PassBase, register_pass
TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type TRANSFORM_PASS_OP_TYPES = list(
QUANT_DEQUANT_PASS_OP_TYPES = utils._act_supported_quantizable_op_type quant_config.SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys()
)
QUANT_DEQUANT_PASS_OP_TYPES = list(
quant_config.SUPPORT_ACT_QUANTIZATION_OP_DICT.keys()
)
def _node_id(node): def _node_id(node):
......
...@@ -35,7 +35,15 @@ from ..log_helper import get_logger ...@@ -35,7 +35,15 @@ from ..log_helper import get_logger
from . import utils from . import utils
from .adaround import run_adaround from .adaround import run_adaround
from .cal_kl_threshold import cal_kl_threshold from .cal_kl_threshold import cal_kl_threshold
from .quant_config import (
SUPPORT_QUANTIZATION_OP_DICT,
ARMCPUQuantizer,
BaseQuantizer,
MKLDNNQuantizer,
TensorRTQuantizer,
)
from .quantization_pass import ( from .quantization_pass import (
AddQuantDequantForInferencePass,
AddQuantDequantPass, AddQuantDequantPass,
AddQuantDequantPassV2, AddQuantDequantPassV2,
QuantizationFreezePass, QuantizationFreezePass,
...@@ -127,7 +135,7 @@ class PostTrainingQuantization: ...@@ -127,7 +135,7 @@ class PostTrainingQuantization:
batch_nums=None, batch_nums=None,
algo="KL", algo="KL",
hist_percent=0.99999, hist_percent=0.99999,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], quantizable_op_type=[],
round_type='round', round_type='round',
learning_rate=0.001, learning_rate=0.001,
is_full_quantize=False, is_full_quantize=False,
...@@ -145,6 +153,7 @@ class PostTrainingQuantization: ...@@ -145,6 +153,7 @@ class PostTrainingQuantization:
cache_dir=None, cache_dir=None,
scale_dict=None, scale_dict=None,
return_graph=False, return_graph=False,
deploy_backend=None,
): ):
''' '''
Constructor. Constructor.
...@@ -190,8 +199,9 @@ class PostTrainingQuantization: ...@@ -190,8 +199,9 @@ class PostTrainingQuantization:
hist_percent(float, optional): The threshold of algo 'hist' for activations. hist_percent(float, optional): The threshold of algo 'hist' for activations.
Default is 0.99999. Default is 0.99999.
quantizable_op_type(list[str], optional): List the type of ops quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is ["conv2d", "depthwise_conv2d", that will be quantized. Default is []. If quantizable_op_type is [],
"mul"]. it will use the default quantization op type of the qunat config in
the current deploy_backend.
round_type(str, optional): The method of converting the quantized weights round_type(str, optional): The method of converting the quantized weights
value float->int. Currently supports ['round', 'adaround'] methods. value float->int. Currently supports ['round', 'adaround'] methods.
Default is `round`, which is rounding nearest to the integer. Default is `round`, which is rounding nearest to the integer.
...@@ -199,8 +209,8 @@ class PostTrainingQuantization: ...@@ -199,8 +209,8 @@ class PostTrainingQuantization:
learning_rate(float, optional): The learning rate of adaround method. learning_rate(float, optional): The learning rate of adaround method.
is_full_quantized(bool, optional): If set is_full_quantized as True, is_full_quantized(bool, optional): If set is_full_quantized as True,
apply quantization to all supported quantizable op type. If set apply quantization to all supported quantizable op type. If set
is_full_quantized as False, only apply quantization to the op type is_full_quantized as False, it will apply quantization to the op type
according to the input quantizable_op_type. according to the input quantizable_op_type or quant config of deploy_backend.
bias_correction(bool, optional): If set as True, use the bias correction bias_correction(bool, optional): If set as True, use the bias correction
method of https://arxiv.org/abs/1810.05723. Default is False. method of https://arxiv.org/abs/1810.05723. Default is False.
activation_bits(int): quantization bit number for activation. activation_bits(int): quantization bit number for activation.
...@@ -234,6 +244,9 @@ class PostTrainingQuantization: ...@@ -234,6 +244,9 @@ class PostTrainingQuantization:
quantization. Default False. quantization. Default False.
is_use_cache_file(bool, optional): This param is deprecated. is_use_cache_file(bool, optional): This param is deprecated.
cache_dir(str, optional): This param is deprecated. cache_dir(str, optional): This param is deprecated.
deploy_backend(str, optional): Deploy backend, it can be None, `TensorRT`,
`MKLDNN`, `ARM`. And it will extend the new backend. Default is None,
which means to use the default general quantization configuration.
Returns: Returns:
None None
...@@ -294,13 +307,6 @@ class PostTrainingQuantization: ...@@ -294,13 +307,6 @@ class PostTrainingQuantization:
self._round_type = round_type self._round_type = round_type
self._learning_rate = learning_rate self._learning_rate = learning_rate
self._dynamic_quantize_op_type = ['lstm'] self._dynamic_quantize_op_type = ['lstm']
self._support_quantize_op_type = list(
set(
utils._weight_supported_quantizable_op_type
+ utils._act_supported_quantizable_op_type
+ self._dynamic_quantize_op_type
)
)
# Check inputs # Check inputs
assert executor is not None, "The executor cannot be None." assert executor is not None, "The executor cannot be None."
...@@ -355,15 +361,6 @@ class PostTrainingQuantization: ...@@ -355,15 +361,6 @@ class PostTrainingQuantization:
self._onnx_format = onnx_format self._onnx_format = onnx_format
self._clip_extra = True if self._onnx_format else False self._clip_extra = True if self._onnx_format else False
self._skip_tensor_list = skip_tensor_list self._skip_tensor_list = skip_tensor_list
self._is_full_quantize = is_full_quantize
if is_full_quantize:
self._quantizable_op_type = self._support_quantize_op_type
else:
self._quantizable_op_type = quantizable_op_type
for op_type in self._quantizable_op_type:
assert op_type in self._support_quantize_op_type, (
op_type + " is not supported for quantization."
)
self._optimize_model = optimize_model self._optimize_model = optimize_model
# Define variables # Define variables
...@@ -373,7 +370,6 @@ class PostTrainingQuantization: ...@@ -373,7 +370,6 @@ class PostTrainingQuantization:
self._fetch_list = None self._fetch_list = None
self._data_loader = data_loader self._data_loader = data_loader
self._out_scale_op_list = utils.QUANT_SUPPORTED_OP_TYPE_LIST
self._quantized_weight_var_name = set() self._quantized_weight_var_name = set()
self._quantized_act_var_name = set() self._quantized_act_var_name = set()
self._weight_op_pairs = {} self._weight_op_pairs = {}
...@@ -403,6 +399,43 @@ class PostTrainingQuantization: ...@@ -403,6 +399,43 @@ class PostTrainingQuantization:
if self._program is not None: if self._program is not None:
self.FLAG = True self.FLAG = True
self._is_full_quantize = is_full_quantize
if is_full_quantize:
quantizable_op_type = list(SUPPORT_QUANTIZATION_OP_DICT.keys())
elif quantizable_op_type:
for op_type in quantizable_op_type:
assert op_type in list(SUPPORT_QUANTIZATION_OP_DICT.keys()), (
op_type + " is not supported for quantization."
)
assert (
activation_bits == weight_bits
), "activation_bits and weight_bits must be the same, other cases are not supported."
support_deploy_backend = [None, "tensorrt", "mkldnn", "arm"]
if not deploy_backend:
self.quant_config = BaseQuantizer(
quantizable_op_type=quantizable_op_type,
quant_bits=weight_bits,
)
elif deploy_backend.lower() == "tensorrt":
self.quant_config = TensorRTQuantizer(
quantizable_op_type=quantizable_op_type,
quant_bits=weight_bits,
)
elif deploy_backend.lower() == "mkldnn":
self.quant_config = MKLDNNQuantizer(
quantizable_op_type=quantizable_op_type,
quant_bits=weight_bits,
)
elif deploy_backend.lower() == "arm":
self.quant_config = ARMCPUQuantizer(
quantizable_op_type=quantizable_op_type,
quant_bits=weight_bits,
)
else:
assert "Deploy Backend {} not support, please choose one of {}.".format(
deploy_backend, support_deploy_backend
)
def quantize(self): def quantize(self):
''' '''
Load the FP32 model, and use the calibrate data to calculate the forward-stage. Load the FP32 model, and use the calibrate data to calculate the forward-stage.
...@@ -486,7 +519,7 @@ class PostTrainingQuantization: ...@@ -486,7 +519,7 @@ class PostTrainingQuantization:
self._save_output_threshold() self._save_output_threshold()
if any( if any(
op_type in self._quantizable_op_type op_type in self.quant_config.activation_quant_operation_types
for op_type in self._dynamic_quantize_op_type for op_type in self._dynamic_quantize_op_type
): ):
self._collect_dynamic_quantize_op_threshold( self._collect_dynamic_quantize_op_threshold(
...@@ -652,9 +685,8 @@ class PostTrainingQuantization: ...@@ -652,9 +685,8 @@ class PostTrainingQuantization:
op._set_attr("op_namescope", "skip_quant") op._set_attr("op_namescope", "skip_quant")
op_type = op.type op_type = op.type
if ( if self._is_full_quantize and op_type not in list(
self._is_full_quantize SUPPORT_QUANTIZATION_OP_DICT.keys()
and op_type not in self._quantizable_op_type
): ):
_logger.warning( _logger.warning(
op_type + " is not supported for quantization." op_type + " is not supported for quantization."
...@@ -664,7 +696,12 @@ class PostTrainingQuantization: ...@@ -664,7 +696,12 @@ class PostTrainingQuantization:
in persistable_var_names in persistable_var_names
) )
# For quantized ops, sample inputs and outputs # For quantized ops, sample inputs and outputs
if op_type in self._quantizable_op_type or is_conv1d_quant: if (
op_type in self.quant_config.weight_quant_operation_types
or op_type
in self.quant_config.activation_quant_operation_types
or is_conv1d_quant
):
collect_var_name( collect_var_name(
utils._get_op_input_var_names(op), utils._get_op_input_var_names(op),
persistable_var_names, persistable_var_names,
...@@ -683,7 +720,7 @@ class PostTrainingQuantization: ...@@ -683,7 +720,7 @@ class PostTrainingQuantization:
in_var_name in_var_name
] = out_var_name ] = out_var_name
# For other op, only sample output scale # For other op, only sample output scale
elif op_type in self._out_scale_op_list: elif op_type in self.quant_config.observer_operation_types:
collect_var_name( collect_var_name(
utils._get_op_output_var_names(op), utils._get_op_output_var_names(op),
persistable_var_names, persistable_var_names,
...@@ -1034,7 +1071,11 @@ class PostTrainingQuantization: ...@@ -1034,7 +1071,11 @@ class PostTrainingQuantization:
), "The algo should be min_max to save input threshold." ), "The algo should be min_max to save input threshold."
for block_id in range(len(self._program.blocks)): for block_id in range(len(self._program.blocks)):
for op in self._program.blocks[block_id].ops: for op in self._program.blocks[block_id].ops:
if op.type in self._quantizable_op_type: if (
op.type in self.quant_config.weight_quant_operation_types
or op.type
in self.quant_config.activation_quant_operation_types
):
for var_name in utils._get_op_input_var_names(op): for var_name in utils._get_op_input_var_names(op):
assert var_name in self._quantized_var_min assert var_name in self._quantized_var_min
assert var_name in self._quantized_var_max assert var_name in self._quantized_var_max
...@@ -1142,10 +1183,6 @@ class PostTrainingQuantization: ...@@ -1142,10 +1183,6 @@ class PostTrainingQuantization:
graph = IrGraph(core.Graph(self._program.desc), for_test=True) graph = IrGraph(core.Graph(self._program.desc), for_test=True)
# use QuantizationTransformPass to insert fake_quant/fake_dequantize op # use QuantizationTransformPass to insert fake_quant/fake_dequantize op
major_quantizable_op_types = []
for op_type in utils._weight_supported_quantizable_op_type:
if op_type in self._quantizable_op_type:
major_quantizable_op_types.append(op_type)
if not self._onnx_format: if not self._onnx_format:
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=self._scope, scope=self._scope,
...@@ -1154,7 +1191,7 @@ class PostTrainingQuantization: ...@@ -1154,7 +1191,7 @@ class PostTrainingQuantization:
activation_bits=self._activation_bits, activation_bits=self._activation_bits,
activation_quantize_type=self._activation_quantize_type, activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type, weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types, quantizable_op_type=self.quant_config.weight_quant_operation_types,
) )
else: else:
transform_pass = QuantizationTransformPassV2( transform_pass = QuantizationTransformPassV2(
...@@ -1164,7 +1201,7 @@ class PostTrainingQuantization: ...@@ -1164,7 +1201,7 @@ class PostTrainingQuantization:
activation_bits=self._activation_bits, activation_bits=self._activation_bits,
activation_quantize_type=self._activation_quantize_type, activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type, weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types, quantizable_op_type=self.quant_config.weight_quant_operation_types,
) )
for sub_graph in graph.all_sub_graphs(): for sub_graph in graph.all_sub_graphs():
...@@ -1174,22 +1211,17 @@ class PostTrainingQuantization: ...@@ -1174,22 +1211,17 @@ class PostTrainingQuantization:
transform_pass.apply(sub_graph) transform_pass.apply(sub_graph)
# use AddQuantDequantPass to insert fake_quant_dequant op # use AddQuantDequantPass to insert fake_quant_dequant op
minor_quantizable_op_types = []
for op_type in utils._act_supported_quantizable_op_type:
if op_type in self._quantizable_op_type:
minor_quantizable_op_types.append(op_type)
if not self._onnx_format: if not self._onnx_format:
add_quant_dequant_pass = AddQuantDequantPass( add_quant_dequant_pass = AddQuantDequantPass(
scope=self._scope, scope=self._scope,
place=self._place, place=self._place,
quantizable_op_type=minor_quantizable_op_types, quantizable_op_type=self.quant_config.activation_quant_operation_types,
) )
else: else:
add_quant_dequant_pass = AddQuantDequantPassV2( add_quant_dequant_pass = AddQuantDequantPassV2(
scope=self._scope, scope=self._scope,
place=self._place, place=self._place,
quantizable_op_type=minor_quantizable_op_types, quantizable_op_type=self.quant_config.activation_quant_operation_types,
is_full_quantized=True,
) )
for sub_graph in graph.all_sub_graphs(): for sub_graph in graph.all_sub_graphs():
...@@ -1283,7 +1315,7 @@ class PostTrainingQuantization: ...@@ -1283,7 +1315,7 @@ class PostTrainingQuantization:
round_type=self._round_type, round_type=self._round_type,
activation_bits=self._activation_bits, activation_bits=self._activation_bits,
weight_quantize_type=self._weight_quantize_type, weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types, quantizable_op_type=self.quant_config.weight_quant_operation_types,
) )
for sub_graph in graph.all_sub_graphs(): for sub_graph in graph.all_sub_graphs():
...@@ -1295,6 +1327,22 @@ class PostTrainingQuantization: ...@@ -1295,6 +1327,22 @@ class PostTrainingQuantization:
sub_graph._for_test = True sub_graph._for_test = True
quant_weight_pass.apply(sub_graph) quant_weight_pass.apply(sub_graph)
infer_pass_quant_op_types = (
self.quant_config.weight_quant_operation_types
+ self.quant_config.activation_quant_operation_types
+ self.quant_config.observer_operation_types
)
out_scale_infer_pass = AddQuantDequantForInferencePass(
scope=self._scope,
place=self._place,
quant_bits=self._activation_bits,
quantizable_op_type=infer_pass_quant_op_types,
calibration_range_dict=self._scale_dict,
)
for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
out_scale_infer_pass.apply(sub_graph)
self._program = graph.to_program() self._program = graph.to_program()
def _save_output_threshold(self): def _save_output_threshold(self):
...@@ -1339,7 +1387,12 @@ class PostTrainingQuantization: ...@@ -1339,7 +1387,12 @@ class PostTrainingQuantization:
threshold_map[out_var_name], threshold_map[out_var_name],
) )
op_node._set_attr("with_quant_attr", True) op_node._set_attr("with_quant_attr", True)
if op_node.type in self._quantizable_op_type: if (
op_node.type
in self.quant_config.weight_quant_operation_types
or op_node.type
in self.quant_config.activation_quant_operation_types
):
op._set_attr("quantization_type", quantized_type) op._set_attr("quantization_type", quantized_type)
def analysis_and_save_info(op_node, out_var_name): def analysis_and_save_info(op_node, out_var_name):
...@@ -1387,7 +1440,9 @@ class PostTrainingQuantization: ...@@ -1387,7 +1440,9 @@ class PostTrainingQuantization:
for block_id in range(len(self._program.blocks)): for block_id in range(len(self._program.blocks)):
for op in self._program.blocks[block_id].ops: for op in self._program.blocks[block_id].ops:
if op.type in ( if op.type in (
self._quantizable_op_type + self._out_scale_op_list self.quant_config.weight_quant_operation_types
+ self.quant_config.activation_quant_operation_types
+ self.quant_config.observer_operation_types
): ):
out_var_names = utils._get_op_output_var_names(op) out_var_names = utils._get_op_output_var_names(op)
for var_name in out_var_names: for var_name in out_var_names:
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# A dict of operators that contain weights and support quantization,
# including operator names, actual input and output names.
SUPPORT_WEIGHT_QUANTIZATION_OP_DICT = {
"conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input", "Filter"], ["Output"]],
"conv2d_transpose": [["Input", "Filter"], ["Output"]],
"mul": [["X", "Y"], ["Out"]],
"matmul": [["X", "Y"], ["Out"]],
"matmul_v2": [["X", "Y"], ["Out"]],
}
# A dict of operators that supports quantization and has only activation inputs,
# including operator names, actual input and output names.
SUPPORT_ACT_QUANTIZATION_OP_DICT = {
"mul": [["X", "Y"], ["Out"]],
"matmul": [["X", "Y"], ["Out"]],
"matmul_v2": [["X", "Y"], ["Out"]],
"pool2d": [["X"], ["Out"]],
"elementwise_add": [["X", "Y"], ["Out"]],
"concat": [["X"], ["Out"]],
"softmax": [["X"], ["Out"]],
"argmax": [["X"], ["Out"]],
"transpose": [["X"], ["Out"]],
"equal": [["X", "Y"], ["Out"]],
"gather": [["X"], ["Out"]],
"greater_equal": [["X", "Y"], ["Out"]],
"greater_than": [["X", "Y"], ["Out"]],
"less_equal": [["X", "Y"], ["Out"]],
"less_than": [["X", "Y"], ["Out"]],
"mean": [["X"], ["Out"]],
"not_equal": [["X", "Y"], ["Out"]],
"reshape": [["X"], ["Out"]],
"reshape2": [["X"], ["Out"]],
"transpose2": [["X"], ["Out"]],
"nearest_interp": [["X"], ["Out"]],
"trilinear_interp": [["X"], ["Out"]],
"slice": [["Input"], ["Out"]],
"squeeze": [["X"], ["Out"]],
"elementwise_sub": [["X", "Y"], ["Out"]],
"relu": [["X"], ["Out"]],
"relu6": [["X"], ["Out"]],
"leaky_relu": [["X"], ["Out"]],
"prelu": [["X", "Alpha"], ["Out"]],
"tanh": [["X"], ["Out"]],
"swish": [["X"], ["Out"]],
"dropout": [["X"], ["Out"]],
"batch_norm": [["X"], ["Y"]],
"layer_norm": [["X"], ["Y"]],
"sigmoid": [["X"], ["Out"]],
"elementwise_mul": [["X", "Y"], ["Out"]],
"elementwise_pow": [["X", "Y"], ["Out"]],
"hard_swish": [["X"], ["Out"]],
"hard_sigmoid": [["X"], ["Out"]],
"gru": [["Input", "Weight"], ["Hidden"]],
"lstm": [["Input", "Weight"], ["Hidden"]],
"pad2d": [["X"], ["Out"]],
"pad3d": [["X"], ["Out"]],
"flatten": [["X"], ["Out"]],
"flatten2": [["X"], ["Out"]],
"unsqueeze2": [["X"], ["Out"]],
"flatten_contiguous_range": [["X"], ["Out"]],
"split": [["X"], ["Out"]],
"squeeze2": [["X"], ["Out"]],
"nearest_interp_v2": [["X"], ["Out"]],
"bilinear_interp": [["X"], ["Out"]],
"bilinear_interp_v2": [["X"], ["Out"]],
"fill_constant_batch_size_like": [["Input"], ["Out"]],
"arg_max": [["X"], ["Out"]],
"abs": [["X"], ["Out"]],
"assign": [["X"], ["Out"]],
"cast": [["X"], ["Out"]],
"clip": [["X"], ["Out"]],
"box_coder": [["PriorBox"], ["OutputBox"]],
"crop": [["X"], ["Out"]],
"cumsum": [["X"], ["Out"]],
"expand_v2": [["X"], ["Out"]],
"fill_any_like": [["X"], ["Out"]],
"fill_constant": [[], ["Out"]],
"gelu": [["X"], ["Out"]],
"instance_norm": [["X"], ["Y"]],
"lookup_table": [["W", "Ids"], ["Out"]],
"lookup_table_v2": [["W", "Ids"], ["Out"]],
"norm": [["X"], ["Norm"]],
"p_norm": [["X"], ["Out"]],
"pow": [["X"], ["Out"]],
"reduce_mean": [["X"], ["Out"]],
"stack": [["X"], ["Y"]],
"top_k_v2": [["X"], ["Out", "Indices"]],
"logical_and": [["X", "Y"], ["Out"]],
"logical_not": [["X"], ["Out"]],
"meshgrid": [["X"], ["Out"]],
"roi_align": [["X", "ROIs"], ["Out"]],
"strided_slice": [["Input"], ["Out"]],
"where": [["Condition", "X", "Y"], ["Out"]],
"grid_sampler": [["X", "Grid"], ["Output"]],
"tile": [["X"], ["Out"]],
"group_norm": [["X"], ["Y", "Mean", "Variance"]],
"reduce_sum": [["X"], ["Out"]],
"square": [["X"], ["Out"]],
"softplus": [["X"], ["Out"]],
"shuffle_channel": [["X"], ["Out"]],
"reduce_max": [["X"], ["Out"]],
"scale": [["X"], ["Out"]],
}
# A full dict of operators that supports quantization,
# including operator names, actual input and output names.
SUPPORT_QUANTIZATION_OP_DICT = SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.copy()
SUPPORT_QUANTIZATION_OP_DICT.update(SUPPORT_ACT_QUANTIZATION_OP_DICT)
class BaseQuantizer:
"""
Basic quantization configuration class, which configures some hyperparameters
required for quantization, including the list of op types to be quantized,
quantization bit number for weight and activation and the range of quantization values.
Args:
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is []. If quantizable_op_type is [],
it will use the default quantization op type of the qunat config in
the current Quantizer.
quant_bits(int, optional): Quantization bit number for weight and activation.
Default is 8.
"""
def __init__(
self,
quantizable_op_type=[],
quant_bits=8,
):
self._quantizable_op_type = quantizable_op_type
self._quant_bits = quant_bits
self._quant_min = -128
self._quant_max = 127
@property
def weight_quant_operation_types(self):
"""
Operation type list which should support weight quantization.
And before these ops, quant dequant nodes will be inserted.
"""
base_weight_op_type_list = list(
SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys()
)
if self._quantizable_op_type:
weight_list = []
for _op_type in self._quantizable_op_type:
if _op_type in base_weight_op_type_list:
weight_list.append(_op_type)
return weight_list
else:
return base_weight_op_type_list
@property
def activation_quant_operation_types(self):
"""
Operation type list which should support activation quantization.
And before these ops, quant dequant nodes will be inserted.
"""
base_act_op_type_list = list(SUPPORT_ACT_QUANTIZATION_OP_DICT.keys())
act_quant_op_list = []
if self._quantizable_op_type:
for _op_type in self._quantizable_op_type:
if _op_type in base_act_op_type_list:
act_quant_op_list.append(_op_type)
else:
act_quant_op_list = [
'mul',
'matmul',
'matmul_v2',
]
return act_quant_op_list
@property
def observer_operation_types(self):
"""
Operation type list for observer in quantization. These nodes only count the
calibration boundary scale and do not participate in the fake quantization.
In order to facilitate the deployment of the prediction engine, quant
and dequant nodes will be inserted after these ops when exporting the model.
"""
return list(SUPPORT_ACT_QUANTIZATION_OP_DICT.keys())
class TensorRTQuantizer(BaseQuantizer):
"""
TensorRT quantization configuration class.
Args:
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is []. If quantizable_op_type is [],
it will use the default quantization op type of the qunat config in
the current Quantizer.
quant_bits(int, optional): Quantization bit number for weight and activation.
Default is 8.
"""
def __init__(
self,
quantizable_op_type=[],
quant_bits=8,
):
super().__init__()
self._quantizable_op_type = quantizable_op_type
self._quant_bits = quant_bits
self._quant_min = -128
self._quant_max = 127
@property
def activation_quant_operation_types(self):
"""
Operation type list which should support activation quantization.
And before these ops, quant dequant nodes will be inserted.
"""
return [
"pool2d",
"elementwise_add",
"elementwise_sub",
"elementwise_mul",
"elementwise_pow",
"concat",
"softmax",
"argmax",
"mean",
"relu",
"relu6",
"leaky_relu",
"tanh",
"swish",
"softplus",
"gelu",
"hard_sigmoid",
"hard_swish",
"sigmoid",
"layer_norm",
"matmul_v2",
"split",
"bilinear_interp",
"nearest_interp",
"trilinear_interp",
"nearest_interp_v2",
"bilinear_interp",
"bilinear_interp_v2",
"clip",
"pow",
"reduce_mean",
"reduce_sum",
"reduce_max",
]
class MKLDNNQuantizer(BaseQuantizer):
"""
MKLDNN quantization configuration class.
Args:
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is []. If quantizable_op_type is [],
it will use the default quantization op type of the qunat config in
the current Quantizer.
quant_bits(int, optional): Quantization bit number for weight and activation.
Default is 8.
"""
def __init__(
self,
quantizable_op_type=[],
quant_bits=8,
):
super().__init__()
self._quantizable_op_type = quantizable_op_type
self._quant_bits = quant_bits
self._quant_min = -128
self._quant_max = 127
@property
def activation_quant_operation_types(self):
"""
Operation type list which should support activation quantization.
And before these ops, quant dequant nodes will be inserted.
"""
return [
"pool2d",
"elementwise_add",
"elementwise_mul",
"concat",
"nearest_interp",
"nearest_interp_v2",
"split",
]
class ARMCPUQuantizer(BaseQuantizer):
"""
ARM CPU with Paddle Lite quantization configuration class.
Args:
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is []. If quantizable_op_type is [],
it will use the default quantization op type of the qunat config in
the current Quantizer.
quant_bits(int, optional): Quantization bit number for weight and activation.
Default is 8.
"""
def __init__(
self,
quantizable_op_type=[],
quant_bits=8,
):
super().__init__()
self._quantizable_op_type = quantizable_op_type
self._quant_bits = quant_bits
self._quant_min = -127
self._quant_max = 127
...@@ -28,6 +28,11 @@ from ...framework import _get_paddle_place, core ...@@ -28,6 +28,11 @@ from ...framework import _get_paddle_place, core
from ...static import Program, data, program_guard, scope_guard from ...static import Program, data, program_guard, scope_guard
from ...utils import unique_name from ...utils import unique_name
from . import utils from . import utils
from .quant_config import (
SUPPORT_ACT_QUANTIZATION_OP_DICT,
SUPPORT_QUANTIZATION_OP_DICT,
SUPPORT_WEIGHT_QUANTIZATION_OP_DICT,
)
_fake_quant_op_list = [ _fake_quant_op_list = [
'fake_quantize_abs_max', 'fake_quantize_abs_max',
...@@ -231,7 +236,7 @@ class QuantizationTransformPass: ...@@ -231,7 +236,7 @@ class QuantizationTransformPass:
self._quantizable_ops = quantizable_op_type self._quantizable_ops = quantizable_op_type
for op in self._quantizable_ops: for op in self._quantizable_ops:
assert op in utils._weight_supported_quantizable_op_type, ( assert op in list(SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys()), (
op + " is not supported for quantization." op + " is not supported for quantization."
) )
self._quantizable_grad_ops = [ self._quantizable_grad_ops = [
...@@ -1594,7 +1599,7 @@ class OutScaleForTrainingPass: ...@@ -1594,7 +1599,7 @@ class OutScaleForTrainingPass:
self._place = _get_paddle_place(place) self._place = _get_paddle_place(place)
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._is_test = is_test self._is_test = is_test
self._teller_set = utils.QUANT_SUPPORTED_OP_TYPE_LIST self._teller_set = list(SUPPORT_QUANTIZATION_OP_DICT.keys())
self._scale_dict = scale_dict self._scale_dict = scale_dict
def apply(self, graph): def apply(self, graph):
...@@ -1749,7 +1754,7 @@ class OutScaleForInferencePass: ...@@ -1749,7 +1754,7 @@ class OutScaleForInferencePass:
scope(static.Scope): The scope is used to initialize these new parameters. scope(static.Scope): The scope is used to initialize these new parameters.
""" """
self._scope = scope self._scope = scope
self._teller_set = utils.QUANT_SUPPORTED_OP_TYPE_LIST self._teller_set = list(SUPPORT_QUANTIZATION_OP_DICT.keys())
def apply(self, graph): def apply(self, graph):
""" """
...@@ -1830,7 +1835,6 @@ class AddQuantDequantPass: ...@@ -1830,7 +1835,6 @@ class AddQuantDequantPass:
quant_bits=8, quant_bits=8,
skip_pattern=["skip_quant"], skip_pattern=["skip_quant"],
quantizable_op_type=["elementwise_add", "pool2d"], quantizable_op_type=["elementwise_add", "pool2d"],
is_full_quantized=False,
is_test=None, is_test=None,
scale_dict=None, scale_dict=None,
): ):
...@@ -1851,10 +1855,6 @@ class AddQuantDequantPass: ...@@ -1851,10 +1855,6 @@ class AddQuantDequantPass:
Default is 'skip_quant'. Default is 'skip_quant'.
quantizable_op_type(list[str], optional): List the type of ops that will be quantizable_op_type(list[str], optional): List the type of ops that will be
quantized. Default is ["elementwise_add", "pool2d"]. quantized. Default is ["elementwise_add", "pool2d"].
is_full_quantized(bool, optional): If set is_full_quantized as True, apply
quantization to all supported quantizable op type. If set is_full_quantized
as False, only apply quantization to the op type according to the input
quantizable_op_type.
""" """
self._scope = scope self._scope = scope
self._place = _get_paddle_place(place) self._place = _get_paddle_place(place)
...@@ -1864,14 +1864,11 @@ class AddQuantDequantPass: ...@@ -1864,14 +1864,11 @@ class AddQuantDequantPass:
self._skip_pattern = skip_pattern self._skip_pattern = skip_pattern
self._scale_dict = scale_dict self._scale_dict = scale_dict
if is_full_quantized: self._quantizable_op_type = quantizable_op_type
self._quantizable_op_type = utils._act_supported_quantizable_op_type for op_type in self._quantizable_op_type:
else: assert op_type in list(SUPPORT_ACT_QUANTIZATION_OP_DICT.keys()), (
self._quantizable_op_type = quantizable_op_type op_type + " is not supported for quantization."
for op_type in quantizable_op_type: )
assert op_type in utils._act_supported_quantizable_op_type, (
op_type + " is not supported for quantization."
)
self._quantizable_grad_op_type = [ self._quantizable_grad_op_type = [
'%s_grad' % (op) for op in self._quantizable_op_type '%s_grad' % (op) for op in self._quantizable_op_type
] ]
...@@ -2485,7 +2482,7 @@ class QuantizationTransformPassV2(QuantizationTransformPass): ...@@ -2485,7 +2482,7 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
self._quantizable_ops = quantizable_op_type self._quantizable_ops = quantizable_op_type
for op in self._quantizable_ops: for op in self._quantizable_ops:
assert op in utils._weight_supported_quantizable_op_type, ( assert op in list(SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys()), (
op + " is not supported for quantization." op + " is not supported for quantization."
) )
self._quantizable_grad_ops = [ self._quantizable_grad_ops = [
...@@ -2763,7 +2760,6 @@ class AddQuantDequantPassV2: ...@@ -2763,7 +2760,6 @@ class AddQuantDequantPassV2:
quant_bits=8, quant_bits=8,
skip_pattern=["skip_quant"], skip_pattern=["skip_quant"],
quantizable_op_type=["elementwise_add", "pool2d"], quantizable_op_type=["elementwise_add", "pool2d"],
is_full_quantized=False,
is_test=None, is_test=None,
scale_dict=None, scale_dict=None,
): ):
...@@ -2782,10 +2778,6 @@ class AddQuantDequantPassV2: ...@@ -2782,10 +2778,6 @@ class AddQuantDequantPassV2:
Default is 'skip_quant'. Default is 'skip_quant'.
quantizable_op_type(list[str], optional): List the type of ops that will be quantizable_op_type(list[str], optional): List the type of ops that will be
quantized. Default is ["elementwise_add", "pool2d"]. quantized. Default is ["elementwise_add", "pool2d"].
is_full_quantized(bool, optional): If set is_full_quantized as True, apply
quantization to all supported quantizable op type. If set is_full_quantized
as False, only apply quantization to the op type according to the input
quantizable_op_type.
scale_dict(dict, optional): calibration ranges of tensors output. scale_dict(dict, optional): calibration ranges of tensors output.
Examples: Examples:
...@@ -2811,14 +2803,11 @@ class AddQuantDequantPassV2: ...@@ -2811,14 +2803,11 @@ class AddQuantDequantPassV2:
self._skip_pattern = skip_pattern self._skip_pattern = skip_pattern
self._scale_dict = scale_dict self._scale_dict = scale_dict
if is_full_quantized: self._quantizable_op_type = quantizable_op_type
self._quantizable_op_type = utils._act_supported_quantizable_op_type for op_type in self._quantizable_op_type:
else: assert op_type in list(SUPPORT_ACT_QUANTIZATION_OP_DICT.keys()), (
self._quantizable_op_type = quantizable_op_type op_type + " is not supported for quantization."
for op_type in quantizable_op_type: )
assert op_type in utils._act_supported_quantizable_op_type, (
op_type + " is not supported for quantization."
)
self._quantizable_grad_op_type = [ self._quantizable_grad_op_type = [
'%s_grad' % (op) for op in self._quantizable_op_type '%s_grad' % (op) for op in self._quantizable_op_type
] ]
...@@ -3243,7 +3232,15 @@ class AddQuantDequantForInferencePass: ...@@ -3243,7 +3232,15 @@ class AddQuantDequantForInferencePass:
When export quant model, it will traverse to find the output of each op, and then insert the quant/dequant op after it. When export quant model, it will traverse to find the output of each op, and then insert the quant/dequant op after it.
""" """
def __init__(self, scope, place, quant_bits=8): def __init__(
self,
scope,
place,
quant_bits=8,
quantizable_op_type=[],
calibration_range_dict=None,
only_observer=True,
):
""" """
Args: Args:
scope(static.Scope): The scope is used to initialize these new parameters. scope(static.Scope): The scope is used to initialize these new parameters.
...@@ -3254,7 +3251,13 @@ class AddQuantDequantForInferencePass: ...@@ -3254,7 +3251,13 @@ class AddQuantDequantForInferencePass:
self._scope = scope self._scope = scope
self._place = place self._place = place
self._quant_bits = quant_bits self._quant_bits = quant_bits
self._teller_set = utils.QUANT_SUPPORTED_OP_TYPE_LIST self._only_observer = only_observer
self._teller_set = (
quantizable_op_type
if quantizable_op_type
else list(SUPPORT_QUANTIZATION_OP_DICT.keys())
)
self._calibration_range_dict = calibration_range_dict
def apply(self, graph): def apply(self, graph):
""" """
...@@ -3321,9 +3324,31 @@ class AddQuantDequantForInferencePass: ...@@ -3321,9 +3324,31 @@ class AddQuantDequantForInferencePass:
shape=var_node.shape(), shape=var_node.shape(),
var_dtype=var_node.dtype(), var_dtype=var_node.dtype(),
) )
scale_var_node = graph._find_node_by_name( if not self._calibration_range_dict:
graph.all_persistable_nodes(), self._scale_name(var_name) scale_var_node = graph._find_node_by_name(
) graph.all_persistable_nodes(), self._scale_name(var_name)
)
elif var_name in self._calibration_range_dict:
scale_value = self._calibration_range_dict[var_name]
scale_var_node = graph.create_persistable_node(
name=self._scale_name(var_name),
var_type=var_node.type(),
shape=[1],
var_dtype=var_node.dtype(),
)
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
_init_var_node(
scale_var_node,
np.array(scale_value, dtype=data_type),
self._scope,
self._place,
)
else:
return None
try: try:
zero_point_node = graph._find_node_by_name( zero_point_node = graph._find_node_by_name(
graph.all_persistable_nodes(), graph.all_persistable_nodes(),
...@@ -3347,7 +3372,11 @@ class AddQuantDequantForInferencePass: ...@@ -3347,7 +3372,11 @@ class AddQuantDequantForInferencePass:
if zero_point_node is not None: if zero_point_node is not None:
inputs["ZeroPoint"] = zero_point_node inputs["ZeroPoint"] = zero_point_node
attrs = {"quant_axis": quant_axis, "bit_length": self._quant_bits} attrs = {
"quant_axis": quant_axis,
"bit_length": self._quant_bits,
"only_observer": self._only_observer,
}
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
outputs = {"Y": quant_var_node} outputs = {"Y": quant_var_node}
...@@ -3376,7 +3405,11 @@ class AddQuantDequantForInferencePass: ...@@ -3376,7 +3405,11 @@ class AddQuantDequantForInferencePass:
if zero_point_node is not None: if zero_point_node is not None:
inputs["ZeroPoint"] = zero_point_node inputs["ZeroPoint"] = zero_point_node
attrs = {"quant_axis": -1, "bit_length": self._quant_bits} attrs = {
"quant_axis": -1,
"bit_length": self._quant_bits,
"only_observer": self._only_observer,
}
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
dequant_op_node = graph.create_op_node( dequant_op_node = graph.create_op_node(
......
...@@ -277,6 +277,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -277,6 +277,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_optimize_model=False, is_optimize_model=False,
batch_nums=10, batch_nums=10,
onnx_format=False, onnx_format=False,
deploy_backend=None,
): ):
try: try:
os.system("mkdir " + self.int8_model) os.system("mkdir " + self.int8_model)
...@@ -305,6 +306,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -305,6 +306,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
optimize_model=is_optimize_model, optimize_model=is_optimize_model,
onnx_format=onnx_format, onnx_format=onnx_format,
is_use_cache_file=is_use_cache_file, is_use_cache_file=is_use_cache_file,
deploy_backend=deploy_backend,
) )
ptq.quantize() ptq.quantize()
ptq.save_quantized_model( ptq.save_quantized_model(
...@@ -329,6 +331,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -329,6 +331,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
diff_threshold, diff_threshold,
onnx_format=False, onnx_format=False,
batch_nums=10, batch_nums=10,
deploy_backend=None,
): ):
infer_iterations = self.infer_iterations infer_iterations = self.infer_iterations
batch_size = self.batch_size batch_size = self.batch_size
...@@ -361,6 +364,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -361,6 +364,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_optimize_model, is_optimize_model,
batch_nums, batch_nums,
onnx_format, onnx_format,
deploy_backend,
) )
print( print(
...@@ -571,5 +575,131 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): ...@@ -571,5 +575,131 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
) )
class TestPostTrainingAvgONNXFormatForMobilenetv1TensorRT(
TestPostTrainingQuantization
):
def test_post_training_onnx_format_mobilenetv1_tensorrt(self):
model = "MobileNet-V1"
algo = "avg"
round_type = "round"
data_urls = [
'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
]
data_md5s = ['5ee2b1775b11dc233079236cdc216c2e']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
onnx_format = True
diff_threshold = 0.05
batch_nums = 10
deploy_backend = "tensorrt"
self.run_test(
model,
'inference.pdmodel',
'inference.pdiparams',
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=onnx_format,
batch_nums=batch_nums,
deploy_backend=deploy_backend,
)
class TestPostTrainingKLONNXFormatForMobilenetv1MKLDNN(
TestPostTrainingQuantization
):
def test_post_training_onnx_format_mobilenetv1_mkldnn(self):
model = "MobileNet-V1"
algo = "ptf"
round_type = "round"
data_urls = [
'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
]
data_md5s = ['5ee2b1775b11dc233079236cdc216c2e']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
onnx_format = True
diff_threshold = 0.05
batch_nums = 2
deploy_backend = "mkldnn"
self.run_test(
model,
'inference.pdmodel',
'inference.pdiparams',
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=onnx_format,
batch_nums=batch_nums,
deploy_backend=deploy_backend,
)
class TestPostTrainingAvgONNXFormatForMobilenetv1ARMCPU(
TestPostTrainingQuantization
):
def test_post_training_onnx_format_mobilenetv1_armcpu(self):
model = "MobileNet-V1"
algo = "avg"
round_type = "round"
data_urls = [
'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
]
data_md5s = ['5ee2b1775b11dc233079236cdc216c2e']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
onnx_format = True
diff_threshold = 0.05
batch_nums = 3
deploy_backend = "arm"
self.run_test(
model,
'inference.pdmodel',
'inference.pdiparams',
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=onnx_format,
batch_nums=batch_nums,
deploy_backend=deploy_backend,
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -188,7 +188,6 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -188,7 +188,6 @@ class TestPostTrainingQuantization(unittest.TestCase):
): ):
origin_model_path = self.download_model(data_url, data_md5, model_name) origin_model_path = self.download_model(data_url, data_md5, model_name)
# origin_model_path = os.path.join(origin_model_path, model_name)
print( print(
"Start FP32 inference for {0} on {1} images ...".format( "Start FP32 inference for {0} on {1} images ...".format(
......
...@@ -17,115 +17,7 @@ import sys ...@@ -17,115 +17,7 @@ import sys
import numpy as np import numpy as np
from ...fluid.framework import IrNode, Operator from ...fluid.framework import IrNode, Operator
from .quant_config import SUPPORT_QUANTIZATION_OP_DICT
_weight_supported_quantizable_op_type = [
'conv2d',
'depthwise_conv2d',
'conv2d_transpose',
'mul',
'matmul',
'matmul_v2',
]
_act_supported_quantizable_op_type = [
"pool2d",
"elementwise_add",
"concat",
"softmax",
"argmax",
"transpose",
"equal",
"gather",
"greater_equal",
"greater_than",
"less_equal",
"less_than",
"mean",
"not_equal",
"reshape",
"reshape2",
"dropout",
"bilinear_interp",
"nearest_interp",
"trilinear_interp",
"slice",
"squeeze",
"elementwise_sub",
"mul",
"matmul",
"relu",
"relu6",
"leaky_relu",
"tanh",
"swish",
"transpose",
"transpose2",
"sigmoid",
"pad2d",
"flatten",
"flatten2",
"batch_norm",
"layer_norm",
"matmul_v2",
"split",
"flatten_contiguous_range",
"squeeze2",
"nearest_interp_v2",
"bilinear_interp",
"bilinear_interp_v2",
"fill_constant_batch_size_like",
"arg_max",
"abs",
"assign",
"cast",
"clip",
"box_coder",
"crop",
"cumsum",
"elementwise_mul",
"elementwise_pow",
"expand_v2",
"fill_any_like",
"fill_constant",
"gelu",
"hard_sigmoid",
"hard_swish",
"instance_norm",
"lookup_table",
"lookup_table_v2",
"norm",
"p_norm",
"pad3d",
"pow",
"prelu",
"reduce_mean",
"unsqueeze",
"unsqueeze2",
"logical_and",
"logical_not",
"meshgrid",
"roi_align",
"strided_slice",
"where",
"grid_sampler",
"tile",
"group_norm",
"reduce_sum",
"square",
"softplus",
"shuffle_channel",
"reduce_max",
"scale",
]
QUANT_SUPPORTED_OP_TYPE_LIST = list(
set(
_weight_supported_quantizable_op_type
+ _act_supported_quantizable_op_type
)
)
_out_scale_op_list = QUANT_SUPPORTED_OP_TYPE_LIST
_channelwise_quant_axis1_ops = [ _channelwise_quant_axis1_ops = [
'conv2d_transpose', 'conv2d_transpose',
...@@ -134,102 +26,6 @@ _channelwise_quant_axis1_ops = [ ...@@ -134,102 +26,6 @@ _channelwise_quant_axis1_ops = [
'matmul_v2', 'matmul_v2',
] ]
# list op real input and output names, to avoid processing input such as AxisTensor.
_op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input", "Filter"], ["Output"]],
"conv2d_transpose": [["Input", "Filter"], ["Output"]],
"mul": [["X", "Y"], ["Out"]],
"matmul": [["X", "Y"], ["Out"]],
"matmul_v2": [["X", "Y"], ["Out"]],
"pool2d": [["X"], ["Out"]],
"elementwise_add": [["X", "Y"], ["Out"]],
"concat": [["X"], ["Out"]],
"softmax": [["X"], ["Out"]],
"argmax": [["X"], ["Out"]],
"transpose": [["X"], ["Out"]],
"equal": [["X", "Y"], ["Out"]],
"gather": [["X"], ["Out"]],
"greater_equal": [["X", "Y"], ["Out"]],
"greater_than": [["X", "Y"], ["Out"]],
"less_equal": [["X", "Y"], ["Out"]],
"less_than": [["X", "Y"], ["Out"]],
"mean": [["X"], ["Out"]],
"not_equal": [["X", "Y"], ["Out"]],
"reshape": [["X"], ["Out"]],
"reshape2": [["X"], ["Out"]],
"transpose2": [["X"], ["Out"]],
"nearest_interp": [["X"], ["Out"]],
"trilinear_interp": [["X"], ["Out"]],
"slice": [["Input"], ["Out"]],
"squeeze": [["X"], ["Out"]],
"elementwise_sub": [["X", "Y"], ["Out"]],
"relu": [["X"], ["Out"]],
"relu6": [["X"], ["Out"]],
"leaky_relu": [["X"], ["Out"]],
"prelu": [["X", "Alpha"], ["Out"]],
"tanh": [["X"], ["Out"]],
"swish": [["X"], ["Out"]],
"dropout": [["X"], ["Out"]],
"batch_norm": [["X"], ["Y"]],
"layer_norm": [["X"], ["Y"]],
"sigmoid": [["X"], ["Out"]],
"elementwise_mul": [["X", "Y"], ["Out"]],
"elementwise_pow": [["X", "Y"], ["Out"]],
"hard_swish": [["X"], ["Out"]],
"hard_sigmoid": [["X"], ["Out"]],
"gru": [["Input", "Weight"], ["Hidden"]],
"lstm": [["Input", "Weight"], ["Hidden"]],
"pad2d": [["X"], ["Out"]],
"pad3d": [["X"], ["Out"]],
"flatten": [["X"], ["Out"]],
"flatten2": [["X"], ["Out"]],
"unsqueeze2": [["X"], ["Out"]],
"flatten_contiguous_range": [["X"], ["Out"]],
"split": [["X"], ["Out"]],
"squeeze2": [["X"], ["Out"]],
"nearest_interp_v2": [["X"], ["Out"]],
"bilinear_interp": [["X"], ["Out"]],
"bilinear_interp_v2": [["X"], ["Out"]],
"fill_constant_batch_size_like": [["Input"], ["Out"]],
"arg_max": [["X"], ["Out"]],
"abs": [["X"], ["Out"]],
"assign": [["X"], ["Out"]],
"cast": [["X"], ["Out"]],
"clip": [["X"], ["Out"]],
"box_coder": [["PriorBox"], ["OutputBox"]],
"crop": [["X"], ["Out"]],
"cumsum": [["X"], ["Out"]],
"expand_v2": [["X"], ["Out"]],
"fill_any_like": [["X"], ["Out"]],
"fill_constant": [[], ["Out"]],
"gelu": [["X"], ["Out"]],
"instance_norm": [["X"], ["Y"]],
"lookup_table": [["W", "Ids"], ["Out"]],
"lookup_table_v2": [["W", "Ids"], ["Out"]],
"norm": [["X"], ["Norm"]],
"p_norm": [["X"], ["Out"]],
"pow": [["X"], ["Out"]],
"reduce_mean": [["X"], ["Out"]],
"stack": [["X"], ["Y"]],
"top_k_v2": [["X"], ["Out", "Indices"]],
"logical_and": [["X", "Y"], ["Out"]],
"logical_not": [["X"], ["Out"]],
"meshgrid": [["X"], ["Out"]],
"roi_align": [["X", "ROIs"], ["Out"]],
"strided_slice": [["Input"], ["Out"]],
"where": [["Condition", "X", "Y"], ["Out"]],
"grid_sampler": [["X", "Grid"], ["Output"]],
"tile": [["X"], ["Out"]],
"group_norm": [["X"], ["Y", "Mean", "Variance"]],
"reduce_sum": [["X"], ["Out"]],
"square": [["X"], ["Out"]],
"softplus": [["X"], ["Out"]],
"shuffle_channel": [["X"], ["Out"]],
"reduce_max": [["X"], ["Out"]],
"scale": [["X"], ["Out"]],
}
def _get_op_input_var_names(op): def _get_op_input_var_names(op):
""" """
...@@ -244,10 +40,10 @@ def _get_op_input_var_names(op): ...@@ -244,10 +40,10 @@ def _get_op_input_var_names(op):
), "The input op should be IrNode or Operator." ), "The input op should be IrNode or Operator."
var_names = [] var_names = []
op_name = op.name() if isinstance(op, IrNode) else op.type op_name = op.name() if isinstance(op, IrNode) else op.type
if op_name not in _op_real_in_out_name: if op_name not in SUPPORT_QUANTIZATION_OP_DICT:
return [] return []
name_list = _op_real_in_out_name[op_name][0] name_list = SUPPORT_QUANTIZATION_OP_DICT[op_name][0]
for name in name_list: for name in name_list:
var_name = op.input(name) var_name = op.input(name)
if isinstance(var_name, list): if isinstance(var_name, list):
...@@ -264,10 +60,10 @@ def _get_op_output_var_names(op): ...@@ -264,10 +60,10 @@ def _get_op_output_var_names(op):
), "The input op should be IrNode or Operator." ), "The input op should be IrNode or Operator."
var_names = [] var_names = []
op_name = op.name() if isinstance(op, IrNode) else op.type op_name = op.name() if isinstance(op, IrNode) else op.type
if op_name not in _op_real_in_out_name: if op_name not in SUPPORT_QUANTIZATION_OP_DICT:
return [] return []
name_list = _op_real_in_out_name[op_name][1] name_list = SUPPORT_QUANTIZATION_OP_DICT[op_name][1]
for name in name_list: for name in name_list:
var_name = op.output(name) var_name = op.output(name)
if isinstance(var_name, list): if isinstance(var_name, list):
...@@ -283,11 +79,11 @@ def _get_input_name_index(op, input_var_name): ...@@ -283,11 +79,11 @@ def _get_input_name_index(op, input_var_name):
op, (IrNode, Operator) op, (IrNode, Operator)
), "The input op should be IrNode or Operator." ), "The input op should be IrNode or Operator."
op_name = op.name() if isinstance(op, IrNode) else op.type op_name = op.name() if isinstance(op, IrNode) else op.type
if op_name not in _op_real_in_out_name: if op_name not in SUPPORT_QUANTIZATION_OP_DICT:
return None return None
res = None res = None
for argname in _op_real_in_out_name[op_name][0]: for argname in SUPPORT_QUANTIZATION_OP_DICT[op_name][0]:
var_names = op.input(argname) var_names = op.input(argname)
for index, name in enumerate(var_names): for index, name in enumerate(var_names):
if name == input_var_name: if name == input_var_name:
...@@ -301,10 +97,10 @@ def _get_output_name_index(op, output_var_name): ...@@ -301,10 +97,10 @@ def _get_output_name_index(op, output_var_name):
op, (IrNode, Operator) op, (IrNode, Operator)
), "The input op should be IrNode or Operator." ), "The input op should be IrNode or Operator."
op_name = op.name() if isinstance(op, IrNode) else op.type op_name = op.name() if isinstance(op, IrNode) else op.type
if op_name not in _op_real_in_out_name: if op_name not in SUPPORT_QUANTIZATION_OP_DICT:
return None return None
name_list = _op_real_in_out_name[op_name][1] name_list = SUPPORT_QUANTIZATION_OP_DICT[op_name][1]
res = None res = None
for name in name_list: for name in name_list:
var_name = op.output(name) var_name = op.output(name)
...@@ -347,7 +143,7 @@ def quant_tensor(x, scale, quant_axis=0, weight_bits=8, onnx_format=False): ...@@ -347,7 +143,7 @@ def quant_tensor(x, scale, quant_axis=0, weight_bits=8, onnx_format=False):
if isinstance(scale, list) and len(scale) == 1: if isinstance(scale, list) and len(scale) == 1:
scale = scale[0] scale = scale[0]
if isinstance(scale, list): if isinstance(scale, list):
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.' assert quant_axis in [-1, 0, 1], 'quant_axis should be 0 or 1 for now.'
for i, s in enumerate(scale): for i, s in enumerate(scale):
if s == 0.0: if s == 0.0:
s = 1e-8 s = 1e-8
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册