From 55849d4e2c4a54b73162e9ef895bd3c644731f1e Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Mon, 11 Jun 2018 16:06:25 +0800 Subject: [PATCH] Refine the VGG-SSD network. (#977) --- fluid/face_detection/infer.py | 2 +- fluid/face_detection/pyramidbox.py | 111 ++++++++++++++++++++++------- 2 files changed, 87 insertions(+), 26 deletions(-) diff --git a/fluid/face_detection/infer.py b/fluid/face_detection/infer.py index b807d0eb..f4401bfb 100644 --- a/fluid/face_detection/infer.py +++ b/fluid/face_detection/infer.py @@ -17,7 +17,7 @@ add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable add_arg('use_gpu', bool, True, "Whether use GPU.") add_arg('use_pyramidbox', bool, False, "Whether use PyramidBox model.") -add_arg('confs_threshold', float, 0.15, "Confidence threshold to draw bbox.") +add_arg('confs_threshold', float, 0.25, "Confidence threshold to draw bbox.") add_arg('image_path', str, '', "The data root path.") add_arg('model_dir', str, '', "The model path.") add_arg('resize_h', int, 0, "The resized image height.") diff --git a/fluid/face_detection/pyramidbox.py b/fluid/face_detection/pyramidbox.py index abac6b40..37339a26 100644 --- a/fluid/face_detection/pyramidbox.py +++ b/fluid/face_detection/pyramidbox.py @@ -242,7 +242,7 @@ class PyramidBox(object): box, var = fluid.layers.prior_box( input, self.image, - min_sizes=[self.min_sizes[1]], + min_sizes=[self.min_sizes[i]], steps=[self.steps[i]] * 2, aspect_ratios=[1.], offset=0.5) @@ -266,25 +266,75 @@ class PyramidBox(object): self.conv4_norm = self._l2_norm_scale(self.conv4, init_scale=8.) self.conv5_norm = self._l2_norm_scale(self.conv5, init_scale=5.) - mbox_locs, mbox_confs, box, box_var = fluid.layers.multi_box_head( - inputs=[ - self.conv3_norm, self.conv4_norm, self.conv5_norm, self.conv6, - self.conv7, self.conv8 - ], - image=self.image, - num_classes=self.num_classes, - min_sizes=[16.0, 32.0, 64.0, 128.0, 256.0, 512.0], - max_sizes=[[], [], [], [], [], []], - aspect_ratios=[[1.], [1.], [1.], [1.], [1.], [1.]], - steps=[4.0, 8.0, 16.0, 32.0, 64.0, 128.0], - base_size=self.data_shape[2], - offset=0.5, - flip=False) - - self.face_mbox_loc = mbox_locs - self.face_mbox_conf = mbox_confs - self.prior_boxes = box - self.box_vars = box_var + def permute_and_reshape(input, last_dim): + trans = fluid.layers.transpose(input, perm=[0, 2, 3, 1]) + new_shape = [ + trans.shape[0], np.prod(trans.shape[1:]) / last_dim, last_dim + ] + return fluid.layers.reshape(trans, shape=new_shape) + + locs, confs = [], [] + boxes, vars = [], [] + b_attr = ParamAttr(learning_rate=2., regularizer=L2Decay(0.)) + + # conv3 + mbox_loc = fluid.layers.conv2d( + self.conv3_norm, 4, 3, 1, 1, bias_attr=b_attr) + loc = permute_and_reshape(mbox_loc, 4) + mbox_conf = fluid.layers.conv2d( + self.conv3_norm, 4, 3, 1, 1, bias_attr=b_attr) + conf1, conf3 = fluid.layers.split( + mbox_conf, num_or_sections=[1, 3], dim=1) + conf3_maxin = fluid.layers.reduce_max(conf3, dim=1, keep_dim=True) + conf = fluid.layers.concat([conf1, conf3_maxin], axis=1) + conf = permute_and_reshape(conf, 2) + box, var = fluid.layers.prior_box( + self.conv3_norm, + self.image, + min_sizes=[16.], + steps=[4, 4], + aspect_ratios=[1.], + clip=False, + flip=True, + offset=0.5) + box = fluid.layers.reshape(box, shape=[-1, 4]) + var = fluid.layers.reshape(var, shape=[-1, 4]) + + locs.append(loc) + confs.append(conf) + boxes.append(box) + vars.append(var) + + min_sizes = [32., 64., 128., 256., 512.] + steps = [8., 16., 32., 64., 128.] + inputs = [ + self.conv4_norm, self.conv5_norm, self.conv6, self.conv7, self.conv8 + ] + for i, input in enumerate(inputs): + mbox_loc = fluid.layers.conv2d(input, 4, 3, 1, 1, bias_attr=b_attr) + loc = permute_and_reshape(mbox_loc, 4) + + mbox_conf = fluid.layers.conv2d(input, 2, 3, 1, 1, bias_attr=b_attr) + conf = permute_and_reshape(mbox_conf, 2) + box, var = fluid.layers.prior_box( + input, + self.image, + min_sizes=[min_sizes[i]], + steps=[steps[i]] * 2, + aspect_ratios=[1.], + offset=0.5) + box = fluid.layers.reshape(box, shape=[-1, 4]) + var = fluid.layers.reshape(var, shape=[-1, 4]) + + locs.append(loc) + confs.append(conf) + boxes.append(box) + vars.append(var) + + self.face_mbox_loc = fluid.layers.concat(locs, axis=1) + self.face_mbox_conf = fluid.layers.concat(confs, axis=1) + self.prior_boxes = fluid.layers.concat(boxes) + self.box_vars = fluid.layers.concat(vars) def vgg_ssd_loss(self): loss = fluid.layers.ssd_loss( @@ -297,16 +347,27 @@ class PyramidBox(object): overlap_threshold=0.35, neg_overlap=0.35) loss = fluid.layers.reduce_sum(loss) - return loss def train(self): face_loss = fluid.layers.ssd_loss( - self.face_mbox_loc, self.face_mbox_conf, self.face_box, - self.gt_label, self.prior_boxes, self.box_vars) + self.face_mbox_loc, + self.face_mbox_conf, + self.face_box, + self.gt_label, + self.prior_boxes, + self.box_vars, + overlap_threshold=0.35, + neg_overlap=0.35) head_loss = fluid.layers.ssd_loss( - self.head_mbox_loc, self.head_mbox_conf, self.head_box, - self.gt_label, self.prior_boxes, self.box_vars) + self.head_mbox_loc, + self.head_mbox_conf, + self.head_box, + self.gt_label, + self.prior_boxes, + self.box_vars, + overlap_threshold=0.35, + neg_overlap=0.35) face_loss = fluid.layers.reduce_sum(face_loss) head_loss = fluid.layers.reduce_sum(head_loss) total_loss = face_loss + head_loss -- GitLab