未验证 提交 b6c505b8 编写于 作者: Q qingqing01 提交者: GitHub

Add head bbox for Pyramid-Box model. (#963)

* Add head bbox for Pyramid-Box model.

* Fix bug in transform_labels and satisfy_sample_constraint function.
上级 271d9585
......@@ -120,14 +120,21 @@ def jaccard_overlap(sample_bbox, object_bbox):
def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels):
has_jaccard_overlap = False if sampler.min_jaccard_overlap == 0 and sampler.max_jaccard_overlap == 0 else True
has_object_coverage = False if sampler.min_object_coverage == 0 and sampler.max_object_coverage == 0 else True
if sampler.min_jaccard_overlap == 0 and sampler.max_jaccard_overlap == 0:
has_jaccard_overlap = False
else:
has_jaccard_overlap = True
if sampler.min_object_coverage == 0 and sampler.max_object_coverage == 0:
has_object_coverage = False
else:
has_object_coverage = True
if not has_jaccard_overlap and not has_object_coverage:
return True
found = False
for i in range(len(bbox_labels)):
object_bbox = bbox(bbox_labels[i][0], bbox_labels[i][1],
bbox_labels[i][2], bbox_labels[i][3])
object_bbox = bbox(bbox_labels[i][1], bbox_labels[i][2],
bbox_labels[i][3], bbox_labels[i][4])
if has_jaccard_overlap:
overlap = jaccard_overlap(sample_bbox, object_bbox)
if sampler.min_jaccard_overlap != 0 and \
......@@ -214,8 +221,8 @@ def transform_labels(bbox_labels, sample_bbox):
sample_labels = []
for i in range(len(bbox_labels)):
sample_label = []
object_bbox = bbox(bbox_labels[i][0], bbox_labels[i][1],
bbox_labels[i][2], bbox_labels[i][3])
object_bbox = bbox(bbox_labels[i][1], bbox_labels[i][2],
bbox_labels[i][3], bbox_labels[i][4])
if not meet_emit_constraint(object_bbox, sample_bbox):
continue
proj_bbox = project_bbox(object_bbox, sample_bbox)
......
......@@ -50,6 +50,7 @@ class PyramidBox(object):
self.min_sizes = [16., 32., 64., 128., 256., 512.]
self.steps = [4., 8., 16., 32., 64., 128.]
self.is_infer = is_infer
self.sub_network = sub_network
# the base network is VGG with atrous layers
self._input()
......@@ -59,12 +60,23 @@ class PyramidBox(object):
self._cpm_module()
self._pyramidbox()
def feeds(self):
if self.is_infer:
return [self.image]
else:
return [
self.image, self.face_box, self.head_box, self.gt_label,
self.difficult
]
def _input(self):
self.image = fluid.layers.data(
name='image', shape=self.data_shape, dtype='float32')
if not self.is_infer:
self.gt_box = fluid.layers.data(
name='gt_box', shape=[4], dtype='float32', lod_level=1)
self.face_box = fluid.layers.data(
name='face_box', shape=[4], dtype='float32', lod_level=1)
self.head_box = fluid.layers.data(
name='head_box', shape=[4], dtype='float32', lod_level=1)
self.gt_label = fluid.layers.data(
name='gt_label', shape=[1], dtype='int32', lod_level=1)
self.difficult = fluid.layers.data(
......@@ -267,7 +279,7 @@ class PyramidBox(object):
# locs, confs, box, box_var = vgg_extra_net(num_classes, image, image_shape)
# nmsed_out = fluid.layers.detection_output(
# locs, confs, box, box_var, nms_threshold=args.nms_threshold)
loss = fluid.layers.ssd_loss(mbox_locs, mbox_confs, self.gt_box,
loss = fluid.layers.ssd_loss(mbox_locs, mbox_confs, self.face_box,
self.gt_label, box, box_var)
loss = fluid.layers.reduce_sum(loss)
......@@ -275,11 +287,11 @@ class PyramidBox(object):
def train(self):
face_loss = fluid.layers.ssd_loss(
self.face_mbox_loc, self.face_mbox_conf, self.gt_box, self.gt_label,
self.prior_boxes, self.box_vars)
self.face_mbox_loc, self.face_mbox_conf, self.face_box,
self.gt_label, self.prior_boxes, self.box_vars)
head_loss = fluid.layers.ssd_loss(
self.head_mbox_loc, self.head_mbox_conf, self.gt_box, self.gt_label,
self.prior_boxes, self.box_vars)
self.head_mbox_loc, self.head_mbox_conf, self.head_box,
self.gt_label, self.prior_boxes, self.box_vars)
face_loss = fluid.layers.reduce_sum(face_loss)
head_loss = fluid.layers.reduce_sum(head_loss)
total_loss = face_loss + head_loss
......@@ -303,14 +315,14 @@ class PyramidBox(object):
face_map_eval = fluid.evaluator.DetectionMAP(
face_nmsed_out,
self.gt_label,
self.gt_box,
self.face_box,
class_num=2,
overlap_threshold=0.5,
ap_version='11point')
head_map_eval = fluid.evaluator.DetectionMAP(
head_nmsed_out,
self.gt_label,
self.gt_box,
self.head_box,
class_num=2,
overlap_threshold=0.5,
ap_version='11point')
......
......@@ -195,6 +195,30 @@ def put_txt_in_dict(input_txt):
return dict_input_txt
def expand_bboxes(bboxes,
expand_left=2.,
expand_up=2.,
expand_right=2.,
expand_down=2.):
"""
Expand bboxes, expand 2 times by defalut.
"""
expand_boxes = []
for bbox in bboxes:
xmin = bbox[0]
ymin = bbox[1]
xmax = bbox[2]
ymax = bbox[3]
w = xmax - xmin
h = ymax - ymin
ex_xmin = max(xmin - w / expand_left, 0.)
ex_ymin = max(ymin - h / expand_up, 0.)
ex_xmax = min(xmax + w / expand_right, 1.)
ex_ymax = min(ymax + h / expand_down, 1.)
expand_boxes.append([ex_xmin, ex_ymin, ex_xmax, ex_ymax])
return expand_boxes
def pyramidbox(settings, file_list, mode, shuffle):
dict_input_txt = {}
......@@ -241,15 +265,10 @@ def pyramidbox(settings, file_list, mode, shuffle):
boxes = sample_labels[:, 1:5]
lbls = [1] * len(boxes)
difficults = [1] * len(boxes)
yield im, boxes, lbls, difficults
yield im, boxes, expand_bboxes(boxes), lbls, difficults
return reader
def train(settings, file_list, shuffle=True):
return pyramidbox(settings, file_list, 'train', shuffle)
def test(settings, file_list):
return pyramidbox(settings, file_list, 'test', False)
......@@ -39,12 +39,15 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
image_shape = [3, data_args.resize_h, data_args.resize_w]
fetches = []
if args.use_pyramidbox:
network = PyramidBox(image_shape, sub_network=args.use_pyramidbox)
face_loss, head_loss, loss = network.train()
fetches = [face_loss, head_loss]
else:
network = PyramidBox(image_shape, sub_network=args.use_pyramidbox)
loss = network.vgg_ssd(num_classes, image_shape)
fetches = [loss]
epocs = 12880 / batch_size
boundaries = [epocs * 100, epocs * 125, epocs * 150]
......@@ -73,9 +76,11 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
exe.run(fluid.default_startup_program())
if pretrained_model:
if not os.path.exists(pretrained_model):
raise ValueError("The pre-trained model path [%s] does not exist." %
(pretrained_model))
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
print('Load pre-trained model.')
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
if args.parallel:
......@@ -84,11 +89,7 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
train_reader = paddle.batch(
reader.train(data_args, train_file_list), batch_size=batch_size)
feeder = fluid.DataFeeder(
place=place,
feed_list=[
network.image, network.gt_box, network.gt_label, network.difficult
])
feeder = fluid.DataFeeder(place=place, feed_list=network.feeds())
def save_model(postfix):
model_path = os.path.join(model_save_dir, postfix)
......@@ -97,8 +98,6 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
print 'save models to %s' % (model_path)
fluid.io.save_persistables(exe, model_path)
best_map = 0.
for pass_id in range(num_passes):
start_time = time.time()
prev_start_time = start_time
......@@ -108,20 +107,27 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
start_time = time.time()
if len(data) < devices_num: continue
if args.parallel:
loss_v, = train_exe.run(fetch_list=[loss.name],
feed=feeder.feed(data))
fetch_vars = train_exe.run(fetch_list=[v.name for v in fetches],
feed=feeder.feed(data))
else:
loss_v, = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[loss])
fetch_vars = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=fetches)
end_time = time.time()
loss_v = np.mean(np.array(loss_v))
fetch_vars = [np.mean(np.array(v)) for v in fetch_vars]
if batch_id % 1 == 0:
print("Pass {0}, batch {1}, loss {2}, time {3}".format(
pass_id, batch_id, loss_v, start_time - prev_start_time))
if not args.use_pyramidbox:
print("Pass {0}, batch {1}, loss {2}, time {3}".format(
pass_id, batch_id, fetch_vars[0],
start_time - prev_start_time))
else:
print("Pass {0}, batch {1}, face loss {2}, head loss {3}, " \
"time {4}".format(pass_id,
batch_id, fetch_vars[0], fetch_vars[1],
start_time - prev_start_time))
if pass_id % 10 == 0 or pass_id == num_passes - 1:
save_model(str(pass_id))
print("Best test map {0}".format(best_map))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册