提交 06eef983 编写于 作者: L Liufang Sang 提交者: Bai Yifan

add default config and quantize_op_types description for tensorrt and paddle-lite (#23)

* refine quanter

* add details in doc

* check type of for_tensorrt and is_full_quantize

* add description for tensorrt
上级 00f971c1
...@@ -20,8 +20,7 @@ quant_config = { ...@@ -20,8 +20,7 @@ quant_config = {
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
'dtype': 'int8', 'dtype': 'int8',
'window_size': 10000, 'window_size': 10000,
'moving_rate': 0.9, 'moving_rate': 0.9
'quant_weight_only': False
} }
``` ```
...@@ -49,7 +48,7 @@ compiled_train_prog = compiled_train_prog.with_data_parallel( ...@@ -49,7 +48,7 @@ compiled_train_prog = compiled_train_prog.with_data_parallel(
### 4. freeze program ### 4. freeze program
``` ```
float_program, int8_program = convert(val_program, float_program, int8_program = convert(val_program,
place, place,
quant_config, quant_config,
scope=None, scope=None,
......
...@@ -78,27 +78,24 @@ def compress(args): ...@@ -78,27 +78,24 @@ def compress(args):
# 1. quantization configs # 1. quantization configs
############################################################################################################ ############################################################################################################
quant_config = { quant_config = {
# weight quantize type, default is 'abs_max' # weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'abs_max', 'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'abs_max' # activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'moving_average_abs_max', 'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8 # weight quantize bit num, default is 8
'weight_bits': 8, 'weight_bits': 8,
# activation quantize bit num, default is 8 # activation quantize bit num, default is 8
'activation_bits': 8, 'activation_bits': 8,
# op of name_scope in not_quant_pattern list, will not quantized # ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'], 'not_quant_pattern': ['skip_quant'],
# op of types in quantize_op_types, will quantized # ops of type in quantize_op_types, will be quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, default is 'int8' # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8', 'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000 # window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000, 'window_size': 10000,
# The decay coefficient of moving average, default is 0.9 # The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9, 'moving_rate': 0.9,
# if set quant_weight_only True, then only quantize parameters of layers which need quantization,
# and insert anti-quantization op for parameters of these layers.
'quant_weight_only': False
} }
train_reader = None train_reader = None
...@@ -141,8 +138,10 @@ def compress(args): ...@@ -141,8 +138,10 @@ def compress(args):
# According to the weight and activation quantization type, the graph will be added # According to the weight and activation quantization type, the graph will be added
# some fake quantize operators and fake dequantize operators. # some fake quantize operators and fake dequantize operators.
############################################################################################################ ############################################################################################################
val_program = quant_aware(val_program, place, quant_config, scope=None, for_test=True) val_program = quant_aware(
compiled_train_prog = quant_aware(train_prog, place, quant_config, scope=None, for_test=False) val_program, place, quant_config, scope=None, for_test=True)
compiled_train_prog = quant_aware(
train_prog, place, quant_config, scope=None, for_test=False)
opt = create_optimizer(args) opt = create_optimizer(args)
opt.minimize(avg_cost) opt.minimize(avg_cost)
...@@ -152,7 +151,8 @@ def compress(args): ...@@ -152,7 +151,8 @@ def compress(args):
if args.pretrained_model: if args.pretrained_model:
def if_exist(var): def if_exist(var):
return os.path.exists(os.path.join(args.pretrained_model, var.name)) return os.path.exists(
os.path.join(args.pretrained_model, var.name))
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist) fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
...@@ -199,9 +199,9 @@ def compress(args): ...@@ -199,9 +199,9 @@ def compress(args):
build_strategy.sync_batch_norm = False build_strategy.sync_batch_norm = False
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
compiled_train_prog = compiled_train_prog.with_data_parallel( compiled_train_prog = compiled_train_prog.with_data_parallel(
loss_name=avg_cost.name, loss_name=avg_cost.name,
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
batch_id = 0 batch_id = 0
for data in train_reader(): for data in train_reader():
...@@ -242,8 +242,8 @@ def compress(args): ...@@ -242,8 +242,8 @@ def compress(args):
# 4. Save inference model # 4. Save inference model
############################################################################################################ ############################################################################################################
model_path = os.path.join(quantization_model_save_dir, args.model, model_path = os.path.join(quantization_model_save_dir, args.model,
'act_' + quant_config['activation_quantize_type'] + '_w_' + quant_config[ 'act_' + quant_config['activation_quantize_type']
'weight_quantize_type']) + '_w_' + quant_config['weight_quantize_type'])
float_path = os.path.join(model_path, 'float') float_path = os.path.join(model_path, 'float')
int8_path = os.path.join(model_path, 'int8') int8_path = os.path.join(model_path, 'int8')
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
...@@ -252,7 +252,8 @@ def compress(args): ...@@ -252,7 +252,8 @@ def compress(args):
fluid.io.save_inference_model( fluid.io.save_inference_model(
dirname=float_path, dirname=float_path,
feeded_var_names=[image.name], feeded_var_names=[image.name],
target_vars=[out], executor=exe, target_vars=[out],
executor=exe,
main_program=float_program, main_program=float_program,
model_filename=float_path + '/model', model_filename=float_path + '/model',
params_filename=float_path + '/params') params_filename=float_path + '/params')
...@@ -260,7 +261,8 @@ def compress(args): ...@@ -260,7 +261,8 @@ def compress(args):
fluid.io.save_inference_model( fluid.io.save_inference_model(
dirname=int8_path, dirname=int8_path,
feeded_var_names=[image.name], feeded_var_names=[image.name],
target_vars=[out], executor=exe, target_vars=[out],
executor=exe,
main_program=int8_program, main_program=int8_program,
model_filename=int8_path + '/model', model_filename=int8_path + '/model',
params_filename=int8_path + '/params') params_filename=int8_path + '/params')
......
...@@ -4,29 +4,50 @@ ...@@ -4,29 +4,50 @@
通过字典配置量化参数 通过字典配置量化参数
``` ```
quant_config_default = { TENSORRT_OP_TYPES = [
'weight_quantize_type': 'abs_max', 'mul', 'conv2d', 'pool2d', 'depthwise_conv2d', 'elementwise_add',
'activation_quantize_type': 'abs_max', 'leaky_relu'
]
TRANSFORM_PASS_OP_TYPES = ['conv2d', 'depthwise_conv2d', 'mul']
QUANT_DEQUANT_PASS_OP_TYPES = [
"pool2d", "elementwise_add", "concat", "softmax", "argmax", "transpose",
"equal", "gather", "greater_equal", "greater_than", "less_equal",
"less_than", "mean", "not_equal", "reshape", "reshape2",
"bilinear_interp", "nearest_interp", "trilinear_interp", "slice",
"squeeze", "elementwise_sub", "relu", "relu6", "leaky_relu", "tanh", "swish"
]
_quant_config_default = {
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8, 'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8, 'activation_bits': 8,
# ops of name_scope in not_quant_pattern list, will not be quantized # ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'], 'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized # ops of type in quantize_op_types, will be quantized
'quantize_op_types': 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
['conv2d', 'depthwise_conv2d', 'mul', 'elementwise_add', 'pool2d'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8', 'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000 # window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000, 'window_size': 10000,
# The decay coefficient of moving average, default is 0.9 # The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9, 'moving_rate': 0.9,
# if True, 'quantize_op_types' will be TENSORRT_OP_TYPES
'for_tensorrt': False,
# if True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES
'is_full_quantize': False
} }
``` ```
**参数:** **参数:**
- **weight_quantize_type(str)** - 参数量化方式。可选``'abs_max'``, ``'channel_wise_abs_max'``, ``'range_abs_max'``, ``'moving_average_abs_max'`` 默认``'abs_max'`` - **weight_quantize_type(str)** - 参数量化方式。可选``'abs_max'``, ``'channel_wise_abs_max'``, ``'range_abs_max'``, ``'moving_average_abs_max'``如果使用``TensorRT``加载量化后的模型来预测,请使用``'channel_wise_abs_max'``。 默认``'channel_wise_abs_max'``
- **activation_quantize_type(str)** - 激活量化方式,可选``'abs_max'``, ``'range_abs_max'``, ``'moving_average_abs_max'``,默认``'abs_max'`` - **activation_quantize_type(str)** - 激活量化方式,可选``'abs_max'``, ``'range_abs_max'``, ``'moving_average_abs_max'``。如果使用``TensorRT``加载量化后的模型来预测,请使用``'range_abs_max', 'moving_average_abs_max'``。,默认``'moving_average_abs_max'``
- **weight_bits(int)** - 参数量化bit数,默认8, 推荐设为8。 - **weight_bits(int)** - 参数量化bit数,默认8, 推荐设为8。
- **activation_bits(int)** - 激活量化bit数,默认8, 推荐设为8。 - **activation_bits(int)** - 激活量化bit数,默认8, 推荐设为8。
- **not_quant_pattern(str | list[str])** - 所有``name_scope``包含``'not_quant_pattern'``字符串的``op``,都不量化, 设置方式请参考[*fluid.name_scope*](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/name_scope_cn.html#name-scope) - **not_quant_pattern(str | list[str])** - 所有``name_scope``包含``'not_quant_pattern'``字符串的``op``,都不量化, 设置方式请参考[*fluid.name_scope*](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/name_scope_cn.html#name-scope)
...@@ -34,7 +55,12 @@ quant_config_default = { ...@@ -34,7 +55,12 @@ quant_config_default = {
- **dtype(int8)** - 量化后的参数类型,默认 ``int8``, 目前仅支持``int8`` - **dtype(int8)** - 量化后的参数类型,默认 ``int8``, 目前仅支持``int8``
- **window_size(int)** - ``'range_abs_max'``量化方式的``window size``,默认10000。 - **window_size(int)** - ``'range_abs_max'``量化方式的``window size``,默认10000。
- **moving_rate(int)** - ``'moving_average_abs_max'``量化方式的衰减系数,默认 0.9。 - **moving_rate(int)** - ``'moving_average_abs_max'``量化方式的衰减系数,默认 0.9。
- **for_tensorrt(bool)** - 量化后的模型是否使用``TensorRT``进行预测。如果是的话,量化op类型为:``TENSORRT_OP_TYPES``。默认值为False.
- **is_full_quantize(bool)** - 是否量化所有可支持op类型。默认值为False.
!!! note "注意事项"
- 目前``Paddle-Lite``有int8 kernel来加速的op只有 ``['conv2d', 'depthwise_conv2d', 'mul']``, 其他op的int8 kernel将陆续支持。
## quant_aware ## quant_aware
paddleslim.quant.quant_aware(program, place, config, scope=None, for_test=False)[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py) paddleslim.quant.quant_aware(program, place, config, scope=None, for_test=False)[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py)
...@@ -67,7 +93,7 @@ paddleslim.quant.quant_aware(program, place, config, scope=None, for_test=False) ...@@ -67,7 +93,7 @@ paddleslim.quant.quant_aware(program, place, config, scope=None, for_test=False)
## convert ## convert
paddleslim.quant.convert(program, place, config, scope=None, save_int8=False)[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py) paddleslim.quant.convert(program, place, config, scope=None, save_int8=False)[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py)
...@@ -135,7 +161,7 @@ inference_prog = quant.convert(quant_eval_program, place, config) ...@@ -135,7 +161,7 @@ inference_prog = quant.convert(quant_eval_program, place, config)
更详细的用法请参考 <a href='https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_aware'>量化训练demo</a> 更详细的用法请参考 <a href='https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_aware'>量化训练demo</a>
## quant_post ## quant_post
paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_generator, model_filename=None, params_filename=None, batch_size=16,batch_nums=None, scope=None, algo='KL', quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"])[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py) paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_generator, model_filename=None, params_filename=None, batch_size=16,batch_nums=None, scope=None, algo='KL', quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], is_full_quantize=False, is_use_cache_file=False, cache_dir="./temp_post_training")[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py)
: 对保存在``${model_dir}``下的模型进行量化,使用``sample_generator``的数据进行参数校正。 : 对保存在``${model_dir}``下的模型进行量化,使用``sample_generator``的数据进行参数校正。
...@@ -152,6 +178,9 @@ paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_gene ...@@ -152,6 +178,9 @@ paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_gene
- **scope(fluid.Scope, optional)** - 用来获取和写入``Variable``, 如果设置为``None``,则使用[*fluid.global_scope()*](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html). 默认值是``None``. - **scope(fluid.Scope, optional)** - 用来获取和写入``Variable``, 如果设置为``None``,则使用[*fluid.global_scope()*](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html). 默认值是``None``.
- **algo(str)** - 量化时使用的算法名称,可为``'KL'``或者``'direct'``。该参数仅针对激活值的量化,因为参数值的量化使用的方式为``'channel_wise_abs_max'``. 当``algo`` 设置为``'direct'``时,使用校正数据的激活值的绝对值的最大值当作``Scale``值,当设置为``'KL'``时,则使用``KL``散度的方法来计算``Scale``值。默认值为``'KL'`` - **algo(str)** - 量化时使用的算法名称,可为``'KL'``或者``'direct'``。该参数仅针对激活值的量化,因为参数值的量化使用的方式为``'channel_wise_abs_max'``. 当``algo`` 设置为``'direct'``时,使用校正数据的激活值的绝对值的最大值当作``Scale``值,当设置为``'KL'``时,则使用``KL``散度的方法来计算``Scale``值。默认值为``'KL'``
- **quantizable_op_type(list[str])** - 需要量化的``op``类型列表。默认值为``["conv2d", "depthwise_conv2d", "mul"]`` - **quantizable_op_type(list[str])** - 需要量化的``op``类型列表。默认值为``["conv2d", "depthwise_conv2d", "mul"]``
- **is_full_quantize(bool)** - 是否量化所有可支持的op类型。如果设置为False, 则按照 ``'quantizable_op_type'`` 的设置进行量化。
- **is_use_cache_file(bool)** - 是否使用硬盘对中间结果进行存储。如果为False, 则将中间结果存储在内存中。
- **cache_dir(str)** - 如果 ``'is_use_cache_file'``为True, 则将中间结果存储在此参数设置的路径下。
**返回** **返回**
...@@ -159,7 +188,8 @@ paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_gene ...@@ -159,7 +188,8 @@ paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_gene
!!! note "注意事项" !!! note "注意事项"
因为该接口会收集校正数据的所有的激活值,所以使用的校正图片不能太多。``'KL'``散度的计算也比较耗时。 - 因为该接口会收集校正数据的所有的激活值,当校正图片比较多时,请设置``'is_use_cache_file'``为True, 将中间结果存储在硬盘中。另外,``'KL'``散度的计算比较耗时。
- 目前``Paddle-Lite``有int8 kernel来加速的op只有 ``['conv2d', 'depthwise_conv2d', 'mul']``, 其他op的int8 kernel将陆续支持。
**代码示例** **代码示例**
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
import copy import copy
import logging
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
...@@ -24,22 +26,37 @@ from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization ...@@ -24,22 +26,37 @@ from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid import core from paddle.fluid import core
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
WEIGHT_QUANTIZATION_TYPES = [ WEIGHT_QUANTIZATION_TYPES = [
'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'abs_max', 'channel_wise_abs_max', 'range_abs_max',
'moving_average_abs_max' 'moving_average_abs_max'
] ]
WEIGHT_QUANTIZATION_TYPES_TENSORRT = ['channel_wise_abs_max']
ACTIVATION_QUANTIZATION_TYPES = [ ACTIVATION_QUANTIZATION_TYPES = [
'abs_max', 'range_abs_max', 'moving_average_abs_max' 'abs_max', 'range_abs_max', 'moving_average_abs_max'
] ]
ACTIVATION_QUANTIZATION_TYPES_TENSORRT = [
'range_abs_max', 'moving_average_abs_max'
]
VALID_DTYPES = ['int8'] VALID_DTYPES = ['int8']
TRANSFORM_PASS_OP_TYPES = ['conv2d', 'depthwise_conv2d', 'mul'] TRANSFORM_PASS_OP_TYPES = QuantizationTransformPass._supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = ['elementwise_add', 'pool2d'] QUANT_DEQUANT_PASS_OP_TYPES = AddQuantDequantPass._supported_quantizable_op_type + \
AddQuantDequantPass._activation_type
TENSORRT_OP_TYPES = [
'mul', 'conv2d', 'pool2d', 'depthwise_conv2d', 'elementwise_add',
'leaky_relu'
]
_quant_config_default = { _quant_config_default = {
# weight quantize type, default is 'abs_max' # weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'abs_max', 'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'abs_max' # activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'abs_max', 'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8 # weight quantize bit num, default is 8
'weight_bits': 8, 'weight_bits': 8,
# activation quantize bit num, default is 8 # activation quantize bit num, default is 8
...@@ -47,25 +64,25 @@ _quant_config_default = { ...@@ -47,25 +64,25 @@ _quant_config_default = {
# ops of name_scope in not_quant_pattern list, will not be quantized # ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'], 'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized # ops of type in quantize_op_types, will be quantized
'quantize_op_types': 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
['conv2d', 'depthwise_conv2d', 'mul', 'elementwise_add', 'pool2d'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8', 'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000 # window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000, 'window_size': 10000,
# The decay coefficient of moving average, default is 0.9 # The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9, 'moving_rate': 0.9,
# if set quant_weight_only True, then only quantize parameters of layers which need to be quantized, # if True, 'quantize_op_types' will be TENSORRT_OP_TYPES
# and activations will not be quantized. 'for_tensorrt': False,
'quant_weight_only': False # if True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES
'is_full_quantize': False
} }
def _parse_configs(user_config): def _parse_configs(user_config):
""" """
check user configs is valid, and set default value if user not config. check if user's configs are valid.
Args: Args:
user_config(dict):the config of user. user_config(dict): user's config.
Return: Return:
configs(dict): final configs will be used. configs(dict): final configs will be used.
""" """
...@@ -73,12 +90,26 @@ def _parse_configs(user_config): ...@@ -73,12 +90,26 @@ def _parse_configs(user_config):
configs = copy.deepcopy(_quant_config_default) configs = copy.deepcopy(_quant_config_default)
configs.update(user_config) configs.update(user_config)
# check configs is valid assert isinstance(configs['for_tensorrt'], bool) and isinstance(
assert configs['weight_quantize_type'] in WEIGHT_QUANTIZATION_TYPES, \ configs['is_full_quantize'],
"Unknown weight_quantize_type: '%s'. It can only be " + " ".join(WEIGHT_QUANTIZATION_TYPES) bool), "'for_tensorrt' and 'is_full_quantize' must both be bool'"
# check if configs is valid
if configs['for_tensorrt']:
weight_types = WEIGHT_QUANTIZATION_TYPES_TENSORRT
activation_types = ACTIVATION_QUANTIZATION_TYPES_TENSORRT
platform = 'TensorRT'
else:
weight_types = WEIGHT_QUANTIZATION_TYPES
activation_types = WEIGHT_QUANTIZATION_TYPES
platform = 'PaddleLite'
assert configs['weight_quantize_type'] in weight_types, \
"Unknown weight_quantize_type: {}. {} only supports {} ".format(configs['weight_quantize_type'],
platform, weight_types)
assert configs['activation_quantize_type'] in ACTIVATION_QUANTIZATION_TYPES, \ assert configs['activation_quantize_type'] in activation_types, \
"Unknown activation_quantize_type: '%s'. It can only be " + " ".join(ACTIVATION_QUANTIZATION_TYPES) "Unknown activation_quantize_type: {}. {} only supports {}".format(configs['activation_quantize_type'],
platform, activation_types)
assert isinstance(configs['weight_bits'], int), \ assert isinstance(configs['weight_bits'], int), \
"weight_bits must be int value." "weight_bits must be int value."
...@@ -92,17 +123,24 @@ def _parse_configs(user_config): ...@@ -92,17 +123,24 @@ def _parse_configs(user_config):
assert (configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \ assert (configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \
"activation_bits should be between 1 and 16." "activation_bits should be between 1 and 16."
assert isinstance(configs['not_quant_pattern'], list), \ assert isinstance(configs['not_quant_pattern'], (list, str)), \
"not_quant_pattern must be a list" "not_quant_pattern must be list or str"
assert isinstance(configs['quantize_op_types'], list), \ assert isinstance(configs['quantize_op_types'], list), \
"quantize_op_types must be a list" "quantize_op_types must be a list"
for op_type in configs['quantize_op_types']: if configs['for_tensorrt']:
assert (op_type in QUANT_DEQUANT_PASS_OP_TYPES) or ( configs['quantize_op_types'] = TENSORRT_OP_TYPES
op_type in TRANSFORM_PASS_OP_TYPES), "{} is not support, \ elif configs['is_full_quantize']:
now support op types are {}".format( configs[
op_type, TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES) 'quantize_op_types'] = TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES
else:
for op_type in configs['quantize_op_types']:
assert (op_type in QUANT_DEQUANT_PASS_OP_TYPES) or (
op_type in TRANSFORM_PASS_OP_TYPES), "{} is not support, \
now support op types are {}".format(
op_type,
TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES)
assert isinstance(configs['dtype'], str), \ assert isinstance(configs['dtype'], str), \
"dtype must be a str." "dtype must be a str."
...@@ -116,36 +154,31 @@ def _parse_configs(user_config): ...@@ -116,36 +154,31 @@ def _parse_configs(user_config):
assert isinstance(configs['moving_rate'], float), \ assert isinstance(configs['moving_rate'], float), \
"moving_rate must be float value, The decay coefficient of moving average, default is 0.9." "moving_rate must be float value, The decay coefficient of moving average, default is 0.9."
assert isinstance(configs['quant_weight_only'], bool), \
"quant_weight_only must be bool value, if set quant_weight_only True, " \
"then only quantize parameters of layers which need to be quantized, " \
" and activations will not be quantized."
return configs return configs
def quant_aware(program, place, config, scope=None, for_test=False): def quant_aware(program, place, config=None, scope=None, for_test=False):
""" """
add trainable quantization ops in program. add trainable quantization ops in program.
Args: Args:
program(fluid.Program): program program(fluid.Program): program to quant
scope(fluid.Scope): the scope to store var, it's should be the value of program's scope, usually it's fluid.global_scope(). place(fluid.CPUPlace or fluid.CUDAPlace): CPU or CUDA device
place(fluid.CPUPlace or fluid.CUDAPlace): place config(dict, optional): configs for quantization. if None, will use default config. Default is None.
config(dict): configs for quantization, default values are in quant_config_default dict. scope(fluid.Scope): the scope to store var, it should be program's scope. if None, will use fluid.global_scope().
for_test: if program is test program, for_test should be set True, else False. default is None.
for_test(bool): if program is test program, set True when program is for test, False when program is for train. Default is False.
Return: Return:
fluid.Program: user can finetune this quantization program to enhance the accuracy. fluid.Program: user can finetune this quantization program to enhance the accuracy.
""" """
scope = fluid.global_scope() if not scope else scope scope = fluid.global_scope() if not scope else scope
assert isinstance(config, dict), "config must be dict" if config is None:
config = _quant_config_default
assert 'weight_quantize_type' in config.keys( else:
), 'weight_quantize_type must be configured' assert isinstance(config, dict), "config must be dict"
assert 'activation_quantize_type' in config.keys( config = _parse_configs(config)
), 'activation_quantize_type must be configured' _logger.info("quant_aware config {}".format(config))
config = _parse_configs(config)
main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) main_graph = IrGraph(core.Graph(program.desc), for_test=for_test)
transform_pass_ops = [] transform_pass_ops = []
...@@ -197,7 +230,10 @@ def quant_post(executor, ...@@ -197,7 +230,10 @@ def quant_post(executor,
batch_nums=None, batch_nums=None,
scope=None, scope=None,
algo='KL', algo='KL',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"]): quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
is_use_cache_file=False,
cache_dir="./temp_post_training"):
""" """
The function utilizes post training quantization method to quantize the The function utilizes post training quantization method to quantize the
fp32 model. It uses calibrate data to calculate the scale factor of fp32 model. It uses calibrate data to calculate the scale factor of
...@@ -232,6 +268,11 @@ def quant_post(executor, ...@@ -232,6 +268,11 @@ def quant_post(executor,
quantizable_op_type(list[str], optional): The list of op types quantizable_op_type(list[str], optional): The list of op types
that will be quantized. Default is ["conv2d", "depthwise_conv2d", that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"]. "mul"].
is_full_quantize(bool): if True, apply quantization to all supported quantizable op type.
If False, only apply quantization to the input quantizable_op_type. Default is False.
is_use_cache_file(bool): If False, all temp data will be saved in memory. If True,
all temp data will be saved to disk. Defalut is False.
cache_dir(str): When 'is_use_cache_file' is True, temp data will be save in 'cache_dir'. Default is './temp_post_training'.
Returns: Returns:
None None
""" """
...@@ -246,41 +287,64 @@ def quant_post(executor, ...@@ -246,41 +287,64 @@ def quant_post(executor,
scope=scope, scope=scope,
algo=algo, algo=algo,
quantizable_op_type=quantizable_op_type, quantizable_op_type=quantizable_op_type,
is_full_quantize=False) is_full_quantize=is_full_quantize,
is_use_cache_file=is_use_cache_file,
cache_dir=cache_dir)
post_training_quantization.quantize() post_training_quantization.quantize()
post_training_quantization.save_quantized_model(quantize_model_path) post_training_quantization.save_quantized_model(quantize_model_path)
def convert(program, place, config, scope=None, save_int8=False): def convert(program, place, config=None, scope=None, save_int8=False):
""" """
add quantization ops in program. the program returned is not trainable. change quantization ops order in program. return program that can used by Paddle-Lite.
Args: Args:
program(fluid.Program): program program(fluid.Program): program that returned by quant_aware
scope(fluid.Scope): the scope to store var, when is None will use fluid.global_scope() place(fluid.CPUPlace or fluid.CUDAPlace): CPU or CUDA device
place(fluid.CPUPlace or fluid.CUDAPlace): place scope(fluid.Scope, optional): the scope to store var, it should be program's scope. if None, will use fluid.global_scope().
config(dict): configs for quantization, default values are in quant_config_default dict. default is None.
save_int8: is export int8 freezed program. config(dict, optional): configs for convert. if set None, will use default config. Default is None.\
It must be same with config that used in 'quant_aware'.
save_int8: if return int8 freezed program. Int8 program can only be used to check size of model weights. \
It cannot be used in Fluid or Paddle-Lite.
Return: Return:
fluid.Program: freezed program which can be used for inference. freezed_program(fluid.Program): freezed program which can be used for inference.
parameters is float32 type, but it's value in int8 range. parameters is float32 type, but it's value in int8 range.
fluid.Program: freezed int8 program which can be used for inference. freezed_program_int8(fluid.Program): freezed int8 program.
if save_int8 is False, this value is None. when save_int8 is False, return freezed_program.
when save_int8 is True, return freezed_program and freezed_program_int8
""" """
scope = fluid.global_scope() if not scope else scope scope = fluid.global_scope() if not scope else scope
if config is None:
config = _quant_config_default
else:
assert isinstance(config, dict), "config must be dict"
config = _parse_configs(config)
_logger.info("convert config {}".format(config))
test_graph = IrGraph(core.Graph(program.desc), for_test=True) test_graph = IrGraph(core.Graph(program.desc), for_test=True)
support_op_types = []
for op in config['quantize_op_types']:
if op in QuantizationFreezePass._supported_quantizable_op_type:
support_op_types.append(op)
# Freeze the graph after training by adjusting the quantize # Freeze the graph after training by adjusting the quantize
# operators' order for the inference. # operators' order for the inference.
freeze_pass = QuantizationFreezePass( freeze_pass = QuantizationFreezePass(
scope=scope, scope=scope,
place=place, place=place,
weight_quantize_type=config['weight_quantize_type']) weight_bits=config['weight_bits'],
activation_bits=config['activation_bits'],
weight_quantize_type=config['weight_quantize_type'],
quantizable_op_type=support_op_types)
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
freezed_program = test_graph.to_program() freezed_program = test_graph.to_program()
if save_int8: if save_int8:
convert_int8_pass = ConvertToInt8Pass( convert_int8_pass = ConvertToInt8Pass(
scope=fluid.global_scope(), place=place) scope=fluid.global_scope(),
place=place,
quantizable_op_type=support_op_types)
convert_int8_pass.apply(test_graph) convert_int8_pass.apply(test_graph)
freezed_program_int8 = test_graph.to_program() freezed_program_int8 = test_graph.to_program()
return freezed_program, freezed_program_int8 return freezed_program, freezed_program_int8
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册