提交 0c934b1e 编写于 作者: W wanghaoshuang

Merge branch 'develop' into 'develop'

quant aware demo

See merge request !59
# 在线量化示例
本示例介绍如何使用在线量化接口,来对训练好的分类模型进行量化, 可以减少模型的存储空间和显存占用。
## 接口介绍
```
quant_config_default = {
'weight_quantize_type': 'abs_max',
'activation_quantize_type': 'abs_max',
'weight_bits': 8,
'activation_bits': 8,
# ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized
'quantize_op_types':
['conv2d', 'depthwise_conv2d', 'mul', 'elementwise_add', 'pool2d'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
# if set quant_weight_only True, then only quantize parameters of layers which need to be quantized,
# and activations will not be quantized.
'quant_weight_only': False
}
```
量化配置表。
参数说明:
- weight_quantize_type(str): 参数量化方式。可选'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max',默认'abs_max'。
- activation_quantize_type(str): 激活量化方式,可选'abs_max', 'range_abs_max', 'moving_average_abs_max',默认'abs_max'。
- weight_bits(int): 参数量化bit数,默认8。
- activation_bits(int): 激活量化bit数,默认8。
- not_quant_pattern(str or str list): 所有name_scope包含not_quant_pattern字符串的op,都不量化。
- quantize_op_types(str of list): 需要进行量化的op类型,目前支持'conv2d', 'depthwise_conv2d', 'mul', 'elementwise_add', 'pool2d'。
- dtype(int8): 量化后的参数类型,默认int8。
- window_size(int): 'range_abs_max'量化的window size,默认10000。
- moving_rate(int): moving_average_abs_max 量化的衰减系数,默认 0.9。
- quant_weight_only(bool): 是否只量化参数,如果设为True,则激活不进行量化,默认False。
```
def quant_aware(program,
place,
config,
scope=None,
for_test=False)
```
该接口会对传入的program插入可训练量化op。
参数介绍:
- program (fluid.program): 传入训练或测试program。
- place(fluid.CPUPlace or fluid.CUDAPlace): 该参数表示Executor执行所在的设备。
- config(dict): 量化配置表。
- scope(fluid.Scope): 传入用于存储var的scope,需要传入program所使用的scope,一般情况下,是fluid.global_scope()。
- for_test(bool): 如果program参数是一个测试用program,for_test应设为True,否则设为False。
返回参数:
- program(fluid.Program): 插入量化op后的program。
注意:如果for_test为False,这里返回的program是一个fluid.CompiledProgram。
```
def convert(program,
place,
config,
scope=None,
save_int8=False)
```
把训练好的量化program,转换为可用于保存inference model的program。
注意,本接口返回的program,不可用于训练。
参数介绍:
- program (fluid.program): 传入测试program。
- place(fluid.CPUPlace or fluid.CUDAPlace): 该参数表示Executor执行所在的设备。
- config(dict): 量化配置表。
- scope(fluid.Scope): 传入用于存储var的scope,需要传入program所使用的scope,一般情况下,是fluid.global_scope()。
- save_int8(bool): 是否需要导出参数为int8的program。(该功能目前只能用于确认模型大小)
返回参数:
- program (fluid.program): freezed program,可用于保存inference model,参数为float32类型,但其数值范围可用int8表示。
- int8_program (fluid.program): freezed program,可用于保存inference model,参数为int8类型。
## 分类模型的离线量化流程
### 1. 配置量化参数
```
quant_config = {
'weight_quantize_type': 'abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'weight_bits': 8,
'activation_bits': 8,
'not_quant_pattern': ['skip_quant'],
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
'dtype': 'int8',
'window_size': 10000,
'moving_rate': 0.9,
'quant_weight_only': False
}
```
### 2. 对训练和测试program插入可训练量化op
```
val_program = quant_aware(val_program, place, quant_config, scope=None, for_test=True)
compiled_train_prog = quant_aware(train_prog, place, quant_config, scope=None, for_test=False)
```
### 3.关掉指定build策略
```
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_all_reduce_ops = False
build_strategy.sync_batch_norm = False
exec_strategy = fluid.ExecutionStrategy()
compiled_train_prog = compiled_train_prog.with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
```
### 4. freeze program
```
float_program, int8_program = convert(val_program,
place,
quant_config,
scope=None,
save_int8=True)
```
### 5.保存预测模型
```
fluid.io.save_inference_model(
dirname=float_path,
feeded_var_names=[image.name],
target_vars=[out], executor=exe,
main_program=float_program,
model_filename=float_path + '/model',
params_filename=float_path + '/params')
fluid.io.save_inference_model(
dirname=int8_path,
feeded_var_names=[image.name],
target_vars=[out], executor=exe,
main_program=int8_program,
model_filename=int8_path + '/model',
params_filename=int8_path + '/params')
```
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
import paddle.fluid as fluid
sys.path.append(sys.path[0] + "../../../")
sys.path.append(sys.path[0] + "../../")
from paddleslim.common import get_logger
from paddleslim.analysis import flops
from paddleslim.quant import quant_aware, quant_post, convert
import models
from utility import add_arguments, print_arguments
quantization_model_save_dir = './quantization_models/'
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64 * 4, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretrained", "Whether to use pretrained model.")
add_arg('lr', float, 0.0001, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 1, "The number of total epochs.")
add_arg('total_images', int, 1281167, "The number of total training images.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('config_file', str, None, "The config file for compression with yaml format.")
add_arg('data', str, "imagenet", "Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.")
add_arg('test_period', int, 10, "Test period in epoches.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def piecewise_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
bd = [step * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
def cosine_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
learning_rate = fluid.layers.cosine_decay(
learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
def create_optimizer(args):
if args.lr_strategy == "piecewise_decay":
return piecewise_decay(args)
elif args.lr_strategy == "cosine_decay":
return cosine_decay(args)
def compress(args):
############################################################################################################
# 1. quantization configs
############################################################################################################
quant_config = {
# weight quantize type, default is 'abs_max'
'weight_quantize_type': 'abs_max',
# activation quantize type, default is 'abs_max'
'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8,
# op of name_scope in not_quant_pattern list, will not quantized
'not_quant_pattern': ['skip_quant'],
# op of types in quantize_op_types, will quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
# if set quant_weight_only True, then only quantize parameters of layers which need quantization,
# and insert anti-quantization op for parameters of these layers.
'quant_weight_only': False
}
train_reader = None
test_reader = None
if args.data == "mnist":
import paddle.dataset.mnist as reader
train_reader = reader.train()
val_reader = reader.test()
class_dim = 10
image_shape = "1,28,28"
elif args.data == "imagenet":
import imagenet_reader as reader
train_reader = reader.train()
val_reader = reader.val()
class_dim = 1000
image_shape = "3,224,224"
else:
raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(
args.model, model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
train_prog = fluid.default_main_program()
val_program = fluid.default_main_program().clone(for_test=True)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
############################################################################################################
# 2. quantization transform programs (training aware)
# Make some quantization transforms in the graph before training and testing.
# According to the weight and activation quantization type, the graph will be added
# some fake quantize operators and fake dequantize operators.
############################################################################################################
val_program = quant_aware(val_program, place, quant_config, scope=None, for_test=True)
compiled_train_prog = quant_aware(train_prog, place, quant_config, scope=None, for_test=False)
opt = create_optimizer(args)
opt.minimize(avg_cost)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if args.pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(args.pretrained_model, var.name))
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
val_reader = paddle.batch(val_reader, batch_size=args.batch_size)
train_reader = paddle.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
train_feeder = feeder = fluid.DataFeeder([image, label], place)
val_feeder = feeder = fluid.DataFeeder(
[image, label], place, program=val_program)
def test(epoch, program):
batch_id = 0
acc_top1_ns = []
acc_top5_ns = []
for data in val_reader():
start_time = time.time()
acc_top1_n, acc_top5_n = exe.run(
program,
feed=train_feeder.feed(data),
fetch_list=[acc_top1.name, acc_top5.name])
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
"Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id,
np.mean(acc_top1_n),
np.mean(acc_top5_n), end_time - start_time))
acc_top1_ns.append(np.mean(acc_top1_n))
acc_top5_ns.append(np.mean(acc_top5_n))
batch_id += 1
_logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".
format(epoch,
np.mean(np.array(acc_top1_ns)),
np.mean(np.array(acc_top5_ns))))
return np.mean(np.array(acc_top1_ns))
def train(epoch, compiled_train_prog):
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
build_strategy.fuse_all_reduce_ops = False
build_strategy.sync_batch_norm = False
exec_strategy = fluid.ExecutionStrategy()
compiled_train_prog = compiled_train_prog.with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
batch_id = 0
for data in train_reader():
start_time = time.time()
loss_n, acc_top1_n, acc_top5_n = exe.run(
compiled_train_prog,
feed=train_feeder.feed(data),
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
end_time = time.time()
loss_n = np.mean(loss_n)
acc_top1_n = np.mean(acc_top1_n)
acc_top5_n = np.mean(acc_top5_n)
if batch_id % args.log_period == 0:
_logger.info(
"epoch[{}]-batch[{}] - loss: {}; acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n,
end_time - start_time))
batch_id += 1
############################################################################################################
# train loop
############################################################################################################
for i in range(args.num_epochs):
train(i, compiled_train_prog)
if i % args.test_period == 0:
test(i, val_program)
############################################################################################################
# 3. Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
# The dtype of float_program's weights is float32, but in int8 range.
############################################################################################################
float_program, int8_program = convert(val_program, place, quant_config, \
scope=None, \
save_int8=True)
############################################################################################################
# 4. Save inference model
############################################################################################################
model_path = os.path.join(quantization_model_save_dir, args.model,
'act_' + quant_config['activation_quantize_type'] + '_w_' + quant_config[
'weight_quantize_type'])
float_path = os.path.join(model_path, 'float')
int8_path = os.path.join(model_path, 'int8')
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_inference_model(
dirname=float_path,
feeded_var_names=[image.name],
target_vars=[out], executor=exe,
main_program=float_program,
model_filename=float_path + '/model',
params_filename=float_path + '/params')
fluid.io.save_inference_model(
dirname=int8_path,
feeded_var_names=[image.name],
target_vars=[out], executor=exe,
main_program=int8_program,
model_filename=int8_path + '/model',
params_filename=int8_path + '/params')
def main():
args = parser.parse_args()
print_arguments(args)
compress(args)
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册