未验证 提交 604c0bbe 编写于 作者: C ceci3 提交者: GitHub

optimization yaml (#1157)

* optimization yaml

* update yaml

* update yaml

* fix unittest

* fix unittest
上级 769c28f5
...@@ -8,27 +8,13 @@ Global: ...@@ -8,27 +8,13 @@ Global:
params_filename: model.pdiparams params_filename: model.pdiparams
Distillation: Distillation:
distill_lambda: 1.0 node:
distill_loss: l2_loss
distill_node_pair:
- teacher_concat_15.tmp_0
- concat_15.tmp_0 - concat_15.tmp_0
- teacher_concat_14.tmp_0
- concat_14.tmp_0 - concat_14.tmp_0
merge_feed: true
teacher_model_dir: ./ppyoloe_crn_l_300e_coco/
teacher_model_filename: model.pdmodel
teacher_params_filename: model.pdiparams
Quantization: Quantization:
use_pact: true use_pact: true
activation_bits: 8
weight_bits: 8
activation_quantize_type: 'range_abs_max' activation_quantize_type: 'range_abs_max'
weight_quantize_type: 'channel_wise_abs_max'
is_full_quantize: false
not_quant_pattern:
- skip_quant
quantize_op_types: quantize_op_types:
- conv2d - conv2d
- depthwise_conv2d - depthwise_conv2d
...@@ -37,7 +23,8 @@ TrainConfig: ...@@ -37,7 +23,8 @@ TrainConfig:
train_iter: 3000 train_iter: 3000
eval_iter: 1000 eval_iter: 1000
learning_rate: 0.00001 learning_rate: 0.00001
optimizer: SGD optimizer_builder:
optim_args: optimizer:
type: SGD
weight_decay: 4.0e-05 weight_decay: 4.0e-05
...@@ -7,33 +7,25 @@ Global: ...@@ -7,33 +7,25 @@ Global:
params_filename: model.pdiparams params_filename: model.pdiparams
Distillation: Distillation:
distill_lambda: 1.0 alpha: 1.0
distill_loss: l2_loss loss: l2
distill_node_pair: node:
- teacher_conv2d_441.tmp_0
- conv2d_441.tmp_0 - conv2d_441.tmp_0
merge_feed: true
teacher_model_dir: ./tinypose_128x96/
teacher_model_filename: model.pdmodel
teacher_params_filename: model.pdiparams
Quantization: Quantization:
activation_bits: 8
is_full_quantize: false
activation_quantize_type: 'range_abs_max' activation_quantize_type: 'range_abs_max'
weight_quantize_type: 'abs_max' weight_quantize_type: 'abs_max'
not_quant_pattern:
- skip_quant
quantize_op_types: quantize_op_types:
- conv2d - conv2d
- depthwise_conv2d - depthwise_conv2d
weight_bits: 8
TrainConfig: TrainConfig:
epochs: 1 epochs: 1
eval_iter: 1000 eval_iter: 1000
learning_rate: 0.0001 learning_rate: 0.0001
optimizer: SGD optimizer_builder:
optim_args: optimizer:
type: SGD
weight_decay: 4.0e-05 weight_decay: 4.0e-05
#origin_metric: 0.291 #origin_metric: 0.291
...@@ -7,28 +7,15 @@ Global: ...@@ -7,28 +7,15 @@ Global:
params_filename: model.pdiparams params_filename: model.pdiparams
Distillation: Distillation:
distill_lambda: 1.0 alpha: 1.0
distill_loss: l2_loss loss: l2
distill_node_pair: node:
- teacher_conv2d_84.tmp_0
- conv2d_84.tmp_0 - conv2d_84.tmp_0
- teacher_conv2d_85.tmp_0
- conv2d_85.tmp_0 - conv2d_85.tmp_0
- teacher_conv2d_86.tmp_0
- conv2d_86.tmp_0 - conv2d_86.tmp_0
merge_feed: true
teacher_model_dir: ./yolov3_mobilenet_v1_270e_coco/
teacher_model_filename: model.pdmodel
teacher_params_filename: model.pdiparams
Quantization: Quantization:
activation_bits: 8
weight_bits: 8
activation_quantize_type: 'range_abs_max' activation_quantize_type: 'range_abs_max'
weight_quantize_type: 'channel_wise_abs_max'
is_full_quantize: false
not_quant_pattern:
- skip_quant
quantize_op_types: quantize_op_types:
- conv2d - conv2d
- depthwise_conv2d - depthwise_conv2d
...@@ -37,8 +24,9 @@ TrainConfig: ...@@ -37,8 +24,9 @@ TrainConfig:
train_iter: 3000 train_iter: 3000
eval_iter: 1000 eval_iter: 1000
learning_rate: 0.0001 learning_rate: 0.0001
optimizer: SGD optimizer_builder:
optim_args: optimizer:
type: SGD
weight_decay: 4.0e-05 weight_decay: 4.0e-05
#origin_metric: 0.289 #origin_metric: 0.289
...@@ -9,29 +9,16 @@ Global: ...@@ -9,29 +9,16 @@ Global:
params_filename: model.pdiparams params_filename: model.pdiparams
Distillation: Distillation:
distill_lambda: 1.0 alpha: 1.0
distill_loss: l2_loss loss: l2
distill_node_pair: node:
- teacher_conv2d_106.tmp_1
- conv2d_106.tmp_1 - conv2d_106.tmp_1
- teacher_conv2d_113.tmp_1
- conv2d_113.tmp_1 - conv2d_113.tmp_1
- teacher_conv2d_119.tmp_1
- conv2d_119.tmp_1 - conv2d_119.tmp_1
merge_feed: true
teacher_model_dir: ./yolov5s_infer/
teacher_model_filename: model.pdmodel
teacher_params_filename: model.pdiparams
Quantization: Quantization:
use_pact: true use_pact: true
activation_bits: 8
weight_bits: 8
activation_quantize_type: 'range_abs_max' activation_quantize_type: 'range_abs_max'
weight_quantize_type: 'channel_wise_abs_max'
is_full_quantize: false
not_quant_pattern:
- skip_quant
quantize_op_types: quantize_op_types:
- conv2d - conv2d
- depthwise_conv2d - depthwise_conv2d
...@@ -40,7 +27,8 @@ TrainConfig: ...@@ -40,7 +27,8 @@ TrainConfig:
train_iter: 3000 train_iter: 3000
eval_iter: 1000 eval_iter: 1000
learning_rate: 0.00001 learning_rate: 0.00001
optimizer: SGD optimizer_builder:
optim_args: optimizer:
type: SGD
weight_decay: 4.0e-05 weight_decay: 4.0e-05
target_metric: 0.365 target_metric: 0.365
...@@ -13,28 +13,30 @@ Quantization: ...@@ -13,28 +13,30 @@ Quantization:
weight_bits: 8 # 权重量化比特数 weight_bits: 8 # 权重量化比特数
activation_quantize_type: 'range_abs_max' # 激活量化方式 activation_quantize_type: 'range_abs_max' # 激活量化方式
weight_quantize_type: 'channel_wise_abs_max' # 权重量化方式 weight_quantize_type: 'channel_wise_abs_max' # 权重量化方式
is_full_quantize: false # 是否全量化
not_quant_pattern: [skip_quant] # 跳过量化层的name_scpoe命名(保持默认即可) not_quant_pattern: [skip_quant] # 跳过量化层的name_scpoe命名(保持默认即可)
quantize_op_types: [conv2d, depthwise_conv2d] # 量化OP列表 quantize_op_types: [conv2d, depthwise_conv2d] # 量化OP列表
dtype: 'int8' # 量化后的参数类型,默认 int8 , 目前仅支持 int8
window_size: 10000 # 'range_abs_max' 量化方式的 window size ,默认10000。
moving_rate: 0.9 # 'moving_average_abs_max' 量化方式的衰减系数,默认 0.9。
for_tensorrt: false # 量化后的模型是否使用 TensorRT 进行预测。如果是的话,量化op类型为: TENSORRT_OP_TYPES 。默认值为False.
is_full_quantize: false # 是否全量化
``` ```
#### 配置定制蒸馏策略 #### 配置定制蒸馏策略
蒸馏参数主要设置蒸馏节点(`distill_node_pair`)和教师预测模型路径,如下所示: 蒸馏参数主要设置蒸馏节点(`node`)和教师预测模型路径,如下所示:
```yaml ```yaml
Distillation: Distillation:
# distill_lambda: distill loss所占权重;可输入多个数值,支持不同节点之间使用不同的lambda值 # ahpha: 蒸馏loss所占权重;可输入多个数值,支持不同节点之间使用不同的ahpha值
distill_lambda: 1.0 lambda: 1.0
# distill_loss: 蒸馏loss算法;可输入多个loss,支持不同节点之间使用不同的loss算法 # loss: 蒸馏loss算法;可输入多个loss,支持不同节点之间使用不同的loss算法
distill_loss: l2_loss loss: l2
# distill_node_pair: 蒸馏节点,即某层输出的变量名称,需包含教师网络节点和对应的学生网络节点, # node: 蒸馏节点,即某层输出的变量名称,可以选择:
# 其中教师网络节点名称将在程序中自动添加 “teacher_” 前缀; # 1. 使用自蒸馏的话,蒸馏结点仅包含学生网络节点即可, 支持多节点蒸馏;
# 可输入多个node_pair,支持多节点蒸馏 # 2. 使用其他蒸馏的话,蒸馏节点需要包含教师网络节点和对应的学生网络节点,
distill_node_pair: # 每两个节点组成一对,分别属于教师模型和学生模型,支持多节点蒸馏。
- teacher_relu_30.tmp_0 node:
- relu_30.tmp_0 - relu_30.tmp_0
# merge_feed: 若teacher和student的输入相同则为true,若teacher和student的输入不同则为false
merge_feed: true
# teacher_model_dir: 保存预测模型文件和预测模型参数文件的文件夹名称 # teacher_model_dir: 保存预测模型文件和预测模型参数文件的文件夹名称
teacher_model_dir: ./inference_model teacher_model_dir: ./inference_model
# teacher_model_filename: 预测模型文件,格式为 *.pdmodel 或 __model__ # teacher_model_filename: 预测模型文件,格式为 *.pdmodel 或 __model__
...@@ -43,16 +45,14 @@ Distillation: ...@@ -43,16 +45,14 @@ Distillation:
teacher_params_filename: model.pdiparams teacher_params_filename: model.pdiparams
``` ```
- 蒸馏loss目前支持的有:fsp_loss,l2_loss,soft_label_loss,也可自定义loss。具体定义和使用可参考[知识蒸馏API文档](https://paddleslim.readthedocs.io/zh_CN/latest/api_cn/static/dist/single_distiller_api.html) - 蒸馏loss目前支持的有:fsp,l2,soft_label,也可自定义loss。具体定义和使用可参考[知识蒸馏API文档](https://paddleslim.readthedocs.io/zh_CN/latest/api_cn/static/dist/single_distiller_api.html)
#### 配置定制结构化稀疏策略 #### 配置定制结构化稀疏策略
结构化稀疏参数设置如下所示: 结构化稀疏参数设置如下所示:
```yaml ```yaml
Prune: ChannelPrune:
# prune_algo: 裁剪算法
prune_algo: prune
# pruned_ratio: 裁剪比例 # pruned_ratio: 裁剪比例
pruned_ratio: 0.25 pruned_ratio: 0.25
# prune_params_name: 需要裁剪的参数名字 # prune_params_name: 需要裁剪的参数名字
...@@ -61,9 +61,27 @@ Prune: ...@@ -61,9 +61,27 @@ Prune:
# criterion: 评估一个卷积层内通道重要性所参考的指标 # criterion: 评估一个卷积层内通道重要性所参考的指标
criterion: l1_norm criterion: l1_norm
``` ```
- prune_algo目前支持的有:prune、asp和transformer_pruner。
- criterion目前支持的有:l1_norm , bn_scale , geometry_median。具体定义和使用可参考[结构化稀疏API文档](https://paddleslim.readthedocs.io/zh_CN/latest/api_cn/static/prune/prune_api.html) - criterion目前支持的有:l1_norm , bn_scale , geometry_median。具体定义和使用可参考[结构化稀疏API文档](https://paddleslim.readthedocs.io/zh_CN/latest/api_cn/static/prune/prune_api.html)
#### 配置定制ASP半结构化稀疏策略
半结构化稀疏参数设置如下所示:
```yaml
ASPPrune:
# prune_params_name: 需要裁剪的参数名字
prune_params_name:
- conv1_weights
```
#### 配置定制针对Transformer结构的结构化剪枝策略
针对Transformer结构的结构化剪枝参数设置如下所示:
```yaml
TransformerPrune:
# pruned_ratio: 每个全链接层的裁剪比例
pruned_ratio: 0.25
```
#### 配置定制非结构化稀疏策略 #### 配置定制非结构化稀疏策略
非结构化稀疏参数设置如下所示: 非结构化稀疏参数设置如下所示:
...@@ -73,8 +91,8 @@ UnstructurePrune: ...@@ -73,8 +91,8 @@ UnstructurePrune:
prune_strategy: gmp prune_strategy: gmp
# prune_mode: 稀疏化的模式,可设置 'ratio' 或 'threshold' # prune_mode: 稀疏化的模式,可设置 'ratio' 或 'threshold'
prune_mode: ratio prune_mode: ratio
# pruned_ratio: 设置稀疏化比例,只有在 prune_mode=='ratio' 时才会生效 # ratio: 设置稀疏化比例,只有在 prune_mode=='ratio' 时才会生效
pruned_ratio: 0.75 ratio: 0.75
# threshold: 设置稀疏化阈值,只有在 prune_mod=='threshold' 时才会生效 # threshold: 设置稀疏化阈值,只有在 prune_mod=='threshold' 时才会生效
threshold: 0.001 threshold: 0.001
# gmp_config: 传入额外的训练超参用以指导GMP训练过程 # gmp_config: 传入额外的训练超参用以指导GMP训练过程
...@@ -112,9 +130,11 @@ TrainConfig: ...@@ -112,9 +130,11 @@ TrainConfig:
epochs: 14 epochs: 14
eval_iter: 400 eval_iter: 400
learning_rate: 5.0e-03 learning_rate: 5.0e-03
optimizer: SGD optimizer_builder:
optim_args: optimizer:
type: SGD
weight_decay: 0.0005 weight_decay: 0.0005
``` ```
- 学习率衰减策略:主要设置策略类名和策略参数,如下所示。目前在paddle中已经实现了多种衰减策略,请参考[lr文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.2/api/paddle/optimizer/lr/LRScheduler_cn.html),策略参数即类初始化参数。 - 学习率衰减策略:主要设置策略类名和策略参数,如下所示。目前在paddle中已经实现了多种衰减策略,请参考[lr文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.2/api/paddle/optimizer/lr/LRScheduler_cn.html),策略参数即类初始化参数。
```yaml ```yaml
......
Distillation: Distillation:
distill_lambda: 1.0 alpha: 1.0
distill_loss: l2_loss loss: l2
distill_node_pair: node:
- teacher_softmax_0.tmp_0
- softmax_0.tmp_0 - softmax_0.tmp_0
merge_feed: true
teacher_model_dir: MobileNetV1_infer teacher_model_dir: MobileNetV1_infer
teacher_model_filename: inference.pdmodel teacher_model_filename: inference.pdmodel
teacher_params_filename: inference.pdiparams teacher_params_filename: inference.pdiparams
...@@ -23,7 +21,8 @@ TrainConfig: ...@@ -23,7 +21,8 @@ TrainConfig:
epochs: 1 epochs: 1
eval_iter: 500 eval_iter: 500
learning_rate: 0.004 learning_rate: 0.004
optimizer: Momentum optimizer_builder:
optim_args: optimizer:
type: Momentum
weight_decay: 0.00003 weight_decay: 0.00003
origin_metric: 0.70898 origin_metric: 0.70898
\ No newline at end of file
...@@ -115,9 +115,10 @@ TrainConfig: ...@@ -115,9 +115,10 @@ TrainConfig:
epochs: 6 epochs: 6
eval_iter: 1070 eval_iter: 1070
learning_rate: 2.0e-5 learning_rate: 2.0e-5
optim_args: optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01 weight_decay: 0.01
optimizer: AdamW
origin_metric: 0.7403 origin_metric: 0.7403
``` ```
......
...@@ -2,7 +2,8 @@ TrainConfig: ...@@ -2,7 +2,8 @@ TrainConfig:
epochs: 6 epochs: 6
eval_iter: 1070 eval_iter: 1070
learning_rate: 2.0e-5 learning_rate: 2.0e-5
optim_args: optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01 weight_decay: 0.01
optimizer: AdamW
origin_metric: 0.7403 origin_metric: 0.7403
...@@ -2,7 +2,8 @@ TrainConfig: ...@@ -2,7 +2,8 @@ TrainConfig:
epochs: 100 epochs: 100
eval_iter: 70 eval_iter: 70
learning_rate: 1.0e-5 learning_rate: 1.0e-5
optim_args: optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01 weight_decay: 0.01
optimizer: AdamW
origin_metric: 0.8421 origin_metric: 0.8421
...@@ -2,7 +2,8 @@ TrainConfig: ...@@ -2,7 +2,8 @@ TrainConfig:
epochs: 6 epochs: 6
eval_iter: 2000 eval_iter: 2000
learning_rate: 3.0e-5 learning_rate: 3.0e-5
optim_args: optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01 weight_decay: 0.01
optimizer: AdamW origin_metric: 0.8098
origin_metric: 0.8098
\ No newline at end of file
...@@ -2,7 +2,8 @@ TrainConfig: ...@@ -2,7 +2,8 @@ TrainConfig:
epochs: 16 epochs: 16
eval_iter: 1000 eval_iter: 1000
learning_rate: 1.0e-5 learning_rate: 1.0e-5
optim_args: optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01 weight_decay: 0.01
optimizer: AdamW
origin_metric: 0.7736 origin_metric: 0.7736
...@@ -2,7 +2,8 @@ TrainConfig: ...@@ -2,7 +2,8 @@ TrainConfig:
epochs: 12 epochs: 12
eval_iter: 750 eval_iter: 750
learning_rate: 2.0e-5 learning_rate: 2.0e-5
optim_args: optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01 weight_decay: 0.01
optimizer: AdamW
origin_metric: 0.6021 origin_metric: 0.6021
...@@ -2,7 +2,8 @@ TrainConfig: ...@@ -2,7 +2,8 @@ TrainConfig:
epochs: 20 epochs: 20
eval_iter: 1050 eval_iter: 1050
learning_rate: 3.0e-5 learning_rate: 3.0e-5
optim_args: optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01 weight_decay: 0.01
optimizer: AdamW origin_metric: 0.7620
origin_metric: 0.7620
\ No newline at end of file
...@@ -2,7 +2,8 @@ TrainConfig: ...@@ -2,7 +2,8 @@ TrainConfig:
epochs: 6 epochs: 6
eval_iter: 1110 eval_iter: 1110
learning_rate: 2.0e-5 learning_rate: 2.0e-5
optim_args: optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01 weight_decay: 0.01
optimizer: AdamW origin_metric: 0.5666
origin_metric: 0.5666
\ No newline at end of file
...@@ -243,8 +243,8 @@ if __name__ == '__main__': ...@@ -243,8 +243,8 @@ if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
compress_config, train_config, _ = load_config(args.config_path) compress_config, train_config, _ = load_config(args.config_path)
if train_config is not None and 'optim_args' in train_config: if train_config is not None:
train_config['optim_args'][ train_config.optimizer_builder[
'apply_decay_param_fun'] = apply_decay_param_fun 'apply_decay_param_fun'] = apply_decay_param_fun
train_dataloader, eval_dataloader = reader() train_dataloader, eval_dataloader = reader()
......
...@@ -5,7 +5,8 @@ TrainConfig: ...@@ -5,7 +5,8 @@ TrainConfig:
epochs: 14 epochs: 14
eval_iter: 400 eval_iter: 400
learning_rate: 5.0e-03 learning_rate: 5.0e-03
optim_args: optimizer_builder:
optimizer:
type: SGD
weight_decay: 0.0005 weight_decay: 0.0005
optimizer: SGD
...@@ -2,28 +2,21 @@ Global: ...@@ -2,28 +2,21 @@ Global:
reader_config: configs/pp_humanseg_lite.yaml reader_config: configs/pp_humanseg_lite.yaml
Distillation: Distillation:
distill_lambda: 1.0 alpha: 1.0
distill_loss: l2_loss loss: l2
distill_node_pair: node:
- teacher_batch_norm_47.tmp_2
- batch_norm_47.tmp_2 - batch_norm_47.tmp_2
merge_feed: true
teacher_model_dir: ./ppseg_lite_portrait_398x224_with_softmax
teacher_model_filename: model.pdmodel
teacher_params_filename: model.pdiparams
Quantization: Quantization:
activation_bits: 8
is_full_quantize: false
not_quant_pattern:
- skip_quant
quantize_op_types: quantize_op_types:
- conv2d - conv2d
- depthwise_conv2d - depthwise_conv2d
weight_bits: 8
TrainConfig: TrainConfig:
epochs: 1 epochs: 1
eval_iter: 400 eval_iter: 400
learning_rate: 0.0005 learning_rate: 0.0005
optimizer: SGD optimizer_builder:
optim_args: optimizer:
weight_decay: 4.0e-05 type: SGD
weight_decay: 0.0005
...@@ -2,19 +2,15 @@ Global: ...@@ -2,19 +2,15 @@ Global:
reader_config: configs/pp_humanseg_lite.yaml reader_config: configs/pp_humanseg_lite.yaml
Distillation: Distillation:
distill_lambda: 1.0 alpha: 1.0
distill_loss: l2_loss loss: l2
distill_node_pair: node:
- teacher_batch_norm_47.tmp_2
- batch_norm_47.tmp_2 - batch_norm_47.tmp_2
merge_feed: true
teacher_model_dir: ./ppseg_lite_portrait_398x224_with_softmax
teacher_model_filename: model.pdmodel
teacher_params_filename: model.pdiparams
UnstructurePrune: UnstructurePrune:
prune_strategy: gmp prune_strategy: gmp
prune_mode: ratio prune_mode: ratio
pruned_ratio: 0.75 ratio: 0.75
gmp_config: gmp_config:
stable_iterations: 0 stable_iterations: 0
pruning_iterations: 4500 pruning_iterations: 4500
...@@ -24,6 +20,7 @@ UnstructurePrune: ...@@ -24,6 +20,7 @@ UnstructurePrune:
initial_ratio: 0.15 initial_ratio: 0.15
prune_params_type: conv1x1_only prune_params_type: conv1x1_only
local_sparsity: True local_sparsity: True
TrainConfig: TrainConfig:
epochs: 14 epochs: 14
eval_iter: 400 eval_iter: 400
...@@ -31,7 +28,7 @@ TrainConfig: ...@@ -31,7 +28,7 @@ TrainConfig:
type: PiecewiseDecay type: PiecewiseDecay
boundaries: [4500] boundaries: [4500]
values: [0.005, 0.0005] values: [0.005, 0.0005]
optim_args: optimizer_builder:
optimizer:
type: SGD
weight_decay: 0.0005 weight_decay: 0.0005
optimizer: SGD
...@@ -13,7 +13,7 @@ import numpy as np ...@@ -13,7 +13,7 @@ import numpy as np
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir) sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
import models import models
from utility import add_arguments, print_arguments, _download, _decompress from utility import add_arguments, print_arguments, _download, _decompress
from paddleslim.dist import merge, l2_loss, soft_label_loss, fsp_loss from paddleslim.dist import merge, l2, soft_label, fsp
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -173,8 +173,8 @@ def compress(args): ...@@ -173,8 +173,8 @@ def compress(args):
merge(teacher_program, student_program, data_name_map, place) merge(teacher_program, student_program, data_name_map, place)
with paddle.static.program_guard(student_program, s_startup): with paddle.static.program_guard(student_program, s_startup):
distill_loss = soft_label_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0", distill_loss = soft_label("teacher_fc_0.tmp_0", "fc_0.tmp_0",
student_program) student_program)
loss = avg_cost + distill_loss loss = avg_cost + distill_loss
lr, opt = create_optimizer(args) lr, opt = create_optimizer(args)
opt.minimize(loss) opt.minimize(loss)
......
...@@ -89,7 +89,7 @@ In order to ensure that the data of the teacher network and the student network ...@@ -89,7 +89,7 @@ In order to ensure that the data of the teacher network and the student network
data_name_map = {'image': 'image'} data_name_map = {'image': 'image'}
main = slim.dist.merge(teacher_program, student_program, data_name_map, fluid.CPUPlace()) main = slim.dist.merge(teacher_program, student_program, data_name_map, fluid.CPUPlace())
with fluid.program_guard(student_program, student_startup): with fluid.program_guard(student_program, student_startup):
l2_loss = slim.dist.l2_loss('teacher_bn5c_branch2b.output.1.tmp_3', 'depthwise_conv2d_11.tmp_0', student_program) l2_loss = slim.dist.l2('teacher_bn5c_branch2b.output.1.tmp_3', 'depthwise_conv2d_11.tmp_0', student_program)
loss = l2_loss + avg_cost loss = l2_loss + avg_cost
opt = fluid.optimizer.Momentum(0.01, 0.9) opt = fluid.optimizer.Momentum(0.01, 0.9)
opt.minimize(loss) opt.minimize(loss)
......
...@@ -27,10 +27,12 @@ AutoCompression ...@@ -27,10 +27,12 @@ AutoCompression
目前关键字只支持以下几种组合策略或者单策略配置: 目前关键字只支持以下几种组合策略或者单策略配置:
1) ``Quantization`` & ``HyperParameterOptimization``: 离线量化超参搜索策略; 1) ``Quantization`` & ``HyperParameterOptimization``: 离线量化超参搜索策略;
2) ``Quantization`` & ``Distillation``: 量化训练和蒸馏的策略; 2) ``Quantization`` & ``Distillation``: 量化训练和蒸馏的策略;
3) ``Prune`` & ``Distillation``: 结构化剪枝和蒸馏的策略; 3) ``ChannelPrune`` & ``Distillation``: 结构化剪枝和蒸馏的策略;
4) ``UnstructurePrune`` & ``Distillation``: 非结构化稀疏和蒸馏的策略; 4) ``ASPPrune`` & ``Distillation``: ASP结构化剪枝和蒸馏的策略;
5) ``Distillation``: 单独单蒸馏策略; 5) ``TransformerPrune`` & ``Distillation``: Transformer结构化剪枝和蒸馏的策略;
6) ``MultiTeacherDistillation``: 多teacher蒸馏策略。 6) ``UnstructurePrune`` & ``Distillation``: 非结构化稀疏和蒸馏的策略;
7) ``Distillation``: 单独单蒸馏策略;
8) ``MultiTeacherDistillation``: 多teacher蒸馏策略。
设置为None的话会自动的选择策略去做压缩。默认:None。 设置为None的话会自动的选择策略去做压缩。默认:None。
- **eval_callback(function, 可选)** - eval回调函数,使用回调函数判断模型训练情况, 回调函数的写法参考: `<//github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/auto-compression/custom_function.rst>`_ 。 ``eval_callback`` 和 ``eval_dataloader`` 不能都设置为None。默认:None。 - **eval_callback(function, 可选)** - eval回调函数,使用回调函数判断模型训练情况, 回调函数的写法参考: `<//github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/auto-compression/custom_function.rst>`_ 。 ``eval_callback`` 和 ``eval_dataloader`` 不能都设置为None。默认:None。
- **eval_dataloader(paddle.io.Dataloader, 可选)** - 如果传入测试数据迭代器,则使用 ``EMD`` 距离判断压缩前后模型之间的差别,目前仅支持离线量化超参搜索使用这种方式判断压缩前后模型的压缩。 - **eval_dataloader(paddle.io.Dataloader, 可选)** - 如果传入测试数据迭代器,则使用 ``EMD`` 距离判断压缩前后模型之间的差别,目前仅支持离线量化超参搜索使用这种方式判断压缩前后模型的压缩。
...@@ -62,11 +64,11 @@ AutoCompression ...@@ -62,11 +64,11 @@ AutoCompression
default_distill_config = { default_distill_config = {
"distill_loss": args.distill_loss, "loss": args.loss,
"distill_node_pair": args.distill_node_pair, "node": args.node,
"distill_lambda": args.distill_lambda, "alpha": args.alpha,
"teacher_model_dir": args.teacher_model_dir, "teacher_model_dir": args.teacher_model_dir,
...@@ -84,7 +86,7 @@ AutoCompression ...@@ -84,7 +86,7 @@ AutoCompression
strategy_config="Quantization": Quantization(**default_ptq_config), strategy_config="Quantization": Quantization(**default_ptq_config),
"HyperParameterOptimization": HyperParameterOptimization(**default_hpo_config)}, \ "Distillation": HyperParameterOptimization(**default_distill_config)}, \
train_config=None, train_dataloader=train_dataloader, eval_callback=eval_dataloader,devices='gpu') train_config=None, train_dataloader=train_dataloader, eval_callback=eval_dataloader,devices='gpu')
...@@ -104,12 +106,14 @@ TrainConfig ...@@ -104,12 +106,14 @@ TrainConfig
**参数:** **参数:**
- **epochs(int)** - 训练的轮数,表明当前数据集需要训练几次。 - **epochs(int)** - 训练的轮数,表明当前数据集需要训练几次。
- **learning_rate(float|LRScheduler)** - 模型优化过程中的学习率。 - **train_iter(int, optional)** 训练的迭代次数,表明需要迭代多少批次的数据,和 ``epoch`` 之间仅需要设置一个。
- **optimizer(str)** - 使用的优化器,需要是 ``paddle.optimizer`` 中优化器的名字, 例如: ``SGD`` 。 - **learning_rate(float|dict)** - 模型优化过程中的学习率, 如果是dict类型,则dict的关键字如下: ``type``: 学习率策略的类名,可参考 ``paddle.optimizer.lr`` 中的类设置,
- **optim_args(dict)** - 优化器参数。可以指定以下参数: 其它关键字根据实际调用的学习率的策略中的参数设置。
``grid_clip`` ,指名使用的梯度裁剪的方法,需要是 ``paddle.nn`` 中梯度裁剪的类的名字,例如: ``ClipGradByValue`` 等。 - **optimizer_builder(dict)** - 使用的优化器和相关配置。dict中对应的关键字如下:
``grad_clip_args`` ,梯度裁剪方法中的参数,例如:梯度裁剪选择的方式为 ``ClipGradByValue`` ,那么 ``grad_clip_args`` 可以设置的参数为 ``max`` 和 ``min`` ,参考: `ClipGradByValue <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/nn/ClipGradByValue_cn.html#clipgradbyvalue>`_ 。 ``optimizer(dict)``: 指定关键字 ``type`` 需要是 ``paddle.optimizer`` 中优化器的类名, 例如: ``SGD`` ,其他关键字根据具体使用的优化器中的参数设置。
其他优化器中可能需要的参数,例如: ``beta1``, ``beta2``, ``apply_decay_param_fun`` 等,参考: `AdamW <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/optimizer/AdamW_cn.html#adamw>`_ 。 ``weight_decay(float, optional)``: 压缩训练过程中的参数衰退。
``regularizer(dict)``: 指定关键字 ``type`` 需要是 ``paddle.regularizer`` 中的权重衰减正则类名,其他关键字根据具体使用的类中的参数设置。
``grid_clip`` ,指名使用的梯度裁剪的方法,需要是 ``paddle.nn`` 中梯度裁剪的类的名字,例如: ``ClipGradByValue`` 等,其他关键字根据具体使用的类中的参数设置。
- **eval_iter(int)** - 训练多少batch的数据进行一次测试。 - **eval_iter(int)** - 训练多少batch的数据进行一次测试。
- **logging_iter(int)** - 训练多少batch的数据进行一次打印。 - **logging_iter(int)** - 训练多少batch的数据进行一次打印。
...@@ -124,7 +128,7 @@ TrainConfig ...@@ -124,7 +128,7 @@ TrainConfig
参考接口: `amp_config <https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/distributed/fleet/DistributedStrategy_cn.html#amp_configs>`_ 来进行相对应的参数配置。 参考接口: `amp_config <https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/distributed/fleet/DistributedStrategy_cn.html#amp_configs>`_ 来进行相对应的参数配置。
- **recompute_config(dict, optional)** - 使用fleet api的前提下可以使用recompute显存优化逻辑。参数按照fleet 接口中所描述的进行配置: `recompute_configs <https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/distributed/fleet/DistributedStrategy_cn.html#recompute_configs>`_ 。 - **recompute_config(dict, optional)** - 使用fleet api的前提下可以使用recompute显存优化逻辑。参数按照fleet 接口中所描述的进行配置: `recompute_configs <https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/distributed/fleet/DistributedStrategy_cn.html#recompute_configs>`_ 。
- **sharding_config(dict, optional)** - 使用fleet api的前提下可以使用sharding 策略。参数按照fleet 接口中所描述的进行配置: `sharding_configs <https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/distributed/fleet/DistributedStrategy_cn.html#sharding_configs>`_ 。 - **sharding_config(dict, optional)** - 使用fleet api的前提下可以使用sharding 策略。参数按照fleet 接口中所描述的进行配置: `sharding_configs <https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/distributed/fleet/DistributedStrategy_cn.html#sharding_configs>`_ 。
- **sparse_model(bool, optional)** - 设置 ``sparse_model`` 为 True, 可以移出非结构化稀疏产出的模型中多余的mask tensor的变量,默认: False.
Quantization Quantization
---------- ----------
...@@ -147,14 +151,13 @@ Distillation ...@@ -147,14 +151,13 @@ Distillation
**参数:** **参数:**
- **distill_loss(str|list[str])** - 蒸馏损失名字,可以设置的损失类型为paddleslim中支持的蒸馏损失,可选的损失函数有: ``fsp_loss``, ``l2_loss``, ``soft_label_loss`` 。如果您需要其他损失函数,可以暂时通过向 `蒸馏损失文件<https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dist/single_distiller.py>`_ z中添加相应的损失函数计算,或者通过提issue的方式我们来协助解决。 - **loss(str|list[str])** - 蒸馏损失名字,可以设置的损失类型为paddleslim中支持的蒸馏损失,可选的损失函数有: ``fsp``, ``l2``, ``soft_label`` 。如果您需要其他损失函数,可以暂时通过向 `蒸馏损失文件<https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dist/single_distiller.py>`_ z中添加相应的损失函数计算,或者通过提issue的方式我们来协助解决。
- **distill_node_pair(list[str])** - 蒸馏节点名字列表,每两个节点组成一对,分别属于教师模型和学生模型。 - **node(list[str])** - 蒸馏节点名字列表,可以选择:1. 使用自蒸馏的话,蒸馏结点仅包含学生网络节点即可, 支持多节点蒸馏; 2. 使用其他蒸馏的话,蒸馏节点需要包含教师网络节点和对应的学生网络节点, 每两个节点组成一对,分别属于教师模型和学生模型。
- **distill_lambda(float|list[float])** - 每一个蒸馏损失的权重,长度需要和 ``distill_loss`` 的长度保持一致。 - **alpha(float|list[float])** - 每一个蒸馏损失的权重,长度需要和 ``loss`` 的长度保持一致。
- **teacher_model_dir(str)** - 教师模型的目录。 - **teacher_model_dir(str)** - 教师模型的目录。
- **teacher_model_filename(str)** - 教师模型的模型文件名字。 - **teacher_model_filename(str)** - 教师模型的模型文件名字。
- **teacher_params_filename(str)** - 教师模型的参数文件名字。 - **teacher_params_filename(str)** - 教师模型的参数文件名字。
- **merge_feed(bool)** - 蒸馏过程是否需要共享同一个输入数据。默认: ``True`` 。
MultiTeacherDistillation MultiTeacherDistillation
...@@ -164,14 +167,13 @@ MultiTeacherDistillation ...@@ -164,14 +167,13 @@ MultiTeacherDistillation
**参数:** **参数:**
- **distill_loss(list[str])** - 蒸馏损失名字,可以设置的损失类型为paddleslim中支持的蒸馏损失,可选的损失函数有: ``fsp_loss``, ``l2_loss``, ``soft_label_loss`` 。如果您需要其他损失函数,可以暂时通过向 `蒸馏损失文件<https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dist/single_distiller.py>`_ z中添加相应的损失函数计算,或者通过提issue的方式我们来协助解决。 - **loss(list[str])** - 蒸馏损失名字,可以设置的损失类型为paddleslim中支持的蒸馏损失,可选的损失函数有: ``fsp``, ``l2``, ``soft_label`` 。如果您需要其他损失函数,可以暂时通过向 `蒸馏损失文件<https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dist/single_distiller.py>`_ z中添加相应的损失函数计算,或者通过提issue的方式我们来协助解决。
- **distill_node_pair(list[list[str]])** - 蒸馏节点名字嵌套列表,教师模型的个数和外部列表的长度需要保持一致。每一个列表代表一个教师模型和学生模型直接的蒸馏节点,其中每两个节点组成一对,分别属于教师模型和学生模型。 - **node(list[list[str]])** - 蒸馏节点名字嵌套列表,教师模型的个数和外部列表的长度需要保持一致。每一个列表代表一个教师模型和学生模型直接的蒸馏节点,其中每两个节点组成一对,分别属于教师模型和学生模型。
- **distill_lambda(list[float])** - 每一个蒸馏损失的权重,长度需要和 ``distill_loss`` 的长度保持一致。 - **alpha(list[float])** - 每一个蒸馏损失的权重,长度需要和 ``distill_loss`` 的长度保持一致。
- **teacher_model_dir(list[str])** - 教师模型的目录列表。 - **teacher_model_dir(list[str])** - 教师模型的目录列表。
- **teacher_model_filename(list[str])** - 教师模型的模型文件名字列表。 - **teacher_model_filename(list[str])** - 教师模型的模型文件名字列表。
- **teacher_params_filename(list[str])** - 教师模型的参数文件名字列表。 - **teacher_params_filename(list[str])** - 教师模型的参数文件名字列表。
- **merge_feed(bool)** - 蒸馏过程是否需要共享同一个输入数据。默认: ``True`` 。
HyperParameterOptimization HyperParameterOptimization
......
...@@ -49,16 +49,16 @@ merge ...@@ -49,16 +49,16 @@ merge
data_name_map, place) data_name_map, place)
fsp_loss fsp
--------- ---------
.. py:function:: paddleslim.dist.fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name, student_var2_name, program=None) .. py:function:: paddleslim.dist.fsp(teacher_var1_name, teacher_var2_name, student_var1_name, student_var2_name, program=None)
`[源代码] <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dist/single_distiller.py#L90>`_ `[源代码] <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dist/single_distiller.py#L90>`_
为program内的teacher var和student var添加fsp_loss. 为program内的teacher var和student var添加fsp loss.
fsp_loss出自论文 `A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning <http://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf>`_ fsp loss出自论文 `A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning <http://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf>`_
**参数:** **参数:**
...@@ -70,7 +70,7 @@ fsp_loss出自论文 `A Gift from Knowledge Distillation: Fast Optimization, Net ...@@ -70,7 +70,7 @@ fsp_loss出自论文 `A Gift from Knowledge Distillation: Fast Optimization, Net
**返回:** **返回:**
- (Variable): 由teacher_var1, teacher_var2, student_var1, student_var2组合得到的fsp_loss - (Variable): 由teacher_var1, teacher_var2, student_var1, student_var2组合得到的fsp loss
**使用示例:** **使用示例:**
...@@ -96,15 +96,15 @@ fsp_loss出自论文 `A Gift from Knowledge Distillation: Fast Optimization, Net ...@@ -96,15 +96,15 @@ fsp_loss出自论文 `A Gift from Knowledge Distillation: Fast Optimization, Net
place = fluid.CUDAPlace(0) if USE_GPU else fluid.CPUPlace() place = fluid.CUDAPlace(0) if USE_GPU else fluid.CPUPlace()
dist.merge(teacher_program, student_program, data_name_map, place) dist.merge(teacher_program, student_program, data_name_map, place)
with fluid.program_guard(student_program): with fluid.program_guard(student_program):
distillation_loss = dist.fsp_loss('teacher_t1.tmp_1', 'teacher_t2.tmp_1', distillation_loss = dist.fsp('teacher_t1.tmp_1', 'teacher_t2.tmp_1',
's1.tmp_1', 's2.tmp_1', student_program) 's1.tmp_1', 's2.tmp_1', student_program)
l2_loss l2
------------ ------------
.. py:function:: paddleslim.dist.l2_loss(teacher_var_name, student_var_name, program=None) .. py:function:: paddleslim.dist.l2(teacher_var_name, student_var_name, program=None)
`[源代码] <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dist/single_distiller.py#L118>`_ `[源代码] <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dist/single_distiller.py#L118>`_
...@@ -144,15 +144,15 @@ l2_loss ...@@ -144,15 +144,15 @@ l2_loss
place = fluid.CUDAPlace(0) if USE_GPU else fluid.CPUPlace() place = fluid.CUDAPlace(0) if USE_GPU else fluid.CPUPlace()
dist.merge(teacher_program, student_program, data_name_map, place) dist.merge(teacher_program, student_program, data_name_map, place)
with fluid.program_guard(student_program): with fluid.program_guard(student_program):
distillation_loss = dist.l2_loss('teacher_t2.tmp_1', 's2.tmp_1', distillation_loss = dist.l2('teacher_t2.tmp_1', 's2.tmp_1',
student_program) student_program)
soft_label_loss soft_label
------------------- -------------------
.. py:function:: paddleslim.dist.soft_label_loss(teacher_var_name, student_var_name, program=None, teacher_temperature=1., student_temperature=1.) .. py:function:: paddleslim.dist.soft_label(teacher_var_name, student_var_name, program=None, teacher_temperature=1., student_temperature=1.)
`[源代码] <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dist/single_distiller.py#L136>`_ `[源代码] <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dist/single_distiller.py#L136>`_
...@@ -170,7 +170,7 @@ soft_label_loss出自论文 `Distilling the Knowledge in a Neural Network <https ...@@ -170,7 +170,7 @@ soft_label_loss出自论文 `Distilling the Knowledge in a Neural Network <https
**返回:** **返回:**
- (Variable): 由teacher_var, student_var组合得到的soft_label_loss - (Variable): 由teacher_var, student_var组合得到的soft label loss
**使用示例:** **使用示例:**
......
...@@ -92,7 +92,7 @@ merge操作将student_program和teacher_program中的所有Tensor和Op都将被 ...@@ -92,7 +92,7 @@ merge操作将student_program和teacher_program中的所有Tensor和Op都将被
data_name_map = {'image': 'image'} data_name_map = {'image': 'image'}
main = slim.dist.merge(teacher_program, student_program, data_name_map, paddle.CPUPlace()) main = slim.dist.merge(teacher_program, student_program, data_name_map, paddle.CPUPlace())
with paddle.static.program_guard(student_program, student_startup): with paddle.static.program_guard(student_program, student_startup):
l2_loss = slim.dist.l2_loss('teacher_bn5c_branch2b.output.1.tmp_3', 'depthwise_conv2d_11.tmp_0', student_program) l2_loss = slim.dist.l2('teacher_bn5c_branch2b.output.1.tmp_3', 'depthwise_conv2d_11.tmp_0', student_program)
loss = l2_loss + avg_cost loss = l2_loss + avg_cost
opt = paddle.optimizer.Momentum(0.01, 0.9) opt = paddle.optimizer.Momentum(0.01, 0.9)
opt.minimize(loss) opt.minimize(loss)
......
...@@ -92,7 +92,6 @@ def create_strategy_config(strategy_str, model_type): ...@@ -92,7 +92,6 @@ def create_strategy_config(strategy_str, model_type):
### default prune config ### default prune config
default_prune_config = { default_prune_config = {
'pruned_ratio': float(tmp_s[1]), 'pruned_ratio': float(tmp_s[1]),
'prune_algo': 'prune',
'criterion': 'l1_norm' 'criterion': 'l1_norm'
} }
else: else:
...@@ -105,10 +104,12 @@ def create_strategy_config(strategy_str, model_type): ...@@ -105,10 +104,12 @@ def create_strategy_config(strategy_str, model_type):
'local_sparsity': True, 'local_sparsity': True,
'prune_params_type': 'conv1x1_only' 'prune_params_type': 'conv1x1_only'
} }
tmp_s[0] = tmp_s[0].replace('prune', 'Prune') if model_type == 'transformer':
tmp_s[0] = tmp_s[0].replace('prune', 'TransformerPrune')
default_prune_config = {'pruned_ratio': float(tmp_s[1])}
else:
tmp_s[0] = tmp_s[0].replace('prune', 'Prune')
tmp_s[0] = tmp_s[0].replace('sparse', 'UnstructurePrune') tmp_s[0] = tmp_s[0].replace('sparse', 'UnstructurePrune')
if model_type == 'transformer' and tmp_s[0] == 'Prune':
default_prune_config['prune_algo'] = 'transformer_pruner'
prune_config = eval(tmp_s[0])(**default_prune_config) prune_config = eval(tmp_s[0])(**default_prune_config)
configs.append({tmp_s[0]: prune_config, 'Distillation': dis_config}) configs.append({tmp_s[0]: prune_config, 'Distillation': dis_config})
......
...@@ -82,15 +82,21 @@ class AutoCompression: ...@@ -82,15 +82,21 @@ class AutoCompression:
2. set ``Quantization`` and ``HyperParameterOptimization`` to get quant_post and hyperparameter optimization compress config. 2. set ``Quantization`` and ``HyperParameterOptimization`` to get quant_post and hyperparameter optimization compress config.
The Quantization config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L24`_ . The Quantization config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L24`_ .
The HyperParameterOptimization config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L73`_ . The HyperParameterOptimization config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L73`_ .
3. set ``Prune`` and ``Distillation`` to get prune and distillation compress config. 3. set ``ChannelPrune`` and ``Distillation`` to get channel prune and distillation compress config.
The Prune config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L82`_ . The ChannelPrune config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L82`_ .
The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L39`_ . The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L39`_ .
4. set ``UnstructurePrune`` and ``Distillation`` to get unstructureprune and distillation compress config. 4. set ``ASPPrune`` and ``Distillation`` to get asp prune and distillation compress config.
The ASPPrune config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L82`_ .
The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L39`_ .
5. set ``TransformerPrune`` and ``Distillation`` to get transformer prune and distillation compress config.
The TransformerPrune config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L82`_ .
The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L39`_ .
6. set ``UnstructurePrune`` and ``Distillation`` to get unstructureprune and distillation compress config.
The UnstructurePrune config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L91`_ . The UnstructurePrune config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L91`_ .
The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L39`_ . The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L39`_ .
5. set ``Distillation`` to use one teacher modol to distillation student model. 7. set ``Distillation`` to use one teacher modol to distillation student model.
The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L39`_ . The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L39`_ .
6. set ``MultiTeacherDistillation`` to use multi-teacher to distillation student model. 8. set ``MultiTeacherDistillation`` to use multi-teacher to distillation student model.
The MultiTeacherDistillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L56`_ . The MultiTeacherDistillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L56`_ .
If set to None, will choose a strategy automatically. Default: None. If set to None, will choose a strategy automatically. Default: None.
...@@ -149,6 +155,8 @@ class AutoCompression: ...@@ -149,6 +155,8 @@ class AutoCompression:
self._strategy, self._config = self._prepare_strategy( self._strategy, self._config = self._prepare_strategy(
self.strategy_config) self.strategy_config)
#print(self._strategy, self._config[0].__dict__)
#sys.exit()
# If train_config is None, set default train_config # If train_config is None, set default train_config
if self.train_config is None: if self.train_config is None:
...@@ -189,14 +197,15 @@ class AutoCompression: ...@@ -189,14 +197,15 @@ class AutoCompression:
for strategy_c in strategy_config: for strategy_c in strategy_config:
quant_config = strategy_c.get("Quantization", None) quant_config = strategy_c.get("Quantization", None)
hpo_config = strategy_c.get("HyperParameterOptimization", None) hpo_config = strategy_c.get("HyperParameterOptimization", None)
prune_config = strategy_c.get("Prune", None) prune_config = strategy_c.get("ChannelPrune", None)
asp_config = strategy_c.get("ASPPrune", None)
transformer_prune_config = strategy_c.get("TransformerPrune", None)
unstructure_prune_config = strategy_c.get("UnstructurePrune", None) unstructure_prune_config = strategy_c.get("UnstructurePrune", None)
single_teacher_distill_config = strategy_c.get("Distillation", None) single_teacher_distill_config = strategy_c.get("Distillation", None)
if single_teacher_distill_config is not None and single_teacher_distill_config.teacher_model_dir is None: if single_teacher_distill_config is not None and single_teacher_distill_config.teacher_model_dir is None:
single_teacher_distill_config = single_teacher_distill_config._replace( single_teacher_distill_config.teacher_model_dir = self.model_dir
teacher_model_dir=self.model_dir, single_teacher_distill_config.teacher_model_filename = self.model_filename
teacher_model_filename=self.model_filename, single_teacher_distill_config.teacher_params_filename = self.params_filename
teacher_params_filename=self.params_filename)
multi_teacher_distill_config = strategy_c.get( multi_teacher_distill_config = strategy_c.get(
"MultiTeacherDistillation", None) "MultiTeacherDistillation", None)
...@@ -219,17 +228,29 @@ class AutoCompression: ...@@ -219,17 +228,29 @@ class AutoCompression:
### case3: prune_config & distill config ### case3: prune_config & distill config
elif prune_config is not None and self._distill_config is not None: elif prune_config is not None and self._distill_config is not None:
strategy.append('prune_dis') strategy.append('channel_prune_dis')
config.append(merge_config(prune_config, self._distill_config)) config.append(merge_config(prune_config, self._distill_config))
### case4: unstructure_config & distill config ### case4: asp_config & distill config
elif asp_config is not None and self._distill_config is not None:
strategy.append('asp_prune_dis')
config.append(merge_config(asp_config, self._distill_config))
### case5: transformer_prune_config & distill config
elif transformer_prune_config is not None and self._distill_config is not None:
strategy.append('transformer_prune_dis')
config.append(
merge_config(transformer_prune_config,
self._distill_config))
### case6: unstructure_config & distill config
elif unstructure_prune_config is not None and self._distill_config is not None: elif unstructure_prune_config is not None and self._distill_config is not None:
strategy.append('unstructure_prune_dis') strategy.append('unstructure_prune_dis')
config.append( config.append(
merge_config(unstructure_prune_config, merge_config(unstructure_prune_config,
self._distill_config)) self._distill_config))
### case4: distill_config ### case7: distill_config
elif self._distill_config is not None: elif self._distill_config is not None:
if single_teacher_distill_config is not None: if single_teacher_distill_config is not None:
strategy.append('single_teacher_dis') strategy.append('single_teacher_dis')
...@@ -272,7 +293,7 @@ class AutoCompression: ...@@ -272,7 +293,7 @@ class AutoCompression:
train_program_info = ProgramInfo(startup_program, train_program, train_program_info = ProgramInfo(startup_program, train_program,
feed_target_names, fetch_targets) feed_target_names, fetch_targets)
config_dict = dict(config._asdict()) config_dict = config.__dict__
if "prune_strategy" in config_dict and config_dict[ if "prune_strategy" in config_dict and config_dict[
"prune_strategy"] == "gmp" and config_dict[ "prune_strategy"] == "gmp" and config_dict[
'gmp_config'] is None: 'gmp_config'] is None:
...@@ -313,7 +334,7 @@ class AutoCompression: ...@@ -313,7 +334,7 @@ class AutoCompression:
self._exe, self._exe,
self._places, self._places,
config_dict, config_dict,
self.train_config._asdict(), self.train_config.__dict__,
train_program_info, train_program_info,
pruner=self._pruner, pruner=self._pruner,
dist_strategy=dist_strategy, dist_strategy=dist_strategy,
...@@ -345,7 +366,7 @@ class AutoCompression: ...@@ -345,7 +366,7 @@ class AutoCompression:
train_program_info.optimizer.amp_init( train_program_info.optimizer.amp_init(
self._places, scope=paddle.static.global_scope()) self._places, scope=paddle.static.global_scope())
if 'prune_algo' in config_dict and config_dict['prune_algo'] == 'asp': if 'asp' in strategy:
### prune weight in scope ### prune weight in scope
self._pruner.prune_model(train_program_info.program) self._pruner.prune_model(train_program_info.program)
......
...@@ -33,7 +33,7 @@ def load_config(config_path): ...@@ -33,7 +33,7 @@ def load_config(config_path):
compress_config = {} compress_config = {}
for key, value in cfg.items(): for key, value in cfg.items():
default_key = eval(key)(**value) default_key = eval(key)(**value) if value is not None else eval(key)()
compress_config[key] = default_key compress_config[key] = default_key
if compress_config.get('TrainConfig') != None: if compress_config.get('TrainConfig') != None:
......
...@@ -17,6 +17,7 @@ import numpy as np ...@@ -17,6 +17,7 @@ import numpy as np
import paddle import paddle
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
import paddle.optimizer as optimizer import paddle.optimizer as optimizer
import paddle.regularizer as regularizer
from ..quant.quanter import quant_aware, _quant_config_default, _parse_configs, pact, get_pact_optimizer from ..quant.quanter import quant_aware, _quant_config_default, _parse_configs, pact, get_pact_optimizer
from ..dist import * from ..dist import *
from ..common.recover_program import recover_inference_program, _remove_fetch_node from ..common.recover_program import recover_inference_program, _remove_fetch_node
...@@ -44,37 +45,73 @@ def _create_lr_scheduler(train_config): ...@@ -44,37 +45,73 @@ def _create_lr_scheduler(train_config):
def _create_optimizer(train_config): def _create_optimizer(train_config):
"""create optimizer""" """create optimizer"""
opt = getattr(optimizer, train_config.get('optimizer') or
'SGD') ### default optimizer is SGD if 'optimizer_builder' not in train_config:
if 'optim_args' in train_config: train_config['optimizer_builder'] = {'optimizer': {'type': 'SGD'}}
if train_config[
'optim_args'] is not None and 'grad_clip' in train_config[ optimizer_builder = train_config['optimizer_builder']
'optim_args'] and train_config['optim_args'][
'grad_clip'] is not None: if 'grad_clip' in optimizer_builder:
grad_clip = getattr( g_clip_params = optimizer_builder['grad_clip']
paddle.nn, train_config['optim_args']['grad_clip'])( g_clip_type = g_clip_params.pop('type')
**train_config['optim_args']['grad_clip_args']) grad_clip = getattr(paddle.nn, g_clip_type)(**g_clip_params)
train_config['optim_args'].pop('grad_clip')
train_config['optim_args'].pop('grad_clip_args')
else:
grad_clip = None
if 'grad_clip' in train_config['optim_args'] and train_config[
'optim_args']['grad_clip'] is None:
train_config['optim_args'].pop('grad_clip')
train_config['optim_args'].pop('grad_clip_args')
else: else:
train_config['optim_args'] = {}
grad_clip = None grad_clip = None
### build regularization
if 'regularizer' in optimizer_builder:
reg_params = optimizer_builder['regularizer']
reg_type = reg_params.pop('type')
reg = getattr(regularizer, reg_type)(**reg_params)
elif 'weight_decay' in optimizer_builder:
reg = optimizer_builder.pop('weight_decay')
else:
reg = None
### build learning rate
lr = _create_lr_scheduler(train_config) lr = _create_lr_scheduler(train_config)
op = opt(learning_rate=lr,
grad_clip=grad_clip, ### build optimizer
**train_config['optim_args']) optim_params = optimizer_builder['optimizer']
return op, lr optim_type = optim_params.pop('type')
opt = getattr(optimizer, optim_type)(learning_rate=lr,
grad_clip=grad_clip,
weight_decay=reg,
**optim_params)
return opt, lr
def _get_distill_node(student_program, config):
node = config.get('node')
if len(node) == 0:
return None
### the type of node is list or list(list)
if isinstance(node[0], list):
test_node = node[0][0]
else:
test_node = node[0]
try:
test_var = student_program.global_block().var(test_node)
distill_node_pair = []
if isinstance(node[0], list):
for n_list in node:
tmp_node_pair = []
for n in n_list:
tmp_node_pair.append('teacher_' + n)
tmp_node_pair.append(n)
distill_node_pair.append(tmp_node_pair)
else:
for n in node:
distill_node_pair.append('teacher_' + n)
distill_node_pair.append(n)
return distill_node_pair
except:
return node
def _parse_distill_loss(distill_node_pair, def _parse_distill_loss(distill_node_pair,
distill_loss='l2_loss', distill_loss='l2',
distill_lambda=1.0): distill_lambda=1.0):
"""parse distill loss config""" """parse distill loss config"""
loss_dist = 0.0 loss_dist = 0.0
...@@ -135,9 +172,9 @@ def _load_program_and_merge(executor, ...@@ -135,9 +172,9 @@ def _load_program_and_merge(executor,
data_name_map = {} data_name_map = {}
if 'merge_feed' not in config or config['merge_feed'] == True: merge_feed = (
assert len(feed_target_names) == len(teacher_feed_target_names), \ sorted(feed_target_names) == sorted(teacher_feed_target_names))
"the number of feed nodes in the teacher model is not equal to the student model" if merge_feed == True:
for i, name in enumerate(feed_target_names): for i, name in enumerate(feed_target_names):
data_name_map[teacher_feed_target_names[i]] = name data_name_map[teacher_feed_target_names[i]] = name
...@@ -153,7 +190,7 @@ def _load_program_and_merge(executor, ...@@ -153,7 +190,7 @@ def _load_program_and_merge(executor,
place, place,
teacher_scope=new_scope, teacher_scope=new_scope,
name_prefix=teacher_name_prefix, name_prefix=teacher_name_prefix,
merge_feed=config.get('merge_feed') or True) merge_feed=merge_feed)
if teacher_idx == None or teacher_idx == 1: if teacher_idx == None or teacher_idx == 1:
return train_program, test_program, data_name_map return train_program, test_program, data_name_map
else: else:
...@@ -180,6 +217,9 @@ def build_distill_program(executor, ...@@ -180,6 +217,9 @@ def build_distill_program(executor,
feed_target_names = train_program_info.feed_target_names feed_target_names = train_program_info.feed_target_names
fetch_targets = train_program_info.fetch_targets fetch_targets = train_program_info.fetch_targets
distill_node_pair = _get_distill_node(train_program,
config) or default_distill_node_pair
teacher_model_dir = config[ teacher_model_dir = config[
"teacher_model_dir"] if "teacher_model_dir" in config else config[ "teacher_model_dir"] if "teacher_model_dir" in config else config[
"teacher_model_path_prefix"] "teacher_model_path_prefix"]
...@@ -270,16 +310,15 @@ def build_distill_program(executor, ...@@ -270,16 +310,15 @@ def build_distill_program(executor,
**train_config['amp_config']) **train_config['amp_config'])
distill_loss, losses = _parse_distill_loss( distill_loss, losses = _parse_distill_loss(
config.get('distill_node_pair') or default_distill_node_pair, distill_node_pair,
config.get('distill_loss') or config.get('loss') or 'l2', ### default loss is l2
'l2_loss', ### default loss is l2_loss config.get('alpha') or 1.0) ### default alpha is 1.0
config.get('distill_lambda') or 1.0) ### default lambda is 1.0
loss = paddle.mean(distill_loss) loss = paddle.mean(distill_loss)
loss.stop_gradient = False loss.stop_gradient = False
if 'prune_algo' in config: ### prune & asp if 'prune_params_name' in config: ### prune
if config['prune_algo'] == 'asp' and not train_config.get( if 'pruned_ratio' not in config and not train_config.get(
'use_fleet'): 'use_fleet'): ### asp
optimizer = pruner.decorate(optimizer) optimizer = pruner.decorate(optimizer)
optimizer.minimize(loss) optimizer.minimize(loss)
elif 'prune_strategy' in config: ###unstructure prune elif 'prune_strategy' in config: ###unstructure prune
...@@ -302,11 +341,8 @@ def build_quant_program(executor, place, config, train_program_info, ...@@ -302,11 +341,8 @@ def build_quant_program(executor, place, config, train_program_info,
scope = paddle.static.global_scope() scope = paddle.static.global_scope()
assert isinstance(config, dict), "quant config must be dict" assert isinstance(config, dict), "quant config must be dict"
default_config = _quant_config_default
default_config.update(config)
config = _parse_configs(default_config)
use_pact = config["use_pact"] use_pact = config.pop("use_pact")
if use_pact: if use_pact:
act_preprocess_func = pact act_preprocess_func = pact
optimizer_func = get_pact_optimizer optimizer_func = get_pact_optimizer
...@@ -364,13 +400,13 @@ def build_prune_program(executor, ...@@ -364,13 +400,13 @@ def build_prune_program(executor,
strategy, strategy,
patterns, patterns,
eval_dataloader=None): eval_dataloader=None):
if 'unstructure' in strategy: if strategy.startswith('unstructure'):
from ..prune.unstructured_pruner import UnstructuredPruner, GMPUnstructuredPruner from ..prune.unstructured_pruner import UnstructuredPruner, GMPUnstructuredPruner
if config["prune_strategy"] is None: if config["prune_strategy"] is None:
pruner = UnstructuredPruner( pruner = UnstructuredPruner(
train_program_info.program, train_program_info.program,
mode=config['prune_mode'], mode=config['prune_mode'],
ratio=config['pruned_ratio'], ratio=config['ratio'],
threshold=config['threshold'], threshold=config['threshold'],
prune_params_type=config['prune_params_type'], prune_params_type=config['prune_params_type'],
place=place, place=place,
...@@ -378,69 +414,65 @@ def build_prune_program(executor, ...@@ -378,69 +414,65 @@ def build_prune_program(executor,
elif config["prune_strategy"] == "gmp": elif config["prune_strategy"] == "gmp":
pruner = GMPUnstructuredPruner( pruner = GMPUnstructuredPruner(
train_program_info.program, train_program_info.program,
ratio=config['pruned_ratio'], ratio=config['ratio'],
prune_params_type=config['prune_params_type'], prune_params_type=config['prune_params_type'],
place=place, place=place,
local_sparsity=config['local_sparsity'], local_sparsity=config['local_sparsity'],
configs=config['gmp_config']) configs=config['gmp_config'])
elif strategy.startswith('channel_prune'):
from ..prune import Pruner
pruner = Pruner(config["criterion"])
params = []
### TODO(ceci3): set default prune weight
for param in train_program_info.program.global_block().all_parameters():
if config['prune_params_name'] is not None and param.name in config[
'prune_params_name']:
params.append(param.name)
pruned_program, _, _ = pruner.prune(
train_program_info.program,
paddle.static.global_scope(),
params=params,
ratios=[config['pruned_ratio']] * len(params),
place=place)
train_program_info.program = pruned_program
elif strategy.startswith('asp'):
from paddle.static import sparsity
pruner = sparsity
excluded_params_name = []
### TODO(ceci3): set default prune weight
for param in train_program_info.program.global_block().all_parameters():
if config[
'prune_params_name'] is not None and param.name not in config[
'prune_params_name']:
excluded_params_name.append(param.name)
if "teacher_" in param.name:
excluded_params_name.append(param.name)
pruner.set_excluded_layers(train_program_info.program,
excluded_params_name)
elif strategy.startswith('transformer_prune'):
from .transformer_pruner import TransformerPruner
assert eval_dataloader is not None, "transformer_pruner must set eval_dataloader"
label_info = _get_label_info(eval_dataloader,
train_program_info.feed_target_names)
assert len(label_info) != 0, \
"maybe something wrong in get label name from eval_dataloader, please check your eval_dataloader"
pruner = TransformerPruner(
executor,
place,
train_program_info.program,
patterns,
label_info,
width_mult=(1.0 - config['pruned_ratio']),
dataloader=eval_dataloader,
fetch_targets=train_program_info.fetch_targets)
pruned_program = pruner.prune()
train_program_info.program = pruned_program
else: else:
if config['prune_algo'] == 'prune': raise NotImplementedError(
from ..prune import Pruner "prune_algo must be choice in [\"prune\", \"asp\"], {} is not support".
pruner = Pruner(config["criterion"]) format(config['prune_algo']))
params = []
### TODO(ceci3): set default prune weight
for param in train_program_info.program.global_block(
).all_parameters():
if config[
'prune_params_name'] is not None and param.name in config[
'prune_params_name']:
params.append(param.name)
pruned_program, _, _ = pruner.prune(
train_program_info.program,
paddle.static.global_scope(),
params=params,
ratios=[config['pruned_ratio']] * len(params),
place=place)
train_program_info.program = pruned_program
elif config['prune_algo'] == 'asp':
from paddle.static import sparsity
pruner = sparsity
excluded_params_name = []
### TODO(ceci3): set default prune weight
for param in train_program_info.program.global_block(
).all_parameters():
if config[
'prune_params_name'] is not None and param.name not in config[
'prune_params_name']:
excluded_params_name.append(param.name)
if "teacher_" in param.name:
excluded_params_name.append(param.name)
pruner.set_excluded_layers(train_program_info.program,
excluded_params_name)
elif config['prune_algo'] == 'transformer_pruner':
from .transformer_pruner import TransformerPruner
assert eval_dataloader is not None, "transformer_pruner must set eval_dataloader"
label_info = _get_label_info(eval_dataloader,
train_program_info.feed_target_names)
assert len(label_info) != 0, \
"maybe something wrong in get label name from eval_dataloader, please check your eval_dataloader"
pruner = TransformerPruner(
executor,
place,
train_program_info.program,
patterns,
label_info,
width_mult=(1.0 - config['pruned_ratio']),
dataloader=eval_dataloader,
fetch_targets=train_program_info.fetch_targets)
pruned_program = pruner.prune()
train_program_info.program = pruned_program
else:
raise NotImplementedError(
"prune_algo must be choice in [\"prune\", \"asp\"], {} is not support".
format(config['prune_algo']))
return pruner, train_program_info return pruner, train_program_info
......
...@@ -12,5 +12,5 @@ ...@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .single_distiller import merge, fsp_loss, l2_loss, soft_label_loss, loss from .single_distiller import merge, fsp, l2, soft_label, loss
from .dml import DML from .dml import DML
...@@ -54,7 +54,8 @@ def merge(teacher_program, ...@@ -54,7 +54,8 @@ def merge(teacher_program,
teacher_program = teacher_program.clone(for_test=True) teacher_program = teacher_program.clone(for_test=True)
for teacher_var in teacher_program.list_vars(): for teacher_var in teacher_program.list_vars():
skip_rename = False skip_rename = False
if teacher_var.name != 'fetch' and (not merge_feed or teacher_var.name != 'feed'): if teacher_var.name != 'fetch' and (not merge_feed or
teacher_var.name != 'feed'):
if teacher_var.name in data_name_map.keys(): if teacher_var.name in data_name_map.keys():
new_name = data_name_map[teacher_var.name] new_name = data_name_map[teacher_var.name]
if new_name == teacher_var.name: if new_name == teacher_var.name:
...@@ -72,7 +73,8 @@ def merge(teacher_program, ...@@ -72,7 +73,8 @@ def merge(teacher_program,
teacher_var.name, new_name) teacher_var.name, new_name)
for teacher_var in teacher_program.list_vars(): for teacher_var in teacher_program.list_vars():
if teacher_var.name != 'fetch' and (not merge_feed or teacher_var.name != 'feed'): if teacher_var.name != 'fetch' and (not merge_feed or
teacher_var.name != 'feed'):
# student program add var # student program add var
new_var = student_program.global_block()._clone_variable( new_var = student_program.global_block()._clone_variable(
teacher_var, force_persistable=False) teacher_var, force_persistable=False)
...@@ -111,11 +113,11 @@ def merge(teacher_program, ...@@ -111,11 +113,11 @@ def merge(teacher_program,
op._op._set_attr("skip_quant", True) op._op._set_attr("skip_quant", True)
def fsp_loss(teacher_var1_name, def fsp(teacher_var1_name,
teacher_var2_name, teacher_var2_name,
student_var1_name, student_var1_name,
student_var2_name, student_var2_name,
program=None): program=None):
"""Combine variables from student model and teacher model by fsp-loss. """Combine variables from student model and teacher model by fsp-loss.
Args: Args:
...@@ -149,7 +151,7 @@ def fsp_loss(teacher_var1_name, ...@@ -149,7 +151,7 @@ def fsp_loss(teacher_var1_name,
return fsp_loss return fsp_loss
def l2_loss(teacher_var_name, student_var_name, program=None): def l2(teacher_var_name, student_var_name, program=None):
"""Combine variables from student model and teacher model by l2-loss. """Combine variables from student model and teacher model by l2-loss.
Args: Args:
...@@ -170,11 +172,11 @@ def l2_loss(teacher_var_name, student_var_name, program=None): ...@@ -170,11 +172,11 @@ def l2_loss(teacher_var_name, student_var_name, program=None):
return l2_loss return l2_loss
def soft_label_loss(teacher_var_name, def soft_label(teacher_var_name,
student_var_name, student_var_name,
program=None, program=None,
teacher_temperature=1., teacher_temperature=1.,
student_temperature=1.): student_temperature=1.):
"""Combine variables from student model and teacher model by soft-label-loss. """Combine variables from student model and teacher model by soft-label-loss.
Args: Args:
......
...@@ -27,7 +27,7 @@ import paddle.fluid as fluid ...@@ -27,7 +27,7 @@ import paddle.fluid as fluid
from ..common.recover_program import recover_inference_program from ..common.recover_program import recover_inference_program
from .quanter import _quant_config_default, _parse_configs, pact, get_pact_optimizer from .quanter import _quant_config_default, _parse_configs, pact, get_pact_optimizer
from .quanter import quant_aware, convert from .quanter import quant_aware, convert
from ..dist import merge, l2_loss, soft_label_loss, fsp_loss from ..dist import merge, l2, soft_label, fsp
from ..auto_compression.create_compressed_program import build_distill_program from ..auto_compression.create_compressed_program import build_distill_program
import logging import logging
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
...@@ -57,7 +57,7 @@ _train_config_default = { ...@@ -57,7 +57,7 @@ _train_config_default = {
and the teacher node and student node are arranged in pairs. and the teacher node and student node are arranged in pairs.
for example, ["teacher_fc_0.tmp_0", "fc_0.tmp_0", "teacher_batch_norm_24.tmp_4", "batch_norm_24.tmp_4"] for example, ["teacher_fc_0.tmp_0", "fc_0.tmp_0", "teacher_batch_norm_24.tmp_4", "batch_norm_24.tmp_4"]
""" """
"distill_node_pair": None "node": None
} }
...@@ -91,12 +91,10 @@ def _parse_train_configs(train_config): ...@@ -91,12 +91,10 @@ def _parse_train_configs(train_config):
"'teacher_model_path_prefix' must both be string" "'teacher_model_path_prefix' must both be string"
assert isinstance(configs['model_path_prefix'], str), \ assert isinstance(configs['model_path_prefix'], str), \
"'model_path_prefix' must both be str" "'model_path_prefix' must both be str"
assert isinstance(configs['distill_node_pair'], list), \ assert isinstance(configs['node'], list), \
"'distill_node_pair' must both be list" "'node' must both be list"
assert len(configs['distill_node_pair']) > 0, \ assert len(configs['node']) > 0, \
"'distill_node_pair' not configured with distillation nodes" "'node' not configured with distillation nodes"
assert len(configs['distill_node_pair']) % 2 == 0, \
"'distill_node_pair' distillation nodes need to be configured in pairs"
return train_config return train_config
...@@ -143,7 +141,7 @@ def quant_aware_with_infermodel(executor, ...@@ -143,7 +141,7 @@ def quant_aware_with_infermodel(executor,
train_config(dict):train aware configs, include num_epoch, save_iter_step, learning_rate, train_config(dict):train aware configs, include num_epoch, save_iter_step, learning_rate,
weight_decay, use_pact, quant_model_ckpt_path, weight_decay, use_pact, quant_model_ckpt_path,
model_path_prefix, teacher_model_path_prefix, model_path_prefix, teacher_model_path_prefix,
distill_node_pair(teacher_node_name1, node_name1, teacher_node_name2, teacher_node_name2, ...) node(node_name1, node_name2, ...)
test_callback(callback function): callback function include two params: compiled test quant program and checkpoint save filename. test_callback(callback function): callback function include two params: compiled test quant program and checkpoint save filename.
user can implement test logic. user can implement test logic.
Returns: Returns:
...@@ -261,7 +259,7 @@ def export_quant_infermodel( ...@@ -261,7 +259,7 @@ def export_quant_infermodel(
train_config(dict):train aware configs, include num_epoch, save_iter_step, learning_rate, train_config(dict):train aware configs, include num_epoch, save_iter_step, learning_rate,
weight_decay, use_pact, quant_model_ckpt_path, weight_decay, use_pact, quant_model_ckpt_path,
model_path_prefix, teacher_model_path_prefix, model_path_prefix, teacher_model_path_prefix,
distill_node_pair(teacher_node_name1, node_name1, teacher_node_name2, teacher_node_name2, ...) node(node_name1, node_name2, ...)
checkpoint_path(str): checkpoint path need to export quant infer model. checkpoint_path(str): checkpoint path need to export quant infer model.
export_inference_model_path_prefix(str): export infer model path prefix, storage directory of model + model name (excluding suffix). export_inference_model_path_prefix(str): export infer model path prefix, storage directory of model + model name (excluding suffix).
Returns: Returns:
......
...@@ -15,7 +15,7 @@ import sys ...@@ -15,7 +15,7 @@ import sys
sys.path.append("../") sys.path.append("../")
import unittest import unittest
import paddle import paddle
from paddleslim.dist import merge, fsp_loss from paddleslim.dist import merge, fsp
from layers import conv_bn_layer from layers import conv_bn_layer
from static_case import StaticCase from static_case import StaticCase
...@@ -49,9 +49,8 @@ class TestFSPLoss(StaticCase): ...@@ -49,9 +49,8 @@ class TestFSPLoss(StaticCase):
for block in paddle.static.default_main_program().blocks: for block in paddle.static.default_main_program().blocks:
for op in block.ops: for op in block.ops:
merged_ops.append(op.type) merged_ops.append(op.type)
distill_loss = fsp_loss('teacher_conv1_out.tmp_1', distill_loss = fsp('teacher_conv1_out.tmp_1', 'teacher_conv6_out.tmp_0',
'teacher_conv6_out.tmp_0', 'conv1_out.tmp_0', 'conv1_out.tmp_0', 'conv2_out.tmp_0')
'conv2_out.tmp_0')
loss_ops = [] loss_ops = []
for block in paddle.static.default_main_program().blocks: for block in paddle.static.default_main_program().blocks:
for op in block.ops: for op in block.ops:
......
...@@ -16,7 +16,7 @@ sys.path.append("../") ...@@ -16,7 +16,7 @@ sys.path.append("../")
import unittest import unittest
import paddle import paddle
from static_case import StaticCase from static_case import StaticCase
from paddleslim.dist import merge, l2_loss from paddleslim.dist import merge, l2
from layers import conv_bn_layer from layers import conv_bn_layer
...@@ -48,8 +48,8 @@ class TestL2Loss(StaticCase): ...@@ -48,8 +48,8 @@ class TestL2Loss(StaticCase):
for block in paddle.static.default_main_program().blocks: for block in paddle.static.default_main_program().blocks:
for op in block.ops: for op in block.ops:
merged_ops.append(op.type) merged_ops.append(op.type)
distill_loss = l2_loss('teacher_conv6_bn_output.tmp_2', distill_loss = l2('teacher_conv6_bn_output.tmp_2',
'conv2_bn_output.tmp_2') 'conv2_bn_output.tmp_2')
loss_ops = [] loss_ops = []
for block in paddle.static.default_main_program().blocks: for block in paddle.static.default_main_program().blocks:
for op in block.ops: for op in block.ops:
......
...@@ -17,6 +17,7 @@ sys.path.append("../") ...@@ -17,6 +17,7 @@ sys.path.append("../")
sys.path.append(".") sys.path.append(".")
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir) sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
import unittest import unittest
import copy
import paddle import paddle
from paddleslim.quant import quant_aware, convert from paddleslim.quant import quant_aware, convert
from paddleslim.quant import quant_aware_with_infermodel, export_quant_infermodel from paddleslim.quant import quant_aware_with_infermodel, export_quant_infermodel
...@@ -145,13 +146,10 @@ class TestQuantAwareWithInferModelCase1(StaticCase): ...@@ -145,13 +146,10 @@ class TestQuantAwareWithInferModelCase1(StaticCase):
"./quantaware_with_infermodel_checkpoints/", "./quantaware_with_infermodel_checkpoints/",
"teacher_model_path_prefix": float_infer_model_path_prefix, "teacher_model_path_prefix": float_infer_model_path_prefix,
"model_path_prefix": float_infer_model_path_prefix, "model_path_prefix": float_infer_model_path_prefix,
"distill_node_pair": [ "node": [
"teacher_fc_0.tmp_0", "fc_0.tmp_0", "fc_0.tmp_0", "batch_norm_24.tmp_4", "batch_norm_22.tmp_4",
"teacher_batch_norm_24.tmp_4", "batch_norm_24.tmp_4", "batch_norm_18.tmp_4", "batch_norm_13.tmp_4",
"teacher_batch_norm_22.tmp_4", "batch_norm_22.tmp_4", "batch_norm_5.tmp_4"
"teacher_batch_norm_18.tmp_4", "batch_norm_18.tmp_4",
"teacher_batch_norm_13.tmp_4", "batch_norm_13.tmp_4",
"teacher_batch_norm_5.tmp_4", "batch_norm_5.tmp_4"
] ]
} }
...@@ -184,7 +182,7 @@ class TestQuantAwareWithInferModelCase1(StaticCase): ...@@ -184,7 +182,7 @@ class TestQuantAwareWithInferModelCase1(StaticCase):
scope=None, scope=None,
train_reader=train_loader, train_reader=train_loader,
quant_config=quant_config, quant_config=quant_config,
train_config=train_config, train_config=copy.deepcopy(train_config),
test_callback=test_callback) test_callback=test_callback)
def test_export_quant_infermodel(exe, place, checkpoint_path, def test_export_quant_infermodel(exe, place, checkpoint_path,
...@@ -194,7 +192,7 @@ class TestQuantAwareWithInferModelCase1(StaticCase): ...@@ -194,7 +192,7 @@ class TestQuantAwareWithInferModelCase1(StaticCase):
place, place,
scope=None, scope=None,
quant_config=quant_config, quant_config=quant_config,
train_config=train_config, train_config=copy.deepcopy(train_config),
checkpoint_path=checkpoint_path, checkpoint_path=checkpoint_path,
export_inference_model_path_prefix=quant_infermodel_save_path) export_inference_model_path_prefix=quant_infermodel_save_path)
......
...@@ -15,7 +15,7 @@ import sys ...@@ -15,7 +15,7 @@ import sys
sys.path.append("../") sys.path.append("../")
import unittest import unittest
import paddle import paddle
from paddleslim.dist import merge, soft_label_loss from paddleslim.dist import merge, soft_label
from layers import conv_bn_layer from layers import conv_bn_layer
from static_case import StaticCase from static_case import StaticCase
...@@ -48,8 +48,8 @@ class TestSoftLabelLoss(StaticCase): ...@@ -48,8 +48,8 @@ class TestSoftLabelLoss(StaticCase):
for block in paddle.static.default_main_program().blocks: for block in paddle.static.default_main_program().blocks:
for op in block.ops: for op in block.ops:
merged_ops.append(op.type) merged_ops.append(op.type)
distill_loss = soft_label_loss('teacher_conv6_bn_output.tmp_2', distill_loss = soft_label('teacher_conv6_bn_output.tmp_2',
'conv2_bn_output.tmp_2') 'conv2_bn_output.tmp_2')
loss_ops = [] loss_ops = []
for block in paddle.static.default_main_program().blocks: for block in paddle.static.default_main_program().blocks:
for op in block.ops: for op in block.ops:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册