未验证 提交 a038747c 编写于 作者: G Guanghua Yu 提交者: GitHub

update quantization new format (#46305)

上级 7fb20b46
...@@ -223,7 +223,8 @@ class ImperativeQuantAware(object): ...@@ -223,7 +223,8 @@ class ImperativeQuantAware(object):
self._quantize_inputs = ImperativeQuantizeInputs(**kwargs) self._quantize_inputs = ImperativeQuantizeInputs(**kwargs)
self._quantize_outputs = ImperativeQuantizeOutputs(moving_rate) self._quantize_outputs = ImperativeQuantizeOutputs(
moving_rate, activation_bits)
def quantize(self, model): def quantize(self, model):
""" """
...@@ -412,16 +413,18 @@ class ImperativeQuantizeOutputs(object): ...@@ -412,16 +413,18 @@ class ImperativeQuantizeOutputs(object):
Calculate the output scales for target layers. Calculate the output scales for target layers.
""" """
def __init__(self, moving_rate=0.9): def __init__(self, moving_rate=0.9, activation_bits=8):
""" """
The constructor for ImperativeQuantizeOutputs. The constructor for ImperativeQuantizeOutputs.
Args: Args:
moving_rate(float): The decay coefficient of moving average. moving_rate(float): The decay coefficient of moving average.
The default value is 0.9. The default value is 0.9.
activation_bits(int, optional): quantization bit number for activation. Default is 8.
""" """
super(ImperativeQuantizeOutputs, self).__init__() super(ImperativeQuantizeOutputs, self).__init__()
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._activation_bits = activation_bits
def apply(self, model): def apply(self, model):
""" """
...@@ -478,7 +481,7 @@ class ImperativeQuantizeOutputs(object): ...@@ -478,7 +481,7 @@ class ImperativeQuantizeOutputs(object):
the saved model. Default None. the saved model. Default None.
onnx_format (bool, optional): Whether to export the quantized model onnx_format (bool, optional): Whether to export the quantized model
with format of ONNX. Default is False. with format of ONNX. Default is False.
**configs (dict, optional): Other save configuration options for **config (dict, optional): Other save configuration options for
compatibility. We do not recommend using these configurations, compatibility. We do not recommend using these configurations,
they may be removed in the future. If not necessary, DO NOT use they may be removed in the future. If not necessary, DO NOT use
them. Default None. them. Default None.
...@@ -518,27 +521,30 @@ class ImperativeQuantizeOutputs(object): ...@@ -518,27 +521,30 @@ class ImperativeQuantizeOutputs(object):
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename)) params_filename=params_filename))
self._gather_scales(infer_program, scope, fetch_targets) if not onnx_format:
self._gather_scales(infer_program, scope, fetch_targets)
# Remove `moving_average_abs_max_scale` node in sub graphs. # Remove `moving_average_abs_max_scale` node in sub graphs.
graph = IrGraph(core.Graph(infer_program.desc), for_test=False) graph = IrGraph(core.Graph(infer_program.desc), for_test=False)
for sub_graph in graph.all_sub_graphs(): for sub_graph in graph.all_sub_graphs():
for _op in sub_graph.all_op_nodes(): for _op in sub_graph.all_op_nodes():
if _op.name() == "moving_average_abs_max_scale": if _op.name() == "moving_average_abs_max_scale":
sub_graph.safe_remove_nodes(_op) sub_graph.safe_remove_nodes(_op)
sub_graph.resolve_hazard() sub_graph.resolve_hazard()
infer_program = graph.to_program() infer_program = graph.to_program()
self._set_skip_quant_attr(infer_program) self._set_skip_quant_attr(infer_program)
clip_extra = False clip_extra = False
if onnx_format: else:
graph = IrGraph(core.Graph(infer_program.desc), for_test=False) graph = IrGraph(core.Graph(infer_program.desc), for_test=False)
transform_pass = ReplaceFakeQuantDequantPass(scope, place) transform_pass = ReplaceFakeQuantDequantPass(
scope, place, quant_bits=self._activation_bits)
transform_pass.apply(graph) transform_pass.apply(graph)
quant_weight_pass = QuantWeightPass(scope, place) quant_weight_pass = QuantWeightPass(scope, place)
quant_weight_pass.apply(graph) quant_weight_pass.apply(graph)
infer_program = graph.to_program() infer_program = graph.to_program()
clip_extra = True clip_extra = True
......
...@@ -344,7 +344,7 @@ class PostTrainingQuantization(object): ...@@ -344,7 +344,7 @@ class PostTrainingQuantization(object):
self._fetch_list = None self._fetch_list = None
self._data_loader = data_loader self._data_loader = data_loader
self._out_scale_op_list = utils._out_scale_op_list self._out_scale_op_list = utils.QUANT_SUPPORTED_OP_TYPE_LIST
self._quantized_weight_var_name = set() self._quantized_weight_var_name = set()
self._quantized_act_var_name = set() self._quantized_act_var_name = set()
self._weight_op_pairs = {} self._weight_op_pairs = {}
...@@ -843,9 +843,6 @@ class PostTrainingQuantization(object): ...@@ -843,9 +843,6 @@ class PostTrainingQuantization(object):
hist, _ = np.histogram(var_tensor_abs, bins=bins) hist, _ = np.histogram(var_tensor_abs, bins=bins)
self._sampling_act_histogram[var_name][0] += hist self._sampling_act_histogram[var_name][0] += hist
def l2_loss(self, gt, pred):
return ((gt - pred)**2).mean()
def _sample_ptf(self): def _sample_ptf(self):
""" """
The following code are modified from: The following code are modified from:
...@@ -885,10 +882,10 @@ class PostTrainingQuantization(object): ...@@ -885,10 +882,10 @@ class PostTrainingQuantization(object):
q_max) * scale4 q_max) * scale4
quant_dequant_var_scale8 = np.clip(np.round(var_tensor / scale8), 0, quant_dequant_var_scale8 = np.clip(np.round(var_tensor / scale8), 0,
q_max) * scale8 q_max) * scale8
score1 = self.l2_loss(var_tensor, quant_dequant_var_scale1) score1 = utils.l2_loss(var_tensor, quant_dequant_var_scale1)
score2 = self.l2_loss(var_tensor, quant_dequant_var_scale2) score2 = utils.l2_loss(var_tensor, quant_dequant_var_scale2)
score4 = self.l2_loss(var_tensor, quant_dequant_var_scale4) score4 = utils.l2_loss(var_tensor, quant_dequant_var_scale4)
score8 = self.l2_loss(var_tensor, quant_dequant_var_scale8) score8 = utils.l2_loss(var_tensor, quant_dequant_var_scale8)
score = [score1, score2, score4, score8] score = [score1, score2, score4, score8]
mask = 2**score.index(min(score)) mask = 2**score.index(min(score))
scale = scale1 * mask scale = scale1 * mask
...@@ -1035,7 +1032,7 @@ class PostTrainingQuantization(object): ...@@ -1035,7 +1032,7 @@ class PostTrainingQuantization(object):
scope=self._scope, scope=self._scope,
place=self._place, place=self._place,
quantizable_op_type=minor_quantizable_op_types, quantizable_op_type=minor_quantizable_op_types,
is_full_quantized=self._is_full_quantize) is_full_quantized=True)
for sub_graph in graph.all_sub_graphs(): for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True sub_graph._for_test = True
......
...@@ -44,6 +44,7 @@ __all__ = [ ...@@ -44,6 +44,7 @@ __all__ = [
'AddQuantDequantPassV2', 'AddQuantDequantPassV2',
'ReplaceFakeQuantDequantPass', 'ReplaceFakeQuantDequantPass',
'QuantWeightPass', 'QuantWeightPass',
'AddQuantDequantForInferencePass',
] ]
_fake_quant_op_list = [ _fake_quant_op_list = [
...@@ -1437,7 +1438,7 @@ class OutScaleForTrainingPass(object): ...@@ -1437,7 +1438,7 @@ class OutScaleForTrainingPass(object):
self._place = _get_paddle_place(place) self._place = _get_paddle_place(place)
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._is_test = is_test self._is_test = is_test
self._teller_set = utils._out_scale_op_list self._teller_set = utils.QUANT_SUPPORTED_OP_TYPE_LIST
self._scale_dict = scale_dict self._scale_dict = scale_dict
def apply(self, graph): def apply(self, graph):
...@@ -1567,7 +1568,7 @@ class OutScaleForInferencePass(object): ...@@ -1567,7 +1568,7 @@ class OutScaleForInferencePass(object):
scope(fluid.Scope): The scope is used to initialize these new parameters. scope(fluid.Scope): The scope is used to initialize these new parameters.
""" """
self._scope = scope self._scope = scope
self._teller_set = utils._out_scale_op_list self._teller_set = utils.QUANT_SUPPORTED_OP_TYPE_LIST
def apply(self, graph): def apply(self, graph):
""" """
...@@ -1852,6 +1853,7 @@ class InsertQuantizeLinear(object): ...@@ -1852,6 +1853,7 @@ class InsertQuantizeLinear(object):
channel_wise(bool, optional): Whether quantization with per channel or not. Default is False. channel_wise(bool, optional): Whether quantization with per channel or not. Default is False.
moving_rate(float): the rate for 'moving average' method. moving_rate(float): the rate for 'moving average' method.
is_test(bool, optional): Whether quantization with training or not. Default is True. is_test(bool, optional): Whether quantization with training or not. Default is True.
scale_dict(dict, optional): calibration ranges of tensors output.
""" """
def __init__(self, def __init__(self,
...@@ -1861,7 +1863,8 @@ class InsertQuantizeLinear(object): ...@@ -1861,7 +1863,8 @@ class InsertQuantizeLinear(object):
quant_axis=-1, quant_axis=-1,
channel_wise=False, channel_wise=False,
moving_rate=0.9, moving_rate=0.9,
is_test=True): is_test=True,
scale_dict=None):
self._place = place self._place = place
self._scope = scope self._scope = scope
self.quant_bits = quant_bits self.quant_bits = quant_bits
...@@ -1869,6 +1872,7 @@ class InsertQuantizeLinear(object): ...@@ -1869,6 +1872,7 @@ class InsertQuantizeLinear(object):
self.channel_wise = channel_wise self.channel_wise = channel_wise
self._is_test = is_test self._is_test = is_test
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._scale_dict = scale_dict
def insert_quant_op(self, graph, var_node, var_name=None): def insert_quant_op(self, graph, var_node, var_name=None):
assert var_node.is_var(), '{} is not a var'.format(var_node.name()) assert var_node.is_var(), '{} is not a var'.format(var_node.name())
...@@ -1880,16 +1884,24 @@ class InsertQuantizeLinear(object): ...@@ -1880,16 +1884,24 @@ class InsertQuantizeLinear(object):
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype( data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32' ) == core.VarDesc.VarType.FP64 else 'float32'
scale_name = self._quantized_scale_name(var_name)
if self.channel_wise: if self.channel_wise:
scale_var_shape = var_node.shape()[self.quant_axis] scale_var_shape = var_node.shape()[self.quant_axis]
scale_var_type = core.VarDesc.VarType.LOD_TENSOR scale_var_type = core.VarDesc.VarType.LOD_TENSOR
init_scale_value = np.zeros(scale_var_shape, dtype=data_type) init_scale_value = np.ones(scale_var_shape,
dtype=data_type) * _SCALE_DEFAULT_VALUE
else: else:
scale_var_shape = 1 scale_var_shape = 1
scale_var_type = var_node.type() scale_var_type = var_node.type()
init_scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type) init_scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type)
if self._scale_dict is not None and var_node.name(
) in self._scale_dict.keys():
init_scale_value = np.array([self._scale_dict[var_node.name()]],
dtype=data_type)
scale_var_node = graph.create_persistable_node( scale_var_node = graph.create_persistable_node(
name=self._quantized_scale_name(var_name), name=scale_name,
var_type=scale_var_type, var_type=scale_var_type,
shape=[scale_var_shape], shape=[scale_var_shape],
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
...@@ -2346,7 +2358,8 @@ class AddQuantDequantPassV2(object): ...@@ -2346,7 +2358,8 @@ class AddQuantDequantPassV2(object):
skip_pattern=["skip_quant"], skip_pattern=["skip_quant"],
quantizable_op_type=["elementwise_add", "pool2d"], quantizable_op_type=["elementwise_add", "pool2d"],
is_full_quantized=False, is_full_quantized=False,
is_test=None): is_test=None,
scale_dict=None):
""" """
Args: Args:
scope(paddle.Scope): The scope is used to initialize these new parameters. scope(paddle.Scope): The scope is used to initialize these new parameters.
...@@ -2366,6 +2379,7 @@ class AddQuantDequantPassV2(object): ...@@ -2366,6 +2379,7 @@ class AddQuantDequantPassV2(object):
quantization to all supported quantizable op type. If set is_full_quantized quantization to all supported quantizable op type. If set is_full_quantized
as False, only apply quantization to the op type according to the input as False, only apply quantization to the op type according to the input
quantizable_op_type. quantizable_op_type.
scale_dict(dict, optional): calibration ranges of tensors output.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -2388,6 +2402,7 @@ class AddQuantDequantPassV2(object): ...@@ -2388,6 +2402,7 @@ class AddQuantDequantPassV2(object):
self._quant_bits = quant_bits self._quant_bits = quant_bits
self._is_test = is_test self._is_test = is_test
self._skip_pattern = skip_pattern self._skip_pattern = skip_pattern
self._scale_dict = scale_dict
if is_full_quantized: if is_full_quantized:
self._quantizable_op_type = utils._act_supported_quantizable_op_type self._quantizable_op_type = utils._act_supported_quantizable_op_type
...@@ -2444,8 +2459,6 @@ class AddQuantDequantPassV2(object): ...@@ -2444,8 +2459,6 @@ class AddQuantDequantPassV2(object):
if is_skip or is_quantized: if is_skip or is_quantized:
continue continue
op_node.op()._set_attr("quantization_type",
"qat_without_weight")
arg_names = utils._get_op_input_var_names(op_node) arg_names = utils._get_op_input_var_names(op_node)
for arg_name in arg_names: for arg_name in arg_names:
in_node = graph._find_node_by_name( in_node = graph._find_node_by_name(
...@@ -2462,7 +2475,8 @@ class AddQuantDequantPassV2(object): ...@@ -2462,7 +2475,8 @@ class AddQuantDequantPassV2(object):
quant_axis=-1, quant_axis=-1,
channel_wise=False, channel_wise=False,
moving_rate=self._moving_rate, moving_rate=self._moving_rate,
is_test=self._is_test) is_test=self._is_test,
scale_dict=self._scale_dict)
quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op( quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op(
graph, in_node) graph, in_node)
dequant_var_node = insert_quant_pass.insert_dequant_op( dequant_var_node = insert_quant_pass.insert_dequant_op(
...@@ -2491,13 +2505,14 @@ class ReplaceFakeQuantDequantPass(object): ...@@ -2491,13 +2505,14 @@ class ReplaceFakeQuantDequantPass(object):
replace quant-dequant ops with quantize_linear and dequantize_linear ops. replace quant-dequant ops with quantize_linear and dequantize_linear ops.
""" """
def __init__(self, scope, place): def __init__(self, scope, place, quant_bits=8):
r""" r"""
Args: Args:
scope(paddle.Scope): The scope is used to initialize these new parameters. scope(paddle.Scope): The scope is used to initialize these new parameters.
place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to initialize new place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to initialize new
parameters described above. If ``place`` is string, it can be It can be ``cpu`` parameters described above. If ``place`` is string, it can be It can be ``cpu``
or ``gpu:x``, where ``x`` is the index of the GPUs. or ``gpu:x``, where ``x`` is the index of the GPUs.
quant_bits(int, optional): quantization bit number for activation. Default is 8.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -2516,6 +2531,7 @@ class ReplaceFakeQuantDequantPass(object): ...@@ -2516,6 +2531,7 @@ class ReplaceFakeQuantDequantPass(object):
""" """
self._place = _get_paddle_place(place) self._place = _get_paddle_place(place)
self._scope = scope self._scope = scope
self._quant_bits = quant_bits
assert self._scope != None, "scope must not be None." assert self._scope != None, "scope must not be None."
assert self._place != None, "place must not be None." assert self._place != None, "place must not be None."
...@@ -2525,7 +2541,8 @@ class ReplaceFakeQuantDequantPass(object): ...@@ -2525,7 +2541,8 @@ class ReplaceFakeQuantDequantPass(object):
fake_quant_dequant_ops = [] fake_quant_dequant_ops = []
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in _fake_quant_dequant_op_list: if op.name() in _fake_quant_dequant_op_list or op.name(
) == "moving_average_abs_max_scale":
fake_quant_dequant_ops.append(op) fake_quant_dequant_ops.append(op)
for _op in fake_quant_dequant_ops: for _op in fake_quant_dequant_ops:
...@@ -2544,7 +2561,7 @@ class ReplaceFakeQuantDequantPass(object): ...@@ -2544,7 +2561,7 @@ class ReplaceFakeQuantDequantPass(object):
quant_axis = op.op().attr("quant_axis") if op.op().has_attr( quant_axis = op.op().attr("quant_axis") if op.op().has_attr(
"quant_axis") else -1 "quant_axis") else -1
bit_length = op.op().attr("bit_length") if op.op().has_attr( bit_length = op.op().attr("bit_length") if op.op().has_attr(
"bit_length") else 8 "bit_length") else self._quant_bits
zero_point_node = None zero_point_node = None
quanted_node = x_node quanted_node = x_node
...@@ -2733,3 +2750,140 @@ class QuantWeightPass(object): ...@@ -2733,3 +2750,140 @@ class QuantWeightPass(object):
def _restore_var(self, name, array): def _restore_var(self, name, array):
tensor = self._scope.find_var(name).get_tensor() tensor = self._scope.find_var(name).get_tensor()
tensor.set(array, self._place) tensor.set(array, self._place)
class AddQuantDequantForInferencePass(object):
"""
When export quant model, it will traverse to find the output of each op, and then insert the quant/dequant op after it.
"""
def __init__(self, scope, place, quant_bits=8):
"""
Args:
scope(fluid.Scope): The scope is used to initialize these new parameters.
place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to restore the weight tensors.
If it's string, it can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs.
quant_bits(int, optional): quantization bit number for weight. Default is 8.
"""
self._scope = scope
self._place = place
self._quant_bits = quant_bits
self._teller_set = utils.QUANT_SUPPORTED_OP_TYPE_LIST
def apply(self, graph):
"""
Args:
graph(IrGraph): the target graph.
"""
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
dequant_node_map = {}
dequantized_vars_map = collections.OrderedDict()
for op_node in graph.all_op_nodes():
if op_node.name() in self._teller_set:
var_names = utils._get_op_output_var_names(op_node)
for var_name in var_names:
out_node = graph._find_node_by_name(op_node.outputs,
var_name)
if out_node.dtype() not in \
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]:
continue
if var_name in dequantized_vars_map:
dequant_var_node = dequantized_vars_map[var_name]
else:
dequant_var_node = self._insert_quant_dequant_op(
graph, out_node)
dequantized_vars_map[var_name] = dequant_var_node
dequant_node_map[var_name] = dequant_var_node
# remove unuse node and link act quant/dequant linear to op node
for op_node in graph.all_op_nodes():
if op_node.name() == 'moving_average_abs_max_scale':
graph.safe_remove_nodes(op_node)
else:
var_names = utils._get_op_input_var_names(op_node)
for var_name in var_names:
if var_name in dequant_node_map:
in_node = graph._find_node_by_name(
op_node.inputs, var_name)
graph.update_input_link(in_node,
dequant_node_map[var_name],
op_node)
return graph
def _scale_name(self, var_name):
"""
Return the scale name for the var named `var_name`.
"""
return "%s@scale" % (var_name)
def _insert_quant_dequant_op(self, graph, var_node):
assert var_node.is_var(), '{} is not a var'.format(var_node.name())
var_name = var_node.name()
quant_axis = -1
quant_var_node = graph.create_var_node(
name="{}.quantized".format(var_name),
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=var_node.dtype())
scale_var_node = graph._find_node_by_name(graph.all_persistable_nodes(),
self._scale_name(var_name))
try:
zero_point_node = graph._find_node_by_name(
graph.all_persistable_nodes(),
"{}@zero_point".format(quant_var_node.name()))
except:
zero_point_node = graph.create_persistable_node(
name="{}@zero_point".format(quant_var_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=scale_var_node.shape(),
var_dtype=core.VarDesc.VarType.INT32)
_init_var_node(zero_point_node,
np.zeros(scale_var_node.shape(), dtype="int32"),
self._scope, self._place)
inputs = {"X": var_node, "Scale": scale_var_node}
if zero_point_node is not None:
inputs["ZeroPoint"] = zero_point_node
attrs = {"quant_axis": quant_axis, "bit_length": self._quant_bits}
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
outputs = {"Y": quant_var_node}
quant_op_node = graph.create_op_node(op_type="quantize_linear",
attrs=attrs,
inputs=inputs,
outputs=outputs)
graph.link_to(var_node, quant_op_node)
graph.link_to(scale_var_node, quant_op_node)
if zero_point_node is not None:
graph.link_to(zero_point_node, quant_op_node)
graph.link_to(quant_op_node, quant_var_node)
# add dequant_linear node
dequant_var_node = graph.create_var_node(
name="{}.dequantized".format(quant_var_node.name()),
var_type=quant_var_node.type(),
shape=quant_var_node.shape(),
var_dtype=quant_var_node.dtype())
inputs = {"X": quant_var_node, "Scale": scale_var_node}
if zero_point_node is not None:
inputs["ZeroPoint"] = zero_point_node
attrs = {"quant_axis": -1, "bit_length": self._quant_bits}
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
dequant_op_node = graph.create_op_node(op_type="dequantize_linear",
attrs=attrs,
inputs=inputs,
outputs={"Y": dequant_var_node})
graph.link_to(quant_var_node, dequant_op_node)
graph.link_to(scale_var_node, dequant_op_node)
if zero_point_node is not None:
graph.link_to(zero_point_node, dequant_op_node)
graph.link_to(dequant_op_node, dequant_var_node)
return dequant_var_node
...@@ -38,6 +38,7 @@ _act_supported_quantizable_op_type = [ ...@@ -38,6 +38,7 @@ _act_supported_quantizable_op_type = [
"mean", "mean",
"not_equal", "not_equal",
"reshape", "reshape",
"reshape2",
"dropout", "dropout",
"bilinear_interp", "bilinear_interp",
"nearest_interp", "nearest_interp",
...@@ -112,10 +113,12 @@ _act_supported_quantizable_op_type = [ ...@@ -112,10 +113,12 @@ _act_supported_quantizable_op_type = [
"scale", "scale",
] ]
_out_scale_op_list = list( QUANT_SUPPORTED_OP_TYPE_LIST = list(
set(_weight_supported_quantizable_op_type + set(_weight_supported_quantizable_op_type +
_act_supported_quantizable_op_type)) _act_supported_quantizable_op_type))
_out_scale_op_list = QUANT_SUPPORTED_OP_TYPE_LIST
_channelwise_quant_axis1_ops = [ _channelwise_quant_axis1_ops = [
'conv2d_transpose', 'mul', 'matmul', 'matmul_v2' 'conv2d_transpose', 'mul', 'matmul', 'matmul_v2'
] ]
...@@ -430,6 +433,10 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor): ...@@ -430,6 +433,10 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor):
return cos_sim return cos_sim
def l2_loss(gt, pred):
return ((gt - pred)**2).mean()
class tqdm(object): class tqdm(object):
def __init__(self, total, bar_format='Loading|{bar}', ncols=80): def __init__(self, total, bar_format='Loading|{bar}', ncols=80):
......
...@@ -292,24 +292,6 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -292,24 +292,6 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_use_cache_file=is_use_cache_file) is_use_cache_file=is_use_cache_file)
ptq.quantize() ptq.quantize()
ptq.save_quantized_model(self.int8_model) ptq.save_quantized_model(self.int8_model)
if onnx_format:
try:
collect_dict = ptq._calibration_scales
save_quant_table_path = os.path.join(self.int8_model,
'calibration_table.txt')
with open(save_quant_table_path, 'w') as txt_file:
for tensor_name in collect_dict.keys():
write_line = '{} {}'.format(
tensor_name,
collect_dict[tensor_name]['scale']) + '\n'
txt_file.write(write_line)
print(
"Quantization clip ranges of tensors is save in: {}".format(
save_quant_table_path))
except:
print(
"Unable to generate `calibration_table.txt`, please update PaddlePaddle >= 2.3.3"
)
def run_test(self, def run_test(self,
model, model,
...@@ -429,36 +411,6 @@ class TestMKLDNNInt8ForMobilenetv1Avg(TestPostTrainingQuantization): ...@@ -429,36 +411,6 @@ class TestMKLDNNInt8ForMobilenetv1Avg(TestPostTrainingQuantization):
onnx_format=False) onnx_format=False)
class TestMKLDNNInt8ForMobilenetv1AbsMaxONNXFormat(TestPostTrainingQuantization
):
def test_onnx_format_abs_max_mobilenetv1(self):
model = "MobileNet-V1"
algo = "abs_max"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
# The accuracy diff of post-training quantization (abs_max) maybe bigger
diff_threshold = 0
self.run_test(model,
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=True)
class TestMKLDNNInt8ForMobilenetv1AbsMax(TestPostTrainingQuantization): class TestMKLDNNInt8ForMobilenetv1AbsMax(TestPostTrainingQuantization):
def test_abs_max_mobilenetv1(self): def test_abs_max_mobilenetv1(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册