From e7684a4d2deb01b482ebddf316ee5400de681941 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Sat, 12 Oct 2019 18:48:48 +0800 Subject: [PATCH] [Face Detection] add facedetection config and eval code (#3466) * Add face detection configs and main code. * Integration multi-scale evaluation. --- configs/face_detection/blazeface.yml | 130 ++++++++++ configs/face_detection/blazeface_nas.yml | 132 ++++++++++ configs/face_detection/faceboxes.yml | 130 ++++++++++ configs/face_detection/faceboxes_lite.yml | 130 ++++++++++ ppdet/data/data_feed.py | 2 +- ppdet/data/transform/arrange_sample.py | 29 ++- ppdet/data/transform/op_helper.py | 3 +- ppdet/data/transform/operators.py | 6 +- ppdet/modeling/architectures/blazeface.py | 4 +- ppdet/modeling/architectures/faceboxes.py | 4 +- ppdet/modeling/backbones/faceboxnet.py | 1 - ppdet/utils/widerface_eval_utils.py | 227 +++++++++++++++++ tools/face_eval.py | 289 ++++++++++++++++++++++ tools/infer.py | 6 +- tools/train.py | 2 + 15 files changed, 1071 insertions(+), 24 deletions(-) create mode 100644 configs/face_detection/blazeface.yml create mode 100644 configs/face_detection/blazeface_nas.yml create mode 100644 configs/face_detection/faceboxes.yml create mode 100644 configs/face_detection/faceboxes_lite.yml create mode 100644 ppdet/utils/widerface_eval_utils.py create mode 100644 tools/face_eval.py diff --git a/configs/face_detection/blazeface.yml b/configs/face_detection/blazeface.yml new file mode 100644 index 000000000..8b27eae70 --- /dev/null +++ b/configs/face_detection/blazeface.yml @@ -0,0 +1,130 @@ +architecture: BlazeFace +max_iters: 320000 +train_feed: SSDTrainFeed +eval_feed: SSDEvalFeed +test_feed: SSDTestFeed +pretrain_weights: +use_gpu: true +snapshot_iter: 10000 +log_smooth_window: 20 +log_iter: 20 +metric: WIDERFACE +save_dir: output +weights: output/blazeface/model_final/ +# 1(label_class) + 1(background) +num_classes: 2 + +BlazeFace: + backbone: BlazeNet + output_decoder: + keep_top_k: 750 + nms_threshold: 0.3 + nms_top_k: 5000 + score_threshold: 0.01 + min_sizes: [[16.,24.], [32., 48., 64., 80., 96., 128.]] + use_density_prior_box: false + +BlazeNet: + with_extra_blocks: true + lite_edition: false + +LearningRate: + base_lr: 0.001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [240000, 300000] + +OptimizerBuilder: + optimizer: + momentum: 0.0 + type: RMSPropOptimizer + regularizer: + factor: 0.0005 + type: L2 + +SSDTrainFeed: + batch_size: 8 + use_process: True + dataset: + dataset_dir: dataset/wider_face + annotation: wider_face_split/wider_face_train_bbx_gt.txt + image_dir: WIDER_train/images + image_shape: [3, 640, 640] + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !NormalizeBox {} + - !RandomDistort + brightness_lower: 0.875 + brightness_upper: 1.125 + is_order: true + - !ExpandImage + max_ratio: 4 + prob: 0.5 + - !CropImageWithDataAchorSampling + anchor_sampler: + - [1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0] + batch_sampler: + - [1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + - [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + - [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + - [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + - [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + target_size: 640 + - !RandomInterpImage + target_size: 640 + - !RandomFlipImage + is_normalized: true + - !Permute {} + - !NormalizeImage + is_scale: false + mean: [104, 117, 123] + std: [127.502231, 127.502231, 127.502231] + +SSDEvalFeed: + batch_size: 1 + use_process: false + fields: ['image', 'im_id', 'gt_box'] + dataset: + dataset_dir: dataset/wider_face + annotation: annotFile.txt #wider_face_split/wider_face_val_bbx_gt.txt + image_dir: WIDER_val/images + drop_last: false + image_shape: [3, 640, 640] + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !NormalizeBox {} + - !ResizeImage + interp: 1 + target_size: 640 + use_cv2: false + - !Permute {} + - !NormalizeImage + is_scale: false + mean: [104, 117, 123] + std: [127.502231, 127.502231, 127.502231] + +SSDTestFeed: + batch_size: 1 + use_process: false + dataset: + use_default_label: true + drop_last: false + image_shape: [3, 640, 640] + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !ResizeImage + interp: 1 + target_size: 640 + use_cv2: false + - !Permute {} + - !NormalizeImage + is_scale: false + mean: [104, 117, 123] + std: [127.502231, 127.502231, 127.502231] diff --git a/configs/face_detection/blazeface_nas.yml b/configs/face_detection/blazeface_nas.yml new file mode 100644 index 000000000..45356bda7 --- /dev/null +++ b/configs/face_detection/blazeface_nas.yml @@ -0,0 +1,132 @@ +architecture: BlazeFace +max_iters: 320000 +train_feed: SSDTrainFeed +eval_feed: SSDEvalFeed +test_feed: SSDTestFeed +pretrain_weights: +use_gpu: true +snapshot_iter: 10000 +log_smooth_window: 20 +log_iter: 20 +metric: WIDERFACE +save_dir: output +weights: output/blazeface_nas/model_final/ +# 1(label_class) + 1(background) +num_classes: 2 + +BlazeFace: + backbone: BlazeNet + output_decoder: + keep_top_k: 750 + nms_threshold: 0.3 + nms_top_k: 5000 + score_threshold: 0.01 + min_sizes: [[16.,24.], [32., 48., 64., 80., 96., 128.]] + use_density_prior_box: false + +BlazeNet: + blaze_filters: [[12, 12], [12, 12, 2], [12, 12]] + double_blaze_filters: [[12, 16, 24, 2], [24, 12, 24], [24, 16, 72, 2], [72, 12, 72]] + with_extra_blocks: true + lite_edition: false + +LearningRate: + base_lr: 0.001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [240000, 300000] + +OptimizerBuilder: + optimizer: + momentum: 0.0 + type: RMSPropOptimizer + regularizer: + factor: 0.0005 + type: L2 + +SSDTrainFeed: + batch_size: 8 + use_process: True + dataset: + dataset_dir: dataset/wider_face + annotation: wider_face_split/wider_face_train_bbx_gt.txt + image_dir: WIDER_train/images + image_shape: [3, 640, 640] + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !NormalizeBox {} + - !RandomDistort + brightness_lower: 0.875 + brightness_upper: 1.125 + is_order: true + - !ExpandImage + max_ratio: 4 + prob: 0.5 + - !CropImageWithDataAchorSampling + anchor_sampler: + - [1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0] + batch_sampler: + - [1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + - [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + - [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + - [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + - [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + target_size: 640 + - !RandomInterpImage + target_size: 640 + - !RandomFlipImage + is_normalized: true + - !Permute {} + - !NormalizeImage + is_scale: false + mean: [104, 117, 123] + std: [127.502231, 127.502231, 127.502231] + +SSDEvalFeed: + batch_size: 1 + use_process: false + fields: ['image', 'im_id', 'gt_box'] + dataset: + dataset_dir: dataset/wider_face + annotation: wider_face_split/wider_face_val_bbx_gt.txt + image_dir: WIDER_val/images + drop_last: false + image_shape: [3, 640, 640] + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !NormalizeBox {} + - !ResizeImage + interp: 1 + target_size: 640 + use_cv2: false + - !Permute {} + - !NormalizeImage + is_scale: false + mean: [104, 117, 123] + std: [127.502231, 127.502231, 127.502231] + +SSDTestFeed: + batch_size: 1 + use_process: false + dataset: + use_default_label: true + drop_last: false + image_shape: [3, 640, 640] + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !ResizeImage + interp: 1 + target_size: 640 + use_cv2: false + - !Permute {} + - !NormalizeImage + is_scale: false + mean: [104, 117, 123] + std: [127.502231, 127.502231, 127.502231] diff --git a/configs/face_detection/faceboxes.yml b/configs/face_detection/faceboxes.yml new file mode 100644 index 000000000..b27872329 --- /dev/null +++ b/configs/face_detection/faceboxes.yml @@ -0,0 +1,130 @@ +architecture: FaceBoxes +train_feed: SSDTrainFeed +eval_feed: SSDEvalFeed +test_feed: SSDTestFeed +pretrain_weights: +use_gpu: true +max_iters: 320000 +snapshot_iter: 10000 +log_smooth_window: 20 +log_iter: 20 +metric: WIDERFACE +save_dir: output +weights: output/faceboxes/model_final/ +# 1(label_class) + 1(background) +num_classes: 2 + +FaceBoxes: + backbone: FaceBoxNet + densities: [[4, 2, 1], [1], [1]] + fixed_sizes: [[32., 64., 128.], [256.], [512.]] + output_decoder: + keep_top_k: 750 + nms_threshold: 0.3 + nms_top_k: 5000 + score_threshold: 0.01 + +FaceBoxNet: + with_extra_blocks: true + lite_edition: false + +LearningRate: + base_lr: 0.001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [240000, 300000] + +OptimizerBuilder: + optimizer: + momentum: 0.0 + type: RMSPropOptimizer + regularizer: + factor: 0.0005 + type: L2 + +SSDTrainFeed: + batch_size: 8 + use_process: True + dataset: + dataset_dir: dataset/wider_face + annotation: wider_face_split/wider_face_train_bbx_gt.txt + image_dir: WIDER_train/images + image_shape: [3, 640, 640] + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !NormalizeBox {} + - !RandomDistort + brightness_lower: 0.875 + brightness_upper: 1.125 + is_order: true + - !ExpandImage + max_ratio: 4 + prob: 0.5 + - !CropImageWithDataAchorSampling + anchor_sampler: + - [1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0] + batch_sampler: + - [1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + - [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + - [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + - [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + - [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + target_size: 640 + - !RandomInterpImage + target_size: 640 + - !RandomFlipImage + is_normalized: true + - !Permute {} + - !NormalizeImage + is_scale: false + mean: [104, 117, 123] + std: [127.502231, 127.502231, 127.502231] + +SSDEvalFeed: + batch_size: 1 + use_process: false + fields: ['image', 'im_id', 'gt_box'] + dataset: + dataset_dir: dataset/wider_face + annotation: wider_face_split/wider_face_val_bbx_gt.txt + image_dir: WIDER_val/images + drop_last: false + image_shape: [3, 640, 640] + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !NormalizeBox {} + - !ResizeImage + interp: 1 + target_size: 640 + use_cv2: false + - !Permute {} + - !NormalizeImage + is_scale: false + mean: [104, 117, 123] + std: [127.502231, 127.502231, 127.502231] + +SSDTestFeed: + batch_size: 1 + use_process: false + dataset: + use_default_label: true + drop_last: false + image_shape: [3, 640, 640] + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !ResizeImage + interp: 1 + target_size: 640 + use_cv2: false + - !Permute {} + - !NormalizeImage + is_scale: false + mean: [104, 117, 123] + std: [127.502231, 127.502231, 127.502231] diff --git a/configs/face_detection/faceboxes_lite.yml b/configs/face_detection/faceboxes_lite.yml new file mode 100644 index 000000000..157f0337e --- /dev/null +++ b/configs/face_detection/faceboxes_lite.yml @@ -0,0 +1,130 @@ +architecture: FaceBoxes +train_feed: SSDTrainFeed +eval_feed: SSDEvalFeed +test_feed: SSDTestFeed +pretrain_weights: +use_gpu: true +max_iters: 320000 +snapshot_iter: 10000 +log_smooth_window: 20 +log_iter: 20 +metric: WIDERFACE +save_dir: output +weights: output/faceboxes_lite/model_final/ +# 1(label_class) + 1(background) +num_classes: 2 + +FaceBoxes: + backbone: FaceBoxNet + densities: [[2, 1, 1], [1, 1]] + fixed_sizes: [[16., 32., 64.], [96., 128.]] + output_decoder: + keep_top_k: 750 + nms_threshold: 0.3 + nms_top_k: 5000 + score_threshold: 0.01 + +FaceBoxNet: + with_extra_blocks: true + lite_edition: true + +LearningRate: + base_lr: 0.001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [240000, 300000] + +OptimizerBuilder: + optimizer: + momentum: 0.0 + type: RMSPropOptimizer + regularizer: + factor: 0.0005 + type: L2 + +SSDTrainFeed: + batch_size: 8 + use_process: True + dataset: + dataset_dir: dataset/wider_face + annotation: wider_face_split/wider_face_train_bbx_gt.txt + image_dir: WIDER_train/images + image_shape: [3, 640, 640] + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !NormalizeBox {} + - !RandomDistort + brightness_lower: 0.875 + brightness_upper: 1.125 + is_order: true + - !ExpandImage + max_ratio: 4 + prob: 0.5 + - !CropImageWithDataAchorSampling + anchor_sampler: + - [1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0] + batch_sampler: + - [1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + - [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + - [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + - [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + - [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] + target_size: 640 + - !RandomInterpImage + target_size: 640 + - !RandomFlipImage + is_normalized: true + - !Permute {} + - !NormalizeImage + is_scale: false + mean: [104, 117, 123] + std: [127.502231, 127.502231, 127.502231] + +SSDEvalFeed: + batch_size: 1 + use_process: false + fields: ['image', 'im_id', 'gt_box'] + dataset: + dataset_dir: dataset/wider_face + annotation: wider_face_split/wider_face_val_bbx_gt.txt + image_dir: WIDER_val/images + drop_last: false + image_shape: [3, 640, 640] + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !NormalizeBox {} + - !ResizeImage + interp: 1 + target_size: 640 + use_cv2: false + - !Permute {} + - !NormalizeImage + is_scale: false + mean: [104, 117, 123] + std: [127.502231, 127.502231, 127.502231] + +SSDTestFeed: + batch_size: 1 + use_process: false + dataset: + use_default_label: true + drop_last: false + image_shape: [3, 640, 640] + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !ResizeImage + interp: 1 + target_size: 640 + use_cv2: false + - !Permute {} + - !NormalizeImage + is_scale: false + mean: [104, 117, 123] + std: [127.502231, 127.502231, 127.502231] diff --git a/ppdet/data/data_feed.py b/ppdet/data/data_feed.py index 4f67bed1a..c79c5e949 100644 --- a/ppdet/data/data_feed.py +++ b/ppdet/data/data_feed.py @@ -781,7 +781,7 @@ class SSDEvalFeed(DataFeed): bufsize=10, use_process=False, memsize=None): - sample_transforms.append(ArrangeEvalSSD()) + sample_transforms.append(ArrangeEvalSSD(fields)) super(SSDEvalFeed, self).__init__( dataset, fields, diff --git a/ppdet/data/transform/arrange_sample.py b/ppdet/data/transform/arrange_sample.py index 697995cd7..e082c2dd7 100644 --- a/ppdet/data/transform/arrange_sample.py +++ b/ppdet/data/transform/arrange_sample.py @@ -200,8 +200,9 @@ class ArrangeEvalSSD(BaseOperator): Transform dict to tuple format needed for training. """ - def __init__(self): + def __init__(self, fields): super(ArrangeEvalSSD, self).__init__() + self.fields = fields def __call__(self, sample, context=None): """ @@ -212,17 +213,25 @@ class ArrangeEvalSSD(BaseOperator): Returns: sample: a tuple containing the following items: (image) """ - im = sample['image'] + outs = [] if len(sample['gt_bbox']) != len(sample['gt_class']): raise ValueError("gt num mismatch: bbox and class.") - im_id = sample['im_id'] - h = sample['h'] - w = sample['w'] - im_shape = np.array((h, w)) - gt_bbox = sample['gt_bbox'] - gt_class = sample['gt_class'] - difficult = sample['difficult'] - outs = (im, im_shape, im_id, gt_bbox, gt_class, difficult) + for field in self.fields: + if field == 'im_shape': + h = sample['h'] + w = sample['w'] + im_shape = np.array((h, w)) + outs.append(im_shape) + elif field == 'is_difficult': + outs.append(sample['difficult']) + elif field == 'gt_box': + outs.append(sample['gt_bbox']) + elif field == 'gt_label': + outs.append(sample['gt_class']) + else: + outs.append(sample[field]) + + outs = tuple(outs) return outs diff --git a/ppdet/data/transform/op_helper.py b/ppdet/data/transform/op_helper.py index f46f9c4e5..838714f4d 100644 --- a/ppdet/data/transform/op_helper.py +++ b/ppdet/data/transform/op_helper.py @@ -102,7 +102,8 @@ def bbox_area_sampling(bboxes, labels, scores, target_size, min_size): else: new_bboxes.append(bbox) new_labels.append(labels[i]) - new_scores.append(scores[i]) + if scores is not None and scores.size != 0: + new_scores.append(scores[i]) bboxes = np.array(new_bboxes) labels = np.array(new_labels) scores = np.array(new_scores) diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index cd4ef4e8a..0a426e081 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -640,7 +640,7 @@ class CropImageWithDataAchorSampling(BaseOperator): self.sampling_prob = sampling_prob self.min_size = min_size self.avoid_no_bbox = avoid_no_bbox - self.scale_array = np.array(das_anchor_scales) + self.das_anchor_scales = np.array(das_anchor_scales) def __call__(self, sample, context): """ @@ -674,8 +674,8 @@ class CropImageWithDataAchorSampling(BaseOperator): if found >= sampler[0]: break sample_bbox = data_anchor_sampling( - gt_bbox, image_width, image_height, self.scale_array, - self.target_size) + gt_bbox, image_width, image_height, + self.das_anchor_scales, self.target_size) if sample_bbox == 0: break if satisfy_sample_constraint_coverage(sampler, sample_bbox, diff --git a/ppdet/modeling/architectures/blazeface.py b/ppdet/modeling/architectures/blazeface.py index d7a221307..cc9a2bb33 100644 --- a/ppdet/modeling/architectures/blazeface.py +++ b/ppdet/modeling/architectures/blazeface.py @@ -108,9 +108,7 @@ class BlazeFace(object): use_density_prior_box=False): def permute_and_reshape(input, last_dim): trans = fluid.layers.transpose(input, perm=[0, 2, 3, 1]) - compile_shape = [ - trans.shape[0], np.prod(trans.shape[1:]) // last_dim, last_dim - ] + compile_shape = [0, -1, last_dim] return fluid.layers.reshape(trans, shape=compile_shape) def _is_list_or_tuple_(data): diff --git a/ppdet/modeling/architectures/faceboxes.py b/ppdet/modeling/architectures/faceboxes.py index 1963b692e..194b3a7e8 100644 --- a/ppdet/modeling/architectures/faceboxes.py +++ b/ppdet/modeling/architectures/faceboxes.py @@ -93,9 +93,7 @@ class FaceBoxes(object): def _multi_box_head(self, inputs, image, num_classes=2): def permute_and_reshape(input, last_dim): trans = fluid.layers.transpose(input, perm=[0, 2, 3, 1]) - compile_shape = [ - trans.shape[0], np.prod(trans.shape[1:]) // last_dim, last_dim - ] + compile_shape = [0, -1, last_dim] return fluid.layers.reshape(trans, shape=compile_shape) def _is_list_or_tuple_(data): diff --git a/ppdet/modeling/backbones/faceboxnet.py b/ppdet/modeling/backbones/faceboxnet.py index cae1eed39..0b82c86b2 100644 --- a/ppdet/modeling/backbones/faceboxnet.py +++ b/ppdet/modeling/backbones/faceboxnet.py @@ -238,7 +238,6 @@ class FaceBoxNet(object): use_cudnn=use_cudnn, param_attr=parameter_attr, bias_attr=False) - print("{}:{}".format(name, conv.shape)) return fluid.layers.batch_norm(input=conv, act=act) def _conv_norm_crelu( diff --git a/ppdet/utils/widerface_eval_utils.py b/ppdet/utils/widerface_eval_utils.py new file mode 100644 index 000000000..a19cd0835 --- /dev/null +++ b/ppdet/utils/widerface_eval_utils.py @@ -0,0 +1,227 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np + +from ppdet.data.source.widerface_loader import widerface_label +from ppdet.utils.coco_eval import bbox2out + +import logging +logger = logging.getLogger(__name__) + +__all__ = [ + 'get_shrink', 'bbox_vote', 'save_widerface_bboxes', 'save_fddb_bboxes', + 'to_chw_bgr', 'bbox2out', 'get_category_info' +] + + +def to_chw_bgr(image): + """ + Transpose image from HWC to CHW and from RBG to BGR. + Args: + image (np.array): an image with HWC and RBG layout. + """ + # HWC to CHW + if len(image.shape) == 3: + image = np.swapaxes(image, 1, 2) + image = np.swapaxes(image, 1, 0) + # RBG to BGR + image = image[[2, 1, 0], :, :] + return image + + +def bbox_vote(det): + order = det[:, 4].ravel().argsort()[::-1] + det = det[order, :] + if det.shape[0] == 0: + dets = np.array([[10, 10, 20, 20, 0.002]]) + det = np.empty(shape=[0, 5]) + while det.shape[0] > 0: + # IOU + area = (det[:, 2] - det[:, 0] + 1) * (det[:, 3] - det[:, 1] + 1) + xx1 = np.maximum(det[0, 0], det[:, 0]) + yy1 = np.maximum(det[0, 1], det[:, 1]) + xx2 = np.minimum(det[0, 2], det[:, 2]) + yy2 = np.minimum(det[0, 3], det[:, 3]) + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + o = inter / (area[0] + area[:] - inter) + + # nms + merge_index = np.where(o >= 0.3)[0] + det_accu = det[merge_index, :] + det = np.delete(det, merge_index, 0) + if merge_index.shape[0] <= 1: + if det.shape[0] == 0: + try: + dets = np.row_stack((dets, det_accu)) + except: + dets = det_accu + continue + det_accu[:, 0:4] = det_accu[:, 0:4] * np.tile(det_accu[:, -1:], (1, 4)) + max_score = np.max(det_accu[:, 4]) + det_accu_sum = np.zeros((1, 5)) + det_accu_sum[:, 0:4] = np.sum(det_accu[:, 0:4], + axis=0) / np.sum(det_accu[:, -1:]) + det_accu_sum[:, 4] = max_score + try: + dets = np.row_stack((dets, det_accu_sum)) + except: + dets = det_accu_sum + dets = dets[0:750, :] + # Only keep 0.3 or more + keep_index = np.where(dets[:, 4] >= 0.01)[0] + dets = dets[keep_index, :] + return dets + + +def get_shrink(height, width): + """ + Args: + height (int): image height. + width (int): image width. + """ + # avoid out of memory + max_shrink_v1 = (0x7fffffff / 577.0 / (height * width))**0.5 + max_shrink_v2 = ((678 * 1024 * 2.0 * 2.0) / (height * width))**0.5 + + def get_round(x, loc): + str_x = str(x) + if '.' in str_x: + str_before, str_after = str_x.split('.') + len_after = len(str_after) + if len_after >= 3: + str_final = str_before + '.' + str_after[0:loc] + return float(str_final) + else: + return x + + max_shrink = get_round(min(max_shrink_v1, max_shrink_v2), 2) - 0.3 + if max_shrink >= 1.5 and max_shrink < 2: + max_shrink = max_shrink - 0.1 + elif max_shrink >= 2 and max_shrink < 3: + max_shrink = max_shrink - 0.2 + elif max_shrink >= 3 and max_shrink < 4: + max_shrink = max_shrink - 0.3 + elif max_shrink >= 4 and max_shrink < 5: + max_shrink = max_shrink - 0.4 + elif max_shrink >= 5: + max_shrink = max_shrink - 0.5 + + shrink = max_shrink if max_shrink < 1 else 1 + return shrink, max_shrink + + +def save_widerface_bboxes(image_path, bboxes_scores, output_dir): + image_name = image_path.split('/')[-1] + image_class = image_path.split('/')[-2] + odir = os.path.join(output_dir, image_class) + if not os.path.exists(odir): + os.makedirs(odir) + + ofname = os.path.join(odir, '%s.txt' % (image_name[:-4])) + f = open(ofname, 'w') + f.write('{:s}\n'.format(image_class + '/' + image_name)) + f.write('{:d}\n'.format(bboxes_scores.shape[0])) + for box_score in bboxes_scores: + xmin, ymin, xmax, ymax, score = box_score + f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'.format(xmin, ymin, ( + xmax - xmin + 1), (ymax - ymin + 1), score)) + f.close() + logger.info("The predicted result is saved as {}".format(ofname)) + + +def save_fddb_bboxes(bboxes_scores, + output_dir, + output_fname='pred_fddb_res.txt'): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + predict_file = os.path.join(output_dir, output_fname) + f = open(predict_file, 'w') + for image_path, dets in bboxes_scores.iteritems(): + f.write('{:s}\n'.format(image_path)) + f.write('{:d}\n'.format(dets.shape[0])) + for box_score in dets: + xmin, ymin, xmax, ymax, score = box_score + width, height = xmax - xmin, ymax - ymin + f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n' + .format(xmin, ymin, width, height, score)) + logger.info("The predicted result is saved as {}".format(predict_file)) + return predict_file + + +def get_category_info(anno_file=None, + with_background=True, + use_default_label=False): + if use_default_label or anno_file is None \ + or not os.path.exists(anno_file): + logger.info("Not found annotation file {}, load " + "wider-face categories.".format(anno_file)) + return widerfaceall_category_info(with_background) + else: + logger.info("Load categories from {}".format(anno_file)) + return get_category_info_from_anno(anno_file, with_background) + + +def get_category_info_from_anno(anno_file, with_background=True): + """ + Get class id to category id map and category id + to category name map from annotation file. + Args: + anno_file (str): annotation file path + with_background (bool, default True): + whether load background as class 0. + """ + cats = [] + with open(anno_file) as f: + for line in f.readlines(): + cats.append(line.strip()) + + if cats[0] != 'background' and with_background: + cats.insert(0, 'background') + if cats[0] == 'background' and not with_background: + cats = cats[1:] + + clsid2catid = {i: i for i in range(len(cats))} + catid2name = {i: name for i, name in enumerate(cats)} + + return clsid2catid, catid2name + + +def widerfaceall_category_info(with_background=True): + """ + Get class id to category id map and category id + to category name map of mixup wider_face dataset + + Args: + with_background (bool, default True): + whether load background as class 0. + """ + label_map = widerface_label(with_background) + label_map = sorted(label_map.items(), key=lambda x: x[1]) + cats = [l[0] for l in label_map] + + if with_background: + cats.insert(0, 'background') + + clsid2catid = {i: i for i in range(len(cats))} + catid2name = {i: name for i, name in enumerate(cats)} + + return clsid2catid, catid2name diff --git a/tools/face_eval.py b/tools/face_eval.py new file mode 100644 index 000000000..a049d26da --- /dev/null +++ b/tools/face_eval.py @@ -0,0 +1,289 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import paddle.fluid as fluid +import numpy as np +from PIL import Image +from collections import OrderedDict + +import ppdet.utils.checkpoint as checkpoint +from ppdet.utils.cli import ArgsParser +from ppdet.utils.check import check_gpu +from ppdet.utils.widerface_eval_utils import get_shrink, bbox_vote, \ + save_widerface_bboxes, save_fddb_bboxes, to_chw_bgr +from ppdet.core.workspace import load_config, merge_config, create +from ppdet.modeling.model_input import create_feed + +import logging +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def face_img_process(image, + mean=[104., 117., 123.], + std=[127.502231, 127.502231, 127.502231]): + img = np.array(image) + img = to_chw_bgr(img) + img = img.astype('float32') + img -= np.array(mean)[:, np.newaxis, np.newaxis].astype('float32') + img /= np.array(std)[:, np.newaxis, np.newaxis].astype('float32') + img = [img] + img = np.array(img) + return img + + +def face_eval_run(exe, + compile_program, + fetches, + img_root_dir, + gt_file, + pred_dir='output/pred', + eval_mode='widerface'): + # load ground truth files + with open(gt_file, 'r') as f: + gt_lines = f.readlines() + imid2path = [] + pos_gt = 0 + while pos_gt < len(gt_lines): + name_gt = gt_lines[pos_gt].strip('\n\t').split()[0] + imid2path.append(name_gt) + pos_gt += 1 + n_gt = int(gt_lines[pos_gt].strip('\n\t').split()[0]) + pos_gt += 1 + n_gt + logger.info('The ground truth file load {} images'.format(len(imid2path))) + + dets_dist = OrderedDict() + for iter_id, im_path in enumerate(imid2path): + image_path = os.path.join(img_root_dir, im_path) + if eval_mode == 'fddb': + image_path += '.jpg' + image = Image.open(image_path).convert('RGB') + shrink, max_shrink = get_shrink(image.size[1], image.size[0]) + + det0 = detect_face(exe, compile_program, fetches, image, shrink) + det1 = flip_test(exe, compile_program, fetches, image, shrink) + [det2, det3] = multi_scale_test(exe, compile_program, fetches, image, + max_shrink) + det4 = multi_scale_test_pyramid(exe, compile_program, fetches, image, + max_shrink) + det = np.row_stack((det0, det1, det2, det3, det4)) + dets = bbox_vote(det) + if eval_mode == 'widerface': + save_widerface_bboxes(image_path, dets, pred_dir) + else: + dets_dist[im_path] = dets + if iter_id % 100 == 0: + logger.info('Test iter {}'.format(iter_id)) + if eval_mode == 'fddb': + save_fddb_bboxes(dets_dist, pred_dir) + logger.info("Finish evaluation.") + + +def detect_face(exe, compile_program, fetches, image, shrink): + image_shape = [3, image.size[1], image.size[0]] + if shrink != 1: + h, w = int(image_shape[1] * shrink), int(image_shape[2] * shrink) + image = image.resize((w, h), Image.ANTIALIAS) + image_shape = [3, h, w] + + img = face_img_process(image) + detection, = exe.run(compile_program, + feed={'image': img}, + fetch_list=[fetches['bbox']], + return_numpy=False) + detection = np.array(detection) + # layout: xmin, ymin, xmax. ymax, score + if np.prod(detection.shape) == 1: + logger.info("No face detected") + return np.array([[0, 0, 0, 0, 0]]) + det_conf = detection[:, 1] + det_xmin = image_shape[2] * detection[:, 2] / shrink + det_ymin = image_shape[1] * detection[:, 3] / shrink + det_xmax = image_shape[2] * detection[:, 4] / shrink + det_ymax = image_shape[1] * detection[:, 5] / shrink + + det = np.column_stack((det_xmin, det_ymin, det_xmax, det_ymax, det_conf)) + return det + + +def flip_test(exe, compile_program, fetches, image, shrink): + img = image.transpose(Image.FLIP_LEFT_RIGHT) + det_f = detect_face(exe, compile_program, fetches, img, shrink) + det_t = np.zeros(det_f.shape) + # image.size: [width, height] + det_t[:, 0] = image.size[0] - det_f[:, 2] + det_t[:, 1] = det_f[:, 1] + det_t[:, 2] = image.size[0] - det_f[:, 0] + det_t[:, 3] = det_f[:, 3] + det_t[:, 4] = det_f[:, 4] + return det_t + + +def multi_scale_test(exe, compile_program, fetches, image, max_shrink): + # Shrink detecting is only used to detect big faces + st = 0.5 if max_shrink >= 0.75 else 0.5 * max_shrink + det_s = detect_face(exe, compile_program, fetches, image, st) + index = np.where( + np.maximum(det_s[:, 2] - det_s[:, 0] + 1, det_s[:, 3] - det_s[:, 1] + 1) + > 30)[0] + det_s = det_s[index, :] + # Enlarge one times + bt = min(2, max_shrink) if max_shrink > 1 else (st + max_shrink) / 2 + det_b = detect_face(exe, compile_program, fetches, image, bt) + + # Enlarge small image x times for small faces + if max_shrink > 2: + bt *= 2 + while bt < max_shrink: + det_b = np.row_stack((det_b, detect_face(exe, compile_program, + fetches, image, bt))) + bt *= 2 + det_b = np.row_stack((det_b, detect_face(exe, compile_program, fetches, + image, max_shrink))) + + # Enlarged images are only used to detect small faces. + if bt > 1: + index = np.where( + np.minimum(det_b[:, 2] - det_b[:, 0] + 1, + det_b[:, 3] - det_b[:, 1] + 1) < 100)[0] + det_b = det_b[index, :] + # Shrinked images are only used to detect big faces. + else: + index = np.where( + np.maximum(det_b[:, 2] - det_b[:, 0] + 1, + det_b[:, 3] - det_b[:, 1] + 1) > 30)[0] + det_b = det_b[index, :] + return det_s, det_b + + +def multi_scale_test_pyramid(exe, compile_program, fetches, image, max_shrink): + # Use image pyramids to detect faces + det_b = detect_face(exe, compile_program, fetches, image, 0.25) + index = np.where( + np.maximum(det_b[:, 2] - det_b[:, 0] + 1, det_b[:, 3] - det_b[:, 1] + 1) + > 30)[0] + det_b = det_b[index, :] + + st = [0.75, 1.25, 1.5, 1.75] + for i in range(len(st)): + if st[i] <= max_shrink: + det_temp = detect_face(exe, compile_program, fetches, image, st[i]) + # Enlarged images are only used to detect small faces. + if st[i] > 1: + index = np.where( + np.minimum(det_temp[:, 2] - det_temp[:, 0] + 1, + det_temp[:, 3] - det_temp[:, 1] + 1) < 100)[0] + det_temp = det_temp[index, :] + # Shrinked images are only used to detect big faces. + else: + index = np.where( + np.maximum(det_temp[:, 2] - det_temp[:, 0] + 1, + det_temp[:, 3] - det_temp[:, 1] + 1) > 30)[0] + det_temp = det_temp[index, :] + det_b = np.row_stack((det_b, det_temp)) + return det_b + + +def main(): + """ + Main evaluate function + """ + cfg = load_config(FLAGS.config) + if 'architecture' in cfg: + main_arch = cfg.architecture + else: + raise ValueError("'architecture' not specified in config file.") + + merge_config(FLAGS.opt) + + # check if set use_gpu=True in paddlepaddle cpu version + check_gpu(cfg.use_gpu) + + if 'eval_feed' not in cfg: + eval_feed = create(main_arch + 'EvalFeed') + else: + eval_feed = create(cfg.eval_feed) + + # define executor + place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + # build program + model = create(main_arch) + startup_prog = fluid.Program() + eval_prog = fluid.Program() + with fluid.program_guard(eval_prog, startup_prog): + with fluid.unique_name.guard(): + _, feed_vars = create_feed(eval_feed, use_pyreader=False) + fetches = model.eval(feed_vars) + + eval_prog = eval_prog.clone(True) + + # load model + exe.run(startup_prog) + if 'weights' in cfg: + checkpoint.load_params(exe, eval_prog, cfg.weights) + + assert cfg.metric in ['WIDERFACE'], \ + "unknown metric type {}".format(cfg.metric) + + annotation_file = getattr(eval_feed.dataset, 'annotation', None) + dataset_dir = FLAGS.dataset_dir if FLAGS.dataset_dir else \ + getattr(eval_feed.dataset, 'dataset_dir', None) + img_root_dir = dataset_dir + if FLAGS.eval_mode == "widerface": + image_dir = getattr(eval_feed.dataset, 'image_dir', None) + img_root_dir = os.path.join(dataset_dir, image_dir) + gt_file = os.path.join(dataset_dir, annotation_file) + pred_dir = FLAGS.output_eval if FLAGS.output_eval else 'output/pred' + face_eval_run( + exe, + eval_prog, + fetches, + img_root_dir, + gt_file, + pred_dir=pred_dir, + eval_mode=FLAGS.eval_mode) + + +if __name__ == '__main__': + parser = ArgsParser() + parser.add_argument( + "-d", + "--dataset_dir", + default=None, + type=str, + help="Dataset path, same as DataFeed.dataset.dataset_dir") + parser.add_argument( + "-f", + "--output_eval", + default=None, + type=str, + help="Evaluation file directory, default is current directory.") + parser.add_argument( + "-e", + "--eval_mode", + default="widerface", + type=str, + help="Evaluation mode, include `widerface` and `fddb`, default is `widerface`." + ) + FLAGS = parser.parse_args() + main() diff --git a/tools/infer.py b/tools/infer.py index 608587000..9801cd6fc 100644 --- a/tools/infer.py +++ b/tools/infer.py @@ -186,12 +186,12 @@ def main(): save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog) # parse infer fetches - assert cfg.metric in ['COCO', 'VOC'], \ + assert cfg.metric in ['COCO', 'VOC', 'WIDERFACE'], \ "unknown metric type {}".format(cfg.metric) extra_keys = [] if cfg['metric'] == 'COCO': extra_keys = ['im_info', 'im_id', 'im_shape'] - if cfg['metric'] == 'VOC': + if cfg['metric'] == 'VOC' or cfg['metric'] == 'WIDERFACE': extra_keys = ['im_id', 'im_shape'] keys, values, _ = parse_fetches(test_fetches, infer_prog, extra_keys) @@ -200,6 +200,8 @@ def main(): from ppdet.utils.coco_eval import bbox2out, mask2out, get_category_info if cfg.metric == "VOC": from ppdet.utils.voc_eval import bbox2out, get_category_info + if cfg.metric == "WIDERFACE": + from ppdet.utils.widerface_eval_utils import bbox2out, get_category_info anno_file = getattr(test_feed.dataset, 'annotation', None) with_background = getattr(test_feed, 'with_background', True) diff --git a/tools/train.py b/tools/train.py index e1d130aa2..0d4dd85ee 100644 --- a/tools/train.py +++ b/tools/train.py @@ -154,6 +154,8 @@ def main(): extra_keys = ['im_info', 'im_id', 'im_shape'] if cfg.metric == 'VOC': extra_keys = ['gt_box', 'gt_label', 'is_difficult'] + if cfg.metric == 'WIDERFACE': + extra_keys = ['im_id', 'im_shape', 'gt_box'] eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog, extra_keys) -- GitLab