提交 06eef983 编写于 作者: L Liufang Sang 提交者: Bai Yifan

add default config and quantize_op_types description for tensorrt and paddle-lite (#23)

* refine quanter

* add details in doc

* check type of for_tensorrt and is_full_quantize

* add description for tensorrt
上级 00f971c1
......@@ -20,8 +20,7 @@ quant_config = {
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
'dtype': 'int8',
'window_size': 10000,
'moving_rate': 0.9,
'quant_weight_only': False
'moving_rate': 0.9
}
```
......@@ -49,7 +48,7 @@ compiled_train_prog = compiled_train_prog.with_data_parallel(
### 4. freeze program
```
float_program, int8_program = convert(val_program,
float_program, int8_program = convert(val_program,
place,
quant_config,
scope=None,
......
......@@ -78,27 +78,24 @@ def compress(args):
# 1. quantization configs
############################################################################################################
quant_config = {
# weight quantize type, default is 'abs_max'
'weight_quantize_type': 'abs_max',
# activation quantize type, default is 'abs_max'
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8,
# op of name_scope in not_quant_pattern list, will not quantized
# ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'],
# op of types in quantize_op_types, will quantized
# ops of type in quantize_op_types, will be quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, default is 'int8'
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
# if set quant_weight_only True, then only quantize parameters of layers which need quantization,
# and insert anti-quantization op for parameters of these layers.
'quant_weight_only': False
}
train_reader = None
......@@ -141,8 +138,10 @@ def compress(args):
# According to the weight and activation quantization type, the graph will be added
# some fake quantize operators and fake dequantize operators.
############################################################################################################
val_program = quant_aware(val_program, place, quant_config, scope=None, for_test=True)
compiled_train_prog = quant_aware(train_prog, place, quant_config, scope=None, for_test=False)
val_program = quant_aware(
val_program, place, quant_config, scope=None, for_test=True)
compiled_train_prog = quant_aware(
train_prog, place, quant_config, scope=None, for_test=False)
opt = create_optimizer(args)
opt.minimize(avg_cost)
......@@ -152,7 +151,8 @@ def compress(args):
if args.pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(args.pretrained_model, var.name))
return os.path.exists(
os.path.join(args.pretrained_model, var.name))
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
......@@ -199,9 +199,9 @@ def compress(args):
build_strategy.sync_batch_norm = False
exec_strategy = fluid.ExecutionStrategy()
compiled_train_prog = compiled_train_prog.with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
batch_id = 0
for data in train_reader():
......@@ -242,8 +242,8 @@ def compress(args):
# 4. Save inference model
############################################################################################################
model_path = os.path.join(quantization_model_save_dir, args.model,
'act_' + quant_config['activation_quantize_type'] + '_w_' + quant_config[
'weight_quantize_type'])
'act_' + quant_config['activation_quantize_type']
+ '_w_' + quant_config['weight_quantize_type'])
float_path = os.path.join(model_path, 'float')
int8_path = os.path.join(model_path, 'int8')
if not os.path.isdir(model_path):
......@@ -252,7 +252,8 @@ def compress(args):
fluid.io.save_inference_model(
dirname=float_path,
feeded_var_names=[image.name],
target_vars=[out], executor=exe,
target_vars=[out],
executor=exe,
main_program=float_program,
model_filename=float_path + '/model',
params_filename=float_path + '/params')
......@@ -260,7 +261,8 @@ def compress(args):
fluid.io.save_inference_model(
dirname=int8_path,
feeded_var_names=[image.name],
target_vars=[out], executor=exe,
target_vars=[out],
executor=exe,
main_program=int8_program,
model_filename=int8_path + '/model',
params_filename=int8_path + '/params')
......
......@@ -4,29 +4,50 @@
通过字典配置量化参数
```
quant_config_default = {
'weight_quantize_type': 'abs_max',
'activation_quantize_type': 'abs_max',
TENSORRT_OP_TYPES = [
'mul', 'conv2d', 'pool2d', 'depthwise_conv2d', 'elementwise_add',
'leaky_relu'
]
TRANSFORM_PASS_OP_TYPES = ['conv2d', 'depthwise_conv2d', 'mul']
QUANT_DEQUANT_PASS_OP_TYPES = [
"pool2d", "elementwise_add", "concat", "softmax", "argmax", "transpose",
"equal", "gather", "greater_equal", "greater_than", "less_equal",
"less_than", "mean", "not_equal", "reshape", "reshape2",
"bilinear_interp", "nearest_interp", "trilinear_interp", "slice",
"squeeze", "elementwise_sub", "relu", "relu6", "leaky_relu", "tanh", "swish"
]
_quant_config_default = {
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8,
# ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized
'quantize_op_types':
['conv2d', 'depthwise_conv2d', 'mul', 'elementwise_add', 'pool2d'],
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
# if True, 'quantize_op_types' will be TENSORRT_OP_TYPES
'for_tensorrt': False,
# if True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES
'is_full_quantize': False
}
```
**参数:**
- **weight_quantize_type(str)** - 参数量化方式。可选``'abs_max'``, ``'channel_wise_abs_max'``, ``'range_abs_max'``, ``'moving_average_abs_max'`` 默认``'abs_max'``
- **activation_quantize_type(str)** - 激活量化方式,可选``'abs_max'``, ``'range_abs_max'``, ``'moving_average_abs_max'``,默认``'abs_max'``
- **weight_quantize_type(str)** - 参数量化方式。可选``'abs_max'``, ``'channel_wise_abs_max'``, ``'range_abs_max'``, ``'moving_average_abs_max'``如果使用``TensorRT``加载量化后的模型来预测,请使用``'channel_wise_abs_max'``。 默认``'channel_wise_abs_max'``
- **activation_quantize_type(str)** - 激活量化方式,可选``'abs_max'``, ``'range_abs_max'``, ``'moving_average_abs_max'``。如果使用``TensorRT``加载量化后的模型来预测,请使用``'range_abs_max', 'moving_average_abs_max'``。,默认``'moving_average_abs_max'``
- **weight_bits(int)** - 参数量化bit数,默认8, 推荐设为8。
- **activation_bits(int)** - 激活量化bit数,默认8, 推荐设为8。
- **not_quant_pattern(str | list[str])** - 所有``name_scope``包含``'not_quant_pattern'``字符串的``op``,都不量化, 设置方式请参考[*fluid.name_scope*](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/name_scope_cn.html#name-scope)
......@@ -34,7 +55,12 @@ quant_config_default = {
- **dtype(int8)** - 量化后的参数类型,默认 ``int8``, 目前仅支持``int8``
- **window_size(int)** - ``'range_abs_max'``量化方式的``window size``,默认10000。
- **moving_rate(int)** - ``'moving_average_abs_max'``量化方式的衰减系数,默认 0.9。
- **for_tensorrt(bool)** - 量化后的模型是否使用``TensorRT``进行预测。如果是的话,量化op类型为:``TENSORRT_OP_TYPES``。默认值为False.
- **is_full_quantize(bool)** - 是否量化所有可支持op类型。默认值为False.
!!! note "注意事项"
- 目前``Paddle-Lite``有int8 kernel来加速的op只有 ``['conv2d', 'depthwise_conv2d', 'mul']``, 其他op的int8 kernel将陆续支持。
## quant_aware
paddleslim.quant.quant_aware(program, place, config, scope=None, for_test=False)[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py)
......@@ -67,7 +93,7 @@ paddleslim.quant.quant_aware(program, place, config, scope=None, for_test=False)
## convert
## convert
paddleslim.quant.convert(program, place, config, scope=None, save_int8=False)[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py)
......@@ -135,7 +161,7 @@ inference_prog = quant.convert(quant_eval_program, place, config)
更详细的用法请参考 <a href='https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_aware'>量化训练demo</a>
## quant_post
paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_generator, model_filename=None, params_filename=None, batch_size=16,batch_nums=None, scope=None, algo='KL', quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"])[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py)
paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_generator, model_filename=None, params_filename=None, batch_size=16,batch_nums=None, scope=None, algo='KL', quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], is_full_quantize=False, is_use_cache_file=False, cache_dir="./temp_post_training")[[源代码]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py)
: 对保存在``${model_dir}``下的模型进行量化,使用``sample_generator``的数据进行参数校正。
......@@ -152,6 +178,9 @@ paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_gene
- **scope(fluid.Scope, optional)** - 用来获取和写入``Variable``, 如果设置为``None``,则使用[*fluid.global_scope()*](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html). 默认值是``None``.
- **algo(str)** - 量化时使用的算法名称,可为``'KL'``或者``'direct'``。该参数仅针对激活值的量化,因为参数值的量化使用的方式为``'channel_wise_abs_max'``. 当``algo`` 设置为``'direct'``时,使用校正数据的激活值的绝对值的最大值当作``Scale``值,当设置为``'KL'``时,则使用``KL``散度的方法来计算``Scale``值。默认值为``'KL'``
- **quantizable_op_type(list[str])** - 需要量化的``op``类型列表。默认值为``["conv2d", "depthwise_conv2d", "mul"]``
- **is_full_quantize(bool)** - 是否量化所有可支持的op类型。如果设置为False, 则按照 ``'quantizable_op_type'`` 的设置进行量化。
- **is_use_cache_file(bool)** - 是否使用硬盘对中间结果进行存储。如果为False, 则将中间结果存储在内存中。
- **cache_dir(str)** - 如果 ``'is_use_cache_file'``为True, 则将中间结果存储在此参数设置的路径下。
**返回**
......@@ -159,7 +188,8 @@ paddleslim.quant.quant_post(executor, model_dir, quantize_model_path,sample_gene
!!! note "注意事项"
因为该接口会收集校正数据的所有的激活值,所以使用的校正图片不能太多。``'KL'``散度的计算也比较耗时。
- 因为该接口会收集校正数据的所有的激活值,当校正图片比较多时,请设置``'is_use_cache_file'``为True, 将中间结果存储在硬盘中。另外,``'KL'``散度的计算比较耗时。
- 目前``Paddle-Lite``有int8 kernel来加速的op只有 ``['conv2d', 'depthwise_conv2d', 'mul']``, 其他op的int8 kernel将陆续支持。
**代码示例**
......
......@@ -13,6 +13,8 @@
# limitations under the License.
import copy
import logging
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
......@@ -24,22 +26,37 @@ from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid import core
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
WEIGHT_QUANTIZATION_TYPES = [
'abs_max', 'channel_wise_abs_max', 'range_abs_max',
'moving_average_abs_max'
]
WEIGHT_QUANTIZATION_TYPES_TENSORRT = ['channel_wise_abs_max']
ACTIVATION_QUANTIZATION_TYPES = [
'abs_max', 'range_abs_max', 'moving_average_abs_max'
]
ACTIVATION_QUANTIZATION_TYPES_TENSORRT = [
'range_abs_max', 'moving_average_abs_max'
]
VALID_DTYPES = ['int8']
TRANSFORM_PASS_OP_TYPES = ['conv2d', 'depthwise_conv2d', 'mul']
QUANT_DEQUANT_PASS_OP_TYPES = ['elementwise_add', 'pool2d']
TRANSFORM_PASS_OP_TYPES = QuantizationTransformPass._supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = AddQuantDequantPass._supported_quantizable_op_type + \
AddQuantDequantPass._activation_type
TENSORRT_OP_TYPES = [
'mul', 'conv2d', 'pool2d', 'depthwise_conv2d', 'elementwise_add',
'leaky_relu'
]
_quant_config_default = {
# weight quantize type, default is 'abs_max'
'weight_quantize_type': 'abs_max',
# activation quantize type, default is 'abs_max'
'activation_quantize_type': 'abs_max',
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8,
# activation quantize bit num, default is 8
......@@ -47,25 +64,25 @@ _quant_config_default = {
# ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized
'quantize_op_types':
['conv2d', 'depthwise_conv2d', 'mul', 'elementwise_add', 'pool2d'],
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
# if set quant_weight_only True, then only quantize parameters of layers which need to be quantized,
# and activations will not be quantized.
'quant_weight_only': False
# if True, 'quantize_op_types' will be TENSORRT_OP_TYPES
'for_tensorrt': False,
# if True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES
'is_full_quantize': False
}
def _parse_configs(user_config):
"""
check user configs is valid, and set default value if user not config.
check if user's configs are valid.
Args:
user_config(dict):the config of user.
user_config(dict): user's config.
Return:
configs(dict): final configs will be used.
"""
......@@ -73,12 +90,26 @@ def _parse_configs(user_config):
configs = copy.deepcopy(_quant_config_default)
configs.update(user_config)
# check configs is valid
assert configs['weight_quantize_type'] in WEIGHT_QUANTIZATION_TYPES, \
"Unknown weight_quantize_type: '%s'. It can only be " + " ".join(WEIGHT_QUANTIZATION_TYPES)
assert isinstance(configs['for_tensorrt'], bool) and isinstance(
configs['is_full_quantize'],
bool), "'for_tensorrt' and 'is_full_quantize' must both be bool'"
# check if configs is valid
if configs['for_tensorrt']:
weight_types = WEIGHT_QUANTIZATION_TYPES_TENSORRT
activation_types = ACTIVATION_QUANTIZATION_TYPES_TENSORRT
platform = 'TensorRT'
else:
weight_types = WEIGHT_QUANTIZATION_TYPES
activation_types = WEIGHT_QUANTIZATION_TYPES
platform = 'PaddleLite'
assert configs['weight_quantize_type'] in weight_types, \
"Unknown weight_quantize_type: {}. {} only supports {} ".format(configs['weight_quantize_type'],
platform, weight_types)
assert configs['activation_quantize_type'] in ACTIVATION_QUANTIZATION_TYPES, \
"Unknown activation_quantize_type: '%s'. It can only be " + " ".join(ACTIVATION_QUANTIZATION_TYPES)
assert configs['activation_quantize_type'] in activation_types, \
"Unknown activation_quantize_type: {}. {} only supports {}".format(configs['activation_quantize_type'],
platform, activation_types)
assert isinstance(configs['weight_bits'], int), \
"weight_bits must be int value."
......@@ -92,17 +123,24 @@ def _parse_configs(user_config):
assert (configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \
"activation_bits should be between 1 and 16."
assert isinstance(configs['not_quant_pattern'], list), \
"not_quant_pattern must be a list"
assert isinstance(configs['not_quant_pattern'], (list, str)), \
"not_quant_pattern must be list or str"
assert isinstance(configs['quantize_op_types'], list), \
"quantize_op_types must be a list"
for op_type in configs['quantize_op_types']:
assert (op_type in QUANT_DEQUANT_PASS_OP_TYPES) or (
op_type in TRANSFORM_PASS_OP_TYPES), "{} is not support, \
now support op types are {}".format(
op_type, TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES)
if configs['for_tensorrt']:
configs['quantize_op_types'] = TENSORRT_OP_TYPES
elif configs['is_full_quantize']:
configs[
'quantize_op_types'] = TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES
else:
for op_type in configs['quantize_op_types']:
assert (op_type in QUANT_DEQUANT_PASS_OP_TYPES) or (
op_type in TRANSFORM_PASS_OP_TYPES), "{} is not support, \
now support op types are {}".format(
op_type,
TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES)
assert isinstance(configs['dtype'], str), \
"dtype must be a str."
......@@ -116,36 +154,31 @@ def _parse_configs(user_config):
assert isinstance(configs['moving_rate'], float), \
"moving_rate must be float value, The decay coefficient of moving average, default is 0.9."
assert isinstance(configs['quant_weight_only'], bool), \
"quant_weight_only must be bool value, if set quant_weight_only True, " \
"then only quantize parameters of layers which need to be quantized, " \
" and activations will not be quantized."
return configs
def quant_aware(program, place, config, scope=None, for_test=False):
def quant_aware(program, place, config=None, scope=None, for_test=False):
"""
add trainable quantization ops in program.
Args:
program(fluid.Program): program
scope(fluid.Scope): the scope to store var, it's should be the value of program's scope, usually it's fluid.global_scope().
place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): configs for quantization, default values are in quant_config_default dict.
for_test: if program is test program, for_test should be set True, else False.
program(fluid.Program): program to quant
place(fluid.CPUPlace or fluid.CUDAPlace): CPU or CUDA device
config(dict, optional): configs for quantization. if None, will use default config. Default is None.
scope(fluid.Scope): the scope to store var, it should be program's scope. if None, will use fluid.global_scope().
default is None.
for_test(bool): if program is test program, set True when program is for test, False when program is for train. Default is False.
Return:
fluid.Program: user can finetune this quantization program to enhance the accuracy.
"""
scope = fluid.global_scope() if not scope else scope
assert isinstance(config, dict), "config must be dict"
assert 'weight_quantize_type' in config.keys(
), 'weight_quantize_type must be configured'
assert 'activation_quantize_type' in config.keys(
), 'activation_quantize_type must be configured'
if config is None:
config = _quant_config_default
else:
assert isinstance(config, dict), "config must be dict"
config = _parse_configs(config)
_logger.info("quant_aware config {}".format(config))
config = _parse_configs(config)
main_graph = IrGraph(core.Graph(program.desc), for_test=for_test)
transform_pass_ops = []
......@@ -197,7 +230,10 @@ def quant_post(executor,
batch_nums=None,
scope=None,
algo='KL',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"]):
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
is_use_cache_file=False,
cache_dir="./temp_post_training"):
"""
The function utilizes post training quantization method to quantize the
fp32 model. It uses calibrate data to calculate the scale factor of
......@@ -232,6 +268,11 @@ def quant_post(executor,
quantizable_op_type(list[str], optional): The list of op types
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"].
is_full_quantize(bool): if True, apply quantization to all supported quantizable op type.
If False, only apply quantization to the input quantizable_op_type. Default is False.
is_use_cache_file(bool): If False, all temp data will be saved in memory. If True,
all temp data will be saved to disk. Defalut is False.
cache_dir(str): When 'is_use_cache_file' is True, temp data will be save in 'cache_dir'. Default is './temp_post_training'.
Returns:
None
"""
......@@ -246,41 +287,64 @@ def quant_post(executor,
scope=scope,
algo=algo,
quantizable_op_type=quantizable_op_type,
is_full_quantize=False)
is_full_quantize=is_full_quantize,
is_use_cache_file=is_use_cache_file,
cache_dir=cache_dir)
post_training_quantization.quantize()
post_training_quantization.save_quantized_model(quantize_model_path)
def convert(program, place, config, scope=None, save_int8=False):
def convert(program, place, config=None, scope=None, save_int8=False):
"""
add quantization ops in program. the program returned is not trainable.
change quantization ops order in program. return program that can used by Paddle-Lite.
Args:
program(fluid.Program): program
scope(fluid.Scope): the scope to store var, when is None will use fluid.global_scope()
place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): configs for quantization, default values are in quant_config_default dict.
save_int8: is export int8 freezed program.
program(fluid.Program): program that returned by quant_aware
place(fluid.CPUPlace or fluid.CUDAPlace): CPU or CUDA device
scope(fluid.Scope, optional): the scope to store var, it should be program's scope. if None, will use fluid.global_scope().
default is None.
config(dict, optional): configs for convert. if set None, will use default config. Default is None.\
It must be same with config that used in 'quant_aware'.
save_int8: if return int8 freezed program. Int8 program can only be used to check size of model weights. \
It cannot be used in Fluid or Paddle-Lite.
Return:
fluid.Program: freezed program which can be used for inference.
freezed_program(fluid.Program): freezed program which can be used for inference.
parameters is float32 type, but it's value in int8 range.
fluid.Program: freezed int8 program which can be used for inference.
if save_int8 is False, this value is None.
freezed_program_int8(fluid.Program): freezed int8 program.
when save_int8 is False, return freezed_program.
when save_int8 is True, return freezed_program and freezed_program_int8
"""
scope = fluid.global_scope() if not scope else scope
if config is None:
config = _quant_config_default
else:
assert isinstance(config, dict), "config must be dict"
config = _parse_configs(config)
_logger.info("convert config {}".format(config))
test_graph = IrGraph(core.Graph(program.desc), for_test=True)
support_op_types = []
for op in config['quantize_op_types']:
if op in QuantizationFreezePass._supported_quantizable_op_type:
support_op_types.append(op)
# Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
freeze_pass = QuantizationFreezePass(
scope=scope,
place=place,
weight_quantize_type=config['weight_quantize_type'])
weight_bits=config['weight_bits'],
activation_bits=config['activation_bits'],
weight_quantize_type=config['weight_quantize_type'],
quantizable_op_type=support_op_types)
freeze_pass.apply(test_graph)
freezed_program = test_graph.to_program()
if save_int8:
convert_int8_pass = ConvertToInt8Pass(
scope=fluid.global_scope(), place=place)
scope=fluid.global_scope(),
place=place,
quantizable_op_type=support_op_types)
convert_int8_pass.apply(test_graph)
freezed_program_int8 = test_graph.to_program()
return freezed_program, freezed_program_int8
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册