提交 ee7c8d90 编写于 作者: G gaoyuan

add ssd latest data augmentation

上级 f454c647
from PIL import Image from PIL import Image, ImageEnhance
import numpy as np import numpy as np
import random import random
import math import math
...@@ -159,3 +159,77 @@ def crop_image(img, bbox_labels, sample_bbox, image_width, image_height): ...@@ -159,3 +159,77 @@ def crop_image(img, bbox_labels, sample_bbox, image_width, image_height):
sample_img = img[ymin:ymax, xmin:xmax] sample_img = img[ymin:ymax, xmin:xmax]
sample_labels = transform_labels(bbox_labels, sample_bbox) sample_labels = transform_labels(bbox_labels, sample_bbox)
return sample_img, sample_labels return sample_img, sample_labels
def random_brightness(img, settings):
prob = random.uniform(0, 1)
if prob < settings._brightness_prob:
delta = random.uniform(-settings._brightness_delta,
settings._brightness_delta) + 1
img = ImageEnhance.Brightness(img).enhance(delta)
return img
def random_contrast(img, settings):
prob = random.uniform(0, 1)
if prob < settings._contrast_prob:
delta = random.uniform(-settings._contrast_delta,
settings._contrast_delta) + 1
img = ImageEnhance.Contrast(img).enhance(delta)
return img
def random_saturation(img, settings):
prob = random.uniform(0, 1)
if prob < settings._saturation_prob:
delta = random.uniform(-settings._saturation_delta,
settings._saturation_delta) + 1
img = ImageEnhance.Color(img).enhance(delta)
return img
def random_hue(img, settings):
prob = random.uniform(0, 1)
if prob < settings._hue_prob:
delta = random.uniform(-settings._hue_delta, settings._hue_delta)
img_hsv = np.array(img.convert('HSV'))
img_hsv[:, :, 0] = img_hsv[:, :, 0] + delta
img = Image.fromarray(img_hsv, mode='HSV').convert('RGB')
return img
def distort_image(img, settings):
prob = random.uniform(0, 1)
# Apply different distort order
if prob > 0.5:
img = random_brightness(img, settings)
img = random_contrast(img, settings)
img = random_saturation(img, settings)
img = random_hue(img, settings)
else:
img = random_brightness(img, settings)
img = random_saturation(img, settings)
img = random_hue(img, settings)
img = random_contrast(img, settings)
return img
def expand_image(img, bbox_labels, img_width, img_height, settings):
prob = random.uniform(0, 1)
if prob < settings._hue_prob:
expand_ratio = random.uniform(1, settings._expand_max_ratio)
if expand_ratio - 1 >= 0.01:
height = int(img_height * expand_ratio)
width = int(img_width * expand_ratio)
h_off = math.floor(random.uniform(0, height - img_height))
w_off = math.floor(random.uniform(0, width - img_width))
expand_bbox = bbox(-w_off / img_width, -h_off / img_height,
(width - w_off) / img_width,
(height - h_off) / img_height)
expand_img = np.ones((height, width, 3))
expand_img = np.uint8(expand_img * np.squeeze(settings._img_mean))
expand_img = Image.fromarray(expand_img)
expand_img.paste(img, (int(w_off), int(h_off)))
bbox_labels = transform_labels(bbox_labels, expand_bbox)
return expand_img, bbox_labels
return img, bbox_labels
...@@ -22,17 +22,38 @@ import os ...@@ -22,17 +22,38 @@ import os
class Settings(object): class Settings(object):
def __init__(self, data_dir, label_file, resize_h, resize_w, mean_value): def __init__(self, data_dir, label_file, resize_h, resize_w, mean_value,
apply_distort, apply_expand):
self._data_dir = data_dir self._data_dir = data_dir
self._label_list = [] self._label_list = []
label_fpath = os.path.join(data_dir, label_file) label_fpath = os.path.join(data_dir, label_file)
for line in open(label_fpath): for line in open(label_fpath):
self._label_list.append(line.strip()) self._label_list.append(line.strip())
self._apply_distort = apply_distort
self._apply_expand = apply_expand
self._resize_height = resize_h self._resize_height = resize_h
self._resize_width = resize_w self._resize_width = resize_w
self._img_mean = np.array(mean_value)[:, np.newaxis, np.newaxis].astype( self._img_mean = np.array(mean_value)[:, np.newaxis, np.newaxis].astype(
'float32') 'float32')
self._expand_prob = 0.5
self._expand_max_ratio = 4
self._hue_prob = 0.5
self._hue_delta = 18
self._contrast_prob = 0.5
self._contrast_delta = 0.5
self._saturation_prob = 0.5
self._saturation_delta = 0.5
self._brightness_prob = 0.5
self._brightness_delta = 0.125
@property
def apply_distort(self):
return self._apply_expand
@property
def apply_distort(self):
return self._apply_distort
@property @property
def data_dir(self): def data_dir(self):
...@@ -71,7 +92,6 @@ def _reader_creator(settings, file_list, mode, shuffle): ...@@ -71,7 +92,6 @@ def _reader_creator(settings, file_list, mode, shuffle):
img = Image.open(img_path) img = Image.open(img_path)
img_width, img_height = img.size img_width, img_height = img.size
img = np.array(img)
# layout: label | xmin | ymin | xmax | ymax | difficult # layout: label | xmin | ymin | xmax | ymax | difficult
if mode == 'train' or mode == 'test': if mode == 'train' or mode == 'test':
...@@ -99,6 +119,12 @@ def _reader_creator(settings, file_list, mode, shuffle): ...@@ -99,6 +119,12 @@ def _reader_creator(settings, file_list, mode, shuffle):
sample_labels = bbox_labels sample_labels = bbox_labels
if mode == 'train': if mode == 'train':
if settings._apply_distort:
img = image_util.distort_image(img, settings)
if settings._apply_expand:
img, bbox_labels = image_util.expand_image(
img, bbox_labels, img_width, img_height,
settings)
batch_sampler = [] batch_sampler = []
# hard-code here # hard-code here
batch_sampler.append( batch_sampler.append(
...@@ -126,6 +152,7 @@ def _reader_creator(settings, file_list, mode, shuffle): ...@@ -126,6 +152,7 @@ def _reader_creator(settings, file_list, mode, shuffle):
sampled_bbox = image_util.generate_batch_samples( sampled_bbox = image_util.generate_batch_samples(
batch_sampler, bbox_labels, img_width, img_height) batch_sampler, bbox_labels, img_width, img_height)
img = np.array(img)
if len(sampled_bbox) > 0: if len(sampled_bbox) > 0:
idx = int(random.uniform(0, len(sampled_bbox))) idx = int(random.uniform(0, len(sampled_bbox)))
img, sample_labels = image_util.crop_image( img, sample_labels = image_util.crop_image(
......
...@@ -45,13 +45,10 @@ def train(train_file_list, ...@@ -45,13 +45,10 @@ def train(train_file_list,
evaluate_difficult=False, evaluate_difficult=False,
ap_version='11point') ap_version='11point')
optimizer = fluid.optimizer.Momentum( boundaries = [40000, 60000]
learning_rate=fluid.layers.exponential_decay( values = [0.001, 0.0005, 0.00025]
learning_rate=learning_rate, optimizer = fluid.optimizer.RMSProp(
decay_steps=40000, learning_rate=fluid.layers.piecewise_decay(boundaries, values),
decay_rate=0.1,
staircase=True),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(0.00005), ) regularization=fluid.regularizer.L2Decay(0.00005), )
optimizer.minimize(loss) optimizer.minimize(loss)
...@@ -60,7 +57,8 @@ def train(train_file_list, ...@@ -60,7 +57,8 @@ def train(train_file_list,
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
load_model.load_paddlev1_vars(place) load_model.load_and_set_vars(place)
#load_model.load_paddlev1_vars(place)
train_reader = paddle.batch( train_reader = paddle.batch(
reader.train(data_args, train_file_list), batch_size=batch_size) reader.train(data_args, train_file_list), batch_size=batch_size)
test_reader = paddle.batch( test_reader = paddle.batch(
...@@ -85,6 +83,7 @@ def train(train_file_list, ...@@ -85,6 +83,7 @@ def train(train_file_list,
loss_v = exe.run(fluid.default_main_program(), loss_v = exe.run(fluid.default_main_program(),
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[loss]) fetch_list=[loss])
if batch_id % 20 == 0:
print("Pass {0}, batch {1}, loss {2}" print("Pass {0}, batch {1}, loss {2}"
.format(pass_id, batch_id, loss_v[0])) .format(pass_id, batch_id, loss_v[0]))
test(pass_id) test(pass_id)
...@@ -100,6 +99,8 @@ if __name__ == '__main__': ...@@ -100,6 +99,8 @@ if __name__ == '__main__':
data_args = reader.Settings( data_args = reader.Settings(
data_dir='./data', data_dir='./data',
label_file='label_list', label_file='label_list',
apply_distort=True,
apply_expand=True,
resize_h=300, resize_h=300,
resize_w=300, resize_w=300,
mean_value=[127.5, 127.5, 127.5]) mean_value=[127.5, 127.5, 127.5])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册