From 37962dcb40acd4dc613ee64d84a903253b4ec054 Mon Sep 17 00:00:00 2001 From: zhiboniu <31800336+zhiboniu@users.noreply.github.com> Date: Thu, 10 Jun 2021 21:15:22 +0800 Subject: [PATCH] add DarkPsoe support (#3341) * add DarkPsoe support * modify Top-Down bbox_file str to bbox.json --- configs/keypoint/README.md | 1 + .../keypoint/hrnet/dark_hrnet_w32_256x192.yml | 143 ++++++++++++++++++ .../keypoint/hrnet/dark_hrnet_w48_256x192.yml | 143 ++++++++++++++++++ configs/keypoint/hrnet/hrnet_w32_256x192.yml | 2 +- configs/keypoint/hrnet/hrnet_w32_384x288.yml | 2 +- deploy/python/keypoint_det_unite_infer.py | 6 +- ppdet/data/transform/keypoint_operators.py | 70 ++++++++- .../modeling/architectures/keypoint_hrnet.py | 78 ++++++++-- 8 files changed, 425 insertions(+), 20 deletions(-) create mode 100644 configs/keypoint/hrnet/dark_hrnet_w32_256x192.yml create mode 100644 configs/keypoint/hrnet/dark_hrnet_w48_256x192.yml diff --git a/configs/keypoint/README.md b/configs/keypoint/README.md index b4285bc9b..246a3bf54 100644 --- a/configs/keypoint/README.md +++ b/configs/keypoint/README.md @@ -35,6 +35,7 @@ ​ 目前KeyPoint模型基于coco数据集开发,其他数据集尚未验证 ​ 请参考PaddleDetection[数据准备部分](https://github.com/PaddlePaddle/PaddleDetection/blob/f0a30f3ba6095ebfdc8fffb6d02766406afc438a/docs/tutorials/PrepareDataSet.md)部署准备COCO数据集即可 + 请注意,Top-Down方案使用检测框测试时,需要给予检测模型生成bbox.json文件,或者从网上[下载地址](https://paddledet.bj.bcebos.com/data/bbox.json)下载后放在根目录(PaddleDetection)下,然后修改config配置文件中use_gt_bbox: False后生效。然后正常执行测试命令即可。 ### 3、训练与测试 diff --git a/configs/keypoint/hrnet/dark_hrnet_w32_256x192.yml b/configs/keypoint/hrnet/dark_hrnet_w32_256x192.yml new file mode 100644 index 000000000..5cc59fb2c --- /dev/null +++ b/configs/keypoint/hrnet/dark_hrnet_w32_256x192.yml @@ -0,0 +1,143 @@ +use_gpu: true +log_iter: 5 +save_dir: output +snapshot_epoch: 10 +weights: output/hrnet_w32_256x192/model_final +epoch: 210 +num_joints: &num_joints 17 +pixel_std: &pixel_std 200 +metric: KeyPointTopDownCOCOEval +num_classes: 1 +train_height: &train_height 256 +train_width: &train_width 192 +trainsize: &trainsize [*train_width, *train_height] +hmsize: &hmsize [48, 64] +flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] + + +#####model +architecture: TopDownHRNet +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/Trunc_HRNet_W32_C_pretrained.pdparams + +TopDownHRNet: + backbone: HRNet + post_process: HRNetPostProcess + flip_perm: *flip_perm + num_joints: *num_joints + width: &width 32 + loss: KeyPointMSELoss + +HRNet: + width: *width + freeze_at: -1 + freeze_norm: false + return_idx: [0] + +KeyPointMSELoss: + use_target_weight: true + + +#####optimizer +LearningRate: + base_lr: 0.0005 + schedulers: + - !PiecewiseDecay + milestones: [170, 200] + gamma: 0.1 + - !LinearWarmup + start_factor: 0.001 + steps: 1000 + +OptimizerBuilder: + optimizer: + type: Adam + regularizer: + factor: 0.0 + type: L2 + + +#####data +TrainDataset: + !KeypointTopDownCocoDataset + image_dir: train2017 + anno_path: annotations/person_keypoints_train2017.json + dataset_dir: dataset/coco + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + use_gt_bbox: True + + +EvalDataset: + !KeypointTopDownCocoDataset + image_dir: val2017 + anno_path: annotations/person_keypoints_val2017.json + dataset_dir: dataset/coco + bbox_file: bbox.json + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + use_gt_bbox: True + image_thre: 0.0 + + +TestDataset: + !ImageFolder + anno_path: dataset/coco/keypoint_imagelist.txt + +worker_num: 2 +global_mean: &global_mean [0.485, 0.456, 0.406] +global_std: &global_std [0.229, 0.224, 0.225] +TrainReader: + sample_transforms: + - RandomFlipHalfBodyTransform: + scale: 0.5 + rot: 40 + num_joints_half_body: 8 + prob_half_body: 0.3 + pixel_std: *pixel_std + trainsize: *trainsize + upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + flip_pairs: *flip_perm + - TopDownAffine: + trainsize: *trainsize + - ToHeatmapsTopDown_DARK: + hmsize: *hmsize + sigma: 2 + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 64 + shuffle: true + drop_last: false + +EvalReader: + sample_transforms: + - TopDownAffine: + trainsize: *trainsize + - ToHeatmapsTopDown_DARK: + hmsize: *hmsize + sigma: 2 + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 16 + drop_empty: false + +TestReader: + sample_transforms: + - Decode: {} + - TopDownEvalAffine: + trainsize: *trainsize + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 1 diff --git a/configs/keypoint/hrnet/dark_hrnet_w48_256x192.yml b/configs/keypoint/hrnet/dark_hrnet_w48_256x192.yml new file mode 100644 index 000000000..901c0ad83 --- /dev/null +++ b/configs/keypoint/hrnet/dark_hrnet_w48_256x192.yml @@ -0,0 +1,143 @@ +use_gpu: true +log_iter: 5 +save_dir: output +snapshot_epoch: 10 +weights: output/hrnet_w48_256x192/model_final +epoch: 210 +num_joints: &num_joints 17 +pixel_std: &pixel_std 200 +metric: KeyPointTopDownCOCOEval +num_classes: 1 +train_height: &train_height 256 +train_width: &train_width 192 +trainsize: &trainsize [*train_width, *train_height] +hmsize: &hmsize [48, 64] +flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] + + +#####model +architecture: TopDownHRNet +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/Trunc_HRNet_W48_C_pretrained.pdparams + +TopDownHRNet: + backbone: HRNet + post_process: HRNetPostProcess + flip_perm: *flip_perm + num_joints: *num_joints + width: &width 48 + loss: KeyPointMSELoss + +HRNet: + width: *width + freeze_at: -1 + freeze_norm: false + return_idx: [0] + +KeyPointMSELoss: + use_target_weight: true + + +#####optimizer +LearningRate: + base_lr: 0.0005 + schedulers: + - !PiecewiseDecay + milestones: [170, 200] + gamma: 0.1 + - !LinearWarmup + start_factor: 0.001 + steps: 1000 + +OptimizerBuilder: + optimizer: + type: Adam + regularizer: + factor: 0.0 + type: L2 + + +#####data +TrainDataset: + !KeypointTopDownCocoDataset + image_dir: train2017 + anno_path: annotations/person_keypoints_train2017.json + dataset_dir: dataset/coco + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + use_gt_bbox: True + + +EvalDataset: + !KeypointTopDownCocoDataset + image_dir: val2017 + anno_path: annotations/person_keypoints_val2017.json + dataset_dir: dataset/coco + bbox_file: bbox.json + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + use_gt_bbox: True + image_thre: 0.0 + + +TestDataset: + !ImageFolder + anno_path: dataset/coco/keypoint_imagelist.txt + +worker_num: 2 +global_mean: &global_mean [0.485, 0.456, 0.406] +global_std: &global_std [0.229, 0.224, 0.225] +TrainReader: + sample_transforms: + - RandomFlipHalfBodyTransform: + scale: 0.5 + rot: 40 + num_joints_half_body: 8 + prob_half_body: 0.3 + pixel_std: *pixel_std + trainsize: *trainsize + upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + flip_pairs: *flip_perm + - TopDownAffine: + trainsize: *trainsize + - ToHeatmapsTopDown_DARK: + hmsize: *hmsize + sigma: 2 + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 64 + shuffle: true + drop_last: false + +EvalReader: + sample_transforms: + - TopDownAffine: + trainsize: *trainsize + - ToHeatmapsTopDown_DARK: + hmsize: *hmsize + sigma: 2 + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 16 + drop_empty: false + +TestReader: + sample_transforms: + - Decode: {} + - TopDownEvalAffine: + trainsize: *trainsize + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 1 diff --git a/configs/keypoint/hrnet/hrnet_w32_256x192.yml b/configs/keypoint/hrnet/hrnet_w32_256x192.yml index 4a63617de..50a838b9a 100644 --- a/configs/keypoint/hrnet/hrnet_w32_256x192.yml +++ b/configs/keypoint/hrnet/hrnet_w32_256x192.yml @@ -73,7 +73,7 @@ EvalDataset: image_dir: val2017 anno_path: annotations/person_keypoints_val2017.json dataset_dir: dataset/coco - bbox_file: person_detection_results/COCO_val2017_detections_AP_H_56_person.json + bbox_file: bbox.json num_joints: *num_joints trainsize: *trainsize pixel_std: *pixel_std diff --git a/configs/keypoint/hrnet/hrnet_w32_384x288.yml b/configs/keypoint/hrnet/hrnet_w32_384x288.yml index b7240ee65..72dc8120f 100644 --- a/configs/keypoint/hrnet/hrnet_w32_384x288.yml +++ b/configs/keypoint/hrnet/hrnet_w32_384x288.yml @@ -74,7 +74,7 @@ EvalDataset: image_dir: val2017 anno_path: annotations/person_keypoints_val2017.json dataset_dir: dataset/coco - bbox_file: person_detection_results/COCO_val2017_detections_AP_H_56_person.json + bbox_file: bbox.json num_joints: *num_joints trainsize: *trainsize pixel_std: *pixel_std diff --git a/deploy/python/keypoint_det_unite_infer.py b/deploy/python/keypoint_det_unite_infer.py index d0321873d..8dcd92c81 100644 --- a/deploy/python/keypoint_det_unite_infer.py +++ b/deploy/python/keypoint_det_unite_infer.py @@ -68,7 +68,9 @@ def affine_backto_orgimages(keypoint_result, batch_records): def topdown_unite_predict(detector, topdown_keypoint_detector, image_list): for i, img_file in enumerate(image_list): image, _ = decode_image(img_file, {}) - results = detector.predict(image, FLAGS.det_threshold) + results = detector.predict([image], FLAGS.det_threshold) + if results['boxes_num'] == 0: + continue batchs_images, det_rects = get_person_from_rect(image, results) keypoint_vector = [] score_vector = [] @@ -121,7 +123,7 @@ def topdown_unite_predict_video(detector, topdown_keypoint_detector, camera_id): print('detect frame:%d' % (index)) frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - results = detector.predict(frame2, FLAGS.det_threshold) + results = detector.predict([frame2], FLAGS.det_threshold) batchs_images, rect_vecotr = get_person_from_rect(frame2, results) keypoint_vector = [] score_vector = [] diff --git a/ppdet/data/transform/keypoint_operators.py b/ppdet/data/transform/keypoint_operators.py index 405ba52fe..f122e118f 100644 --- a/ppdet/data/transform/keypoint_operators.py +++ b/ppdet/data/transform/keypoint_operators.py @@ -39,7 +39,8 @@ registered_ops = [] __all__ = [ 'RandomAffine', 'KeyPointFlip', 'TagGenerate', 'ToHeatmaps', 'NormalizePermute', 'EvalAffine', 'RandomFlipHalfBodyTransform', - 'TopDownAffine', 'ToHeatmapsTopDown', 'TopDownEvalAffine' + 'TopDownAffine', 'ToHeatmapsTopDown', 'ToHeatmapsTopDown_DARK', + 'TopDownEvalAffine' ] @@ -393,6 +394,9 @@ class ToHeatmaps(object): dul = np.clip(ul, 0, hmsize - 1) dbr = np.clip(br, 0, hmsize) for i in range(len(visible)): + if visible[i][0] < 0 or visible[i][1] < 0 or visible[i][ + 0] >= hmsize or visible[i][1] >= hmsize: + continue dx1, dy1 = dul[i] dx2, dy2 = dbr[i] sx1, sy1 = sul[i] @@ -551,13 +555,16 @@ class TopDownAffine(object): rot = records['rotate'] if "rotate" in records else 0 trans = get_affine_transform(records['center'], records['scale'] * 200, rot, self.trainsize) + trans_joint = get_affine_transform( + records['center'], records['scale'] * 200, rot, + [self.trainsize[0] / 4, self.trainsize[1] / 4]) image = cv2.warpAffine( image, trans, (int(self.trainsize[0]), int(self.trainsize[1])), flags=cv2.INTER_LINEAR) for i in range(joints.shape[0]): if joints_vis[i, 0] > 0.0: - joints[i, 0:2] = affine_transform(joints[i, 0:2], trans) + joints[i, 0:2] = affine_transform(joints[i, 0:2], trans_joint) records['image'] = image records['joints'] = joints @@ -628,8 +635,8 @@ class ToHeatmapsTopDown(object): tmp_size = self.sigma * 3 for joint_id in range(num_joints): feat_stride = image_size / self.hmsize - mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5) - mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5) + mu_x = int(joints[joint_id][0] + 0.5) + mu_y = int(joints[joint_id][1] + 0.5) # Check that any part of the gaussian is in-bounds ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] @@ -662,3 +669,58 @@ class ToHeatmapsTopDown(object): del records['joints'], records['joints_vis'] return records + + +@register_keypointop +class ToHeatmapsTopDown_DARK(object): + """to generate the gaussin heatmaps of keypoint for heatmap loss + + Args: + hmsize (list): [w, h] output heatmap's size + sigma (float): the std of gaussin kernel genereted + records(dict): the dict contained the image and coords + + Returns: + records (dict): contain the heatmaps used to heatmaploss + + """ + + def __init__(self, hmsize, sigma): + super(ToHeatmapsTopDown_DARK, self).__init__() + self.hmsize = np.array(hmsize) + self.sigma = sigma + + def __call__(self, records): + joints = records['joints'] + joints_vis = records['joints_vis'] + num_joints = joints.shape[0] + target_weight = np.ones((num_joints, 1), dtype=np.float32) + target_weight[:, 0] = joints_vis[:, 0] + target = np.zeros( + (num_joints, self.hmsize[1], self.hmsize[0]), dtype=np.float32) + tmp_size = self.sigma * 3 + for joint_id in range(num_joints): + mu_x = joints[joint_id][0] + mu_y = joints[joint_id][1] + # Check that any part of the gaussian is in-bounds + ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] + br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] + if ul[0] >= self.hmsize[0] or ul[1] >= self.hmsize[1] or br[ + 0] < 0 or br[1] < 0: + # If not, just return the image as is + target_weight[joint_id] = 0 + continue + + x = np.arange(0, self.hmsize[0], 1, np.float32) + y = np.arange(0, self.hmsize[1], 1, np.float32) + y = y[:, np.newaxis] + + v = target_weight[joint_id] + if v > 0.5: + target[joint_id] = np.exp(-( + (x - mu_x)**2 + (y - mu_y)**2) / (2 * self.sigma**2)) + records['target'] = target + records['target_weight'] = target_weight + del records['joints'], records['joints_vis'] + + return records diff --git a/ppdet/modeling/architectures/keypoint_hrnet.py b/ppdet/modeling/architectures/keypoint_hrnet.py index 1e309e9b4..3cacf3a53 100644 --- a/ppdet/modeling/architectures/keypoint_hrnet.py +++ b/ppdet/modeling/architectures/keypoint_hrnet.py @@ -19,6 +19,7 @@ from __future__ import print_function import paddle import numpy as np import math +import cv2 from ppdet.core.workspace import register, create from .meta_arch import BaseArch from ..keypoint_utils import transform_preds @@ -118,6 +119,9 @@ class TopDownHRNet(BaseArch): class HRNetPostProcess(object): + def __init__(self, use_dark=True): + self.use_dark = use_dark + def get_max_preds(self, heatmaps): '''get predictions from score maps @@ -154,7 +158,54 @@ class HRNetPostProcess(object): return preds, maxvals - def get_final_preds(self, heatmaps, center, scale): + def gaussian_blur(self, heatmap, kernel): + border = (kernel - 1) // 2 + batch_size = heatmap.shape[0] + num_joints = heatmap.shape[1] + height = heatmap.shape[2] + width = heatmap.shape[3] + for i in range(batch_size): + for j in range(num_joints): + origin_max = np.max(heatmap[i, j]) + dr = np.zeros((height + 2 * border, width + 2 * border)) + dr[border:-border, border:-border] = heatmap[i, j].copy() + dr = cv2.GaussianBlur(dr, (kernel, kernel), 0) + heatmap[i, j] = dr[border:-border, border:-border].copy() + heatmap[i, j] *= origin_max / np.max(heatmap[i, j]) + return heatmap + + def dark_parse(self, hm, coord): + heatmap_height = hm.shape[0] + heatmap_width = hm.shape[1] + px = int(coord[0]) + py = int(coord[1]) + if 1 < px < heatmap_width - 2 and 1 < py < heatmap_height - 2: + dx = 0.5 * (hm[py][px + 1] - hm[py][px - 1]) + dy = 0.5 * (hm[py + 1][px] - hm[py - 1][px]) + dxx = 0.25 * (hm[py][px + 2] - 2 * hm[py][px] + hm[py][px - 2]) + dxy = 0.25 * (hm[py+1][px+1] - hm[py-1][px+1] - hm[py+1][px-1] \ + + hm[py-1][px-1]) + dyy = 0.25 * ( + hm[py + 2 * 1][px] - 2 * hm[py][px] + hm[py - 2 * 1][px]) + derivative = np.matrix([[dx], [dy]]) + hessian = np.matrix([[dxx, dxy], [dxy, dyy]]) + if dxx * dyy - dxy**2 != 0: + hessianinv = hessian.I + offset = -hessianinv * derivative + offset = np.squeeze(np.array(offset.T), axis=0) + coord += offset + return coord + + def dark_postprocess(self, hm, coords, kernelsize): + hm = self.gaussian_blur(hm, kernelsize) + hm = np.maximum(hm, 1e-10) + hm = np.log(hm) + for n in range(coords.shape[0]): + for p in range(coords.shape[1]): + coords[n, p] = self.dark_parse(hm[n][p], coords[n][p]) + return coords + + def get_final_preds(self, heatmaps, center, scale, kernelsize=3): """the highest heatvalue location with a quarter offset in the direction from the highest response to the second highest response. @@ -173,17 +224,20 @@ class HRNetPostProcess(object): heatmap_height = heatmaps.shape[2] heatmap_width = heatmaps.shape[3] - for n in range(coords.shape[0]): - for p in range(coords.shape[1]): - hm = heatmaps[n][p] - px = int(math.floor(coords[n][p][0] + 0.5)) - py = int(math.floor(coords[n][p][1] + 0.5)) - if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1: - diff = np.array([ - hm[py][px + 1] - hm[py][px - 1], - hm[py + 1][px] - hm[py - 1][px] - ]) - coords[n][p] += np.sign(diff) * .25 + if self.use_dark: + coords = self.dark_postprocess(heatmaps, coords, kernelsize) + else: + for n in range(coords.shape[0]): + for p in range(coords.shape[1]): + hm = heatmaps[n][p] + px = int(math.floor(coords[n][p][0] + 0.5)) + py = int(math.floor(coords[n][p][1] + 0.5)) + if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1: + diff = np.array([ + hm[py][px + 1] - hm[py][px - 1], + hm[py + 1][px] - hm[py - 1][px] + ]) + coords[n][p] += np.sign(diff) * .25 preds = coords.copy() # Transform back -- GitLab