提交 35d10239 编写于 作者: I itminner

fix readme

上级 11de67b9
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
本示例介绍如何使用在线量化接口,来对训练好的分类模型进行量化, 可以减少模型的存储空间和显存占用。 本示例介绍如何使用在线量化接口,来对训练好的分类模型进行量化, 可以减少模型的存储空间和显存占用。
## 接口介绍 ## 接口介绍
``` ```
quant_config_default = { quant_config_default = {
'weight_quantize_type': 'abs_max', 'weight_quantize_type': 'abs_max',
...@@ -25,6 +26,7 @@ quant_config_default = { ...@@ -25,6 +26,7 @@ quant_config_default = {
'quant_weight_only': False 'quant_weight_only': False
} }
``` ```
量化配置表。 量化配置表。
参数说明: 参数说明:
- weight_quantize_type(str): 参数量化方式。可选'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max',默认'abs_max'。 - weight_quantize_type(str): 参数量化方式。可选'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max',默认'abs_max'。
...@@ -32,11 +34,12 @@ quant_config_default = { ...@@ -32,11 +34,12 @@ quant_config_default = {
- weight_bits(int): 参数量化bit数,默认8。 - weight_bits(int): 参数量化bit数,默认8。
- activation_bits(int): 激活量化bit数,默认8。 - activation_bits(int): 激活量化bit数,默认8。
- not_quant_pattern(str or str list): 所有name_scope包含not_quant_pattern字符串的op,都不量化。 - not_quant_pattern(str or str list): 所有name_scope包含not_quant_pattern字符串的op,都不量化。
- quantize_op_types(str of list): 需要进行量化的op类型。 - quantize_op_types(str of list): 需要进行量化的op类型,目前支持'conv2d', 'depthwise_conv2d', 'mul', 'elementwise_add', 'pool2d'
- dtype(int8): 量化后的参数类型,默认int8。 - dtype(int8): 量化后的参数类型,默认int8。
- window_size(int): 'range_abs_max'量化的window size,默认10000。 - window_size(int): 'range_abs_max'量化的window size,默认10000。
- moving_rate(int): moving_average_abs_max 量化的衰减系数,默认 0.9。 - moving_rate(int): moving_average_abs_max 量化的衰减系数,默认 0.9。
- quant_weight_only(bool): 是否只量化参数,如果设为True,则激活不进行量化,默认False。 - quant_weight_only(bool): 是否只量化参数,如果设为True,则激活不进行量化,默认False。
``` ```
def quant_aware(program, def quant_aware(program,
place, place,
...@@ -44,10 +47,11 @@ def quant_aware(program, ...@@ -44,10 +47,11 @@ def quant_aware(program,
scope=None, scope=None,
for_test=False) for_test=False)
``` ```
该接口会对传入的program插入可训练量化op。 该接口会对传入的program插入可训练量化op。
参数介绍: 参数介绍:
- program (fluid.program): 传入训练或测试program。 - program (fluid.program): 传入训练或测试program。
- place(fluid.CPUPlace or fluid.CUDAPlace(N): 该参数表示Executor执行所在的设备,这里的N为GPU对应的ID - place(fluid.CPUPlace or fluid.CUDAPlace): 该参数表示Executor执行所在的设备
- config(dict): 量化配置表。 - config(dict): 量化配置表。
- scope(fluid.Scope): 传入用于存储var的scope,需要传入program所使用的scope,一般情况下,是fluid.global_scope()。 - scope(fluid.Scope): 传入用于存储var的scope,需要传入program所使用的scope,一般情况下,是fluid.global_scope()。
- for_test(bool): 如果program参数是一个测试用program,for_test应设为True,否则设为False。 - for_test(bool): 如果program参数是一个测试用program,for_test应设为True,否则设为False。
...@@ -63,14 +67,15 @@ def convert(program, ...@@ -63,14 +67,15 @@ def convert(program,
scope=None, scope=None,
save_int8=False) save_int8=False)
``` ```
把训练好的量化program,转换为可用于保存inference model的program。 把训练好的量化program,转换为可用于保存inference model的program。
注意,本接口返回的program,不可用于训练。 注意,本接口返回的program,不可用于训练。
参数介绍: 参数介绍:
- program (fluid.program): 传入测试program。 - program (fluid.program): 传入测试program。
- place(fluid.CPUPlace or fluid.CUDAPlace(N): 该参数表示Executor执行所在的设备,这里的N为GPU对应的ID - place(fluid.CPUPlace or fluid.CUDAPlace): 该参数表示Executor执行所在的设备
- config(dict): 量化配置表。 - config(dict): 量化配置表。
- -scope(fluid.Scope): 传入用于存储var的scope,需要传入program所使用的scope,一般情况下,是fluid.global_scope()。 - scope(fluid.Scope): 传入用于存储var的scope,需要传入program所使用的scope,一般情况下,是fluid.global_scope()。
- save_int8(bool): 是否需要导出参数为int8的program。 - save_int8(bool): 是否需要导出参数为int8的program。(该功能目前只能用于确认模型大小)
返回参数: 返回参数:
- program (fluid.program): freezed program,可用于保存inference model,参数为float32类型,但其数值范围可用int8表示。 - program (fluid.program): freezed program,可用于保存inference model,参数为float32类型,但其数值范围可用int8表示。
...@@ -104,11 +109,10 @@ val_program = quant_aware(val_program, place, quant_config, scope=None, for_test ...@@ -104,11 +109,10 @@ val_program = quant_aware(val_program, place, quant_config, scope=None, for_test
compiled_train_prog = quant_aware(train_prog, place, quant_config, scope=None, for_test=False) compiled_train_prog = quant_aware(train_prog, place, quant_config, scope=None, for_test=False)
``` ```
###3.关掉指定build策略 ### 3.关掉指定build策略
``` ```
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
build_strategy.fuse_all_reduce_ops = False build_strategy.fuse_all_reduce_ops = False
build_strategy.sync_batch_norm = False build_strategy.sync_batch_norm = False
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
...@@ -117,7 +121,9 @@ compiled_train_prog = compiled_train_prog.with_data_parallel( ...@@ -117,7 +121,9 @@ compiled_train_prog = compiled_train_prog.with_data_parallel(
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
``` ```
###4. freeze program
### 4. freeze program
``` ```
float_program, int8_program = convert(val_program, float_program, int8_program = convert(val_program,
place, place,
...@@ -125,7 +131,9 @@ float_program, int8_program = convert(val_program, ...@@ -125,7 +131,9 @@ float_program, int8_program = convert(val_program,
scope=None, scope=None,
save_int8=True) save_int8=True)
``` ```
###5.保存预测模型
### 5.保存预测模型
``` ```
fluid.io.save_inference_model( fluid.io.save_inference_model(
dirname=float_path, dirname=float_path,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册