未验证 提交 37962dcb 编写于 作者: Z zhiboniu 提交者: GitHub

add DarkPsoe support (#3341)

* add DarkPsoe support

* modify Top-Down bbox_file str to bbox.json
上级 0b763b3d
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
​ 目前KeyPoint模型基于coco数据集开发,其他数据集尚未验证 ​ 目前KeyPoint模型基于coco数据集开发,其他数据集尚未验证
​ 请参考PaddleDetection[数据准备部分](https://github.com/PaddlePaddle/PaddleDetection/blob/f0a30f3ba6095ebfdc8fffb6d02766406afc438a/docs/tutorials/PrepareDataSet.md)部署准备COCO数据集即可 ​ 请参考PaddleDetection[数据准备部分](https://github.com/PaddlePaddle/PaddleDetection/blob/f0a30f3ba6095ebfdc8fffb6d02766406afc438a/docs/tutorials/PrepareDataSet.md)部署准备COCO数据集即可
请注意,Top-Down方案使用检测框测试时,需要给予检测模型生成bbox.json文件,或者从网上[下载地址](https://paddledet.bj.bcebos.com/data/bbox.json)下载后放在根目录(PaddleDetection)下,然后修改config配置文件中use_gt_bbox: False后生效。然后正常执行测试命令即可。
### 3、训练与测试 ### 3、训练与测试
......
use_gpu: true
log_iter: 5
save_dir: output
snapshot_epoch: 10
weights: output/hrnet_w32_256x192/model_final
epoch: 210
num_joints: &num_joints 17
pixel_std: &pixel_std 200
metric: KeyPointTopDownCOCOEval
num_classes: 1
train_height: &train_height 256
train_width: &train_width 192
trainsize: &trainsize [*train_width, *train_height]
hmsize: &hmsize [48, 64]
flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
#####model
architecture: TopDownHRNet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/Trunc_HRNet_W32_C_pretrained.pdparams
TopDownHRNet:
backbone: HRNet
post_process: HRNetPostProcess
flip_perm: *flip_perm
num_joints: *num_joints
width: &width 32
loss: KeyPointMSELoss
HRNet:
width: *width
freeze_at: -1
freeze_norm: false
return_idx: [0]
KeyPointMSELoss:
use_target_weight: true
#####optimizer
LearningRate:
base_lr: 0.0005
schedulers:
- !PiecewiseDecay
milestones: [170, 200]
gamma: 0.1
- !LinearWarmup
start_factor: 0.001
steps: 1000
OptimizerBuilder:
optimizer:
type: Adam
regularizer:
factor: 0.0
type: L2
#####data
TrainDataset:
!KeypointTopDownCocoDataset
image_dir: train2017
anno_path: annotations/person_keypoints_train2017.json
dataset_dir: dataset/coco
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
use_gt_bbox: True
EvalDataset:
!KeypointTopDownCocoDataset
image_dir: val2017
anno_path: annotations/person_keypoints_val2017.json
dataset_dir: dataset/coco
bbox_file: bbox.json
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
use_gt_bbox: True
image_thre: 0.0
TestDataset:
!ImageFolder
anno_path: dataset/coco/keypoint_imagelist.txt
worker_num: 2
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
TrainReader:
sample_transforms:
- RandomFlipHalfBodyTransform:
scale: 0.5
rot: 40
num_joints_half_body: 8
prob_half_body: 0.3
pixel_std: *pixel_std
trainsize: *trainsize
upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
flip_pairs: *flip_perm
- TopDownAffine:
trainsize: *trainsize
- ToHeatmapsTopDown_DARK:
hmsize: *hmsize
sigma: 2
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 64
shuffle: true
drop_last: false
EvalReader:
sample_transforms:
- TopDownAffine:
trainsize: *trainsize
- ToHeatmapsTopDown_DARK:
hmsize: *hmsize
sigma: 2
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 16
drop_empty: false
TestReader:
sample_transforms:
- Decode: {}
- TopDownEvalAffine:
trainsize: *trainsize
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
use_gpu: true
log_iter: 5
save_dir: output
snapshot_epoch: 10
weights: output/hrnet_w48_256x192/model_final
epoch: 210
num_joints: &num_joints 17
pixel_std: &pixel_std 200
metric: KeyPointTopDownCOCOEval
num_classes: 1
train_height: &train_height 256
train_width: &train_width 192
trainsize: &trainsize [*train_width, *train_height]
hmsize: &hmsize [48, 64]
flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
#####model
architecture: TopDownHRNet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/Trunc_HRNet_W48_C_pretrained.pdparams
TopDownHRNet:
backbone: HRNet
post_process: HRNetPostProcess
flip_perm: *flip_perm
num_joints: *num_joints
width: &width 48
loss: KeyPointMSELoss
HRNet:
width: *width
freeze_at: -1
freeze_norm: false
return_idx: [0]
KeyPointMSELoss:
use_target_weight: true
#####optimizer
LearningRate:
base_lr: 0.0005
schedulers:
- !PiecewiseDecay
milestones: [170, 200]
gamma: 0.1
- !LinearWarmup
start_factor: 0.001
steps: 1000
OptimizerBuilder:
optimizer:
type: Adam
regularizer:
factor: 0.0
type: L2
#####data
TrainDataset:
!KeypointTopDownCocoDataset
image_dir: train2017
anno_path: annotations/person_keypoints_train2017.json
dataset_dir: dataset/coco
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
use_gt_bbox: True
EvalDataset:
!KeypointTopDownCocoDataset
image_dir: val2017
anno_path: annotations/person_keypoints_val2017.json
dataset_dir: dataset/coco
bbox_file: bbox.json
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
use_gt_bbox: True
image_thre: 0.0
TestDataset:
!ImageFolder
anno_path: dataset/coco/keypoint_imagelist.txt
worker_num: 2
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
TrainReader:
sample_transforms:
- RandomFlipHalfBodyTransform:
scale: 0.5
rot: 40
num_joints_half_body: 8
prob_half_body: 0.3
pixel_std: *pixel_std
trainsize: *trainsize
upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
flip_pairs: *flip_perm
- TopDownAffine:
trainsize: *trainsize
- ToHeatmapsTopDown_DARK:
hmsize: *hmsize
sigma: 2
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 64
shuffle: true
drop_last: false
EvalReader:
sample_transforms:
- TopDownAffine:
trainsize: *trainsize
- ToHeatmapsTopDown_DARK:
hmsize: *hmsize
sigma: 2
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 16
drop_empty: false
TestReader:
sample_transforms:
- Decode: {}
- TopDownEvalAffine:
trainsize: *trainsize
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
...@@ -73,7 +73,7 @@ EvalDataset: ...@@ -73,7 +73,7 @@ EvalDataset:
image_dir: val2017 image_dir: val2017
anno_path: annotations/person_keypoints_val2017.json anno_path: annotations/person_keypoints_val2017.json
dataset_dir: dataset/coco dataset_dir: dataset/coco
bbox_file: person_detection_results/COCO_val2017_detections_AP_H_56_person.json bbox_file: bbox.json
num_joints: *num_joints num_joints: *num_joints
trainsize: *trainsize trainsize: *trainsize
pixel_std: *pixel_std pixel_std: *pixel_std
......
...@@ -74,7 +74,7 @@ EvalDataset: ...@@ -74,7 +74,7 @@ EvalDataset:
image_dir: val2017 image_dir: val2017
anno_path: annotations/person_keypoints_val2017.json anno_path: annotations/person_keypoints_val2017.json
dataset_dir: dataset/coco dataset_dir: dataset/coco
bbox_file: person_detection_results/COCO_val2017_detections_AP_H_56_person.json bbox_file: bbox.json
num_joints: *num_joints num_joints: *num_joints
trainsize: *trainsize trainsize: *trainsize
pixel_std: *pixel_std pixel_std: *pixel_std
......
...@@ -68,7 +68,9 @@ def affine_backto_orgimages(keypoint_result, batch_records): ...@@ -68,7 +68,9 @@ def affine_backto_orgimages(keypoint_result, batch_records):
def topdown_unite_predict(detector, topdown_keypoint_detector, image_list): def topdown_unite_predict(detector, topdown_keypoint_detector, image_list):
for i, img_file in enumerate(image_list): for i, img_file in enumerate(image_list):
image, _ = decode_image(img_file, {}) image, _ = decode_image(img_file, {})
results = detector.predict(image, FLAGS.det_threshold) results = detector.predict([image], FLAGS.det_threshold)
if results['boxes_num'] == 0:
continue
batchs_images, det_rects = get_person_from_rect(image, results) batchs_images, det_rects = get_person_from_rect(image, results)
keypoint_vector = [] keypoint_vector = []
score_vector = [] score_vector = []
...@@ -121,7 +123,7 @@ def topdown_unite_predict_video(detector, topdown_keypoint_detector, camera_id): ...@@ -121,7 +123,7 @@ def topdown_unite_predict_video(detector, topdown_keypoint_detector, camera_id):
print('detect frame:%d' % (index)) print('detect frame:%d' % (index))
frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = detector.predict(frame2, FLAGS.det_threshold) results = detector.predict([frame2], FLAGS.det_threshold)
batchs_images, rect_vecotr = get_person_from_rect(frame2, results) batchs_images, rect_vecotr = get_person_from_rect(frame2, results)
keypoint_vector = [] keypoint_vector = []
score_vector = [] score_vector = []
......
...@@ -39,7 +39,8 @@ registered_ops = [] ...@@ -39,7 +39,8 @@ registered_ops = []
__all__ = [ __all__ = [
'RandomAffine', 'KeyPointFlip', 'TagGenerate', 'ToHeatmaps', 'RandomAffine', 'KeyPointFlip', 'TagGenerate', 'ToHeatmaps',
'NormalizePermute', 'EvalAffine', 'RandomFlipHalfBodyTransform', 'NormalizePermute', 'EvalAffine', 'RandomFlipHalfBodyTransform',
'TopDownAffine', 'ToHeatmapsTopDown', 'TopDownEvalAffine' 'TopDownAffine', 'ToHeatmapsTopDown', 'ToHeatmapsTopDown_DARK',
'TopDownEvalAffine'
] ]
...@@ -393,6 +394,9 @@ class ToHeatmaps(object): ...@@ -393,6 +394,9 @@ class ToHeatmaps(object):
dul = np.clip(ul, 0, hmsize - 1) dul = np.clip(ul, 0, hmsize - 1)
dbr = np.clip(br, 0, hmsize) dbr = np.clip(br, 0, hmsize)
for i in range(len(visible)): for i in range(len(visible)):
if visible[i][0] < 0 or visible[i][1] < 0 or visible[i][
0] >= hmsize or visible[i][1] >= hmsize:
continue
dx1, dy1 = dul[i] dx1, dy1 = dul[i]
dx2, dy2 = dbr[i] dx2, dy2 = dbr[i]
sx1, sy1 = sul[i] sx1, sy1 = sul[i]
...@@ -551,13 +555,16 @@ class TopDownAffine(object): ...@@ -551,13 +555,16 @@ class TopDownAffine(object):
rot = records['rotate'] if "rotate" in records else 0 rot = records['rotate'] if "rotate" in records else 0
trans = get_affine_transform(records['center'], records['scale'] * 200, trans = get_affine_transform(records['center'], records['scale'] * 200,
rot, self.trainsize) rot, self.trainsize)
trans_joint = get_affine_transform(
records['center'], records['scale'] * 200, rot,
[self.trainsize[0] / 4, self.trainsize[1] / 4])
image = cv2.warpAffine( image = cv2.warpAffine(
image, image,
trans, (int(self.trainsize[0]), int(self.trainsize[1])), trans, (int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR) flags=cv2.INTER_LINEAR)
for i in range(joints.shape[0]): for i in range(joints.shape[0]):
if joints_vis[i, 0] > 0.0: if joints_vis[i, 0] > 0.0:
joints[i, 0:2] = affine_transform(joints[i, 0:2], trans) joints[i, 0:2] = affine_transform(joints[i, 0:2], trans_joint)
records['image'] = image records['image'] = image
records['joints'] = joints records['joints'] = joints
...@@ -628,8 +635,8 @@ class ToHeatmapsTopDown(object): ...@@ -628,8 +635,8 @@ class ToHeatmapsTopDown(object):
tmp_size = self.sigma * 3 tmp_size = self.sigma * 3
for joint_id in range(num_joints): for joint_id in range(num_joints):
feat_stride = image_size / self.hmsize feat_stride = image_size / self.hmsize
mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5) mu_x = int(joints[joint_id][0] + 0.5)
mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5) mu_y = int(joints[joint_id][1] + 0.5)
# Check that any part of the gaussian is in-bounds # Check that any part of the gaussian is in-bounds
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
...@@ -662,3 +669,58 @@ class ToHeatmapsTopDown(object): ...@@ -662,3 +669,58 @@ class ToHeatmapsTopDown(object):
del records['joints'], records['joints_vis'] del records['joints'], records['joints_vis']
return records return records
@register_keypointop
class ToHeatmapsTopDown_DARK(object):
"""to generate the gaussin heatmaps of keypoint for heatmap loss
Args:
hmsize (list): [w, h] output heatmap's size
sigma (float): the std of gaussin kernel genereted
records(dict): the dict contained the image and coords
Returns:
records (dict): contain the heatmaps used to heatmaploss
"""
def __init__(self, hmsize, sigma):
super(ToHeatmapsTopDown_DARK, self).__init__()
self.hmsize = np.array(hmsize)
self.sigma = sigma
def __call__(self, records):
joints = records['joints']
joints_vis = records['joints_vis']
num_joints = joints.shape[0]
target_weight = np.ones((num_joints, 1), dtype=np.float32)
target_weight[:, 0] = joints_vis[:, 0]
target = np.zeros(
(num_joints, self.hmsize[1], self.hmsize[0]), dtype=np.float32)
tmp_size = self.sigma * 3
for joint_id in range(num_joints):
mu_x = joints[joint_id][0]
mu_y = joints[joint_id][1]
# Check that any part of the gaussian is in-bounds
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
if ul[0] >= self.hmsize[0] or ul[1] >= self.hmsize[1] or br[
0] < 0 or br[1] < 0:
# If not, just return the image as is
target_weight[joint_id] = 0
continue
x = np.arange(0, self.hmsize[0], 1, np.float32)
y = np.arange(0, self.hmsize[1], 1, np.float32)
y = y[:, np.newaxis]
v = target_weight[joint_id]
if v > 0.5:
target[joint_id] = np.exp(-(
(x - mu_x)**2 + (y - mu_y)**2) / (2 * self.sigma**2))
records['target'] = target
records['target_weight'] = target_weight
del records['joints'], records['joints_vis']
return records
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
import paddle import paddle
import numpy as np import numpy as np
import math import math
import cv2
from ppdet.core.workspace import register, create from ppdet.core.workspace import register, create
from .meta_arch import BaseArch from .meta_arch import BaseArch
from ..keypoint_utils import transform_preds from ..keypoint_utils import transform_preds
...@@ -118,6 +119,9 @@ class TopDownHRNet(BaseArch): ...@@ -118,6 +119,9 @@ class TopDownHRNet(BaseArch):
class HRNetPostProcess(object): class HRNetPostProcess(object):
def __init__(self, use_dark=True):
self.use_dark = use_dark
def get_max_preds(self, heatmaps): def get_max_preds(self, heatmaps):
'''get predictions from score maps '''get predictions from score maps
...@@ -154,7 +158,54 @@ class HRNetPostProcess(object): ...@@ -154,7 +158,54 @@ class HRNetPostProcess(object):
return preds, maxvals return preds, maxvals
def get_final_preds(self, heatmaps, center, scale): def gaussian_blur(self, heatmap, kernel):
border = (kernel - 1) // 2
batch_size = heatmap.shape[0]
num_joints = heatmap.shape[1]
height = heatmap.shape[2]
width = heatmap.shape[3]
for i in range(batch_size):
for j in range(num_joints):
origin_max = np.max(heatmap[i, j])
dr = np.zeros((height + 2 * border, width + 2 * border))
dr[border:-border, border:-border] = heatmap[i, j].copy()
dr = cv2.GaussianBlur(dr, (kernel, kernel), 0)
heatmap[i, j] = dr[border:-border, border:-border].copy()
heatmap[i, j] *= origin_max / np.max(heatmap[i, j])
return heatmap
def dark_parse(self, hm, coord):
heatmap_height = hm.shape[0]
heatmap_width = hm.shape[1]
px = int(coord[0])
py = int(coord[1])
if 1 < px < heatmap_width - 2 and 1 < py < heatmap_height - 2:
dx = 0.5 * (hm[py][px + 1] - hm[py][px - 1])
dy = 0.5 * (hm[py + 1][px] - hm[py - 1][px])
dxx = 0.25 * (hm[py][px + 2] - 2 * hm[py][px] + hm[py][px - 2])
dxy = 0.25 * (hm[py+1][px+1] - hm[py-1][px+1] - hm[py+1][px-1] \
+ hm[py-1][px-1])
dyy = 0.25 * (
hm[py + 2 * 1][px] - 2 * hm[py][px] + hm[py - 2 * 1][px])
derivative = np.matrix([[dx], [dy]])
hessian = np.matrix([[dxx, dxy], [dxy, dyy]])
if dxx * dyy - dxy**2 != 0:
hessianinv = hessian.I
offset = -hessianinv * derivative
offset = np.squeeze(np.array(offset.T), axis=0)
coord += offset
return coord
def dark_postprocess(self, hm, coords, kernelsize):
hm = self.gaussian_blur(hm, kernelsize)
hm = np.maximum(hm, 1e-10)
hm = np.log(hm)
for n in range(coords.shape[0]):
for p in range(coords.shape[1]):
coords[n, p] = self.dark_parse(hm[n][p], coords[n][p])
return coords
def get_final_preds(self, heatmaps, center, scale, kernelsize=3):
"""the highest heatvalue location with a quarter offset in the """the highest heatvalue location with a quarter offset in the
direction from the highest response to the second highest response. direction from the highest response to the second highest response.
...@@ -173,17 +224,20 @@ class HRNetPostProcess(object): ...@@ -173,17 +224,20 @@ class HRNetPostProcess(object):
heatmap_height = heatmaps.shape[2] heatmap_height = heatmaps.shape[2]
heatmap_width = heatmaps.shape[3] heatmap_width = heatmaps.shape[3]
for n in range(coords.shape[0]): if self.use_dark:
for p in range(coords.shape[1]): coords = self.dark_postprocess(heatmaps, coords, kernelsize)
hm = heatmaps[n][p] else:
px = int(math.floor(coords[n][p][0] + 0.5)) for n in range(coords.shape[0]):
py = int(math.floor(coords[n][p][1] + 0.5)) for p in range(coords.shape[1]):
if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1: hm = heatmaps[n][p]
diff = np.array([ px = int(math.floor(coords[n][p][0] + 0.5))
hm[py][px + 1] - hm[py][px - 1], py = int(math.floor(coords[n][p][1] + 0.5))
hm[py + 1][px] - hm[py - 1][px] if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
]) diff = np.array([
coords[n][p] += np.sign(diff) * .25 hm[py][px + 1] - hm[py][px - 1],
hm[py + 1][px] - hm[py - 1][px]
])
coords[n][p] += np.sign(diff) * .25
preds = coords.copy() preds = coords.copy()
# Transform back # Transform back
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册