未验证 提交 11814d1c 编写于 作者: G Guanghua Yu 提交者: GitHub

fix fake quant demo (#1397)

上级 a0d87b27
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
import paddle import paddle
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.framework import core from paddle.framework import core
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass, AddQuantDequantPass, QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass, QuantizationTransformPassV2, AddQuantDequantPass, AddQuantDequantPassV2, QuantizationFreezePass, QuantWeightPass
try: try:
from paddle.fluid.contrib.slim.quantization import utils from paddle.fluid.contrib.slim.quantization import utils
...@@ -23,7 +23,8 @@ def post_quant_fake(executor, ...@@ -23,7 +23,8 @@ def post_quant_fake(executor,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False, is_full_quantize=False,
activation_bits=8, activation_bits=8,
weight_bits=8): weight_bits=8,
onnx_format=False):
""" """
Utilizing post training quantization methon to quantize the FP32 model, Utilizing post training quantization methon to quantize the FP32 model,
and it not uses calibrate data and the fake model cannot be used in practice. and it not uses calibrate data and the fake model cannot be used in practice.
...@@ -67,14 +68,24 @@ def post_quant_fake(executor, ...@@ -67,14 +68,24 @@ def post_quant_fake(executor,
for op_type in _weight_supported_quantizable_op_type: for op_type in _weight_supported_quantizable_op_type:
if op_type in _quantizable_op_type: if op_type in _quantizable_op_type:
major_quantizable_op_types.append(op_type) major_quantizable_op_types.append(op_type)
transform_pass = QuantizationTransformPass( if onnx_format:
scope=_scope, transform_pass = QuantizationTransformPassV2(
place=_place, scope=_scope,
weight_bits=weight_bits, place=_place,
activation_bits=activation_bits, weight_bits=weight_bits,
activation_quantize_type=activation_quantize_type, activation_bits=activation_bits,
weight_quantize_type=weight_quantize_type, activation_quantize_type=activation_quantize_type,
quantizable_op_type=major_quantizable_op_types) weight_quantize_type=weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)
else:
transform_pass = QuantizationTransformPass(
scope=_scope,
place=_place,
weight_bits=weight_bits,
activation_bits=activation_bits,
activation_quantize_type=activation_quantize_type,
weight_quantize_type=weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)
for sub_graph in graph.all_sub_graphs(): for sub_graph in graph.all_sub_graphs():
# Insert fake_quant/fake_dequantize op must in test graph, so # Insert fake_quant/fake_dequantize op must in test graph, so
...@@ -87,30 +98,66 @@ def post_quant_fake(executor, ...@@ -87,30 +98,66 @@ def post_quant_fake(executor,
for op_type in _act_supported_quantizable_op_type: for op_type in _act_supported_quantizable_op_type:
if op_type in _quantizable_op_type: if op_type in _quantizable_op_type:
minor_quantizable_op_types.append(op_type) minor_quantizable_op_types.append(op_type)
add_quant_dequant_pass = AddQuantDequantPass( if onnx_format:
scope=_scope, add_quant_dequant_pass = AddQuantDequantPassV2(
place=_place, scope=_scope,
quantizable_op_type=minor_quantizable_op_types) place=_place,
quantizable_op_type=minor_quantizable_op_types,
is_full_quantized=is_full_quantize)
else:
add_quant_dequant_pass = AddQuantDequantPass(
scope=_scope,
place=_place,
quantizable_op_type=minor_quantizable_op_types)
for sub_graph in graph.all_sub_graphs(): for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True sub_graph._for_test = True
add_quant_dequant_pass.apply(sub_graph) add_quant_dequant_pass.apply(sub_graph)
# apply QuantizationFreezePass, and obtain the final quant model # apply QuantizationFreezePass, and obtain the final quant model
freeze_pass = QuantizationFreezePass( if onnx_format:
scope=_scope, quant_weight_pass = QuantWeightPass(_scope, _place)
place=_place, for sub_graph in graph.all_sub_graphs():
weight_bits=weight_bits, sub_graph._for_test = True
activation_bits=activation_bits, quant_weight_pass.apply(sub_graph)
weight_quantize_type=weight_quantize_type, else:
quantizable_op_type=major_quantizable_op_types) freeze_pass = QuantizationFreezePass(
scope=_scope,
place=_place,
weight_bits=weight_bits,
activation_bits=activation_bits,
weight_quantize_type=weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)
for sub_graph in graph.all_sub_graphs(): for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True sub_graph._for_test = True
freeze_pass.apply(sub_graph) freeze_pass.apply(sub_graph)
_program = graph.to_program() _program = graph.to_program()
def save_info(op_node, out_var_name, out_info_name, quantized_type):
op_node._set_attr(out_info_name, 0.001)
op_node._set_attr("with_quant_attr", True)
if op_node.type in _quantizable_op_type:
op._set_attr("quantization_type", quantized_type)
def analysis_and_save_info(op_node, out_var_name):
argname_index = utils._get_output_name_index(op_node, out_var_name)
assert argname_index is not None, \
out_var_name + " is not the output of the op"
save_info(op_node, out_var_name, "out_threshold", "post_avg")
save_info(op_node, out_var_name,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_avg")
for block_id in range(len(_program.blocks)):
for op in _program.blocks[block_id].ops:
if op.type in (_quantizable_op_type + utils._out_scale_op_list):
out_var_names = utils._get_op_output_var_names(op)
for var_name in out_var_names:
analysis_and_save_info(op, var_name)
feed_vars = [_program.global_block().var(name) for name in _feed_list] feed_vars = [_program.global_block().var(name) for name in _feed_list]
model_name = model_filename.split('.')[ model_name = model_filename.split('.')[
0] if model_filename is not None else 'model' 0] if model_filename is not None else 'model'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册