未验证 提交 2afa546f 编写于 作者: L Liufang Sang 提交者: GitHub

fix quant_aware for 1.6 (#152)

上级 f437281b
......@@ -12,7 +12,7 @@ sys.path.append(sys.path[0] + "/../../../")
sys.path.append(sys.path[0] + "/../../")
from paddleslim.common import get_logger
from paddleslim.analysis import flops
from paddleslim.quant import quant_aware, quant_post, convert
from paddleslim.quant import quant_aware, convert
import models
from utility import add_arguments, print_arguments
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import copy
import logging
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
......@@ -22,6 +23,9 @@ from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid import core
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
WEIGHT_QUANTIZATION_TYPES = [
'abs_max', 'channel_wise_abs_max', 'range_abs_max',
'moving_average_abs_max'
......@@ -40,10 +44,8 @@ _quant_config_default = {
'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8,
# ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# ops of name_scope in not_quant_pattern , will not be quantized
'not_quant_pattern': 'skip_quant',
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000
......@@ -84,12 +86,9 @@ def _parse_configs(user_config):
assert (configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \
"activation_bits should be between 1 and 16."
assert isinstance(configs['not_quant_pattern'], list), \
assert isinstance(configs['not_quant_pattern'], str), \
"not_quant_pattern must be a list"
assert isinstance(configs['quantize_op_types'], list), \
"quantize_op_types must be a list"
assert isinstance(configs['dtype'], str), \
"dtype must be a str."
......@@ -102,11 +101,6 @@ def _parse_configs(user_config):
assert isinstance(configs['moving_rate'], float), \
"moving_rate must be float value, The decay coefficient of moving average, default is 0.9."
assert isinstance(configs['quant_weight_only'], bool), \
"quant_weight_only must be bool value, if set quant_weight_only True, " \
"then only quantize parameters of layers which need to be quantized, " \
" and activations will not be quantized."
return configs
......@@ -142,7 +136,6 @@ def quant_aware(program, place, config=None, scope=None, for_test=False):
weight_quantize_type=config['weight_quantize_type'],
window_size=config['window_size'],
moving_rate=config['moving_rate'],
quantizable_op_type=config['quantize_op_types'],
skip_pattern=config['not_quant_pattern'])
transform_pass.apply(main_graph)
......@@ -185,6 +178,8 @@ def convert(program, place, config=None, scope=None, save_int8=False):
freeze_pass = QuantizationFreezePass(
scope=scope,
place=place,
weight_bits=config['weight_bits'],
activation_bits=config['activation_bits'],
weight_quantize_type=config['weight_quantize_type'])
freeze_pass.apply(test_graph)
freezed_program = test_graph.to_program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册