提交 496ff37b 编写于 作者: X Xingyuan Bu 提交者: qingqing01

COCO dataset for SSD and update README.md (#844)

* ready to coco_reader

* complete coco_reader.py & coco_train.py

* complete coco reader

* rename file

* use argparse instead of explicit assignment

* fix

* fix reader bug for some gray image in coco data

* ready to train coco

* fix bug in test()

* fix bug in test()

* change coco dataset to coco2017 dataset

* change dataset from coco to coco2017

* change learning rate

* fix bug in gt label (category id 2 label)

* fix bug in background label

* save model when train finished

* use coco map

* adding coco year version args: 2014 or 2017

* add coco dataset download, and README.md

* fix

* fix image truncted IOError, map version error

* add test config

* add eval.py for evaluate trained model 

* fix

* fix bug when cocoMAP

* updata READEME.md

* fix cocoMAP bug

* find strange with test_program = fluid.default_main_program().clone(for_test=True)

* add inference and visualize, awa, README.md

* upload infer&visual example image

* refine image

* refine

* fix bug after merge

* follow yapf

* follow comments

* fix bug after separate eval and eval_cocoMAP

* follow yapf

* follow comments

* follow yapf

* follow yapf
上级 806cff7c
......@@ -6,7 +6,7 @@ The minimum PaddlePaddle version needed for the code sample in this directory is
### Introduction
[Single Shot MultiBox Detector (SSD)](https://arxiv.org/abs/1512.02325) framework for object detection is based on a feed-forward convolutional network. The early network is a standard convolutional architecture for image classification, such as VGG, ResNet, or MobileNet, which is als called base network. In this tutorial we used [MobileNet](https://arxiv.org/abs/1704.04861).
[Single Shot MultiBox Detector (SSD)](https://arxiv.org/abs/1512.02325) framework for object detection is based on a feed-forward convolutional network. The early network is a standard convolutional architecture for image classification, such as VGG, ResNet, or MobileNet, which is also called base network. In this tutorial we used [MobileNet](https://arxiv.org/abs/1704.04861).
### Data Preparation
......@@ -52,30 +52,51 @@ Declaration: the MobileNet-v1 SSD model is converted by [TensorFlow model](https
#### Train on PASCAL VOC
- Train on one device (/GPU).
```python
env CUDA_VISIABLE_DEVICES=0 python -u train.py --parallel=False --data='pascalvoc' --pretrained_model='pretrained/ssd_mobilenet_v1_coco/'
env CUDA_VISIABLE_DEVICES=0 python -u train.py --parallel=False --dataset='pascalvoc' --pretrained_model='pretrained/ssd_mobilenet_v1_coco/'
```
- Train on multi devices (/GPUs).
```python
env CUDA_VISIABLE_DEVICES=0,1 python -u train.py --batch_size=64 --data='pascalvoc' --pretrained_model='pretrained/ssd_mobilenet_v1_coco/'
env CUDA_VISIABLE_DEVICES=0,1 python -u train.py --batch_size=64 --dataset='pascalvoc' --pretrained_model='pretrained/ssd_mobilenet_v1_coco/'
```
#### Train on MS-COCO
- Train on one device (/GPU).
```python
env CUDA_VISIABLE_DEVICES=0 python -u train.py --parallel=False --data='coco' --pretrained_model='pretrained/mobilenet_imagenet/'
env CUDA_VISIABLE_DEVICES=0 python -u train.py --parallel=False --dataset='coco2014' --pretrained_model='pretrained/mobilenet_imagenet/'
```
- Train on multi devices (/GPUs).
```python
env CUDA_VISIABLE_DEVICES=0,1 python -u train.py --batch_size=64 --data='coco' --pretrained_model='pretrained/mobilenet_imagenet/'
env CUDA_VISIABLE_DEVICES=0,1 python -u train.py --batch_size=64 --dataset='coco2014' --pretrained_model='pretrained/mobilenet_imagenet/'
```
TBD
### Evaluate
You can evaluate your trained model in different metric like 11point, integral on both PASCAL VOC and COCO dataset. Moreover, we provide eval_coco_map.py which uses a COCO-specific mAP metric defined by [COCO committee](http://cocodataset.org/#detections-eval). To use this eval_coco_map.py, [cocoapi](https://github.com/cocodataset/cocoapi) is needed.
Install the cocoapi:
```
# COCOAPI=/path/to/clone/cocoapi
git clone https://github.com/cocodataset/cocoapi.git $COCOAPI
cd $COCOAPI/PythonAPI
# Install into global site-packages
make install
# Alternatively, if you do not have permissions or prefer
# not to install the COCO API into global site-packages
python2 setup.py install --user
```
Note we set the defualt test list to the dataset's test/val list, you can use your own test list by setting test_list args.
#### Evaluate on PASCAL VOC
```python
env CUDA_VISIABLE_DEVICES=0 python eval.py --dataset='pascalvoc' --model_dir='train_pascal_model/90' --data_dir='data/pascalvoc' --test_list='test.txt' --ap_version='11point'
```
#### Evaluate on MS-COCO
```python
env CUDA_VISIABLE_DEVICES=0 python eval.py --model='model/90' --test_list=''
env CUDA_VISIABLE_DEVICES=0 python eval.py --dataset='coco2014' --nms_threshold=0.5 --model_dir='train_coco_model/40' --test_list='annotations/instances_minival2014.json' --ap_version='integral'
env CUDA_VISIABLE_DEVICES=0 python eval_coco_map.py --dataset='coco2017' --nms_threshold=0.5 --model_dir='train_coco_model/40' --test_list='annotations/instances_minival2017.json'
```
TBD
......@@ -83,8 +104,16 @@ TBD
### Infer and Visualize
```python
env CUDA_VISIABLE_DEVICES=0 python infer.py --batch_size=2 --model='model/90' --test_list=''
env CUDA_VISIABLE_DEVICES=0 python infer.py --model_dir='train_coco_model/20' --image_path='./data/coco/val2014/COCO_val2014_000000000139.jpg'
```
Below is the examples after running python infer.py to inference and visualize the model result.
<p align="center">
<img src="images/COCO_val2014_000000000139.jpg" height=300 width=400 hspace='10'/>
<img src="images/COCO_val2014_000000000785.jpg" height=300 width=400 hspace='10'/>
<img src="images/COCO_val2014_000000142324.jpg" height=300 width=400 hspace='10'/>
<img src="images/COCO_val2014_000000144003.jpg" height=300 width=400 hspace='10'/> <br />
MobileNet-SSD300x300 Visualization Examples
</p>
TBD
......
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd "$DIR"
# Download the data.
echo "Downloading..."
wget http://images.cocodataset.org/zips/train2014.zip
wget http://images.cocodataset.org/zips/val2014.zip
wget http://images.cocodataset.org/zips/train2017.zip
wget http://images.cocodataset.org/zips/val2017.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
# Extract the data.
echo "Extractint..."
unzip train2014.tar
unzip val2014.tar
unzip train2017.tar
unzip val2017.tar
unzip annotations_trainval2014.tar
unzip annotations_trainval2017.tar
......@@ -13,27 +13,27 @@ from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('dataset', str, 'pascalvoc', "coco or pascalvoc.")
add_arg('dataset', str, 'pascalvoc', "coco2014, coco2017, and pascalvoc.")
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('data_dir', str, '', "The data root path.")
add_arg('test_list', str, '', "The testing data lists.")
add_arg('label_file', str, '', "The label file, which save the real name and is only used for Pascal VOC.")
add_arg('model_dir', str, '', "The model path.")
add_arg('ap_version', str, '11point', "11point or integral")
add_arg('resize_h', int, 300, "The resized image height.")
add_arg('resize_w', int, 300, "The resized image width.")
add_arg('mean_value_B', float, 127.5, "mean value for B channel which will be subtracted") #123.68
add_arg('mean_value_G', float, 127.5, "mean value for G channel which will be subtracted") #116.78
add_arg('mean_value_R', float, 127.5, "mean value for R channel which will be subtracted") #103.94
add_arg('model_dir', str, '', "The model path.")
add_arg('nms_threshold', float, 0.45, "NMS threshold.")
add_arg('ap_version', str, '11point', "integral, 11point.")
add_arg('resize_h', int, 300, "The resized image height.")
add_arg('resize_w', int, 300, "The resized image height.")
add_arg('mean_value_B', float, 127.5, "Mean value for B channel which will be subtracted.") #123.68
add_arg('mean_value_G', float, 127.5, "Mean value for G channel which will be subtracted.") #116.78
add_arg('mean_value_R', float, 127.5, "Mean value for R channel which will be subtracted.") #103.94
# yapf: enable
def eval(args, data_args, test_list, batch_size, model_dir=None):
image_shape = [3, data_args.resize_h, data_args.resize_w]
if data_args.dataset == 'coco':
num_classes = 81
elif data_args.dataset == 'pascalvoc':
if 'coco' in data_args.dataset:
num_classes = 91
elif 'pascalvoc' in data_args.dataset:
num_classes = 21
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
......@@ -46,61 +46,77 @@ def eval(args, data_args, test_list, batch_size, model_dir=None):
locs, confs, box, box_var = mobile_net(num_classes, image, image_shape)
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
locs, confs, box, box_var, nms_threshold=args.nms_threshold)
loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box, box_var)
loss = fluid.layers.reduce_sum(loss)
test_program = fluid.default_main_program().clone(for_test=True)
with fluid.program_guard(test_program):
map_eval = fluid.evaluator.DetectionMAP(
nmsed_out,
gt_label,
gt_box,
difficult,
num_classes,
overlap_threshold=0.5,
evaluate_difficult=False,
ap_version=args.ap_version)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# yapf: disable
if model_dir:
def if_exist(var):
return os.path.exists(os.path.join(model_dir, var.name))
fluid.io.load_vars(exe, model_dir, predicate=if_exist)
# yapf: enable
test_reader = paddle.batch(
reader.test(data_args, test_list), batch_size=batch_size)
feeder = fluid.DataFeeder(
place=place, feed_list=[image, gt_box, gt_label, difficult])
_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
for idx, data in enumerate(test_reader()):
test_map = exe.run(test_program,
feed=feeder.feed(data),
fetch_list=[accum_map])
if idx % 50 == 0:
print("Batch {0}, map {1}".format(idx, test_map[0]))
print("Test model {0}, map {1}".format(model_dir, test_map[0]))
def test():
test_program = fluid.default_main_program().clone(for_test=True)
with fluid.program_guard(test_program):
map_eval = fluid.evaluator.DetectionMAP(
nmsed_out,
gt_label,
gt_box,
difficult,
num_classes,
overlap_threshold=0.5,
evaluate_difficult=False,
ap_version=args.ap_version)
_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
for batch_id, data in enumerate(test_reader()):
test_map = exe.run(test_program,
feed=feeder.feed(data),
fetch_list=[accum_map])
if batch_id % 20 == 0:
print("Batch {0}, map {1}".format(batch_id, test_map[0]))
print("Test model {0}, map {1}".format(model_dir, test_map[0]))
test()
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
data_dir = 'data/pascalvoc'
test_list = 'test.txt'
label_file = 'label_list'
if 'coco' in args.dataset:
data_dir = './data/coco'
if '2014' in args.dataset:
test_list = 'annotations/instances_minival2014.json'
elif '2017' in args.dataset:
test_list = 'annotations/instances_val2017.json'
data_args = reader.Settings(
dataset=args.dataset,
data_dir=args.data_dir,
label_file=args.label_file,
data_dir=args.data_dir if len(args.data_dir) > 0 else data_dir,
label_file=label_file,
resize_h=args.resize_h,
resize_w=args.resize_w,
mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R])
mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R],
apply_distort=False,
apply_expand=False,
ap_version=args.ap_version,
toy=0)
eval(
args,
test_list=args.test_list,
data_args=data_args,
test_list=args.test_list if len(args.test_list) > 0 else test_list,
batch_size=args.batch_size,
model_dir=args.model_dir)
import os
import time
import numpy as np
import argparse
import functools
import paddle
import paddle.fluid as fluid
import reader
from mobilenet_ssd import mobile_net
from utility import add_arguments, print_arguments
# A special mAP metric for COCO dataset, which averages AP in different IoUs.
# To use this eval_cocoMAP.py, [cocoapi](https://github.com/cocodataset/cocoapi) is needed.
import json
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('dataset', str, 'coco2014', "coco2014, coco2017.")
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('data_dir', str, '', "The data root path.")
add_arg('test_list', str, '', "The testing data lists.")
add_arg('model_dir', str, '', "The model path.")
add_arg('nms_threshold', float, 0.5, "NMS threshold.")
add_arg('ap_version', str, 'cocoMAP', "cocoMAP.")
add_arg('resize_h', int, 300, "The resized image height.")
add_arg('resize_w', int, 300, "The resized image height.")
add_arg('mean_value_B', float, 127.5, "Mean value for B channel which will be subtracted.") #123.68
add_arg('mean_value_G', float, 127.5, "Mean value for G channel which will be subtracted.") #116.78
add_arg('mean_value_R', float, 127.5, "Mean value for R channel which will be subtracted.") #103.94
# yapf: enable
def eval(args, data_args, test_list, batch_size, model_dir=None):
image_shape = [3, data_args.resize_h, data_args.resize_w]
num_classes = 91
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
gt_box = fluid.layers.data(
name='gt_box', shape=[4], dtype='float32', lod_level=1)
gt_label = fluid.layers.data(
name='gt_label', shape=[1], dtype='int32', lod_level=1)
gt_iscrowd = fluid.layers.data(
name='gt_iscrowd', shape=[1], dtype='int32', lod_level=1)
gt_image_info = fluid.layers.data(
name='gt_image_id', shape=[3], dtype='int32', lod_level=1)
locs, confs, box, box_var = mobile_net(num_classes, image, image_shape)
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=args.nms_threshold)
loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box, box_var)
loss = fluid.layers.reduce_sum(loss)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# yapf: disable
if model_dir:
def if_exist(var):
return os.path.exists(os.path.join(model_dir, var.name))
fluid.io.load_vars(exe, model_dir, predicate=if_exist)
# yapf: enable
test_reader = paddle.batch(
reader.test(data_args, test_list), batch_size=batch_size)
feeder = fluid.DataFeeder(
place=place,
feed_list=[image, gt_box, gt_label, gt_iscrowd, gt_image_info])
def get_dt_res(nmsed_out_v):
dts_res = []
lod = nmsed_out_v[0].lod()[0]
nmsed_out_v = np.array(nmsed_out_v[0])
real_batch_size = min(batch_size, len(data))
assert (len(lod) == real_batch_size + 1), \
"Error Lod Tensor offset dimension. Lod({}) vs. batch_size({})".format(len(lod), batch_size)
k = 0
for i in range(real_batch_size):
dt_num_this_img = lod[i + 1] - lod[i]
image_id = int(data[i][4][0])
image_width = int(data[i][4][1])
image_height = int(data[i][4][2])
for j in range(dt_num_this_img):
dt = nmsed_out_v[k]
k = k + 1
category_id, score, xmin, ymin, xmax, ymax = dt.tolist()
xmin = max(min(xmin, 1.0), 0.0) * image_width
ymin = max(min(ymin, 1.0), 0.0) * image_height
xmax = max(min(xmax, 1.0), 0.0) * image_width
ymax = max(min(ymax, 1.0), 0.0) * image_height
w = xmax - xmin
h = ymax - ymin
bbox = [xmin, ymin, w, h]
dt_res = {
'image_id': image_id,
'category_id': category_id,
'bbox': bbox,
'score': score
}
dts_res.append(dt_res)
def test():
dts_res = []
for batch_id, data in enumerate(test_reader()):
nmsed_out_v = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[nmsed_out],
return_numpy=False)
if batch_id % 20 == 0:
print("Batch {0}".format(batch_id))
dts_res += get_dt_res(nmsed_out_v)
with open("detection_result.json", 'w') as outfile:
json.dump(dts_res, outfile)
print("start evaluate using coco api")
cocoGt = COCO(os.path.join(data_args.data_dir, test_list))
cocoDt = cocoGt.loadRes("detection_result.json")
cocoEval = COCOeval(cocoGt, cocoDt, "bbox")
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
test()
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
data_dir = './data/coco'
if '2014' in args.dataset:
test_list = 'annotations/instances_minival2014.json'
elif '2017' in args.dataset:
test_list = 'annotations/instances_val2017.json'
data_args = reader.Settings(
dataset=args.dataset,
data_dir=args.data_dir if len(args.data_dir) > 0 else data_dir,
label_file='',
resize_h=args.resize_h,
resize_w=args.resize_w,
mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R],
apply_distort=False,
apply_expand=False,
ap_version=args.ap_version,
toy=0)
eval(
args,
data_args=data_args,
test_list=args.test_list if len(args.test_list) > 0 else test_list,
batch_size=args.batch_size,
model_dir=args.model_dir)
from PIL import Image, ImageEnhance
from PIL import Image, ImageEnhance, ImageDraw
from PIL import ImageFile
import numpy as np
import random
import math
ImageFile.LOAD_TRUNCATED_IMAGES = True #otherwise IOError raised image file is truncated
class sampler():
def __init__(self, max_sample, max_trial, min_scale, max_scale,
......@@ -144,7 +147,8 @@ def transform_labels(bbox_labels, sample_bbox):
sample_label.append(float(proj_bbox.ymin))
sample_label.append(float(proj_bbox.xmax))
sample_label.append(float(proj_bbox.ymax))
sample_label.append(bbox_labels[i][5])
#sample_label.append(bbox_labels[i][5])
sample_label = sample_label + bbox_labels[i][5:]
sample_labels.append(sample_label)
return sample_labels
......
import os
import time
import numpy as np
import argparse
import functools
from PIL import Image
from PIL import ImageDraw
import paddle
import paddle.fluid as fluid
import reader
from mobilenet_ssd import mobile_net
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('dataset', str, 'pascalvoc', "coco and pascalvoc.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('image_path', str, '', "The image used to inference and visualize.")
add_arg('model_dir', str, '', "The model path.")
add_arg('nms_threshold', float, 0.45, "NMS threshold.")
add_arg('confs_threshold', float, 0.2, "Confidence threshold to draw bbox.")
add_arg('resize_h', int, 300, "The resized image height.")
add_arg('resize_w', int, 300, "The resized image height.")
add_arg('mean_value_B', float, 127.5, "Mean value for B channel which will be subtracted.") #123.68
add_arg('mean_value_G', float, 127.5, "Mean value for G channel which will be subtracted.") #116.78
add_arg('mean_value_R', float, 127.5, "Mean value for R channel which will be subtracted.") #103.94
# yapf: enable
def infer(args, data_args, image_path, model_dir):
image_shape = [3, data_args.resize_h, data_args.resize_w]
if 'coco' in data_args.dataset:
num_classes = 91
elif 'pascalvoc' in data_args.dataset:
num_classes = 21
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
locs, confs, box, box_var = mobile_net(num_classes, image, image_shape)
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=args.nms_threshold)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# yapf: disable
if model_dir:
def if_exist(var):
return os.path.exists(os.path.join(model_dir, var.name))
fluid.io.load_vars(exe, model_dir, predicate=if_exist)
# yapf: enable
infer_reader = reader.infer(data_args, image_path)
feeder = fluid.DataFeeder(place=place, feed_list=[image])
def infer():
data = infer_reader()
nmsed_out_v = exe.run(fluid.default_main_program(),
feed=feeder.feed([[data]]),
fetch_list=[nmsed_out],
return_numpy=False)
nmsed_out_v = np.array(nmsed_out_v[0])
draw_bounding_box_on_image(image_path, nmsed_out_v,
args.confs_threshold)
for dt in nmsed_out_v:
category_id, score, xmin, ymin, xmax, ymax = dt.tolist()
infer()
def draw_bounding_box_on_image(image_path, nms_out, confs_threshold):
image = Image.open(image_path)
draw = ImageDraw.Draw(image)
im_width, im_height = image.size
for dt in nms_out:
category_id, score, xmin, ymin, xmax, ymax = dt.tolist()
if score < confs_threshold:
continue
bbox = dt[2:]
xmin, ymin, xmax, ymax = bbox
(left, right, top, bottom) = (xmin * im_width, xmax * im_width,
ymin * im_height, ymax * im_height)
draw.line(
[(left, top), (left, bottom), (right, bottom), (right, top),
(left, top)],
width=4,
fill='red')
image_name = image_path.split('/')[-1]
print("image with bbox drawed saved as {}".format(image_name))
image.save(image_name)
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
data_args = reader.Settings(
dataset=args.dataset,
data_dir='',
label_file='',
resize_h=args.resize_h,
resize_w=args.resize_w,
mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R],
apply_distort=False,
apply_expand=False,
ap_version='',
toy=0)
infer(
args,
data_args=data_args,
image_path=args.image_path,
model_dir=args.model_dir)
......@@ -34,11 +34,13 @@ class Settings(object):
mean_value=[127.5, 127.5, 127.5],
apply_distort=True,
apply_expand=True,
ap_version='11point',
toy=0):
self._dataset = dataset
self._ap_version = ap_version
self._toy = toy
self._data_dir = data_dir
if dataset == "pascalvoc":
if 'pascalvoc' in dataset:
self._label_list = []
label_fpath = os.path.join(data_dir, label_file)
for line in open(label_fpath):
......@@ -65,6 +67,10 @@ class Settings(object):
def dataset(self):
return self._dataset
@property
def ap_version(self):
return self._ap_version
@property
def toy(self):
return self._toy
......@@ -187,17 +193,17 @@ def coco(settings, file_list, mode, shuffle):
if im.mode == 'L':
im = im.convert('RGB')
im_width, im_height = im.size
im_id = image['id']
# layout: category_id | xmin | ymin | xmax | ymax | iscrowd |
# origin_coco_bbox | segmentation | area | image_id | annotation_id
# layout: category_id | xmin | ymin | xmax | ymax | iscrowd
bbox_labels = []
annIds = coco.getAnnIds(imgIds=image['id'])
anns = coco.loadAnns(annIds)
for ann in anns:
bbox_sample = []
# start from 1, leave 0 to background
bbox_sample.append(
float(category_ids.index(ann['category_id'])) + 1)
bbox_sample.append(float(ann['category_id']))
#float(category_ids.index(ann['category_id'])) + 1)
bbox = ann['bbox']
xmin, ymin, w, h = bbox
xmax = xmin + w
......@@ -214,8 +220,12 @@ def coco(settings, file_list, mode, shuffle):
im = im.astype('float32')
boxes = sample_labels[:, 1:5]
lbls = sample_labels[:, 0].astype('int32')
difficults = sample_labels[:, -1].astype('int32')
yield im, boxes, lbls, difficults
iscrowd = sample_labels[:, -1].astype('int32')
if 'cocoMAP' in settings.ap_version:
yield im, boxes, lbls, iscrowd, \
[im_id, im_width, im_height]
else:
yield im, boxes, lbls, iscrowd
return reader
......@@ -268,40 +278,9 @@ def pascalvoc(settings, file_list, mode, shuffle):
return reader
def draw_bounding_box_on_image(image,
sample_labels,
image_name,
category_names,
color='red',
thickness=4,
with_text=True,
normalized=True):
image = Image.fromarray(image)
draw = ImageDraw.Draw(image)
im_width, im_height = image.size
if not normalized:
im_width, im_height = 1, 1
for item in sample_labels:
label = item[0]
category_name = category_names[int(label)]
bbox = item[1:5]
xmin, ymin, xmax, ymax = bbox
(left, right, top, bottom) = (xmin * im_width, xmax * im_width,
ymin * im_height, ymax * im_height)
draw.line(
[(left, top), (left, bottom), (right, bottom), (right, top),
(left, top)],
width=thickness,
fill=color)
if with_text:
if image.mode == 'RGB':
draw.text((left, top), category_name, (255, 255, 0))
image.save(image_name)
def train(settings, file_list, shuffle=True):
file_list = os.path.join(settings.data_dir, file_list)
if settings.dataset == 'coco':
if 'coco' in settings.dataset:
train_settings = copy.copy(settings)
if '2014' in file_list:
sub_dir = "train2014"
......@@ -315,7 +294,7 @@ def train(settings, file_list, shuffle=True):
def test(settings, file_list):
file_list = os.path.join(settings.data_dir, file_list)
if settings.dataset == 'coco':
if 'coco' in settings.dataset:
test_settings = copy.copy(settings)
if '2014' in file_list:
sub_dir = "val2014"
......@@ -329,10 +308,10 @@ def test(settings, file_list):
def infer(settings, image_path):
def reader():
im = Image.open(image_path)
if im.mode == 'L':
im = im.convert('RGB')
im_width, im_height = im.size
img = Image.open(image_path)
if img.mode == 'L':
img = im.convert('RGB')
im_width, im_height = img.size
img = img.resize((settings.resize_w, settings.resize_h),
Image.ANTIALIAS)
img = np.array(img)
......@@ -345,6 +324,6 @@ def infer(settings, image_path):
img = img.astype('float32')
img -= settings.img_mean
img = img * 0.007843
yield img
return img
return reader
......@@ -17,23 +17,21 @@ add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('learning_rate', float, 0.001, "Learning rate.")
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('num_passes', int, 120, "Epoch number.")
add_arg('parallel', bool, True, "Whether use parallel training.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('use_nccl', bool, False, "Whether to use NCCL or not.")
add_arg('dataset', str, 'pascalvoc', "coco or pascalvoc.")
add_arg('model_save_dir', str, 'model', "The path to save model.")
add_arg('pretrained_model', str, 'pretrained/ssd_mobilenet_v1_coco/', "The init model path.")
add_arg('apply_distort', bool, True, "Whether apply distort")
add_arg('apply_expand', bool, True, "Whether appley expand")
add_arg('ap_version', str, '11point', "11point or integral")
add_arg('resize_h', int, 300, "The resized image height.")
add_arg('resize_w', int, 300, "The resized image width.")
add_arg('mean_value_B', float, 127.5, "mean value for B channel which will be subtracted") #123.68
add_arg('mean_value_G', float, 127.5, "mean value for G channel which will be subtracted") #116.78
add_arg('mean_value_R', float, 127.5, "mean value for R channel which will be subtracted") #103.94
add_arg('is_toy', int, 0, "Toy for quick debug, 0 means using all data, while n means using only n sample")
# yapf: enable
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('dataset', str, 'pascalvoc', "coco2014, coco2017, and pascalvoc.")
add_arg('model_save_dir', str, 'model', "The path to save model.")
add_arg('pretrained_model', str, 'pretrained/ssd_mobilenet_v1_coco/', "The init model path.")
add_arg('apply_distort', bool, True, "Whether apply distort.")
add_arg('apply_expand', bool, False, "Whether appley expand.")
add_arg('nms_threshold', float, 0.45, "NMS threshold.")
add_arg('ap_version', str, 'integral', "integral, 11point.")
add_arg('resize_h', int, 300, "The resized image height.")
add_arg('resize_w', int, 300, "The resized image height.")
add_arg('mean_value_B', float, 127.5, "Mean value for B channel which will be subtracted.") #123.68
add_arg('mean_value_G', float, 127.5, "Mean value for G channel which will be subtracted.") #116.78
add_arg('mean_value_R', float, 127.5, "Mean value for R channel which will be subtracted.") #103.94
add_arg('is_toy', int, 0, "Toy for quick debug, 0 means using all data, while n means using only n sample.")
#yapf: enable
def parallel_do(args,
train_file_list,
......@@ -118,10 +116,8 @@ def parallel_do(args,
exe.run(fluid.default_startup_program())
if pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
train_reader = paddle.batch(
......@@ -139,7 +135,7 @@ def parallel_do(args,
test_map = exe.run(test_program,
feed=feeder.feed(data),
fetch_list=[accum_map])
print("Test {0}, map {1}".format(pass_id, test_map[0]))
print("Pass {0}, test map {1}".format(pass_id, test_map[0]))
for pass_id in range(num_passes):
start_time = time.time()
......@@ -170,12 +166,12 @@ def parallel_exe(args,
learning_rate,
batch_size,
num_passes,
model_save_dir='model',
model_save_dir,
pretrained_model=None):
image_shape = [3, data_args.resize_h, data_args.resize_w]
if data_args.dataset == 'coco':
num_classes = 81
elif data_args.dataset == 'pascalvoc':
if 'coco' in data_args.dataset:
num_classes = 91
elif 'pascalvoc' in data_args.dataset:
num_classes = 21
devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
......@@ -188,11 +184,16 @@ def parallel_exe(args,
name='gt_label', shape=[1], dtype='int32', lod_level=1)
difficult = fluid.layers.data(
name='gt_difficult', shape=[1], dtype='int32', lod_level=1)
gt_iscrowd = fluid.layers.data(
name='gt_iscrowd', shape=[1], dtype='int32', lod_level=1)
gt_image_info = fluid.layers.data(
name='gt_image_id', shape=[3], dtype='int32', lod_level=1)
locs, confs, box, box_var = mobile_net(num_classes, image, image_shape)
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box, box_var)
locs, confs, box, box_var, nms_threshold=args.nms_threshold)
loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box,
box_var)
loss = fluid.layers.reduce_sum(loss)
test_program = fluid.default_main_program().clone(for_test=True)
......@@ -207,7 +208,7 @@ def parallel_exe(args,
evaluate_difficult=False,
ap_version=args.ap_version)
if data_args.dataset == 'coco':
if 'coco' in data_args.dataset:
# learning rate decay in 12, 19 pass, respectively
if '2014' in train_file_list:
epocs = 82783 / batch_size
......@@ -215,13 +216,16 @@ def parallel_exe(args,
elif '2017' in train_file_list:
epocs = 118287 / batch_size
boundaries = [epcos * 12, epocs * 19]
values = [
learning_rate, learning_rate * 0.5, learning_rate * 0.25
]
elif data_args.dataset == 'pascalvoc':
epocs = 19200 / batch_size
boundaries = [epocs * 40, epocs * 60, epocs * 80, epocs * 100]
values = [
learning_rate, learning_rate * 0.5, learning_rate * 0.25,
learning_rate * 0.1, learning_rate * 0.01
]
values = [
learning_rate, learning_rate * 0.5, learning_rate * 0.25,
learning_rate * 0.1, learning_rate * 0.01
]
optimizer = fluid.optimizer.RMSProp(
learning_rate=fluid.layers.piecewise_decay(boundaries, values),
regularization=fluid.regularizer.L2Decay(0.00005), )
......@@ -233,10 +237,8 @@ def parallel_exe(args,
exe.run(fluid.default_startup_program())
if pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
if args.parallel:
......@@ -262,15 +264,16 @@ def parallel_exe(args,
def test(pass_id, best_map):
_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
test_map = None
for data in test_reader():
for batch_id, data in enumerate(test_reader()):
test_map = exe.run(test_program,
feed=feeder.feed(data),
fetch_list=[accum_map])
if batch_id % 20 == 0:
print("Batch {0}, map {1}".format(batch_id, test_map[0]))
if test_map[0] > best_map:
best_map = test_map[0]
save_model('best_model')
print("Test {0}, map {1}".format(pass_id, test_map[0]))
print("Pass {0}, test map {1}".format(pass_id, test_map[0]))
for pass_id in range(num_passes):
start_time = time.time()
......@@ -307,23 +310,26 @@ if __name__ == '__main__':
val_file_list = 'test.txt'
label_file = 'label_list'
model_save_dir = args.model_save_dir
if args.dataset == 'coco':
data_dir = './data/COCO17'
train_file_list = 'annotations/instances_train2017.json'
val_file_list = 'annotations/instances_val2017.json'
label_file = 'label_list'
if 'coco' in args.dataset:
data_dir = './data/coco'
if '2014' in args.dataset:
train_file_list = 'annotations/instances_train2014.json'
val_file_list = 'annotations/instances_minival2014.json'
elif '2017' in args.dataset:
train_file_list = 'annotations/instances_train2017.json'
val_file_list = 'annotations/instances_val2017.json'
data_args = reader.Settings(
dataset=args.dataset,
data_dir=data_dir,
label_file=label_file,
apply_distort=args.apply_distort,
apply_expand=args.apply_expand,
resize_h=args.resize_h,
resize_w=args.resize_w,
mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R],
apply_distort=args.apply_distort,
apply_expand=args.apply_expand,
ap_version = args.ap_version,
toy=args.is_toy)
#method = parallel_do
method = parallel_exe
method(
args,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册