未验证 提交 0a4dbe16 编写于 作者: L Liufang Sang 提交者: GitHub

add user defined function support in quantization and unittest and demo (#319)

上级 8b0005bd
......@@ -5,6 +5,7 @@ from .resnet_vd import ResNet50_vd, ResNet101_vd
from .mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2
from .pvanet import PVANet
from .slimfacenet import SlimFaceNet_A_x0_60, SlimFaceNet_B_x0_75, SlimFaceNet_C_x0_75
from .mobilenet_v3 import *
__all__ = [
"model_list", "MobileNet", "ResNet34", "ResNet50", "MobileNetV2", "PVANet",
"ResNet50_vd", "ResNet101_vd", "MobileNetV2_x0_25"
......@@ -13,3 +14,6 @@ model_list = [
'MobileNet', 'ResNet34', 'ResNet50', 'MobileNetV2', 'PVANet',
'ResNet50_vd', "ResNet101_vd", "MobileNetV2_x0_25"
]
__all__ += mobilenet_v3.__all__
model_list += mobilenet_v3.__all__
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
import math
__all__ = [
'MobileNetV3', 'MobileNetV3_small_x0_25', 'MobileNetV3_small_x0_5',
'MobileNetV3_small_x0_75', 'MobileNetV3_small_x1_0',
'MobileNetV3_small_x1_25', 'MobileNetV3_large_x0_25',
'MobileNetV3_large_x0_5', 'MobileNetV3_large_x0_75',
'MobileNetV3_large_x1_0', 'MobileNetV3_large_x1_25',
'MobileNetV3_large_x2_0'
]
class MobileNetV3():
def __init__(self, scale=1.0, model_name='small'):
self.scale = scale
self.inplanes = 16
if model_name == "large":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, 'relu', 1],
[3, 64, 24, False, 'relu', 2],
[3, 72, 24, False, 'relu', 1],
[5, 72, 40, True, 'relu', 2],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hard_swish', 2],
[3, 200, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 480, 112, True, 'hard_swish', 1],
[3, 672, 112, True, 'hard_swish', 1],
[5, 672, 160, True, 'hard_swish', 2],
[5, 960, 160, True, 'hard_swish', 1],
[5, 960, 160, True, 'hard_swish', 1],
]
self.cls_ch_squeeze = 960
self.cls_ch_expand = 1280
elif model_name == "small":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, 'relu', 2],
[3, 72, 24, False, 'relu', 2],
[3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hard_swish', 2],
[5, 240, 40, True, 'hard_swish', 1],
[5, 240, 40, True, 'hard_swish', 1],
[5, 120, 48, True, 'hard_swish', 1],
[5, 144, 48, True, 'hard_swish', 1],
[5, 288, 96, True, 'hard_swish', 2],
[5, 576, 96, True, 'hard_swish', 1],
[5, 576, 96, True, 'hard_swish', 1],
]
self.cls_ch_squeeze = 576
self.cls_ch_expand = 1280
else:
raise NotImplementedError
def net(self, input, class_dim=1000):
scale = self.scale
inplanes = self.inplanes
cfg = self.cfg
cls_ch_squeeze = self.cls_ch_squeeze
cls_ch_expand = self.cls_ch_expand
#conv1
conv = self.conv_bn_layer(
input,
filter_size=3,
#num_filters=int(scale*inplanes),
num_filters=inplanes if scale <= 1.0 else int(inplanes * scale),
stride=2,
padding=1,
num_groups=1,
if_act=True,
act='hard_swish',
name='conv1')
print(conv.shape)
i = 0
for layer_cfg in cfg:
conv = self.residual_unit(
input=conv,
num_in_filter=inplanes,
num_mid_filter=int(scale * layer_cfg[1]),
num_out_filter=int(scale * layer_cfg[2]),
act=layer_cfg[4],
stride=layer_cfg[5],
filter_size=layer_cfg[0],
use_se=layer_cfg[3],
name='conv' + str(i + 2))
inplanes = int(scale * layer_cfg[2])
i += 1
conv = self.conv_bn_layer(
input=conv,
filter_size=1,
num_filters=int(scale * cls_ch_squeeze),
stride=1,
padding=0,
num_groups=1,
if_act=True,
act='hard_swish',
name='conv_last')
conv = fluid.layers.pool2d(
input=conv, pool_type='avg', global_pooling=True, use_cudnn=False)
conv = fluid.layers.conv2d(
input=conv,
num_filters=cls_ch_expand,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(name='last_1x1_conv_weights'),
bias_attr=False)
#conv = fluid.layers.hard_swish(conv)
conv = self.hard_swish(conv)
out = fluid.layers.fc(input=conv,
size=class_dim,
act='softmax',
param_attr=ParamAttr(name='fc_weights'),
bias_attr=ParamAttr(name='fc_offset'))
return out
def conv_bn_layer(self,
input,
filter_size,
num_filters,
stride,
padding,
num_groups=1,
if_act=True,
act=None,
name=None,
use_cudnn=True):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(
input=conv,
param_attr=ParamAttr(
name=bn_name + "_scale",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
bias_attr=ParamAttr(
name=bn_name + "_offset",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act:
if act == 'relu':
bn = fluid.layers.relu(bn)
elif act == 'hard_swish':
#bn = fluid.layers.hard_swish(bn)
bn = self.hard_swish(bn)
return bn
def hard_swish(self, x):
return x * fluid.layers.relu6(x + 3) / 6.
def se_block(self, input, num_out_filter, ratio=4, name=None):
num_mid_filter = int(num_out_filter // ratio)
pool = fluid.layers.pool2d(
input=input, pool_type='avg', global_pooling=True, use_cudnn=False)
conv1 = fluid.layers.conv2d(
input=pool,
filter_size=1,
num_filters=num_mid_filter,
act='relu',
param_attr=ParamAttr(name=name + '_1_weights'),
bias_attr=ParamAttr(name=name + '_1_offset'))
conv2 = fluid.layers.conv2d(
input=conv1,
filter_size=1,
num_filters=num_out_filter,
act='hard_sigmoid',
param_attr=ParamAttr(name=name + '_2_weights'),
bias_attr=ParamAttr(name=name + '_2_offset'))
scale = fluid.layers.elementwise_mul(x=input, y=conv2, axis=0)
return scale
def residual_unit(self,
input,
num_in_filter,
num_mid_filter,
num_out_filter,
stride,
filter_size,
act=None,
use_se=False,
name=None):
input_data = input
conv0 = self.conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_mid_filter,
stride=1,
padding=0,
if_act=True,
act=act,
name=name + '_expand')
conv1 = self.conv_bn_layer(
input=conv0,
filter_size=filter_size,
num_filters=num_mid_filter,
stride=stride,
padding=int((filter_size - 1) // 2),
if_act=True,
act=act,
num_groups=num_mid_filter,
use_cudnn=False,
name=name + '_depthwise')
if use_se:
with fluid.name_scope('se_block_skip'):
conv1 = self.se_block(
input=conv1,
num_out_filter=num_mid_filter,
name=name + '_se')
conv2 = self.conv_bn_layer(
input=conv1,
filter_size=1,
num_filters=num_out_filter,
stride=1,
padding=0,
if_act=False,
name=name + '_linear')
if num_in_filter != num_out_filter or stride != 1:
return conv2
else:
return fluid.layers.elementwise_add(
x=input_data, y=conv2, act=None)
def MobileNetV3_small_x0_25():
model = MobileNetV3(model_name='small', scale=0.25)
return model
def MobileNetV3_small_x0_5():
model = MobileNetV3(model_name='small', scale=0.5)
return model
def MobileNetV3_small_x0_75():
model = MobileNetV3(model_name='small', scale=0.75)
return model
def MobileNetV3_small_x1_0():
model = MobileNetV3(model_name='small', scale=1.0)
return model
def MobileNetV3_small_x1_25():
model = MobileNetV3(model_name='small', scale=1.25)
return model
def MobileNetV3_large_x0_25():
model = MobileNetV3(model_name='large', scale=0.25)
return model
def MobileNetV3_large_x0_5():
model = MobileNetV3(model_name='large', scale=0.5)
return model
def MobileNetV3_large_x0_75():
model = MobileNetV3(model_name='large', scale=0.75)
return model
def MobileNetV3_large_x1_0():
model = MobileNetV3(model_name='large', scale=1.0)
return model
def MobileNetV3_large_x1_25():
model = MobileNetV3(model_name='large', scale=1.25)
return model
def MobileNetV3_large_x2_0():
model = MobileNetV3(model_name='large', scale=2.0)
return model
# 自定义量化方法使用示例
本示例介绍如何使用自定义量化方法,以PACT方法为例,量化训练好的分类模型MobileNetV3, 可以减少模型的存储空间和显存占用。
## 方法介绍
PACT(Parameterized Clipping Activation for Quantized Neural Networks)[论文地址](https://arxiv.org/abs/1805.06085)提出了在量化激活值之前去掉一些离群点来使量化精度提高。论文中给的PACT的公式是:
<p align="center">
<img src="./image/pact.png" height=50 width=420 hspace='10'/> <br />
</p>
因为论文中的思想是将PACT公式代替ReLU激活函数,但是在实际使用中,将要进行量化的激活值不一定来自ReLU激活函数,有可能是其他函数,也有可能是来自elementwise op等,所以本demo中的方法是在激活值和量化op之间加入改进后的PACT方法,公式如下:
<p align="center">
<img src="./image/pact_our.png" height=50 width=360 hspace='10'/> <br />
</p>
改进的原因是要量化的激活值不一定都是大于0,而量化时寻找的时激活值的最大值,所以小于0的值也要进行约束。
### 定义PACT函数
自定义量化方法支持对激活值或者权重定义预处理方式,同时也支持自定义量化方法。在 `quant_aware` 接口中,相关参数以及意义如下:
- `weight_quantize_func`: 自定义对权重量化的函数,该函数的输入是待量化的权重,输出是反量化之后的权重,可以快速验证此量化函数是否有效。此参数设置后,将会替代量化配置中 `weight_quantize_type` 定义的方法,如果此参数不设置,将继续使用 `weight_quantize_type` 定义的方法。
- `act_quantize_func`: 自定义对激活量化的函数,该函数的输入是待量化的激活,输出是反量化之后的激活,可以快速验证此量化函数是否有效。将会替代量化配置中 `activation_quantize_type` 定义的方法,如果此参数不设置,将继续使用 `activation_quantize_type` 定义的方法。
- `weight_preprocess_func` : 自定义在对权重做量化之前,对权重进行处理的函数。此方法的意义在于网络中的参数不一定适合于直接量化,如果对参数分布先进行处理再进行量化,或许可以提高量化精度。
- `act_preprocess_func` : 自定义在对激活做量化之前,对激活进行处理的函数。此方法的意义在于网络中的激活值不一定适合于直接量化,如果对激活值先进行处理再进行量化,或许可以提高量化精度。
- `optimizer_func` : 该参数是一个返回optimizer的函数。定义的optimizer函数将用于定义上述自定义函数中的参数的优化参数。
PACT方法属于自定义 `act_preprocess_func`, 输入是将要量化的激活值。
可如下定义:
```
import paddle
import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
def pact(x, name=None):
helper = LayerHelper("pact", **locals())
dtype = 'float32'
# 定义PACT初始阈值
init_thres = 20
u_param_attr = fluid.ParamAttr(
name=x.name + '_pact',
initializer=fluid.initializer.ConstantInitializer(value=init_thres),
regularizer=fluid.regularizer.L2Decay(0.0001),
learning_rate=1)
u_param = helper.create_parameter(
attr=u_param_attr, shape=[1], dtype=dtype)
x = fluid.layers.elementwise_sub(
x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param)))
x = fluid.layers.elementwise_add(
x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x)))
return x
```
函数中可以定义初始阈值,和初始阈值的l2正则项系数,在训练过程中可根据梯度传播训练阈值为一个合适的值。
优化器函数如下:
```
def get_optimizer():
return fluid.optimizer.MomentumOptimizer(0.001, 0.9)
```
因为除了PACT阈值以外,其他参数都是训练好的,因此在训练时可以将PACT中阈值的学习率调大一些。
> 注意,因为PACT只是在量化时去掉了离群点,影响了量化scale的选择,因此使用PACT训练后,可以用普通量化的方法加载参数进行测试,是一个不影响预测的方法。
## MobileNetV3的量化训练流程
### 准备数据
``demo``文件夹下创建``data``文件夹,将``ImageNet``数据集解压在``data``文件夹下,解压后``data/ILSVRC2012``文件夹下应包含以下文件:
- ``'train'``文件夹,训练图片
- ``'train_list.txt'``文件
- ``'val'``文件夹,验证图片
- ``'val_list.txt'``文件
### 准备需要量化的模型
我们将使用飞桨分类库[PaddleClas](https://github.com/PaddlePaddle/PaddleClas)中给出的MobileNetV3精度最高的模型进行量化。量化前精度top-1为78.9%.
```
mkdir pretrain
cd pretrain
wget https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar
tar xf MobileNetV3_large_x1_0_ssld_pretrained.tar
cd ..
```
使用该模型的原因是因为MobileNetV3这个使用ssld蒸馏之后的模型,激活值存在很多离群点,可有效地验证PACT的效果。下面的图是MobileNetV3的其中一个中间激活值分布的直方图:
<p align="center">
<img src="./image/activation_dist.png" height=400 width=500 hspace='10'/> <br />
</p>
图中直方图的横坐标的范围是激活值分布的最小值和最大值,从图中可以看出,最小值在-60左右,最大值在80左右,但是主要分布在-20到20之间。
### 开启 `image` 的梯度
因为目前实现的原因,需要将 `image` 的梯度开启。
```
image.stop_gradient = False
```
### 配置量化参数
```
quant_config = {
'weight_quantize_type': 'channel_wise_abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'weight_bits': 8,
'activation_bits': 8,
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
'dtype': 'int8',
'window_size': 10000,
'moving_rate': 0.9
}
```
### 对训练和测试program插入可训练量化op
普通量化:
```
val_program = quant_aware(val_program, place, quant_config, scope=None, for_test=True)
compiled_train_prog = quant_aware(train_prog, place, quant_config, scope=None, for_test=False)
```
使用PACT的量化:
```
val_program = quant_aware(val_program, place, quant_config, scope=None, act_preprocess_func=pact, executor=exe, for_test=True)
compiled_train_prog = quant_aware(train_prog, place, quant_config, scope=None, act_preprocess_func=pact, optimizer_func=get_optimizer, executor=exe, for_test=False)
```
### 关掉指定build策略
```
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_all_reduce_ops = False
build_strategy.sync_batch_norm = False
exec_strategy = fluid.ExecutionStrategy()
compiled_train_prog = compiled_train_prog.with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
```
### 训练命令
普通量化:
```
python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/MobileNetV3_large_x1_0_ssld_pretrained --checkpoint_dir ./output/MobileNetV3_large_x1_0 --num_epochs 30 --lr 0.0001 --use_pact False
```
输出结果为:
```
2020-06-05 15:14:15,319-INFO: epoch[0]-batch[10] - loss: 2.50413322449; acc_top1: 0.515625; acc_top5: 0.75; time: 1.29066705704
2020-06-05 15:14:28,950-INFO: epoch[0]-batch[20] - loss: 3.14219880104; acc_top1: 0.3828125; acc_top5: 0.62890625; time: 1.29546618462
2020-06-05 15:14:42,479-INFO: epoch[0]-batch[30] - loss: 3.34660744667; acc_top1: 0.3671875; acc_top5: 0.609375; time: 1.20717287064
2020-06-05 15:14:56,196-INFO: epoch[0]-batch[40] - loss: 3.69098854065; acc_top1: 0.2890625; acc_top5: 0.5546875; time: 1.29232215881
2020-06-05 15:15:09,815-INFO: epoch[0]-batch[50] - loss: 3.5337202549; acc_top1: 0.30078125; acc_top5: 0.5546875; time: 1.34358000755
2020-06-05 15:15:23,550-INFO: epoch[0]-batch[60] - loss: 3.22006082535; acc_top1: 0.359375; acc_top5: 0.609375; time: 1.34181118011
2020-06-05 15:15:37,425-INFO: epoch[0]-batch[70] - loss: 3.06894540787; acc_top1: 0.4375; acc_top5: 0.65625; time: 1.33122491837
2020-06-05 15:15:51,161-INFO: epoch[0]-batch[80] - loss: 3.00548839569; acc_top1: 0.3828125; acc_top5: 0.6328125; time: 1.27601099014
2020-06-05 15:16:05,158-INFO: epoch[0]-batch[90] - loss: 2.52197813988; acc_top1: 0.484375; acc_top5: 0.71484375; time: 1.28280210495
```
可以看到普通量化loss不稳定,而且在实验进行到2个epoch时,loss会变为nan。普通量化很不稳定
使用PACT量化训练
```
python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/MobileNetV3_large_x1_0_ssld_pretrained --checkpoint_dir ./output/MobileNetV3_large_x1_0 --num_epochs 30 --lr 0.0001 --use_pact True --batch_size 128 --lr_strategy=piecewise_decay --step_epochs 20 --l2_decay 1e-5
```
输出结果为
```
2020-06-05 15:25:37,647-INFO: epoch[0]-batch[10] - loss: 1.60160636902; acc_top1: 0.65625; acc_top5: 0.890625; time: 1.56788897514
2020-06-05 15:25:53,191-INFO: epoch[0]-batch[20] - loss: 1.4748904705; acc_top1: 0.6484375; acc_top5: 0.84375; time: 1.4936029911
2020-06-05 15:26:08,598-INFO: epoch[0]-batch[30] - loss: 1.427333951; acc_top1: 0.6953125; acc_top5: 0.875; time: 1.51066279411
2020-06-05 15:26:24,009-INFO: epoch[0]-batch[40] - loss: 1.43955898285; acc_top1: 0.6640625; acc_top5: 0.8671875; time: 1.49221611023
2020-06-05 15:26:39,501-INFO: epoch[0]-batch[50] - loss: 1.29342699051; acc_top1: 0.6953125; acc_top5: 0.90625; time: 1.50851297379
2020-06-05 15:26:54,927-INFO: epoch[0]-batch[60] - loss: 1.49478590488; acc_top1: 0.6171875; acc_top5: 0.875; time: 1.50131177902
2020-06-05 15:27:10,250-INFO: epoch[0]-batch[70] - loss: 1.34970903397; acc_top1: 0.7109375; acc_top5: 0.890625; time: 1.51333618164
2020-06-05 15:27:25,309-INFO: epoch[0]-batch[80] - loss: 1.51600492001; acc_top1: 0.6796875; acc_top5: 0.859375; time: 1.44952607155
2020-06-05 15:27:40,273-INFO: epoch[0]-batch[90] - loss: 1.5926772356; acc_top1: 0.6328125; acc_top5: 0.859375; time: 1.45620679855
2020-06-05 15:27:55,660-INFO: epoch[0]-batch[100] - loss: 1.40280032158; acc_top1: 0.671875; acc_top5: 0.875; time: 1.50846099854
```
可以看出loss值比较稳定,并且我们在实验时,可以得到top-1 77.5%的量化模型。除了上述命令中的配置外,还要设置为 `pact` 初始阈值为20。量化模型可点击[下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/mobilenetv3_pact_quant.tar)
import sys
import paddle
import paddle.fluid as fluid
from paddleslim.quant import quant_aware, convert
import numpy as np
from paddle.fluid.layer_helper import LayerHelper
def pact(x, name=None):
helper = LayerHelper("pact", **locals())
dtype = 'float32'
init_thres = 20
u_param_attr = fluid.ParamAttr(
name=x.name + '_pact',
initializer=fluid.initializer.ConstantInitializer(value=init_thres),
regularizer=fluid.regularizer.L2Decay(0.0001),
learning_rate=1)
u_param = helper.create_parameter(
attr=u_param_attr, shape=[1], dtype=dtype)
x = fluid.layers.elementwise_sub(
x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param)))
x = fluid.layers.elementwise_add(
x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x)))
return x
def get_optimizer():
return fluid.optimizer.MomentumOptimizer(0.0001, 0.9)
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
import paddle.fluid as fluid
sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
from paddleslim.common import get_logger
from paddleslim.analysis import flops
from paddleslim.quant import quant_aware, quant_post, convert
import models
from utility import add_arguments, print_arguments
from pact import *
quantization_model_save_dir = './quantization_models/'
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64 * 4,
"Minibatch size.")
add_arg('use_gpu', bool, True,
"Whether to use GPU or not.")
add_arg('model', str, "MobileNet",
"The target model.")
add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretrained",
"Whether to use pretrained model.")
add_arg('lr', float, 0.0001,
"The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay",
"The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5,
"The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9,
"The value of momentum_rate.")
add_arg('num_epochs', int, 1,
"The number of total epochs.")
add_arg('total_images', int, 1281167,
"The number of total training images.")
parser.add_argument('--step_epochs', nargs='+', type=int,
default=[30, 60, 90],
help="piecewise decay step")
add_arg('config_file', str, None,
"The config file for compression with yaml format.")
add_arg('data', str, "imagenet",
"Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period', int, 10,
"Log period in batches.")
add_arg('checkpoint_dir', str, "output",
"checkpoint save dir")
add_arg('use_pact', bool, True,
"Whether to use PACT or not.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def piecewise_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
bd = [step * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
def cosine_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
learning_rate = fluid.layers.cosine_decay(
learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
def create_optimizer(args):
if args.lr_strategy == "piecewise_decay":
return piecewise_decay(args)
elif args.lr_strategy == "cosine_decay":
return cosine_decay(args)
def compress(args):
# 1. quantization configs
quant_config = {
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8,
# ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
}
train_reader = None
test_reader = None
if args.data == "mnist":
import paddle.dataset.mnist as reader
train_reader = reader.train()
val_reader = reader.test()
class_dim = 10
image_shape = "1,28,28"
elif args.data == "imagenet":
import imagenet_reader as reader
train_reader = reader.train()
val_reader = reader.val()
class_dim = 1000
image_shape = "3,224,224"
else:
raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(
args.model, model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
if args.use_pact:
image.stop_gradient = False
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
train_prog = fluid.default_main_program()
val_program = fluid.default_main_program().clone(for_test=True)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
opt = create_optimizer(args)
opt.minimize(avg_cost)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
# 2. quantization transform programs (training aware)
# Make some quantization transforms in the graph before training and testing.
# According to the weight and activation quantization type, the graph will be added
# some fake quantize operators and fake dequantize operators.
if args.use_pact:
act_preprocess_func = pact
optimizer_func = get_optimizer
executor = exe
else:
act_preprocess_func = None
optimizer_func = None
executor = None
val_program = quant_aware(
val_program,
place,
quant_config,
scope=None,
act_preprocess_func=act_preprocess_func,
optimizer_func=optimizer_func,
executor=executor,
for_test=True)
compiled_train_prog = quant_aware(
train_prog,
place,
quant_config,
scope=None,
act_preprocess_func=act_preprocess_func,
optimizer_func=optimizer_func,
executor=executor,
for_test=False)
assert os.path.exists(
args.pretrained_model), "pretrained_model doesn't exist"
if args.pretrained_model:
def if_exist(var):
return os.path.exists(
os.path.join(args.pretrained_model, var.name))
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
val_reader = paddle.fluid.io.batch(val_reader, batch_size=args.batch_size)
train_reader = paddle.fluid.io.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
train_feeder = feeder = fluid.DataFeeder([image, label], place)
val_feeder = feeder = fluid.DataFeeder(
[image, label], place, program=val_program)
def test(epoch, program):
batch_id = 0
acc_top1_ns = []
acc_top5_ns = []
for data in val_reader():
start_time = time.time()
acc_top1_n, acc_top5_n = exe.run(
program,
feed=train_feeder.feed(data),
fetch_list=[acc_top1.name, acc_top5.name])
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
"Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id,
np.mean(acc_top1_n),
np.mean(acc_top5_n), end_time - start_time))
acc_top1_ns.append(np.mean(acc_top1_n))
acc_top5_ns.append(np.mean(acc_top5_n))
batch_id += 1
_logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".
format(epoch,
np.mean(np.array(acc_top1_ns)),
np.mean(np.array(acc_top5_ns))))
return np.mean(np.array(acc_top1_ns))
def train(epoch, compiled_train_prog):
batch_id = 0
for data in train_reader():
start_time = time.time()
loss_n, acc_top1_n, acc_top5_n = exe.run(
compiled_train_prog,
feed=train_feeder.feed(data),
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
end_time = time.time()
loss_n = np.mean(loss_n)
acc_top1_n = np.mean(acc_top1_n)
acc_top5_n = np.mean(acc_top5_n)
if batch_id % args.log_period == 0:
_logger.info(
"epoch[{}]-batch[{}] - loss: {}; acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n,
end_time - start_time))
if args.use_pact and batch_id % 1000 == 0:
threshold = {}
for var in val_program.list_vars():
if 'pact' in var.name:
array = np.array(fluid.global_scope().find_var(
var.name).get_tensor())
threshold[var.name] = array[0]
print(threshold)
batch_id += 1
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
build_strategy.fuse_all_reduce_ops = False
build_strategy.sync_batch_norm = False
exec_strategy = fluid.ExecutionStrategy()
compiled_train_prog = compiled_train_prog.with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
# train loop
best_acc1 = 0.0
best_epoch = 0
for i in range(args.num_epochs):
train(i, compiled_train_prog)
acc1 = test(i, val_program)
fluid.io.save_persistables(
exe,
dirname=os.path.join(args.checkpoint_dir, str(i)),
main_program=val_program)
if acc1 > best_acc1:
best_acc1 = acc1
best_epoch = i
fluid.io.save_persistables(
exe,
dirname=os.path.join(args.checkpoint_dir, 'best_model'),
main_program=val_program)
if os.path.exists(os.path.join(args.checkpoint_dir, 'best_model')):
fluid.io.load_persistables(
exe,
dirname=os.path.join(args.checkpoint_dir, 'best_model'),
main_program=val_program)
# 3. Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
# The dtype of float_program's weights is float32, but in int8 range.
float_program, int8_program = convert(val_program, place, quant_config, \
scope=None, \
save_int8=True)
print("eval best_model after convert")
final_acc1 = test(best_epoch, float_program)
# 4. Save inference model
model_path = os.path.join(quantization_model_save_dir, args.model,
'act_' + quant_config['activation_quantize_type']
+ '_w_' + quant_config['weight_quantize_type'])
float_path = os.path.join(model_path, 'float')
int8_path = os.path.join(model_path, 'int8')
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_inference_model(
dirname=float_path,
feeded_var_names=[image.name],
target_vars=[out],
executor=exe,
main_program=float_program,
model_filename=float_path + '/model',
params_filename=float_path + '/params')
fluid.io.save_inference_model(
dirname=int8_path,
feeded_var_names=[image.name],
target_vars=[out],
executor=exe,
main_program=int8_program,
model_filename=int8_path + '/model',
params_filename=int8_path + '/params')
def main():
args = parser.parse_args()
print_arguments(args)
compress(args)
if __name__ == '__main__':
main()
......@@ -23,7 +23,7 @@ quant_post_dynamic
.. py:function:: paddleslim.quant.quant_post_dynamic(model_dir, save_model_dir, model_filename=None, params_filename=None, save_model_filename=None, save_params_filename=None, quantizable_op_type=["conv2d", "mul"], weight_bits=8, generate_test_model=False)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/release/1.1.0/paddleslim/quant/quanter.py>`_
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py>`_
动态离线量化,将模型中特定OP的权重从FP32类型量化成INT8/16类型。
......@@ -99,7 +99,7 @@ quant_post_static
.. py:function:: paddleslim.quant.quant_post_static(executor,model_dir, quantize_model_path, batch_generator=None, sample_generator=None, model_filename=None, params_filename=None, save_model_filename='__model__', save_params_filename='__params__', batch_size=16, batch_nums=None, scope=None, algo='KL', quantizable_op_type=["conv2d","depthwise_conv2d","mul"], is_full_quantize=False, weight_bits=8, activation_bits=8, activation_quantize_type='range_abs_max', weight_quantize_type='channel_wise_abs_max', is_use_cache_file=False, cache_dir="./temp_post_training")
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/release/1.1.0/paddleslim/quant/quanter.py>`_
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py>`_
静态离线量化,使用少量校准数据计算量化因子,可以快速得到量化模型。使用该量化模型进行预测,可以减少计算量、降低计算内存、减小模型大小。
......@@ -184,7 +184,7 @@ quant_post_static
batch_size=16,
batch_nums=10)
更详细的用法请参考 `离线量化demo <https://github.com/PaddlePaddle/PaddleSlim/tree/release/1.1.0/demo/quant/quant_post>`_ 。
更详细的用法请参考 `离线量化demo <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_post>`_ 。
......@@ -192,9 +192,9 @@ quant_post_static
quant_aware
------------
.. py:function:: paddleslim.quant.quant_aware(program, place, config, scope=None, for_test=False)
.. py:function:: paddleslim.quant.quant_aware(program, place, config, scope=None, for_test=False, weight_quantize_func=None, act_quantize_func=None, weight_preprocess_func=None, act_preprocess_func=None, optimizer_func=None, executor=None))
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/release/1.1.0/paddleslim/quant/quanter.py>`_
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py>`_
在 program 中加入量化和反量化op, 用于量化训练。
......@@ -206,6 +206,14 @@ quant_aware
- **config(dict)** - 量化配置表。
- **scope(fluid.Scope, optional)** - 传入用于存储 ``Variable`` 的 ``scope`` ,需要传入 ``program`` 所使用的 ``scope`` ,一般情况下,是 `fluid.global_scope() <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_ 。设置为 ``None`` 时将使用 `fluid.global_scope() <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_ ,默认值为 ``None`` 。
- **for_test(bool)** - 如果 ``program`` 参数是一个测试 ``program`` , ``for_test`` 应设为True,否则设为False 。
- **weight_quantize_func(function)** - 自定义对权重量化的函数,该函数的输入是待量化的权重,输出是反量化之后的权重,可以快速验证此量化函数是否有效。此参数设置后,将会替代量化配置中 `weight_quantize_type` 定义的方法,如果此参数不设置,将继续使用 `weight_quantize_type` 定义的方法。默认为None。
- **act_quantize_func(function)** - 自定义对激活量化的函数,该函数的输入是待量化的激活,输出是反量化之后的激活,可以快速验证此量化函数是否有效。将会替代量化配置中 `activation_quantize_type` 定义的方法,如果此参数不设置,将继续使用 `activation_quantize_type` 定义的方法。默认为None.
- **weight_preprocess_func(function)** - 自定义在对权重做量化之前,对权重进行处理的函数。此方法的意义在于网络中的参数不一定适合于直接量化,如果对参数分布先进行处理再进行量化,或许可以提高量化精度。默认为None.
- **act_preprocess_func(function)** - 自定义在对激活做量化之前,对激活进行处理的函数。此方法的意义在于网络中的激活值不一定适合于直接量化,如果对激活值先进行处理再进行量化,或许可以提高量化精度。默认为None.
- **optimizer_func(function)** - 该参数是一个返回optimizer的函数。定义的optimizer函数将用于定义上述自定义函数中的参数的优化参数。默认为None.
- **executor(fluid.Executor)** - 用于初始化上述自定义函数中的变量。默认为None.
**返回**
......@@ -230,7 +238,7 @@ convert
.. py:function:: paddleslim.quant.convert(program, place, config, scope=None, save_int8=False)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/release/1.1.0/paddleslim/quant/quanter.py>`_
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py>`_
把训练好的量化 program ,转换为可用于保存 ``inference model`` 的 program 。
......@@ -294,7 +302,7 @@ convert
inference_prog = quant.convert(quant_eval_program, place, config)
更详细的用法请参考 `量化训练demo <https://github.com/PaddlePaddle/PaddleSlim/tree/release/1.1.0/demo/quant/quant_aware>`_ 。
更详细的用法请参考 `量化训练demo <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_aware>`_ 。
量化训练方法的参数配置
......@@ -367,7 +375,7 @@ quant_embedding
.. py:function:: paddleslim.quant.quant_embedding(program, place, config=None, scope=None)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/release/1.1.0/paddleslim/quant/quant_embedding.py>`_
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quant_embedding.py>`_
对 ``Embedding`` 参数进行量化。
......@@ -418,6 +426,6 @@ fluid.Program
}
quant_program = quant.quant_embedding(infer_program, place, config)
更详细的用法请参考 `Embedding量化demo <https://github.com/PaddlePaddle/PaddleSlim/tree/release/1.1.0/demo/quant/quant_embedding>`_
更详细的用法请参考 `Embedding量化demo <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_embedding>`_
......@@ -160,7 +160,17 @@ def _parse_configs(user_config):
return configs
def quant_aware(program, place, config=None, scope=None, for_test=False):
def quant_aware(program,
place,
config=None,
scope=None,
for_test=False,
weight_quantize_func=None,
act_quantize_func=None,
weight_preprocess_func=None,
act_preprocess_func=None,
optimizer_func=None,
executor=None):
"""Add quantization and dequantization operators to "program"
for quantization training or testing.
......@@ -175,7 +185,32 @@ def quant_aware(program, place, config=None, scope=None, for_test=False):
`fluid.global_scope <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_. When ``None`` will use `fluid.global_scope() <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_ . Default: ``None``.
for_test(bool): If the 'program' parameter is a test program, this parameter should be set to ``True``.
Otherwise, set to ``False``.Default: False
weight_quantize_func(function): Function that defines how to quantize weight. Using this
can quickly test if user's quantization method works or not. In this function, user should
both define quantization function and dequantization function, that is, the function's input
is non-quantized weight and function returns dequantized weight. If None, will use
quantization op defined by 'weight_quantize_type'.
Default is None.
act_quantize_func(function): Function that defines how to quantize activation. Using this
can quickly test if user's quantization method works or not. In this function, user should
both define quantization and dequantization process, that is, the function's input
is non-quantized activation and function returns dequantized activation. If None, will use
quantization op defined by 'activation_quantize_type'.
Default is None.
weight_preprocess_func(function): Function that defines how to preprocess weight before quantization. Using this
can quickly test if user's preprocess method works or not. The function's input
is non-quantized weight and function returns processed weight to be quantized. If None, the weight will
be quantized directly.
Default is None.
act_preprocess_func(function): Function that defines how to preprocess activation before quantization. Using this
can quickly test if user's preprocess method works or not. The function's input
is non-quantized activation and function returns processed activation to be quantized. If None, the activation will
be quantized directly.
Default is None.
optimizer_func(function): Fuction return a optimizer. When 'is_test' is False and user want to use self-defined
quantization function and preprocess function, this function must be set. Default is None.
exe(Fluid.Executor): If user want to use self-defined quantization function and preprocess function, exe must be set for
initialization. Default is None.
Returns:
fluid.CompiledProgram | fluid.Program: Program with quantization and dequantization ``operators``
"""
......@@ -208,7 +243,13 @@ def quant_aware(program, place, config=None, scope=None, for_test=False):
window_size=config['window_size'],
moving_rate=config['moving_rate'],
quantizable_op_type=transform_pass_ops,
skip_pattern=config['not_quant_pattern'])
skip_pattern=config['not_quant_pattern'],
weight_quantize_func=weight_quantize_func,
act_quantize_func=act_quantize_func,
weight_preprocess_func=weight_preprocess_func,
act_preprocess_func=act_preprocess_func,
optimizer_func=optimizer_func,
executor=executor)
transform_pass.apply(main_graph)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle
import paddle.fluid as fluid
from paddleslim.quant import quant_aware, convert
sys.path.append("../demo")
from models import MobileNet
from layers import conv_bn_layer
import paddle.dataset.mnist as reader
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
import numpy as np
from paddle.fluid.layer_helper import LayerHelper
def pact(x, name=None):
helper = LayerHelper("pact", **locals())
dtype = 'float32'
init_thres = 20
u_param_attr = fluid.ParamAttr(
name=x.name + '_pact',
initializer=fluid.initializer.ConstantInitializer(value=init_thres),
regularizer=fluid.regularizer.L2Decay(0.0001),
learning_rate=1)
u_param = helper.create_parameter(
attr=u_param_attr, shape=[1], dtype=dtype)
x = fluid.layers.elementwise_sub(
x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param)))
x = fluid.layers.elementwise_add(
x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x)))
return x
def get_optimizer():
return fluid.optimizer.MomentumOptimizer(0.0001, 0.9)
class TestQuantAwareCase1(unittest.TestCase):
def get_model(self):
image = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
model = MobileNet()
out = model.net(input=image, class_dim=10)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
startup_prog = fluid.default_startup_program()
train_prog = fluid.default_main_program()
return startup_prog, train_prog
def test_accuracy(self):
image = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32')
image.stop_gradient = False
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
model = MobileNet()
out = model.net(input=image, class_dim=10)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
optimizer = fluid.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
regularization=fluid.regularizer.L2Decay(4e-5))
optimizer.minimize(avg_cost)
main_prog = fluid.default_main_program()
val_prog = main_prog.clone(for_test=True)
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
feeder = fluid.DataFeeder([image, label], place, program=main_prog)
train_reader = paddle.fluid.io.batch(
paddle.dataset.mnist.train(), batch_size=64)
eval_reader = paddle.fluid.io.batch(
paddle.dataset.mnist.test(), batch_size=64)
def train(program):
iter = 0
for data in train_reader():
cost, top1, top5 = exe.run(
program,
feed=feeder.feed(data),
fetch_list=[avg_cost, acc_top1, acc_top5])
iter += 1
if iter % 100 == 0:
print(
'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
format(iter, cost, top1, top5))
def test(program):
iter = 0
result = [[], [], []]
for data in eval_reader():
cost, top1, top5 = exe.run(
program,
feed=feeder.feed(data),
fetch_list=[avg_cost, acc_top1, acc_top5])
iter += 1
if iter % 100 == 0:
print(
'eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
format(iter, cost, top1, top5))
result[0].append(cost)
result[1].append(top1)
result[2].append(top5)
print(' avg loss {}, acc_top1 {}, acc_top5 {}'.format(
np.mean(result[0]), np.mean(result[1]), np.mean(result[2])))
return np.mean(result[1]), np.mean(result[2])
train(main_prog)
top1_1, top5_1 = test(main_prog)
config = {
'weight_quantize_type': 'channel_wise_abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'],
}
quant_train_prog_pact = quant_aware(
main_prog,
place,
config,
for_test=False,
act_preprocess_func=pact,
optimizer_func=get_optimizer,
executor=exe)
quant_eval_prog = quant_aware(val_prog, place, config, for_test=True)
train(quant_train_prog_pact)
quant_eval_prog, int8_prog = convert(
quant_eval_prog, place, config, save_int8=True)
top1_2, top5_2 = test(quant_eval_prog)
# values before quantization and after quantization should be close
print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
print("after quantization: top1: {}, top5: {}".format(top1_2, top5_2))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册