diff --git a/fluid/face_detection/.gitignore b/fluid/face_detection/.gitignore index 27735faca6e555e439300fca5dccd893f70ef9a0..13d42af893162c1908a39fea1d072a22929e5430 100644 --- a/fluid/face_detection/.gitignore +++ b/fluid/face_detection/.gitignore @@ -1,5 +1,7 @@ model/ +pretrained/ data/ label/ -pretrained/ *.swp +*.log +infer_results/ diff --git a/fluid/face_detection/infer.py b/fluid/face_detection/infer.py index f4401bfba7bfacab44aab2aa0c33f8e5ffd36498..71a878cb39f9888e3c308ee24e34dd6c3a073d33 100644 --- a/fluid/face_detection/infer.py +++ b/fluid/face_detection/infer.py @@ -11,7 +11,6 @@ import paddle.fluid as fluid import reader from pyramidbox import PyramidBox from utility import add_arguments, print_arguments - parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable @@ -20,73 +19,272 @@ add_arg('use_pyramidbox', bool, False, "Whether use PyramidBox model.") add_arg('confs_threshold', float, 0.25, "Confidence threshold to draw bbox.") add_arg('image_path', str, '', "The data root path.") add_arg('model_dir', str, '', "The model path.") -add_arg('resize_h', int, 0, "The resized image height.") -add_arg('resize_w', int, 0, "The resized image height.") # yapf: enable def draw_bounding_box_on_image(image_path, nms_out, confs_threshold): image = Image.open(image_path) draw = ImageDraw.Draw(image) - im_width, im_height = image.size - for dt in nms_out: - category_id, score, xmin, ymin, xmax, ymax = dt.tolist() + xmin, ymin, xmax, ymax, score = dt if score < confs_threshold: continue - bbox = dt[2:] - xmin, ymin, xmax, ymax = bbox - (left, right, top, bottom) = (xmin * im_width, xmax * im_width, - ymin * im_height, ymax * im_height) + (left, right, top, bottom) = (xmin, xmax, ymin, ymax) draw.line( [(left, top), (left, bottom), (right, bottom), (right, top), (left, top)], width=4, fill='red') image_name = image_path.split('/')[-1] + image_class = image_path.split('/')[-2] print("image with bbox drawed saved as {}".format(image_name)) - image.save(image_name) + image.save('./infer_results/' + image_class.encode('utf-8') + '/' + + image_name.encode('utf-8')) -def infer(args, data_args): - num_classes = 2 - infer_reader = reader.infer(data_args, args.image_path) - data = infer_reader() +def write_to_txt(image_path, f, nms_out): + image_name = image_path.split('/')[-1] + image_class = image_path.split('/')[-2] + f.write('{:s}\n'.format( + image_class.encode('utf-8') + '/' + image_name.encode('utf-8'))) + f.write('{:d}\n'.format(nms_out.shape[0])) + for dt in nms_out: + xmin, ymin, xmax, ymax, score = dt + f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'.format(xmin, ymin, ( + xmax - xmin + 1), (ymax - ymin + 1), score)) + print("image infer result saved {}".format(image_name[:-4])) - if args.resize_h and args.resize_w: - image_shape = [3, args.resize_h, args.resize_w] - else: - image_shape = data.shape[1:] - fetches = [] +def get_round(x, loc): + str_x = str(x) + if '.' in str_x: + len_after = len(str_x.split('.')[1]) + str_before = str_x.split('.')[0] + str_after = str_x.split('.')[1] + if len_after >= 3: + str_final = str_before + '.' + str_after[0:loc] + return float(str_final) + else: + return x + + +def bbox_vote(det): + order = det[:, 4].ravel().argsort()[::-1] + det = det[order, :] + if det.shape[0] == 0: + dets = np.array([[10, 10, 20, 20, 0.002]]) + det = np.empty(shape=[0, 5]) + while det.shape[0] > 0: + # IOU + area = (det[:, 2] - det[:, 0] + 1) * (det[:, 3] - det[:, 1] + 1) + xx1 = np.maximum(det[0, 0], det[:, 0]) + yy1 = np.maximum(det[0, 1], det[:, 1]) + xx2 = np.minimum(det[0, 2], det[:, 2]) + yy2 = np.minimum(det[0, 3], det[:, 3]) + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + o = inter / (area[0] + area[:] - inter) + + # get needed merge det and delete these det + merge_index = np.where(o >= 0.3)[0] + det_accu = det[merge_index, :] + det = np.delete(det, merge_index, 0) + if merge_index.shape[0] <= 1: + if det.shape[0] == 0: + try: + dets = np.row_stack((dets, det_accu)) + except: + dets = det_accu + continue + det_accu[:, 0:4] = det_accu[:, 0:4] * np.tile(det_accu[:, -1:], (1, 4)) + max_score = np.max(det_accu[:, 4]) + det_accu_sum = np.zeros((1, 5)) + det_accu_sum[:, 0:4] = np.sum(det_accu[:, 0:4], + axis=0) / np.sum(det_accu[:, -1:]) + det_accu_sum[:, 4] = max_score + try: + dets = np.row_stack((dets, det_accu_sum)) + except: + dets = det_accu_sum + dets = dets[0:750, :] + return dets + + +def image_preprocess(image): + img = np.array(image) + # HWC to CHW + if len(img.shape) == 3: + img = np.swapaxes(img, 1, 2) + img = np.swapaxes(img, 1, 0) + # RBG to BGR + img = img[[2, 1, 0], :, :] + img = img.astype('float32') + img -= np.array( + [104., 117., 123.])[:, np.newaxis, np.newaxis].astype('float32') + img = img * 0.007843 + img = [img] + img = np.array(img) + return img - network = PyramidBox( - image_shape, - num_classes, - sub_network=args.use_pyramidbox, - is_infer=True) - infer_program, nmsed_out = network.infer() - fetches = [nmsed_out] +def detect_face(image, shrink): + image_shape = [3, image.size[1], image.size[0]] + num_classes = 2 place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) - model_dir = args.model_dir - if not os.path.exists(model_dir): - raise ValueError("The model path [%s] does not exist." % (model_dir)) + if shrink != 1: + image = image.resize((int(image_shape[2] * shrink), + int(image_shape[1] * shrink)), Image.ANTIALIAS) + image_shape = [ + image_shape[0], int(image_shape[1] * shrink), + int(image_shape[2] * shrink) + ] + print "image_shape:", image_shape + img = image_preprocess(image) + + scope = fluid.core.Scope() + main_program = fluid.Program() + startup_program = fluid.Program() + + with fluid.scope_guard(scope): + with fluid.unique_name.guard(): + with fluid.program_guard(main_program, startup_program): + fetches = [] + network = PyramidBox( + image_shape, + num_classes, + sub_network=args.use_pyramidbox, + is_infer=True) + infer_program, nmsed_out = network.infer(main_program) + fetches = [nmsed_out] + fluid.io.load_persistables( + exe, args.model_dir, main_program=main_program) + + detection, = exe.run(infer_program, + feed={'image': img}, + fetch_list=fetches, + return_numpy=False) + detection = np.array(detection) + # layout: xmin, ymin, xmax. ymax, score + det_conf = detection[:, 1] + det_xmin = image_shape[2] * detection[:, 2] / shrink + det_ymin = image_shape[1] * detection[:, 3] / shrink + det_xmax = image_shape[2] * detection[:, 4] / shrink + det_ymax = image_shape[1] * detection[:, 5] / shrink + + det = np.column_stack((det_xmin, det_ymin, det_xmax, det_ymax, det_conf)) + keep_index = np.where(det[:, 4] >= 0)[0] + det = det[keep_index, :] + return det + + +def flip_test(image, shrink): + img = image.transpose(Image.FLIP_LEFT_RIGHT) + det_f = detect_face(img, shrink) + det_t = np.zeros(det_f.shape) + # image.size: [width, height] + det_t[:, 0] = image.size[0] - det_f[:, 2] + det_t[:, 1] = det_f[:, 1] + det_t[:, 2] = image.size[0] - det_f[:, 0] + det_t[:, 3] = det_f[:, 3] + det_t[:, 4] = det_f[:, 4] + return det_t + + +def multi_scale_test(image, max_shrink): + # shrink detecting and shrink only detect big face + st = 0.5 if max_shrink >= 0.75 else 0.5 * max_shrink + det_s = detect_face(image, st) + index = np.where( + np.maximum(det_s[:, 2] - det_s[:, 0] + 1, det_s[:, 3] - det_s[:, 1] + 1) + > 30)[0] + det_s = det_s[index, :] + # enlarge one times + bt = min(2, max_shrink) if max_shrink > 1 else (st + max_shrink) / 2 + det_b = detect_face(image, bt) + + # enlarge small image x times for small face + if max_shrink > 2: + bt *= 2 + while bt < max_shrink: + det_b = np.row_stack((det_b, detect_face(image, bt))) + bt *= 2 + det_b = np.row_stack((det_b, detect_face(image, max_shrink))) + + # enlarge only detect small face + if bt > 1: + index = np.where( + np.minimum(det_b[:, 2] - det_b[:, 0] + 1, + det_b[:, 3] - det_b[:, 1] + 1) < 100)[0] + det_b = det_b[index, :] + else: + index = np.where( + np.maximum(det_b[:, 2] - det_b[:, 0] + 1, + det_b[:, 3] - det_b[:, 1] + 1) > 30)[0] + det_b = det_b[index, :] + return det_s, det_b + + +def get_im_shrink(image_shape): + max_shrink_v1 = (0x7fffffff / 577.0 / + (image_shape[1] * image_shape[2]))**0.5 + max_shrink_v2 = ( + (678 * 1024 * 2.0 * 2.0) / (image_shape[1] * image_shape[2]))**0.5 + max_shrink = get_round(min(max_shrink_v1, max_shrink_v2), 2) - 0.3 + + if max_shrink >= 1.5 and max_shrink < 2: + max_shrink = max_shrink - 0.1 + elif max_shrink >= 2 and max_shrink < 3: + max_shrink = max_shrink - 0.2 + elif max_shrink >= 3 and max_shrink < 4: + max_shrink = max_shrink - 0.3 + elif max_shrink >= 4 and max_shrink < 5: + max_shrink = max_shrink - 0.4 + elif max_shrink >= 5: + max_shrink = max_shrink - 0.5 + + print 'max_shrink = ', max_shrink + shrink = max_shrink if max_shrink < 1 else 1 + print "shrink = ", shrink + + return shrink, max_shrink + + +def infer(args, batch_size, data_args): + if not os.path.exists(args.model_dir): + raise ValueError("The model path [%s] does not exist." % + (args.model_dir)) + + infer_reader = paddle.batch( + reader.test(data_args, file_list), batch_size=batch_size) + + for batch_id, img in enumerate(infer_reader()): + image = img[0][0] + image_path = img[0][1] + + # image.size: [width, height] + image_shape = [3, image.size[1], image.size[0]] + + shrink, max_shrink = get_im_shrink(image_shape) - def if_exist(var): - return os.path.exists(os.path.join(model_dir, var.name)) + det0 = detect_face(image, shrink) + det1 = flip_test(image, shrink) + [det2, det3] = multi_scale_test(image, max_shrink) + det = np.row_stack((det0, det1, det2, det3)) + dets = bbox_vote(det) - fluid.io.load_vars(exe, model_dir, predicate=if_exist) + image_name = image_path.split('/')[-1] + image_class = image_path.split('/')[-2] + if not os.path.exists('./infer_results/' + image_class.encode('utf-8')): + os.makedirs('./infer_results/' + image_class.encode('utf-8')) - feed = {'image': fluid.create_lod_tensor(data, [], place)} - predict, = exe.run(infer_program, - feed=feed, - fetch_list=fetches, - return_numpy=False) - predict = np.array(predict) - draw_bounding_box_on_image(args.image_path, predict, args.confs_threshold) + f = open('./infer_results/' + image_class.encode('utf-8') + '/' + + image_name.encode('utf-8')[:-4] + '.txt', 'w') + write_to_txt(image_path, f, dets) + # draw_bounding_box_on_image(image_path, dets, args.confs_threshold) + print "Done" if __name__ == '__main__': @@ -98,10 +296,8 @@ if __name__ == '__main__': data_args = reader.Settings( data_dir=data_dir, - resize_h=args.resize_h, - resize_w=args.resize_w, mean_value=[104., 117., 123], apply_distort=False, apply_expand=False, ap_version='11point') - infer(args, data_args=data_args) + infer(args, batch_size=1, data_args=data_args) diff --git a/fluid/face_detection/pyramidbox.py b/fluid/face_detection/pyramidbox.py index 4bcce5e080b9dc7cba58d52f01343d563a114170..ce01cb7a113219e08d4deb2984d2a12b2590faa5 100644 --- a/fluid/face_detection/pyramidbox.py +++ b/fluid/face_detection/pyramidbox.py @@ -39,7 +39,11 @@ def conv_block(input, groups, filters, ksizes, strides=None, with_pool=True): act='relu') if with_pool: pool = fluid.layers.pool2d( - input=conv, pool_size=2, pool_type='max', pool_stride=2) + input=conv, + pool_size=2, + pool_type='max', + pool_stride=2, + ceil_mode=True) return conv, pool else: return conv @@ -148,6 +152,8 @@ class PyramidBox(object): b_attr = ParamAttr(learning_rate=2., regularizer=L2Decay(0.)) conv2 = fluid.layers.conv2d( up_to, ch, 1, act='relu', bias_attr=b_attr) + if self.is_infer: + upsampling = fluid.layers.crop(upsampling, shape=conv2) # eltwise mul conv_fuse = upsampling * conv2 return conv_fuse @@ -393,8 +399,11 @@ class PyramidBox(object): total_loss = face_loss + head_loss return face_loss, head_loss, total_loss - def infer(self): - test_program = fluid.default_main_program().clone(for_test=True) + def infer(self, main_program=None): + if main_program is None: + test_program = fluid.default_main_program().clone(for_test=True) + else: + test_program = main_program.clone(for_test=True) with fluid.program_guard(test_program): face_nmsed_out = fluid.layers.detection_output( self.face_mbox_loc, diff --git a/fluid/face_detection/reader.py b/fluid/face_detection/reader.py index 165d80961270a06df60b69e3a6451809b7f2d503..42109b1194cad071c6571ffa1eb590526a688033 100644 --- a/fluid/face_detection/reader.py +++ b/fluid/face_detection/reader.py @@ -238,34 +238,38 @@ def pyramidbox(settings, file_list, mode, shuffle): im_width, im_height = im.size # layout: label | xmin | ymin | xmax | ymax - bbox_labels = [] - for index_box in range(len(dict_input_txt[index_image])): - if index_box >= 2: - bbox_sample = [] - temp_info_box = dict_input_txt[index_image][ - index_box].split(' ') - xmin = float(temp_info_box[0]) - ymin = float(temp_info_box[1]) - w = float(temp_info_box[2]) - h = float(temp_info_box[3]) - xmax = xmin + w - ymax = ymin + h - - bbox_sample.append(1) - bbox_sample.append(float(xmin) / im_width) - bbox_sample.append(float(ymin) / im_height) - bbox_sample.append(float(xmax) / im_width) - bbox_sample.append(float(ymax) / im_height) - bbox_labels.append(bbox_sample) - - im, sample_labels = preprocess(im, bbox_labels, mode, settings) - sample_labels = np.array(sample_labels) - if len(sample_labels) == 0: continue - im = im.astype('float32') - boxes = sample_labels[:, 1:5] - lbls = [1] * len(boxes) - difficults = [1] * len(boxes) - yield im, boxes, expand_bboxes(boxes), lbls, difficults + if mode == 'train': + bbox_labels = [] + for index_box in range(len(dict_input_txt[index_image])): + if index_box >= 2: + bbox_sample = [] + temp_info_box = dict_input_txt[index_image][ + index_box].split(' ') + xmin = float(temp_info_box[0]) + ymin = float(temp_info_box[1]) + w = float(temp_info_box[2]) + h = float(temp_info_box[3]) + xmax = xmin + w + ymax = ymin + h + + bbox_sample.append(1) + bbox_sample.append(float(xmin) / im_width) + bbox_sample.append(float(ymin) / im_height) + bbox_sample.append(float(xmax) / im_width) + bbox_sample.append(float(ymax) / im_height) + bbox_labels.append(bbox_sample) + + im, sample_labels = preprocess(im, bbox_labels, mode, settings) + sample_labels = np.array(sample_labels) + if len(sample_labels) == 0: continue + im = im.astype('float32') + boxes = sample_labels[:, 1:5] + lbls = [1] * len(boxes) + difficults = [1] * len(boxes) + yield im, boxes, expand_bboxes(boxes), lbls, difficults + + if mode == 'test': + yield im, image_path return reader @@ -274,6 +278,10 @@ def train(settings, file_list, shuffle=True): return pyramidbox(settings, file_list, 'train', shuffle) +def test(settings, file_list): + return pyramidbox(settings, file_list, 'test', False) + + def infer(settings, image_path): def batch_reader(): img = Image.open(image_path) diff --git a/fluid/face_detection/train.py b/fluid/face_detection/train.py index c0c8efd2421cb560562ac002464aebbc85235602..c10722b9e33d6c9d05f961d3b2cf73a859b9da3c 100644 --- a/fluid/face_detection/train.py +++ b/fluid/face_detection/train.py @@ -16,11 +16,11 @@ add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable add_arg('parallel', bool, True, "parallel") -add_arg('learning_rate', float, 0.0001, "Learning rate.") -add_arg('batch_size', int, 16, "Minibatch size.") +add_arg('learning_rate', float, 0.001, "Learning rate.") +add_arg('batch_size', int, 12, "Minibatch size.") add_arg('num_passes', int, 120, "Epoch number.") add_arg('use_gpu', bool, True, "Whether use GPU.") -add_arg('use_pyramidbox', bool, False, "Whether use PyramidBox model.") +add_arg('use_pyramidbox', bool, True, "Whether use PyramidBox model.") add_arg('dataset', str, 'WIDERFACE', "coco2014, coco2017, and pascalvoc.") add_arg('model_save_dir', str, 'model', "The path to save model.") add_arg('pretrained_model', str, './pretrained/', "The init model path.") @@ -50,10 +50,10 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, fetches = [loss] epocs = 12880 / batch_size - boundaries = [epocs * 100, epocs * 125, epocs * 150] + boundaries = [epocs * 40, epocs * 60, epocs * 80, epocs * 100] values = [ - learning_rate, learning_rate * 0.1, learning_rate * 0.01, - learning_rate * 0.001 + learning_rate, learning_rate * 0.5, learning_rate * 0.25, + learning_rate * 0.1, learning_rate * 0.01 ] if optimizer_method == "momentum": @@ -70,12 +70,19 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, ) optimizer.minimize(loss) + # fluid.memory_optimize(fluid.default_main_program()) place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) + start_pass = 0 if pretrained_model: + if pretrained_model.isdigit(): + start_pass = int(pretrained_model) + 1 + pretrained_model = os.path.join(args.model_save_dir, pretrained_model) + print("Resume from %s " %(pretrained_model)) + if not os.path.exists(pretrained_model): raise ValueError("The pre-trained model path [%s] does not exist." % (pretrained_model)) @@ -98,14 +105,14 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, print 'save models to %s' % (model_path) fluid.io.save_persistables(exe, model_path) - for pass_id in range(num_passes): + for pass_id in range(start_pass, num_passes): start_time = time.time() prev_start_time = start_time end_time = 0 for batch_id, data in enumerate(train_reader()): prev_start_time = start_time start_time = time.time() - if len(data) < devices_num: continue + if len(data) < 2 * devices_num: continue if args.parallel: fetch_vars = train_exe.run(fetch_list=[v.name for v in fetches], feed=feeder.feed(data))