未验证 提交 2afa546f 编写于 作者: L Liufang Sang 提交者: GitHub

fix quant_aware for 1.6 (#152)

上级 f437281b
...@@ -12,7 +12,7 @@ sys.path.append(sys.path[0] + "/../../../") ...@@ -12,7 +12,7 @@ sys.path.append(sys.path[0] + "/../../../")
sys.path.append(sys.path[0] + "/../../") sys.path.append(sys.path[0] + "/../../")
from paddleslim.common import get_logger from paddleslim.common import get_logger
from paddleslim.analysis import flops from paddleslim.analysis import flops
from paddleslim.quant import quant_aware, quant_post, convert from paddleslim.quant import quant_aware, convert
import models import models
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import copy import copy
import logging
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
...@@ -22,6 +23,9 @@ from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass ...@@ -22,6 +23,9 @@ from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid import core from paddle.fluid import core
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
WEIGHT_QUANTIZATION_TYPES = [ WEIGHT_QUANTIZATION_TYPES = [
'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'abs_max', 'channel_wise_abs_max', 'range_abs_max',
'moving_average_abs_max' 'moving_average_abs_max'
...@@ -40,10 +44,8 @@ _quant_config_default = { ...@@ -40,10 +44,8 @@ _quant_config_default = {
'weight_bits': 8, 'weight_bits': 8,
# activation quantize bit num, default is 8 # activation quantize bit num, default is 8
'activation_bits': 8, 'activation_bits': 8,
# ops of name_scope in not_quant_pattern list, will not be quantized # ops of name_scope in not_quant_pattern , will not be quantized
'not_quant_pattern': ['skip_quant'], 'not_quant_pattern': 'skip_quant',
# ops of type in quantize_op_types, will be quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8', 'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000 # window size for 'range_abs_max' quantization. defaulf is 10000
...@@ -84,12 +86,9 @@ def _parse_configs(user_config): ...@@ -84,12 +86,9 @@ def _parse_configs(user_config):
assert (configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \ assert (configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \
"activation_bits should be between 1 and 16." "activation_bits should be between 1 and 16."
assert isinstance(configs['not_quant_pattern'], list), \ assert isinstance(configs['not_quant_pattern'], str), \
"not_quant_pattern must be a list" "not_quant_pattern must be a list"
assert isinstance(configs['quantize_op_types'], list), \
"quantize_op_types must be a list"
assert isinstance(configs['dtype'], str), \ assert isinstance(configs['dtype'], str), \
"dtype must be a str." "dtype must be a str."
...@@ -102,11 +101,6 @@ def _parse_configs(user_config): ...@@ -102,11 +101,6 @@ def _parse_configs(user_config):
assert isinstance(configs['moving_rate'], float), \ assert isinstance(configs['moving_rate'], float), \
"moving_rate must be float value, The decay coefficient of moving average, default is 0.9." "moving_rate must be float value, The decay coefficient of moving average, default is 0.9."
assert isinstance(configs['quant_weight_only'], bool), \
"quant_weight_only must be bool value, if set quant_weight_only True, " \
"then only quantize parameters of layers which need to be quantized, " \
" and activations will not be quantized."
return configs return configs
...@@ -142,7 +136,6 @@ def quant_aware(program, place, config=None, scope=None, for_test=False): ...@@ -142,7 +136,6 @@ def quant_aware(program, place, config=None, scope=None, for_test=False):
weight_quantize_type=config['weight_quantize_type'], weight_quantize_type=config['weight_quantize_type'],
window_size=config['window_size'], window_size=config['window_size'],
moving_rate=config['moving_rate'], moving_rate=config['moving_rate'],
quantizable_op_type=config['quantize_op_types'],
skip_pattern=config['not_quant_pattern']) skip_pattern=config['not_quant_pattern'])
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
...@@ -185,6 +178,8 @@ def convert(program, place, config=None, scope=None, save_int8=False): ...@@ -185,6 +178,8 @@ def convert(program, place, config=None, scope=None, save_int8=False):
freeze_pass = QuantizationFreezePass( freeze_pass = QuantizationFreezePass(
scope=scope, scope=scope,
place=place, place=place,
weight_bits=config['weight_bits'],
activation_bits=config['activation_bits'],
weight_quantize_type=config['weight_quantize_type']) weight_quantize_type=config['weight_quantize_type'])
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
freezed_program = test_graph.to_program() freezed_program = test_graph.to_program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册