未验证 提交 372d4d14 编写于 作者: Q qingqing01 提交者: GitHub

Quantization-aware training for SSD. (#1409)

* Quantization-aware training for SSD.
* Clean code and refine doc.
上级 88098fff
运行本目录下的程序示例需要使用 PaddlePaddle 最新的 develop branch 版本。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新 PaddlePaddle 安装版本。
---
## Pyramidbox 人脸检测
## Table of Contents
......
......@@ -21,3 +21,4 @@ data/pascalvoc/trainval.txt
log*
*.log
ssd_mobilenet_v1_pascalvoc*
quant_model
The minimum PaddlePaddle version needed for the code sample in this directory is the latest develop branch. If you are on a version of PaddlePaddle earlier than this, [please update your installation](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html).
---
## SSD Object Detection
## Table of Contents
......
运行本目录下的程序示例需要使用 PaddlePaddle 最新的 develop branch 版本。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新 PaddlePaddle 安装版本。
---
## SSD 目标检测
## Table of Contents
......
## Quantization-aware training for SSD
### Introduction
The quantization-aware training used in this experiments is introduced in [fixed-point quantization desigin](https://gthub.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/design/quantization/fixed_point_quantization.md). Since quantization-aware training is still an active area of research and experimentation,
here, we just give an simple quantization training usage in Fluid based on MobileNet-SSD model, and more other exeperiments are still needed, like how to quantization traning by considering fusing batch normalization and convolution/fully-connected layers, channel-wise quantization of weights and so on.
A Python transpiler is used to rewrite Fluid training program or evaluation program for quantization-aware training:
```python
#startup_prog = fluid.Program()
#train_prog = fluid.Program()
#loss = build_program(
# main_prog=train_prog,
# startup_prog=startup_prog,
# is_train=True)
#build_program(
# main_prog=test_prog,
# startup_prog=startup_prog,
# is_train=False)
#test_prog = test_prog.clone(for_test=True)
# above is an pseudo code
transpiler = fluid.contrib.QuantizeTranspiler(
weight_bits=8,
activation_bits=8,
activation_quantize_type='abs_max', # or 'range_abs_max'
weight_quantize_type='abs_max')
# note, transpiler.training_transpile will rewrite train_prog
# startup_prog is needed since it needs to insert and initialize
# some state variable
transpiler.training_transpile(train_prog, startup_prog)
transpiler.training_transpile(test_prog, startup_prog)
```
According to above design, this transpiler inserts fake quantization and de-quantization operation for each convolution operation (including depthwise convolution operation) and fully-connected operation. These quantizations take affect on weights and activations.
In the design, we introduce dynamic quantization and static quantization strategies for different activation quantization methods. In the expriments, when set `activation_quantize_type` to `abs_max`, it is dynamic quantization. That is to say, the quantization scale (maximum of absolute value) of activation will be calculated each mini-batch during inference. When set `activation_quantize_type` to `range_abs_max`, a quantization scale for inference period will be calculated during training. Following part will introduce how to train.
### Quantization-aware training
The training is fine-tuned on the well-trained MobileNet-SSD model. So download model at first:
```bash
wget http://paddlemodels.bj.bcebos.com/ssd_mobilenet_v1_pascalvoc.tar.gz
```
- dynamic quantization:
```python
python main_quant.py \
--data_dir=$PascalVOC_DIR$ \
--mode='train' \
--init_model=ssd_mobilenet_v1_pascalvoc \
--act_quant_type='abs_max' \
--epoc_num=20 \
--learning_rate=0.0001 \
--batch_size=64 \
--model_save_dir=$OUTPUT_DIR$
```
Since fine-tuned on a well-trained model, we use a small start learnng rate 0.0001, and train 20 epocs.
- static quantization:
```python
python main_quant.py \
--data_dir=$PascalVOC_DIR$ \
--mode='train' \
--init_model=ssd_mobilenet_v1_pascalvoc \
--act_quant_type='range_abs_max' \
--epoc_num=80 \
--learning_rate=0.001 \
--lr_epochs=30,60 \
--lr_decay_rates=1,0.1,0.01 \
--batch_size=64 \
--model_save_dir=$OUTPUT_DIR$
```
Here, train 80 epocs, learning rate decays at 30 and 60 epocs by 0.1 every time. Users can adjust these hype-parameters.
### Convert to inference model
As described in the design documentation, the inference graph is a little different from training, the difference is the de-quantization operation is before or after conv/fc. This is equivalent in training due to linear operation of conv/fc and de-quantization and functions' commutative law. But for inference, it needs to convert the graph, `fluid.contrib.QuantizeTranspiler.freeze_program` is used to do this:
```python
#startup_prog = fluid.Program()
#test_prog = fluid.Program()
#test_py_reader, map_eval, nmsed_out, image = build_program(
# main_prog=test_prog,
# startup_prog=startup_prog,
# train_params=configs,
# is_train=False)
#test_prog = test_prog.clone(for_test=True)
#transpiler = fluid.contrib.QuantizeTranspiler(weight_bits=8,
# activation_bits=8,
# activation_quantize_type=act_quant_type,
# weight_quantize_type='abs_max')
#transpiler.training_transpile(test_prog, startup_prog)
#place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
#exe = fluid.Executor(place)
#exe.run(startup_prog)
def if_exist(var):
return os.path.exists(os.path.join(init_model, var.name))
fluid.io.load_vars(exe, init_model, main_program=test_prog,
predicate=if_exist)
# freeze the rewrited training program
# freeze after load parameters, it will quantized weights
transpiler.freeze_program(test_prog, place)
```
Users can evaluate the converted model by:
```bash
python main_quant.py \
--data_dir=$PascalVOC_DIR$ \
--mode='test' \
--init_model=$MODLE_DIR$ \
--model_save_dir=$MobileNet_SSD_8BIT_MODEL$
```
You also can check the 8-bit model by the inference scripts
```bash
python main_quant.py \
--mode='infer' \
--init_model=$MobileNet_SSD_8BIT_MODEL$ \
--image_path='/data/PascalVOC/VOCdevkit/VOC2007/JPEGImages/002271.jpg'
```
See 002271.jpg for the visualized image with bbouding boxes.
### Results
Results of MobileNet-v1-SSD 300x300 model on PascalVOC dataset.
| Model | mAP |
|:---------------------------------------:|:------------------:|
|Floating point: 32bit | 73.32% |
|Fixed point: 8bit, dynamic quantization | 72.77% |
|Fixed point: 8bit, static quantization | 72.45% |
As mentioned above, other experiments, like how to quantization traning by considering fusing batch normalization and convolution/fully-connected layers, channel-wise quantization of weights, quantizated weights type with uint8 instead of int8 and so on.
......@@ -77,12 +77,13 @@ def eval(args, data_args, test_list, batch_size, model_dir=None):
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
# yapf: disable
if model_dir:
def if_exist(var):
return os.path.exists(os.path.join(model_dir, var.name))
fluid.io.load_vars(exe, model_dir, main_program=test_prog, predicate=if_exist)
# yapf: enable
def if_exist(var):
return os.path.exists(os.path.join(model_dir, var.name))
fluid.io.load_vars(
exe, model_dir, main_program=test_prog, predicate=if_exist)
test_reader = reader.test(data_args, test_list, batch_size=batch_size)
test_py_reader.decorate_paddle_reader(test_reader)
......@@ -96,7 +97,7 @@ def eval(args, data_args, test_list, batch_size, model_dir=None):
if batch_id % 10 == 0:
print("Batch {0}, map {1}".format(batch_id, test_map))
batch_id += 1
except fluid.core.EOFException:
except (fluid.core.EOFException, StopIteration):
test_py_reader.reset()
print("Test model {0}, map {1}".format(model_dir, test_map))
......
import os
import time
import numpy as np
import argparse
import functools
import shutil
import math
import paddle
import paddle.fluid as fluid
import reader
from mobilenet_ssd import mobile_net
from utility import add_arguments, print_arguments
from train import build_program
from train import train_parameters
from infer import draw_bounding_box_on_image
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('learning_rate', float, 0.0001, "Learning rate.")
add_arg('batch_size', int, 64, "Minibatch size.")
add_arg('epoc_num', int, 20, "Epoch number.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('parallel', bool, True, "Whether train in parallel on multi-devices.")
add_arg('model_save_dir', str, 'quant_model', "The path to save model.")
add_arg('init_model', str, 'ssd_mobilenet_v1_pascalvoc', "The init model path.")
add_arg('ap_version', str, '11point', "mAP version can be integral or 11point.")
add_arg('image_shape', str, '3,300,300', "Input image shape.")
add_arg('mean_BGR', str, '127.5,127.5,127.5', "Mean value for B,G,R channel which will be subtracted.")
add_arg('lr_epochs', str, '30,60', "The learning decay steps.")
add_arg('lr_decay_rates', str, '1,0.1,0.01', "The learning decay rates for each step.")
add_arg('data_dir', str, 'data/pascalvoc', "Data directory")
add_arg('act_quant_type', str, 'abs_max', "Quantize type of activation, whicn can be abs_max or range_abs_max")
add_arg('image_path', str, '', "The image used to inference and visualize.")
add_arg('confs_threshold', float, 0.5, "Confidence threshold to draw bbox.")
add_arg('mode', str, 'train', "Job mode can be one of ['train', 'test', 'infer'].")
#yapf: enable
def test(exe, test_prog, map_eval, test_py_reader):
_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
test_py_reader.start()
try:
batch = 0
while True:
test_map, = exe.run(test_prog, fetch_list=[accum_map])
if batch % 10 == 0:
print("Batch {0}, map {1}".format(batch, test_map))
batch += 1
except fluid.core.EOFException:
test_py_reader.reset()
finally:
test_py_reader.reset()
print("Test map {0}".format(test_map))
return test_map
def save_model(exe, main_prog, model_save_dir, postfix):
model_path = os.path.join(model_save_dir, postfix)
if os.path.isdir(model_path):
shutil.rmtree(model_path)
fluid.io.save_persistables(exe, model_path, main_program=main_prog)
def train(args,
data_args,
train_params,
train_file_list,
val_file_list):
model_save_dir = args.model_save_dir
init_model = args.init_model
epoc_num = args.epoc_num
use_gpu = args.use_gpu
parallel = args.parallel
is_shuffle = True
act_quant_type = args.act_quant_type
if use_gpu:
devices_num = fluid.core.get_cuda_device_count()
else:
devices_num = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
batch_size = train_params['batch_size']
batch_size_per_device = batch_size // devices_num
iters_per_epoc = train_params["train_images"] // batch_size
num_workers = 4
startup_prog = fluid.Program()
train_prog = fluid.Program()
test_prog = fluid.Program()
train_py_reader, loss = build_program(
main_prog=train_prog,
startup_prog=startup_prog,
train_params=train_params,
is_train=True)
test_py_reader, map_eval, _, _ = build_program(
main_prog=test_prog,
startup_prog=startup_prog,
train_params=train_params,
is_train=False)
test_prog = test_prog.clone(for_test=True)
transpiler = fluid.contrib.QuantizeTranspiler(weight_bits=8,
activation_bits=8,
activation_quantize_type=act_quant_type,
weight_quantize_type='abs_max')
transpiler.training_transpile(train_prog, startup_prog)
transpiler.training_transpile(test_prog, startup_prog)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
if init_model:
print('Load init model %s.' % init_model)
def if_exist(var):
return os.path.exists(os.path.join(init_model, var.name))
fluid.io.load_vars(exe, init_model, main_program=train_prog,
predicate=if_exist)
else:
print('There is no init model.')
if parallel:
train_exe = fluid.ParallelExecutor(main_program=train_prog,
use_cuda=use_gpu, loss_name=loss.name)
train_reader = reader.train(data_args,
train_file_list,
batch_size_per_device,
shuffle=is_shuffle,
use_multiprocessing=True,
num_workers=num_workers,
max_queue=24)
test_reader = reader.test(data_args, val_file_list, batch_size)
train_py_reader.decorate_paddle_reader(train_reader)
test_py_reader.decorate_paddle_reader(test_reader)
train_py_reader.start()
best_map = 0.
try:
for epoc in range(epoc_num):
if epoc == 0:
# test quantized model without quantization-aware training.
test_map = test(exe, test_prog, map_eval, test_py_reader)
# train
for batch in range(iters_per_epoc):
start_time = time.time()
if parallel:
outs = train_exe.run(fetch_list=[loss.name])
else:
outs = exe.run(train_prog, fetch_list=[loss])
end_time = time.time()
avg_loss = np.mean(np.array(outs[0]))
if batch % 20 == 0:
print("Epoc {:d}, batch {:d}, loss {:.6f}, time {:.5f}".format(
epoc , batch, avg_loss, end_time - start_time))
end_time = time.time()
test_map = test(exe, test_prog, map_eval, test_py_reader)
save_model(exe, train_prog, model_save_dir, str(epoc))
if test_map > best_map:
best_map = test_map
save_model(exe, train_prog, model_save_dir, 'best_map')
print("Best test map {0}".format(best_map))
except (fluid.core.EOFException, StopIteration):
train_py_reader.reset()
def eval(args, data_args, configs, val_file_list):
init_model = args.init_model
use_gpu = args.use_gpu
act_quant_type = args.act_quant_type
model_save_dir = args.model_save_dir
batch_size = configs['batch_size']
batch_size_per_device = batch_size
startup_prog = fluid.Program()
test_prog = fluid.Program()
test_py_reader, map_eval, nmsed_out, image = build_program(
main_prog=test_prog,
startup_prog=startup_prog,
train_params=configs,
is_train=False)
test_prog = test_prog.clone(for_test=True)
transpiler = fluid.contrib.QuantizeTranspiler(weight_bits=8,
activation_bits=8,
activation_quantize_type=act_quant_type,
weight_quantize_type='abs_max')
transpiler.training_transpile(test_prog, startup_prog)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
def if_exist(var):
return os.path.exists(os.path.join(init_model, var.name))
fluid.io.load_vars(exe, init_model, main_program=test_prog,
predicate=if_exist)
# freeze after load parameters
transpiler.freeze_program(test_prog, place)
test_reader = reader.test(data_args, val_file_list, batch_size)
test_py_reader.decorate_paddle_reader(test_reader)
test_map = test(exe, test_prog, map_eval, test_py_reader)
print("Test model {0}, map {1}".format(init_model, test_map))
fluid.io.save_inference_model(model_save_dir, [image.name],
[nmsed_out], exe, test_prog)
def infer(args, data_args):
model_dir = args.init_model
image_path = args.image_path
confs_threshold = args.confs_threshold
voc_labels = 'data/pascalvoc/label_list'
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
[inference_program, feed , fetch] = fluid.io.load_inference_model(
dirname=model_dir,
executor=exe,
model_filename='__model__')
print(np.array(fluid.global_scope().find_var('conv2d_20.w_0').get_tensor()))
print(np.max(np.array(fluid.global_scope().find_var('conv2d_20.w_0').get_tensor())))
infer_reader = reader.infer(data_args, image_path)
data = infer_reader()
data = data.reshape((1,) + data.shape)
outs = exe.run(inference_program,
feed={feed[0]: data},
fetch_list=fetch,
return_numpy=False)
out = np.array(outs[0])
draw_bounding_box_on_image(image_path, out, confs_threshold, voc_labels)
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
# for pascalvoc
label_file = 'label_list'
train_list = 'trainval.txt'
val_list = 'test.txt'
dataset = 'pascalvoc'
mean_BGR = [float(m) for m in args.mean_BGR.split(",")]
image_shape = [int(m) for m in args.image_shape.split(",")]
lr_epochs = [int(m) for m in args.lr_epochs.split(",")]
lr_rates = [float(m) for m in args.lr_decay_rates.split(",")]
train_parameters[dataset]['image_shape'] = image_shape
train_parameters[dataset]['batch_size'] = args.batch_size
train_parameters[dataset]['lr'] = args.learning_rate
train_parameters[dataset]['epoc_num'] = args.epoc_num
train_parameters[dataset]['ap_version'] = args.ap_version
train_parameters[dataset]['lr_epochs'] = lr_epochs
train_parameters[dataset]['lr_decay'] = lr_rates
data_args = reader.Settings(
dataset=dataset,
data_dir=args.data_dir,
label_file=label_file,
resize_h=image_shape[1],
resize_w=image_shape[2],
mean_value=mean_BGR,
apply_distort=True,
apply_expand=True,
ap_version = args.ap_version)
if args.mode == 'train':
train(args, data_args, train_parameters[dataset], train_list, val_list)
elif args.mode == 'test':
eval(args, data_args, train_parameters[dataset], val_list)
else:
infer(args, data_args)
......@@ -5,6 +5,7 @@ import argparse
import functools
import shutil
import math
import multiprocessing
import paddle
import paddle.fluid as fluid
......@@ -16,18 +17,18 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('learning_rate', float, 0.001, "Learning rate.")
add_arg('batch_size', int, 64, "Minibatch size.")
add_arg('batch_size', int, 64, "Minibatch size of all devices.")
add_arg('epoc_num', int, 120, "Epoch number.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('parallel', bool, True, "Parallel.")
add_arg('dataset', str, 'pascalvoc', "coco2014, coco2017, and pascalvoc.")
add_arg('parallel', bool, True, "Whether train in parallel on multi-devices.")
add_arg('dataset', str, 'pascalvoc', "dataset can be coco2014, coco2017, and pascalvoc.")
add_arg('model_save_dir', str, 'model', "The path to save model.")
add_arg('pretrained_model', str, 'pretrained/ssd_mobilenet_v1_coco/', "The init model path.")
add_arg('ap_version', str, '11point', "Integral, 11point.")
add_arg('ap_version', str, '11point', "mAP version can be integral or 11point.")
add_arg('image_shape', str, '3,300,300', "Input image shape.")
add_arg('mean_BGR', str, '127.5,127.5,127.5', "Mean value for B,G,R channel which will be subtracted.")
add_arg('data_dir', str, 'data/pascalvoc', "data directory")
add_arg('enable_ce', bool, False, "Whether use CE to evaluate the model")
add_arg('mean_BGR', str, '127.5,127.5,127.5', "Mean value for B,G,R channel which will be subtracted.")
add_arg('data_dir', str, 'data/pascalvoc', "Data directory.")
add_arg('enable_ce', bool, False, "Whether use CE to evaluate the model.")
#yapf: enable
train_parameters = {
......@@ -81,6 +82,7 @@ def build_program(main_prog, startup_prog, train_params, is_train):
image_shape = train_params['image_shape']
class_num = train_params['class_num']
ap_version = train_params['ap_version']
outs = []
with fluid.program_guard(main_prog, startup_prog):
py_reader = fluid.layers.py_reader(
capacity=64,
......@@ -98,11 +100,12 @@ def build_program(main_prog, startup_prog, train_params, is_train):
loss = fluid.layers.reduce_sum(loss)
optimizer = optimizer_setting(train_params)
optimizer.minimize(loss)
outs = [py_reader, loss]
else:
with fluid.unique_name.guard("inference"):
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
loss = fluid.evaluator.DetectionMAP(
map_eval = fluid.evaluator.DetectionMAP(
nmsed_out,
gt_label,
gt_box,
......@@ -111,7 +114,9 @@ def build_program(main_prog, startup_prog, train_params, is_train):
overlap_threshold=0.5,
evaluate_difficult=False,
ap_version=ap_version)
return py_reader, loss
# nmsed_out and image is used to save mode for inference
outs = [py_reader, map_eval, nmsed_out, image]
return outs
def train(args,
......@@ -127,8 +132,12 @@ def train(args,
enable_ce = args.enable_ce
is_shuffle = True
devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
devices_num = len(devices.split(","))
if not use_gpu:
devices_num = int(os.environ.get('CPU_NUM',
multiprocessing.cpu_count()))
else:
devices_num = fluid.core.get_cuda_device_count()
batch_size = train_params['batch_size']
epoc_num = train_params['epoc_num']
batch_size_per_device = batch_size // devices_num
......@@ -153,7 +162,7 @@ def train(args,
startup_prog=startup_prog,
train_params=train_params,
is_train=True)
test_py_reader, map_eval = build_program(
test_py_reader, map_eval, _, _ = build_program(
main_prog=test_prog,
startup_prog=startup_prog,
train_params=train_params,
......@@ -258,11 +267,9 @@ def train(args,
print("kpis train_speed_card%s %f" %
(devices_num, total_time / epoch_idx))
except fluid.core.EOFException:
train_py_reader.reset()
except StopIteration:
except (fluid.core.EOFException, StopIteration):
train_reader().close()
train_py_reader.reset()
train_py_reader.reset()
if __name__ == '__main__':
......@@ -291,6 +298,7 @@ if __name__ == '__main__':
train_parameters[dataset]['batch_size'] = args.batch_size
train_parameters[dataset]['lr'] = args.learning_rate
train_parameters[dataset]['epoc_num'] = args.epoc_num
train_parameters[dataset]['ap_version'] = args.ap_version
data_args = reader.Settings(
dataset=args.dataset,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册