提交 900ede56 编写于 作者: B baiyf 提交者: qingqing01

Add object_coverage constraint in data sampler & code fix (#962)

* Add object_coverage constraint & fix bbox_labels data layout bug

* code clean
上级 6e091a42
model/
data/
label/
pretrained/
*.swp
......@@ -17,6 +17,8 @@ class sampler():
max_aspect_ratio,
min_jaccard_overlap,
max_jaccard_overlap,
min_object_coverage,
max_object_coverage,
use_square=False):
self.max_sample = max_sample
self.max_trial = max_trial
......@@ -26,6 +28,8 @@ class sampler():
self.max_aspect_ratio = max_aspect_ratio
self.min_jaccard_overlap = min_jaccard_overlap
self.max_jaccard_overlap = max_jaccard_overlap
self.min_object_coverage = min_object_coverage
self.max_object_coverage = max_object_coverage
self.use_square = use_square
......@@ -37,10 +41,36 @@ class bbox():
self.ymax = ymax
def intersect_bbox(bbox1, bbox2):
if bbox2.xmin > bbox1.xmax or bbox2.xmax < bbox1.xmin or \
bbox2.ymin > bbox1.ymax or bbox2.ymax < bbox1.ymin:
intersection_box = bbox(0.0, 0.0, 0.0, 0.0)
else:
intersection_box = bbox(
max(bbox1.xmin, bbox2.xmin),
max(bbox1.ymin, bbox2.ymin),
min(bbox1.xmax, bbox2.xmax), min(bbox1.ymax, bbox2.ymax))
return intersection_box
def bbox_coverage(bbox1, bbox2):
inter_box = intersect_bbox(bbox1, bbox2)
intersect_size = bbox_area(inter_box)
if intersect_size > 0:
bbox1_size = bbox_area(bbox1)
return intersect_size / bbox1_size
else:
return 0.
def bbox_area(src_bbox):
width = src_bbox.xmax - src_bbox.xmin
height = src_bbox.ymax - src_bbox.ymin
return width * height
if src_bbox.xmax < src_bbox.xmin or src_bbox.ymax < src_bbox.ymin:
return 0.
else:
width = src_bbox.xmax - src_bbox.xmin
height = src_bbox.ymax - src_bbox.ymin
return width * height
def generate_sample(sampler, image_width, image_height):
......@@ -90,21 +120,35 @@ def jaccard_overlap(sample_bbox, object_bbox):
def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels):
if sampler.min_jaccard_overlap == 0 and sampler.max_jaccard_overlap == 0:
has_jaccard_overlap = False if sampler.min_jaccard_overlap == 0 and sampler.max_jaccard_overlap == 0 else True
has_object_coverage = False if sampler.min_object_coverage == 0 and sampler.max_object_coverage == 0 else True
if not has_jaccard_overlap and not has_object_coverage:
return True
found = False
for i in range(len(bbox_labels)):
object_bbox = bbox(bbox_labels[i][0], bbox_labels[i][1],
bbox_labels[i][2], bbox_labels[i][3])
# now only support constraint by jaccard overlap
overlap = jaccard_overlap(sample_bbox, object_bbox)
if sampler.min_jaccard_overlap != 0 and \
overlap < sampler.min_jaccard_overlap:
continue
if sampler.max_jaccard_overlap != 0 and \
overlap > sampler.max_jaccard_overlap:
continue
return True
return False
if has_jaccard_overlap:
overlap = jaccard_overlap(sample_bbox, object_bbox)
if sampler.min_jaccard_overlap != 0 and \
overlap < sampler.min_jaccard_overlap:
continue
if sampler.max_jaccard_overlap != 0 and \
overlap > sampler.max_jaccard_overlap:
continue
found = True
if has_object_coverage:
object_coverage = bbox_coverage(object_bbox, sample_bbox)
if sampler.min_object_coverage != 0 and \
object_coverage < sampler.min_object_coverage:
continue
if sampler.max_object_coverage != 0 and \
object_coverage > sampler.max_object_coverage:
continue
found = True
if found:
return True
return found
def generate_batch_samples(batch_sampler, bbox_labels, image_width,
......
......@@ -51,7 +51,7 @@ class PyramidBox(object):
self.steps = [4., 8., 16., 32., 64., 128.]
self.is_infer = is_infer
# the base network is VGG with atrus layers
# the base network is VGG with atrous layers
self._input()
self._vgg()
if sub_network:
......
......@@ -72,7 +72,7 @@ class Settings(object):
return self._toy
@property
def apply_distort(self):
def apply_expand(self):
return self._apply_expand
@property
......@@ -117,15 +117,20 @@ def preprocess(img, bbox_labels, mode, settings):
batch_sampler = []
# hard-code here
batch_sampler.append(
image_util.sampler(1, 50, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, True))
image_util.sampler(1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0,
True))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 0.0, True))
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0,
True))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 0.0, True))
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0,
True))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 0.0, True))
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0,
True))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 0.0, True))
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0,
True))
sampled_bbox = image_util.generate_batch_samples(
batch_sampler, bbox_labels, img_width, img_height)
......@@ -208,12 +213,11 @@ def pyramidbox(settings, file_list, mode, shuffle):
im = im.convert('RGB')
im_width, im_height = im.size
# layout: category_id | xmin | ymin | xmax | ymax | iscrowd
# layout: label | xmin | ymin | xmax | ymax
bbox_labels = []
for index_box in range(len(dict_input_txt[index_image])):
if index_box >= 2:
bbox_sample = []
temp_info_box = dict_input_txt[index_image][
index_box].split(' ')
xmin = float(temp_info_box[0])
......@@ -223,6 +227,7 @@ def pyramidbox(settings, file_list, mode, shuffle):
xmax = xmin + w
ymax = ymin + h
bbox_sample.append(1)
bbox_sample.append(float(xmin) / im_width)
bbox_sample.append(float(ymin) / im_height)
bbox_sample.append(float(xmax) / im_width)
......@@ -233,7 +238,7 @@ def pyramidbox(settings, file_list, mode, shuffle):
sample_labels = np.array(sample_labels)
if len(sample_labels) == 0: continue
im = im.astype('float32')
boxes = sample_labels[:, 0:4]
boxes = sample_labels[:, 1:5]
lbls = [1] * len(boxes)
difficults = [1] * len(boxes)
......
import os
import shutil
import numpy as np
import time
import argparse
......@@ -22,7 +23,7 @@ add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('use_pyramidbox', bool, False, "Whether use PyramidBox model.")
add_arg('dataset', str, 'WIDERFACE', "coco2014, coco2017, and pascalvoc.")
add_arg('model_save_dir', str, 'model', "The path to save model.")
add_arg('pretrained_model', str, './vgg_model/', "The init model path.")
add_arg('pretrained_model', str, './pretrained/', "The init model path.")
add_arg('resize_h', int, 640, "The resized image height.")
add_arg('resize_w', int, 640, "The resized image height.")
#yapf: enable
......@@ -71,7 +72,6 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
# fluid.io.save_inference_model('./vgg_model/', ['image'], [loss], exe)
if pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册