未验证 提交 8bbae468 编写于 作者: G Guanghua Yu 提交者: GitHub

Add observer attribute in qdq node & Add quant config for different backends. (#46887)

上级 07db4a9f
......@@ -200,6 +200,12 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(true);
AddAttr<bool>(
"only_observer",
"(bool, default false) Whether to only observer or not. If "
"only_observer=false, it will calculate fake quant or dequant output. "
"If only_observer=true, it will only calibrate scale information.")
.SetDefault(false);
AddComment(R"DOC(
The scale of QuantizeLinear operator is a vector.
In detail, each channel of the input X has a scale value.
......
......@@ -61,6 +61,7 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
int bin_cnt = std::pow(2, bit_length - 1) - 1;
int quant_axis = context.Attr<int>("quant_axis");
bool is_test = context.Attr<bool>("is_test");
bool only_observer = context.Attr<bool>("only_observer");
auto& dev_ctx = context.template device_context<DeviceContext>();
if (quant_axis < 0) {
......@@ -91,11 +92,19 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
out_state,
out_accum,
out_scale);
ClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
if (only_observer) {
framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
} else {
ClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
}
} else {
ClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *in_scale, bin_cnt, round_type, out);
if (only_observer) {
framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
} else {
ClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *in_scale, bin_cnt, round_type, out);
}
}
} else {
if (!is_test) {
......@@ -103,11 +112,19 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
FindChannelAbsMaxFunctor<DeviceContext, T>()(
dev_ctx, *in, quant_axis, out_scale_data);
ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out);
if (only_observer) {
framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
} else {
ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out);
}
} else {
ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *in_scale, bin_cnt, round_type, quant_axis, out);
if (only_observer) {
framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
} else {
ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *in_scale, bin_cnt, round_type, quant_axis, out);
}
}
}
}
......@@ -132,6 +149,12 @@ class DeQuantizeLinearKernel : public framework::OpKernel<T> {
int bit_length = context.Attr<int>("bit_length");
auto quant_axis = context.Attr<int>("quant_axis");
dev_ctx.template Alloc<D>(out, out->numel() * sizeof(D));
bool only_observer = context.Attr<bool>("only_observer");
if (only_observer) {
framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
return;
}
if (quant_axis < 0) {
float max_range = (std::pow(2, bit_length - 1) - 1);
......
......@@ -24,15 +24,19 @@ from paddle.static.quantization import (
AddQuantDequantPassV2,
OutScaleForTrainingPass,
QuantizationTransformPassV2,
utils,
quant_config,
)
from ..auto_parallel.converter import Converter
from ..auto_parallel.dist_attribute import OperatorDistAttr, TensorDistAttr
from .pass_base import PassBase, register_pass
TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = utils._act_supported_quantizable_op_type
TRANSFORM_PASS_OP_TYPES = list(
quant_config.SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys()
)
QUANT_DEQUANT_PASS_OP_TYPES = list(
quant_config.SUPPORT_ACT_QUANTIZATION_OP_DICT.keys()
)
def _node_id(node):
......
......@@ -35,7 +35,15 @@ from ..log_helper import get_logger
from . import utils
from .adaround import run_adaround
from .cal_kl_threshold import cal_kl_threshold
from .quant_config import (
SUPPORT_QUANTIZATION_OP_DICT,
ARMCPUQuantizer,
BaseQuantizer,
MKLDNNQuantizer,
TensorRTQuantizer,
)
from .quantization_pass import (
AddQuantDequantForInferencePass,
AddQuantDequantPass,
AddQuantDequantPassV2,
QuantizationFreezePass,
......@@ -127,7 +135,7 @@ class PostTrainingQuantization:
batch_nums=None,
algo="KL",
hist_percent=0.99999,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
quantizable_op_type=[],
round_type='round',
learning_rate=0.001,
is_full_quantize=False,
......@@ -145,6 +153,7 @@ class PostTrainingQuantization:
cache_dir=None,
scale_dict=None,
return_graph=False,
deploy_backend=None,
):
'''
Constructor.
......@@ -190,8 +199,9 @@ class PostTrainingQuantization:
hist_percent(float, optional): The threshold of algo 'hist' for activations.
Default is 0.99999.
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"].
that will be quantized. Default is []. If quantizable_op_type is [],
it will use the default quantization op type of the qunat config in
the current deploy_backend.
round_type(str, optional): The method of converting the quantized weights
value float->int. Currently supports ['round', 'adaround'] methods.
Default is `round`, which is rounding nearest to the integer.
......@@ -199,8 +209,8 @@ class PostTrainingQuantization:
learning_rate(float, optional): The learning rate of adaround method.
is_full_quantized(bool, optional): If set is_full_quantized as True,
apply quantization to all supported quantizable op type. If set
is_full_quantized as False, only apply quantization to the op type
according to the input quantizable_op_type.
is_full_quantized as False, it will apply quantization to the op type
according to the input quantizable_op_type or quant config of deploy_backend.
bias_correction(bool, optional): If set as True, use the bias correction
method of https://arxiv.org/abs/1810.05723. Default is False.
activation_bits(int): quantization bit number for activation.
......@@ -234,6 +244,9 @@ class PostTrainingQuantization:
quantization. Default False.
is_use_cache_file(bool, optional): This param is deprecated.
cache_dir(str, optional): This param is deprecated.
deploy_backend(str, optional): Deploy backend, it can be None, `TensorRT`,
`MKLDNN`, `ARM`. And it will extend the new backend. Default is None,
which means to use the default general quantization configuration.
Returns:
None
......@@ -294,13 +307,6 @@ class PostTrainingQuantization:
self._round_type = round_type
self._learning_rate = learning_rate
self._dynamic_quantize_op_type = ['lstm']
self._support_quantize_op_type = list(
set(
utils._weight_supported_quantizable_op_type
+ utils._act_supported_quantizable_op_type
+ self._dynamic_quantize_op_type
)
)
# Check inputs
assert executor is not None, "The executor cannot be None."
......@@ -355,15 +361,6 @@ class PostTrainingQuantization:
self._onnx_format = onnx_format
self._clip_extra = True if self._onnx_format else False
self._skip_tensor_list = skip_tensor_list
self._is_full_quantize = is_full_quantize
if is_full_quantize:
self._quantizable_op_type = self._support_quantize_op_type
else:
self._quantizable_op_type = quantizable_op_type
for op_type in self._quantizable_op_type:
assert op_type in self._support_quantize_op_type, (
op_type + " is not supported for quantization."
)
self._optimize_model = optimize_model
# Define variables
......@@ -373,7 +370,6 @@ class PostTrainingQuantization:
self._fetch_list = None
self._data_loader = data_loader
self._out_scale_op_list = utils.QUANT_SUPPORTED_OP_TYPE_LIST
self._quantized_weight_var_name = set()
self._quantized_act_var_name = set()
self._weight_op_pairs = {}
......@@ -403,6 +399,43 @@ class PostTrainingQuantization:
if self._program is not None:
self.FLAG = True
self._is_full_quantize = is_full_quantize
if is_full_quantize:
quantizable_op_type = list(SUPPORT_QUANTIZATION_OP_DICT.keys())
elif quantizable_op_type:
for op_type in quantizable_op_type:
assert op_type in list(SUPPORT_QUANTIZATION_OP_DICT.keys()), (
op_type + " is not supported for quantization."
)
assert (
activation_bits == weight_bits
), "activation_bits and weight_bits must be the same, other cases are not supported."
support_deploy_backend = [None, "tensorrt", "mkldnn", "arm"]
if not deploy_backend:
self.quant_config = BaseQuantizer(
quantizable_op_type=quantizable_op_type,
quant_bits=weight_bits,
)
elif deploy_backend.lower() == "tensorrt":
self.quant_config = TensorRTQuantizer(
quantizable_op_type=quantizable_op_type,
quant_bits=weight_bits,
)
elif deploy_backend.lower() == "mkldnn":
self.quant_config = MKLDNNQuantizer(
quantizable_op_type=quantizable_op_type,
quant_bits=weight_bits,
)
elif deploy_backend.lower() == "arm":
self.quant_config = ARMCPUQuantizer(
quantizable_op_type=quantizable_op_type,
quant_bits=weight_bits,
)
else:
assert "Deploy Backend {} not support, please choose one of {}.".format(
deploy_backend, support_deploy_backend
)
def quantize(self):
'''
Load the FP32 model, and use the calibrate data to calculate the forward-stage.
......@@ -486,7 +519,7 @@ class PostTrainingQuantization:
self._save_output_threshold()
if any(
op_type in self._quantizable_op_type
op_type in self.quant_config.activation_quant_operation_types
for op_type in self._dynamic_quantize_op_type
):
self._collect_dynamic_quantize_op_threshold(
......@@ -652,9 +685,8 @@ class PostTrainingQuantization:
op._set_attr("op_namescope", "skip_quant")
op_type = op.type
if (
self._is_full_quantize
and op_type not in self._quantizable_op_type
if self._is_full_quantize and op_type not in list(
SUPPORT_QUANTIZATION_OP_DICT.keys()
):
_logger.warning(
op_type + " is not supported for quantization."
......@@ -664,7 +696,12 @@ class PostTrainingQuantization:
in persistable_var_names
)
# For quantized ops, sample inputs and outputs
if op_type in self._quantizable_op_type or is_conv1d_quant:
if (
op_type in self.quant_config.weight_quant_operation_types
or op_type
in self.quant_config.activation_quant_operation_types
or is_conv1d_quant
):
collect_var_name(
utils._get_op_input_var_names(op),
persistable_var_names,
......@@ -683,7 +720,7 @@ class PostTrainingQuantization:
in_var_name
] = out_var_name
# For other op, only sample output scale
elif op_type in self._out_scale_op_list:
elif op_type in self.quant_config.observer_operation_types:
collect_var_name(
utils._get_op_output_var_names(op),
persistable_var_names,
......@@ -1034,7 +1071,11 @@ class PostTrainingQuantization:
), "The algo should be min_max to save input threshold."
for block_id in range(len(self._program.blocks)):
for op in self._program.blocks[block_id].ops:
if op.type in self._quantizable_op_type:
if (
op.type in self.quant_config.weight_quant_operation_types
or op.type
in self.quant_config.activation_quant_operation_types
):
for var_name in utils._get_op_input_var_names(op):
assert var_name in self._quantized_var_min
assert var_name in self._quantized_var_max
......@@ -1142,10 +1183,6 @@ class PostTrainingQuantization:
graph = IrGraph(core.Graph(self._program.desc), for_test=True)
# use QuantizationTransformPass to insert fake_quant/fake_dequantize op
major_quantizable_op_types = []
for op_type in utils._weight_supported_quantizable_op_type:
if op_type in self._quantizable_op_type:
major_quantizable_op_types.append(op_type)
if not self._onnx_format:
transform_pass = QuantizationTransformPass(
scope=self._scope,
......@@ -1154,7 +1191,7 @@ class PostTrainingQuantization:
activation_bits=self._activation_bits,
activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types,
quantizable_op_type=self.quant_config.weight_quant_operation_types,
)
else:
transform_pass = QuantizationTransformPassV2(
......@@ -1164,7 +1201,7 @@ class PostTrainingQuantization:
activation_bits=self._activation_bits,
activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types,
quantizable_op_type=self.quant_config.weight_quant_operation_types,
)
for sub_graph in graph.all_sub_graphs():
......@@ -1174,22 +1211,17 @@ class PostTrainingQuantization:
transform_pass.apply(sub_graph)
# use AddQuantDequantPass to insert fake_quant_dequant op
minor_quantizable_op_types = []
for op_type in utils._act_supported_quantizable_op_type:
if op_type in self._quantizable_op_type:
minor_quantizable_op_types.append(op_type)
if not self._onnx_format:
add_quant_dequant_pass = AddQuantDequantPass(
scope=self._scope,
place=self._place,
quantizable_op_type=minor_quantizable_op_types,
quantizable_op_type=self.quant_config.activation_quant_operation_types,
)
else:
add_quant_dequant_pass = AddQuantDequantPassV2(
scope=self._scope,
place=self._place,
quantizable_op_type=minor_quantizable_op_types,
is_full_quantized=True,
quantizable_op_type=self.quant_config.activation_quant_operation_types,
)
for sub_graph in graph.all_sub_graphs():
......@@ -1283,7 +1315,7 @@ class PostTrainingQuantization:
round_type=self._round_type,
activation_bits=self._activation_bits,
weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types,
quantizable_op_type=self.quant_config.weight_quant_operation_types,
)
for sub_graph in graph.all_sub_graphs():
......@@ -1295,6 +1327,22 @@ class PostTrainingQuantization:
sub_graph._for_test = True
quant_weight_pass.apply(sub_graph)
infer_pass_quant_op_types = (
self.quant_config.weight_quant_operation_types
+ self.quant_config.activation_quant_operation_types
+ self.quant_config.observer_operation_types
)
out_scale_infer_pass = AddQuantDequantForInferencePass(
scope=self._scope,
place=self._place,
quant_bits=self._activation_bits,
quantizable_op_type=infer_pass_quant_op_types,
calibration_range_dict=self._scale_dict,
)
for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
out_scale_infer_pass.apply(sub_graph)
self._program = graph.to_program()
def _save_output_threshold(self):
......@@ -1339,7 +1387,12 @@ class PostTrainingQuantization:
threshold_map[out_var_name],
)
op_node._set_attr("with_quant_attr", True)
if op_node.type in self._quantizable_op_type:
if (
op_node.type
in self.quant_config.weight_quant_operation_types
or op_node.type
in self.quant_config.activation_quant_operation_types
):
op._set_attr("quantization_type", quantized_type)
def analysis_and_save_info(op_node, out_var_name):
......@@ -1387,7 +1440,9 @@ class PostTrainingQuantization:
for block_id in range(len(self._program.blocks)):
for op in self._program.blocks[block_id].ops:
if op.type in (
self._quantizable_op_type + self._out_scale_op_list
self.quant_config.weight_quant_operation_types
+ self.quant_config.activation_quant_operation_types
+ self.quant_config.observer_operation_types
):
out_var_names = utils._get_op_output_var_names(op)
for var_name in out_var_names:
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# A dict of operators that contain weights and support quantization,
# including operator names, actual input and output names.
SUPPORT_WEIGHT_QUANTIZATION_OP_DICT = {
"conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input", "Filter"], ["Output"]],
"conv2d_transpose": [["Input", "Filter"], ["Output"]],
"mul": [["X", "Y"], ["Out"]],
"matmul": [["X", "Y"], ["Out"]],
"matmul_v2": [["X", "Y"], ["Out"]],
}
# A dict of operators that supports quantization and has only activation inputs,
# including operator names, actual input and output names.
SUPPORT_ACT_QUANTIZATION_OP_DICT = {
"mul": [["X", "Y"], ["Out"]],
"matmul": [["X", "Y"], ["Out"]],
"matmul_v2": [["X", "Y"], ["Out"]],
"pool2d": [["X"], ["Out"]],
"elementwise_add": [["X", "Y"], ["Out"]],
"concat": [["X"], ["Out"]],
"softmax": [["X"], ["Out"]],
"argmax": [["X"], ["Out"]],
"transpose": [["X"], ["Out"]],
"equal": [["X", "Y"], ["Out"]],
"gather": [["X"], ["Out"]],
"greater_equal": [["X", "Y"], ["Out"]],
"greater_than": [["X", "Y"], ["Out"]],
"less_equal": [["X", "Y"], ["Out"]],
"less_than": [["X", "Y"], ["Out"]],
"mean": [["X"], ["Out"]],
"not_equal": [["X", "Y"], ["Out"]],
"reshape": [["X"], ["Out"]],
"reshape2": [["X"], ["Out"]],
"transpose2": [["X"], ["Out"]],
"nearest_interp": [["X"], ["Out"]],
"trilinear_interp": [["X"], ["Out"]],
"slice": [["Input"], ["Out"]],
"squeeze": [["X"], ["Out"]],
"elementwise_sub": [["X", "Y"], ["Out"]],
"relu": [["X"], ["Out"]],
"relu6": [["X"], ["Out"]],
"leaky_relu": [["X"], ["Out"]],
"prelu": [["X", "Alpha"], ["Out"]],
"tanh": [["X"], ["Out"]],
"swish": [["X"], ["Out"]],
"dropout": [["X"], ["Out"]],
"batch_norm": [["X"], ["Y"]],
"layer_norm": [["X"], ["Y"]],
"sigmoid": [["X"], ["Out"]],
"elementwise_mul": [["X", "Y"], ["Out"]],
"elementwise_pow": [["X", "Y"], ["Out"]],
"hard_swish": [["X"], ["Out"]],
"hard_sigmoid": [["X"], ["Out"]],
"gru": [["Input", "Weight"], ["Hidden"]],
"lstm": [["Input", "Weight"], ["Hidden"]],
"pad2d": [["X"], ["Out"]],
"pad3d": [["X"], ["Out"]],
"flatten": [["X"], ["Out"]],
"flatten2": [["X"], ["Out"]],
"unsqueeze2": [["X"], ["Out"]],
"flatten_contiguous_range": [["X"], ["Out"]],
"split": [["X"], ["Out"]],
"squeeze2": [["X"], ["Out"]],
"nearest_interp_v2": [["X"], ["Out"]],
"bilinear_interp": [["X"], ["Out"]],
"bilinear_interp_v2": [["X"], ["Out"]],
"fill_constant_batch_size_like": [["Input"], ["Out"]],
"arg_max": [["X"], ["Out"]],
"abs": [["X"], ["Out"]],
"assign": [["X"], ["Out"]],
"cast": [["X"], ["Out"]],
"clip": [["X"], ["Out"]],
"box_coder": [["PriorBox"], ["OutputBox"]],
"crop": [["X"], ["Out"]],
"cumsum": [["X"], ["Out"]],
"expand_v2": [["X"], ["Out"]],
"fill_any_like": [["X"], ["Out"]],
"fill_constant": [[], ["Out"]],
"gelu": [["X"], ["Out"]],
"instance_norm": [["X"], ["Y"]],
"lookup_table": [["W", "Ids"], ["Out"]],
"lookup_table_v2": [["W", "Ids"], ["Out"]],
"norm": [["X"], ["Norm"]],
"p_norm": [["X"], ["Out"]],
"pow": [["X"], ["Out"]],
"reduce_mean": [["X"], ["Out"]],
"stack": [["X"], ["Y"]],
"top_k_v2": [["X"], ["Out", "Indices"]],
"logical_and": [["X", "Y"], ["Out"]],
"logical_not": [["X"], ["Out"]],
"meshgrid": [["X"], ["Out"]],
"roi_align": [["X", "ROIs"], ["Out"]],
"strided_slice": [["Input"], ["Out"]],
"where": [["Condition", "X", "Y"], ["Out"]],
"grid_sampler": [["X", "Grid"], ["Output"]],
"tile": [["X"], ["Out"]],
"group_norm": [["X"], ["Y", "Mean", "Variance"]],
"reduce_sum": [["X"], ["Out"]],
"square": [["X"], ["Out"]],
"softplus": [["X"], ["Out"]],
"shuffle_channel": [["X"], ["Out"]],
"reduce_max": [["X"], ["Out"]],
"scale": [["X"], ["Out"]],
}
# A full dict of operators that supports quantization,
# including operator names, actual input and output names.
SUPPORT_QUANTIZATION_OP_DICT = SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.copy()
SUPPORT_QUANTIZATION_OP_DICT.update(SUPPORT_ACT_QUANTIZATION_OP_DICT)
class BaseQuantizer:
"""
Basic quantization configuration class, which configures some hyperparameters
required for quantization, including the list of op types to be quantized,
quantization bit number for weight and activation and the range of quantization values.
Args:
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is []. If quantizable_op_type is [],
it will use the default quantization op type of the qunat config in
the current Quantizer.
quant_bits(int, optional): Quantization bit number for weight and activation.
Default is 8.
"""
def __init__(
self,
quantizable_op_type=[],
quant_bits=8,
):
self._quantizable_op_type = quantizable_op_type
self._quant_bits = quant_bits
self._quant_min = -128
self._quant_max = 127
@property
def weight_quant_operation_types(self):
"""
Operation type list which should support weight quantization.
And before these ops, quant dequant nodes will be inserted.
"""
base_weight_op_type_list = list(
SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys()
)
if self._quantizable_op_type:
weight_list = []
for _op_type in self._quantizable_op_type:
if _op_type in base_weight_op_type_list:
weight_list.append(_op_type)
return weight_list
else:
return base_weight_op_type_list
@property
def activation_quant_operation_types(self):
"""
Operation type list which should support activation quantization.
And before these ops, quant dequant nodes will be inserted.
"""
base_act_op_type_list = list(SUPPORT_ACT_QUANTIZATION_OP_DICT.keys())
act_quant_op_list = []
if self._quantizable_op_type:
for _op_type in self._quantizable_op_type:
if _op_type in base_act_op_type_list:
act_quant_op_list.append(_op_type)
else:
act_quant_op_list = [
'mul',
'matmul',
'matmul_v2',
]
return act_quant_op_list
@property
def observer_operation_types(self):
"""
Operation type list for observer in quantization. These nodes only count the
calibration boundary scale and do not participate in the fake quantization.
In order to facilitate the deployment of the prediction engine, quant
and dequant nodes will be inserted after these ops when exporting the model.
"""
return list(SUPPORT_ACT_QUANTIZATION_OP_DICT.keys())
class TensorRTQuantizer(BaseQuantizer):
"""
TensorRT quantization configuration class.
Args:
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is []. If quantizable_op_type is [],
it will use the default quantization op type of the qunat config in
the current Quantizer.
quant_bits(int, optional): Quantization bit number for weight and activation.
Default is 8.
"""
def __init__(
self,
quantizable_op_type=[],
quant_bits=8,
):
super().__init__()
self._quantizable_op_type = quantizable_op_type
self._quant_bits = quant_bits
self._quant_min = -128
self._quant_max = 127
@property
def activation_quant_operation_types(self):
"""
Operation type list which should support activation quantization.
And before these ops, quant dequant nodes will be inserted.
"""
return [
"pool2d",
"elementwise_add",
"elementwise_sub",
"elementwise_mul",
"elementwise_pow",
"concat",
"softmax",
"argmax",
"mean",
"relu",
"relu6",
"leaky_relu",
"tanh",
"swish",
"softplus",
"gelu",
"hard_sigmoid",
"hard_swish",
"sigmoid",
"layer_norm",
"matmul_v2",
"split",
"bilinear_interp",
"nearest_interp",
"trilinear_interp",
"nearest_interp_v2",
"bilinear_interp",
"bilinear_interp_v2",
"clip",
"pow",
"reduce_mean",
"reduce_sum",
"reduce_max",
]
class MKLDNNQuantizer(BaseQuantizer):
"""
MKLDNN quantization configuration class.
Args:
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is []. If quantizable_op_type is [],
it will use the default quantization op type of the qunat config in
the current Quantizer.
quant_bits(int, optional): Quantization bit number for weight and activation.
Default is 8.
"""
def __init__(
self,
quantizable_op_type=[],
quant_bits=8,
):
super().__init__()
self._quantizable_op_type = quantizable_op_type
self._quant_bits = quant_bits
self._quant_min = -128
self._quant_max = 127
@property
def activation_quant_operation_types(self):
"""
Operation type list which should support activation quantization.
And before these ops, quant dequant nodes will be inserted.
"""
return [
"pool2d",
"elementwise_add",
"elementwise_mul",
"concat",
"nearest_interp",
"nearest_interp_v2",
"split",
]
class ARMCPUQuantizer(BaseQuantizer):
"""
ARM CPU with Paddle Lite quantization configuration class.
Args:
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is []. If quantizable_op_type is [],
it will use the default quantization op type of the qunat config in
the current Quantizer.
quant_bits(int, optional): Quantization bit number for weight and activation.
Default is 8.
"""
def __init__(
self,
quantizable_op_type=[],
quant_bits=8,
):
super().__init__()
self._quantizable_op_type = quantizable_op_type
self._quant_bits = quant_bits
self._quant_min = -127
self._quant_max = 127
......@@ -28,6 +28,11 @@ from ...framework import _get_paddle_place, core
from ...static import Program, data, program_guard, scope_guard
from ...utils import unique_name
from . import utils
from .quant_config import (
SUPPORT_ACT_QUANTIZATION_OP_DICT,
SUPPORT_QUANTIZATION_OP_DICT,
SUPPORT_WEIGHT_QUANTIZATION_OP_DICT,
)
_fake_quant_op_list = [
'fake_quantize_abs_max',
......@@ -231,7 +236,7 @@ class QuantizationTransformPass:
self._quantizable_ops = quantizable_op_type
for op in self._quantizable_ops:
assert op in utils._weight_supported_quantizable_op_type, (
assert op in list(SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys()), (
op + " is not supported for quantization."
)
self._quantizable_grad_ops = [
......@@ -1594,7 +1599,7 @@ class OutScaleForTrainingPass:
self._place = _get_paddle_place(place)
self._moving_rate = moving_rate
self._is_test = is_test
self._teller_set = utils.QUANT_SUPPORTED_OP_TYPE_LIST
self._teller_set = list(SUPPORT_QUANTIZATION_OP_DICT.keys())
self._scale_dict = scale_dict
def apply(self, graph):
......@@ -1749,7 +1754,7 @@ class OutScaleForInferencePass:
scope(static.Scope): The scope is used to initialize these new parameters.
"""
self._scope = scope
self._teller_set = utils.QUANT_SUPPORTED_OP_TYPE_LIST
self._teller_set = list(SUPPORT_QUANTIZATION_OP_DICT.keys())
def apply(self, graph):
"""
......@@ -1830,7 +1835,6 @@ class AddQuantDequantPass:
quant_bits=8,
skip_pattern=["skip_quant"],
quantizable_op_type=["elementwise_add", "pool2d"],
is_full_quantized=False,
is_test=None,
scale_dict=None,
):
......@@ -1851,10 +1855,6 @@ class AddQuantDequantPass:
Default is 'skip_quant'.
quantizable_op_type(list[str], optional): List the type of ops that will be
quantized. Default is ["elementwise_add", "pool2d"].
is_full_quantized(bool, optional): If set is_full_quantized as True, apply
quantization to all supported quantizable op type. If set is_full_quantized
as False, only apply quantization to the op type according to the input
quantizable_op_type.
"""
self._scope = scope
self._place = _get_paddle_place(place)
......@@ -1864,14 +1864,11 @@ class AddQuantDequantPass:
self._skip_pattern = skip_pattern
self._scale_dict = scale_dict
if is_full_quantized:
self._quantizable_op_type = utils._act_supported_quantizable_op_type
else:
self._quantizable_op_type = quantizable_op_type
for op_type in quantizable_op_type:
assert op_type in utils._act_supported_quantizable_op_type, (
op_type + " is not supported for quantization."
)
self._quantizable_op_type = quantizable_op_type
for op_type in self._quantizable_op_type:
assert op_type in list(SUPPORT_ACT_QUANTIZATION_OP_DICT.keys()), (
op_type + " is not supported for quantization."
)
self._quantizable_grad_op_type = [
'%s_grad' % (op) for op in self._quantizable_op_type
]
......@@ -2485,7 +2482,7 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
self._quantizable_ops = quantizable_op_type
for op in self._quantizable_ops:
assert op in utils._weight_supported_quantizable_op_type, (
assert op in list(SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys()), (
op + " is not supported for quantization."
)
self._quantizable_grad_ops = [
......@@ -2763,7 +2760,6 @@ class AddQuantDequantPassV2:
quant_bits=8,
skip_pattern=["skip_quant"],
quantizable_op_type=["elementwise_add", "pool2d"],
is_full_quantized=False,
is_test=None,
scale_dict=None,
):
......@@ -2782,10 +2778,6 @@ class AddQuantDequantPassV2:
Default is 'skip_quant'.
quantizable_op_type(list[str], optional): List the type of ops that will be
quantized. Default is ["elementwise_add", "pool2d"].
is_full_quantized(bool, optional): If set is_full_quantized as True, apply
quantization to all supported quantizable op type. If set is_full_quantized
as False, only apply quantization to the op type according to the input
quantizable_op_type.
scale_dict(dict, optional): calibration ranges of tensors output.
Examples:
......@@ -2811,14 +2803,11 @@ class AddQuantDequantPassV2:
self._skip_pattern = skip_pattern
self._scale_dict = scale_dict
if is_full_quantized:
self._quantizable_op_type = utils._act_supported_quantizable_op_type
else:
self._quantizable_op_type = quantizable_op_type
for op_type in quantizable_op_type:
assert op_type in utils._act_supported_quantizable_op_type, (
op_type + " is not supported for quantization."
)
self._quantizable_op_type = quantizable_op_type
for op_type in self._quantizable_op_type:
assert op_type in list(SUPPORT_ACT_QUANTIZATION_OP_DICT.keys()), (
op_type + " is not supported for quantization."
)
self._quantizable_grad_op_type = [
'%s_grad' % (op) for op in self._quantizable_op_type
]
......@@ -3243,7 +3232,15 @@ class AddQuantDequantForInferencePass:
When export quant model, it will traverse to find the output of each op, and then insert the quant/dequant op after it.
"""
def __init__(self, scope, place, quant_bits=8):
def __init__(
self,
scope,
place,
quant_bits=8,
quantizable_op_type=[],
calibration_range_dict=None,
only_observer=True,
):
"""
Args:
scope(static.Scope): The scope is used to initialize these new parameters.
......@@ -3254,7 +3251,13 @@ class AddQuantDequantForInferencePass:
self._scope = scope
self._place = place
self._quant_bits = quant_bits
self._teller_set = utils.QUANT_SUPPORTED_OP_TYPE_LIST
self._only_observer = only_observer
self._teller_set = (
quantizable_op_type
if quantizable_op_type
else list(SUPPORT_QUANTIZATION_OP_DICT.keys())
)
self._calibration_range_dict = calibration_range_dict
def apply(self, graph):
"""
......@@ -3321,9 +3324,31 @@ class AddQuantDequantForInferencePass:
shape=var_node.shape(),
var_dtype=var_node.dtype(),
)
scale_var_node = graph._find_node_by_name(
graph.all_persistable_nodes(), self._scale_name(var_name)
)
if not self._calibration_range_dict:
scale_var_node = graph._find_node_by_name(
graph.all_persistable_nodes(), self._scale_name(var_name)
)
elif var_name in self._calibration_range_dict:
scale_value = self._calibration_range_dict[var_name]
scale_var_node = graph.create_persistable_node(
name=self._scale_name(var_name),
var_type=var_node.type(),
shape=[1],
var_dtype=var_node.dtype(),
)
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
_init_var_node(
scale_var_node,
np.array(scale_value, dtype=data_type),
self._scope,
self._place,
)
else:
return None
try:
zero_point_node = graph._find_node_by_name(
graph.all_persistable_nodes(),
......@@ -3347,7 +3372,11 @@ class AddQuantDequantForInferencePass:
if zero_point_node is not None:
inputs["ZeroPoint"] = zero_point_node
attrs = {"quant_axis": quant_axis, "bit_length": self._quant_bits}
attrs = {
"quant_axis": quant_axis,
"bit_length": self._quant_bits,
"only_observer": self._only_observer,
}
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
outputs = {"Y": quant_var_node}
......@@ -3376,7 +3405,11 @@ class AddQuantDequantForInferencePass:
if zero_point_node is not None:
inputs["ZeroPoint"] = zero_point_node
attrs = {"quant_axis": -1, "bit_length": self._quant_bits}
attrs = {
"quant_axis": -1,
"bit_length": self._quant_bits,
"only_observer": self._only_observer,
}
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
dequant_op_node = graph.create_op_node(
......
......@@ -277,6 +277,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_optimize_model=False,
batch_nums=10,
onnx_format=False,
deploy_backend=None,
):
try:
os.system("mkdir " + self.int8_model)
......@@ -305,6 +306,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
optimize_model=is_optimize_model,
onnx_format=onnx_format,
is_use_cache_file=is_use_cache_file,
deploy_backend=deploy_backend,
)
ptq.quantize()
ptq.save_quantized_model(
......@@ -329,6 +331,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
diff_threshold,
onnx_format=False,
batch_nums=10,
deploy_backend=None,
):
infer_iterations = self.infer_iterations
batch_size = self.batch_size
......@@ -361,6 +364,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_optimize_model,
batch_nums,
onnx_format,
deploy_backend,
)
print(
......@@ -571,5 +575,131 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
)
class TestPostTrainingAvgONNXFormatForMobilenetv1TensorRT(
TestPostTrainingQuantization
):
def test_post_training_onnx_format_mobilenetv1_tensorrt(self):
model = "MobileNet-V1"
algo = "avg"
round_type = "round"
data_urls = [
'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
]
data_md5s = ['5ee2b1775b11dc233079236cdc216c2e']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
onnx_format = True
diff_threshold = 0.05
batch_nums = 10
deploy_backend = "tensorrt"
self.run_test(
model,
'inference.pdmodel',
'inference.pdiparams',
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=onnx_format,
batch_nums=batch_nums,
deploy_backend=deploy_backend,
)
class TestPostTrainingKLONNXFormatForMobilenetv1MKLDNN(
TestPostTrainingQuantization
):
def test_post_training_onnx_format_mobilenetv1_mkldnn(self):
model = "MobileNet-V1"
algo = "ptf"
round_type = "round"
data_urls = [
'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
]
data_md5s = ['5ee2b1775b11dc233079236cdc216c2e']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
onnx_format = True
diff_threshold = 0.05
batch_nums = 2
deploy_backend = "mkldnn"
self.run_test(
model,
'inference.pdmodel',
'inference.pdiparams',
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=onnx_format,
batch_nums=batch_nums,
deploy_backend=deploy_backend,
)
class TestPostTrainingAvgONNXFormatForMobilenetv1ARMCPU(
TestPostTrainingQuantization
):
def test_post_training_onnx_format_mobilenetv1_armcpu(self):
model = "MobileNet-V1"
algo = "avg"
round_type = "round"
data_urls = [
'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
]
data_md5s = ['5ee2b1775b11dc233079236cdc216c2e']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
onnx_format = True
diff_threshold = 0.05
batch_nums = 3
deploy_backend = "arm"
self.run_test(
model,
'inference.pdmodel',
'inference.pdiparams',
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=onnx_format,
batch_nums=batch_nums,
deploy_backend=deploy_backend,
)
if __name__ == '__main__':
unittest.main()
......@@ -188,7 +188,6 @@ class TestPostTrainingQuantization(unittest.TestCase):
):
origin_model_path = self.download_model(data_url, data_md5, model_name)
# origin_model_path = os.path.join(origin_model_path, model_name)
print(
"Start FP32 inference for {0} on {1} images ...".format(
......
......@@ -17,115 +17,7 @@ import sys
import numpy as np
from ...fluid.framework import IrNode, Operator
_weight_supported_quantizable_op_type = [
'conv2d',
'depthwise_conv2d',
'conv2d_transpose',
'mul',
'matmul',
'matmul_v2',
]
_act_supported_quantizable_op_type = [
"pool2d",
"elementwise_add",
"concat",
"softmax",
"argmax",
"transpose",
"equal",
"gather",
"greater_equal",
"greater_than",
"less_equal",
"less_than",
"mean",
"not_equal",
"reshape",
"reshape2",
"dropout",
"bilinear_interp",
"nearest_interp",
"trilinear_interp",
"slice",
"squeeze",
"elementwise_sub",
"mul",
"matmul",
"relu",
"relu6",
"leaky_relu",
"tanh",
"swish",
"transpose",
"transpose2",
"sigmoid",
"pad2d",
"flatten",
"flatten2",
"batch_norm",
"layer_norm",
"matmul_v2",
"split",
"flatten_contiguous_range",
"squeeze2",
"nearest_interp_v2",
"bilinear_interp",
"bilinear_interp_v2",
"fill_constant_batch_size_like",
"arg_max",
"abs",
"assign",
"cast",
"clip",
"box_coder",
"crop",
"cumsum",
"elementwise_mul",
"elementwise_pow",
"expand_v2",
"fill_any_like",
"fill_constant",
"gelu",
"hard_sigmoid",
"hard_swish",
"instance_norm",
"lookup_table",
"lookup_table_v2",
"norm",
"p_norm",
"pad3d",
"pow",
"prelu",
"reduce_mean",
"unsqueeze",
"unsqueeze2",
"logical_and",
"logical_not",
"meshgrid",
"roi_align",
"strided_slice",
"where",
"grid_sampler",
"tile",
"group_norm",
"reduce_sum",
"square",
"softplus",
"shuffle_channel",
"reduce_max",
"scale",
]
QUANT_SUPPORTED_OP_TYPE_LIST = list(
set(
_weight_supported_quantizable_op_type
+ _act_supported_quantizable_op_type
)
)
_out_scale_op_list = QUANT_SUPPORTED_OP_TYPE_LIST
from .quant_config import SUPPORT_QUANTIZATION_OP_DICT
_channelwise_quant_axis1_ops = [
'conv2d_transpose',
......@@ -134,102 +26,6 @@ _channelwise_quant_axis1_ops = [
'matmul_v2',
]
# list op real input and output names, to avoid processing input such as AxisTensor.
_op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input", "Filter"], ["Output"]],
"conv2d_transpose": [["Input", "Filter"], ["Output"]],
"mul": [["X", "Y"], ["Out"]],
"matmul": [["X", "Y"], ["Out"]],
"matmul_v2": [["X", "Y"], ["Out"]],
"pool2d": [["X"], ["Out"]],
"elementwise_add": [["X", "Y"], ["Out"]],
"concat": [["X"], ["Out"]],
"softmax": [["X"], ["Out"]],
"argmax": [["X"], ["Out"]],
"transpose": [["X"], ["Out"]],
"equal": [["X", "Y"], ["Out"]],
"gather": [["X"], ["Out"]],
"greater_equal": [["X", "Y"], ["Out"]],
"greater_than": [["X", "Y"], ["Out"]],
"less_equal": [["X", "Y"], ["Out"]],
"less_than": [["X", "Y"], ["Out"]],
"mean": [["X"], ["Out"]],
"not_equal": [["X", "Y"], ["Out"]],
"reshape": [["X"], ["Out"]],
"reshape2": [["X"], ["Out"]],
"transpose2": [["X"], ["Out"]],
"nearest_interp": [["X"], ["Out"]],
"trilinear_interp": [["X"], ["Out"]],
"slice": [["Input"], ["Out"]],
"squeeze": [["X"], ["Out"]],
"elementwise_sub": [["X", "Y"], ["Out"]],
"relu": [["X"], ["Out"]],
"relu6": [["X"], ["Out"]],
"leaky_relu": [["X"], ["Out"]],
"prelu": [["X", "Alpha"], ["Out"]],
"tanh": [["X"], ["Out"]],
"swish": [["X"], ["Out"]],
"dropout": [["X"], ["Out"]],
"batch_norm": [["X"], ["Y"]],
"layer_norm": [["X"], ["Y"]],
"sigmoid": [["X"], ["Out"]],
"elementwise_mul": [["X", "Y"], ["Out"]],
"elementwise_pow": [["X", "Y"], ["Out"]],
"hard_swish": [["X"], ["Out"]],
"hard_sigmoid": [["X"], ["Out"]],
"gru": [["Input", "Weight"], ["Hidden"]],
"lstm": [["Input", "Weight"], ["Hidden"]],
"pad2d": [["X"], ["Out"]],
"pad3d": [["X"], ["Out"]],
"flatten": [["X"], ["Out"]],
"flatten2": [["X"], ["Out"]],
"unsqueeze2": [["X"], ["Out"]],
"flatten_contiguous_range": [["X"], ["Out"]],
"split": [["X"], ["Out"]],
"squeeze2": [["X"], ["Out"]],
"nearest_interp_v2": [["X"], ["Out"]],
"bilinear_interp": [["X"], ["Out"]],
"bilinear_interp_v2": [["X"], ["Out"]],
"fill_constant_batch_size_like": [["Input"], ["Out"]],
"arg_max": [["X"], ["Out"]],
"abs": [["X"], ["Out"]],
"assign": [["X"], ["Out"]],
"cast": [["X"], ["Out"]],
"clip": [["X"], ["Out"]],
"box_coder": [["PriorBox"], ["OutputBox"]],
"crop": [["X"], ["Out"]],
"cumsum": [["X"], ["Out"]],
"expand_v2": [["X"], ["Out"]],
"fill_any_like": [["X"], ["Out"]],
"fill_constant": [[], ["Out"]],
"gelu": [["X"], ["Out"]],
"instance_norm": [["X"], ["Y"]],
"lookup_table": [["W", "Ids"], ["Out"]],
"lookup_table_v2": [["W", "Ids"], ["Out"]],
"norm": [["X"], ["Norm"]],
"p_norm": [["X"], ["Out"]],
"pow": [["X"], ["Out"]],
"reduce_mean": [["X"], ["Out"]],
"stack": [["X"], ["Y"]],
"top_k_v2": [["X"], ["Out", "Indices"]],
"logical_and": [["X", "Y"], ["Out"]],
"logical_not": [["X"], ["Out"]],
"meshgrid": [["X"], ["Out"]],
"roi_align": [["X", "ROIs"], ["Out"]],
"strided_slice": [["Input"], ["Out"]],
"where": [["Condition", "X", "Y"], ["Out"]],
"grid_sampler": [["X", "Grid"], ["Output"]],
"tile": [["X"], ["Out"]],
"group_norm": [["X"], ["Y", "Mean", "Variance"]],
"reduce_sum": [["X"], ["Out"]],
"square": [["X"], ["Out"]],
"softplus": [["X"], ["Out"]],
"shuffle_channel": [["X"], ["Out"]],
"reduce_max": [["X"], ["Out"]],
"scale": [["X"], ["Out"]],
}
def _get_op_input_var_names(op):
"""
......@@ -244,10 +40,10 @@ def _get_op_input_var_names(op):
), "The input op should be IrNode or Operator."
var_names = []
op_name = op.name() if isinstance(op, IrNode) else op.type
if op_name not in _op_real_in_out_name:
if op_name not in SUPPORT_QUANTIZATION_OP_DICT:
return []
name_list = _op_real_in_out_name[op_name][0]
name_list = SUPPORT_QUANTIZATION_OP_DICT[op_name][0]
for name in name_list:
var_name = op.input(name)
if isinstance(var_name, list):
......@@ -264,10 +60,10 @@ def _get_op_output_var_names(op):
), "The input op should be IrNode or Operator."
var_names = []
op_name = op.name() if isinstance(op, IrNode) else op.type
if op_name not in _op_real_in_out_name:
if op_name not in SUPPORT_QUANTIZATION_OP_DICT:
return []
name_list = _op_real_in_out_name[op_name][1]
name_list = SUPPORT_QUANTIZATION_OP_DICT[op_name][1]
for name in name_list:
var_name = op.output(name)
if isinstance(var_name, list):
......@@ -283,11 +79,11 @@ def _get_input_name_index(op, input_var_name):
op, (IrNode, Operator)
), "The input op should be IrNode or Operator."
op_name = op.name() if isinstance(op, IrNode) else op.type
if op_name not in _op_real_in_out_name:
if op_name not in SUPPORT_QUANTIZATION_OP_DICT:
return None
res = None
for argname in _op_real_in_out_name[op_name][0]:
for argname in SUPPORT_QUANTIZATION_OP_DICT[op_name][0]:
var_names = op.input(argname)
for index, name in enumerate(var_names):
if name == input_var_name:
......@@ -301,10 +97,10 @@ def _get_output_name_index(op, output_var_name):
op, (IrNode, Operator)
), "The input op should be IrNode or Operator."
op_name = op.name() if isinstance(op, IrNode) else op.type
if op_name not in _op_real_in_out_name:
if op_name not in SUPPORT_QUANTIZATION_OP_DICT:
return None
name_list = _op_real_in_out_name[op_name][1]
name_list = SUPPORT_QUANTIZATION_OP_DICT[op_name][1]
res = None
for name in name_list:
var_name = op.output(name)
......@@ -347,7 +143,7 @@ def quant_tensor(x, scale, quant_axis=0, weight_bits=8, onnx_format=False):
if isinstance(scale, list) and len(scale) == 1:
scale = scale[0]
if isinstance(scale, list):
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
assert quant_axis in [-1, 0, 1], 'quant_axis should be 0 or 1 for now.'
for i, s in enumerate(scale):
if s == 0.0:
s = 1e-8
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册