提交 e76927fe 编写于 作者: D Dang Qingqing

Refine reader.py and train.py

上级 547b3918
...@@ -102,169 +102,169 @@ class Settings(object): ...@@ -102,169 +102,169 @@ class Settings(object):
return self._img_mean return self._img_mean
def _reader_creator(settings, file_list, mode, shuffle): def preprocess(img, bbox_labels, mode, settings):
img_width, img_height = img.size
sampled_labels = bbox_labels
if mode == 'train':
if settings._apply_distort:
img = image_util.distort_image(img, settings)
if settings._apply_expand:
img, bbox_labels, img_width, img_height = image_util.expand_image(
img, bbox_labels, img_width, img_height, settings)
# sampling
batch_sampler = []
# hard-code here
batch_sampler.append(
image_util.sampler(1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0))
sampled_bbox = image_util.generate_batch_samples(batch_sampler,
bbox_labels)
img = np.array(img)
if len(sampled_bbox) > 0:
idx = int(random.uniform(0, len(sampled_bbox)))
img, sampled_labels = image_util.crop_image(
img, bbox_labels, sampled_bbox[idx], img_width, img_height)
img = Image.fromarray(img)
img = img.resize((settings.resize_w, settings.resize_h), Image.ANTIALIAS)
img = np.array(img)
if mode == 'train':
mirror = int(random.uniform(0, 2))
if mirror == 1:
img = img[:, ::-1, :]
for i in xrange(len(sampled_labels)):
tmp = sampled_labels[i][1]
sampled_labels[i][1] = 1 - sampled_labels[i][3]
sampled_labels[i][3] = 1 - tmp
# HWC to CHW
if len(img.shape) == 3:
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
# RBG to BGR
img = img[[2, 1, 0], :, :]
img = img.astype('float32')
img -= settings.img_mean
#img = img.flatten()
img = img * 0.007843
return img, sampled_labels
def coco(settings, file_list, mode, shuffle):
# cocoapi
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
coco = COCO(file_list)
image_ids = coco.getImgIds()
images = coco.loadImgs(image_ids)
category_ids = coco.getCatIds()
category_names = [item['name'] for item in coco.loadCats(category_ids)]
if not settings.toy == 0:
images = images[:settings.toy] if len(images) > settings.toy else images
print("{} on {} with {} images".format(mode, settings.dataset, len(images)))
def reader(): def reader():
if settings.dataset == 'coco': if mode == 'train' and shuffle:
# cocoapi
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
coco = COCO(file_list)
image_ids = coco.getImgIds()
images = coco.loadImgs(image_ids)
category_ids = coco.getCatIds()
category_names = [
item['name'] for item in coco.loadCats(category_ids)
]
else:
flist = open(file_list)
images = [line.strip() for line in flist]
if not settings.toy == 0:
images = images[:settings.toy] if len(
images) > settings.toy else images
print("{} on {} with {} images".format(mode, settings.dataset,
len(images)))
if shuffle:
random.shuffle(images) random.shuffle(images)
for image in images: for image in images:
if settings.dataset == 'coco': image_name = image['file_name']
image_name = image['file_name'] image_path = os.path.join(settings.data_dir, image_name)
image_path = os.path.join(settings.data_dir, image_name)
elif settings.dataset == 'pascalvoc': im = Image.open(image_path)
if mode == 'train' or mode == 'test': if im.mode == 'L':
image_path, label_path = image.split() im = im.convert('RGB')
image_path = os.path.join(settings.data_dir, image_path) im_width, im_height = im.size
label_path = os.path.join(settings.data_dir, label_path)
elif mode == 'infer': # layout: category_id | xmin | ymin | xmax | ymax | iscrowd |
image_path = os.path.join(settings.data_dir, image) # origin_coco_bbox | segmentation | area | image_id | annotation_id
bbox_labels = []
img = Image.open(image_path) annIds = coco.getAnnIds(imgIds=image['id'])
if img.mode == 'L': anns = coco.loadAnns(annIds)
img = img.convert('RGB') for ann in anns:
img_width, img_height = img.size bbox_sample = []
# start from 1, leave 0 to background
if mode == 'train' or mode == 'test': bbox_sample.append(
if settings.dataset == 'coco': float(category_ids.index(ann['category_id'])) + 1)
# layout: category_id | xmin | ymin | xmax | ymax | iscrowd | origin_coco_bbox | segmentation | area | image_id | annotation_id bbox = ann['bbox']
bbox_labels = [] xmin, ymin, w, h = bbox
annIds = coco.getAnnIds(imgIds=image['id']) xmax = xmin + w
anns = coco.loadAnns(annIds) ymax = ymin + h
for ann in anns: bbox_sample.append(float(xmin) / im_width)
bbox_sample = [] bbox_sample.append(float(ymin) / im_height)
# start from 1, leave 0 to background bbox_sample.append(float(xmax) / im_width)
bbox_sample.append( bbox_sample.append(float(ymax) / im_height)
float(category_ids.index(ann['category_id'])) + 1) bbox_sample.append(float(ann['iscrowd']))
bbox = ann['bbox'] bbox_labels.append(bbox_sample)
xmin, ymin, w, h = bbox im, sample_labels = preprocess(im, bbox_labels, mode, settings)
xmax = xmin + w sample_labels = np.array(sample_labels)
ymax = ymin + h if len(sample_labels) == 0: continue
bbox_sample.append(float(xmin) / img_width) im = im.astype('float32')
bbox_sample.append(float(ymin) / img_height) boxes = sample_labels[:, 1:5]
bbox_sample.append(float(xmax) / img_width) lbls = sample_labels[:, 0].astype('int32')
bbox_sample.append(float(ymax) / img_height) difficults = sample_labels[:, -1].astype('int32')
bbox_sample.append(float(ann['iscrowd'])) yield im, boxes, lbls, difficults
#bbox_sample.append(ann['bbox'])
#bbox_sample.append(ann['segmentation']) return reader
#bbox_sample.append(ann['area'])
#bbox_sample.append(ann['image_id'])
#bbox_sample.append(ann['id'])
bbox_labels.append(bbox_sample)
elif settings.dataset == 'pascalvoc':
# layout: label | xmin | ymin | xmax | ymax | difficult
bbox_labels = []
root = xml.etree.ElementTree.parse(label_path).getroot()
for object in root.findall('object'):
bbox_sample = []
# start from 1
bbox_sample.append(
float(
settings.label_list.index(
object.find('name').text)))
bbox = object.find('bndbox')
difficult = float(object.find('difficult').text)
bbox_sample.append(
float(bbox.find('xmin').text) / img_width)
bbox_sample.append(
float(bbox.find('ymin').text) / img_height)
bbox_sample.append(
float(bbox.find('xmax').text) / img_width)
bbox_sample.append(
float(bbox.find('ymax').text) / img_height)
bbox_sample.append(difficult)
bbox_labels.append(bbox_sample)
sample_labels = bbox_labels
if mode == 'train':
if settings._apply_distort:
img = image_util.distort_image(img, settings)
if settings._apply_expand:
img, bbox_labels, img_width, img_height = image_util.expand_image(
img, bbox_labels, img_width, img_height, settings)
batch_sampler = []
# hard-code here
batch_sampler.append(
image_util.sampler(1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0))
""" random crop """
sampled_bbox = image_util.generate_batch_samples(
batch_sampler, bbox_labels)
img = np.array(img)
if len(sampled_bbox) > 0:
idx = int(random.uniform(0, len(sampled_bbox)))
img, sample_labels = image_util.crop_image(
img, bbox_labels, sampled_bbox[idx], img_width,
img_height)
img = Image.fromarray(img)
img = img.resize((settings.resize_w, settings.resize_h),
Image.ANTIALIAS)
img = np.array(img)
if mode == 'train':
mirror = int(random.uniform(0, 2))
if mirror == 1:
img = img[:, ::-1, :]
for i in xrange(len(sample_labels)):
tmp = sample_labels[i][1]
sample_labels[i][1] = 1 - sample_labels[i][3]
sample_labels[i][3] = 1 - tmp
# HWC to CHW
if len(img.shape) == 3:
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
# RBG to BGR
img = img[[2, 1, 0], :, :]
img = img.astype('float32')
img -= settings.img_mean
img = img.flatten()
img = img * 0.007843
def pascalvoc(settings, file_list, mode, shuffle):
flist = open(file_list)
images = [line.strip() for line in flist]
if not settings.toy == 0:
images = images[:settings.toy] if len(images) > settings.toy else images
print("{} on {} with {} images".format(mode, settings.dataset, len(images)))
def reader():
if mode == 'train' and shuffle:
random.shuffle(images)
for image in images:
image_path, label_path = image.split()
image_path = os.path.join(settings.data_dir, image_path)
label_path = os.path.join(settings.data_dir, label_path)
im = Image.open(image_path)
if im.mode == 'L':
im = im.convert('RGB')
im_width, im_height = im.size
# layout: label | xmin | ymin | xmax | ymax | difficult
bbox_labels = []
root = xml.etree.ElementTree.parse(label_path).getroot()
for object in root.findall('object'):
bbox_sample = []
# start from 1
bbox_sample.append(
float(settings.label_list.index(object.find('name').text)))
bbox = object.find('bndbox')
difficult = float(object.find('difficult').text)
bbox_sample.append(float(bbox.find('xmin').text) / im_width)
bbox_sample.append(float(bbox.find('ymin').text) / im_height)
bbox_sample.append(float(bbox.find('xmax').text) / im_width)
bbox_sample.append(float(bbox.find('ymax').text) / im_height)
bbox_sample.append(difficult)
bbox_labels.append(bbox_sample)
im, sample_labels = preprocess(im, bbox_labels, mode, settings)
sample_labels = np.array(sample_labels) sample_labels = np.array(sample_labels)
if mode == 'train' or mode == 'test': if len(sample_labels) == 0: continue
if mode == 'train' and len(sample_labels) == 0: continue im = im.astype('float32')
if mode == 'test' and len(sample_labels) == 0: continue boxes = sample_labels[:, 1:5]
yield img.astype( lbls = sample_labels[:, 0].astype('int32')
'float32' difficults = sample_labels[:, -1].astype('int32')
), sample_labels[:, 1:5], sample_labels[:, 0].astype( yield im, boxes, lbls, difficults
'int32'), sample_labels[:, -1].astype('int32')
elif mode == 'infer':
yield img.astype('float32')
return reader return reader
...@@ -309,9 +309,9 @@ def train(settings, file_list, shuffle=True): ...@@ -309,9 +309,9 @@ def train(settings, file_list, shuffle=True):
elif '2017' in file_list: elif '2017' in file_list:
sub_dir = "train2017" sub_dir = "train2017"
train_settings.data_dir = os.path.join(settings.data_dir, sub_dir) train_settings.data_dir = os.path.join(settings.data_dir, sub_dir)
return _reader_creator(train_settings, file_list, 'train', shuffle) return coco(train_settings, file_list, 'train', shuffle)
else: else:
return _reader_creator(settings, file_list, 'train', shuffle) return pascalvoc(settings, file_list, 'train', shuffle)
def test(settings, file_list): def test(settings, file_list):
...@@ -323,10 +323,24 @@ def test(settings, file_list): ...@@ -323,10 +323,24 @@ def test(settings, file_list):
elif '2017' in file_list: elif '2017' in file_list:
sub_dir = "val2017" sub_dir = "val2017"
test_settings.data_dir = os.path.join(settings.data_dir, sub_dir) test_settings.data_dir = os.path.join(settings.data_dir, sub_dir)
return _reader_creator(test_settings, file_list, 'test', False) return coco(test_settings, file_list, 'test', False)
else: else:
return _reader_creator(settings, file_list, 'test', False) return pascalvoc(settings, file_list, 'test', False)
def infer(settings, file_list): def infer(settings, image_path):
return _reader_creator(settings, file_list, 'infer', False) im = Image.open(image_path)
if im.mode == 'L':
im = im.convert('RGB')
im_width, im_height = im.size
img = img.resize((settings.resize_w, settings.resize_h), Image.ANTIALIAS)
img = np.array(img)
# HWC to CHW
if len(img.shape) == 3:
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
# RBG to BGR
img = img[[2, 1, 0], :, :]
img = img.astype('float32')
img -= settings.img_mean
img = img * 0.007843
...@@ -3,6 +3,7 @@ import time ...@@ -3,6 +3,7 @@ import time
import numpy as np import numpy as np
import argparse import argparse
import functools import functools
import shutil
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -205,7 +206,6 @@ def parallel_exe(args, ...@@ -205,7 +206,6 @@ def parallel_exe(args,
evaluate_difficult=False, evaluate_difficult=False,
ap_version=args.ap_version) ap_version=args.ap_version)
print('ParallelExecutor, ap_version = ', args.ap_version)
if data_args.dataset == 'coco': if data_args.dataset == 'coco':
# learning rate decay in 12, 19 pass, respectively # learning rate decay in 12, 19 pass, respectively
if '2014' in train_file_list: if '2014' in train_file_list:
...@@ -243,7 +243,15 @@ def parallel_exe(args, ...@@ -243,7 +243,15 @@ def parallel_exe(args,
feeder = fluid.DataFeeder( feeder = fluid.DataFeeder(
place=place, feed_list=[image, gt_box, gt_label, difficult]) place=place, feed_list=[image, gt_box, gt_label, difficult])
def test(pass_id): def save_model(postfix):
model_path = os.path.join(model_save_dir, postfix)
if os.path.isdir(model_path):
shutil.rmtree(model_path)
print 'save models to %s' % (model_path)
fluid.io.save_persistables(exe, model_path)
best_map = 0.
def test(pass_id, best_map):
_, accum_map = map_eval.get_map_var() _, accum_map = map_eval.get_map_var()
map_eval.reset(exe) map_eval.reset(exe)
test_map = None test_map = None
...@@ -251,13 +259,15 @@ def parallel_exe(args, ...@@ -251,13 +259,15 @@ def parallel_exe(args,
test_map = exe.run(test_program, test_map = exe.run(test_program,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[accum_map]) fetch_list=[accum_map])
if test_map[0] > best_map:
best_map = test_map[0]
save_model('best_model')
print("Test {0}, map {1}".format(pass_id, test_map[0])) print("Test {0}, map {1}".format(pass_id, test_map[0]))
for pass_id in range(num_passes): for pass_id in range(num_passes):
start_time = time.time() start_time = time.time()
prev_start_time = start_time prev_start_time = start_time
end_time = 0 end_time = 0
test(pass_id)
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
prev_start_time = start_time prev_start_time = start_time
start_time = time.time() start_time = time.time()
...@@ -269,11 +279,10 @@ def parallel_exe(args, ...@@ -269,11 +279,10 @@ def parallel_exe(args,
if batch_id % 20 == 0: if batch_id % 20 == 0:
print("Pass {0}, batch {1}, loss {2}, time {3}".format( print("Pass {0}, batch {1}, loss {2}, time {3}".format(
pass_id, batch_id, loss_v, start_time - prev_start_time)) pass_id, batch_id, loss_v, start_time - prev_start_time))
test(pass_id, best_map)
if pass_id % 10 == 0 or pass_id == num_passes - 1: if pass_id % 10 == 0 or pass_id == num_passes - 1:
model_path = os.path.join(model_save_dir, str(pass_id)) save_model(str(pass_id))
print 'save models to %s' % (model_path) print("Best test map {0}".format(best_map))
fluid.io.save_persistables(exe, model_path)
if __name__ == '__main__': if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册