diff --git a/fluid/face_detection/.gitignore b/fluid/face_detection/.gitignore index 92c29d5e44f932477917e7b77721836024724d0e..13d42af893162c1908a39fea1d072a22929e5430 100644 --- a/fluid/face_detection/.gitignore +++ b/fluid/face_detection/.gitignore @@ -1,16 +1,7 @@ -# saved model model/ - -# pretrained model pretrained/ - -# used data and label data/ label/ - -# log and swap files *.swp *.log - -# infer infer_results/ diff --git a/fluid/face_detection/infer.py b/fluid/face_detection/infer.py index 32aeeac72680e9a98677a650e885dc6cf65ac0e3..71a878cb39f9888e3c308ee24e34dd6c3a073d33 100644 --- a/fluid/face_detection/infer.py +++ b/fluid/face_detection/infer.py @@ -181,9 +181,10 @@ def detect_face(image, shrink): def flip_test(image, shrink): - image = image.transpose(Image.FLIP_LEFT_RIGHT) - det_f = detect_face(image, shrink) + img = image.transpose(Image.FLIP_LEFT_RIGHT) + det_f = detect_face(img, shrink) det_t = np.zeros(det_f.shape) + # image.size: [width, height] det_t[:, 0] = image.size[0] - det_f[:, 2] det_t[:, 1] = det_f[:, 1] det_t[:, 2] = image.size[0] - det_f[:, 0] @@ -263,6 +264,7 @@ def infer(args, batch_size, data_args): image = img[0][0] image_path = img[0][1] + # image.size: [width, height] image_shape = [3, image.size[1], image.size[0]] shrink, max_shrink = get_im_shrink(image_shape) diff --git a/fluid/face_detection/train.py b/fluid/face_detection/train.py index c0c8efd2421cb560562ac002464aebbc85235602..c10722b9e33d6c9d05f961d3b2cf73a859b9da3c 100644 --- a/fluid/face_detection/train.py +++ b/fluid/face_detection/train.py @@ -16,11 +16,11 @@ add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable add_arg('parallel', bool, True, "parallel") -add_arg('learning_rate', float, 0.0001, "Learning rate.") -add_arg('batch_size', int, 16, "Minibatch size.") +add_arg('learning_rate', float, 0.001, "Learning rate.") +add_arg('batch_size', int, 12, "Minibatch size.") add_arg('num_passes', int, 120, "Epoch number.") add_arg('use_gpu', bool, True, "Whether use GPU.") -add_arg('use_pyramidbox', bool, False, "Whether use PyramidBox model.") +add_arg('use_pyramidbox', bool, True, "Whether use PyramidBox model.") add_arg('dataset', str, 'WIDERFACE', "coco2014, coco2017, and pascalvoc.") add_arg('model_save_dir', str, 'model', "The path to save model.") add_arg('pretrained_model', str, './pretrained/', "The init model path.") @@ -50,10 +50,10 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, fetches = [loss] epocs = 12880 / batch_size - boundaries = [epocs * 100, epocs * 125, epocs * 150] + boundaries = [epocs * 40, epocs * 60, epocs * 80, epocs * 100] values = [ - learning_rate, learning_rate * 0.1, learning_rate * 0.01, - learning_rate * 0.001 + learning_rate, learning_rate * 0.5, learning_rate * 0.25, + learning_rate * 0.1, learning_rate * 0.01 ] if optimizer_method == "momentum": @@ -70,12 +70,19 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, ) optimizer.minimize(loss) + # fluid.memory_optimize(fluid.default_main_program()) place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) + start_pass = 0 if pretrained_model: + if pretrained_model.isdigit(): + start_pass = int(pretrained_model) + 1 + pretrained_model = os.path.join(args.model_save_dir, pretrained_model) + print("Resume from %s " %(pretrained_model)) + if not os.path.exists(pretrained_model): raise ValueError("The pre-trained model path [%s] does not exist." % (pretrained_model)) @@ -98,14 +105,14 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, print 'save models to %s' % (model_path) fluid.io.save_persistables(exe, model_path) - for pass_id in range(num_passes): + for pass_id in range(start_pass, num_passes): start_time = time.time() prev_start_time = start_time end_time = 0 for batch_id, data in enumerate(train_reader()): prev_start_time = start_time start_time = time.time() - if len(data) < devices_num: continue + if len(data) < 2 * devices_num: continue if args.parallel: fetch_vars = train_exe.run(fetch_list=[v.name for v in fetches], feed=feeder.feed(data))