“4d8345e3ac0ba79a17359a72e940ade284c0b1a9”上不存在“paddle/phi/core/lod_utils.h”
未验证 提交 11814d1c 编写于 作者: G Guanghua Yu 提交者: GitHub

fix fake quant demo (#1397)

上级 a0d87b27
......@@ -2,7 +2,7 @@ import os
import paddle
from paddle.fluid.framework import IrGraph
from paddle.framework import core
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass, AddQuantDequantPass, QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass, QuantizationTransformPassV2, AddQuantDequantPass, AddQuantDequantPassV2, QuantizationFreezePass, QuantWeightPass
try:
from paddle.fluid.contrib.slim.quantization import utils
......@@ -23,7 +23,8 @@ def post_quant_fake(executor,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
activation_bits=8,
weight_bits=8):
weight_bits=8,
onnx_format=False):
"""
Utilizing post training quantization methon to quantize the FP32 model,
and it not uses calibrate data and the fake model cannot be used in practice.
......@@ -67,14 +68,24 @@ def post_quant_fake(executor,
for op_type in _weight_supported_quantizable_op_type:
if op_type in _quantizable_op_type:
major_quantizable_op_types.append(op_type)
transform_pass = QuantizationTransformPass(
scope=_scope,
place=_place,
weight_bits=weight_bits,
activation_bits=activation_bits,
activation_quantize_type=activation_quantize_type,
weight_quantize_type=weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)
if onnx_format:
transform_pass = QuantizationTransformPassV2(
scope=_scope,
place=_place,
weight_bits=weight_bits,
activation_bits=activation_bits,
activation_quantize_type=activation_quantize_type,
weight_quantize_type=weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)
else:
transform_pass = QuantizationTransformPass(
scope=_scope,
place=_place,
weight_bits=weight_bits,
activation_bits=activation_bits,
activation_quantize_type=activation_quantize_type,
weight_quantize_type=weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)
for sub_graph in graph.all_sub_graphs():
# Insert fake_quant/fake_dequantize op must in test graph, so
......@@ -87,30 +98,66 @@ def post_quant_fake(executor,
for op_type in _act_supported_quantizable_op_type:
if op_type in _quantizable_op_type:
minor_quantizable_op_types.append(op_type)
add_quant_dequant_pass = AddQuantDequantPass(
scope=_scope,
place=_place,
quantizable_op_type=minor_quantizable_op_types)
if onnx_format:
add_quant_dequant_pass = AddQuantDequantPassV2(
scope=_scope,
place=_place,
quantizable_op_type=minor_quantizable_op_types,
is_full_quantized=is_full_quantize)
else:
add_quant_dequant_pass = AddQuantDequantPass(
scope=_scope,
place=_place,
quantizable_op_type=minor_quantizable_op_types)
for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
add_quant_dequant_pass.apply(sub_graph)
# apply QuantizationFreezePass, and obtain the final quant model
freeze_pass = QuantizationFreezePass(
scope=_scope,
place=_place,
weight_bits=weight_bits,
activation_bits=activation_bits,
weight_quantize_type=weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)
if onnx_format:
quant_weight_pass = QuantWeightPass(_scope, _place)
for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
quant_weight_pass.apply(sub_graph)
else:
freeze_pass = QuantizationFreezePass(
scope=_scope,
place=_place,
weight_bits=weight_bits,
activation_bits=activation_bits,
weight_quantize_type=weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)
for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
freeze_pass.apply(sub_graph)
for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
freeze_pass.apply(sub_graph)
_program = graph.to_program()
def save_info(op_node, out_var_name, out_info_name, quantized_type):
op_node._set_attr(out_info_name, 0.001)
op_node._set_attr("with_quant_attr", True)
if op_node.type in _quantizable_op_type:
op._set_attr("quantization_type", quantized_type)
def analysis_and_save_info(op_node, out_var_name):
argname_index = utils._get_output_name_index(op_node, out_var_name)
assert argname_index is not None, \
out_var_name + " is not the output of the op"
save_info(op_node, out_var_name, "out_threshold", "post_avg")
save_info(op_node, out_var_name,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_avg")
for block_id in range(len(_program.blocks)):
for op in _program.blocks[block_id].ops:
if op.type in (_quantizable_op_type + utils._out_scale_op_list):
out_var_names = utils._get_op_output_var_names(op)
for var_name in out_var_names:
analysis_and_save_info(op, var_name)
feed_vars = [_program.global_block().var(name) for name in _feed_list]
model_name = model_filename.split('.')[
0] if model_filename is not None else 'model'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册