未验证 提交 37962dcb 编写于 作者: Z zhiboniu 提交者: GitHub

add DarkPsoe support (#3341)

* add DarkPsoe support

* modify Top-Down bbox_file str to bbox.json
上级 0b763b3d
......@@ -35,6 +35,7 @@
​ 目前KeyPoint模型基于coco数据集开发,其他数据集尚未验证
​ 请参考PaddleDetection[数据准备部分](https://github.com/PaddlePaddle/PaddleDetection/blob/f0a30f3ba6095ebfdc8fffb6d02766406afc438a/docs/tutorials/PrepareDataSet.md)部署准备COCO数据集即可
请注意,Top-Down方案使用检测框测试时,需要给予检测模型生成bbox.json文件,或者从网上[下载地址](https://paddledet.bj.bcebos.com/data/bbox.json)下载后放在根目录(PaddleDetection)下,然后修改config配置文件中use_gt_bbox: False后生效。然后正常执行测试命令即可。
### 3、训练与测试
......
use_gpu: true
log_iter: 5
save_dir: output
snapshot_epoch: 10
weights: output/hrnet_w32_256x192/model_final
epoch: 210
num_joints: &num_joints 17
pixel_std: &pixel_std 200
metric: KeyPointTopDownCOCOEval
num_classes: 1
train_height: &train_height 256
train_width: &train_width 192
trainsize: &trainsize [*train_width, *train_height]
hmsize: &hmsize [48, 64]
flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
#####model
architecture: TopDownHRNet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/Trunc_HRNet_W32_C_pretrained.pdparams
TopDownHRNet:
backbone: HRNet
post_process: HRNetPostProcess
flip_perm: *flip_perm
num_joints: *num_joints
width: &width 32
loss: KeyPointMSELoss
HRNet:
width: *width
freeze_at: -1
freeze_norm: false
return_idx: [0]
KeyPointMSELoss:
use_target_weight: true
#####optimizer
LearningRate:
base_lr: 0.0005
schedulers:
- !PiecewiseDecay
milestones: [170, 200]
gamma: 0.1
- !LinearWarmup
start_factor: 0.001
steps: 1000
OptimizerBuilder:
optimizer:
type: Adam
regularizer:
factor: 0.0
type: L2
#####data
TrainDataset:
!KeypointTopDownCocoDataset
image_dir: train2017
anno_path: annotations/person_keypoints_train2017.json
dataset_dir: dataset/coco
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
use_gt_bbox: True
EvalDataset:
!KeypointTopDownCocoDataset
image_dir: val2017
anno_path: annotations/person_keypoints_val2017.json
dataset_dir: dataset/coco
bbox_file: bbox.json
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
use_gt_bbox: True
image_thre: 0.0
TestDataset:
!ImageFolder
anno_path: dataset/coco/keypoint_imagelist.txt
worker_num: 2
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
TrainReader:
sample_transforms:
- RandomFlipHalfBodyTransform:
scale: 0.5
rot: 40
num_joints_half_body: 8
prob_half_body: 0.3
pixel_std: *pixel_std
trainsize: *trainsize
upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
flip_pairs: *flip_perm
- TopDownAffine:
trainsize: *trainsize
- ToHeatmapsTopDown_DARK:
hmsize: *hmsize
sigma: 2
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 64
shuffle: true
drop_last: false
EvalReader:
sample_transforms:
- TopDownAffine:
trainsize: *trainsize
- ToHeatmapsTopDown_DARK:
hmsize: *hmsize
sigma: 2
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 16
drop_empty: false
TestReader:
sample_transforms:
- Decode: {}
- TopDownEvalAffine:
trainsize: *trainsize
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
use_gpu: true
log_iter: 5
save_dir: output
snapshot_epoch: 10
weights: output/hrnet_w48_256x192/model_final
epoch: 210
num_joints: &num_joints 17
pixel_std: &pixel_std 200
metric: KeyPointTopDownCOCOEval
num_classes: 1
train_height: &train_height 256
train_width: &train_width 192
trainsize: &trainsize [*train_width, *train_height]
hmsize: &hmsize [48, 64]
flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
#####model
architecture: TopDownHRNet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/Trunc_HRNet_W48_C_pretrained.pdparams
TopDownHRNet:
backbone: HRNet
post_process: HRNetPostProcess
flip_perm: *flip_perm
num_joints: *num_joints
width: &width 48
loss: KeyPointMSELoss
HRNet:
width: *width
freeze_at: -1
freeze_norm: false
return_idx: [0]
KeyPointMSELoss:
use_target_weight: true
#####optimizer
LearningRate:
base_lr: 0.0005
schedulers:
- !PiecewiseDecay
milestones: [170, 200]
gamma: 0.1
- !LinearWarmup
start_factor: 0.001
steps: 1000
OptimizerBuilder:
optimizer:
type: Adam
regularizer:
factor: 0.0
type: L2
#####data
TrainDataset:
!KeypointTopDownCocoDataset
image_dir: train2017
anno_path: annotations/person_keypoints_train2017.json
dataset_dir: dataset/coco
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
use_gt_bbox: True
EvalDataset:
!KeypointTopDownCocoDataset
image_dir: val2017
anno_path: annotations/person_keypoints_val2017.json
dataset_dir: dataset/coco
bbox_file: bbox.json
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
use_gt_bbox: True
image_thre: 0.0
TestDataset:
!ImageFolder
anno_path: dataset/coco/keypoint_imagelist.txt
worker_num: 2
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
TrainReader:
sample_transforms:
- RandomFlipHalfBodyTransform:
scale: 0.5
rot: 40
num_joints_half_body: 8
prob_half_body: 0.3
pixel_std: *pixel_std
trainsize: *trainsize
upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
flip_pairs: *flip_perm
- TopDownAffine:
trainsize: *trainsize
- ToHeatmapsTopDown_DARK:
hmsize: *hmsize
sigma: 2
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 64
shuffle: true
drop_last: false
EvalReader:
sample_transforms:
- TopDownAffine:
trainsize: *trainsize
- ToHeatmapsTopDown_DARK:
hmsize: *hmsize
sigma: 2
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 16
drop_empty: false
TestReader:
sample_transforms:
- Decode: {}
- TopDownEvalAffine:
trainsize: *trainsize
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
......@@ -73,7 +73,7 @@ EvalDataset:
image_dir: val2017
anno_path: annotations/person_keypoints_val2017.json
dataset_dir: dataset/coco
bbox_file: person_detection_results/COCO_val2017_detections_AP_H_56_person.json
bbox_file: bbox.json
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
......
......@@ -74,7 +74,7 @@ EvalDataset:
image_dir: val2017
anno_path: annotations/person_keypoints_val2017.json
dataset_dir: dataset/coco
bbox_file: person_detection_results/COCO_val2017_detections_AP_H_56_person.json
bbox_file: bbox.json
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
......
......@@ -68,7 +68,9 @@ def affine_backto_orgimages(keypoint_result, batch_records):
def topdown_unite_predict(detector, topdown_keypoint_detector, image_list):
for i, img_file in enumerate(image_list):
image, _ = decode_image(img_file, {})
results = detector.predict(image, FLAGS.det_threshold)
results = detector.predict([image], FLAGS.det_threshold)
if results['boxes_num'] == 0:
continue
batchs_images, det_rects = get_person_from_rect(image, results)
keypoint_vector = []
score_vector = []
......@@ -121,7 +123,7 @@ def topdown_unite_predict_video(detector, topdown_keypoint_detector, camera_id):
print('detect frame:%d' % (index))
frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = detector.predict(frame2, FLAGS.det_threshold)
results = detector.predict([frame2], FLAGS.det_threshold)
batchs_images, rect_vecotr = get_person_from_rect(frame2, results)
keypoint_vector = []
score_vector = []
......
......@@ -39,7 +39,8 @@ registered_ops = []
__all__ = [
'RandomAffine', 'KeyPointFlip', 'TagGenerate', 'ToHeatmaps',
'NormalizePermute', 'EvalAffine', 'RandomFlipHalfBodyTransform',
'TopDownAffine', 'ToHeatmapsTopDown', 'TopDownEvalAffine'
'TopDownAffine', 'ToHeatmapsTopDown', 'ToHeatmapsTopDown_DARK',
'TopDownEvalAffine'
]
......@@ -393,6 +394,9 @@ class ToHeatmaps(object):
dul = np.clip(ul, 0, hmsize - 1)
dbr = np.clip(br, 0, hmsize)
for i in range(len(visible)):
if visible[i][0] < 0 or visible[i][1] < 0 or visible[i][
0] >= hmsize or visible[i][1] >= hmsize:
continue
dx1, dy1 = dul[i]
dx2, dy2 = dbr[i]
sx1, sy1 = sul[i]
......@@ -551,13 +555,16 @@ class TopDownAffine(object):
rot = records['rotate'] if "rotate" in records else 0
trans = get_affine_transform(records['center'], records['scale'] * 200,
rot, self.trainsize)
trans_joint = get_affine_transform(
records['center'], records['scale'] * 200, rot,
[self.trainsize[0] / 4, self.trainsize[1] / 4])
image = cv2.warpAffine(
image,
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR)
for i in range(joints.shape[0]):
if joints_vis[i, 0] > 0.0:
joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
joints[i, 0:2] = affine_transform(joints[i, 0:2], trans_joint)
records['image'] = image
records['joints'] = joints
......@@ -628,8 +635,8 @@ class ToHeatmapsTopDown(object):
tmp_size = self.sigma * 3
for joint_id in range(num_joints):
feat_stride = image_size / self.hmsize
mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)
mu_x = int(joints[joint_id][0] + 0.5)
mu_y = int(joints[joint_id][1] + 0.5)
# Check that any part of the gaussian is in-bounds
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
......@@ -662,3 +669,58 @@ class ToHeatmapsTopDown(object):
del records['joints'], records['joints_vis']
return records
@register_keypointop
class ToHeatmapsTopDown_DARK(object):
"""to generate the gaussin heatmaps of keypoint for heatmap loss
Args:
hmsize (list): [w, h] output heatmap's size
sigma (float): the std of gaussin kernel genereted
records(dict): the dict contained the image and coords
Returns:
records (dict): contain the heatmaps used to heatmaploss
"""
def __init__(self, hmsize, sigma):
super(ToHeatmapsTopDown_DARK, self).__init__()
self.hmsize = np.array(hmsize)
self.sigma = sigma
def __call__(self, records):
joints = records['joints']
joints_vis = records['joints_vis']
num_joints = joints.shape[0]
target_weight = np.ones((num_joints, 1), dtype=np.float32)
target_weight[:, 0] = joints_vis[:, 0]
target = np.zeros(
(num_joints, self.hmsize[1], self.hmsize[0]), dtype=np.float32)
tmp_size = self.sigma * 3
for joint_id in range(num_joints):
mu_x = joints[joint_id][0]
mu_y = joints[joint_id][1]
# Check that any part of the gaussian is in-bounds
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
if ul[0] >= self.hmsize[0] or ul[1] >= self.hmsize[1] or br[
0] < 0 or br[1] < 0:
# If not, just return the image as is
target_weight[joint_id] = 0
continue
x = np.arange(0, self.hmsize[0], 1, np.float32)
y = np.arange(0, self.hmsize[1], 1, np.float32)
y = y[:, np.newaxis]
v = target_weight[joint_id]
if v > 0.5:
target[joint_id] = np.exp(-(
(x - mu_x)**2 + (y - mu_y)**2) / (2 * self.sigma**2))
records['target'] = target
records['target_weight'] = target_weight
del records['joints'], records['joints_vis']
return records
......@@ -19,6 +19,7 @@ from __future__ import print_function
import paddle
import numpy as np
import math
import cv2
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
from ..keypoint_utils import transform_preds
......@@ -118,6 +119,9 @@ class TopDownHRNet(BaseArch):
class HRNetPostProcess(object):
def __init__(self, use_dark=True):
self.use_dark = use_dark
def get_max_preds(self, heatmaps):
'''get predictions from score maps
......@@ -154,7 +158,54 @@ class HRNetPostProcess(object):
return preds, maxvals
def get_final_preds(self, heatmaps, center, scale):
def gaussian_blur(self, heatmap, kernel):
border = (kernel - 1) // 2
batch_size = heatmap.shape[0]
num_joints = heatmap.shape[1]
height = heatmap.shape[2]
width = heatmap.shape[3]
for i in range(batch_size):
for j in range(num_joints):
origin_max = np.max(heatmap[i, j])
dr = np.zeros((height + 2 * border, width + 2 * border))
dr[border:-border, border:-border] = heatmap[i, j].copy()
dr = cv2.GaussianBlur(dr, (kernel, kernel), 0)
heatmap[i, j] = dr[border:-border, border:-border].copy()
heatmap[i, j] *= origin_max / np.max(heatmap[i, j])
return heatmap
def dark_parse(self, hm, coord):
heatmap_height = hm.shape[0]
heatmap_width = hm.shape[1]
px = int(coord[0])
py = int(coord[1])
if 1 < px < heatmap_width - 2 and 1 < py < heatmap_height - 2:
dx = 0.5 * (hm[py][px + 1] - hm[py][px - 1])
dy = 0.5 * (hm[py + 1][px] - hm[py - 1][px])
dxx = 0.25 * (hm[py][px + 2] - 2 * hm[py][px] + hm[py][px - 2])
dxy = 0.25 * (hm[py+1][px+1] - hm[py-1][px+1] - hm[py+1][px-1] \
+ hm[py-1][px-1])
dyy = 0.25 * (
hm[py + 2 * 1][px] - 2 * hm[py][px] + hm[py - 2 * 1][px])
derivative = np.matrix([[dx], [dy]])
hessian = np.matrix([[dxx, dxy], [dxy, dyy]])
if dxx * dyy - dxy**2 != 0:
hessianinv = hessian.I
offset = -hessianinv * derivative
offset = np.squeeze(np.array(offset.T), axis=0)
coord += offset
return coord
def dark_postprocess(self, hm, coords, kernelsize):
hm = self.gaussian_blur(hm, kernelsize)
hm = np.maximum(hm, 1e-10)
hm = np.log(hm)
for n in range(coords.shape[0]):
for p in range(coords.shape[1]):
coords[n, p] = self.dark_parse(hm[n][p], coords[n][p])
return coords
def get_final_preds(self, heatmaps, center, scale, kernelsize=3):
"""the highest heatvalue location with a quarter offset in the
direction from the highest response to the second highest response.
......@@ -173,17 +224,20 @@ class HRNetPostProcess(object):
heatmap_height = heatmaps.shape[2]
heatmap_width = heatmaps.shape[3]
for n in range(coords.shape[0]):
for p in range(coords.shape[1]):
hm = heatmaps[n][p]
px = int(math.floor(coords[n][p][0] + 0.5))
py = int(math.floor(coords[n][p][1] + 0.5))
if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
diff = np.array([
hm[py][px + 1] - hm[py][px - 1],
hm[py + 1][px] - hm[py - 1][px]
])
coords[n][p] += np.sign(diff) * .25
if self.use_dark:
coords = self.dark_postprocess(heatmaps, coords, kernelsize)
else:
for n in range(coords.shape[0]):
for p in range(coords.shape[1]):
hm = heatmaps[n][p]
px = int(math.floor(coords[n][p][0] + 0.5))
py = int(math.floor(coords[n][p][1] + 0.5))
if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
diff = np.array([
hm[py][px + 1] - hm[py][px - 1],
hm[py + 1][px] - hm[py - 1][px]
])
coords[n][p] += np.sign(diff) * .25
preds = coords.copy()
# Transform back
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册