提交 de9eb717 编写于 作者: L Liufang Sang 提交者: whs

fix classification quantization (#3483)

上级 5108c1c1
......@@ -58,33 +58,58 @@ class ResNet():
pool_padding=1,
pool_type='max')
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
if layers >= 50:
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
conv_name = prefix_name + conv_name
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
name=conv_name)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
fc_name = fc_name if fc_name is None else prefix_name + fc_name
out = fluid.layers.fc(input=pool,
conv_name = "res" + str(block + 2) + chr(97 + i)
conv_name = prefix_name + conv_name
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
name=conv_name)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
fc_name = fc_name if fc_name is None else prefix_name + fc_name
out = fluid.layers.fc(input=pool,
size=class_dim,
act='softmax',
name=fc_name,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv,
stdv)))
else:
for block in range(len(depth)):
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
conv_name = prefix_name + conv_name
conv = self.basic_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
is_first=block == i == 0,
name=conv_name)
pool = fluid.layers.pool2d(
input=conv, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
fc_name = fc_name if fc_name is None else prefix_name + fc_name
out = fluid.layers.fc(
input=pool,
size=class_dim,
act='softmax',
name=fc_name,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
return out
def conv_bn_layer(self,
......@@ -126,9 +151,9 @@ class ResNet():
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', )
def shortcut(self, input, ch_out, stride, name):
def shortcut(self, input, ch_out, stride, is_first, name):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
if ch_in != ch_out or stride != 1 or is_first == True:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return input
......@@ -155,10 +180,29 @@ class ResNet():
name=name + "_branch2c")
short = self.shortcut(
input, num_filters * 4, stride, name=name + "_branch1")
input, num_filters * 4, stride, is_first=False, name=name + "_branch1")
return fluid.layers.elementwise_add(
x=short, y=conv2, act='relu', name=name + ".add.output.5")
def basic_block(self, input, num_filters, stride, is_first, name):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=3,
act='relu',
stride=stride,
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
short = self.shortcut(
input, num_filters, stride, is_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
def ResNet34(prefix_name=''):
......
......@@ -64,22 +64,10 @@ PaddlePaddle框架中有四个和量化相关的IrPass, 分别是QuantizationTra
>注意:配置文件中的信息不会保存在断点中,重启前对配置文件的修改将会生效。
### 保存评估和预测模型
如果在配置文件的量化策略中设置了`float_model_save_path`, `int8_model_save_path`, `mobile_model_save_path`, 在训练结束后,会保存模型量化压缩之后用于评估和预测的模型。接下来介绍这三种模型的区别。
## 评估
如果在配置文件中设置了`checkpoint_path`,则每个epoch会保存一个量化后的用于评估的模型,
该模型会保存在`${checkpoint_path}/${epoch_id}/eval_model/`路径下,包含`__model__``__params__`两个文件。
其中,`__model__`用于保存模型结构信息,`__params__`用于保存参数(parameters)信息。模型结构和训练时一样。
如果不需要保存评估模型,可以在定义Compressor对象时,将`save_eval_model`选项设置为False(默认为True)。
脚本<a href="../eval.py">PaddleSlim/classification/eval.py</a>中为使用该模型在评估数据集上做评估的示例。
## 预测
如果在配置文件的量化策略中设置了`float_model_save_path`, `int8_model_save_path`, `mobile_model_save_path`, 在训练结束后,会保存模型量化压缩之后用于预测的模型。接下来介绍这三种预测模型的区别。
### float预测模型
#### float模型
在介绍量化训练时的模型结构时介绍了PaddlePaddle框架中有四个和量化相关的IrPass, 分别是QuantizationTransformPass、QuantizationFreezePass、ConvertToInt8Pass以及TransformForMobilePass。float预测模型是在应用QuantizationFreezePass并删除eval_program中多余的operators之后,保存的模型。
QuantizationFreezePass主要用于改变IrGraph中量化op和反量化op的顺序,即将类似图1中的量化op和反量化op顺序改变为图2中的布局。除此之外,QuantizationFreezePass还会将`conv2d``depthwise_conv2d``mul`等算子的权重离线量化为int8_t范围内的值(但数据类型仍为float32),以减少预测过程中对权重的量化操作,示例如图2:
......@@ -89,7 +77,7 @@ QuantizationFreezePass主要用于改变IrGraph中量化op和反量化op的顺
<strong>图2:应用QuantizationFreezePass后的结果</strong>
</p>
### int8预测模型
#### int8模型
在对训练网络进行QuantizationFreezePass之后,执行ConvertToInt8Pass,
其主要目的是将执行完QuantizationFreezePass后输出的权重类型由`FP32`更改为`INT8`。换言之,用户可以选择将量化后的权重保存为float32类型(不执行ConvertToInt8Pass)或者int8_t类型(执行ConvertToInt8Pass),示例如图3:
......@@ -98,7 +86,7 @@ QuantizationFreezePass主要用于改变IrGraph中量化op和反量化op的顺
<strong>图3:应用ConvertToInt8Pass后的结果</strong>
</p>
### mobile预测模型
#### mobile模型
经TransformForMobilePass转换后,用户可得到兼容[paddle-lite](https://github.com/PaddlePaddle/Paddle-Lite)移动端预测库的量化模型。paddle-mobile中的量化op和反量化op的名称分别为`quantize``dequantize``quantize`算子和PaddlePaddle框架中的`fake_quantize_abs_max`算子簇的功能类似,`dequantize` 算子和PaddlePaddle框架中的`fake_dequantize_max_abs`算子簇的功能相同。若选择paddle-mobile执行量化训练输出的模型,则需要将`fake_quantize_abs_max`等算子改为`quantize`算子以及将`fake_dequantize_max_abs`等算子改为`dequantize`算子,示例如图4:
<p align="center">
......@@ -106,6 +94,30 @@ QuantizationFreezePass主要用于改变IrGraph中量化op和反量化op的顺
<strong>图4:应用TransformForMobilePass后的结果</strong>
</p>
## 评估
### 每个epoch保存的评估模型
因为量化的最终模型只有在end_epoch时保存一次,不能保证保存的模型是最好的,因此
如果在配置文件中设置了`checkpoint_path`,则每个epoch会保存一个量化后的用于评估的模型,
该模型会保存在`${checkpoint_path}/${epoch_id}/eval_model/`路径下,包含`__model__``__params__`两个文件。
其中,`__model__`用于保存模型结构信息,`__params__`用于保存参数(parameters)信息。模型结构和训练时一样。
如果不需要保存评估模型,可以在定义Compressor对象时,将`save_eval_model`选项设置为False(默认为True)。
脚本<a href="../eval.py">PaddleSlim/classification/eval.py</a>中为使用该模型在评估数据集上做评估的示例。
在评估之后,选取效果最好的epoch的模型,可使用脚本 <a href='./freeze.py'>PaddleSlim/classification/freeze.py</a>将该模型转化为以上介绍的三种模型:float模型,int8模型,mo
bile模型,需要配置的参数为:
- model_path, 加载的模型路径,`为${checkpoint_path}/${epoch_id}/eval_model/`
- weight_quant_type 模型参数的量化方式,和配置文件中的类型保持一致
- save_path `float`, `int8`, `mobile`模型的保存路径,分别为 `${save_path}/float/`, `${save_path}/int8/`, `${save_path}/mobile/`
### 最终评估模型
最终使用的评估模型是float模型,使用脚本<a href="../eval.py">PaddleSlim/classification/eval.py</a>中为使用该模型在评估数据集上做评估的示例。
## 预测
### python预测
float预测模型可直接使用原生PaddlePaddle Fluid预测方法进行预测。
......@@ -139,7 +151,7 @@ fluid.optimizer.Momentum(momentum=0.9,
values=[0.0001, 0.00001]),
regularization=fluid.regularizer.L2Decay(1e-4))
```
batch size 1024
8卡,batch size 1024,epoch 30, 挑选好的结果
### MobileNetV2
......@@ -171,6 +183,7 @@ fluid.optimizer.Momentum(momentum=0.9,
values=[0.0001, 0.00001]),
regularization=fluid.regularizer.L2Decay(1e-4))
```
batch size 1024
8卡,batch size 1024,epoch 30, 挑选好的结果
## FAQ
......@@ -53,12 +53,12 @@ def compress(args):
val_program = fluid.default_main_program().clone()
# quantization usually use small learning rate
values = [1e-4, 1e-5, 1e-6]
values = [1e-4, 1e-5]
opt = fluid.optimizer.Momentum(
momentum=0.9,
learning_rate=fluid.layers.piecewise_decay(
boundaries=[5000 * 30, 5000 * 60], values=values),
regularization=fluid.regularizer.L2Decay(4e-5))
boundaries=[5000 * 12], values=values),
regularization=fluid.regularizer.L2Decay(1e-4))
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
......
......@@ -3,7 +3,7 @@ strategies:
quantization_strategy:
class: 'QuantizationStrategy'
start_epoch: 0
end_epoch: 0
end_epoch: 29
float_model_save_path: './output/mobilenet_v1/float'
mobile_model_save_path: './output/mobilenet_v1/mobile'
int8_model_save_path: './output/mobilenet_v1/int8'
......@@ -14,7 +14,7 @@ strategies:
save_in_nodes: ['image']
save_out_nodes: ['fc_0.tmp_2']
compressor:
epoch: 1
epoch: 30
checkpoint_path: './checkpoints/mobilenet_v1/'
strategies:
- quantization_strategy
......@@ -3,7 +3,7 @@ strategies:
quantization_strategy:
class: 'QuantizationStrategy'
start_epoch: 0
end_epoch: 0
end_epoch: 29
float_model_save_path: './output/mobilenet_v2/float'
mobile_model_save_path: './output/mobilenet_v2/mobile'
int8_model_save_path: './output/mobilenet_v2/int8'
......@@ -14,7 +14,7 @@ strategies:
save_in_nodes: ['image']
save_out_nodes: ['fc_0.tmp_2']
compressor:
epoch: 1
epoch: 30
checkpoint_path: './checkpoints/mobilenet_v2/'
strategies:
- quantization_strategy
......@@ -3,10 +3,10 @@ strategies:
quantization_strategy:
class: 'QuantizationStrategy'
start_epoch: 0
end_epoch: 0
float_model_save_path: './output/resnet50/float'
mobile_model_save_path: './output/resnet50/mobile'
int8_model_save_path: './output/resnet50/int8'
end_epoch: 29
float_model_save_path: './output/resnet34/float'
mobile_model_save_path: './output/resnet34/mobile'
int8_model_save_path: './output/resnet34/int8'
weight_bits: 8
activation_bits: 8
weight_quantize_type: 'abs_max'
......@@ -14,7 +14,7 @@ strategies:
save_in_nodes: ['image']
save_out_nodes: ['fc_0.tmp_2']
compressor:
epoch: 2
checkpoint_path: './checkpoints/resnet50/'
epoch: 30
checkpoint_path: './checkpoints/resnet34/'
strategies:
- quantization_strategy
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import numpy as np
import argparse
import functools
import logging
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
sys.path.append("..")
import imagenet_reader as reader
sys.path.append("../../")
from utility import add_arguments, print_arguments
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
# yapf: disable
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model_path', str, "./pruning/checkpoints/resnet50/2/eval_model/", "Whether to use pretrained model.")
add_arg('save_path', str, './output', 'Path to save inference model')
add_arg('weight_quant_type', str, 'abs_max', 'quantization type for weight')
# yapf: enable
def eval(args):
# parameters from arguments
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
val_program, feed_names, fetch_targets = fluid.io.load_inference_model(args.model_path,
exe,
model_filename="__model__",
params_filename="__params__")
val_reader = paddle.batch(reader.val(), batch_size=128)
feeder = fluid.DataFeeder(place=place, feed_list=feed_names, program=val_program)
results=[]
for batch_id, data in enumerate(val_reader()):
# top1_acc, top5_acc
result = exe.run(val_program,
feed=feeder.feed(data),
fetch_list=fetch_targets)
result = [np.mean(r) for r in result]
results.append(result)
result = np.mean(np.array(results), axis=0)
print("top1_acc/top5_acc= {}".format(result))
sys.stdout.flush()
_logger.info("freeze the graph for inference")
test_graph = IrGraph(core.Graph(val_program.desc), for_test=True)
freeze_pass = QuantizationFreezePass(
scope=fluid.global_scope(),
place=place,
weight_quantize_type=args.weight_quant_type)
freeze_pass.apply(test_graph)
server_program = test_graph.to_program()
fluid.io.save_inference_model(
dirname=os.path.join(args.save_path, 'float'),
feeded_var_names=feed_names,
target_vars=fetch_targets,
executor=exe,
main_program=server_program,
model_filename='model',
params_filename='params')
_logger.info("convert the weights into int8 type")
convert_int8_pass = ConvertToInt8Pass(
scope=fluid.global_scope(),
place=place)
convert_int8_pass.apply(test_graph)
server_int8_program = test_graph.to_program()
fluid.io.save_inference_model(
dirname=os.path.join(args.save_path, 'int8'),
feeded_var_names=feed_names,
target_vars=fetch_targets,
executor=exe,
main_program=server_int8_program,
model_filename='model',
params_filename='params')
_logger.info("convert the freezed pass to paddle-lite execution")
mobile_pass = TransformForMobilePass()
mobile_pass.apply(test_graph)
mobile_program = test_graph.to_program()
fluid.io.save_inference_model(
dirname=os.path.join(args.save_path, 'mobile'),
feeded_var_names=feed_names,
target_vars=fetch_targets,
executor=exe,
main_program=mobile_program,
model_filename='model',
params_filename='params')
def main():
args = parser.parse_args()
print_arguments(args)
eval(args)
if __name__ == '__main__':
main()
......@@ -4,7 +4,7 @@
root_url="http://paddle-imagenet-models-name.bj.bcebos.com"
MobileNetV1="MobileNetV1_pretrained.tar"
MobileNetV2="MobileNetV2_pretrained.tar"
ResNet50="ResNet50_pretrained.tar"
ResNet34="ResNet34_pretrained.tar"
pretrain_dir='../pretrain'
if [ ! -d ${pretrain_dir} ]; then
......@@ -23,9 +23,9 @@ if [ ! -f ${MobileNetV2} ]; then
tar xf ${MobileNetV2}
fi
if [ ! -f ${ResNet50} ]; then
wget ${root_url}/${ResNet50}
tar xf ${ResNet50}
if [ ! -f ${ResNet34} ]; then
wget ${root_url}/${ResNet34}
tar xf ${ResNet34}
fi
cd -
......@@ -37,14 +37,14 @@ export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
## for quantization for mobilenet_v1
python -u compress.py \
--model "MobileNet" \
--use_gpu 1 \
--batch_size 32 \
--pretrained_model ../pretrain/MobileNetV1_pretrained \
--config_file "./configs/mobilenet_v1.yaml" \
> mobilenet_v1.log 2>&1 &
tailf mobilenet_v1.log
#python -u compress.py \
# --model "MobileNet" \
# --use_gpu 1 \
# --batch_size 256 \
# --pretrained_model ../pretrain/MobileNetV1_pretrained \
# --config_file "./configs/mobilenet_v1.yaml" \
#> mobilenet_v1.log 2>&1 &
#tailf mobilenet_v1.log
## for quantization of mobilenet_v2
#python -u compress.py \
......@@ -56,12 +56,12 @@ tailf mobilenet_v1.log
# > mobilenet_v2.log 2>&1 &
#tailf mobilenet_v2.log
# for compression of resnet50
#python -u compress.py \
# --model "ResNet50" \
# --use_gpu 1 \
# --batch_size 32 \
# --pretrained_model ../pretrain/ResNet50_pretrained \
# --config_file "./configs/resnet50.yaml" \
# > resnet50.log 2>&1 &
#tailf resnet50.log
# for compression of resnet34
python -u compress.py \
--model "ResNet34" \
--use_gpu 1 \
--batch_size 32 \
--pretrained_model ../pretrain/ResNet34_pretrained \
--config_file "./configs/resnet34.yaml" \
> resnet34.log 2>&1 &
tailf resnet34.log
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册