提交 76526e51 编写于 作者: X Xingyuan Bu 提交者: qingqing01

Ssd coco reader (#767)

* ready to coco_reader

* complete coco_reader.py & coco_train.py

* complete coco reader

* rename file

* use argparse instead of explicit assignment

* fix

* fix reader bug for some gray image in coco data

* ready to train coco

* fix bug in test()

* fix bug in test()

* change coco dataset to coco2017 dataset

* change dataset from coco to coco2017

* change learning rate

* fix bug in gt label (category id 2 label)

* fix bug in background label

* save model when train finished
上级 a4d00a37
......@@ -13,7 +13,7 @@ def conv_bn(input,
num_groups=1,
act='relu',
use_cudnn=True):
parameter_attr = ParamAttr(learning_rate=0.1, initializer=MSRA())
parameter_attr = ParamAttr(initializer=MSRA())
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
......@@ -25,14 +25,11 @@ def conv_bn(input,
use_cudnn=use_cudnn,
param_attr=parameter_attr,
bias_attr=False)
parameter_attr = ParamAttr(learning_rate=0.1, initializer=MSRA())
bias_attr = ParamAttr(learning_rate=0.2)
return fluid.layers.batch_norm(
input=conv,
act=act,
epsilon=0.00001,
param_attr=parameter_attr,
bias_attr=bias_attr)
#parameter_attr = ParamAttr(learning_rate=0.1, initializer=MSRA())
#bias_attr = ParamAttr(learning_rate=0.2)
return fluid.layers.batch_norm(input=conv, act=act, epsilon=0.00001)
#param_attr=parameter_attr,
#bias_attr=bias_attr)
def depthwise_separable(input, num_filters1, num_filters2, num_groups, stride,
......@@ -76,7 +73,7 @@ def extra_block(input, num_filters1, num_filters2, num_groups, stride, scale):
return normal_conv
def mobile_net(img, img_shape, scale=1.0):
def mobile_net(num_classes, img, img_shape, scale=1.0):
# 300x300
tmp = conv_bn(img, 3, int(32 * scale), 2, 1, 3)
# 150x150
......@@ -104,10 +101,11 @@ def mobile_net(img, img_shape, scale=1.0):
module16 = extra_block(module15, 128, 256, 1, 2, scale)
# 2x2
module17 = extra_block(module16, 64, 128, 1, 2, scale)
mbox_locs, mbox_confs, box, box_var = fluid.layers.multi_box_head(
inputs=[module11, module13, module14, module15, module16, module17],
image=img,
num_classes=21,
num_classes=num_classes,
min_ratio=20,
max_ratio=90,
min_sizes=[60.0, 105.0, 150.0, 195.0, 240.0, 285.0],
......
......@@ -16,19 +16,29 @@ import image_util
from paddle.utils.image_util import *
import random
from PIL import Image
from PIL import ImageDraw
import numpy as np
import xml.etree.ElementTree
import os
import time
import copy
# cocoapi
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
class Settings(object):
def __init__(self, data_dir, label_file, resize_h, resize_w, mean_value,
apply_distort, apply_expand):
def __init__(self, dataset, toy, data_dir, label_file, resize_h, resize_w,
mean_value, apply_distort, apply_expand):
self._dataset = dataset
self._toy = toy
self._data_dir = data_dir
self._label_list = []
label_fpath = os.path.join(data_dir, label_file)
for line in open(label_fpath):
self._label_list.append(line.strip())
if dataset == "pascalvoc":
self._label_list = []
label_fpath = os.path.join(data_dir, label_file)
for line in open(label_fpath):
self._label_list.append(line.strip())
self._apply_distort = apply_distort
self._apply_expand = apply_expand
......@@ -47,6 +57,14 @@ class Settings(object):
self._brightness_prob = 0.5
self._brightness_delta = 0.125
@property
def dataset(self):
return self._dataset
@property
def toy(self):
return self._toy
@property
def apply_distort(self):
return self._apply_expand
......@@ -59,6 +77,10 @@ class Settings(object):
def data_dir(self):
return self._data_dir
@data_dir.setter
def data_dir(self, data_dir):
self._data_dir = data_dir
@property
def label_list(self):
return self._label_list
......@@ -78,23 +100,72 @@ class Settings(object):
def _reader_creator(settings, file_list, mode, shuffle):
def reader():
with open(file_list) as flist:
lines = [line.strip() for line in flist]
if shuffle:
random.shuffle(lines)
for line in lines:
if settings.dataset == 'coco':
coco = COCO(file_list)
image_ids = coco.getImgIds()
images = coco.loadImgs(image_ids)
category_ids = coco.getCatIds()
category_names = [
item['name'] for item in coco.loadCats(category_ids)
]
elif settings.dataset == 'pascalvoc':
flist = open(file_list)
images = [line.strip() for line in flist]
if not settings.toy == 0:
images = images[:settings.toy] if len(
images) > settings.toy else images
print("{} on {} with {} images".format(mode, settings.dataset,
len(images)))
if shuffle:
random.shuffle(images)
for image in images:
if settings.dataset == 'coco':
image_name = image['file_name']
image_path = os.path.join(settings.data_dir, image_name)
elif settings.dataset == 'pascalvoc':
if mode == 'train' or mode == 'test':
img_path, label_path = line.split()
img_path = os.path.join(settings.data_dir, img_path)
image_path, label_path = image.split()
image_path = os.path.join(settings.data_dir, image_path)
label_path = os.path.join(settings.data_dir, label_path)
elif mode == 'infer':
img_path = os.path.join(settings.data_dir, line)
image_path = os.path.join(settings.data_dir, image)
img = Image.open(img_path)
img_width, img_height = img.size
img = Image.open(image_path)
if img.mode == 'L':
img = img.convert('RGB')
img_width, img_height = img.size
# layout: label | xmin | ymin | xmax | ymax | difficult
if mode == 'train' or mode == 'test':
if mode == 'train' or mode == 'test':
if settings.dataset == 'coco':
# layout: category_id | xmin | ymin | xmax | ymax | iscrowd | origin_coco_bbox | segmentation | area | image_id | annotation_id
bbox_labels = []
annIds = coco.getAnnIds(imgIds=image['id'])
anns = coco.loadAnns(annIds)
for ann in anns:
bbox_sample = []
# start from 1, leave 0 to background
bbox_sample.append(
float(category_ids.index(ann['category_id'])) + 1)
bbox = ann['bbox']
xmin, ymin, w, h = bbox
xmax = xmin + w
ymax = ymin + h
bbox_sample.append(float(xmin) / img_width)
bbox_sample.append(float(ymin) / img_height)
bbox_sample.append(float(xmax) / img_width)
bbox_sample.append(float(ymax) / img_height)
bbox_sample.append(float(ann['iscrowd']))
#bbox_sample.append(ann['bbox'])
#bbox_sample.append(ann['segmentation'])
#bbox_sample.append(ann['area'])
#bbox_sample.append(ann['image_id'])
#bbox_sample.append(ann['id'])
bbox_labels.append(bbox_sample)
elif settings.dataset == 'pascalvoc':
# layout: label | xmin | ymin | xmax | ymax | difficult
bbox_labels = []
root = xml.etree.ElementTree.parse(label_path).getroot()
for object in root.findall('object'):
......@@ -117,91 +188,138 @@ def _reader_creator(settings, file_list, mode, shuffle):
bbox_sample.append(difficult)
bbox_labels.append(bbox_sample)
sample_labels = bbox_labels
if mode == 'train':
if settings._apply_distort:
img = image_util.distort_image(img, settings)
if settings._apply_expand:
img, bbox_labels = image_util.expand_image(
img, bbox_labels, img_width, img_height,
settings)
batch_sampler = []
# hard-code here
batch_sampler.append(
image_util.sampler(1, 1, 1.0, 1.0, 1.0, 1.0, 0.0,
0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.1,
0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.3,
0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.5,
0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.7,
0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.9,
0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.0,
1.0))
""" random crop """
sampled_bbox = image_util.generate_batch_samples(
batch_sampler, bbox_labels, img_width, img_height)
img = np.array(img)
if len(sampled_bbox) > 0:
idx = int(random.uniform(0, len(sampled_bbox)))
img, sample_labels = image_util.crop_image(
img, bbox_labels, sampled_bbox[idx], img_width,
img_height)
img = Image.fromarray(img)
img = img.resize((settings.resize_w, settings.resize_h),
Image.ANTIALIAS)
img = np.array(img)
sample_labels = bbox_labels
if mode == 'train':
mirror = int(random.uniform(0, 2))
if mirror == 1:
img = img[:, ::-1, :]
for i in xrange(len(sample_labels)):
tmp = sample_labels[i][1]
sample_labels[i][1] = 1 - sample_labels[i][3]
sample_labels[i][3] = 1 - tmp
if len(img.shape) == 3:
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
img = img[[2, 1, 0], :, :]
img = img.astype('float32')
img -= settings.img_mean
img = img.flatten()
img = img * 0.007843
sample_labels = np.array(sample_labels)
if mode == 'train' or mode == 'test':
if mode == 'train' and len(sample_labels) == 0: continue
yield img.astype(
'float32'
), sample_labels[:, 1:5], sample_labels[:, 0].astype(
'int32'), sample_labels[:, -1].astype('int32')
elif mode == 'infer':
yield img.astype('float32')
if settings._apply_distort:
img = image_util.distort_image(img, settings)
if settings._apply_expand:
img, bbox_labels = image_util.expand_image(
img, bbox_labels, img_width, img_height, settings)
batch_sampler = []
# hard-code here
batch_sampler.append(
image_util.sampler(1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0))
""" random crop """
sampled_bbox = image_util.generate_batch_samples(
batch_sampler, bbox_labels, img_width, img_height)
img = np.array(img)
if len(sampled_bbox) > 0:
idx = int(random.uniform(0, len(sampled_bbox)))
img, sample_labels = image_util.crop_image(
img, bbox_labels, sampled_bbox[idx], img_width,
img_height)
img = Image.fromarray(img)
img = img.resize((settings.resize_w, settings.resize_h),
Image.ANTIALIAS)
img = np.array(img)
if mode == 'train':
mirror = int(random.uniform(0, 2))
if mirror == 1:
img = img[:, ::-1, :]
for i in xrange(len(sample_labels)):
tmp = sample_labels[i][1]
sample_labels[i][1] = 1 - sample_labels[i][3]
sample_labels[i][3] = 1 - tmp
#draw_bounding_box_on_image(img, sample_labels, image_name, category_names, normalized=True)
# HWC to CHW
if len(img.shape) == 3:
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
# RBG to BGR
img = img[[2, 1, 0], :, :]
img = img.astype('float32')
img -= settings.img_mean
img = img.flatten()
img = img * 0.007843
sample_labels = np.array(sample_labels)
if mode == 'train' or mode == 'test':
if mode == 'train' and len(sample_labels) == 0: continue
if mode == 'test' and len(sample_labels) == 0: continue
yield img.astype(
'float32'
), sample_labels[:, 1:5], sample_labels[:, 0].astype(
'int32'), sample_labels[:, -1].astype('int32')
elif mode == 'infer':
yield img.astype('float32')
return reader
def draw_bounding_box_on_image(image,
sample_labels,
image_name,
category_names,
color='red',
thickness=4,
with_text=True,
normalized=True):
image = Image.fromarray(image)
draw = ImageDraw.Draw(image)
im_width, im_height = image.size
if not normalized:
im_width, im_height = 1, 1
for item in sample_labels:
label = item[0]
category_name = category_names[int(label)]
bbox = item[1:5]
xmin, ymin, xmax, ymax = bbox
(left, right, top, bottom) = (xmin * im_width, xmax * im_width,
ymin * im_height, ymax * im_height)
draw.line(
[(left, top), (left, bottom), (right, bottom), (right, top),
(left, top)],
width=thickness,
fill=color)
#draw.rectangle([xmin, ymin, xmax, ymax], outline=color)
if with_text:
if image.mode == 'RGB':
draw.text((left, top), category_name, (255, 255, 0))
image.save(image_name)
def train(settings, file_list, shuffle=True):
return _reader_creator(settings, file_list, 'train', shuffle)
if settings.dataset == 'coco':
train_settings = copy.copy(settings)
if '2014' in file_list:
sub_dir = "train2014"
elif '2017' in file_list:
sub_dir = "train2017"
train_settings.data_dir = os.path.join(settings.data_dir, sub_dir)
file_list = os.path.join(settings.data_dir, file_list)
return _reader_creator(train_settings, file_list, 'train', shuffle)
elif settings.dataset == 'pascalvoc':
return _reader_creator(settings, file_list, 'train', shuffle)
def test(settings, file_list):
return _reader_creator(settings, file_list, 'test', False)
if settings.dataset == 'coco':
test_settings = copy.copy(settings)
if '2014' in file_list:
sub_dir = "val2014"
elif '2017' in file_list:
sub_dir = "val2017"
test_settings.data_dir = os.path.join(settings.data_dir, sub_dir)
file_list = os.path.join(settings.data_dir, file_list)
return _reader_creator(test_settings, file_list, 'test', False)
elif settings.dataset == 'pascalvoc':
return _reader_creator(settings, file_list, 'test', False)
def infer(settings, file_list):
......
import paddle.v2 as paddle
import paddle
import paddle.fluid as fluid
import reader
import load_model as load_model
from mobilenet_ssd import mobile_net
from utility import add_arguments, print_arguments
import os
import time
import numpy as np
import argparse
import functools
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('parallel', bool, True, "Whether use parallel training.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
# yapf: disable
add_arg('learning_rate', float, 0.001, "Learning rate.")
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('num_passes', int, 25, "Epoch number.")
add_arg('parallel', bool, True, "Whether use parallel training.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('data_dir', str, './data/COCO17', "Root path of data")
add_arg('train_file_list', str, 'annotations/instances_train2017.json',
"train file list")
add_arg('val_file_list', str, 'annotations/instances_val2017.json',
"vaild file list")
add_arg('model_save_dir', str, 'model_COCO17', "where to save model")
add_arg('dataset', str, 'coco', "coco or pascalvoc")
add_arg(
'is_toy', int, 0,
"Is Toy for quick debug, 0 means using all data, while n means using only n sample"
)
add_arg('label_file', str, 'label_list',
"Lable file which lists all label name")
add_arg('apply_distort', bool, True, "Whether apply distort")
add_arg('apply_expand', bool, False, "Whether appley expand")
add_arg('resize_h', int, 300, "resize image size")
add_arg('resize_w', int, 300, "resize image size")
add_arg('mean_value_B', float, 127.5,
"mean value which will be subtracted") #123.68
add_arg('mean_value_G', float, 127.5,
"mean value which will be subtracted") #116.78
add_arg('mean_value_R', float, 127.5,
"mean value which will be subtracted") #103.94
def train(args,
......@@ -28,6 +53,10 @@ def train(args,
model_save_dir='model',
init_model_path=None):
image_shape = [3, data_args.resize_h, data_args.resize_w]
if data_args.dataset == 'coco':
num_classes = 81
elif data_args.dataset == 'pascalvoc':
num_classes = 21
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
gt_box = fluid.layers.data(
......@@ -45,9 +74,10 @@ def train(args,
gt_box_ = pd.read_input(gt_box)
gt_label_ = pd.read_input(gt_label)
difficult_ = pd.read_input(difficult)
locs, confs, box, box_var = mobile_net(image_, image_shape)
loss = fluid.layers.ssd_loss(locs, confs, gt_box_, gt_label_,
box, box_var)
locs, confs, box, box_var = mobile_net(num_classes, image_,
image_shape)
loss = fluid.layers.ssd_loss(locs, confs, gt_box_, gt_label_, box,
box_var)
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
loss = fluid.layers.reduce_sum(loss)
......@@ -57,11 +87,11 @@ def train(args,
loss, nmsed_out = pd()
loss = fluid.layers.mean(loss)
else:
locs, confs, box, box_var = mobile_net(image, image_shape)
locs, confs, box, box_var = mobile_net(num_classes, image, image_shape)
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label,
box, box_var)
loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box,
box_var)
loss = fluid.layers.reduce_sum(loss)
test_program = fluid.default_main_program().clone(for_test=True)
......@@ -71,13 +101,20 @@ def train(args,
gt_label,
gt_box,
difficult,
21,
num_classes,
overlap_threshold=0.5,
evaluate_difficult=False,
ap_version='11point')
boundaries = [40000, 60000]
values = [0.001, 0.0005, 0.00025]
ap_version='integral')
if data_args.dataset == 'coco':
# learning rate decay in 12, 19 pass, respectively
if '2014' in train_file_list:
boundaries = [82783 / batch_size * 12, 82783 / batch_size * 19]
elif '2017' in train_file_list:
boundaries = [118287 / batch_size * 12, 118287 / batch_size * 19]
elif data_args.dataset == 'pascalvoc':
boundaries = [40000, 60000]
values = [learning_rate, learning_rate * 0.5, learning_rate * 0.25]
optimizer = fluid.optimizer.RMSProp(
learning_rate=fluid.layers.piecewise_decay(boundaries, values),
regularization=fluid.regularizer.L2Decay(0.00005), )
......@@ -88,8 +125,8 @@ def train(args,
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
load_model.load_and_set_vars(place)
#load_model.load_paddlev1_vars(place)
#load_model.load_and_set_vars(place)
load_model.load_paddlev1_vars(place)
train_reader = paddle.batch(
reader.train(data_args, train_file_list), batch_size=batch_size)
test_reader = paddle.batch(
......@@ -108,16 +145,23 @@ def train(args,
print("Test {0}, map {1}".format(pass_id, test_map[0]))
for pass_id in range(num_passes):
start_time = time.time()
prev_start_time = start_time
end_time = 0
for batch_id, data in enumerate(train_reader()):
prev_start_time = start_time
start_time = time.time()
#print("Batch {} start at {:.2f}".format(batch_id, start_time))
loss_v = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[loss])
end_time = time.time()
if batch_id % 20 == 0:
print("Pass {0}, batch {1}, loss {2}"
.format(pass_id, batch_id, loss_v[0]))
print("Pass {0}, batch {1}, loss {2}, time {3}".format(
pass_id, batch_id, loss_v[0], start_time - prev_start_time))
test(pass_id)
if pass_id % 10 == 0:
if pass_id % 10 == 0 or pass_id == num_passes - 1:
model_path = os.path.join(model_save_dir, str(pass_id))
print 'save models to %s' % (model_path)
fluid.io.save_inference_model(model_path, ['image'], [nmsed_out],
......@@ -128,17 +172,21 @@ if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
data_args = reader.Settings(
data_dir='./data',
label_file='label_list',
apply_distort=True,
apply_expand=True,
resize_h=300,
resize_w=300,
mean_value=[127.5, 127.5, 127.5])
train(args,
train_file_list='./data/trainval.txt',
val_file_list='./data/test.txt',
data_args=data_args,
learning_rate=0.001,
batch_size=args.batch_size,
num_passes=300)
dataset=args.dataset, # coco or pascalvoc
toy=args.is_toy,
data_dir=args.data_dir,
label_file=args.label_file,
apply_distort=args.apply_distort,
apply_expand=args.apply_expand,
resize_h=args.resize_h,
resize_w=args.resize_w,
mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R])
train(
args,
train_file_list=args.train_file_list,
val_file_list=args.val_file_list,
data_args=data_args,
learning_rate=args.learning_rate,
batch_size=args.batch_size,
num_passes=args.num_passes,
model_save_dir=args.model_save_dir)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册