提交 e743c29c 编写于 作者: G Guanghua Yu 提交者: qingqing01

[Face Detection] add facedetection config and eval code (#3466)

* Add face detection configs and main code.
* Integration multi-scale evaluation.
上级 ebe078cf
architecture: BlazeFace
max_iters: 320000
train_feed: SSDTrainFeed
eval_feed: SSDEvalFeed
test_feed: SSDTestFeed
pretrain_weights:
use_gpu: true
snapshot_iter: 10000
log_smooth_window: 20
log_iter: 20
metric: WIDERFACE
save_dir: output
weights: output/blazeface/model_final/
# 1(label_class) + 1(background)
num_classes: 2
BlazeFace:
backbone: BlazeNet
output_decoder:
keep_top_k: 750
nms_threshold: 0.3
nms_top_k: 5000
score_threshold: 0.01
min_sizes: [[16.,24.], [32., 48., 64., 80., 96., 128.]]
use_density_prior_box: false
BlazeNet:
with_extra_blocks: true
lite_edition: false
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [240000, 300000]
OptimizerBuilder:
optimizer:
momentum: 0.0
type: RMSPropOptimizer
regularizer:
factor: 0.0005
type: L2
SSDTrainFeed:
batch_size: 8
use_process: True
dataset:
dataset_dir: dataset/wider_face
annotation: wider_face_split/wider_face_train_bbx_gt.txt
image_dir: WIDER_train/images
image_shape: [3, 640, 640]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeBox {}
- !RandomDistort
brightness_lower: 0.875
brightness_upper: 1.125
is_order: true
- !ExpandImage
max_ratio: 4
prob: 0.5
- !CropImageWithDataAchorSampling
anchor_sampler:
- [1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0]
batch_sampler:
- [1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
target_size: 640
- !RandomInterpImage
target_size: 640
- !RandomFlipImage
is_normalized: true
- !Permute {}
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [127.502231, 127.502231, 127.502231]
SSDEvalFeed:
batch_size: 1
use_process: false
fields: ['image', 'im_id', 'gt_box']
dataset:
dataset_dir: dataset/wider_face
annotation: annotFile.txt #wider_face_split/wider_face_val_bbx_gt.txt
image_dir: WIDER_val/images
drop_last: false
image_shape: [3, 640, 640]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeBox {}
- !ResizeImage
interp: 1
target_size: 640
use_cv2: false
- !Permute {}
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [127.502231, 127.502231, 127.502231]
SSDTestFeed:
batch_size: 1
use_process: false
dataset:
use_default_label: true
drop_last: false
image_shape: [3, 640, 640]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !ResizeImage
interp: 1
target_size: 640
use_cv2: false
- !Permute {}
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [127.502231, 127.502231, 127.502231]
architecture: BlazeFace
max_iters: 320000
train_feed: SSDTrainFeed
eval_feed: SSDEvalFeed
test_feed: SSDTestFeed
pretrain_weights:
use_gpu: true
snapshot_iter: 10000
log_smooth_window: 20
log_iter: 20
metric: WIDERFACE
save_dir: output
weights: output/blazeface_nas/model_final/
# 1(label_class) + 1(background)
num_classes: 2
BlazeFace:
backbone: BlazeNet
output_decoder:
keep_top_k: 750
nms_threshold: 0.3
nms_top_k: 5000
score_threshold: 0.01
min_sizes: [[16.,24.], [32., 48., 64., 80., 96., 128.]]
use_density_prior_box: false
BlazeNet:
blaze_filters: [[12, 12], [12, 12, 2], [12, 12]]
double_blaze_filters: [[12, 16, 24, 2], [24, 12, 24], [24, 16, 72, 2], [72, 12, 72]]
with_extra_blocks: true
lite_edition: false
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [240000, 300000]
OptimizerBuilder:
optimizer:
momentum: 0.0
type: RMSPropOptimizer
regularizer:
factor: 0.0005
type: L2
SSDTrainFeed:
batch_size: 8
use_process: True
dataset:
dataset_dir: dataset/wider_face
annotation: wider_face_split/wider_face_train_bbx_gt.txt
image_dir: WIDER_train/images
image_shape: [3, 640, 640]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeBox {}
- !RandomDistort
brightness_lower: 0.875
brightness_upper: 1.125
is_order: true
- !ExpandImage
max_ratio: 4
prob: 0.5
- !CropImageWithDataAchorSampling
anchor_sampler:
- [1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0]
batch_sampler:
- [1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
target_size: 640
- !RandomInterpImage
target_size: 640
- !RandomFlipImage
is_normalized: true
- !Permute {}
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [127.502231, 127.502231, 127.502231]
SSDEvalFeed:
batch_size: 1
use_process: false
fields: ['image', 'im_id', 'gt_box']
dataset:
dataset_dir: dataset/wider_face
annotation: wider_face_split/wider_face_val_bbx_gt.txt
image_dir: WIDER_val/images
drop_last: false
image_shape: [3, 640, 640]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeBox {}
- !ResizeImage
interp: 1
target_size: 640
use_cv2: false
- !Permute {}
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [127.502231, 127.502231, 127.502231]
SSDTestFeed:
batch_size: 1
use_process: false
dataset:
use_default_label: true
drop_last: false
image_shape: [3, 640, 640]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !ResizeImage
interp: 1
target_size: 640
use_cv2: false
- !Permute {}
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [127.502231, 127.502231, 127.502231]
architecture: FaceBoxes
train_feed: SSDTrainFeed
eval_feed: SSDEvalFeed
test_feed: SSDTestFeed
pretrain_weights:
use_gpu: true
max_iters: 320000
snapshot_iter: 10000
log_smooth_window: 20
log_iter: 20
metric: WIDERFACE
save_dir: output
weights: output/faceboxes/model_final/
# 1(label_class) + 1(background)
num_classes: 2
FaceBoxes:
backbone: FaceBoxNet
densities: [[4, 2, 1], [1], [1]]
fixed_sizes: [[32., 64., 128.], [256.], [512.]]
output_decoder:
keep_top_k: 750
nms_threshold: 0.3
nms_top_k: 5000
score_threshold: 0.01
FaceBoxNet:
with_extra_blocks: true
lite_edition: false
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [240000, 300000]
OptimizerBuilder:
optimizer:
momentum: 0.0
type: RMSPropOptimizer
regularizer:
factor: 0.0005
type: L2
SSDTrainFeed:
batch_size: 8
use_process: True
dataset:
dataset_dir: dataset/wider_face
annotation: wider_face_split/wider_face_train_bbx_gt.txt
image_dir: WIDER_train/images
image_shape: [3, 640, 640]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeBox {}
- !RandomDistort
brightness_lower: 0.875
brightness_upper: 1.125
is_order: true
- !ExpandImage
max_ratio: 4
prob: 0.5
- !CropImageWithDataAchorSampling
anchor_sampler:
- [1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0]
batch_sampler:
- [1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
target_size: 640
- !RandomInterpImage
target_size: 640
- !RandomFlipImage
is_normalized: true
- !Permute {}
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [127.502231, 127.502231, 127.502231]
SSDEvalFeed:
batch_size: 1
use_process: false
fields: ['image', 'im_id', 'gt_box']
dataset:
dataset_dir: dataset/wider_face
annotation: wider_face_split/wider_face_val_bbx_gt.txt
image_dir: WIDER_val/images
drop_last: false
image_shape: [3, 640, 640]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeBox {}
- !ResizeImage
interp: 1
target_size: 640
use_cv2: false
- !Permute {}
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [127.502231, 127.502231, 127.502231]
SSDTestFeed:
batch_size: 1
use_process: false
dataset:
use_default_label: true
drop_last: false
image_shape: [3, 640, 640]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !ResizeImage
interp: 1
target_size: 640
use_cv2: false
- !Permute {}
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [127.502231, 127.502231, 127.502231]
architecture: FaceBoxes
train_feed: SSDTrainFeed
eval_feed: SSDEvalFeed
test_feed: SSDTestFeed
pretrain_weights:
use_gpu: true
max_iters: 320000
snapshot_iter: 10000
log_smooth_window: 20
log_iter: 20
metric: WIDERFACE
save_dir: output
weights: output/faceboxes_lite/model_final/
# 1(label_class) + 1(background)
num_classes: 2
FaceBoxes:
backbone: FaceBoxNet
densities: [[2, 1, 1], [1, 1]]
fixed_sizes: [[16., 32., 64.], [96., 128.]]
output_decoder:
keep_top_k: 750
nms_threshold: 0.3
nms_top_k: 5000
score_threshold: 0.01
FaceBoxNet:
with_extra_blocks: true
lite_edition: true
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [240000, 300000]
OptimizerBuilder:
optimizer:
momentum: 0.0
type: RMSPropOptimizer
regularizer:
factor: 0.0005
type: L2
SSDTrainFeed:
batch_size: 8
use_process: True
dataset:
dataset_dir: dataset/wider_face
annotation: wider_face_split/wider_face_train_bbx_gt.txt
image_dir: WIDER_train/images
image_shape: [3, 640, 640]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeBox {}
- !RandomDistort
brightness_lower: 0.875
brightness_upper: 1.125
is_order: true
- !ExpandImage
max_ratio: 4
prob: 0.5
- !CropImageWithDataAchorSampling
anchor_sampler:
- [1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0]
batch_sampler:
- [1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
target_size: 640
- !RandomInterpImage
target_size: 640
- !RandomFlipImage
is_normalized: true
- !Permute {}
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [127.502231, 127.502231, 127.502231]
SSDEvalFeed:
batch_size: 1
use_process: false
fields: ['image', 'im_id', 'gt_box']
dataset:
dataset_dir: dataset/wider_face
annotation: wider_face_split/wider_face_val_bbx_gt.txt
image_dir: WIDER_val/images
drop_last: false
image_shape: [3, 640, 640]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeBox {}
- !ResizeImage
interp: 1
target_size: 640
use_cv2: false
- !Permute {}
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [127.502231, 127.502231, 127.502231]
SSDTestFeed:
batch_size: 1
use_process: false
dataset:
use_default_label: true
drop_last: false
image_shape: [3, 640, 640]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !ResizeImage
interp: 1
target_size: 640
use_cv2: false
- !Permute {}
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [127.502231, 127.502231, 127.502231]
......@@ -781,7 +781,7 @@ class SSDEvalFeed(DataFeed):
bufsize=10,
use_process=False,
memsize=None):
sample_transforms.append(ArrangeEvalSSD())
sample_transforms.append(ArrangeEvalSSD(fields))
super(SSDEvalFeed, self).__init__(
dataset,
fields,
......
......@@ -200,8 +200,9 @@ class ArrangeEvalSSD(BaseOperator):
Transform dict to tuple format needed for training.
"""
def __init__(self):
def __init__(self, fields):
super(ArrangeEvalSSD, self).__init__()
self.fields = fields
def __call__(self, sample, context=None):
"""
......@@ -212,17 +213,25 @@ class ArrangeEvalSSD(BaseOperator):
Returns:
sample: a tuple containing the following items: (image)
"""
im = sample['image']
outs = []
if len(sample['gt_bbox']) != len(sample['gt_class']):
raise ValueError("gt num mismatch: bbox and class.")
im_id = sample['im_id']
for field in self.fields:
if field == 'im_shape':
h = sample['h']
w = sample['w']
im_shape = np.array((h, w))
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
difficult = sample['difficult']
outs = (im, im_shape, im_id, gt_bbox, gt_class, difficult)
outs.append(im_shape)
elif field == 'is_difficult':
outs.append(sample['difficult'])
elif field == 'gt_box':
outs.append(sample['gt_bbox'])
elif field == 'gt_label':
outs.append(sample['gt_class'])
else:
outs.append(sample[field])
outs = tuple(outs)
return outs
......
......@@ -102,6 +102,7 @@ def bbox_area_sampling(bboxes, labels, scores, target_size, min_size):
else:
new_bboxes.append(bbox)
new_labels.append(labels[i])
if scores is not None and scores.size != 0:
new_scores.append(scores[i])
bboxes = np.array(new_bboxes)
labels = np.array(new_labels)
......
......@@ -640,7 +640,7 @@ class CropImageWithDataAchorSampling(BaseOperator):
self.sampling_prob = sampling_prob
self.min_size = min_size
self.avoid_no_bbox = avoid_no_bbox
self.scale_array = np.array(das_anchor_scales)
self.das_anchor_scales = np.array(das_anchor_scales)
def __call__(self, sample, context):
"""
......@@ -674,8 +674,8 @@ class CropImageWithDataAchorSampling(BaseOperator):
if found >= sampler[0]:
break
sample_bbox = data_anchor_sampling(
gt_bbox, image_width, image_height, self.scale_array,
self.target_size)
gt_bbox, image_width, image_height,
self.das_anchor_scales, self.target_size)
if sample_bbox == 0:
break
if satisfy_sample_constraint_coverage(sampler, sample_bbox,
......
......@@ -108,9 +108,7 @@ class BlazeFace(object):
use_density_prior_box=False):
def permute_and_reshape(input, last_dim):
trans = fluid.layers.transpose(input, perm=[0, 2, 3, 1])
compile_shape = [
trans.shape[0], np.prod(trans.shape[1:]) // last_dim, last_dim
]
compile_shape = [0, -1, last_dim]
return fluid.layers.reshape(trans, shape=compile_shape)
def _is_list_or_tuple_(data):
......
......@@ -93,9 +93,7 @@ class FaceBoxes(object):
def _multi_box_head(self, inputs, image, num_classes=2):
def permute_and_reshape(input, last_dim):
trans = fluid.layers.transpose(input, perm=[0, 2, 3, 1])
compile_shape = [
trans.shape[0], np.prod(trans.shape[1:]) // last_dim, last_dim
]
compile_shape = [0, -1, last_dim]
return fluid.layers.reshape(trans, shape=compile_shape)
def _is_list_or_tuple_(data):
......
......@@ -238,7 +238,6 @@ class FaceBoxNet(object):
use_cudnn=use_cudnn,
param_attr=parameter_attr,
bias_attr=False)
print("{}:{}".format(name, conv.shape))
return fluid.layers.batch_norm(input=conv, act=act)
def _conv_norm_crelu(
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from ppdet.data.source.widerface_loader import widerface_label
from ppdet.utils.coco_eval import bbox2out
import logging
logger = logging.getLogger(__name__)
__all__ = [
'get_shrink', 'bbox_vote', 'save_widerface_bboxes', 'save_fddb_bboxes',
'to_chw_bgr', 'bbox2out', 'get_category_info'
]
def to_chw_bgr(image):
"""
Transpose image from HWC to CHW and from RBG to BGR.
Args:
image (np.array): an image with HWC and RBG layout.
"""
# HWC to CHW
if len(image.shape) == 3:
image = np.swapaxes(image, 1, 2)
image = np.swapaxes(image, 1, 0)
# RBG to BGR
image = image[[2, 1, 0], :, :]
return image
def bbox_vote(det):
order = det[:, 4].ravel().argsort()[::-1]
det = det[order, :]
if det.shape[0] == 0:
dets = np.array([[10, 10, 20, 20, 0.002]])
det = np.empty(shape=[0, 5])
while det.shape[0] > 0:
# IOU
area = (det[:, 2] - det[:, 0] + 1) * (det[:, 3] - det[:, 1] + 1)
xx1 = np.maximum(det[0, 0], det[:, 0])
yy1 = np.maximum(det[0, 1], det[:, 1])
xx2 = np.minimum(det[0, 2], det[:, 2])
yy2 = np.minimum(det[0, 3], det[:, 3])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
o = inter / (area[0] + area[:] - inter)
# nms
merge_index = np.where(o >= 0.3)[0]
det_accu = det[merge_index, :]
det = np.delete(det, merge_index, 0)
if merge_index.shape[0] <= 1:
if det.shape[0] == 0:
try:
dets = np.row_stack((dets, det_accu))
except:
dets = det_accu
continue
det_accu[:, 0:4] = det_accu[:, 0:4] * np.tile(det_accu[:, -1:], (1, 4))
max_score = np.max(det_accu[:, 4])
det_accu_sum = np.zeros((1, 5))
det_accu_sum[:, 0:4] = np.sum(det_accu[:, 0:4],
axis=0) / np.sum(det_accu[:, -1:])
det_accu_sum[:, 4] = max_score
try:
dets = np.row_stack((dets, det_accu_sum))
except:
dets = det_accu_sum
dets = dets[0:750, :]
# Only keep 0.3 or more
keep_index = np.where(dets[:, 4] >= 0.01)[0]
dets = dets[keep_index, :]
return dets
def get_shrink(height, width):
"""
Args:
height (int): image height.
width (int): image width.
"""
# avoid out of memory
max_shrink_v1 = (0x7fffffff / 577.0 / (height * width))**0.5
max_shrink_v2 = ((678 * 1024 * 2.0 * 2.0) / (height * width))**0.5
def get_round(x, loc):
str_x = str(x)
if '.' in str_x:
str_before, str_after = str_x.split('.')
len_after = len(str_after)
if len_after >= 3:
str_final = str_before + '.' + str_after[0:loc]
return float(str_final)
else:
return x
max_shrink = get_round(min(max_shrink_v1, max_shrink_v2), 2) - 0.3
if max_shrink >= 1.5 and max_shrink < 2:
max_shrink = max_shrink - 0.1
elif max_shrink >= 2 and max_shrink < 3:
max_shrink = max_shrink - 0.2
elif max_shrink >= 3 and max_shrink < 4:
max_shrink = max_shrink - 0.3
elif max_shrink >= 4 and max_shrink < 5:
max_shrink = max_shrink - 0.4
elif max_shrink >= 5:
max_shrink = max_shrink - 0.5
shrink = max_shrink if max_shrink < 1 else 1
return shrink, max_shrink
def save_widerface_bboxes(image_path, bboxes_scores, output_dir):
image_name = image_path.split('/')[-1]
image_class = image_path.split('/')[-2]
odir = os.path.join(output_dir, image_class)
if not os.path.exists(odir):
os.makedirs(odir)
ofname = os.path.join(odir, '%s.txt' % (image_name[:-4]))
f = open(ofname, 'w')
f.write('{:s}\n'.format(image_class + '/' + image_name))
f.write('{:d}\n'.format(bboxes_scores.shape[0]))
for box_score in bboxes_scores:
xmin, ymin, xmax, ymax, score = box_score
f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'.format(xmin, ymin, (
xmax - xmin + 1), (ymax - ymin + 1), score))
f.close()
logger.info("The predicted result is saved as {}".format(ofname))
def save_fddb_bboxes(bboxes_scores,
output_dir,
output_fname='pred_fddb_res.txt'):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
predict_file = os.path.join(output_dir, output_fname)
f = open(predict_file, 'w')
for image_path, dets in bboxes_scores.iteritems():
f.write('{:s}\n'.format(image_path))
f.write('{:d}\n'.format(dets.shape[0]))
for box_score in dets:
xmin, ymin, xmax, ymax, score = box_score
width, height = xmax - xmin, ymax - ymin
f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'
.format(xmin, ymin, width, height, score))
logger.info("The predicted result is saved as {}".format(predict_file))
return predict_file
def get_category_info(anno_file=None,
with_background=True,
use_default_label=False):
if use_default_label or anno_file is None \
or not os.path.exists(anno_file):
logger.info("Not found annotation file {}, load "
"wider-face categories.".format(anno_file))
return widerfaceall_category_info(with_background)
else:
logger.info("Load categories from {}".format(anno_file))
return get_category_info_from_anno(anno_file, with_background)
def get_category_info_from_anno(anno_file, with_background=True):
"""
Get class id to category id map and category id
to category name map from annotation file.
Args:
anno_file (str): annotation file path
with_background (bool, default True):
whether load background as class 0.
"""
cats = []
with open(anno_file) as f:
for line in f.readlines():
cats.append(line.strip())
if cats[0] != 'background' and with_background:
cats.insert(0, 'background')
if cats[0] == 'background' and not with_background:
cats = cats[1:]
clsid2catid = {i: i for i in range(len(cats))}
catid2name = {i: name for i, name in enumerate(cats)}
return clsid2catid, catid2name
def widerfaceall_category_info(with_background=True):
"""
Get class id to category id map and category id
to category name map of mixup wider_face dataset
Args:
with_background (bool, default True):
whether load background as class 0.
"""
label_map = widerface_label(with_background)
label_map = sorted(label_map.items(), key=lambda x: x[1])
cats = [l[0] for l in label_map]
if with_background:
cats.insert(0, 'background')
clsid2catid = {i: i for i in range(len(cats))}
catid2name = {i: name for i, name in enumerate(cats)}
return clsid2catid, catid2name
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle.fluid as fluid
import numpy as np
from PIL import Image
from collections import OrderedDict
import ppdet.utils.checkpoint as checkpoint
from ppdet.utils.cli import ArgsParser
from ppdet.utils.check import check_gpu
from ppdet.utils.widerface_eval_utils import get_shrink, bbox_vote, \
save_widerface_bboxes, save_fddb_bboxes, to_chw_bgr
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.modeling.model_input import create_feed
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def face_img_process(image,
mean=[104., 117., 123.],
std=[127.502231, 127.502231, 127.502231]):
img = np.array(image)
img = to_chw_bgr(img)
img = img.astype('float32')
img -= np.array(mean)[:, np.newaxis, np.newaxis].astype('float32')
img /= np.array(std)[:, np.newaxis, np.newaxis].astype('float32')
img = [img]
img = np.array(img)
return img
def face_eval_run(exe,
compile_program,
fetches,
img_root_dir,
gt_file,
pred_dir='output/pred',
eval_mode='widerface'):
# load ground truth files
with open(gt_file, 'r') as f:
gt_lines = f.readlines()
imid2path = []
pos_gt = 0
while pos_gt < len(gt_lines):
name_gt = gt_lines[pos_gt].strip('\n\t').split()[0]
imid2path.append(name_gt)
pos_gt += 1
n_gt = int(gt_lines[pos_gt].strip('\n\t').split()[0])
pos_gt += 1 + n_gt
logger.info('The ground truth file load {} images'.format(len(imid2path)))
dets_dist = OrderedDict()
for iter_id, im_path in enumerate(imid2path):
image_path = os.path.join(img_root_dir, im_path)
if eval_mode == 'fddb':
image_path += '.jpg'
image = Image.open(image_path).convert('RGB')
shrink, max_shrink = get_shrink(image.size[1], image.size[0])
det0 = detect_face(exe, compile_program, fetches, image, shrink)
det1 = flip_test(exe, compile_program, fetches, image, shrink)
[det2, det3] = multi_scale_test(exe, compile_program, fetches, image,
max_shrink)
det4 = multi_scale_test_pyramid(exe, compile_program, fetches, image,
max_shrink)
det = np.row_stack((det0, det1, det2, det3, det4))
dets = bbox_vote(det)
if eval_mode == 'widerface':
save_widerface_bboxes(image_path, dets, pred_dir)
else:
dets_dist[im_path] = dets
if iter_id % 100 == 0:
logger.info('Test iter {}'.format(iter_id))
if eval_mode == 'fddb':
save_fddb_bboxes(dets_dist, pred_dir)
logger.info("Finish evaluation.")
def detect_face(exe, compile_program, fetches, image, shrink):
image_shape = [3, image.size[1], image.size[0]]
if shrink != 1:
h, w = int(image_shape[1] * shrink), int(image_shape[2] * shrink)
image = image.resize((w, h), Image.ANTIALIAS)
image_shape = [3, h, w]
img = face_img_process(image)
detection, = exe.run(compile_program,
feed={'image': img},
fetch_list=[fetches['bbox']],
return_numpy=False)
detection = np.array(detection)
# layout: xmin, ymin, xmax. ymax, score
if np.prod(detection.shape) == 1:
logger.info("No face detected")
return np.array([[0, 0, 0, 0, 0]])
det_conf = detection[:, 1]
det_xmin = image_shape[2] * detection[:, 2] / shrink
det_ymin = image_shape[1] * detection[:, 3] / shrink
det_xmax = image_shape[2] * detection[:, 4] / shrink
det_ymax = image_shape[1] * detection[:, 5] / shrink
det = np.column_stack((det_xmin, det_ymin, det_xmax, det_ymax, det_conf))
return det
def flip_test(exe, compile_program, fetches, image, shrink):
img = image.transpose(Image.FLIP_LEFT_RIGHT)
det_f = detect_face(exe, compile_program, fetches, img, shrink)
det_t = np.zeros(det_f.shape)
# image.size: [width, height]
det_t[:, 0] = image.size[0] - det_f[:, 2]
det_t[:, 1] = det_f[:, 1]
det_t[:, 2] = image.size[0] - det_f[:, 0]
det_t[:, 3] = det_f[:, 3]
det_t[:, 4] = det_f[:, 4]
return det_t
def multi_scale_test(exe, compile_program, fetches, image, max_shrink):
# Shrink detecting is only used to detect big faces
st = 0.5 if max_shrink >= 0.75 else 0.5 * max_shrink
det_s = detect_face(exe, compile_program, fetches, image, st)
index = np.where(
np.maximum(det_s[:, 2] - det_s[:, 0] + 1, det_s[:, 3] - det_s[:, 1] + 1)
> 30)[0]
det_s = det_s[index, :]
# Enlarge one times
bt = min(2, max_shrink) if max_shrink > 1 else (st + max_shrink) / 2
det_b = detect_face(exe, compile_program, fetches, image, bt)
# Enlarge small image x times for small faces
if max_shrink > 2:
bt *= 2
while bt < max_shrink:
det_b = np.row_stack((det_b, detect_face(exe, compile_program,
fetches, image, bt)))
bt *= 2
det_b = np.row_stack((det_b, detect_face(exe, compile_program, fetches,
image, max_shrink)))
# Enlarged images are only used to detect small faces.
if bt > 1:
index = np.where(
np.minimum(det_b[:, 2] - det_b[:, 0] + 1,
det_b[:, 3] - det_b[:, 1] + 1) < 100)[0]
det_b = det_b[index, :]
# Shrinked images are only used to detect big faces.
else:
index = np.where(
np.maximum(det_b[:, 2] - det_b[:, 0] + 1,
det_b[:, 3] - det_b[:, 1] + 1) > 30)[0]
det_b = det_b[index, :]
return det_s, det_b
def multi_scale_test_pyramid(exe, compile_program, fetches, image, max_shrink):
# Use image pyramids to detect faces
det_b = detect_face(exe, compile_program, fetches, image, 0.25)
index = np.where(
np.maximum(det_b[:, 2] - det_b[:, 0] + 1, det_b[:, 3] - det_b[:, 1] + 1)
> 30)[0]
det_b = det_b[index, :]
st = [0.75, 1.25, 1.5, 1.75]
for i in range(len(st)):
if st[i] <= max_shrink:
det_temp = detect_face(exe, compile_program, fetches, image, st[i])
# Enlarged images are only used to detect small faces.
if st[i] > 1:
index = np.where(
np.minimum(det_temp[:, 2] - det_temp[:, 0] + 1,
det_temp[:, 3] - det_temp[:, 1] + 1) < 100)[0]
det_temp = det_temp[index, :]
# Shrinked images are only used to detect big faces.
else:
index = np.where(
np.maximum(det_temp[:, 2] - det_temp[:, 0] + 1,
det_temp[:, 3] - det_temp[:, 1] + 1) > 30)[0]
det_temp = det_temp[index, :]
det_b = np.row_stack((det_b, det_temp))
return det_b
def main():
"""
Main evaluate function
"""
cfg = load_config(FLAGS.config)
if 'architecture' in cfg:
main_arch = cfg.architecture
else:
raise ValueError("'architecture' not specified in config file.")
merge_config(FLAGS.opt)
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu(cfg.use_gpu)
if 'eval_feed' not in cfg:
eval_feed = create(main_arch + 'EvalFeed')
else:
eval_feed = create(cfg.eval_feed)
# define executor
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# build program
model = create(main_arch)
startup_prog = fluid.Program()
eval_prog = fluid.Program()
with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard():
_, feed_vars = create_feed(eval_feed, use_pyreader=False)
fetches = model.eval(feed_vars)
eval_prog = eval_prog.clone(True)
# load model
exe.run(startup_prog)
if 'weights' in cfg:
checkpoint.load_params(exe, eval_prog, cfg.weights)
assert cfg.metric in ['WIDERFACE'], \
"unknown metric type {}".format(cfg.metric)
annotation_file = getattr(eval_feed.dataset, 'annotation', None)
dataset_dir = FLAGS.dataset_dir if FLAGS.dataset_dir else \
getattr(eval_feed.dataset, 'dataset_dir', None)
img_root_dir = dataset_dir
if FLAGS.eval_mode == "widerface":
image_dir = getattr(eval_feed.dataset, 'image_dir', None)
img_root_dir = os.path.join(dataset_dir, image_dir)
gt_file = os.path.join(dataset_dir, annotation_file)
pred_dir = FLAGS.output_eval if FLAGS.output_eval else 'output/pred'
face_eval_run(
exe,
eval_prog,
fetches,
img_root_dir,
gt_file,
pred_dir=pred_dir,
eval_mode=FLAGS.eval_mode)
if __name__ == '__main__':
parser = ArgsParser()
parser.add_argument(
"-d",
"--dataset_dir",
default=None,
type=str,
help="Dataset path, same as DataFeed.dataset.dataset_dir")
parser.add_argument(
"-f",
"--output_eval",
default=None,
type=str,
help="Evaluation file directory, default is current directory.")
parser.add_argument(
"-e",
"--eval_mode",
default="widerface",
type=str,
help="Evaluation mode, include `widerface` and `fddb`, default is `widerface`."
)
FLAGS = parser.parse_args()
main()
......@@ -186,12 +186,12 @@ def main():
save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog)
# parse infer fetches
assert cfg.metric in ['COCO', 'VOC'], \
assert cfg.metric in ['COCO', 'VOC', 'WIDERFACE'], \
"unknown metric type {}".format(cfg.metric)
extra_keys = []
if cfg['metric'] == 'COCO':
extra_keys = ['im_info', 'im_id', 'im_shape']
if cfg['metric'] == 'VOC':
if cfg['metric'] == 'VOC' or cfg['metric'] == 'WIDERFACE':
extra_keys = ['im_id', 'im_shape']
keys, values, _ = parse_fetches(test_fetches, infer_prog, extra_keys)
......@@ -200,6 +200,8 @@ def main():
from ppdet.utils.coco_eval import bbox2out, mask2out, get_category_info
if cfg.metric == "VOC":
from ppdet.utils.voc_eval import bbox2out, get_category_info
if cfg.metric == "WIDERFACE":
from ppdet.utils.widerface_eval_utils import bbox2out, get_category_info
anno_file = getattr(test_feed.dataset, 'annotation', None)
with_background = getattr(test_feed, 'with_background', True)
......
......@@ -154,6 +154,8 @@ def main():
extra_keys = ['im_info', 'im_id', 'im_shape']
if cfg.metric == 'VOC':
extra_keys = ['gt_box', 'gt_label', 'is_difficult']
if cfg.metric == 'WIDERFACE':
extra_keys = ['im_id', 'im_shape', 'gt_box']
eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,
extra_keys)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册