diff --git a/configs/keypoint/tiny_pose/keypoint/tinypose_128x96.yml b/configs/keypoint/tiny_pose/keypoint/tinypose_128x96.yml new file mode 100644 index 0000000000000000000000000000000000000000..a9ee77e4e256080ca0c6b93591f830977f853333 --- /dev/null +++ b/configs/keypoint/tiny_pose/keypoint/tinypose_128x96.yml @@ -0,0 +1,147 @@ +use_gpu: true +log_iter: 5 +save_dir: output +snapshot_epoch: 10 +weights: output/tinypose_128x96/model_final +epoch: 420 +num_joints: &num_joints 17 +pixel_std: &pixel_std 200 +metric: KeyPointTopDownCOCOEval +num_classes: 1 +train_height: &train_height 128 +train_width: &train_width 96 +trainsize: &trainsize [*train_width, *train_height] +hmsize: &hmsize [24, 32] +flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] + + +#####model +architecture: TopDownHRNet + +TopDownHRNet: + backbone: LiteHRNet + post_process: HRNetPostProcess + flip_perm: *flip_perm + num_joints: *num_joints + width: &width 40 + loss: KeyPointMSELoss + use_dark: true + +LiteHRNet: + network_type: wider_naive + freeze_at: -1 + freeze_norm: false + return_idx: [0] + +KeyPointMSELoss: + use_target_weight: true + loss_scale: 1.0 + +#####optimizer +LearningRate: + base_lr: 0.008 + schedulers: + - !PiecewiseDecay + milestones: [380, 410] + gamma: 0.1 + - !LinearWarmup + start_factor: 0.001 + steps: 500 + +OptimizerBuilder: + optimizer: + type: Adam + regularizer: + factor: 0.0 + type: L2 + + +#####data +TrainDataset: + !KeypointTopDownCocoDataset + image_dir: "" + anno_path: aic_coco_train_cocoformat.json + dataset_dir: dataset + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + use_gt_bbox: True + + +EvalDataset: + !KeypointTopDownCocoDataset + image_dir: val2017 + anno_path: annotations/person_keypoints_val2017.json + dataset_dir: dataset/coco + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + use_gt_bbox: True + image_thre: 0.0 + +TestDataset: + !ImageFolder + anno_path: dataset/coco/keypoint_imagelist.txt + +worker_num: 2 +global_mean: &global_mean [0.485, 0.456, 0.406] +global_std: &global_std [0.229, 0.224, 0.225] +TrainReader: + sample_transforms: + - RandomFlipHalfBodyTransform: + scale: 0.25 + rot: 30 + num_joints_half_body: 8 + prob_half_body: 0.3 + pixel_std: *pixel_std + trainsize: *trainsize + upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + flip_pairs: *flip_perm + - AugmentationbyInformantionDropping: + prob_cutout: 0.5 + offset_factor: 0.05 + num_patch: 1 + trainsize: *trainsize + - TopDownAffine: + trainsize: *trainsize + use_udp: true + - ToHeatmapsTopDown_DARK: + hmsize: *hmsize + sigma: 1 + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 512 + shuffle: true + drop_last: false + +EvalReader: + sample_transforms: + - TopDownAffine: + trainsize: *trainsize + use_udp: true + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 16 + +TestReader: + inputs_def: + image_shape: [3, *train_height, *train_width] + sample_transforms: + - Decode: {} + - TopDownEvalAffine: + trainsize: *trainsize + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 1 + fuse_normalize: true diff --git a/configs/keypoint/tiny_pose/keypoint/tinypose_256x192.yml b/configs/keypoint/tiny_pose/keypoint/tinypose_256x192.yml new file mode 100644 index 0000000000000000000000000000000000000000..01c57212feb5d8ff81cfd91f5d936fd28b69b028 --- /dev/null +++ b/configs/keypoint/tiny_pose/keypoint/tinypose_256x192.yml @@ -0,0 +1,147 @@ +use_gpu: true +log_iter: 5 +save_dir: output +snapshot_epoch: 10 +weights: output/tinypose_256x192/model_final +epoch: 420 +num_joints: &num_joints 17 +pixel_std: &pixel_std 200 +metric: KeyPointTopDownCOCOEval +num_classes: 1 +train_height: &train_height 256 +train_width: &train_width 192 +trainsize: &trainsize [*train_width, *train_height] +hmsize: &hmsize [48, 64] +flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] + + +#####model +architecture: TopDownHRNet + +TopDownHRNet: + backbone: LiteHRNet + post_process: HRNetPostProcess + flip_perm: *flip_perm + num_joints: *num_joints + width: &width 40 + loss: KeyPointMSELoss + use_dark: true + +LiteHRNet: + network_type: wider_naive + freeze_at: -1 + freeze_norm: false + return_idx: [0] + +KeyPointMSELoss: + use_target_weight: true + loss_scale: 1.0 + +#####optimizer +LearningRate: + base_lr: 0.002 + schedulers: + - !PiecewiseDecay + milestones: [380, 410] + gamma: 0.1 + - !LinearWarmup + start_factor: 0.001 + steps: 500 + +OptimizerBuilder: + optimizer: + type: Adam + regularizer: + factor: 0.0 + type: L2 + + +#####data +TrainDataset: + !KeypointTopDownCocoDataset + image_dir: "" + anno_path: aic_coco_train_cocoformat.json + dataset_dir: dataset + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + use_gt_bbox: True + + +EvalDataset: + !KeypointTopDownCocoDataset + image_dir: val2017 + anno_path: annotations/person_keypoints_val2017.json + dataset_dir: dataset/coco + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + use_gt_bbox: True + image_thre: 0.0 + +TestDataset: + !ImageFolder + anno_path: dataset/coco/keypoint_imagelist.txt + +worker_num: 2 +global_mean: &global_mean [0.485, 0.456, 0.406] +global_std: &global_std [0.229, 0.224, 0.225] +TrainReader: + sample_transforms: + - RandomFlipHalfBodyTransform: + scale: 0.25 + rot: 30 + num_joints_half_body: 8 + prob_half_body: 0.3 + pixel_std: *pixel_std + trainsize: *trainsize + upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + flip_pairs: *flip_perm + - AugmentationbyInformantionDropping: + prob_cutout: 0.5 + offset_factor: 0.05 + num_patch: 1 + trainsize: *trainsize + - TopDownAffine: + trainsize: *trainsize + use_udp: true + - ToHeatmapsTopDown_DARK: + hmsize: *hmsize + sigma: 2 + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 128 + shuffle: true + drop_last: false + +EvalReader: + sample_transforms: + - TopDownAffine: + trainsize: *trainsize + use_udp: true + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 16 + +TestReader: + inputs_def: + image_shape: [3, *train_height, *train_width] + sample_transforms: + - Decode: {} + - TopDownEvalAffine: + trainsize: *trainsize + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 1 + fuse_normalize: true diff --git a/configs/keypoint/tiny_pose/pedestrian_detection/picodet_s_320_pedestrian.yml b/configs/keypoint/tiny_pose/pedestrian_detection/picodet_s_320_pedestrian.yml new file mode 100644 index 0000000000000000000000000000000000000000..45b8e10db70811b947e0f2f2be3d1dc7ea3beceb --- /dev/null +++ b/configs/keypoint/tiny_pose/pedestrian_detection/picodet_s_320_pedestrian.yml @@ -0,0 +1,143 @@ +use_gpu: true +log_iter: 20 +save_dir: output +snapshot_epoch: 1 +print_flops: false +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x0_75_pretrained.pdparams +weights: output/picodet_s_320_pedestrian/model_final +find_unused_parameters: True +use_ema: true +cycle_epoch: 40 +snapshot_epoch: 10 +epoch: 300 +metric: COCO +num_classes: 1 + +architecture: PicoDet + +PicoDet: + backbone: ESNet + neck: CSPPAN + head: PicoHead + +ESNet: + scale: 0.75 + feature_maps: [4, 11, 14] + act: hard_swish + channel_ratio: [0.875, 0.5, 0.5, 0.5, 0.625, 0.5, 0.625, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5] + +CSPPAN: + out_channels: 96 + use_depthwise: True + num_csp_blocks: 1 + num_features: 4 + +PicoHead: + conv_feat: + name: PicoFeat + feat_in: 96 + feat_out: 96 + num_convs: 2 + num_fpn_stride: 4 + norm_type: bn + share_cls_reg: True + fpn_stride: [8, 16, 32, 64] + feat_in_chan: 96 + prior_prob: 0.01 + reg_max: 7 + cell_offset: 0.5 + loss_class: + name: VarifocalLoss + use_sigmoid: True + iou_weighted: True + loss_weight: 1.0 + loss_dfl: + name: DistributionFocalLoss + loss_weight: 0.25 + loss_bbox: + name: GIoULoss + loss_weight: 2.0 + assigner: + name: SimOTAAssigner + candidate_topk: 10 + iou_weight: 6 + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 100 + score_threshold: 0.025 + nms_threshold: 0.6 + +LearningRate: + base_lr: 0.4 + schedulers: + - !CosineDecay + max_epochs: 300 + - !LinearWarmup + start_factor: 0.1 + steps: 300 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.00004 + type: L2 + +TrainDataset: + !COCODataSet + image_dir: "" + anno_path: aic_coco_train_cocoformat.json + dataset_dir: dataset + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] + +EvalDataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + +TestDataset: + !ImageFolder + anno_path: annotations/instances_val2017.json + +worker_num: 8 +TrainReader: + sample_transforms: + - Decode: {} + - RandomCrop: {} + - RandomFlip: {prob: 0.5} + - RandomDistort: {} + batch_transforms: + - BatchRandomResize: {target_size: [256, 288, 320, 352, 384], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_size: 128 + shuffle: true + drop_last: true + collate_batch: false + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [320, 320], keep_ratio: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 8 + shuffle: false + +TestReader: + inputs_def: + image_shape: [1, 3, 320, 320] + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [320, 320], keep_ratio: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: false diff --git a/deploy/python/keypoint_preprocess.py b/deploy/python/keypoint_preprocess.py index 6619c7db59d75256d1806644c6b81820c85fe67f..97a747c4ca0522d691dad5998e81088647f57ae1 100644 --- a/deploy/python/keypoint_preprocess.py +++ b/deploy/python/keypoint_preprocess.py @@ -108,6 +108,37 @@ def get_affine_transform(center, return trans +def get_warp_matrix(theta, size_input, size_dst, size_target): + """Calculate the transformation matrix under the constraint of unbiased. + Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased + Data Processing for Human Pose Estimation (CVPR 2020). + + Args: + theta (float): Rotation angle in degrees. + size_input (np.ndarray): Size of input image [w, h]. + size_dst (np.ndarray): Size of output image [w, h]. + size_target (np.ndarray): Size of ROI in input plane [w, h]. + + Returns: + matrix (np.ndarray): A matrix for transformation. + """ + theta = np.deg2rad(theta) + matrix = np.zeros((2, 3), dtype=np.float32) + scale_x = size_dst[0] / size_target[0] + scale_y = size_dst[1] / size_target[1] + matrix[0, 0] = np.cos(theta) * scale_x + matrix[0, 1] = -np.sin(theta) * scale_x + matrix[0, 2] = scale_x * ( + -0.5 * size_input[0] * np.cos(theta) + 0.5 * size_input[1] * + np.sin(theta) + 0.5 * size_target[0]) + matrix[1, 0] = np.sin(theta) * scale_y + matrix[1, 1] = np.cos(theta) * scale_y + matrix[1, 2] = scale_y * ( + -0.5 * size_input[0] * np.sin(theta) - 0.5 * size_input[1] * + np.cos(theta) + 0.5 * size_target[1]) + return matrix + + def rotate_point(pt, angle_rad): """Rotate a point by an angle. @@ -154,6 +185,7 @@ class TopDownEvalAffine(object): Args: trainsize (list): [w, h], the standard size used to train + use_udp (bool): whether to use Unbiased Data Processing. records(dict): the dict contained the image and coords Returns: @@ -161,19 +193,29 @@ class TopDownEvalAffine(object): """ - def __init__(self, trainsize): + def __init__(self, trainsize, use_udp=False): self.trainsize = trainsize + self.use_udp = use_udp def __call__(self, image, im_info): rot = 0 imshape = im_info['im_shape'][::-1] center = im_info['center'] if 'center' in im_info else imshape / 2. scale = im_info['scale'] if 'scale' in im_info else imshape - trans = get_affine_transform(center, scale, rot, self.trainsize) - image = cv2.warpAffine( - image, - trans, (int(self.trainsize[0]), int(self.trainsize[1])), - flags=cv2.INTER_LINEAR) + if self.use_udp: + trans = get_warp_matrix( + rot, center * 2.0, + [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale) + image = cv2.warpAffine( + image, + trans, (int(self.trainsize[0]), int(self.trainsize[1])), + flags=cv2.INTER_LINEAR) + else: + trans = get_affine_transform(center, scale, rot, self.trainsize) + image = cv2.warpAffine( + image, + trans, (int(self.trainsize[0]), int(self.trainsize[1])), + flags=cv2.INTER_LINEAR) return image, im_info diff --git a/ppdet/data/transform/keypoint_operators.py b/ppdet/data/transform/keypoint_operators.py index a578f0854939efb197cf9b7fca97cf7aed55e779..81770b63efd76e62435af52b4ab1f5318ebcf988 100644 --- a/ppdet/data/transform/keypoint_operators.py +++ b/ppdet/data/transform/keypoint_operators.py @@ -28,7 +28,7 @@ import numpy as np import math import copy -from ...modeling.keypoint_utils import get_affine_mat_kernel, warp_affine_joints, get_affine_transform, affine_transform +from ...modeling.keypoint_utils import get_affine_mat_kernel, warp_affine_joints, get_affine_transform, affine_transform, get_warp_matrix from ppdet.core.workspace import serializable from ppdet.utils.logger import setup_logger logger = setup_logger(__name__) @@ -36,10 +36,19 @@ logger = setup_logger(__name__) registered_ops = [] __all__ = [ - 'RandomAffine', 'KeyPointFlip', 'TagGenerate', 'ToHeatmaps', - 'NormalizePermute', 'EvalAffine', 'RandomFlipHalfBodyTransform', - 'TopDownAffine', 'ToHeatmapsTopDown', 'ToHeatmapsTopDown_DARK', - 'TopDownEvalAffine' + 'RandomAffine', + 'KeyPointFlip', + 'TagGenerate', + 'ToHeatmaps', + 'NormalizePermute', + 'EvalAffine', + 'RandomFlipHalfBodyTransform', + 'TopDownAffine', + 'ToHeatmapsTopDown', + 'ToHeatmapsTopDown_DARK', + 'ToHeatmapsTopDown_UDP', + 'TopDownEvalAffine', + 'AugmentationbyInformantionDropping', ] @@ -96,37 +105,6 @@ class KeyPointFlip(object): return records -def get_warp_matrix(theta, size_input, size_dst, size_target): - """Calculate the transformation matrix under the constraint of unbiased. - Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased - Data Processing for Human Pose Estimation (CVPR 2020). - - Args: - theta (float): Rotation angle in degrees. - size_input (np.ndarray): Size of input image [w, h]. - size_dst (np.ndarray): Size of output image [w, h]. - size_target (np.ndarray): Size of ROI in input plane [w, h]. - - Returns: - matrix (np.ndarray): A matrix for transformation. - """ - theta = np.deg2rad(theta) - matrix = np.zeros((2, 3), dtype=np.float32) - scale_x = size_dst[0] / size_target[0] - scale_y = size_dst[1] / size_target[1] - matrix[0, 0] = math.cos(theta) * scale_x - matrix[0, 1] = -math.sin(theta) * scale_x - matrix[0, 2] = scale_x * ( - -0.5 * size_input[0] * math.cos(theta) + 0.5 * size_input[1] * - math.sin(theta) + 0.5 * size_target[0]) - matrix[1, 0] = math.sin(theta) * scale_y - matrix[1, 1] = math.cos(theta) * scale_y - matrix[1, 2] = scale_y * ( - -0.5 * size_input[0] * math.sin(theta) - 0.5 * size_input[1] * - math.cos(theta) + 0.5 * size_target[1]) - return matrix - - @register_keypointop class RandomAffine(object): """apply affine transform to image, mask and coords @@ -531,12 +509,72 @@ class RandomFlipHalfBodyTransform(object): return records +@register_keypointop +class AugmentationbyInformantionDropping(object): + """AID: Augmentation by Informantion Dropping. Please refer + to https://arxiv.org/abs/2008.07139 + + Args: + prob_cutout (float): The probability of the Cutout augmentation. + offset_factor (float): Offset factor of cutout center. + num_patch (int): Number of patches to be cutout. + records(dict): the dict contained the image and coords + + Returns: + records (dict): contain the image and coords after tranformed + + """ + + def __init__(self, + trainsize, + prob_cutout=0.0, + offset_factor=0.2, + num_patch=1): + self.prob_cutout = prob_cutout + self.offset_factor = offset_factor + self.num_patch = num_patch + self.trainsize = trainsize + + def _cutout(self, img, joints, joints_vis): + height, width, _ = img.shape + img = img.reshape((height * width, -1)) + feat_x_int = np.arange(0, width) + feat_y_int = np.arange(0, height) + feat_x_int, feat_y_int = np.meshgrid(feat_x_int, feat_y_int) + feat_x_int = feat_x_int.reshape((-1, )) + feat_y_int = feat_y_int.reshape((-1, )) + for _ in range(self.num_patch): + vis_idx, _ = np.where(joints_vis > 0) + occlusion_joint_id = np.random.choice(vis_idx) + center = joints[occlusion_joint_id, 0:2] + offset = np.random.randn(2) * self.trainsize[0] * self.offset_factor + center = center + offset + radius = np.random.uniform(0.1, 0.2) * self.trainsize[0] + x_offset = (center[0] - feat_x_int) / radius + y_offset = (center[1] - feat_y_int) / radius + dis = x_offset**2 + y_offset**2 + keep_pos = np.where((dis <= 1) & (dis >= 0))[0] + img[keep_pos, :] = 0 + img = img.reshape((height, width, -1)) + return img + + def __call__(self, records): + img = records['image'] + joints = records['joints'] + joints_vis = records['joints_vis'] + if np.random.rand() < self.prob_cutout: + img = self._cutout(img, joints, joints_vis) + records['image'] = img + return records + + @register_keypointop class TopDownAffine(object): """apply affine transform to image and coords Args: trainsize (list): [w, h], the standard size used to train + use_udp (bool): whether to use Unbiased Data Processing. records(dict): the dict contained the image and coords Returns: @@ -544,26 +582,36 @@ class TopDownAffine(object): """ - def __init__(self, trainsize): + def __init__(self, trainsize, use_udp=False): self.trainsize = trainsize + self.use_udp = use_udp def __call__(self, records): image = records['image'] joints = records['joints'] joints_vis = records['joints_vis'] rot = records['rotate'] if "rotate" in records else 0 - trans = get_affine_transform(records['center'], records['scale'] * 200, - rot, self.trainsize) - trans_joint = get_affine_transform( - records['center'], records['scale'] * 200, rot, - [self.trainsize[0] / 4, self.trainsize[1] / 4]) - image = cv2.warpAffine( - image, - trans, (int(self.trainsize[0]), int(self.trainsize[1])), - flags=cv2.INTER_LINEAR) - for i in range(joints.shape[0]): - if joints_vis[i, 0] > 0.0: - joints[i, 0:2] = affine_transform(joints[i, 0:2], trans_joint) + if self.use_udp: + trans = get_warp_matrix( + rot, records['center'] * 2.0, + [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], + records['scale'] * 200.0) + image = cv2.warpAffine( + image, + trans, (int(self.trainsize[0]), int(self.trainsize[1])), + flags=cv2.INTER_LINEAR) + joints[:, 0:2] = warp_affine_joints(joints[:, 0:2].copy(), trans) + else: + trans = get_affine_transform(records['center'], records['scale'] * + 200, rot, self.trainsize) + image = cv2.warpAffine( + image, + trans, (int(self.trainsize[0]), int(self.trainsize[1])), + flags=cv2.INTER_LINEAR) + for i in range(joints.shape[0]): + if joints_vis[i, 0] > 0.0: + joints[i, 0:2] = affine_transform(joints[i, 0:2], trans) + records['image'] = image records['joints'] = joints @@ -576,6 +624,7 @@ class TopDownEvalAffine(object): Args: trainsize (list): [w, h], the standard size used to train + use_udp (bool): whether to use Unbiased Data Processing. records(dict): the dict contained the image and coords Returns: @@ -583,8 +632,9 @@ class TopDownEvalAffine(object): """ - def __init__(self, trainsize): + def __init__(self, trainsize, use_udp=False): self.trainsize = trainsize + self.use_udp = use_udp def __call__(self, records): image = records['image'] @@ -592,11 +642,21 @@ class TopDownEvalAffine(object): imshape = records['im_shape'][::-1] center = imshape / 2. scale = imshape - trans = get_affine_transform(center, scale, rot, self.trainsize) - image = cv2.warpAffine( - image, - trans, (int(self.trainsize[0]), int(self.trainsize[1])), - flags=cv2.INTER_LINEAR) + + if self.use_udp: + trans = get_warp_matrix( + rot, center * 2.0, + [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale) + image = cv2.warpAffine( + image, + trans, (int(self.trainsize[0]), int(self.trainsize[1])), + flags=cv2.INTER_LINEAR) + else: + trans = get_affine_transform(center, scale, rot, self.trainsize) + image = cv2.warpAffine( + image, + trans, (int(self.trainsize[0]), int(self.trainsize[1])), + flags=cv2.INTER_LINEAR) records['image'] = image return records @@ -632,10 +692,10 @@ class ToHeatmapsTopDown(object): target = np.zeros( (num_joints, self.hmsize[1], self.hmsize[0]), dtype=np.float32) tmp_size = self.sigma * 3 + feat_stride = image_size / self.hmsize for joint_id in range(num_joints): - feat_stride = image_size / self.hmsize - mu_x = int(joints[joint_id][0] + 0.5) - mu_y = int(joints[joint_id][1] + 0.5) + mu_x = int(joints[joint_id][0] + 0.5) / feat_stride[0] + mu_y = int(joints[joint_id][1] + 0.5) / feat_stride[1] # Check that any part of the gaussian is in-bounds ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] @@ -693,14 +753,17 @@ class ToHeatmapsTopDown_DARK(object): joints = records['joints'] joints_vis = records['joints_vis'] num_joints = joints.shape[0] + image_size = np.array( + [records['image'].shape[1], records['image'].shape[0]]) target_weight = np.ones((num_joints, 1), dtype=np.float32) target_weight[:, 0] = joints_vis[:, 0] target = np.zeros( (num_joints, self.hmsize[1], self.hmsize[0]), dtype=np.float32) tmp_size = self.sigma * 3 + feat_stride = image_size / self.hmsize for joint_id in range(num_joints): - mu_x = joints[joint_id][0] - mu_y = joints[joint_id][1] + mu_x = joints[joint_id][0] / feat_stride[0] + mu_y = joints[joint_id][1] / feat_stride[1] # Check that any part of the gaussian is in-bounds ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] @@ -723,3 +786,74 @@ class ToHeatmapsTopDown_DARK(object): del records['joints'], records['joints_vis'] return records + + +@register_keypointop +class ToHeatmapsTopDown_UDP(object): + """to generate the gaussian heatmaps of keypoint for heatmap loss. + ref: Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing + for Human Pose Estimation (CVPR 2020). + + Args: + hmsize (list): [w, h] output heatmap's size + sigma (float): the std of gaussin kernel genereted + records(dict): the dict contained the image and coords + + Returns: + records (dict): contain the heatmaps used to heatmaploss + """ + + def __init__(self, hmsize, sigma): + super(ToHeatmapsTopDown_UDP, self).__init__() + self.hmsize = np.array(hmsize) + self.sigma = sigma + + def __call__(self, records): + joints = records['joints'] + joints_vis = records['joints_vis'] + num_joints = joints.shape[0] + image_size = np.array( + [records['image'].shape[1], records['image'].shape[0]]) + target_weight = np.ones((num_joints, 1), dtype=np.float32) + target_weight[:, 0] = joints_vis[:, 0] + target = np.zeros( + (num_joints, self.hmsize[1], self.hmsize[0]), dtype=np.float32) + tmp_size = self.sigma * 3 + size = 2 * tmp_size + 1 + x = np.arange(0, size, 1, np.float32) + y = x[:, None] + feat_stride = (image_size - 1.0) / (self.hmsize - 1.0) + for joint_id in range(num_joints): + mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5) + mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5) + # Check that any part of the gaussian is in-bounds + ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] + br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] + if ul[0] >= self.hmsize[0] or ul[1] >= self.hmsize[1] or br[ + 0] < 0 or br[1] < 0: + # If not, just return the image as is + target_weight[joint_id] = 0 + continue + + mu_x_ac = joints[joint_id][0] / feat_stride[0] + mu_y_ac = joints[joint_id][1] / feat_stride[1] + x0 = y0 = size // 2 + x0 += mu_x_ac - mu_x + y0 += mu_y_ac - mu_y + g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * self.sigma**2)) + # Usable gaussian range + g_x = max(0, -ul[0]), min(br[0], self.hmsize[0]) - ul[0] + g_y = max(0, -ul[1]), min(br[1], self.hmsize[1]) - ul[1] + # Image range + img_x = max(0, ul[0]), min(br[0], self.hmsize[0]) + img_y = max(0, ul[1]), min(br[1], self.hmsize[1]) + + v = target_weight[joint_id] + if v > 0.5: + target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[ + 0]:g_y[1], g_x[0]:g_x[1]] + records['target'] = target + records['target_weight'] = target_weight + del records['joints'], records['joints_vis'] + + return records diff --git a/ppdet/modeling/keypoint_utils.py b/ppdet/modeling/keypoint_utils.py index d930997cc91aeab8ba65414eb49a49845c2d02af..b3f84da7dcff694a6c91b67db3f534559412d458 100644 --- a/ppdet/modeling/keypoint_utils.py +++ b/ppdet/modeling/keypoint_utils.py @@ -95,6 +95,37 @@ def get_affine_transform(center, return trans +def get_warp_matrix(theta, size_input, size_dst, size_target): + """Calculate the transformation matrix under the constraint of unbiased. + Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased + Data Processing for Human Pose Estimation (CVPR 2020). + + Args: + theta (float): Rotation angle in degrees. + size_input (np.ndarray): Size of input image [w, h]. + size_dst (np.ndarray): Size of output image [w, h]. + size_target (np.ndarray): Size of ROI in input plane [w, h]. + + Returns: + matrix (np.ndarray): A matrix for transformation. + """ + theta = np.deg2rad(theta) + matrix = np.zeros((2, 3), dtype=np.float32) + scale_x = size_dst[0] / size_target[0] + scale_y = size_dst[1] / size_target[1] + matrix[0, 0] = np.cos(theta) * scale_x + matrix[0, 1] = -np.sin(theta) * scale_x + matrix[0, 2] = scale_x * ( + -0.5 * size_input[0] * np.cos(theta) + 0.5 * size_input[1] * + np.sin(theta) + 0.5 * size_target[0]) + matrix[1, 0] = np.sin(theta) * scale_y + matrix[1, 1] = np.cos(theta) * scale_y + matrix[1, 2] = scale_y * ( + -0.5 * size_input[0] * np.sin(theta) - 0.5 * size_input[1] * + np.cos(theta) + 0.5 * size_target[1]) + return matrix + + def _get_3rd_point(a, b): """To calculate the affine matrix, three pairs of points are required. This function is used to get the 3rd point, given 2D points a & b.