未验证 提交 776ca93e 编写于 作者: G Guanghua Yu 提交者: GitHub

add face key-point detection (#1001)

上级 ce5ab172
architecture: BlazeFace
max_iters: 160000
pretrain_weights:
use_gpu: true
snapshot_iter: 10000
log_smooth_window: 20
log_iter: 20
metric: WIDERFACE
save_dir: output
weights: output/blazeface_keypoint/model_final.pdparams
# 1(label_class) + 1(background)
num_classes: 2
BlazeFace:
backbone: BlazeNet
output_decoder:
keep_top_k: 750
nms_threshold: 0.3
nms_top_k: 5000
score_threshold: 0.01
min_sizes: [[16.,24.], [32., 48., 64., 80., 96., 128.]]
use_density_prior_box: false
with_lmk: true
lmk_loss:
overlap_threshold: 0.35
neg_overlap: 0.35
BlazeNet:
with_extra_blocks: true
lite_edition: false
LearningRate:
base_lr: 0.002
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [120000, 150000]
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
TrainReader:
inputs_def:
image_shape: [3, 640, 640]
fields: ['image', 'gt_bbox', 'gt_class', 'gt_keypoint', 'keypoint_ignore']
dataset:
!WIDERFaceDataSet
dataset_dir: dataset/wider_face
anno_path: wider_face_split/wider_face_train_bbx_lmk_gt.txt
image_dir: WIDER_train/images
with_lmk: true
sample_transforms:
- !DecodeImage
to_rgb: true
- !NormalizeBox {}
- !RandomDistort
brightness_lower: 0.875
brightness_upper: 1.125
is_order: true
- !ExpandImage
max_ratio: 4
prob: 0.5
- !CropImageWithDataAchorSampling
anchor_sampler:
- [1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0]
batch_sampler:
- [1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
target_size: 640
- !ResizeImage
target_size: 640
interp: 1
- !RandomInterpImage
target_size: 640
- !RandomFlipImage
is_normalized: true
- !Permute {}
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [127.502231, 127.502231, 127.502231]
batch_size: 16
use_process: true
worker_num: 8
shuffle: true
EvalReader:
inputs_def:
fields: ['image', 'im_id']
dataset:
!WIDERFaceDataSet
dataset_dir: dataset/wider_face
anno_path: wider_face_split/wider_face_val_bbx_gt.txt
image_dir: WIDER_val/images
sample_transforms:
- !DecodeImage
to_rgb: true
- !NormalizeBox {}
- !Permute {}
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [127.502231, 127.502231, 127.502231]
batch_size: 1
TestReader:
inputs_def:
fields: ['image', 'im_id', 'im_shape']
dataset:
!ImageFolder
use_default_label: true
sample_transforms:
- !DecodeImage
to_rgb: true
- !Permute {}
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [127.502231, 127.502231, 127.502231]
batch_size: 1
...@@ -16,6 +16,6 @@ wget https://dataset.bj.bcebos.com/wider_face/WIDER_val.zip ...@@ -16,6 +16,6 @@ wget https://dataset.bj.bcebos.com/wider_face/WIDER_val.zip
wget https://dataset.bj.bcebos.com/wider_face/wider_face_split.zip wget https://dataset.bj.bcebos.com/wider_face/wider_face_split.zip
# Extract the data. # Extract the data.
echo "Extracting..." echo "Extracting..."
unzip WIDER_train.zip unzip -q WIDER_train.zip
unzip WIDER_val.zip unzip -q WIDER_val.zip
unzip wider_face_split.zip unzip -q wider_face_split.zip
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
- [数据准备](#数据准备) - [数据准备](#数据准备)
- [训练与推理](#训练与推理) - [训练与推理](#训练与推理)
- [评估](#评估) - [评估](#评估)
- [人脸关键点检测](#人脸关键点检测)
- [算法细节](#算法细节) - [算法细节](#算法细节)
- [如何贡献代码](#如何贡献代码) - [如何贡献代码](#如何贡献代码)
...@@ -142,7 +143,7 @@ cd dataset/wider_face && ./download.sh ...@@ -142,7 +143,7 @@ cd dataset/wider_face && ./download.sh
训练流程与推理流程方法与其他算法一致,请参考[GETTING_STARTED_cn.md](../tutorials/GETTING_STARTED_cn.md) 训练流程与推理流程方法与其他算法一致,请参考[GETTING_STARTED_cn.md](../tutorials/GETTING_STARTED_cn.md)
**注意:** **注意:**
- `BlazeFace``FaceBoxes`训练是以每卡`batch_size=8`在4卡GPU上进行训练(总`batch_size`是32),并且训练320000轮 - `BlazeFace``FaceBoxes`训练是以每卡`batch_size=8`在4卡GPU上进行训练(总`batch_size`是32),并且训练320000轮
(如果你的GPU数达不到4,请参考[学习率计算规则表](../tutorials/GETTING_STARTED_cn.html#faq))。 (如果你的GPU数达不到4,请参考[学习率计算规则表](../FAQ.md))。
- 人脸检测模型目前我们不支持边训练边评估。 - 人脸检测模型目前我们不支持边训练边评估。
...@@ -241,6 +242,20 @@ cd dataset/fddb/evaluation ...@@ -241,6 +242,20 @@ cd dataset/fddb/evaluation
(2)`OUTPUT_DIR`是FDDB评估输出结果文件前缀,会生成两个文件`{OUTPUT_DIR}ContROC.txt``{OUTPUT_DIR}DiscROC.txt` (2)`OUTPUT_DIR`是FDDB评估输出结果文件前缀,会生成两个文件`{OUTPUT_DIR}ContROC.txt``{OUTPUT_DIR}DiscROC.txt`
(3)参数用法及注释可通过执行`./evaluate --help`来获取。 (3)参数用法及注释可通过执行`./evaluate --help`来获取。
## 人脸关键点检测
(1)下载PaddleDetection开放的WIDER-FACE数据集人脸关键点标注文件([链接](https://dataset.bj.bcebos.com/wider_face/wider_face_train_bbx_lmk_gt.txt)),并拷贝至`wider_face/wider_face_split`文件夹中:
```shell
cd dataset/wider_face/wider_face_split/
wget https://dataset.bj.bcebos.com/wider_face/wider_face_train_bbx_lmk_gt.txt
```
(2)使用`configs/face_detection/blazeface_keypoint.yml`配置文件进行训练与评估,使用方法与上一节内容一致。
![](../images/12_Group_Group_12_Group_Group_12_84.jpg)
## 算法细节 ## 算法细节
### BlazeFace ### BlazeFace
...@@ -257,7 +272,7 @@ cd dataset/fddb/evaluation ...@@ -257,7 +272,7 @@ cd dataset/fddb/evaluation
- 原始版本: 参考原始论文复现; - 原始版本: 参考原始论文复现;
- Lite版本: 使用3x3卷积替换5x5卷积,更少的网络层数和通道数; - Lite版本: 使用3x3卷积替换5x5卷积,更少的网络层数和通道数;
- NAS版本: 使用神经网络搜索算法构建网络结构,相比于`Lite`版本,NAS版本需要更少的网络层数和通道数。 - NAS版本: 使用神经网络搜索算法构建网络结构,相比于`Lite`版本,NAS版本需要更少的网络层数和通道数。
- NAS_V2版本1: 基于PaddleSlim中SANAS算法在blazeface-NAS的基础上搜索出来的结构,相比`NAS`版本,NAS_V2版本的精度平均高出3个点,在855芯片上的硬件延时相对`NAS`版本仅增加5%。 - NAS_V2版本: 基于PaddleSlim中SANAS算法在blazeface-NAS的基础上搜索出来的结构,相比`NAS`版本,NAS_V2版本的精度平均高出3个点,在855芯片上的硬件延时相对`NAS`版本仅增加5%。
### FaceBoxes ### FaceBoxes
**简介:** **简介:**
......
...@@ -8,6 +8,7 @@ English | [简体中文](FACE_DETECTION.md) ...@@ -8,6 +8,7 @@ English | [简体中文](FACE_DETECTION.md)
- [Data Pipline](#Data-Pipline) - [Data Pipline](#Data-Pipline)
- [Training and Inference](#Training-and-Inference) - [Training and Inference](#Training-and-Inference)
- [Evaluation](#Evaluation) - [Evaluation](#Evaluation)
- [Face key-point detection](#Face-key-point-detection)
- [Algorithm Description](#Algorithm-Description) - [Algorithm Description](#Algorithm-Description)
- [Contributing](#Contributing) - [Contributing](#Contributing)
...@@ -155,7 +156,7 @@ Please refer to [READER.md](../advanced_tutorials/READER.md) for details. ...@@ -155,7 +156,7 @@ Please refer to [READER.md](../advanced_tutorials/READER.md) for details.
**NOTES:** **NOTES:**
- `BlazeFace` and `FaceBoxes` is trained in 4 GPU with `batch_size=8` per gpu (total batch size as 32) - `BlazeFace` and `FaceBoxes` is trained in 4 GPU with `batch_size=8` per gpu (total batch size as 32)
and trained 320000 iters.(If your GPU count is not 4, please refer to the rule of training parameters and trained 320000 iters.(If your GPU count is not 4, please refer to the rule of training parameters
in the table of [calculation rules](../tutorials/GETTING_STARTED.html#faq)). in the table of [calculation rules](../FAQ.md)).
- Currently we do not support evaluation in training. - Currently we do not support evaluation in training.
### Evaluation ### Evaluation
...@@ -258,6 +259,20 @@ cd dataset/fddb/evaluation ...@@ -258,6 +259,20 @@ cd dataset/fddb/evaluation
which will generate two files `{OUTPUT_DIR}ContROC.txt``{OUTPUT_DIR}DiscROC.txt`; which will generate two files `{OUTPUT_DIR}ContROC.txt``{OUTPUT_DIR}DiscROC.txt`;
(3)The interpretation of the argument can be performed by `./evaluate --help`. (3)The interpretation of the argument can be performed by `./evaluate --help`.
## Face key-point detection
(1)Download face key-point annotation file in WIDER FACE dataset([Link](https://dataset.bj.bcebos.com/wider_face/wider_face_train_bbx_lmk_gt.txt)), and copy to the folder `wider_face/wider_face_split`:
```shell
cd dataset/wider_face/wider_face_split/
wget https://dataset.bj.bcebos.com/wider_face/wider_face_train_bbx_lmk_gt.txt
```
(2)Use `configs/face_detection/blazeface_keypoint.yml` configuration file for training and evaluation, the method of use is the same as the previous section.
![](../images/12_Group_Group_12_Group_Group_12_84.jpg)
## Algorithm Description ## Algorithm Description
### BlazeFace ### BlazeFace
......
...@@ -41,7 +41,8 @@ class WIDERFaceDataSet(DataSet): ...@@ -41,7 +41,8 @@ class WIDERFaceDataSet(DataSet):
image_dir=None, image_dir=None,
anno_path=None, anno_path=None,
sample_num=-1, sample_num=-1,
with_background=True): with_background=True,
with_lmk=False):
super(WIDERFaceDataSet, self).__init__( super(WIDERFaceDataSet, self).__init__(
image_dir=image_dir, image_dir=image_dir,
anno_path=anno_path, anno_path=anno_path,
...@@ -53,6 +54,7 @@ class WIDERFaceDataSet(DataSet): ...@@ -53,6 +54,7 @@ class WIDERFaceDataSet(DataSet):
self.with_background = with_background self.with_background = with_background
self.roidbs = None self.roidbs = None
self.cname2cid = None self.cname2cid = None
self.with_lmk = with_lmk
def load_roidb_and_cname2cid(self): def load_roidb_and_cname2cid(self):
anno_path = os.path.join(self.dataset_dir, self.anno_path) anno_path = os.path.join(self.dataset_dir, self.anno_path)
...@@ -62,33 +64,23 @@ class WIDERFaceDataSet(DataSet): ...@@ -62,33 +64,23 @@ class WIDERFaceDataSet(DataSet):
records = [] records = []
ct = 0 ct = 0
file_lists = _load_file_list(txt_file) file_lists = self._load_file_list(txt_file)
cname2cid = widerface_label(self.with_background) cname2cid = widerface_label(self.with_background)
for item in file_lists: for item in file_lists:
im_fname = item[0] im_fname = item[0]
im_id = np.array([ct]) im_id = np.array([ct])
gt_bbox = np.zeros((len(item) - 2, 4), dtype=np.float32) gt_bbox = np.zeros((len(item) - 1, 4), dtype=np.float32)
gt_class = np.ones((len(item) - 2, 1), dtype=np.int32) gt_class = np.ones((len(item) - 1, 1), dtype=np.int32)
gt_lmk_labels = np.zeros((len(item) - 1, 10), dtype=np.float32)
lmk_ignore_flag = np.zeros((len(item) - 1, 1), dtype=np.int32)
for index_box in range(len(item)): for index_box in range(len(item)):
if index_box >= 2: if index_box < 1:
temp_info_box = item[index_box].split(' ') continue
xmin = float(temp_info_box[0]) gt_bbox[index_box - 1] = item[index_box][0]
ymin = float(temp_info_box[1]) if self.with_lmk:
w = float(temp_info_box[2]) gt_lmk_labels[index_box - 1] = item[index_box][1]
h = float(temp_info_box[3]) lmk_ignore_flag[index_box - 1] = item[index_box][2]
# Filter out wrong labels
if w < 0 or h < 0:
logger.warn('Illegal box with w: {}, h: {} in '
'img: {}, and it will be ignored'.format(
w, h, im_fname))
continue
xmin = max(0, xmin)
ymin = max(0, ymin)
xmax = xmin + w
ymax = ymin + h
gt_bbox[index_box - 2] = [xmin, ymin, xmax, ymax]
im_fname = os.path.join(image_dir, im_fname = os.path.join(image_dir,
im_fname) if image_dir else im_fname im_fname) if image_dir else im_fname
widerface_rec = { widerface_rec = {
...@@ -97,7 +89,10 @@ class WIDERFaceDataSet(DataSet): ...@@ -97,7 +89,10 @@ class WIDERFaceDataSet(DataSet):
'gt_bbox': gt_bbox, 'gt_bbox': gt_bbox,
'gt_class': gt_class, 'gt_class': gt_class,
} }
# logger.debug if self.with_lmk:
widerface_rec['gt_keypoint'] = gt_lmk_labels
widerface_rec['keypoint_ignore'] = lmk_ignore_flag
if len(item) != 0: if len(item) != 0:
records.append(widerface_rec) records.append(widerface_rec)
...@@ -108,34 +103,64 @@ class WIDERFaceDataSet(DataSet): ...@@ -108,34 +103,64 @@ class WIDERFaceDataSet(DataSet):
logger.debug('{} samples in file {}'.format(ct, anno_path)) logger.debug('{} samples in file {}'.format(ct, anno_path))
self.roidbs, self.cname2cid = records, cname2cid self.roidbs, self.cname2cid = records, cname2cid
def _load_file_list(self, input_txt):
def _load_file_list(input_txt): with open(input_txt, 'r') as f_dir:
with open(input_txt, 'r') as f_dir: lines_input_txt = f_dir.readlines()
lines_input_txt = f_dir.readlines()
file_dict = {}
file_dict = {} num_class = 0
num_class = 0 for i in range(len(lines_input_txt)):
for i in range(len(lines_input_txt)): line_txt = lines_input_txt[i].strip('\n\t\r')
line_txt = lines_input_txt[i].strip('\n\t\r') if '.jpg' in line_txt:
if '.jpg' in line_txt: if i != 0:
if i != 0: num_class += 1
num_class += 1 file_dict[num_class] = []
file_dict[num_class] = []
file_dict[num_class].append(line_txt)
if '.jpg' not in line_txt:
if len(line_txt) > 6:
split_str = line_txt.split(' ')
x1_min = float(split_str[0])
y1_min = float(split_str[1])
x2_max = float(split_str[2])
y2_max = float(split_str[3])
line_txt = str(x1_min) + ' ' + str(y1_min) + ' ' + str(
x2_max) + ' ' + str(y2_max)
file_dict[num_class].append(line_txt) file_dict[num_class].append(line_txt)
else: if '.jpg' not in line_txt:
file_dict[num_class].append(line_txt) if len(line_txt) <= 6:
continue
return list(file_dict.values()) result_boxs = []
split_str = line_txt.split(' ')
xmin = float(split_str[0])
ymin = float(split_str[1])
w = float(split_str[2])
h = float(split_str[3])
# Filter out wrong labels
if w < 0 or h < 0:
logger.warn('Illegal box with w: {}, h: {} in '
'img: {}, and it will be ignored'.format(
w, h, file_dict[num_class][0]))
continue
xmin = max(0, xmin)
ymin = max(0, ymin)
xmax = xmin + w
ymax = ymin + h
gt_bbox = [xmin, ymin, xmax, ymax]
result_boxs.append(gt_bbox)
if self.with_lmk:
assert len(split_str) > 18, 'When `with_lmk=True`, the number' \
'of characters per line in the annotation file should' \
'exceed 18.'
lmk0_x = float(split_str[5])
lmk0_y = float(split_str[6])
lmk1_x = float(split_str[8])
lmk1_y = float(split_str[9])
lmk2_x = float(split_str[11])
lmk2_y = float(split_str[12])
lmk3_x = float(split_str[14])
lmk3_y = float(split_str[15])
lmk4_x = float(split_str[17])
lmk4_y = float(split_str[18])
lmk_ignore_flag = 0 if lmk0_x == -1 else 1
gt_lmk_label = [
lmk0_x, lmk0_y, lmk1_x, lmk1_y, lmk2_x, lmk2_y, lmk3_x,
lmk3_y, lmk4_x, lmk4_y
]
result_boxs.append(gt_lmk_label)
result_boxs.append(lmk_ignore_flag)
file_dict[num_class].append(result_boxs)
return list(file_dict.values())
def widerface_label(with_background=True): def widerface_label(with_background=True):
......
...@@ -61,10 +61,13 @@ def is_overlap(object_bbox, sample_bbox): ...@@ -61,10 +61,13 @@ def is_overlap(object_bbox, sample_bbox):
return True return True
def filter_and_process(sample_bbox, bboxes, labels, scores=None): def filter_and_process(sample_bbox, bboxes, labels, scores=None,
keypoints=None):
new_bboxes = [] new_bboxes = []
new_labels = [] new_labels = []
new_scores = [] new_scores = []
new_keypoints = []
new_kp_ignore = []
for i in range(len(bboxes)): for i in range(len(bboxes)):
new_bbox = [0, 0, 0, 0] new_bbox = [0, 0, 0, 0]
obj_bbox = [bboxes[i][0], bboxes[i][1], bboxes[i][2], bboxes[i][3]] obj_bbox = [bboxes[i][0], bboxes[i][1], bboxes[i][2], bboxes[i][3]]
...@@ -84,9 +87,24 @@ def filter_and_process(sample_bbox, bboxes, labels, scores=None): ...@@ -84,9 +87,24 @@ def filter_and_process(sample_bbox, bboxes, labels, scores=None):
new_labels.append([labels[i][0]]) new_labels.append([labels[i][0]])
if scores is not None: if scores is not None:
new_scores.append([scores[i][0]]) new_scores.append([scores[i][0]])
if keypoints is not None:
sample_keypoint = keypoints[0][i]
for j in range(len(sample_keypoint)):
kp_len = sample_height if j % 2 else sample_width
sample_coord = sample_bbox[1] if j % 2 else sample_bbox[0]
sample_keypoint[j] = (
sample_keypoint[j] - sample_coord) / kp_len
sample_keypoint[j] = max(min(sample_keypoint[j], 1.0), 0.0)
new_keypoints.append(sample_keypoint)
new_kp_ignore.append(keypoints[1][i])
bboxes = np.array(new_bboxes) bboxes = np.array(new_bboxes)
labels = np.array(new_labels) labels = np.array(new_labels)
scores = np.array(new_scores) scores = np.array(new_scores)
if keypoints is not None:
keypoints = np.array(new_keypoints)
new_kp_ignore = np.array(new_kp_ignore)
return bboxes, labels, scores, (keypoints, new_kp_ignore)
return bboxes, labels, scores return bboxes, labels, scores
......
...@@ -32,9 +32,10 @@ import logging ...@@ -32,9 +32,10 @@ import logging
import random import random
import math import math
import numpy as np import numpy as np
import os
import cv2 import cv2
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance, ImageDraw
from ppdet.core.workspace import serializable from ppdet.core.workspace import serializable
from ppdet.modeling.ops import AnchorGrid from ppdet.modeling.ops import AnchorGrid
...@@ -400,6 +401,16 @@ class RandomFlipImage(BaseOperator): ...@@ -400,6 +401,16 @@ class RandomFlipImage(BaseOperator):
flipped_segms.append(_flip_rle(segm, height, width)) flipped_segms.append(_flip_rle(segm, height, width))
return flipped_segms return flipped_segms
def flip_keypoint(self, gt_keypoint, width):
for i in range(gt_keypoint.shape[1]):
if i % 2 == 0:
old_x = gt_keypoint[:, i].copy()
if self.is_normalized:
gt_keypoint[:, i] = 1 - old_x
else:
gt_keypoint[:, i] = width - old_x - 1
return gt_keypoint
def __call__(self, sample, context=None): def __call__(self, sample, context=None):
"""Filp the image and bounding box. """Filp the image and bounding box.
Operators: Operators:
...@@ -447,6 +458,9 @@ class RandomFlipImage(BaseOperator): ...@@ -447,6 +458,9 @@ class RandomFlipImage(BaseOperator):
if self.is_mask_flip and len(sample['gt_poly']) != 0: if self.is_mask_flip and len(sample['gt_poly']) != 0:
sample['gt_poly'] = self.flip_segms(sample['gt_poly'], sample['gt_poly'] = self.flip_segms(sample['gt_poly'],
height, width) height, width)
if 'gt_keypoint' in sample.keys():
sample['gt_keypoint'] = self.flip_keypoint(
sample['gt_keypoint'], width)
sample['flipped'] = True sample['flipped'] = True
sample['image'] = im sample['image'] = im
sample = samples if batch_input else samples[0] sample = samples if batch_input else samples[0]
...@@ -741,8 +755,17 @@ class ExpandImage(BaseOperator): ...@@ -741,8 +755,17 @@ class ExpandImage(BaseOperator):
im = Image.fromarray(im) im = Image.fromarray(im)
expand_im.paste(im, (int(w_off), int(h_off))) expand_im.paste(im, (int(w_off), int(h_off)))
expand_im = np.asarray(expand_im) expand_im = np.asarray(expand_im)
gt_bbox, gt_class, _ = filter_and_process(expand_bbox, gt_bbox, if 'gt_keypoint' in sample.keys(
gt_class) ) and 'keypoint_ignore' in sample.keys():
keypoints = (sample['gt_keypoint'],
sample['keypoint_ignore'])
gt_bbox, gt_class, _, gt_keypoints = filter_and_process(
expand_bbox, gt_bbox, gt_class, keypoints=keypoints)
sample['gt_keypoint'] = gt_keypoints[0]
sample['keypoint_ignore'] = gt_keypoints[1]
else:
gt_bbox, gt_class, _ = filter_and_process(expand_bbox,
gt_bbox, gt_class)
sample['image'] = expand_im sample['image'] = expand_im
sample['gt_bbox'] = gt_bbox sample['gt_bbox'] = gt_bbox
sample['gt_class'] = gt_class sample['gt_class'] = gt_class
...@@ -816,7 +839,7 @@ class CropImage(BaseOperator): ...@@ -816,7 +839,7 @@ class CropImage(BaseOperator):
sample_bbox = sampled_bbox.pop(idx) sample_bbox = sampled_bbox.pop(idx)
sample_bbox = clip_bbox(sample_bbox) sample_bbox = clip_bbox(sample_bbox)
crop_bbox, crop_class, crop_score = \ crop_bbox, crop_class, crop_score = \
filter_and_process(sample_bbox, gt_bbox, gt_class, gt_score) filter_and_process(sample_bbox, gt_bbox, gt_class, scores=gt_score)
if self.avoid_no_bbox: if self.avoid_no_bbox:
if len(crop_bbox) < 1: if len(crop_bbox) < 1:
continue continue
...@@ -919,8 +942,16 @@ class CropImageWithDataAchorSampling(BaseOperator): ...@@ -919,8 +942,16 @@ class CropImageWithDataAchorSampling(BaseOperator):
idx = int(np.random.uniform(0, len(sampled_bbox))) idx = int(np.random.uniform(0, len(sampled_bbox)))
sample_bbox = sampled_bbox.pop(idx) sample_bbox = sampled_bbox.pop(idx)
crop_bbox, crop_class, crop_score = filter_and_process( if 'gt_keypoint' in sample.keys():
sample_bbox, gt_bbox, gt_class, gt_score) keypoints = (sample['gt_keypoint'],
sample['keypoint_ignore'])
crop_bbox, crop_class, crop_score, gt_keypoints = \
filter_and_process(sample_bbox, gt_bbox, gt_class,
scores=gt_score,
keypoints=keypoints)
else:
crop_bbox, crop_class, crop_score = filter_and_process(
sample_bbox, gt_bbox, gt_class, scores=gt_score)
crop_bbox, crop_class, crop_score = bbox_area_sampling( crop_bbox, crop_class, crop_score = bbox_area_sampling(
crop_bbox, crop_class, crop_score, self.target_size, crop_bbox, crop_class, crop_score, self.target_size,
self.min_size) self.min_size)
...@@ -934,6 +965,9 @@ class CropImageWithDataAchorSampling(BaseOperator): ...@@ -934,6 +965,9 @@ class CropImageWithDataAchorSampling(BaseOperator):
sample['gt_bbox'] = crop_bbox sample['gt_bbox'] = crop_bbox
sample['gt_class'] = crop_class sample['gt_class'] = crop_class
sample['gt_score'] = crop_score sample['gt_score'] = crop_score
if 'gt_keypoint' in sample.keys():
sample['gt_keypoint'] = gt_keypoints[0]
sample['keypoint_ignore'] = gt_keypoints[1]
return sample return sample
return sample return sample
...@@ -955,8 +989,16 @@ class CropImageWithDataAchorSampling(BaseOperator): ...@@ -955,8 +989,16 @@ class CropImageWithDataAchorSampling(BaseOperator):
sample_bbox = sampled_bbox.pop(idx) sample_bbox = sampled_bbox.pop(idx)
sample_bbox = clip_bbox(sample_bbox) sample_bbox = clip_bbox(sample_bbox)
crop_bbox, crop_class, crop_score = filter_and_process( if 'gt_keypoint' in sample.keys():
sample_bbox, gt_bbox, gt_class, gt_score) keypoints = (sample['gt_keypoint'],
sample['keypoint_ignore'])
crop_bbox, crop_class, crop_score, gt_keypoints = \
filter_and_process(sample_bbox, gt_bbox, gt_class,
scores=gt_score,
keypoints=keypoints)
else:
crop_bbox, crop_class, crop_score = filter_and_process(
sample_bbox, gt_bbox, gt_class, scores=gt_score)
# sampling bbox according the bbox area # sampling bbox according the bbox area
crop_bbox, crop_class, crop_score = bbox_area_sampling( crop_bbox, crop_class, crop_score = bbox_area_sampling(
crop_bbox, crop_class, crop_score, self.target_size, crop_bbox, crop_class, crop_score, self.target_size,
...@@ -974,6 +1016,9 @@ class CropImageWithDataAchorSampling(BaseOperator): ...@@ -974,6 +1016,9 @@ class CropImageWithDataAchorSampling(BaseOperator):
sample['gt_bbox'] = crop_bbox sample['gt_bbox'] = crop_bbox
sample['gt_class'] = crop_class sample['gt_class'] = crop_class
sample['gt_score'] = crop_score sample['gt_score'] = crop_score
if 'gt_keypoint' in sample.keys():
sample['gt_keypoint'] = gt_keypoints[0]
sample['keypoint_ignore'] = gt_keypoints[1]
return sample return sample
return sample return sample
...@@ -995,6 +1040,17 @@ class NormalizeBox(BaseOperator): ...@@ -995,6 +1040,17 @@ class NormalizeBox(BaseOperator):
gt_bbox[i][2] = gt_bbox[i][2] / width gt_bbox[i][2] = gt_bbox[i][2] / width
gt_bbox[i][3] = gt_bbox[i][3] / height gt_bbox[i][3] = gt_bbox[i][3] / height
sample['gt_bbox'] = gt_bbox sample['gt_bbox'] = gt_bbox
if 'gt_keypoint' in sample.keys():
gt_keypoint = sample['gt_keypoint']
for i in range(gt_keypoint.shape[1]):
if i % 2:
gt_keypoint[:, i] = gt_keypoint[:, i] / height
else:
gt_keypoint[:, i] = gt_keypoint[:, i] / width
sample['gt_keypoint'] = gt_keypoint
return sample return sample
...@@ -1837,7 +1893,6 @@ class BboxXYXY2XYWH(BaseOperator): ...@@ -1837,7 +1893,6 @@ class BboxXYXY2XYWH(BaseOperator):
return sample return sample
@register_op
class Lighting(BaseOperator): class Lighting(BaseOperator):
""" """
Lighting the imagen by eigenvalues and eigenvectors Lighting the imagen by eigenvalues and eigenvectors
...@@ -2270,3 +2325,69 @@ class TargetAssign(BaseOperator): ...@@ -2270,3 +2325,69 @@ class TargetAssign(BaseOperator):
targets[matched_indices] = matched_targets targets[matched_indices] = matched_targets
sample['fg_num'] = np.array(len(matched_targets), dtype=np.int32) sample['fg_num'] = np.array(len(matched_targets), dtype=np.int32)
return sample return sample
@register_op
class DebugVisibleImage(BaseOperator):
"""
In debug mode, visualize images according to `gt_box`.
(Currently only supported when not cropping and flipping image.)
"""
def __init__(self, output_dir='output/debug', is_normalized=False):
super(DebugVisibleImage, self).__init__()
self.is_normalized = is_normalized
self.output_dir = output_dir
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
if not isinstance(self.is_normalized, bool):
raise TypeError("{}: input type is invalid.".format(self))
def __call__(self, sample, context=None):
image = Image.open(sample['im_file']).convert('RGB')
out_file_name = sample['im_file'].split('/')[-1]
width = sample['w']
height = sample['h']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
draw = ImageDraw.Draw(image)
for i in range(gt_bbox.shape[0]):
if self.is_normalized:
gt_bbox[i][0] = gt_bbox[i][0] * width
gt_bbox[i][1] = gt_bbox[i][1] * height
gt_bbox[i][2] = gt_bbox[i][2] * width
gt_bbox[i][3] = gt_bbox[i][3] * height
xmin, ymin, xmax, ymax = gt_bbox[i]
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=2,
fill='green')
# draw label
text = str(gt_class[i][0])
tw, th = draw.textsize(text)
draw.rectangle(
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill='green')
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
if 'gt_keypoint' in sample.keys():
gt_keypoint = sample['gt_keypoint']
if self.is_normalized:
for i in range(gt_keypoint.shape[1]):
if i % 2:
gt_keypoint[:, i] = gt_keypoint[:, i] * height
else:
gt_keypoint[:, i] = gt_keypoint[:, i] * width
for i in range(gt_keypoint.shape[0]):
keypoint = gt_keypoint[i]
for j in range(int(keypoint.shape[0] / 2)):
x1 = round(keypoint[2 * j]).astype(np.int32)
y1 = round(keypoint[2 * j + 1]).astype(np.int32)
draw.ellipse(
(x1, y1, x1 + 5, y1i + 5),
fill='green',
outline='green')
save_path = os.path.join(self.output_dir, out_file_name)
image.save(save_path, quality=95)
return sample
...@@ -25,6 +25,7 @@ from paddle.fluid.regularizer import L2Decay ...@@ -25,6 +25,7 @@ from paddle.fluid.regularizer import L2Decay
from ppdet.core.workspace import register from ppdet.core.workspace import register
from ppdet.modeling.ops import SSDOutputDecoder from ppdet.modeling.ops import SSDOutputDecoder
from ppdet.modeling.losses import SSDWithLmkLoss
__all__ = ['BlazeFace'] __all__ = ['BlazeFace']
...@@ -59,24 +60,29 @@ class BlazeFace(object): ...@@ -59,24 +60,29 @@ class BlazeFace(object):
steps=[8., 16.], steps=[8., 16.],
num_classes=2, num_classes=2,
use_density_prior_box=False, use_density_prior_box=False,
densities=[[2, 2], [2, 1, 1, 1, 1, 1]]): densities=[[2, 2], [2, 1, 1, 1, 1, 1]],
with_lmk=False,
lmk_loss=SSDWithLmkLoss().__dict__):
super(BlazeFace, self).__init__() super(BlazeFace, self).__init__()
self.backbone = backbone self.backbone = backbone
self.num_classes = num_classes self.num_classes = num_classes
self.with_lmk = with_lmk
self.output_decoder = output_decoder self.output_decoder = output_decoder
if isinstance(output_decoder, dict): if isinstance(output_decoder, dict):
if self.with_lmk:
output_decoder['return_index'] = True
self.output_decoder = SSDOutputDecoder(**output_decoder) self.output_decoder = SSDOutputDecoder(**output_decoder)
self.min_sizes = min_sizes self.min_sizes = min_sizes
self.max_sizes = max_sizes self.max_sizes = max_sizes
self.steps = steps self.steps = steps
self.use_density_prior_box = use_density_prior_box self.use_density_prior_box = use_density_prior_box
self.densities = densities self.densities = densities
self.landmark = None
if self.with_lmk and isinstance(lmk_loss, dict):
self.lmk_loss = SSDWithLmkLoss(**lmk_loss)
def build(self, feed_vars, mode='train'): def build(self, feed_vars, mode='train'):
im = feed_vars['image'] im = feed_vars['image']
if mode == 'train':
gt_bbox = feed_vars['gt_bbox']
gt_class = feed_vars['gt_class']
body_feats = self.backbone(im) body_feats = self.backbone(im)
locs, confs, box, box_var = self._multi_box_head( locs, confs, box, box_var = self._multi_box_head(
...@@ -86,20 +92,40 @@ class BlazeFace(object): ...@@ -86,20 +92,40 @@ class BlazeFace(object):
use_density_prior_box=self.use_density_prior_box) use_density_prior_box=self.use_density_prior_box)
if mode == 'train': if mode == 'train':
loss = fluid.layers.ssd_loss( gt_bbox = feed_vars['gt_bbox']
locs, gt_class = feed_vars['gt_class']
confs, if self.with_lmk:
gt_bbox, lmk_labels = feed_vars['gt_keypoint']
gt_class, lmk_ignore_flag = feed_vars["keypoint_ignore"]
box, loss = self.lmk_loss(locs, confs, gt_bbox, gt_class,
box_var, self.landmark, lmk_labels, lmk_ignore_flag,
overlap_threshold=0.35, box, box_var)
neg_overlap=0.35) else:
loss = fluid.layers.ssd_loss(
locs,
confs,
gt_bbox,
gt_class,
box,
box_var,
overlap_threshold=0.35,
neg_overlap=0.35)
loss = fluid.layers.reduce_sum(loss) loss = fluid.layers.reduce_sum(loss)
return {'loss': loss} return {'loss': loss}
else: else:
pred = self.output_decoder(locs, confs, box, box_var) if self.with_lmk:
return {'bbox': pred} pred, face_index = self.output_decoder(locs, confs, box,
box_var)
return {
'bbox': pred,
'face_index': face_index,
'prior_boxes': box,
'landmark': self.landmark
}
else:
pred = self.output_decoder(locs, confs, box, box_var)
return {'bbox': pred}
def _multi_box_head(self, def _multi_box_head(self,
inputs, inputs,
...@@ -111,11 +137,9 @@ class BlazeFace(object): ...@@ -111,11 +137,9 @@ class BlazeFace(object):
compile_shape = [0, -1, last_dim] compile_shape = [0, -1, last_dim]
return fluid.layers.reshape(trans, shape=compile_shape) return fluid.layers.reshape(trans, shape=compile_shape)
def _is_list_or_tuple_(data):
return (isinstance(data, list) or isinstance(data, tuple))
locs, confs = [], [] locs, confs = [], []
boxes, vars = [], [] boxes, vars = [], []
lmk_locs = []
b_attr = ParamAttr(learning_rate=2., regularizer=L2Decay(0.)) b_attr = ParamAttr(learning_rate=2., regularizer=L2Decay(0.))
for i, input in enumerate(inputs): for i, input in enumerate(inputs):
...@@ -157,7 +181,21 @@ class BlazeFace(object): ...@@ -157,7 +181,21 @@ class BlazeFace(object):
# get conf # get conf
mbox_conf = fluid.layers.conv2d( mbox_conf = fluid.layers.conv2d(
input, num_conf_output, 3, 1, 1, bias_attr=b_attr) input, num_conf_output, 3, 1, 1, bias_attr=b_attr)
conf = permute_and_reshape(mbox_conf, 2) conf = permute_and_reshape(mbox_conf, num_classes)
if self.with_lmk:
# get landmark
lmk_loc_output = num_boxes * 10
lmk_box_loc = fluid.layers.conv2d(
input,
lmk_loc_output,
3,
1,
1,
param_attr=ParamAttr(name='lmk' + str(i) + '_weights'),
bias_attr=False)
lmk_loc = permute_and_reshape(lmk_box_loc, 10)
lmk_locs.append(lmk_loc)
locs.append(loc) locs.append(loc)
confs.append(conf) confs.append(conf)
...@@ -168,6 +206,8 @@ class BlazeFace(object): ...@@ -168,6 +206,8 @@ class BlazeFace(object):
face_mbox_conf = fluid.layers.concat(confs, axis=1) face_mbox_conf = fluid.layers.concat(confs, axis=1)
prior_boxes = fluid.layers.concat(boxes) prior_boxes = fluid.layers.concat(boxes)
box_vars = fluid.layers.concat(vars) box_vars = fluid.layers.concat(vars)
if self.with_lmk:
self.landmark = fluid.layers.concat(lmk_locs, axis=1)
return face_mbox_loc, face_mbox_conf, prior_boxes, box_vars return face_mbox_loc, face_mbox_conf, prior_boxes, box_vars
def _inputs_def(self, image_shape): def _inputs_def(self, image_shape):
...@@ -179,6 +219,8 @@ class BlazeFace(object): ...@@ -179,6 +219,8 @@ class BlazeFace(object):
'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1}, 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1},
'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
'im_shape': {'shape': [None, 3], 'dtype': 'int32', 'lod_level': 0}, 'im_shape': {'shape': [None, 3], 'dtype': 'int32', 'lod_level': 0},
'gt_keypoint': {'shape': [None, 10], 'dtype': 'float32', 'lod_level': 1},
'keypoint_ignore': {'shape': [None, 1], 'dtype': 'float32', 'lod_level': 1},
} }
# yapf: enable # yapf: enable
return inputs_def return inputs_def
......
...@@ -23,6 +23,7 @@ from . import balanced_l1_loss ...@@ -23,6 +23,7 @@ from . import balanced_l1_loss
from . import fcos_loss from . import fcos_loss
from . import diou_loss_yolo from . import diou_loss_yolo
from . import iou_aware_loss from . import iou_aware_loss
from . import ssd_with_lmk_loss
from .iou_aware_loss import * from .iou_aware_loss import *
from .yolo_loss import * from .yolo_loss import *
...@@ -33,3 +34,4 @@ from .iou_loss import * ...@@ -33,3 +34,4 @@ from .iou_loss import *
from .balanced_l1_loss import * from .balanced_l1_loss import *
from .fcos_loss import * from .fcos_loss import *
from .diou_loss_yolo import * from .diou_loss_yolo import *
from .ssd_with_lmk_loss import *
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import Variable
import paddle.fluid.layers as layers
from paddle.fluid.layers import (tensor, iou_similarity, bipartite_match,
target_assign, box_coder)
from ppdet.core.workspace import register, serializable
__all__ = ['SSDWithLmkLoss']
@register
@serializable
class SSDWithLmkLoss(object):
"""
ssd_with_lmk_loss function.
Args:
background_label (int): The index of background label, 0 by default.
overlap_threshold (float): If match_type is `per_prediction`,
use `overlap_threshold` to determine the extra matching bboxes
when finding matched boxes. 0.5 by default.
neg_pos_ratio (float): The ratio of the negative boxes to the positive
boxes, used only when mining_type is `max_negative`, 3.0 by default.
neg_overlap (float): The negative overlap upper bound for the unmatched
predictions. Use only when mining_type is `max_negative`, 0.5 by default.
loc_loss_weight (float): Weight for localization loss, 1.0 by default.
conf_loss_weight (float): Weight for confidence loss, 1.0 by default.
match_type (str): The type of matching method during training, should be
`bipartite` or `per_prediction`, `per_prediction` by default.
normalize (bool): Whether to normalize the loss by the total number of
output locations, True by default.
"""
def __init__(self,
background_label=0,
overlap_threshold=0.5,
neg_pos_ratio=3.0,
neg_overlap=0.5,
loc_loss_weight=1.0,
conf_loss_weight=1.0,
match_type='per_prediction',
normalize=True):
super(SSDWithLmkLoss, self).__init__()
self.background_label = background_label
self.overlap_threshold = overlap_threshold
self.neg_pos_ratio = neg_pos_ratio
self.neg_overlap = neg_overlap
self.loc_loss_weight = loc_loss_weight
self.conf_loss_weight = conf_loss_weight
self.match_type = match_type
self.normalize = normalize
def __call__(self,
location,
confidence,
gt_box,
gt_label,
landmark_predict,
lmk_label,
lmk_ignore_flag,
prior_box,
prior_box_var=None):
def _reshape_to_2d(var):
return layers.flatten(x=var, axis=2)
helper = LayerHelper('ssd_loss') #, **locals())
# Only support mining_type == 'max_negative' now.
mining_type = 'max_negative'
# The max `sample_size` of negative box, used only
# when mining_type is `hard_example`.
sample_size = None
num, num_prior, num_class = confidence.shape
conf_shape = layers.shape(confidence)
# 1. Find matched boundding box by prior box.
# 1.1 Compute IOU similarity between ground-truth boxes and prior boxes.
iou = iou_similarity(x=gt_box, y=prior_box)
# 1.2 Compute matched boundding box by bipartite matching algorithm.
matched_indices, matched_dist = bipartite_match(iou, self.match_type,
self.overlap_threshold)
# 2. Compute confidence for mining hard examples
# 2.1. Get the target label based on matched indices
gt_label = layers.reshape(
x=gt_label, shape=(len(gt_label.shape) - 1) * (0, ) + (-1, 1))
gt_label.stop_gradient = True
target_label, _ = target_assign(
gt_label, matched_indices, mismatch_value=self.background_label)
# 2.2. Compute confidence loss.
# Reshape confidence to 2D tensor.
confidence = _reshape_to_2d(confidence)
target_label = tensor.cast(x=target_label, dtype='int64')
target_label = _reshape_to_2d(target_label)
target_label.stop_gradient = True
conf_loss = layers.softmax_with_cross_entropy(confidence, target_label)
# 3. Mining hard examples
actual_shape = layers.slice(conf_shape, axes=[0], starts=[0], ends=[2])
actual_shape.stop_gradient = True
conf_loss = layers.reshape(
x=conf_loss, shape=(-1, 0), actual_shape=actual_shape)
conf_loss.stop_gradient = True
neg_indices = helper.create_variable_for_type_inference(dtype='int32')
updated_matched_indices = helper.create_variable_for_type_inference(
dtype=matched_indices.dtype)
helper.append_op(
type='mine_hard_examples',
inputs={
'ClsLoss': conf_loss,
'LocLoss': None,
'MatchIndices': matched_indices,
'MatchDist': matched_dist,
},
outputs={
'NegIndices': neg_indices,
'UpdatedMatchIndices': updated_matched_indices
},
attrs={
'neg_pos_ratio': self.neg_pos_ratio,
'neg_dist_threshold': self.neg_overlap,
'mining_type': mining_type,
'sample_size': sample_size,
})
# 4. Assign classification and regression targets
# 4.1. Encoded bbox according to the prior boxes.
encoded_bbox = box_coder(
prior_box=prior_box,
prior_box_var=prior_box_var,
target_box=gt_box,
code_type='encode_center_size')
# 4.2. Assign regression targets
target_bbox, target_loc_weight = target_assign(
encoded_bbox,
updated_matched_indices,
mismatch_value=self.background_label)
# 4.3. Assign classification targets
target_label, target_conf_weight = target_assign(
gt_label,
updated_matched_indices,
negative_indices=neg_indices,
mismatch_value=self.background_label)
target_loc_weight = target_loc_weight * target_label
encoded_lmk_label = self.decode_lmk(lmk_label, prior_box, prior_box_var)
target_lmk, target_lmk_weight = target_assign(
encoded_lmk_label,
updated_matched_indices,
mismatch_value=self.background_label)
lmk_ignore_flag = layers.reshape(
x=lmk_ignore_flag,
shape=(len(lmk_ignore_flag.shape) - 1) * (0, ) + (-1, 1))
target_ignore, nouse = target_assign(
lmk_ignore_flag,
updated_matched_indices,
mismatch_value=self.background_label)
target_lmk_weight = target_lmk_weight * target_ignore
landmark_predict = _reshape_to_2d(landmark_predict)
target_lmk = _reshape_to_2d(target_lmk)
target_lmk_weight = _reshape_to_2d(target_lmk_weight)
lmk_loss = layers.smooth_l1(landmark_predict, target_lmk)
lmk_loss = lmk_loss * target_lmk_weight
target_lmk.stop_gradient = True
target_lmk_weight.stop_gradient = True
target_ignore.stop_gradient = True
nouse.stop_gradient = True
# 5. Compute loss.
# 5.1 Compute confidence loss.
target_label = _reshape_to_2d(target_label)
target_label = tensor.cast(x=target_label, dtype='int64')
conf_loss = layers.softmax_with_cross_entropy(confidence, target_label)
target_conf_weight = _reshape_to_2d(target_conf_weight)
conf_loss = conf_loss * target_conf_weight
# the target_label and target_conf_weight do not have gradient.
target_label.stop_gradient = True
target_conf_weight.stop_gradient = True
# 5.2 Compute regression loss.
location = _reshape_to_2d(location)
target_bbox = _reshape_to_2d(target_bbox)
loc_loss = layers.smooth_l1(location, target_bbox)
target_loc_weight = _reshape_to_2d(target_loc_weight)
loc_loss = loc_loss * target_loc_weight
# the target_bbox and target_loc_weight do not have gradient.
target_bbox.stop_gradient = True
target_loc_weight.stop_gradient = True
# 5.3 Compute overall weighted loss.
loss = self.conf_loss_weight * conf_loss + self.loc_loss_weight * loc_loss + 0.4 * lmk_loss
# reshape to [N, Np], N is the batch size and Np is the prior box number.
loss = layers.reshape(x=loss, shape=(-1, 0), actual_shape=actual_shape)
loss = layers.reduce_sum(loss, dim=1, keep_dim=True)
if self.normalize:
normalizer = layers.reduce_sum(target_loc_weight) + 1
loss = loss / normalizer
return loss
def decode_lmk(self, lmk_label, prior_box, prior_box_var):
label0, label1, label2, label3, label4 = fluid.layers.split(
lmk_label, num_or_sections=5, dim=1)
lmk_labels_list = [label0, label1, label2, label3, label4]
encoded_lmk_list = []
for label in lmk_labels_list:
concat_label = fluid.layers.concat([label, label], axis=1)
encoded_label = box_coder(
prior_box=prior_box,
prior_box_var=prior_box_var,
target_box=concat_label,
code_type='encode_center_size')
encoded_lmk_label, _ = fluid.layers.split(
encoded_label, num_or_sections=2, dim=2)
encoded_lmk_list.append(encoded_lmk_label)
encoded_lmk_concat = fluid.layers.concat(
[
encoded_lmk_list[0], encoded_lmk_list[1], encoded_lmk_list[2],
encoded_lmk_list[3], encoded_lmk_list[4]
],
axis=2)
return encoded_lmk_concat
...@@ -1478,7 +1478,8 @@ class SSDOutputDecoder(object): ...@@ -1478,7 +1478,8 @@ class SSDOutputDecoder(object):
keep_top_k=200, keep_top_k=200,
score_threshold=0.01, score_threshold=0.01,
nms_eta=1.0, nms_eta=1.0,
background_label=0): background_label=0,
return_index=False):
super(SSDOutputDecoder, self).__init__() super(SSDOutputDecoder, self).__init__()
self.nms_threshold = nms_threshold self.nms_threshold = nms_threshold
self.background_label = background_label self.background_label = background_label
...@@ -1486,6 +1487,7 @@ class SSDOutputDecoder(object): ...@@ -1486,6 +1487,7 @@ class SSDOutputDecoder(object):
self.keep_top_k = keep_top_k self.keep_top_k = keep_top_k
self.score_threshold = score_threshold self.score_threshold = score_threshold
self.nms_eta = nms_eta self.nms_eta = nms_eta
self.return_index = return_index
@register @register
......
...@@ -115,7 +115,8 @@ def load_params(exe, prog, path, ignore_params=[]): ...@@ -115,7 +115,8 @@ def load_params(exe, prog, path, ignore_params=[]):
path = _get_weight_path(path) path = _get_weight_path(path)
path = _strip_postfix(path) path = _strip_postfix(path)
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): if not (os.path.isdir(path) or os.path.isfile(path) or
os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not " raise ValueError("Model pretrain path {} does not "
"exists.".format(path)) "exists.".format(path))
......
...@@ -30,7 +30,8 @@ def visualize_results(image, ...@@ -30,7 +30,8 @@ def visualize_results(image,
catid2name, catid2name,
threshold=0.5, threshold=0.5,
bbox_results=None, bbox_results=None,
mask_results=None): mask_results=None,
lmk_results=None):
""" """
Visualize bbox and mask results Visualize bbox and mask results
""" """
...@@ -38,6 +39,8 @@ def visualize_results(image, ...@@ -38,6 +39,8 @@ def visualize_results(image,
image = draw_mask(image, im_id, mask_results, threshold) image = draw_mask(image, im_id, mask_results, threshold)
if bbox_results: if bbox_results:
image = draw_bbox(image, im_id, catid2name, bbox_results, threshold) image = draw_bbox(image, im_id, catid2name, bbox_results, threshold)
if lmk_results:
image = draw_lmk(image, im_id, lmk_results, threshold)
return image return image
...@@ -106,3 +109,21 @@ def draw_bbox(image, im_id, catid2name, bboxes, threshold): ...@@ -106,3 +109,21 @@ def draw_bbox(image, im_id, catid2name, bboxes, threshold):
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255)) draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
return image return image
def draw_lmk(image, im_id, lmk_results, threshold):
draw = ImageDraw.Draw(image)
catid2color = {}
color_list = colormap(rgb=True)[:40]
for dt in np.array(lmk_results):
lmk_decode, score = dt['landmark'], dt['score']
if im_id != dt['image_id']:
continue
if score < threshold:
continue
for j in range(5):
x1 = int(round(lmk_decode[2 * j]))
y1 = int(round(lmk_decode[2 * j + 1]))
draw.ellipse(
(x1, y1, x1 + 5, y1 + 5), fill='green', outline='green')
return image
...@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) ...@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
__all__ = [ __all__ = [
'get_shrink', 'bbox_vote', 'save_widerface_bboxes', 'save_fddb_bboxes', 'get_shrink', 'bbox_vote', 'save_widerface_bboxes', 'save_fddb_bboxes',
'to_chw_bgr', 'bbox2out', 'get_category_info' 'to_chw_bgr', 'bbox2out', 'get_category_info', 'lmk2out'
] ]
...@@ -227,3 +227,58 @@ def widerfaceall_category_info(with_background=True): ...@@ -227,3 +227,58 @@ def widerfaceall_category_info(with_background=True):
catid2name = {i: name for i, name in enumerate(cats)} catid2name = {i: name for i, name in enumerate(cats)}
return clsid2catid, catid2name return clsid2catid, catid2name
def lmk2out(results, is_bbox_normalized=False):
"""
Args:
results: request a dict, should include: `landmark`, `im_id`,
if is_bbox_normalized=True, also need `im_shape`.
is_bbox_normalized: whether or not landmark is normalized.
"""
xywh_res = []
for t in results:
bboxes = t['bbox'][0]
lengths = t['bbox'][1][0]
im_ids = np.array(t['im_id'][0]).flatten()
if bboxes.shape == (1, 1) or bboxes is None:
continue
face_index = t['face_index'][0]
prior_box = t['prior_boxes'][0]
predict_lmk = t['landmark'][0]
prior = np.reshape(prior_box, (-1, 4))
predictlmk = np.reshape(predict_lmk, (-1, 10))
k = 0
for a in range(len(lengths)):
num = lengths[a]
im_id = int(im_ids[a])
for i in range(num):
score = bboxes[k][1]
theindex = face_index[i][0]
me_prior = prior[theindex, :]
lmk_pred = predictlmk[theindex, :]
prior_w = me_prior[2] - me_prior[0]
prior_h = me_prior[3] - me_prior[1]
prior_w_center = (me_prior[2] + me_prior[0]) / 2
prior_h_center = (me_prior[3] + me_prior[1]) / 2
lmk_decode = np.zeros((10))
for j in [0, 2, 4, 6, 8]:
lmk_decode[j] = lmk_pred[j] * 0.1 * prior_w + prior_w_center
for j in [1, 3, 5, 7, 9]:
lmk_decode[j] = lmk_pred[j] * 0.1 * prior_h + prior_h_center
im_shape = t['im_shape'][0][a].tolist()
image_h, image_w = int(im_shape[0]), int(im_shape[1])
if is_bbox_normalized:
lmk_decode = lmk_decode * np.array([
image_w, image_h, image_w, image_h, image_w, image_h,
image_w, image_h, image_w, image_h
])
lmk_res = {
'image_id': im_id,
'landmark': lmk_decode,
'score': score,
}
xywh_res.append(lmk_res)
k += 1
return xywh_res
...@@ -144,7 +144,7 @@ def main(): ...@@ -144,7 +144,7 @@ def main():
if cfg.metric == "VOC": if cfg.metric == "VOC":
from ppdet.utils.voc_eval import bbox2out, get_category_info from ppdet.utils.voc_eval import bbox2out, get_category_info
if cfg.metric == "WIDERFACE": if cfg.metric == "WIDERFACE":
from ppdet.utils.widerface_eval_utils import bbox2out, get_category_info from ppdet.utils.widerface_eval_utils import bbox2out, lmk2out, get_category_info
anno_file = dataset.get_anno() anno_file = dataset.get_anno()
with_background = dataset.with_background with_background = dataset.with_background
...@@ -181,11 +181,14 @@ def main(): ...@@ -181,11 +181,14 @@ def main():
bbox_results = None bbox_results = None
mask_results = None mask_results = None
lmk_results = None
if 'bbox' in res: if 'bbox' in res:
bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized) bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized)
if 'mask' in res: if 'mask' in res:
mask_results = mask2out([res], clsid2catid, mask_results = mask2out([res], clsid2catid,
model.mask_head.resolution) model.mask_head.resolution)
if 'landmark' in res:
lmk_results = lmk2out([res], is_bbox_normalized)
# visualize result # visualize result
im_ids = res['im_id'][0] im_ids = res['im_id'][0]
...@@ -203,7 +206,7 @@ def main(): ...@@ -203,7 +206,7 @@ def main():
image = visualize_results(image, image = visualize_results(image,
int(im_id), catid2name, int(im_id), catid2name,
FLAGS.draw_threshold, bbox_results, FLAGS.draw_threshold, bbox_results,
mask_results) mask_results, lmk_results)
# use VisualDL to log image with bbox # use VisualDL to log image with bbox
if FLAGS.use_vdl: if FLAGS.use_vdl:
...@@ -253,4 +256,4 @@ if __name__ == '__main__': ...@@ -253,4 +256,4 @@ if __name__ == '__main__':
default="vdl_log_dir/image", default="vdl_log_dir/image",
help='VisualDL logging directory for image.') help='VisualDL logging directory for image.')
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
main() main()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册