未验证 提交 70a36428 编写于 作者: G Guanghua Yu 提交者: GitHub

add ptq data-free method (#1026)

* add ptq data-free method
上级 d31a202a
import paddle
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass, AddQuantDequantPass, QuantizationFreezePass
def post_quant_fake(executor,
model_dir,
model_filename=None,
params_filename=None,
save_model_path=None,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
activation_bits=8,
weight_bits=8):
"""
Utilizing post training quantization methon to quantize the FP32 model,
and it not uses calibrate data and the fake model cannot be used in practice.
Usage:
paddle.enable_static()
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
post_quant_fake(executor=exe,
model_dir='./inference_model/MobileNet/',
model_filename='model',
params_filename='params',
save_model_path='fake_quant')
"""
activation_quantize_type = 'range_abs_max'
weight_quantize_type = 'channel_wise_abs_max'
_dynamic_quantize_op_type = ['lstm']
_weight_supported_quantizable_op_type = QuantizationTransformPass._supported_quantizable_op_type
_act_supported_quantizable_op_type = AddQuantDequantPass._supported_quantizable_op_type
_support_quantize_op_type = list(
set(_weight_supported_quantizable_op_type +
_act_supported_quantizable_op_type + _dynamic_quantize_op_type))
_place = executor.place
_scope = paddle.static.global_scope()
if is_full_quantize:
_quantizable_op_type = _support_quantize_op_type
else:
_quantizable_op_type = quantizable_op_type
for op_type in _quantizable_op_type:
assert op_type in _support_quantize_op_type, \
op_type + " is not supported for quantization."
_program, _feed_list, _fetch_list = paddle.fluid.io.load_inference_model(
model_dir,
executor,
model_filename=model_filename,
params_filename=params_filename)
graph = IrGraph(core.Graph(_program.desc), for_test=True)
# use QuantizationTransformPass to insert fake_quant/fake_dequantize op
major_quantizable_op_types = []
for op_type in _weight_supported_quantizable_op_type:
if op_type in _quantizable_op_type:
major_quantizable_op_types.append(op_type)
transform_pass = QuantizationTransformPass(
scope=_scope,
place=_place,
weight_bits=weight_bits,
activation_bits=activation_bits,
activation_quantize_type=activation_quantize_type,
weight_quantize_type=weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)
for sub_graph in graph.all_sub_graphs():
# Insert fake_quant/fake_dequantize op must in test graph, so
# set per graph's _for_test is True.
sub_graph._for_test = True
transform_pass.apply(sub_graph)
# use AddQuantDequantPass to insert fake_quant_dequant op
minor_quantizable_op_types = []
for op_type in _act_supported_quantizable_op_type:
if op_type in _quantizable_op_type:
minor_quantizable_op_types.append(op_type)
add_quant_dequant_pass = AddQuantDequantPass(
scope=_scope,
place=_place,
quantizable_op_type=minor_quantizable_op_types)
for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
add_quant_dequant_pass.apply(sub_graph)
# apply QuantizationFreezePass, and obtain the final quant model
freeze_pass = QuantizationFreezePass(
scope=_scope,
place=_place,
weight_bits=weight_bits,
activation_bits=activation_bits,
weight_quantize_type=weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)
for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
freeze_pass.apply(sub_graph)
_program = graph.to_program()
paddle.fluid.io.save_inference_model(
dirname=save_model_path,
model_filename=model_filename,
params_filename=params_filename,
feeded_var_names=_feed_list,
target_vars=_fetch_list,
executor=executor,
main_program=_program)
print("The quantized model is saved in: " + save_model_path)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册