From 9fafde8fbcac8f56f98a2da649ac416e91b472fa Mon Sep 17 00:00:00 2001 From: XYZ <1290573099@qq.com> Date: Wed, 8 Feb 2023 16:57:44 +0800 Subject: [PATCH] tinypose3d for medical dataset (#7696) * tinypose3d for medical dataset * modify tinypose-3d codes according to comments * the images in dataset is named 'image' * change model name to TinyPose3D * annotations --- .../tinypose3d_medical_multi_frames.yml | 138 ++++++++ .../tinypose3d_multi_frames_heatmap.yml | 138 ++++++++ ppdet/data/source/__init__.py | 2 +- ppdet/data/source/pose3d_cmb.py | 182 +++++++++++ ppdet/data/transform/__init__.py | 2 + .../data/transform/keypoints_3d_operators.py | 296 ++++++++++++++++++ ppdet/data/transform/operators.py | 33 +- .../modeling/architectures/keypoint_hrnet.py | 207 +++++++++++- ppdet/modeling/backbones/lite_hrnet.py | 5 + 9 files changed, 995 insertions(+), 8 deletions(-) create mode 100644 configs/keypoint/tiny_pose/tinypose3d_medical_multi_frames.yml create mode 100644 configs/keypoint/tiny_pose/tinypose3d_multi_frames_heatmap.yml create mode 100644 ppdet/data/transform/keypoints_3d_operators.py diff --git a/configs/keypoint/tiny_pose/tinypose3d_medical_multi_frames.yml b/configs/keypoint/tiny_pose/tinypose3d_medical_multi_frames.yml new file mode 100644 index 000000000..aad7a4055 --- /dev/null +++ b/configs/keypoint/tiny_pose/tinypose3d_medical_multi_frames.yml @@ -0,0 +1,138 @@ +use_gpu: true +log_iter: 5 +save_dir: output +snapshot_epoch: 1 +weights: output/tinypose_3D_multi_frames/model_final +epoch: 420 +num_joints: &num_joints 24 +pixel_std: &pixel_std 200 +metric: Pose3DEval +num_classes: 1 +train_height: &train_height 128 +train_width: &train_width 96 +trainsize: &trainsize [*train_width, *train_height] +hmsize: &hmsize [24, 32] +flip_perm: &flip_perm [[1, 2], [4, 5], [7, 8], [10, 11], [13, 14], [16, 17], [18, 19], [20, 21], [22, 23]] + + +#####model +architecture: TinyPose3DHRNet +pretrain_weights: medical_multi_frames_best_model.pdparams + +TinyPose3DHRNet: + backbone: LiteHRNet + post_process: TinyPose3DPostProcess + num_joints: *num_joints + width: &width 40 + loss: KeyPointRegressionMSELoss + +LiteHRNet: + network_type: wider_naive + freeze_at: -1 + freeze_norm: false + return_idx: [0] + +KeyPointRegressionMSELoss: + reduction: 'mean' + +#####optimizer +LearningRate: + base_lr: 0.001 + schedulers: + - !PiecewiseDecay + milestones: [17, 21] + gamma: 0.1 + - !LinearWarmup + start_factor: 0.01 + steps: 1000 + +OptimizerBuilder: + optimizer: + type: Adam + regularizer: + factor: 0.0 + type: L2 + +#####data +TrainDataset: + !Keypoint3DMultiFramesDataset + dataset_dir: "data/medical/multi_frames/train" + image_dir: "images" + p3d_dir: "joint_pc/player_0" + json_path: "json_results/player_0/player_0.json" + img_size: *trainsize # w,h + num_frames: 6 + + +EvalDataset: + !Keypoint3DMultiFramesDataset + dataset_dir: "data/medical/multi_frames/val" + image_dir: "images" + p3d_dir: "joint_pc/player_0" + json_path: "json_results/player_0/player_0.json" + img_size: *trainsize # w,h + num_frames: 6 + +TestDataset: + !Keypoint3DMultiFramesDataset + dataset_dir: "data/medical/multi_frames/val" + image_dir: "images" + p3d_dir: "joint_pc/player_0" + json_path: "json_results/player_0/player_0.json" + img_size: *trainsize # w,h + num_frames: 6 + +worker_num: 4 +global_mean: &global_mean [0.485, 0.456, 0.406] +global_std: &global_std [0.229, 0.224, 0.225] +TrainReader: + sample_transforms: + - CropAndFlipImages: + crop_range: [556, 1366] + - RandomFlipHalfBody3DTransformImages: + scale: 0.25 + rot: 30 + num_joints_half_body: 9 + prob_half_body: 0.3 + pixel_std: *pixel_std + trainsize: *trainsize + upper_body_ids: [0, 3, 6, 9, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + flip_pairs: *flip_perm + do_occlusion: true + - Resize: {interp: 2, target_size: [*train_height,*train_width], keep_ratio: false} + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - PermuteImages: {} + batch_size: 32 + shuffle: true + drop_last: false + +EvalReader: + sample_transforms: + - CropAndFlipImages: + crop_range: [556, 1366] + - Resize: {interp: 2, target_size: [*train_height,*train_width], keep_ratio: false} + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - PermuteImages: {} + batch_size: 32 + +TestReader: + inputs_def: + image_shape: [3, *train_height, *train_width] + sample_transforms: + - Decode: {} + - LetterBoxResize: { target_size: [*train_height,*train_width]} + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 1 + fuse_normalize: false diff --git a/configs/keypoint/tiny_pose/tinypose3d_multi_frames_heatmap.yml b/configs/keypoint/tiny_pose/tinypose3d_multi_frames_heatmap.yml new file mode 100644 index 000000000..a5893ec9b --- /dev/null +++ b/configs/keypoint/tiny_pose/tinypose3d_multi_frames_heatmap.yml @@ -0,0 +1,138 @@ +use_gpu: true +log_iter: 5 +save_dir: output +snapshot_epoch: 1 +weights: output/tinypose3d_multi_frames_heatmap/model_final +epoch: 420 +num_joints: &num_joints 24 +pixel_std: &pixel_std 200 +metric: Pose3DEval +num_classes: 1 +train_height: &train_height 128 +train_width: &train_width 128 +trainsize: &trainsize [*train_width, *train_height] +hmsize: &hmsize [24, 32] +flip_perm: &flip_perm [[1, 2], [4, 5], [7, 8], [10, 11], [13, 14], [16, 17], [18, 19], [20, 21], [22, 23]] + +#####model +architecture: TinyPose3DHRHeatmapNet +pretrain_weights: medical_multi_frames_best_model.pdparams + +TinyPose3DHRHeatmapNet: + backbone: LiteHRNet + post_process: TinyPosePostProcess + num_joints: *num_joints + width: &width 40 + loss: KeyPointRegressionMSELoss + +LiteHRNet: + network_type: wider_naive + freeze_at: -1 + freeze_norm: false + return_idx: [0] + +KeyPointRegressionMSELoss: + reduction: 'mean' + +#####optimizer +LearningRate: + base_lr: 0.001 + schedulers: + - !PiecewiseDecay + milestones: [17, 21] + gamma: 0.1 + - !LinearWarmup + start_factor: 0.01 + steps: 1000 + +OptimizerBuilder: + optimizer: + type: Adam + regularizer: + factor: 0.0 + type: L2 + +#####data +TrainDataset: + !Keypoint3DMultiFramesDataset + dataset_dir: "data/medical/multi_frames/train" + image_dir: "images" + p3d_dir: "joint_pc/player_0" + json_path: "json_results/player_0/player_0.json" + img_size: *trainsize # w,h + num_frames: 6 + + +EvalDataset: + !Keypoint3DMultiFramesDataset + dataset_dir: "data/medical/multi_frames/val" + image_dir: "images" + p3d_dir: "joint_pc/player_0" + json_path: "json_results/player_0/player_0.json" + img_size: *trainsize # w,h + num_frames: 6 + +TestDataset: + !Keypoint3DMultiFramesDataset + dataset_dir: "data/medical/multi_frames/val" + image_dir: "images" + p3d_dir: "joint_pc/player_0" + json_path: "json_results/player_0/player_0.json" + img_size: *trainsize # w,h + num_frames: 6 + +worker_num: 4 +global_mean: &global_mean [0.485, 0.456, 0.406] +global_std: &global_std [0.229, 0.224, 0.225] +TrainReader: + sample_transforms: + - CropAndFlipImages: + crop_range: [556, 1366] # 保留train_height/train_width比例的情况下,裁剪原图左右两个的黑色填充 + - RandomFlipHalfBody3DTransformImages: + scale: 0.25 + rot: 30 + num_joints_half_body: 9 + prob_half_body: 0.3 + pixel_std: *pixel_std + trainsize: *trainsize + upper_body_ids: [0, 3, 6, 9, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + flip_pairs: *flip_perm + do_occlusion: true + - Resize: {interp: 2, target_size: [*train_height,*train_width], keep_ratio: false} + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - PermuteImages: {} + batch_size: 1 #32 + shuffle: true + drop_last: false + +EvalReader: + sample_transforms: + - CropAndFlipImages: + crop_range: [556, 1366] + - Resize: {interp: 2, target_size: [*train_height,*train_width], keep_ratio: false} + #- OriginPointTranslationImages: {} + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - PermuteImages: {} + batch_size: 32 + +TestReader: + inputs_def: + image_shape: [3, *train_height, *train_width] + sample_transforms: + - Decode: {} + - LetterBoxResize: { target_size: [*train_height,*train_width]} + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 1 + fuse_normalize: false diff --git a/ppdet/data/source/__init__.py b/ppdet/data/source/__init__.py index 766ce6de7..f4fef334e 100644 --- a/ppdet/data/source/__init__.py +++ b/ppdet/data/source/__init__.py @@ -28,4 +28,4 @@ from .keypoint_coco import * from .mot import * from .sniper_coco import SniperCOCODataSet from .dataset import ImageFolder -from .pose3d_cmb import Pose3DDataset +from .pose3d_cmb import * diff --git a/ppdet/data/source/pose3d_cmb.py b/ppdet/data/source/pose3d_cmb.py index ea89daf01..3c465a325 100644 --- a/ppdet/data/source/pose3d_cmb.py +++ b/ppdet/data/source/pose3d_cmb.py @@ -23,6 +23,7 @@ import pycocotools from pycocotools.coco import COCO from .dataset import DetDataset from ppdet.core.workspace import register, serializable +from paddle.io import Dataset @serializable @@ -198,3 +199,184 @@ class Pose3DDataset(DetDataset): raise ValueError( "Some dataset is not valid and cannot download automatically now, please prepare the dataset first" ) + + +@register +@serializable +class Keypoint3DMultiFramesDataset(Dataset): + """24 keypoints 3D dataset for pose estimation. + + each item is a list of images + + The dataset loads raw features and apply specified transforms + to return a dict containing the image tensors and other information. + + Args: + dataset_dir (str): Root path to the dataset. + image_dir (str): Path to a directory where images are held. + """ + + def __init__( + self, + dataset_dir, # 数据集根目录 + image_dir, # 图像文件夹 + p3d_dir, # 3D关键点文件夹 + json_path, + img_size, #图像resize大小 + num_frames, # 帧序列长度 + anno_path=None, ): + + self.dataset_dir = dataset_dir + self.image_dir = image_dir + self.p3d_dir = p3d_dir + self.json_path = json_path + self.img_size = img_size + self.num_frames = num_frames + self.anno_path = anno_path + + self.data_labels, self.mf_inds = self._generate_multi_frames_list() + + def _generate_multi_frames_list(self): + act_list = os.listdir(self.dataset_dir) # 动作列表 + count = 0 + mf_list = [] + annos_dict = {'images': [], 'annotations': [], 'act_inds': []} + for act in act_list: #对每个动作,生成帧序列 + if '.' in act: + continue + + json_path = os.path.join(self.dataset_dir, act, self.json_path) + with open(json_path, 'r') as j: + annos = json.load(j) + length = len(annos['images']) + for k, v in annos.items(): + if k in annos_dict: + annos_dict[k].extend(v) + annos_dict['act_inds'].extend([act] * length) + + mf = [[i + j + count for j in range(self.num_frames)] + for i in range(0, length - self.num_frames + 1)] + mf_list.extend(mf) + count += length + + print("total data number:", len(mf_list)) + return annos_dict, mf_list + + def __call__(self, *args, **kwargs): + return self + + def __getitem__(self, index): # 拿一个连续的序列 + inds = self.mf_inds[ + index] # 如[568, 569, 570, 571, 572, 573],长度为num_frames + + images = self.data_labels['images'] # all images + annots = self.data_labels['annotations'] # all annots + + act = self.data_labels['act_inds'][inds[0]] # 动作名(文件夹名) + + kps3d_list = [] + kps3d_vis_list = [] + names = [] + + h, w = 0, 0 + for ind in inds: # one image + height = float(images[ind]['height']) + width = float(images[ind]['width']) + name = images[ind]['file_name'] # 图像名称,带有后缀 + + kps3d_name = name.split('.')[0] + '.obj' + kps3d_path = os.path.join(self.dataset_dir, act, self.p3d_dir, + kps3d_name) + + joints, joints_vis = self.kps3d_process(kps3d_path) + joints_vis = np.array(joints_vis, dtype=np.float32) + + kps3d_list.append(joints) + kps3d_vis_list.append(joints_vis) + names.append(name) + + kps3d = np.array(kps3d_list) # (6, 24, 3),(num_frames, joints_num, 3) + kps3d_vis = np.array(kps3d_vis_list) + + # read image + imgs = [] + for name in names: + img_path = os.path.join(self.dataset_dir, act, self.image_dir, name) + + image = cv2.imread(img_path, cv2.IMREAD_COLOR | + cv2.IMREAD_IGNORE_ORIENTATION) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + imgs.append(np.expand_dims(image, axis=0)) + + imgs = np.concatenate(imgs, axis=0) + imgs = imgs.astype( + np.float32) # (6, 1080, 1920, 3),(num_frames, h, w, c) + + # attention: 此时图像和标注是镜像的 + records = { + 'kps3d': kps3d, + 'kps3d_vis': kps3d_vis, + "image": imgs, + 'act': act, + 'names': names, + 'im_id': index + } + + return self.transform(records) + + def kps3d_process(self, kps3d_path): + count = 0 + kps = [] + kps_vis = [] + + with open(kps3d_path, 'r') as f: + lines = f.readlines() + for line in lines: + if line[0] == 'v': + kps.append([]) + line = line.strip('\n').split(' ')[1:] + for kp in line: + kps[-1].append(float(kp)) + count += 1 + + kps_vis.append([1, 1, 1]) + + kps = np.array(kps) # 52,3 + kps_vis = np.array(kps_vis) + + kps *= 10 # scale points + kps -= kps[[0], :] # set root point to zero + + kps = np.concatenate((kps[0:23], kps[[37]]), axis=0) # 24,3 + + kps *= 10 + + kps_vis = np.concatenate((kps_vis[0:23], kps_vis[[37]]), axis=0) # 24,3 + + return kps, kps_vis + + def __len__(self): + return len(self.mf_inds) + + def get_anno(self): + if self.anno_path is None: + return + return os.path.join(self.dataset_dir, self.anno_path) + + def check_or_download_dataset(self): + return + + def parse_dataset(self, ): + return + + def set_transform(self, transform): + self.transform = transform + + def set_epoch(self, epoch_id): + self._epoch = epoch_id + + def set_kwargs(self, **kwargs): + self.mixup_epoch = kwargs.get('mixup_epoch', -1) + self.cutmix_epoch = kwargs.get('cutmix_epoch', -1) + self.mosaic_epoch = kwargs.get('mosaic_epoch', -1) diff --git a/ppdet/data/transform/__init__.py b/ppdet/data/transform/__init__.py index a9bf0004a..08d7f64d9 100644 --- a/ppdet/data/transform/__init__.py +++ b/ppdet/data/transform/__init__.py @@ -17,12 +17,14 @@ from . import batch_operators from . import keypoint_operators from . import mot_operators from . import rotated_operators +from . import keypoints_3d_operators from .operators import * from .batch_operators import * from .keypoint_operators import * from .mot_operators import * from .rotated_operators import * +from .keypoints_3d_operators import * __all__ = [] __all__ += registered_ops diff --git a/ppdet/data/transform/keypoints_3d_operators.py b/ppdet/data/transform/keypoints_3d_operators.py new file mode 100644 index 000000000..13337bc32 --- /dev/null +++ b/ppdet/data/transform/keypoints_3d_operators.py @@ -0,0 +1,296 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +try: + from collections.abc import Sequence +except Exception: + from collections import Sequence +import cv2 +import numpy as np +import math +import copy +import random +import uuid +from numbers import Number, Integral + +from ...modeling.keypoint_utils import get_affine_mat_kernel, warp_affine_joints, get_affine_transform, affine_transform, get_warp_matrix +from ppdet.core.workspace import serializable +from ppdet.utils.logger import setup_logger +logger = setup_logger(__name__) + +registered_ops = [] + +__all__ = [ + 'CropAndFlipImages', 'PermuteImages', 'RandomFlipHalfBody3DTransformImages' +] + +import matplotlib.pyplot as plt +from PIL import Image, ImageDraw +from mpl_toolkits.mplot3d import Axes3D + + +def register_keypointop(cls): + return serializable(cls) + + +def register_op(cls): + registered_ops.append(cls.__name__) + if not hasattr(BaseOperator, cls.__name__): + setattr(BaseOperator, cls.__name__, cls) + else: + raise KeyError("The {} class has been registered.".format(cls.__name__)) + return serializable(cls) + + +class BaseOperator(object): + def __init__(self, name=None): + if name is None: + name = self.__class__.__name__ + self._id = name + '_' + str(uuid.uuid4())[-6:] + + def apply(self, sample, context=None): + """ Process a sample. + Args: + sample (dict): a dict of sample, eg: {'image':xx, 'label': xxx} + context (dict): info about this sample processing + Returns: + result (dict): a processed sample + """ + return sample + + def __call__(self, sample, context=None): + """ Process a sample. + Args: + sample (dict): a dict of sample, eg: {'image':xx, 'label': xxx} + context (dict): info about this sample processing + Returns: + result (dict): a processed sample + """ + if isinstance(sample, Sequence): # for batch_size + for i in range(len(sample)): + sample[i] = self.apply(sample[i], context) + else: + # image.shape changed + sample = self.apply(sample, context) + return sample + + def __str__(self): + return str(self._id) + + +@register_keypointop +class CropAndFlipImages(object): + """Crop all images""" + + def __init__(self, crop_range, flip_pairs=None): + super(CropAndFlipImages, self).__init__() + self.crop_range = crop_range + self.flip_pairs = flip_pairs + + def __call__(self, records): # tuple + images = records["image"] + images = images[:, :, ::-1, :] + images = images[:, :, self.crop_range[0]:self.crop_range[1]] + records["image"] = images + + if "kps2d" in records.keys(): + kps2d = records["kps2d"] + + width, height = images.shape[2], images.shape[1] + kps2d = np.array(kps2d) + kps2d[:, :, 0] = kps2d[:, :, 0] - self.crop_range[0] + + for pair in self.flip_pairs: + kps2d[:, pair[0], :], kps2d[:,pair[1], :] = \ + kps2d[:,pair[1], :], kps2d[:,pair[0], :].copy() + + records["kps2d"] = kps2d + + return records + + +@register_op +class PermuteImages(BaseOperator): + def __init__(self): + """ + Change the channel to be (batch_size, C, H, W) #(6, 3, 1080, 1920) + """ + super(PermuteImages, self).__init__() + + def apply(self, sample, context=None): + images = sample["image"] + images = images.transpose((0, 3, 1, 2)) + + sample["image"] = images + + return sample + + +@register_keypointop +class RandomFlipHalfBody3DTransformImages(object): + """apply data augment to images and coords + to achieve the flip, scale, rotate and half body transform effect for training image + Args: + trainsize (list):[w, h], Image target size + upper_body_ids (list): The upper body joint ids + flip_pairs (list): The left-right joints exchange order list + pixel_std (int): The pixel std of the scale + scale (float): The scale factor to transform the image + rot (int): The rotate factor to transform the image + num_joints_half_body (int): The joints threshold of the half body transform + prob_half_body (float): The threshold of the half body transform + flip (bool): Whether to flip the image + Returns: + records(dict): contain the image and coords after tranformed + """ + + def __init__(self, + trainsize, + upper_body_ids, + flip_pairs, + pixel_std, + scale=0.35, + rot=40, + num_joints_half_body=8, + prob_half_body=0.3, + flip=True, + rot_prob=0.6, + do_occlusion=False): + super(RandomFlipHalfBody3DTransformImages, self).__init__() + self.trainsize = trainsize + self.upper_body_ids = upper_body_ids + self.flip_pairs = flip_pairs + self.pixel_std = pixel_std + self.scale = scale + self.rot = rot + self.num_joints_half_body = num_joints_half_body + self.prob_half_body = prob_half_body + self.flip = flip + self.aspect_ratio = trainsize[0] * 1.0 / trainsize[1] + self.rot_prob = rot_prob + self.do_occlusion = do_occlusion + + def halfbody_transform(self, joints, joints_vis): + upper_joints = [] + lower_joints = [] + for joint_id in range(joints.shape[0]): + if joints_vis[joint_id][0] > 0: + if joint_id in self.upper_body_ids: + upper_joints.append(joints[joint_id]) + else: + lower_joints.append(joints[joint_id]) + if np.random.randn() < 0.5 and len(upper_joints) > 2: + selected_joints = upper_joints + else: + selected_joints = lower_joints if len( + lower_joints) > 2 else upper_joints + if len(selected_joints) < 2: + return None, None + selected_joints = np.array(selected_joints, dtype=np.float32) + center = selected_joints.mean(axis=0)[:2] + left_top = np.amin(selected_joints, axis=0) + right_bottom = np.amax(selected_joints, axis=0) + w = right_bottom[0] - left_top[0] + h = right_bottom[1] - left_top[1] + if w > self.aspect_ratio * h: + h = w * 1.0 / self.aspect_ratio + elif w < self.aspect_ratio * h: + w = h * self.aspect_ratio + scale = np.array( + [w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std], + dtype=np.float32) + scale = scale * 1.5 + + return center, scale + + def flip_joints(self, joints, joints_vis, width, matched_parts, kps2d=None): + # joints: (6, 24, 3),(num_frames, num_joints, 3) + + joints[:, :, 0] = width - joints[:, :, 0] - 1 # x + if kps2d is not None: + kps2d[:, :, 0] = width - kps2d[:, :, 0] - 1 + + for pair in matched_parts: + joints[:, pair[0], :], joints[:,pair[1], :] = \ + joints[:,pair[1], :], joints[:,pair[0], :].copy() + + joints_vis[:,pair[0], :], joints_vis[:,pair[1], :] = \ + joints_vis[:,pair[1], :], joints_vis[:,pair[0], :].copy() + + if kps2d is not None: + kps2d[:, pair[0], :], kps2d[:,pair[1], :] = \ + kps2d[:,pair[1], :], kps2d[:,pair[0], :].copy() + + # move to zero + joints -= joints[:, [0], :] # (batch_size, 24, 3),numpy.ndarray + + return joints, joints_vis, kps2d + + def __call__(self, records): + images = records[ + 'image'] #kps3d, kps3d_vis, images. images.shape(num_frames, width, height, 3) + + joints = records['kps3d'] + joints_vis = records['kps3d_vis'] + + kps2d = None + if 'kps2d' in records.keys(): + kps2d = records['kps2d'] + + if self.flip and np.random.random() <= 0.5: + images = images[:, :, ::-1, :] # 图像水平翻转 (6, 1080, 810, 3) + joints, joints_vis, kps2d = self.flip_joints( + joints, joints_vis, images.shape[2], self.flip_pairs, + kps2d) # 关键点左右对称翻转 + occlusion = False + if self.do_occlusion and random.random() <= 0.5: # 随机遮挡 + height = images[0].shape[0] + width = images[0].shape[1] + occlusion = True + while True: + area_min = 0.0 + area_max = 0.2 + synth_area = (random.random() * + (area_max - area_min) + area_min) * width * height + + ratio_min = 0.3 + ratio_max = 1 / 0.3 + synth_ratio = (random.random() * + (ratio_max - ratio_min) + ratio_min) + + synth_h = math.sqrt(synth_area * synth_ratio) + synth_w = math.sqrt(synth_area / synth_ratio) + synth_xmin = random.random() * (width - synth_w - 1) + synth_ymin = random.random() * (height - synth_h - 1) + + if synth_xmin >= 0 and synth_ymin >= 0 and synth_xmin + synth_w < width and synth_ymin + synth_h < height: + xmin = int(synth_xmin) + ymin = int(synth_ymin) + w = int(synth_w) + h = int(synth_h) + + mask = np.random.rand(h, w, 3) * 255 + images[:, ymin:ymin + h, xmin:xmin + w, :] = mask[ + None, :, :, :] + break + + records['image'] = images + records['kps3d'] = joints + records['kps3d_vis'] = joints_vis + if kps2d is not None: + records['kps2d'] = kps2d + + return records diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index 9b390f018..2f57cdfe3 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -400,6 +400,7 @@ class NormalizeImage(BaseOperator): 2.(optional) Each pixel minus mean and is divided by std """ im = sample['image'] + im = im.astype(np.float32, copy=False) if self.is_scale: scale = 1.0 / 255.0 @@ -410,6 +411,7 @@ class NormalizeImage(BaseOperator): std = np.array(self.std)[np.newaxis, np.newaxis, :] im -= mean im /= std + sample['image'] = im if 'pre_image' in sample: @@ -425,6 +427,7 @@ class NormalizeImage(BaseOperator): pre_im -= mean pre_im /= std sample['pre_image'] = pre_im + return sample @@ -813,13 +816,14 @@ class Resize(BaseOperator): im = sample['image'] if not isinstance(im, np.ndarray): raise TypeError("{}: image type is not numpy.".format(self)) - if len(im.shape) != 3: - raise ImageError('{}: image is not 3-dimensional.'.format(self)) # apply image - im_shape = im.shape - if self.keep_ratio: + if len(im.shape) == 3: + im_shape = im.shape + else: + im_shape = im[0].shape + if self.keep_ratio: im_size_min = np.min(im_shape[0:2]) im_size_max = np.max(im_shape[0:2]) @@ -839,8 +843,25 @@ class Resize(BaseOperator): im_scale_y = resize_h / im_shape[0] im_scale_x = resize_w / im_shape[1] - im = self.apply_image(sample['image'], [im_scale_x, im_scale_y]) - sample['image'] = im.astype(np.float32) + if len(im.shape) == 3: + im = self.apply_image(sample['image'], [im_scale_x, im_scale_y]) + sample['image'] = im.astype(np.float32) + else: + resized_images = [] + for one_im in im: + applied_im = self.apply_image(one_im, [im_scale_x, im_scale_y]) + resized_images.append(applied_im) + + sample['image'] = np.array(resized_images) + + # 2d keypoints resize + if 'kps2d' in sample.keys(): + kps2d = sample['kps2d'] + kps2d[:, :, 0] = kps2d[:, :, 0] * im_scale_x + kps2d[:, :, 1] = kps2d[:, :, 1] * im_scale_y + + sample['kps2d'] = kps2d + sample['im_shape'] = np.asarray([resize_h, resize_w], dtype=np.float32) if 'scale_factor' in sample: scale_factor = sample['scale_factor'] diff --git a/ppdet/modeling/architectures/keypoint_hrnet.py b/ppdet/modeling/architectures/keypoint_hrnet.py index 914bd043c..fa3541d7d 100644 --- a/ppdet/modeling/architectures/keypoint_hrnet.py +++ b/ppdet/modeling/architectures/keypoint_hrnet.py @@ -24,8 +24,9 @@ from ppdet.core.workspace import register, create from .meta_arch import BaseArch from ..keypoint_utils import transform_preds from .. import layers as L +from paddle.nn import functional as F -__all__ = ['TopDownHRNet'] +__all__ = ['TopDownHRNet', 'TinyPose3DHRNet', 'TinyPose3DHRHeatmapNet'] @register @@ -265,3 +266,207 @@ class HRNetPostProcess(object): maxvals, axis=1) ]] return outputs + + +class TinyPose3DPostProcess(object): + def __init__(self): + pass + + def __call__(self, output, center, scale): + """ + Args: + output (numpy.ndarray): numpy.ndarray([batch_size, num_joints, 3]), keypoints coords + scale (numpy.ndarray): The scale factor + Returns: + preds: numpy.ndarray([batch_size, num_joints, 3]), keypoints coords + """ + + preds = output.numpy().copy() + + # Transform back + for i in range(output.shape[0]): # batch_size + preds[i][:, 0] = preds[i][:, 0] * scale[i][0] + preds[i][:, 1] = preds[i][:, 1] * scale[i][1] + + return preds + + +def soft_argmax(heatmaps, joint_num): + dims = heatmaps.shape + depth_dim = (int)(dims[1] / joint_num) + heatmaps = heatmaps.reshape((-1, joint_num, depth_dim * dims[2] * dims[3])) + heatmaps = F.softmax(heatmaps, 2) + heatmaps = heatmaps.reshape((-1, joint_num, depth_dim, dims[2], dims[3])) + + accu_x = heatmaps.sum(axis=(2, 3)) + accu_y = heatmaps.sum(axis=(2, 4)) + accu_z = heatmaps.sum(axis=(3, 4)) + + accu_x = accu_x * paddle.arange(1, 33) + accu_y = accu_y * paddle.arange(1, 33) + accu_z = accu_z * paddle.arange(1, 33) + + accu_x = accu_x.sum(axis=2, keepdim=True) - 1 + accu_y = accu_y.sum(axis=2, keepdim=True) - 1 + accu_z = accu_z.sum(axis=2, keepdim=True) - 1 + + coord_out = paddle.concat( + (accu_x, accu_y, accu_z), axis=2) # [batch_size, joint_num, 3] + + return coord_out + + +@register +class TinyPose3DHRHeatmapNet(BaseArch): + __category__ = 'architecture' + __inject__ = ['loss'] + + def __init__( + self, + width, # 40, backbone输出的channel数目 + num_joints, + backbone='HRNet', + loss='KeyPointRegressionMSELoss', + post_process=TinyPose3DPostProcess): + """ + Args: + backbone (nn.Layer): backbone instance + post_process (object): post process instance + """ + super(TinyPose3DHRHeatmapNet, self).__init__() + + self.backbone = backbone + self.post_process = TinyPose3DPostProcess() + self.loss = loss + self.deploy = False + self.num_joints = num_joints + + self.final_conv = L.Conv2d(width, num_joints, 1, 1, 0, bias=True) + # for heatmap output + self.final_conv_new = L.Conv2d( + width, num_joints * 32, 1, 1, 0, bias=True) + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + # backbone + backbone = create(cfg['backbone']) + + return {'backbone': backbone, } + + def _forward(self): + feats = self.backbone(self.inputs) # feats:[[batch_size, 40, 32, 24]] + + hrnet_outputs = self.final_conv_new(feats[0]) + res = soft_argmax(hrnet_outputs, self.num_joints) + + if self.training: + return self.loss(res, self.inputs) + else: # export model need + return res + + def get_loss(self): + return self._forward() + + def get_pred(self): + res_lst = self._forward() + outputs = {'keypoint': res_lst} + return outputs + + def flip_back(self, output_flipped, matched_parts): + assert output_flipped.ndim == 4,\ + 'output_flipped should be [batch_size, num_joints, height, width]' + + output_flipped = output_flipped[:, :, :, ::-1] + + for pair in matched_parts: + tmp = output_flipped[:, pair[0], :, :].copy() + output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :] + output_flipped[:, pair[1], :, :] = tmp + + return output_flipped + + +@register +class TinyPose3DHRNet(BaseArch): + __category__ = 'architecture' + __inject__ = ['loss'] + + def __init__(self, + width, + num_joints, + backbone='HRNet', + loss='KeyPointRegressionMSELoss', + post_process=TinyPose3DPostProcess): + """ + Args: + backbone (nn.Layer): backbone instance + post_process (object): post process instance + """ + super(TinyPose3DHRNet, self).__init__() + self.backbone = backbone + self.post_process = TinyPose3DPostProcess() + self.loss = loss + self.deploy = False + self.num_joints = num_joints + + self.final_conv = L.Conv2d(width, num_joints, 1, 1, 0, bias=True) + + self.final_conv_new = L.Conv2d( + width, num_joints * 32, 1, 1, 0, bias=True) + + self.flatten = paddle.nn.Flatten(start_axis=2, stop_axis=3) + self.fc1 = paddle.nn.Linear(768, 256) + self.act1 = paddle.nn.ReLU() + self.fc2 = paddle.nn.Linear(256, 64) + self.act2 = paddle.nn.ReLU() + self.fc3 = paddle.nn.Linear(64, 3) + + # for human3.6M + self.fc1_1 = paddle.nn.Linear(3136, 1024) + self.fc2_1 = paddle.nn.Linear(1024, 256) + self.fc3_1 = paddle.nn.Linear(256, 3) + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + # backbone + backbone = create(cfg['backbone']) + + return {'backbone': backbone, } + + def _forward(self): + feats = self.backbone(self.inputs) # feats:[[batch_size, 40, 32, 24]] + + hrnet_outputs = self.final_conv(feats[0]) + flatten_res = self.flatten( + hrnet_outputs) # [batch_size, 24, (height/4)*(width/4)] + res = self.fc1(flatten_res) + res = self.act1(res) + res = self.fc2(res) + res = self.act2(res) + res = self.fc3(res) # [batch_size, 24, 3] + + if self.training: + return self.loss(res, self.inputs) + else: # export model need + return res + + def get_loss(self): + return self._forward() + + def get_pred(self): + res_lst = self._forward() + outputs = {'keypoint': res_lst} + return outputs + + def flip_back(self, output_flipped, matched_parts): + assert output_flipped.ndim == 4,\ + 'output_flipped should be [batch_size, num_joints, height, width]' + + output_flipped = output_flipped[:, :, :, ::-1] + + for pair in matched_parts: + tmp = output_flipped[:, pair[0], :, :].copy() + output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :] + output_flipped[:, pair[1], :, :] = tmp + + return output_flipped diff --git a/ppdet/modeling/backbones/lite_hrnet.py b/ppdet/modeling/backbones/lite_hrnet.py index d6832c509..95e3a2630 100644 --- a/ppdet/modeling/backbones/lite_hrnet.py +++ b/ppdet/modeling/backbones/lite_hrnet.py @@ -854,6 +854,11 @@ class LiteHRNet(nn.Layer): def forward(self, inputs): x = inputs['image'] + dims = x.shape + if len(dims) == 5: + x = paddle.reshape(x, (dims[0] * dims[1], dims[2], dims[3], + dims[4])) # [6, 3, 128, 96] + x = self.stem(x) y_list = [x] for stage_idx in range(3): -- GitLab