diff --git a/demo/models/__init__.py b/demo/models/__init__.py
index 9be359324d76017548ad438481880e456d93d349..b6771d7086bb150742c4a7198f2224f63d603e8e 100644
--- a/demo/models/__init__.py
+++ b/demo/models/__init__.py
@@ -5,6 +5,7 @@ from .resnet_vd import ResNet50_vd, ResNet101_vd
from .mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2
from .pvanet import PVANet
from .slimfacenet import SlimFaceNet_A_x0_60, SlimFaceNet_B_x0_75, SlimFaceNet_C_x0_75
+from .mobilenet_v3 import *
__all__ = [
"model_list", "MobileNet", "ResNet34", "ResNet50", "MobileNetV2", "PVANet",
"ResNet50_vd", "ResNet101_vd", "MobileNetV2_x0_25"
@@ -13,3 +14,6 @@ model_list = [
'MobileNet', 'ResNet34', 'ResNet50', 'MobileNetV2', 'PVANet',
'ResNet50_vd', "ResNet101_vd", "MobileNetV2_x0_25"
]
+
+__all__ += mobilenet_v3.__all__
+model_list += mobilenet_v3.__all__
diff --git a/demo/models/mobilenet_v3.py b/demo/models/mobilenet_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..3276b352744a199ee858d193cb46e1b5ce36bca7
--- /dev/null
+++ b/demo/models/mobilenet_v3.py
@@ -0,0 +1,303 @@
+import paddle.fluid as fluid
+from paddle.fluid.initializer import MSRA
+from paddle.fluid.param_attr import ParamAttr
+import math
+
+__all__ = [
+ 'MobileNetV3', 'MobileNetV3_small_x0_25', 'MobileNetV3_small_x0_5',
+ 'MobileNetV3_small_x0_75', 'MobileNetV3_small_x1_0',
+ 'MobileNetV3_small_x1_25', 'MobileNetV3_large_x0_25',
+ 'MobileNetV3_large_x0_5', 'MobileNetV3_large_x0_75',
+ 'MobileNetV3_large_x1_0', 'MobileNetV3_large_x1_25',
+ 'MobileNetV3_large_x2_0'
+]
+
+
+class MobileNetV3():
+ def __init__(self, scale=1.0, model_name='small'):
+ self.scale = scale
+ self.inplanes = 16
+ if model_name == "large":
+ self.cfg = [
+ # k, exp, c, se, nl, s,
+ [3, 16, 16, False, 'relu', 1],
+ [3, 64, 24, False, 'relu', 2],
+ [3, 72, 24, False, 'relu', 1],
+ [5, 72, 40, True, 'relu', 2],
+ [5, 120, 40, True, 'relu', 1],
+ [5, 120, 40, True, 'relu', 1],
+ [3, 240, 80, False, 'hard_swish', 2],
+ [3, 200, 80, False, 'hard_swish', 1],
+ [3, 184, 80, False, 'hard_swish', 1],
+ [3, 184, 80, False, 'hard_swish', 1],
+ [3, 480, 112, True, 'hard_swish', 1],
+ [3, 672, 112, True, 'hard_swish', 1],
+ [5, 672, 160, True, 'hard_swish', 2],
+ [5, 960, 160, True, 'hard_swish', 1],
+ [5, 960, 160, True, 'hard_swish', 1],
+ ]
+ self.cls_ch_squeeze = 960
+ self.cls_ch_expand = 1280
+ elif model_name == "small":
+ self.cfg = [
+ # k, exp, c, se, nl, s,
+ [3, 16, 16, True, 'relu', 2],
+ [3, 72, 24, False, 'relu', 2],
+ [3, 88, 24, False, 'relu', 1],
+ [5, 96, 40, True, 'hard_swish', 2],
+ [5, 240, 40, True, 'hard_swish', 1],
+ [5, 240, 40, True, 'hard_swish', 1],
+ [5, 120, 48, True, 'hard_swish', 1],
+ [5, 144, 48, True, 'hard_swish', 1],
+ [5, 288, 96, True, 'hard_swish', 2],
+ [5, 576, 96, True, 'hard_swish', 1],
+ [5, 576, 96, True, 'hard_swish', 1],
+ ]
+ self.cls_ch_squeeze = 576
+ self.cls_ch_expand = 1280
+ else:
+ raise NotImplementedError
+
+ def net(self, input, class_dim=1000):
+ scale = self.scale
+ inplanes = self.inplanes
+ cfg = self.cfg
+ cls_ch_squeeze = self.cls_ch_squeeze
+ cls_ch_expand = self.cls_ch_expand
+
+ #conv1
+ conv = self.conv_bn_layer(
+ input,
+ filter_size=3,
+ #num_filters=int(scale*inplanes),
+ num_filters=inplanes if scale <= 1.0 else int(inplanes * scale),
+ stride=2,
+ padding=1,
+ num_groups=1,
+ if_act=True,
+ act='hard_swish',
+ name='conv1')
+ print(conv.shape)
+ i = 0
+ for layer_cfg in cfg:
+ conv = self.residual_unit(
+ input=conv,
+ num_in_filter=inplanes,
+ num_mid_filter=int(scale * layer_cfg[1]),
+ num_out_filter=int(scale * layer_cfg[2]),
+ act=layer_cfg[4],
+ stride=layer_cfg[5],
+ filter_size=layer_cfg[0],
+ use_se=layer_cfg[3],
+ name='conv' + str(i + 2))
+
+ inplanes = int(scale * layer_cfg[2])
+ i += 1
+
+ conv = self.conv_bn_layer(
+ input=conv,
+ filter_size=1,
+ num_filters=int(scale * cls_ch_squeeze),
+ stride=1,
+ padding=0,
+ num_groups=1,
+ if_act=True,
+ act='hard_swish',
+ name='conv_last')
+ conv = fluid.layers.pool2d(
+ input=conv, pool_type='avg', global_pooling=True, use_cudnn=False)
+ conv = fluid.layers.conv2d(
+ input=conv,
+ num_filters=cls_ch_expand,
+ filter_size=1,
+ stride=1,
+ padding=0,
+ act=None,
+ param_attr=ParamAttr(name='last_1x1_conv_weights'),
+ bias_attr=False)
+ #conv = fluid.layers.hard_swish(conv)
+ conv = self.hard_swish(conv)
+ out = fluid.layers.fc(input=conv,
+ size=class_dim,
+ act='softmax',
+ param_attr=ParamAttr(name='fc_weights'),
+ bias_attr=ParamAttr(name='fc_offset'))
+ return out
+
+ def conv_bn_layer(self,
+ input,
+ filter_size,
+ num_filters,
+ stride,
+ padding,
+ num_groups=1,
+ if_act=True,
+ act=None,
+ name=None,
+ use_cudnn=True):
+ conv = fluid.layers.conv2d(
+ input=input,
+ num_filters=num_filters,
+ filter_size=filter_size,
+ stride=stride,
+ padding=padding,
+ groups=num_groups,
+ act=None,
+ use_cudnn=use_cudnn,
+ param_attr=ParamAttr(name=name + '_weights'),
+ bias_attr=False)
+ bn_name = name + '_bn'
+ bn = fluid.layers.batch_norm(
+ input=conv,
+ param_attr=ParamAttr(
+ name=bn_name + "_scale",
+ regularizer=fluid.regularizer.L2DecayRegularizer(
+ regularization_coeff=0.0)),
+ bias_attr=ParamAttr(
+ name=bn_name + "_offset",
+ regularizer=fluid.regularizer.L2DecayRegularizer(
+ regularization_coeff=0.0)),
+ moving_mean_name=bn_name + '_mean',
+ moving_variance_name=bn_name + '_variance')
+ if if_act:
+ if act == 'relu':
+ bn = fluid.layers.relu(bn)
+ elif act == 'hard_swish':
+ #bn = fluid.layers.hard_swish(bn)
+ bn = self.hard_swish(bn)
+ return bn
+
+ def hard_swish(self, x):
+ return x * fluid.layers.relu6(x + 3) / 6.
+
+ def se_block(self, input, num_out_filter, ratio=4, name=None):
+ num_mid_filter = int(num_out_filter // ratio)
+ pool = fluid.layers.pool2d(
+ input=input, pool_type='avg', global_pooling=True, use_cudnn=False)
+ conv1 = fluid.layers.conv2d(
+ input=pool,
+ filter_size=1,
+ num_filters=num_mid_filter,
+ act='relu',
+ param_attr=ParamAttr(name=name + '_1_weights'),
+ bias_attr=ParamAttr(name=name + '_1_offset'))
+ conv2 = fluid.layers.conv2d(
+ input=conv1,
+ filter_size=1,
+ num_filters=num_out_filter,
+ act='hard_sigmoid',
+ param_attr=ParamAttr(name=name + '_2_weights'),
+ bias_attr=ParamAttr(name=name + '_2_offset'))
+
+ scale = fluid.layers.elementwise_mul(x=input, y=conv2, axis=0)
+ return scale
+
+ def residual_unit(self,
+ input,
+ num_in_filter,
+ num_mid_filter,
+ num_out_filter,
+ stride,
+ filter_size,
+ act=None,
+ use_se=False,
+ name=None):
+
+ input_data = input
+ conv0 = self.conv_bn_layer(
+ input=input,
+ filter_size=1,
+ num_filters=num_mid_filter,
+ stride=1,
+ padding=0,
+ if_act=True,
+ act=act,
+ name=name + '_expand')
+
+ conv1 = self.conv_bn_layer(
+ input=conv0,
+ filter_size=filter_size,
+ num_filters=num_mid_filter,
+ stride=stride,
+ padding=int((filter_size - 1) // 2),
+ if_act=True,
+ act=act,
+ num_groups=num_mid_filter,
+ use_cudnn=False,
+ name=name + '_depthwise')
+
+ if use_se:
+ with fluid.name_scope('se_block_skip'):
+ conv1 = self.se_block(
+ input=conv1,
+ num_out_filter=num_mid_filter,
+ name=name + '_se')
+
+ conv2 = self.conv_bn_layer(
+ input=conv1,
+ filter_size=1,
+ num_filters=num_out_filter,
+ stride=1,
+ padding=0,
+ if_act=False,
+ name=name + '_linear')
+ if num_in_filter != num_out_filter or stride != 1:
+ return conv2
+ else:
+ return fluid.layers.elementwise_add(
+ x=input_data, y=conv2, act=None)
+
+
+def MobileNetV3_small_x0_25():
+ model = MobileNetV3(model_name='small', scale=0.25)
+ return model
+
+
+def MobileNetV3_small_x0_5():
+ model = MobileNetV3(model_name='small', scale=0.5)
+ return model
+
+
+def MobileNetV3_small_x0_75():
+ model = MobileNetV3(model_name='small', scale=0.75)
+ return model
+
+
+def MobileNetV3_small_x1_0():
+ model = MobileNetV3(model_name='small', scale=1.0)
+ return model
+
+
+def MobileNetV3_small_x1_25():
+ model = MobileNetV3(model_name='small', scale=1.25)
+ return model
+
+
+def MobileNetV3_large_x0_25():
+ model = MobileNetV3(model_name='large', scale=0.25)
+ return model
+
+
+def MobileNetV3_large_x0_5():
+ model = MobileNetV3(model_name='large', scale=0.5)
+ return model
+
+
+def MobileNetV3_large_x0_75():
+ model = MobileNetV3(model_name='large', scale=0.75)
+ return model
+
+
+def MobileNetV3_large_x1_0():
+ model = MobileNetV3(model_name='large', scale=1.0)
+ return model
+
+
+def MobileNetV3_large_x1_25():
+ model = MobileNetV3(model_name='large', scale=1.25)
+ return model
+
+
+def MobileNetV3_large_x2_0():
+ model = MobileNetV3(model_name='large', scale=2.0)
+ return model
diff --git a/demo/quant/pact_quant_aware/README.md b/demo/quant/pact_quant_aware/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c494d42cf94d6ce4d66a776a9e3896fa8fddd8b6
--- /dev/null
+++ b/demo/quant/pact_quant_aware/README.md
@@ -0,0 +1,198 @@
+# 自定义量化方法使用示例
+
+本示例介绍如何使用自定义量化方法,以PACT方法为例,量化训练好的分类模型MobileNetV3, 可以减少模型的存储空间和显存占用。
+
+## 方法介绍
+PACT(Parameterized Clipping Activation for Quantized Neural Networks)[论文地址](https://arxiv.org/abs/1805.06085)提出了在量化激活值之前去掉一些离群点来使量化精度提高。论文中给的PACT的公式是:
+
+
+
+
+
+因为论文中的思想是将PACT公式代替ReLU激活函数,但是在实际使用中,将要进行量化的激活值不一定来自ReLU激活函数,有可能是其他函数,也有可能是来自elementwise op等,所以本demo中的方法是在激活值和量化op之间加入改进后的PACT方法,公式如下:
+
+
+
+
+
+
+改进的原因是要量化的激活值不一定都是大于0,而量化时寻找的时激活值的最大值,所以小于0的值也要进行约束。
+
+### 定义PACT函数
+
+自定义量化方法支持对激活值或者权重定义预处理方式,同时也支持自定义量化方法。在 `quant_aware` 接口中,相关参数以及意义如下:
+
+- `weight_quantize_func`: 自定义对权重量化的函数,该函数的输入是待量化的权重,输出是反量化之后的权重,可以快速验证此量化函数是否有效。此参数设置后,将会替代量化配置中 `weight_quantize_type` 定义的方法,如果此参数不设置,将继续使用 `weight_quantize_type` 定义的方法。
+- `act_quantize_func`: 自定义对激活量化的函数,该函数的输入是待量化的激活,输出是反量化之后的激活,可以快速验证此量化函数是否有效。将会替代量化配置中 `activation_quantize_type` 定义的方法,如果此参数不设置,将继续使用 `activation_quantize_type` 定义的方法。
+
+- `weight_preprocess_func` : 自定义在对权重做量化之前,对权重进行处理的函数。此方法的意义在于网络中的参数不一定适合于直接量化,如果对参数分布先进行处理再进行量化,或许可以提高量化精度。
+
+- `act_preprocess_func` : 自定义在对激活做量化之前,对激活进行处理的函数。此方法的意义在于网络中的激活值不一定适合于直接量化,如果对激活值先进行处理再进行量化,或许可以提高量化精度。
+
+- `optimizer_func` : 该参数是一个返回optimizer的函数。定义的optimizer函数将用于定义上述自定义函数中的参数的优化参数。
+
+PACT方法属于自定义 `act_preprocess_func`, 输入是将要量化的激活值。
+
+可如下定义:
+
+```
+import paddle
+import paddle.fluid as fluid
+from paddle.fluid.layer_helper import LayerHelper
+
+def pact(x, name=None):
+ helper = LayerHelper("pact", **locals())
+ dtype = 'float32'
+ # 定义PACT初始阈值
+ init_thres = 20
+ u_param_attr = fluid.ParamAttr(
+ name=x.name + '_pact',
+ initializer=fluid.initializer.ConstantInitializer(value=init_thres),
+ regularizer=fluid.regularizer.L2Decay(0.0001),
+ learning_rate=1)
+ u_param = helper.create_parameter(
+ attr=u_param_attr, shape=[1], dtype=dtype)
+ x = fluid.layers.elementwise_sub(
+ x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param)))
+ x = fluid.layers.elementwise_add(
+ x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x)))
+
+ return x
+```
+
+函数中可以定义初始阈值,和初始阈值的l2正则项系数,在训练过程中可根据梯度传播训练阈值为一个合适的值。
+
+优化器函数如下:
+
+```
+def get_optimizer():
+ return fluid.optimizer.MomentumOptimizer(0.001, 0.9)
+```
+因为除了PACT阈值以外,其他参数都是训练好的,因此在训练时可以将PACT中阈值的学习率调大一些。
+
+> 注意,因为PACT只是在量化时去掉了离群点,影响了量化scale的选择,因此使用PACT训练后,可以用普通量化的方法加载参数进行测试,是一个不影响预测的方法。
+
+## MobileNetV3的量化训练流程
+
+### 准备数据
+
+在``demo``文件夹下创建``data``文件夹,将``ImageNet``数据集解压在``data``文件夹下,解压后``data/ILSVRC2012``文件夹下应包含以下文件:
+- ``'train'``文件夹,训练图片
+- ``'train_list.txt'``文件
+- ``'val'``文件夹,验证图片
+- ``'val_list.txt'``文件
+
+### 准备需要量化的模型
+
+我们将使用飞桨分类库[PaddleClas](https://github.com/PaddlePaddle/PaddleClas)中给出的MobileNetV3精度最高的模型进行量化。量化前精度top-1为78.9%.
+
+```
+mkdir pretrain
+cd pretrain
+wget https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar
+tar xf MobileNetV3_large_x1_0_ssld_pretrained.tar
+cd ..
+```
+
+使用该模型的原因是因为MobileNetV3这个使用ssld蒸馏之后的模型,激活值存在很多离群点,可有效地验证PACT的效果。下面的图是MobileNetV3的其中一个中间激活值分布的直方图:
+
+
+
+
+图中直方图的横坐标的范围是激活值分布的最小值和最大值,从图中可以看出,最小值在-60左右,最大值在80左右,但是主要分布在-20到20之间。
+
+
+### 开启 `image` 的梯度
+
+因为目前实现的原因,需要将 `image` 的梯度开启。
+
+```
+image.stop_gradient = False
+```
+
+### 配置量化参数
+
+```
+quant_config = {
+ 'weight_quantize_type': 'channel_wise_abs_max',
+ 'activation_quantize_type': 'moving_average_abs_max',
+ 'weight_bits': 8,
+ 'activation_bits': 8,
+ 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
+ 'dtype': 'int8',
+ 'window_size': 10000,
+ 'moving_rate': 0.9
+}
+```
+
+### 对训练和测试program插入可训练量化op
+
+普通量化:
+```
+val_program = quant_aware(val_program, place, quant_config, scope=None, for_test=True)
+
+compiled_train_prog = quant_aware(train_prog, place, quant_config, scope=None, for_test=False)
+```
+
+使用PACT的量化:
+```
+val_program = quant_aware(val_program, place, quant_config, scope=None, act_preprocess_func=pact, executor=exe, for_test=True)
+
+compiled_train_prog = quant_aware(train_prog, place, quant_config, scope=None, act_preprocess_func=pact, optimizer_func=get_optimizer, executor=exe, for_test=False)
+```
+
+### 关掉指定build策略
+
+```
+build_strategy = fluid.BuildStrategy()
+build_strategy.fuse_all_reduce_ops = False
+build_strategy.sync_batch_norm = False
+exec_strategy = fluid.ExecutionStrategy()
+compiled_train_prog = compiled_train_prog.with_data_parallel(
+ loss_name=avg_cost.name,
+ build_strategy=build_strategy,
+ exec_strategy=exec_strategy)
+```
+
+
+### 训练命令
+
+普通量化:
+```
+python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/MobileNetV3_large_x1_0_ssld_pretrained --checkpoint_dir ./output/MobileNetV3_large_x1_0 --num_epochs 30 --lr 0.0001 --use_pact False
+
+```
+
+输出结果为:
+```
+2020-06-05 15:14:15,319-INFO: epoch[0]-batch[10] - loss: 2.50413322449; acc_top1: 0.515625; acc_top5: 0.75; time: 1.29066705704
+2020-06-05 15:14:28,950-INFO: epoch[0]-batch[20] - loss: 3.14219880104; acc_top1: 0.3828125; acc_top5: 0.62890625; time: 1.29546618462
+2020-06-05 15:14:42,479-INFO: epoch[0]-batch[30] - loss: 3.34660744667; acc_top1: 0.3671875; acc_top5: 0.609375; time: 1.20717287064
+2020-06-05 15:14:56,196-INFO: epoch[0]-batch[40] - loss: 3.69098854065; acc_top1: 0.2890625; acc_top5: 0.5546875; time: 1.29232215881
+2020-06-05 15:15:09,815-INFO: epoch[0]-batch[50] - loss: 3.5337202549; acc_top1: 0.30078125; acc_top5: 0.5546875; time: 1.34358000755
+2020-06-05 15:15:23,550-INFO: epoch[0]-batch[60] - loss: 3.22006082535; acc_top1: 0.359375; acc_top5: 0.609375; time: 1.34181118011
+2020-06-05 15:15:37,425-INFO: epoch[0]-batch[70] - loss: 3.06894540787; acc_top1: 0.4375; acc_top5: 0.65625; time: 1.33122491837
+2020-06-05 15:15:51,161-INFO: epoch[0]-batch[80] - loss: 3.00548839569; acc_top1: 0.3828125; acc_top5: 0.6328125; time: 1.27601099014
+2020-06-05 15:16:05,158-INFO: epoch[0]-batch[90] - loss: 2.52197813988; acc_top1: 0.484375; acc_top5: 0.71484375; time: 1.28280210495
+```
+可以看到普通量化loss不稳定,而且在实验进行到2个epoch时,loss会变为nan。普通量化很不稳定
+
+使用PACT量化训练
+```
+python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/MobileNetV3_large_x1_0_ssld_pretrained --checkpoint_dir ./output/MobileNetV3_large_x1_0 --num_epochs 30 --lr 0.0001 --use_pact True --batch_size 128 --lr_strategy=piecewise_decay --step_epochs 20 --l2_decay 1e-5
+```
+
+输出结果为
+```
+2020-06-05 15:25:37,647-INFO: epoch[0]-batch[10] - loss: 1.60160636902; acc_top1: 0.65625; acc_top5: 0.890625; time: 1.56788897514
+2020-06-05 15:25:53,191-INFO: epoch[0]-batch[20] - loss: 1.4748904705; acc_top1: 0.6484375; acc_top5: 0.84375; time: 1.4936029911
+2020-06-05 15:26:08,598-INFO: epoch[0]-batch[30] - loss: 1.427333951; acc_top1: 0.6953125; acc_top5: 0.875; time: 1.51066279411
+2020-06-05 15:26:24,009-INFO: epoch[0]-batch[40] - loss: 1.43955898285; acc_top1: 0.6640625; acc_top5: 0.8671875; time: 1.49221611023
+2020-06-05 15:26:39,501-INFO: epoch[0]-batch[50] - loss: 1.29342699051; acc_top1: 0.6953125; acc_top5: 0.90625; time: 1.50851297379
+2020-06-05 15:26:54,927-INFO: epoch[0]-batch[60] - loss: 1.49478590488; acc_top1: 0.6171875; acc_top5: 0.875; time: 1.50131177902
+2020-06-05 15:27:10,250-INFO: epoch[0]-batch[70] - loss: 1.34970903397; acc_top1: 0.7109375; acc_top5: 0.890625; time: 1.51333618164
+2020-06-05 15:27:25,309-INFO: epoch[0]-batch[80] - loss: 1.51600492001; acc_top1: 0.6796875; acc_top5: 0.859375; time: 1.44952607155
+2020-06-05 15:27:40,273-INFO: epoch[0]-batch[90] - loss: 1.5926772356; acc_top1: 0.6328125; acc_top5: 0.859375; time: 1.45620679855
+2020-06-05 15:27:55,660-INFO: epoch[0]-batch[100] - loss: 1.40280032158; acc_top1: 0.671875; acc_top5: 0.875; time: 1.50846099854
+```
+可以看出loss值比较稳定,并且我们在实验时,可以得到top-1 77.5%的量化模型。除了上述命令中的配置外,还要设置为 `pact` 初始阈值为20。量化模型可点击[下载链接](https://paddlemodels.bj.bcebos.com/PaddleSlim/mobilenetv3_pact_quant.tar)。
diff --git a/demo/quant/pact_quant_aware/image/activation_dist.png b/demo/quant/pact_quant_aware/image/activation_dist.png
new file mode 100644
index 0000000000000000000000000000000000000000..9e133f8c6d9628d33410ce82d6ad4fa2233dd323
Binary files /dev/null and b/demo/quant/pact_quant_aware/image/activation_dist.png differ
diff --git a/demo/quant/pact_quant_aware/image/pact.png b/demo/quant/pact_quant_aware/image/pact.png
new file mode 100644
index 0000000000000000000000000000000000000000..86e0733fac37a968df73e24f1c9d2870be3e0988
Binary files /dev/null and b/demo/quant/pact_quant_aware/image/pact.png differ
diff --git a/demo/quant/pact_quant_aware/image/pact_our.png b/demo/quant/pact_quant_aware/image/pact_our.png
new file mode 100644
index 0000000000000000000000000000000000000000..62eefdb46bd634b832e0daf1d949dbe35d871406
Binary files /dev/null and b/demo/quant/pact_quant_aware/image/pact_our.png differ
diff --git a/demo/quant/pact_quant_aware/pact.py b/demo/quant/pact_quant_aware/pact.py
new file mode 100644
index 0000000000000000000000000000000000000000..26a2a5efd6e9b819db9b7134a62a1ac8c1fc296f
--- /dev/null
+++ b/demo/quant/pact_quant_aware/pact.py
@@ -0,0 +1,30 @@
+import sys
+import paddle
+import paddle.fluid as fluid
+from paddleslim.quant import quant_aware, convert
+import numpy as np
+
+from paddle.fluid.layer_helper import LayerHelper
+
+
+def pact(x, name=None):
+ helper = LayerHelper("pact", **locals())
+ dtype = 'float32'
+ init_thres = 20
+ u_param_attr = fluid.ParamAttr(
+ name=x.name + '_pact',
+ initializer=fluid.initializer.ConstantInitializer(value=init_thres),
+ regularizer=fluid.regularizer.L2Decay(0.0001),
+ learning_rate=1)
+ u_param = helper.create_parameter(
+ attr=u_param_attr, shape=[1], dtype=dtype)
+ x = fluid.layers.elementwise_sub(
+ x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param)))
+ x = fluid.layers.elementwise_add(
+ x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x)))
+
+ return x
+
+
+def get_optimizer():
+ return fluid.optimizer.MomentumOptimizer(0.0001, 0.9)
diff --git a/demo/quant/pact_quant_aware/train.py b/demo/quant/pact_quant_aware/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc8f0178e837b9416847ec5767e2dfd0cc9288b3
--- /dev/null
+++ b/demo/quant/pact_quant_aware/train.py
@@ -0,0 +1,345 @@
+import os
+import sys
+import logging
+import paddle
+import argparse
+import functools
+import math
+import time
+import numpy as np
+import paddle.fluid as fluid
+sys.path[0] = os.path.join(
+ os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
+from paddleslim.common import get_logger
+from paddleslim.analysis import flops
+from paddleslim.quant import quant_aware, quant_post, convert
+import models
+from utility import add_arguments, print_arguments
+from pact import *
+quantization_model_save_dir = './quantization_models/'
+
+_logger = get_logger(__name__, level=logging.INFO)
+
+parser = argparse.ArgumentParser(description=__doc__)
+add_arg = functools.partial(add_arguments, argparser=parser)
+# yapf: disable
+add_arg('batch_size', int, 64 * 4,
+ "Minibatch size.")
+add_arg('use_gpu', bool, True,
+ "Whether to use GPU or not.")
+add_arg('model', str, "MobileNet",
+ "The target model.")
+add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretrained",
+ "Whether to use pretrained model.")
+add_arg('lr', float, 0.0001,
+ "The learning rate used to fine-tune pruned model.")
+add_arg('lr_strategy', str, "piecewise_decay",
+ "The learning rate decay strategy.")
+add_arg('l2_decay', float, 3e-5,
+ "The l2_decay parameter.")
+add_arg('momentum_rate', float, 0.9,
+ "The value of momentum_rate.")
+add_arg('num_epochs', int, 1,
+ "The number of total epochs.")
+add_arg('total_images', int, 1281167,
+ "The number of total training images.")
+parser.add_argument('--step_epochs', nargs='+', type=int,
+ default=[30, 60, 90],
+ help="piecewise decay step")
+add_arg('config_file', str, None,
+ "The config file for compression with yaml format.")
+add_arg('data', str, "imagenet",
+ "Which data to use. 'mnist' or 'imagenet'")
+add_arg('log_period', int, 10,
+ "Log period in batches.")
+add_arg('checkpoint_dir', str, "output",
+ "checkpoint save dir")
+add_arg('use_pact', bool, True,
+ "Whether to use PACT or not.")
+
+# yapf: enable
+
+model_list = [m for m in dir(models) if "__" not in m]
+
+
+def piecewise_decay(args):
+ step = int(math.ceil(float(args.total_images) / args.batch_size))
+ bd = [step * e for e in args.step_epochs]
+ lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
+ learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
+ optimizer = fluid.optimizer.Momentum(
+ learning_rate=learning_rate,
+ momentum=args.momentum_rate,
+ regularization=fluid.regularizer.L2Decay(args.l2_decay))
+ return optimizer
+
+
+def cosine_decay(args):
+ step = int(math.ceil(float(args.total_images) / args.batch_size))
+ learning_rate = fluid.layers.cosine_decay(
+ learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs)
+ optimizer = fluid.optimizer.Momentum(
+ learning_rate=learning_rate,
+ momentum=args.momentum_rate,
+ regularization=fluid.regularizer.L2Decay(args.l2_decay))
+ return optimizer
+
+
+def create_optimizer(args):
+ if args.lr_strategy == "piecewise_decay":
+ return piecewise_decay(args)
+ elif args.lr_strategy == "cosine_decay":
+ return cosine_decay(args)
+
+
+def compress(args):
+ # 1. quantization configs
+ quant_config = {
+ # weight quantize type, default is 'channel_wise_abs_max'
+ 'weight_quantize_type': 'channel_wise_abs_max',
+ # activation quantize type, default is 'moving_average_abs_max'
+ 'activation_quantize_type': 'moving_average_abs_max',
+ # weight quantize bit num, default is 8
+ 'weight_bits': 8,
+ # activation quantize bit num, default is 8
+ 'activation_bits': 8,
+ # ops of name_scope in not_quant_pattern list, will not be quantized
+ 'not_quant_pattern': ['skip_quant'],
+ # ops of type in quantize_op_types, will be quantized
+ 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
+ # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
+ 'dtype': 'int8',
+ # window size for 'range_abs_max' quantization. defaulf is 10000
+ 'window_size': 10000,
+ # The decay coefficient of moving average, default is 0.9
+ 'moving_rate': 0.9,
+ }
+
+ train_reader = None
+ test_reader = None
+ if args.data == "mnist":
+ import paddle.dataset.mnist as reader
+ train_reader = reader.train()
+ val_reader = reader.test()
+ class_dim = 10
+ image_shape = "1,28,28"
+ elif args.data == "imagenet":
+ import imagenet_reader as reader
+ train_reader = reader.train()
+ val_reader = reader.val()
+ class_dim = 1000
+ image_shape = "3,224,224"
+ else:
+ raise ValueError("{} is not supported.".format(args.data))
+
+ image_shape = [int(m) for m in image_shape.split(",")]
+ assert args.model in model_list, "{} is not in lists: {}".format(
+ args.model, model_list)
+ image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
+ if args.use_pact:
+ image.stop_gradient = False
+ label = fluid.layers.data(name='label', shape=[1], dtype='int64')
+ # model definition
+ model = models.__dict__[args.model]()
+ out = model.net(input=image, class_dim=class_dim)
+ cost = fluid.layers.cross_entropy(input=out, label=label)
+ avg_cost = fluid.layers.mean(x=cost)
+ acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
+ acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
+
+ train_prog = fluid.default_main_program()
+ val_program = fluid.default_main_program().clone(for_test=True)
+
+ place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
+ opt = create_optimizer(args)
+ opt.minimize(avg_cost)
+
+ exe = fluid.Executor(place)
+ exe.run(fluid.default_startup_program())
+
+ # 2. quantization transform programs (training aware)
+ # Make some quantization transforms in the graph before training and testing.
+ # According to the weight and activation quantization type, the graph will be added
+ # some fake quantize operators and fake dequantize operators.
+
+ if args.use_pact:
+ act_preprocess_func = pact
+ optimizer_func = get_optimizer
+ executor = exe
+ else:
+ act_preprocess_func = None
+ optimizer_func = None
+ executor = None
+
+ val_program = quant_aware(
+ val_program,
+ place,
+ quant_config,
+ scope=None,
+ act_preprocess_func=act_preprocess_func,
+ optimizer_func=optimizer_func,
+ executor=executor,
+ for_test=True)
+ compiled_train_prog = quant_aware(
+ train_prog,
+ place,
+ quant_config,
+ scope=None,
+ act_preprocess_func=act_preprocess_func,
+ optimizer_func=optimizer_func,
+ executor=executor,
+ for_test=False)
+
+ assert os.path.exists(
+ args.pretrained_model), "pretrained_model doesn't exist"
+
+ if args.pretrained_model:
+
+ def if_exist(var):
+ return os.path.exists(
+ os.path.join(args.pretrained_model, var.name))
+
+ fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
+
+ val_reader = paddle.fluid.io.batch(val_reader, batch_size=args.batch_size)
+ train_reader = paddle.fluid.io.batch(
+ train_reader, batch_size=args.batch_size, drop_last=True)
+
+ train_feeder = feeder = fluid.DataFeeder([image, label], place)
+ val_feeder = feeder = fluid.DataFeeder(
+ [image, label], place, program=val_program)
+
+ def test(epoch, program):
+ batch_id = 0
+ acc_top1_ns = []
+ acc_top5_ns = []
+ for data in val_reader():
+ start_time = time.time()
+ acc_top1_n, acc_top5_n = exe.run(
+ program,
+ feed=train_feeder.feed(data),
+ fetch_list=[acc_top1.name, acc_top5.name])
+ end_time = time.time()
+ if batch_id % args.log_period == 0:
+ _logger.info(
+ "Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
+ format(epoch, batch_id,
+ np.mean(acc_top1_n),
+ np.mean(acc_top5_n), end_time - start_time))
+ acc_top1_ns.append(np.mean(acc_top1_n))
+ acc_top5_ns.append(np.mean(acc_top5_n))
+ batch_id += 1
+
+ _logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".
+ format(epoch,
+ np.mean(np.array(acc_top1_ns)),
+ np.mean(np.array(acc_top5_ns))))
+ return np.mean(np.array(acc_top1_ns))
+
+ def train(epoch, compiled_train_prog):
+
+ batch_id = 0
+ for data in train_reader():
+ start_time = time.time()
+ loss_n, acc_top1_n, acc_top5_n = exe.run(
+ compiled_train_prog,
+ feed=train_feeder.feed(data),
+ fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
+ end_time = time.time()
+ loss_n = np.mean(loss_n)
+ acc_top1_n = np.mean(acc_top1_n)
+ acc_top5_n = np.mean(acc_top5_n)
+ if batch_id % args.log_period == 0:
+ _logger.info(
+ "epoch[{}]-batch[{}] - loss: {}; acc_top1: {}; acc_top5: {}; time: {}".
+ format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n,
+ end_time - start_time))
+
+ if args.use_pact and batch_id % 1000 == 0:
+ threshold = {}
+ for var in val_program.list_vars():
+ if 'pact' in var.name:
+ array = np.array(fluid.global_scope().find_var(
+ var.name).get_tensor())
+ threshold[var.name] = array[0]
+ print(threshold)
+
+ batch_id += 1
+
+ build_strategy = fluid.BuildStrategy()
+ build_strategy.memory_optimize = False
+ build_strategy.enable_inplace = False
+ build_strategy.fuse_all_reduce_ops = False
+ build_strategy.sync_batch_norm = False
+ exec_strategy = fluid.ExecutionStrategy()
+ compiled_train_prog = compiled_train_prog.with_data_parallel(
+ loss_name=avg_cost.name,
+ build_strategy=build_strategy,
+ exec_strategy=exec_strategy)
+
+ # train loop
+ best_acc1 = 0.0
+ best_epoch = 0
+ for i in range(args.num_epochs):
+ train(i, compiled_train_prog)
+ acc1 = test(i, val_program)
+ fluid.io.save_persistables(
+ exe,
+ dirname=os.path.join(args.checkpoint_dir, str(i)),
+ main_program=val_program)
+ if acc1 > best_acc1:
+ best_acc1 = acc1
+ best_epoch = i
+ fluid.io.save_persistables(
+ exe,
+ dirname=os.path.join(args.checkpoint_dir, 'best_model'),
+ main_program=val_program)
+ if os.path.exists(os.path.join(args.checkpoint_dir, 'best_model')):
+ fluid.io.load_persistables(
+ exe,
+ dirname=os.path.join(args.checkpoint_dir, 'best_model'),
+ main_program=val_program)
+ # 3. Freeze the graph after training by adjusting the quantize
+ # operators' order for the inference.
+ # The dtype of float_program's weights is float32, but in int8 range.
+ float_program, int8_program = convert(val_program, place, quant_config, \
+ scope=None, \
+ save_int8=True)
+ print("eval best_model after convert")
+ final_acc1 = test(best_epoch, float_program)
+ # 4. Save inference model
+ model_path = os.path.join(quantization_model_save_dir, args.model,
+ 'act_' + quant_config['activation_quantize_type']
+ + '_w_' + quant_config['weight_quantize_type'])
+ float_path = os.path.join(model_path, 'float')
+ int8_path = os.path.join(model_path, 'int8')
+ if not os.path.isdir(model_path):
+ os.makedirs(model_path)
+
+ fluid.io.save_inference_model(
+ dirname=float_path,
+ feeded_var_names=[image.name],
+ target_vars=[out],
+ executor=exe,
+ main_program=float_program,
+ model_filename=float_path + '/model',
+ params_filename=float_path + '/params')
+
+ fluid.io.save_inference_model(
+ dirname=int8_path,
+ feeded_var_names=[image.name],
+ target_vars=[out],
+ executor=exe,
+ main_program=int8_program,
+ model_filename=int8_path + '/model',
+ params_filename=int8_path + '/params')
+
+
+def main():
+ args = parser.parse_args()
+ print_arguments(args)
+ compress(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/docs/zh_cn/api_cn/quantization_api.rst b/docs/zh_cn/api_cn/quantization_api.rst
index 83a34a80bfafd22066b91766a33721d4830af81c..1b3358b16cf772550990e2f824872f901bc2c4bb 100644
--- a/docs/zh_cn/api_cn/quantization_api.rst
+++ b/docs/zh_cn/api_cn/quantization_api.rst
@@ -23,7 +23,7 @@ quant_post_dynamic
.. py:function:: paddleslim.quant.quant_post_dynamic(model_dir, save_model_dir, model_filename=None, params_filename=None, save_model_filename=None, save_params_filename=None, quantizable_op_type=["conv2d", "mul"], weight_bits=8, generate_test_model=False)
-`源代码 `_
+`源代码 `_
动态离线量化,将模型中特定OP的权重从FP32类型量化成INT8/16类型。
@@ -99,7 +99,7 @@ quant_post_static
.. py:function:: paddleslim.quant.quant_post_static(executor,model_dir, quantize_model_path, batch_generator=None, sample_generator=None, model_filename=None, params_filename=None, save_model_filename='__model__', save_params_filename='__params__', batch_size=16, batch_nums=None, scope=None, algo='KL', quantizable_op_type=["conv2d","depthwise_conv2d","mul"], is_full_quantize=False, weight_bits=8, activation_bits=8, activation_quantize_type='range_abs_max', weight_quantize_type='channel_wise_abs_max', is_use_cache_file=False, cache_dir="./temp_post_training")
-`源代码 `_
+`源代码 `_
静态离线量化,使用少量校准数据计算量化因子,可以快速得到量化模型。使用该量化模型进行预测,可以减少计算量、降低计算内存、减小模型大小。
@@ -184,7 +184,7 @@ quant_post_static
batch_size=16,
batch_nums=10)
-更详细的用法请参考 `离线量化demo `_ 。
+更详细的用法请参考 `离线量化demo `_ 。
@@ -192,9 +192,9 @@ quant_post_static
quant_aware
------------
-.. py:function:: paddleslim.quant.quant_aware(program, place, config, scope=None, for_test=False)
+.. py:function:: paddleslim.quant.quant_aware(program, place, config, scope=None, for_test=False, weight_quantize_func=None, act_quantize_func=None, weight_preprocess_func=None, act_preprocess_func=None, optimizer_func=None, executor=None))
-`源代码 `_
+`源代码 `_
在 program 中加入量化和反量化op, 用于量化训练。
@@ -206,6 +206,14 @@ quant_aware
- **config(dict)** - 量化配置表。
- **scope(fluid.Scope, optional)** - 传入用于存储 ``Variable`` 的 ``scope`` ,需要传入 ``program`` 所使用的 ``scope`` ,一般情况下,是 `fluid.global_scope() `_ 。设置为 ``None`` 时将使用 `fluid.global_scope() `_ ,默认值为 ``None`` 。
- **for_test(bool)** - 如果 ``program`` 参数是一个测试 ``program`` , ``for_test`` 应设为True,否则设为False 。
+- **weight_quantize_func(function)** - 自定义对权重量化的函数,该函数的输入是待量化的权重,输出是反量化之后的权重,可以快速验证此量化函数是否有效。此参数设置后,将会替代量化配置中 `weight_quantize_type` 定义的方法,如果此参数不设置,将继续使用 `weight_quantize_type` 定义的方法。默认为None。
+- **act_quantize_func(function)** - 自定义对激活量化的函数,该函数的输入是待量化的激活,输出是反量化之后的激活,可以快速验证此量化函数是否有效。将会替代量化配置中 `activation_quantize_type` 定义的方法,如果此参数不设置,将继续使用 `activation_quantize_type` 定义的方法。默认为None.
+- **weight_preprocess_func(function)** - 自定义在对权重做量化之前,对权重进行处理的函数。此方法的意义在于网络中的参数不一定适合于直接量化,如果对参数分布先进行处理再进行量化,或许可以提高量化精度。默认为None.
+
+- **act_preprocess_func(function)** - 自定义在对激活做量化之前,对激活进行处理的函数。此方法的意义在于网络中的激活值不一定适合于直接量化,如果对激活值先进行处理再进行量化,或许可以提高量化精度。默认为None.
+
+- **optimizer_func(function)** - 该参数是一个返回optimizer的函数。定义的optimizer函数将用于定义上述自定义函数中的参数的优化参数。默认为None.
+- **executor(fluid.Executor)** - 用于初始化上述自定义函数中的变量。默认为None.
**返回**
@@ -230,7 +238,7 @@ convert
.. py:function:: paddleslim.quant.convert(program, place, config, scope=None, save_int8=False)
-`源代码 `_
+`源代码 `_
把训练好的量化 program ,转换为可用于保存 ``inference model`` 的 program 。
@@ -294,7 +302,7 @@ convert
inference_prog = quant.convert(quant_eval_program, place, config)
-更详细的用法请参考 `量化训练demo `_ 。
+更详细的用法请参考 `量化训练demo `_ 。
量化训练方法的参数配置
@@ -367,7 +375,7 @@ quant_embedding
.. py:function:: paddleslim.quant.quant_embedding(program, place, config=None, scope=None)
-`源代码 `_
+`源代码 `_
对 ``Embedding`` 参数进行量化。
@@ -418,6 +426,6 @@ fluid.Program
}
quant_program = quant.quant_embedding(infer_program, place, config)
-更详细的用法请参考 `Embedding量化demo `_
+更详细的用法请参考 `Embedding量化demo `_
diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py
index b6b749d8bb02d6cff1cca1d9742182408be089d9..462f1c754a8d4da8e06af971f44bc1a73f5c1338 100755
--- a/paddleslim/quant/quanter.py
+++ b/paddleslim/quant/quanter.py
@@ -160,7 +160,17 @@ def _parse_configs(user_config):
return configs
-def quant_aware(program, place, config=None, scope=None, for_test=False):
+def quant_aware(program,
+ place,
+ config=None,
+ scope=None,
+ for_test=False,
+ weight_quantize_func=None,
+ act_quantize_func=None,
+ weight_preprocess_func=None,
+ act_preprocess_func=None,
+ optimizer_func=None,
+ executor=None):
"""Add quantization and dequantization operators to "program"
for quantization training or testing.
@@ -175,7 +185,32 @@ def quant_aware(program, place, config=None, scope=None, for_test=False):
`fluid.global_scope `_. When ``None`` will use `fluid.global_scope() `_ . Default: ``None``.
for_test(bool): If the 'program' parameter is a test program, this parameter should be set to ``True``.
Otherwise, set to ``False``.Default: False
-
+ weight_quantize_func(function): Function that defines how to quantize weight. Using this
+ can quickly test if user's quantization method works or not. In this function, user should
+ both define quantization function and dequantization function, that is, the function's input
+ is non-quantized weight and function returns dequantized weight. If None, will use
+ quantization op defined by 'weight_quantize_type'.
+ Default is None.
+ act_quantize_func(function): Function that defines how to quantize activation. Using this
+ can quickly test if user's quantization method works or not. In this function, user should
+ both define quantization and dequantization process, that is, the function's input
+ is non-quantized activation and function returns dequantized activation. If None, will use
+ quantization op defined by 'activation_quantize_type'.
+ Default is None.
+ weight_preprocess_func(function): Function that defines how to preprocess weight before quantization. Using this
+ can quickly test if user's preprocess method works or not. The function's input
+ is non-quantized weight and function returns processed weight to be quantized. If None, the weight will
+ be quantized directly.
+ Default is None.
+ act_preprocess_func(function): Function that defines how to preprocess activation before quantization. Using this
+ can quickly test if user's preprocess method works or not. The function's input
+ is non-quantized activation and function returns processed activation to be quantized. If None, the activation will
+ be quantized directly.
+ Default is None.
+ optimizer_func(function): Fuction return a optimizer. When 'is_test' is False and user want to use self-defined
+ quantization function and preprocess function, this function must be set. Default is None.
+ exe(Fluid.Executor): If user want to use self-defined quantization function and preprocess function, exe must be set for
+ initialization. Default is None.
Returns:
fluid.CompiledProgram | fluid.Program: Program with quantization and dequantization ``operators``
"""
@@ -208,7 +243,13 @@ def quant_aware(program, place, config=None, scope=None, for_test=False):
window_size=config['window_size'],
moving_rate=config['moving_rate'],
quantizable_op_type=transform_pass_ops,
- skip_pattern=config['not_quant_pattern'])
+ skip_pattern=config['not_quant_pattern'],
+ weight_quantize_func=weight_quantize_func,
+ act_quantize_func=act_quantize_func,
+ weight_preprocess_func=weight_preprocess_func,
+ act_preprocess_func=act_preprocess_func,
+ optimizer_func=optimizer_func,
+ executor=executor)
transform_pass.apply(main_graph)
diff --git a/tests/test_quant_aware_user_defined.py b/tests/test_quant_aware_user_defined.py
new file mode 100644
index 0000000000000000000000000000000000000000..d59c2ee16a9c920d86265b8eb7f8ca7517e3c329
--- /dev/null
+++ b/tests/test_quant_aware_user_defined.py
@@ -0,0 +1,157 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+sys.path.append("../")
+import unittest
+import paddle
+import paddle.fluid as fluid
+from paddleslim.quant import quant_aware, convert
+sys.path.append("../demo")
+from models import MobileNet
+from layers import conv_bn_layer
+import paddle.dataset.mnist as reader
+from paddle.fluid.framework import IrGraph
+from paddle.fluid import core
+import numpy as np
+
+from paddle.fluid.layer_helper import LayerHelper
+
+
+def pact(x, name=None):
+ helper = LayerHelper("pact", **locals())
+ dtype = 'float32'
+ init_thres = 20
+ u_param_attr = fluid.ParamAttr(
+ name=x.name + '_pact',
+ initializer=fluid.initializer.ConstantInitializer(value=init_thres),
+ regularizer=fluid.regularizer.L2Decay(0.0001),
+ learning_rate=1)
+ u_param = helper.create_parameter(
+ attr=u_param_attr, shape=[1], dtype=dtype)
+ x = fluid.layers.elementwise_sub(
+ x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param)))
+ x = fluid.layers.elementwise_add(
+ x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x)))
+
+ return x
+
+
+def get_optimizer():
+ return fluid.optimizer.MomentumOptimizer(0.0001, 0.9)
+
+
+class TestQuantAwareCase1(unittest.TestCase):
+ def get_model(self):
+ image = fluid.layers.data(
+ name='image', shape=[1, 28, 28], dtype='float32')
+ label = fluid.layers.data(name='label', shape=[1], dtype='int64')
+ model = MobileNet()
+ out = model.net(input=image, class_dim=10)
+ cost = fluid.layers.cross_entropy(input=out, label=label)
+ avg_cost = fluid.layers.mean(x=cost)
+ startup_prog = fluid.default_startup_program()
+ train_prog = fluid.default_main_program()
+ return startup_prog, train_prog
+
+ def test_accuracy(self):
+ image = fluid.layers.data(
+ name='image', shape=[1, 28, 28], dtype='float32')
+ image.stop_gradient = False
+ label = fluid.layers.data(name='label', shape=[1], dtype='int64')
+ model = MobileNet()
+ out = model.net(input=image, class_dim=10)
+ cost = fluid.layers.cross_entropy(input=out, label=label)
+ avg_cost = fluid.layers.mean(x=cost)
+ acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
+ acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
+ optimizer = fluid.optimizer.Momentum(
+ momentum=0.9,
+ learning_rate=0.01,
+ regularization=fluid.regularizer.L2Decay(4e-5))
+ optimizer.minimize(avg_cost)
+ main_prog = fluid.default_main_program()
+ val_prog = main_prog.clone(for_test=True)
+
+ place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
+ ) else fluid.CPUPlace()
+ exe = fluid.Executor(place)
+ exe.run(fluid.default_startup_program())
+ feeder = fluid.DataFeeder([image, label], place, program=main_prog)
+ train_reader = paddle.fluid.io.batch(
+ paddle.dataset.mnist.train(), batch_size=64)
+ eval_reader = paddle.fluid.io.batch(
+ paddle.dataset.mnist.test(), batch_size=64)
+
+ def train(program):
+ iter = 0
+ for data in train_reader():
+ cost, top1, top5 = exe.run(
+ program,
+ feed=feeder.feed(data),
+ fetch_list=[avg_cost, acc_top1, acc_top5])
+ iter += 1
+ if iter % 100 == 0:
+ print(
+ 'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
+ format(iter, cost, top1, top5))
+
+ def test(program):
+ iter = 0
+ result = [[], [], []]
+ for data in eval_reader():
+ cost, top1, top5 = exe.run(
+ program,
+ feed=feeder.feed(data),
+ fetch_list=[avg_cost, acc_top1, acc_top5])
+ iter += 1
+ if iter % 100 == 0:
+ print(
+ 'eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
+ format(iter, cost, top1, top5))
+ result[0].append(cost)
+ result[1].append(top1)
+ result[2].append(top5)
+ print(' avg loss {}, acc_top1 {}, acc_top5 {}'.format(
+ np.mean(result[0]), np.mean(result[1]), np.mean(result[2])))
+ return np.mean(result[1]), np.mean(result[2])
+
+ train(main_prog)
+ top1_1, top5_1 = test(main_prog)
+
+ config = {
+ 'weight_quantize_type': 'channel_wise_abs_max',
+ 'activation_quantize_type': 'moving_average_abs_max',
+ 'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'],
+ }
+ quant_train_prog_pact = quant_aware(
+ main_prog,
+ place,
+ config,
+ for_test=False,
+ act_preprocess_func=pact,
+ optimizer_func=get_optimizer,
+ executor=exe)
+
+ quant_eval_prog = quant_aware(val_prog, place, config, for_test=True)
+ train(quant_train_prog_pact)
+ quant_eval_prog, int8_prog = convert(
+ quant_eval_prog, place, config, save_int8=True)
+ top1_2, top5_2 = test(quant_eval_prog)
+ # values before quantization and after quantization should be close
+ print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
+ print("after quantization: top1: {}, top5: {}".format(top1_2, top5_2))
+
+
+if __name__ == '__main__':
+ unittest.main()