提交 9406373b 编写于 作者: I itminner

add comments; add user config check

上级 d2c912d8
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
......@@ -7,43 +22,113 @@ from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid import core
def quant_aware(program, scope, place, config, for_test=False, loss_name=''):
QUANTIZATION_TYPES=['abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max']
quant_config_default = {
# weight quantize type, default is 'abs_max'
'weight_quantize_type': 'abs_max',
# activation quantize type, default is 'abs_max'
'activation_quantize_type': 'abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8,
# ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
# if set quant_weight_only True, then only quantize parameters of layers which need to be quantized,
# and activations will not be quantized.
'quant_weight_only': False
}
def _parse_configs(user_config):
"""
check user configs is valid, and set default value if user not config.
Args:
user_config(dict):the config of user.
Return:
configs(dict): final configs will be used.
"""
configs = copy.deepcopy(quant_config_default)
configs.update(user_config)
# check configs is valid
assert configs['weight_quantize_type'] in QUANTIZATION_TYPES, \
"Unknown weight_quantize_type: '%s'. It can only be " \
"'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'."
assert configs['activation_quantize_type'] in QUANTIZATION_TYPES, \
"Unknown activation_quantize_type: '%s'. It can only be " \
"'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'."
assert isinstance(configs['weight_bits'], int), \
"weight_bits must be int value, such as 8, 16, 32, etc"
assert isinstance(configs['activation_bits'], int), \
"activation_bits must be int value, such as 8, 16, 32, etc"
assert isinstance(configs['not_quant_pattern'], list), \
"not_quant_pattern must be a list"
assert isinstance(configs['quantize_op_types'], list), \
"quantize_op_types must be a list"
assert isinstance(configs['dtype'], str), \
"dtype must be a str, it can be config as 'int8', 'uint8', 'int16', etc."
assert isinstance(configs['window_size'], int), \
"window_size must be int value, window size for 'range_abs_max' quantization, default is 10000."
assert isinstance(configs['moving_rate'], float), \
"moving_rate must be float value, The decay coefficient of moving average, default is 0.9."
assert isinstance(configs['quant_weight_only'], bool), \
"quant_weight_only must be bool value, if set quant_weight_only True, " \
"then only quantize parameters of layers which need to be quantized, " \
" and activations will not be quantized."
return configs
def quant_aware(program, scope, place, config, for_test=False):
"""
add trainable quantization ops in program.
Args:
program(fluid.Program): program
scope(fluid.Scope): the scope to store var, when is None will use fluid.global_scope()
place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): configs for quantization, default values are in quant_config_default dict.
for_test: is for test program.
Return:
fluid.Program: user can finetune this quantization program to enhance the accuracy.
"""
scope = fluid.global_scope() if not scope else scope
assert isinstance(config, dict), "config must be dict"
assert 'weight_quant_type' in config.keys(), 'weight_quant_type must be configured'
assert 'activation_quant_type' in config.keys(), 'activation_quant_type must be configured'
config = _parse_configs(config)
main_graph = IrGraph(core.Graph(program.desc), for_test=for_test)
weight_quant_type = 'abs_max'
activation_quant_type = 'abs_max'
if 'weight_quantize_type' in config:
weight_quant_type = config['weight_quantize_type']
if 'activation_quantize_type' in config:
activation_quant_type = config['activation_quantize_type']
weight_bits = 8
activation_bits = 8
if 'weight_bits' in config:
weight_bits = config['weight_bits']
if 'activation_bits' in config:
activation_bits = config['activation_bits']
window_size=10000
if 'window_size' in config:
window_size = config['window_size']
moving_rate = 10000
if 'moving_rate' in config:
moving_rate = config['moving_rate']
not_quant_pattern=['skip_quant']
assert not_quant_pattern is list, 'not_quant_pattern should config as list, for example, not_quant_pattern:["skip_quant"]'
transform_pass = QuantizationTransformPass(
scope=scope, place=place,
weight_bits=weight_bits,
activation_bits=activation_bits,
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quant_type,
window_size=window_size,
moving_rate=moving_rate,
weight_bits=config['weight_bits'],
activation_bits=config['activation_bits'],
activation_quantize_type=config['activation_quant_type'],
weight_quantize_type=config['weight_quant_type'],
window_size=config['window_size'],
moving_rate=config['moving_rate'],
skip_pattern=''#not_quant_pattern
)
......@@ -57,19 +142,31 @@ def quant_aware(program, scope, place, config, for_test=False, loss_name=''):
return quant_program
def quant_post(program, scope, place, config):
main_graph = IrGraph(core.Graph(program.desc), for_test=True)
"""
add quantization ops in program. the program returned is not trainable.
Args:
program(fluid.Program): program
scope(fluid.Scope): the scope to store var, when is None will use fluid.global_scope()
place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): configs for quantization, default values are in quant_config_default dict.
for_test: is for test program.
Return:
fluid.Program: the quantization program is not trainable.
"""
scope = fluid.global_scope() if not scope else scope
assert isinstance(config, dict), "config must be dict"
assert 'weight_quant_type' in config.keys(), 'weight_quant_type must be configured'
assert 'activation_quant_type' in config.keys(), 'activation_quant_type must be configured'
config = _parse_configs(config)
weight_quant_type = 'abs_max'
activation_quant_type = 'abs_max'
if 'weight_quantize_type' in config:
weight_quant_type = config['weight_quantize_type']
if 'activation_quantize_type' in config:
activation_quant_type = config['activation_quantize_type']
main_graph = IrGraph(core.Graph(program.desc), for_test=True)
transform_pass = QuantizationTransformPass(
scope=scope, place=place,
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quant_type)
activation_quantize_type=config['activation_quant_type'],
weight_quantize_type=config['weight_quant_type'])
transform_pass.apply(main_graph)
......@@ -77,18 +174,28 @@ def quant_post(program, scope, place, config):
return quant_program
def convert(program, scope, place, config, save_int8=False):
"""
add quantization ops in program. the program returned is not trainable.
Args:
program(fluid.Program): program
scope(fluid.Scope): the scope to store var, when is None will use fluid.global_scope()
place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): configs for quantization, default values are in quant_config_default dict.
save_int8: is export int8 freezed program.
Return:
fluid.Program: freezed program which can be used for inference.
parameters is float32 type, but it's value in int8 range.
fluid.Program: freezed int8 program which can be used for inference.
"""
test_graph = IrGraph(core.Graph(program.desc), for_test=True)
# 2. Freeze the graph after training by adjusting the quantize
# Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
weight_quant_type = 'abs_max'
if 'weight_quantize_type' in config:
weight_quant_type = config['weight_quantize_type']
freeze_pass = QuantizationFreezePass(
scope=scope,
place=place,
weight_quantize_type=weight_quant_type)
weight_quantize_type=config['weight_quant_type'])
freeze_pass.apply(test_graph)
freezed_program = test_graph.to_program()
freezed_program_int8 = None
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
import random
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .mobilenet import MobileNet
from .resnet import ResNet34, ResNet50
from .mobilenet_v2 import MobileNetV2
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......
......@@ -16,6 +16,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import distutils.util
import os
import numpy as np
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册