You need to sign in or sign up before continuing.
未验证 提交 2a6b7dc9 编写于 作者: Q qingqing01 提交者: GitHub

Data sampling enhancement. (#947)

* Support groups in de-conv and fix bug.

* Clean code.

* Fix model config.

* Some mirror changes in data argumentations.

* Set learning rate by input arguments.
上级 86235262
...@@ -8,9 +8,16 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True #otherwise IOError raised image file is ...@@ -8,9 +8,16 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True #otherwise IOError raised image file is
class sampler(): class sampler():
def __init__(self, max_sample, max_trial, min_scale, max_scale, def __init__(self,
min_aspect_ratio, max_aspect_ratio, min_jaccard_overlap, max_sample,
max_jaccard_overlap): max_trial,
min_scale,
max_scale,
min_aspect_ratio,
max_aspect_ratio,
min_jaccard_overlap,
max_jaccard_overlap,
use_square=False):
self.max_sample = max_sample self.max_sample = max_sample
self.max_trial = max_trial self.max_trial = max_trial
self.min_scale = min_scale self.min_scale = min_scale
...@@ -19,6 +26,7 @@ class sampler(): ...@@ -19,6 +26,7 @@ class sampler():
self.max_aspect_ratio = max_aspect_ratio self.max_aspect_ratio = max_aspect_ratio
self.min_jaccard_overlap = min_jaccard_overlap self.min_jaccard_overlap = min_jaccard_overlap
self.max_jaccard_overlap = max_jaccard_overlap self.max_jaccard_overlap = max_jaccard_overlap
self.use_square = use_square
class bbox(): class bbox():
...@@ -35,7 +43,7 @@ def bbox_area(src_bbox): ...@@ -35,7 +43,7 @@ def bbox_area(src_bbox):
return width * height return width * height
def generate_sample(sampler): def generate_sample(sampler, image_width, image_height):
scale = random.uniform(sampler.min_scale, sampler.max_scale) scale = random.uniform(sampler.min_scale, sampler.max_scale)
aspect_ratio = random.uniform(sampler.min_aspect_ratio, aspect_ratio = random.uniform(sampler.min_aspect_ratio,
sampler.max_aspect_ratio) sampler.max_aspect_ratio)
...@@ -44,6 +52,14 @@ def generate_sample(sampler): ...@@ -44,6 +52,14 @@ def generate_sample(sampler):
bbox_width = scale * (aspect_ratio**0.5) bbox_width = scale * (aspect_ratio**0.5)
bbox_height = scale / (aspect_ratio**0.5) bbox_height = scale / (aspect_ratio**0.5)
# guarantee a squared image patch after cropping
if sampler.use_square:
if image_height < image_width:
bbox_width = bbox_height * image_height / image_width
else:
bbox_height = bbox_width * image_width / image_height
xmin_bound = 1 - bbox_width xmin_bound = 1 - bbox_width
ymin_bound = 1 - bbox_height ymin_bound = 1 - bbox_height
xmin = random.uniform(0, xmin_bound) xmin = random.uniform(0, xmin_bound)
...@@ -79,6 +95,7 @@ def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels): ...@@ -79,6 +95,7 @@ def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels):
for i in range(len(bbox_labels)): for i in range(len(bbox_labels)):
object_bbox = bbox(bbox_labels[i][0], bbox_labels[i][1], object_bbox = bbox(bbox_labels[i][0], bbox_labels[i][1],
bbox_labels[i][2], bbox_labels[i][3]) bbox_labels[i][2], bbox_labels[i][3])
# now only support constraint by jaccard overlap
overlap = jaccard_overlap(sample_bbox, object_bbox) overlap = jaccard_overlap(sample_bbox, object_bbox)
if sampler.min_jaccard_overlap != 0 and \ if sampler.min_jaccard_overlap != 0 and \
overlap < sampler.min_jaccard_overlap: overlap < sampler.min_jaccard_overlap:
...@@ -90,7 +107,8 @@ def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels): ...@@ -90,7 +107,8 @@ def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels):
return False return False
def generate_batch_samples(batch_sampler, bbox_labels): def generate_batch_samples(batch_sampler, bbox_labels, image_width,
image_height):
sampled_bbox = [] sampled_bbox = []
index = [] index = []
c = 0 c = 0
...@@ -99,7 +117,7 @@ def generate_batch_samples(batch_sampler, bbox_labels): ...@@ -99,7 +117,7 @@ def generate_batch_samples(batch_sampler, bbox_labels):
for i in range(sampler.max_trial): for i in range(sampler.max_trial):
if found >= sampler.max_sample: if found >= sampler.max_sample:
break break
sample_bbox = generate_sample(sampler) sample_bbox = generate_sample(sampler, image_width, image_height)
if satisfy_sample_constraint(sampler, sample_bbox, bbox_labels): if satisfy_sample_constraint(sampler, sample_bbox, bbox_labels):
sampled_bbox.append(sample_bbox) sampled_bbox.append(sample_bbox)
found = found + 1 found = found + 1
...@@ -127,15 +145,14 @@ def meet_emit_constraint(src_bbox, sample_bbox): ...@@ -127,15 +145,14 @@ def meet_emit_constraint(src_bbox, sample_bbox):
return False return False
def transform_labels(bbox_labels, sample_bbox): def project_bbox(object_bbox, sample_bbox):
if object_bbox.xmin >= sample_bbox.xmax or \
object_bbox.xmax <= sample_bbox.xmin or \
object_bbox.ymin >= sample_bbox.ymax or \
object_bbox.ymax <= sample_bbox.ymin:
return False
else:
proj_bbox = bbox(0, 0, 0, 0) proj_bbox = bbox(0, 0, 0, 0)
sample_labels = []
for i in range(len(bbox_labels)):
sample_label = []
object_bbox = bbox(bbox_labels[i][0], bbox_labels[i][1],
bbox_labels[i][2], bbox_labels[i][3])
if not meet_emit_constraint(object_bbox, sample_bbox):
continue
sample_width = sample_bbox.xmax - sample_bbox.xmin sample_width = sample_bbox.xmax - sample_bbox.xmin
sample_height = sample_bbox.ymax - sample_bbox.ymin sample_height = sample_bbox.ymax - sample_bbox.ymin
proj_bbox.xmin = (object_bbox.xmin - sample_bbox.xmin) / sample_width proj_bbox.xmin = (object_bbox.xmin - sample_bbox.xmin) / sample_width
...@@ -144,12 +161,26 @@ def transform_labels(bbox_labels, sample_bbox): ...@@ -144,12 +161,26 @@ def transform_labels(bbox_labels, sample_bbox):
proj_bbox.ymax = (object_bbox.ymax - sample_bbox.ymin) / sample_height proj_bbox.ymax = (object_bbox.ymax - sample_bbox.ymin) / sample_height
proj_bbox = clip_bbox(proj_bbox) proj_bbox = clip_bbox(proj_bbox)
if bbox_area(proj_bbox) > 0: if bbox_area(proj_bbox) > 0:
return proj_bbox
else:
return False
def transform_labels(bbox_labels, sample_bbox):
sample_labels = []
for i in range(len(bbox_labels)):
sample_label = []
object_bbox = bbox(bbox_labels[i][0], bbox_labels[i][1],
bbox_labels[i][2], bbox_labels[i][3])
if not meet_emit_constraint(object_bbox, sample_bbox):
continue
proj_bbox = project_bbox(object_bbox, sample_bbox)
if proj_bbox:
sample_label.append(bbox_labels[i][0]) sample_label.append(bbox_labels[i][0])
sample_label.append(float(proj_bbox.xmin)) sample_label.append(float(proj_bbox.xmin))
sample_label.append(float(proj_bbox.ymin)) sample_label.append(float(proj_bbox.ymin))
sample_label.append(float(proj_bbox.xmax)) sample_label.append(float(proj_bbox.xmax))
sample_label.append(float(proj_bbox.ymax)) sample_label.append(float(proj_bbox.ymax))
#sample_label.append(bbox_labels[i][5])
sample_label = sample_label + bbox_labels[i][5:] sample_label = sample_label + bbox_labels[i][5:]
sample_labels.append(sample_label) sample_labels.append(sample_label)
return sample_labels return sample_labels
......
...@@ -29,9 +29,9 @@ class Settings(object): ...@@ -29,9 +29,9 @@ class Settings(object):
dataset=None, dataset=None,
data_dir=None, data_dir=None,
label_file=None, label_file=None,
resize_h=300, resize_h=None,
resize_w=300, resize_w=None,
mean_value=[127.5, 127.5, 127.5], mean_value=[104., 117., 123.],
apply_distort=True, apply_distort=True,
apply_expand=True, apply_expand=True,
ap_version='11point', ap_version='11point',
...@@ -55,6 +55,8 @@ class Settings(object): ...@@ -55,6 +55,8 @@ class Settings(object):
self._saturation_prob = 0.5 self._saturation_prob = 0.5
self._saturation_delta = 0.5 self._saturation_delta = 0.5
self._brightness_prob = 0.5 self._brightness_prob = 0.5
# _brightness_delta is the normalized value by 256
# self._brightness_delta = 32
self._brightness_delta = 0.125 self._brightness_delta = 0.125
@property @property
...@@ -115,17 +117,17 @@ def preprocess(img, bbox_labels, mode, settings): ...@@ -115,17 +117,17 @@ def preprocess(img, bbox_labels, mode, settings):
batch_sampler = [] batch_sampler = []
# hard-code here # hard-code here
batch_sampler.append( batch_sampler.append(
image_util.sampler(1, 50, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)) image_util.sampler(1, 50, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, True))
batch_sampler.append( batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 1.0)) image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 0.0, True))
batch_sampler.append( batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 1.0)) image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 0.0, True))
batch_sampler.append( batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 1.0)) image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 0.0, True))
batch_sampler.append( batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 1.0)) image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 0.0, True))
sampled_bbox = image_util.generate_batch_samples(batch_sampler, sampled_bbox = image_util.generate_batch_samples(
bbox_labels) batch_sampler, bbox_labels, img_width, img_height)
img = np.array(img) img = np.array(img)
if len(sampled_bbox) > 0: if len(sampled_bbox) > 0:
......
...@@ -29,7 +29,7 @@ add_arg('resize_w', int, 640, "The resized image height.") ...@@ -29,7 +29,7 @@ add_arg('resize_w', int, 640, "The resized image height.")
def train(args, data_args, learning_rate, batch_size, pretrained_model, def train(args, data_args, learning_rate, batch_size, pretrained_model,
num_passes): num_passes, optimizer_method):
num_classes = 2 num_classes = 2
...@@ -51,8 +51,15 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, ...@@ -51,8 +51,15 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
learning_rate, learning_rate * 0.1, learning_rate * 0.01, learning_rate, learning_rate * 0.1, learning_rate * 0.01,
learning_rate * 0.001 learning_rate * 0.001
] ]
#print('main program ', fluid.default_main_program())
if optimizer_method == "momentum":
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=boundaries, values=values),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(0.0005),
)
else:
optimizer = fluid.optimizer.RMSProp( optimizer = fluid.optimizer.RMSProp(
learning_rate=fluid.layers.piecewise_decay(boundaries, values), learning_rate=fluid.layers.piecewise_decay(boundaries, values),
regularization=fluid.regularizer.L2Decay(0.0005), regularization=fluid.regularizer.L2Decay(0.0005),
...@@ -131,11 +138,14 @@ if __name__ == '__main__': ...@@ -131,11 +138,14 @@ if __name__ == '__main__':
data_dir=data_dir, data_dir=data_dir,
resize_h=args.resize_h, resize_h=args.resize_h,
resize_w=args.resize_w, resize_w=args.resize_w,
apply_expand=False,
mean_value=[104., 117., 123],
ap_version='11point') ap_version='11point')
train( train(
args, args,
data_args=data_args, data_args=data_args,
learning_rate=0.01, learning_rate=args.learning_rate,
batch_size=args.batch_size, batch_size=args.batch_size,
pretrained_model=args.pretrained_model, pretrained_model=args.pretrained_model,
num_passes=args.num_passes) num_passes=args.num_passes,
optimizer_method="momentum")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册