未验证 提交 25ff4007 编写于 作者: Q qingqing01 提交者: GitHub

Support to set multiprocess in SSD model (#2776)

* Support to set multiprocess in SSD model
* clean code
上级 3f32248e
......@@ -91,6 +91,7 @@ tar -xf vgg_ilsvrc_16_fc_reduced.tar.gz && rm -f vgg_ilsvrc_16_fc_reduced.tar.gz
python -u train.py --batch_size=16 --pretrained_model=vgg_ilsvrc_16_fc_reduced
```
- 可以通过设置 `export CUDA_VISIBLE_DEVICES=0,1,2,3` 指定想要使用的GPU数量,`batch_size`默认设置为12或16。
- **注意**: 在**Windows**机器上训练,需要设置 `--use_multiprocess=False`,因为在Windows上使用Python多进程加速训练时有错误。
- 更多的可选参数见:
```bash
python train.py --help
......
......@@ -280,14 +280,25 @@ def train_generator(settings, file_list, batch_size, shuffle=True):
return reader
def train(settings, file_list, batch_size, shuffle=True, num_workers=8):
def train(settings,
file_list,
batch_size,
shuffle=True,
use_multiprocess=True,
num_workers=8):
file_lists = load_file_list(file_list)
if use_multiprocess:
n = int(math.ceil(len(file_lists) // num_workers))
split_lists = [file_lists[i:i + n] for i in range(0, len(file_lists), n)]
split_lists = [
file_lists[i:i + n] for i in range(0, len(file_lists), n)
]
readers = []
for iterm in split_lists:
readers.append(train_generator(settings, iterm, batch_size, shuffle))
readers.append(
train_generator(settings, iterm, batch_size, shuffle))
return paddle.reader.multiprocess_reader(readers, False)
else:
return train_generator(settings, file_lists, batch_size, shuffle)
def test(settings, file_list):
......
......@@ -9,6 +9,20 @@ import time
import argparse
import functools
def set_paddle_flags(**kwargs):
for key, value in kwargs.items():
if os.environ.get(key, None) is None:
os.environ[key] = str(value)
# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
# not take any effect.
set_paddle_flags(
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
)
import paddle
import paddle.fluid as fluid
from pyramidbox import PyramidBox
......@@ -32,6 +46,7 @@ add_arg('mean_BGR', str, '104., 117., 123.', "Mean value for B,G,R cha
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('pretrained_model', str, './vgg_ilsvrc_16_fc_reduced/', "The init model path.")
add_arg('data_dir', str, 'data', "The base dir of dataset")
add_arg('use_multiprocess', bool, True, "Whether use multi-process for data preprocessing.")
parser.add_argument('--enable_ce', action='store_true', help='If set, run the task with continuous evaluation logs.')
parser.add_argument('--batch_num', type=int, help="batch num for ce")
parser.add_argument('--num_devices', type=int, default=1, help='Number of GPU devices')
......@@ -163,7 +178,8 @@ def train(args, config, train_params, train_file_list):
train_file_list,
batch_size_per_device,
shuffle = is_shuffle,
num_workers = num_workers)
use_multiprocess=args.use_multiprocess,
num_workers=num_workers)
train_py_reader.decorate_paddle_reader(train_reader)
if args.parallel:
......
......@@ -23,7 +23,7 @@ SSD is readily pluggable into a wide variant standard convolutional network, suc
Please download [PASCAL VOC dataset](http://host.robots.ox.ac.uk/pascal/VOC/) at first, skip this step if you already have one.
```bash
```
cd data/pascalvoc
./download.sh
```
......@@ -36,7 +36,7 @@ The command `download.sh` also will create training and testing file lists.
We provide two pre-trained models. The one is MobileNet-v1 SSD trained on COCO dataset, but removed the convolutional predictors for COCO dataset. This model can be used to initialize the models when training other datasets, like PASCAL VOC. The other pre-trained model is MobileNet-v1 trained on ImageNet 2012 dataset but removed the last weights and bias in the Fully-Connected layer. Download MobileNet-v1 SSD:
```bash
```
./pretrained/download_coco.sh
```
......@@ -46,13 +46,14 @@ Declaration: the MobileNet-v1 SSD model is converted by [TensorFlow model](https
#### Train on PASCAL VOC
`train.py` is the main caller of the training module. Examples of usage are shown below.
```bash
```
python -u train.py --batch_size=64 --dataset=pascalvoc --pretrained_model=pretrained/ssd_mobilenet_v1_coco/
```
- Set ```export CUDA_VISIBLE_DEVICES=0,1``` to specifiy the number of GPU you want to use.
- **Note**: set `--use_multiprocess=False` when training on **Windows**, since some problems need to be solved when using Python multiprocess to accelerate data processing.
- For more help on arguments:
```bash
```
python train.py --help
```
......@@ -69,13 +70,13 @@ We used RMSProp optimizer with mini-batch size 64 to train the MobileNet-SSD. Th
You can evaluate your trained model in different metrics like 11point, integral on both PASCAL VOC and COCO dataset. Note we set the default test list to the dataset's test/val list, you can use your own test list by setting ```--test_list``` args.
`eval.py` is the main caller of the evaluating module. Examples of usage are shown below.
```bash
```
python eval.py --dataset=pascalvoc --model_dir=model/best_model --data_dir=data/pascalvoc --test_list=test.txt
```
### Infer and Visualize
`infer.py` is the main caller of the inferring module. Examples of usage are shown below.
```bash
```
python infer.py --dataset=pascalvoc --nms_threshold=0.45 --model_dir=model/best_model --image_path=./data/pascalvoc/VOCdevkit/VOC2007/JPEGImages/009963.jpg
```
Below are the examples of running the inference and visualizing the model result.
......
......@@ -24,7 +24,7 @@ SSD 可以方便地插入到任何一种标准卷积网络中,比如 VGG、Res
请先使用下面的命令下载 [PASCAL VOC 数据集](http://host.robots.ox.ac.uk/pascal/VOC/)
```bash
```
cd data/pascalvoc
./download.sh
```
......@@ -38,7 +38,7 @@ cd data/pascalvoc
我们提供了两个预训练模型。第一个模型是在 COCO 数据集上预训练的 MobileNet-v1 SSD,我们将它的预测头移除了以便在 COCO 以外的数据集上进行训练。第二个模型是在 ImageNet 2012 数据集上预训练的 MobileNet-v1,我们也将最后的全连接层移除以便进行目标检测训练。下载 MobileNet-v1 SSD:
```bash
```
./pretrained/download_coco.sh
```
......@@ -48,13 +48,14 @@ cd data/pascalvoc
#### 训练
`train.py` 是训练模块的主要执行程序,调用示例如下:
```bash
python -u train.py --batch_size=64 --dataset='pascalvoc' --pretrained_model='pretrained/ssd_mobilenet_v1_coco/'
```
python -u train.py --batch_size=64 --dataset=pascalvoc --pretrained_model=pretrained/ssd_mobilenet_v1_coco/
```
- 可以通过设置 ```export CUDA_VISIBLE_DEVICES=0,1``` 指定想要使用的GPU数量。
- **注意**: 在**Windows**机器上训练,需要设置 `--use_multiprocess=False`,因为在Windows上使用Python多进程加速训练时有错误。
- 更多的可选参数见:
```bash
```
python train.py --help
```
......@@ -71,15 +72,16 @@ cd data/pascalvoc
你可以使用11point、integral等指标在PASCAL VOC 数据集上评估训练好的模型。不失一般性,我们采用相应数据集的测试列表作为样例代码的默认列表,你也可以通过设置```--test_list```来指定自己的测试样本列表。
`eval.py`是评估模块的主要执行程序,调用示例如下:
```bash
python eval.py --dataset='pascalvoc' --model_dir='train_pascal_model/best_model' --data_dir='data/pascalvoc' --test_list='test.txt' --ap_version='11point' --nms_threshold=0.45
```
python eval.py --dataset=pascalvoc --model_dir=model/best_model --data_dir=data/pascalvoc --test_list=test.txt
```
### 模型预测以及可视化
`infer.py`是预测及可视化模块的主要执行程序,调用示例如下:
```bash
python infer.py --dataset='pascalvoc' --nms_threshold=0.45 --model_dir='train_pascal_model/best_model' --image_path='./data/pascalvoc/VOCdevkit/VOC2007/JPEGImages/009963.jpg'
```
python infer.py --dataset=pascalvoc --nms_threshold=0.45 --model_dir=model/best_model --image_path=./data/pascalvoc/VOCdevkit/VOC2007/JPEGImages/009963.jpg
```
下图可视化了模型的预测结果:
<p align="center">
......
......@@ -283,6 +283,7 @@ def train(settings,
file_list,
batch_size,
shuffle=True,
use_multiprocess=True,
num_workers=8,
enable_ce=False):
file_path = os.path.join(settings.data_dir, file_list)
......@@ -294,14 +295,15 @@ def train(settings,
image_ids = coco_api.getImgIds()
images = coco_api.loadImgs(image_ids)
np.random.shuffle(images)
n = int(math.ceil(len(images) // num_workers))
image_lists = [images[i:i + n] for i in range(0, len(images), n)]
if '2014' in file_list:
sub_dir = "train2014"
elif '2017' in file_list:
sub_dir = "train2017"
data_dir = os.path.join(settings.data_dir, sub_dir)
n = int(math.ceil(len(images) // num_workers)) if use_multiprocess \
else len(images)
image_lists = [images[i:i + n] for i in range(0, len(images), n)]
for l in image_lists:
readers.append(
coco(settings, coco_api, l, 'train', batch_size, shuffle,
......@@ -309,11 +311,16 @@ def train(settings,
else:
images = [line.strip() for line in open(file_path)]
np.random.shuffle(images)
n = int(math.ceil(len(images) // num_workers))
n = int(math.ceil(len(images) // num_workers)) if use_multiprocess \
else len(images)
image_lists = [images[i:i + n] for i in range(0, len(images), n)]
for l in image_lists:
readers.append(pascalvoc(settings, l, 'train', batch_size, shuffle))
print("use_multiprocess ", use_multiprocess)
if use_multiprocess:
return paddle.reader.multiprocess_reader(readers, False)
else:
return readers[0]
def test(settings, file_list, batch_size):
......
......@@ -7,6 +7,20 @@ import shutil
import math
import multiprocessing
def set_paddle_flags(**kwargs):
for key, value in kwargs.items():
if os.environ.get(key, None) is None:
os.environ[key] = str(value)
# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
# not take any effect.
set_paddle_flags(
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
)
import paddle
import paddle.fluid as fluid
import reader
......@@ -28,6 +42,7 @@ add_arg('ap_version', str, '11point', "mAP version can be inte
add_arg('image_shape', str, '3,300,300', "Input image shape.")
add_arg('mean_BGR', str, '127.5,127.5,127.5', "Mean value for B,G,R channel which will be subtracted.")
add_arg('data_dir', str, 'data/pascalvoc', "Data directory.")
add_arg('use_multiprocess', bool, True, "Whether use multi-process for data preprocessing.")
add_arg('enable_ce', bool, False, "Whether use CE to evaluate the model.")
#yapf: enable
......@@ -185,14 +200,8 @@ def train(args,
build_strategy.memory_optimize = True
train_exe = fluid.ParallelExecutor(main_program=train_prog,
use_cuda=use_gpu, loss_name=loss.name, build_strategy=build_strategy)
train_reader = reader.train(data_args,
train_file_list,
batch_size_per_device,
shuffle=is_shuffle,
num_workers=num_workers,
enable_ce=enable_ce)
test_reader = reader.test(data_args, val_file_list, batch_size)
train_py_reader.decorate_paddle_reader(train_reader)
test_py_reader.decorate_paddle_reader(test_reader)
def save_model(postfix, main_prog):
......@@ -233,6 +242,7 @@ def train(args,
train_file_list,
batch_size_per_device,
shuffle=is_shuffle,
use_multiprocess=args.use_multiprocess,
num_workers=num_workers,
enable_ce=enable_ce)
train_py_reader.decorate_paddle_reader(train_reader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册