未验证 提交 9fafde8f 编写于 作者: XYZ_916's avatar XYZ_916 提交者: GitHub

tinypose3d for medical dataset (#7696)

* tinypose3d for medical dataset

* modify tinypose-3d codes according to comments

* the images in dataset is named 'image'

* change model name to TinyPose3D

* annotations
上级 0bf1c25c
use_gpu: true
log_iter: 5
save_dir: output
snapshot_epoch: 1
weights: output/tinypose_3D_multi_frames/model_final
epoch: 420
num_joints: &num_joints 24
pixel_std: &pixel_std 200
metric: Pose3DEval
num_classes: 1
train_height: &train_height 128
train_width: &train_width 96
trainsize: &trainsize [*train_width, *train_height]
hmsize: &hmsize [24, 32]
flip_perm: &flip_perm [[1, 2], [4, 5], [7, 8], [10, 11], [13, 14], [16, 17], [18, 19], [20, 21], [22, 23]]
#####model
architecture: TinyPose3DHRNet
pretrain_weights: medical_multi_frames_best_model.pdparams
TinyPose3DHRNet:
backbone: LiteHRNet
post_process: TinyPose3DPostProcess
num_joints: *num_joints
width: &width 40
loss: KeyPointRegressionMSELoss
LiteHRNet:
network_type: wider_naive
freeze_at: -1
freeze_norm: false
return_idx: [0]
KeyPointRegressionMSELoss:
reduction: 'mean'
#####optimizer
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
milestones: [17, 21]
gamma: 0.1
- !LinearWarmup
start_factor: 0.01
steps: 1000
OptimizerBuilder:
optimizer:
type: Adam
regularizer:
factor: 0.0
type: L2
#####data
TrainDataset:
!Keypoint3DMultiFramesDataset
dataset_dir: "data/medical/multi_frames/train"
image_dir: "images"
p3d_dir: "joint_pc/player_0"
json_path: "json_results/player_0/player_0.json"
img_size: *trainsize # w,h
num_frames: 6
EvalDataset:
!Keypoint3DMultiFramesDataset
dataset_dir: "data/medical/multi_frames/val"
image_dir: "images"
p3d_dir: "joint_pc/player_0"
json_path: "json_results/player_0/player_0.json"
img_size: *trainsize # w,h
num_frames: 6
TestDataset:
!Keypoint3DMultiFramesDataset
dataset_dir: "data/medical/multi_frames/val"
image_dir: "images"
p3d_dir: "joint_pc/player_0"
json_path: "json_results/player_0/player_0.json"
img_size: *trainsize # w,h
num_frames: 6
worker_num: 4
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
TrainReader:
sample_transforms:
- CropAndFlipImages:
crop_range: [556, 1366]
- RandomFlipHalfBody3DTransformImages:
scale: 0.25
rot: 30
num_joints_half_body: 9
prob_half_body: 0.3
pixel_std: *pixel_std
trainsize: *trainsize
upper_body_ids: [0, 3, 6, 9, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
flip_pairs: *flip_perm
do_occlusion: true
- Resize: {interp: 2, target_size: [*train_height,*train_width], keep_ratio: false}
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- PermuteImages: {}
batch_size: 32
shuffle: true
drop_last: false
EvalReader:
sample_transforms:
- CropAndFlipImages:
crop_range: [556, 1366]
- Resize: {interp: 2, target_size: [*train_height,*train_width], keep_ratio: false}
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- PermuteImages: {}
batch_size: 32
TestReader:
inputs_def:
image_shape: [3, *train_height, *train_width]
sample_transforms:
- Decode: {}
- LetterBoxResize: { target_size: [*train_height,*train_width]}
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
fuse_normalize: false
use_gpu: true
log_iter: 5
save_dir: output
snapshot_epoch: 1
weights: output/tinypose3d_multi_frames_heatmap/model_final
epoch: 420
num_joints: &num_joints 24
pixel_std: &pixel_std 200
metric: Pose3DEval
num_classes: 1
train_height: &train_height 128
train_width: &train_width 128
trainsize: &trainsize [*train_width, *train_height]
hmsize: &hmsize [24, 32]
flip_perm: &flip_perm [[1, 2], [4, 5], [7, 8], [10, 11], [13, 14], [16, 17], [18, 19], [20, 21], [22, 23]]
#####model
architecture: TinyPose3DHRHeatmapNet
pretrain_weights: medical_multi_frames_best_model.pdparams
TinyPose3DHRHeatmapNet:
backbone: LiteHRNet
post_process: TinyPosePostProcess
num_joints: *num_joints
width: &width 40
loss: KeyPointRegressionMSELoss
LiteHRNet:
network_type: wider_naive
freeze_at: -1
freeze_norm: false
return_idx: [0]
KeyPointRegressionMSELoss:
reduction: 'mean'
#####optimizer
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
milestones: [17, 21]
gamma: 0.1
- !LinearWarmup
start_factor: 0.01
steps: 1000
OptimizerBuilder:
optimizer:
type: Adam
regularizer:
factor: 0.0
type: L2
#####data
TrainDataset:
!Keypoint3DMultiFramesDataset
dataset_dir: "data/medical/multi_frames/train"
image_dir: "images"
p3d_dir: "joint_pc/player_0"
json_path: "json_results/player_0/player_0.json"
img_size: *trainsize # w,h
num_frames: 6
EvalDataset:
!Keypoint3DMultiFramesDataset
dataset_dir: "data/medical/multi_frames/val"
image_dir: "images"
p3d_dir: "joint_pc/player_0"
json_path: "json_results/player_0/player_0.json"
img_size: *trainsize # w,h
num_frames: 6
TestDataset:
!Keypoint3DMultiFramesDataset
dataset_dir: "data/medical/multi_frames/val"
image_dir: "images"
p3d_dir: "joint_pc/player_0"
json_path: "json_results/player_0/player_0.json"
img_size: *trainsize # w,h
num_frames: 6
worker_num: 4
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
TrainReader:
sample_transforms:
- CropAndFlipImages:
crop_range: [556, 1366] # 保留train_height/train_width比例的情况下,裁剪原图左右两个的黑色填充
- RandomFlipHalfBody3DTransformImages:
scale: 0.25
rot: 30
num_joints_half_body: 9
prob_half_body: 0.3
pixel_std: *pixel_std
trainsize: *trainsize
upper_body_ids: [0, 3, 6, 9, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
flip_pairs: *flip_perm
do_occlusion: true
- Resize: {interp: 2, target_size: [*train_height,*train_width], keep_ratio: false}
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- PermuteImages: {}
batch_size: 1 #32
shuffle: true
drop_last: false
EvalReader:
sample_transforms:
- CropAndFlipImages:
crop_range: [556, 1366]
- Resize: {interp: 2, target_size: [*train_height,*train_width], keep_ratio: false}
#- OriginPointTranslationImages: {}
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- PermuteImages: {}
batch_size: 32
TestReader:
inputs_def:
image_shape: [3, *train_height, *train_width]
sample_transforms:
- Decode: {}
- LetterBoxResize: { target_size: [*train_height,*train_width]}
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
fuse_normalize: false
...@@ -28,4 +28,4 @@ from .keypoint_coco import * ...@@ -28,4 +28,4 @@ from .keypoint_coco import *
from .mot import * from .mot import *
from .sniper_coco import SniperCOCODataSet from .sniper_coco import SniperCOCODataSet
from .dataset import ImageFolder from .dataset import ImageFolder
from .pose3d_cmb import Pose3DDataset from .pose3d_cmb import *
...@@ -23,6 +23,7 @@ import pycocotools ...@@ -23,6 +23,7 @@ import pycocotools
from pycocotools.coco import COCO from pycocotools.coco import COCO
from .dataset import DetDataset from .dataset import DetDataset
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from paddle.io import Dataset
@serializable @serializable
...@@ -198,3 +199,184 @@ class Pose3DDataset(DetDataset): ...@@ -198,3 +199,184 @@ class Pose3DDataset(DetDataset):
raise ValueError( raise ValueError(
"Some dataset is not valid and cannot download automatically now, please prepare the dataset first" "Some dataset is not valid and cannot download automatically now, please prepare the dataset first"
) )
@register
@serializable
class Keypoint3DMultiFramesDataset(Dataset):
"""24 keypoints 3D dataset for pose estimation.
each item is a list of images
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
Args:
dataset_dir (str): Root path to the dataset.
image_dir (str): Path to a directory where images are held.
"""
def __init__(
self,
dataset_dir, # 数据集根目录
image_dir, # 图像文件夹
p3d_dir, # 3D关键点文件夹
json_path,
img_size, #图像resize大小
num_frames, # 帧序列长度
anno_path=None, ):
self.dataset_dir = dataset_dir
self.image_dir = image_dir
self.p3d_dir = p3d_dir
self.json_path = json_path
self.img_size = img_size
self.num_frames = num_frames
self.anno_path = anno_path
self.data_labels, self.mf_inds = self._generate_multi_frames_list()
def _generate_multi_frames_list(self):
act_list = os.listdir(self.dataset_dir) # 动作列表
count = 0
mf_list = []
annos_dict = {'images': [], 'annotations': [], 'act_inds': []}
for act in act_list: #对每个动作,生成帧序列
if '.' in act:
continue
json_path = os.path.join(self.dataset_dir, act, self.json_path)
with open(json_path, 'r') as j:
annos = json.load(j)
length = len(annos['images'])
for k, v in annos.items():
if k in annos_dict:
annos_dict[k].extend(v)
annos_dict['act_inds'].extend([act] * length)
mf = [[i + j + count for j in range(self.num_frames)]
for i in range(0, length - self.num_frames + 1)]
mf_list.extend(mf)
count += length
print("total data number:", len(mf_list))
return annos_dict, mf_list
def __call__(self, *args, **kwargs):
return self
def __getitem__(self, index): # 拿一个连续的序列
inds = self.mf_inds[
index] # 如[568, 569, 570, 571, 572, 573],长度为num_frames
images = self.data_labels['images'] # all images
annots = self.data_labels['annotations'] # all annots
act = self.data_labels['act_inds'][inds[0]] # 动作名(文件夹名)
kps3d_list = []
kps3d_vis_list = []
names = []
h, w = 0, 0
for ind in inds: # one image
height = float(images[ind]['height'])
width = float(images[ind]['width'])
name = images[ind]['file_name'] # 图像名称,带有后缀
kps3d_name = name.split('.')[0] + '.obj'
kps3d_path = os.path.join(self.dataset_dir, act, self.p3d_dir,
kps3d_name)
joints, joints_vis = self.kps3d_process(kps3d_path)
joints_vis = np.array(joints_vis, dtype=np.float32)
kps3d_list.append(joints)
kps3d_vis_list.append(joints_vis)
names.append(name)
kps3d = np.array(kps3d_list) # (6, 24, 3),(num_frames, joints_num, 3)
kps3d_vis = np.array(kps3d_vis_list)
# read image
imgs = []
for name in names:
img_path = os.path.join(self.dataset_dir, act, self.image_dir, name)
image = cv2.imread(img_path, cv2.IMREAD_COLOR |
cv2.IMREAD_IGNORE_ORIENTATION)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
imgs.append(np.expand_dims(image, axis=0))
imgs = np.concatenate(imgs, axis=0)
imgs = imgs.astype(
np.float32) # (6, 1080, 1920, 3),(num_frames, h, w, c)
# attention: 此时图像和标注是镜像的
records = {
'kps3d': kps3d,
'kps3d_vis': kps3d_vis,
"image": imgs,
'act': act,
'names': names,
'im_id': index
}
return self.transform(records)
def kps3d_process(self, kps3d_path):
count = 0
kps = []
kps_vis = []
with open(kps3d_path, 'r') as f:
lines = f.readlines()
for line in lines:
if line[0] == 'v':
kps.append([])
line = line.strip('\n').split(' ')[1:]
for kp in line:
kps[-1].append(float(kp))
count += 1
kps_vis.append([1, 1, 1])
kps = np.array(kps) # 52,3
kps_vis = np.array(kps_vis)
kps *= 10 # scale points
kps -= kps[[0], :] # set root point to zero
kps = np.concatenate((kps[0:23], kps[[37]]), axis=0) # 24,3
kps *= 10
kps_vis = np.concatenate((kps_vis[0:23], kps_vis[[37]]), axis=0) # 24,3
return kps, kps_vis
def __len__(self):
return len(self.mf_inds)
def get_anno(self):
if self.anno_path is None:
return
return os.path.join(self.dataset_dir, self.anno_path)
def check_or_download_dataset(self):
return
def parse_dataset(self, ):
return
def set_transform(self, transform):
self.transform = transform
def set_epoch(self, epoch_id):
self._epoch = epoch_id
def set_kwargs(self, **kwargs):
self.mixup_epoch = kwargs.get('mixup_epoch', -1)
self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)
...@@ -17,12 +17,14 @@ from . import batch_operators ...@@ -17,12 +17,14 @@ from . import batch_operators
from . import keypoint_operators from . import keypoint_operators
from . import mot_operators from . import mot_operators
from . import rotated_operators from . import rotated_operators
from . import keypoints_3d_operators
from .operators import * from .operators import *
from .batch_operators import * from .batch_operators import *
from .keypoint_operators import * from .keypoint_operators import *
from .mot_operators import * from .mot_operators import *
from .rotated_operators import * from .rotated_operators import *
from .keypoints_3d_operators import *
__all__ = [] __all__ = []
__all__ += registered_ops __all__ += registered_ops
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
try:
from collections.abc import Sequence
except Exception:
from collections import Sequence
import cv2
import numpy as np
import math
import copy
import random
import uuid
from numbers import Number, Integral
from ...modeling.keypoint_utils import get_affine_mat_kernel, warp_affine_joints, get_affine_transform, affine_transform, get_warp_matrix
from ppdet.core.workspace import serializable
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
registered_ops = []
__all__ = [
'CropAndFlipImages', 'PermuteImages', 'RandomFlipHalfBody3DTransformImages'
]
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
from mpl_toolkits.mplot3d import Axes3D
def register_keypointop(cls):
return serializable(cls)
def register_op(cls):
registered_ops.append(cls.__name__)
if not hasattr(BaseOperator, cls.__name__):
setattr(BaseOperator, cls.__name__, cls)
else:
raise KeyError("The {} class has been registered.".format(cls.__name__))
return serializable(cls)
class BaseOperator(object):
def __init__(self, name=None):
if name is None:
name = self.__class__.__name__
self._id = name + '_' + str(uuid.uuid4())[-6:]
def apply(self, sample, context=None):
""" Process a sample.
Args:
sample (dict): a dict of sample, eg: {'image':xx, 'label': xxx}
context (dict): info about this sample processing
Returns:
result (dict): a processed sample
"""
return sample
def __call__(self, sample, context=None):
""" Process a sample.
Args:
sample (dict): a dict of sample, eg: {'image':xx, 'label': xxx}
context (dict): info about this sample processing
Returns:
result (dict): a processed sample
"""
if isinstance(sample, Sequence): # for batch_size
for i in range(len(sample)):
sample[i] = self.apply(sample[i], context)
else:
# image.shape changed
sample = self.apply(sample, context)
return sample
def __str__(self):
return str(self._id)
@register_keypointop
class CropAndFlipImages(object):
"""Crop all images"""
def __init__(self, crop_range, flip_pairs=None):
super(CropAndFlipImages, self).__init__()
self.crop_range = crop_range
self.flip_pairs = flip_pairs
def __call__(self, records): # tuple
images = records["image"]
images = images[:, :, ::-1, :]
images = images[:, :, self.crop_range[0]:self.crop_range[1]]
records["image"] = images
if "kps2d" in records.keys():
kps2d = records["kps2d"]
width, height = images.shape[2], images.shape[1]
kps2d = np.array(kps2d)
kps2d[:, :, 0] = kps2d[:, :, 0] - self.crop_range[0]
for pair in self.flip_pairs:
kps2d[:, pair[0], :], kps2d[:,pair[1], :] = \
kps2d[:,pair[1], :], kps2d[:,pair[0], :].copy()
records["kps2d"] = kps2d
return records
@register_op
class PermuteImages(BaseOperator):
def __init__(self):
"""
Change the channel to be (batch_size, C, H, W) #(6, 3, 1080, 1920)
"""
super(PermuteImages, self).__init__()
def apply(self, sample, context=None):
images = sample["image"]
images = images.transpose((0, 3, 1, 2))
sample["image"] = images
return sample
@register_keypointop
class RandomFlipHalfBody3DTransformImages(object):
"""apply data augment to images and coords
to achieve the flip, scale, rotate and half body transform effect for training image
Args:
trainsize (list):[w, h], Image target size
upper_body_ids (list): The upper body joint ids
flip_pairs (list): The left-right joints exchange order list
pixel_std (int): The pixel std of the scale
scale (float): The scale factor to transform the image
rot (int): The rotate factor to transform the image
num_joints_half_body (int): The joints threshold of the half body transform
prob_half_body (float): The threshold of the half body transform
flip (bool): Whether to flip the image
Returns:
records(dict): contain the image and coords after tranformed
"""
def __init__(self,
trainsize,
upper_body_ids,
flip_pairs,
pixel_std,
scale=0.35,
rot=40,
num_joints_half_body=8,
prob_half_body=0.3,
flip=True,
rot_prob=0.6,
do_occlusion=False):
super(RandomFlipHalfBody3DTransformImages, self).__init__()
self.trainsize = trainsize
self.upper_body_ids = upper_body_ids
self.flip_pairs = flip_pairs
self.pixel_std = pixel_std
self.scale = scale
self.rot = rot
self.num_joints_half_body = num_joints_half_body
self.prob_half_body = prob_half_body
self.flip = flip
self.aspect_ratio = trainsize[0] * 1.0 / trainsize[1]
self.rot_prob = rot_prob
self.do_occlusion = do_occlusion
def halfbody_transform(self, joints, joints_vis):
upper_joints = []
lower_joints = []
for joint_id in range(joints.shape[0]):
if joints_vis[joint_id][0] > 0:
if joint_id in self.upper_body_ids:
upper_joints.append(joints[joint_id])
else:
lower_joints.append(joints[joint_id])
if np.random.randn() < 0.5 and len(upper_joints) > 2:
selected_joints = upper_joints
else:
selected_joints = lower_joints if len(
lower_joints) > 2 else upper_joints
if len(selected_joints) < 2:
return None, None
selected_joints = np.array(selected_joints, dtype=np.float32)
center = selected_joints.mean(axis=0)[:2]
left_top = np.amin(selected_joints, axis=0)
right_bottom = np.amax(selected_joints, axis=0)
w = right_bottom[0] - left_top[0]
h = right_bottom[1] - left_top[1]
if w > self.aspect_ratio * h:
h = w * 1.0 / self.aspect_ratio
elif w < self.aspect_ratio * h:
w = h * self.aspect_ratio
scale = np.array(
[w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
dtype=np.float32)
scale = scale * 1.5
return center, scale
def flip_joints(self, joints, joints_vis, width, matched_parts, kps2d=None):
# joints: (6, 24, 3),(num_frames, num_joints, 3)
joints[:, :, 0] = width - joints[:, :, 0] - 1 # x
if kps2d is not None:
kps2d[:, :, 0] = width - kps2d[:, :, 0] - 1
for pair in matched_parts:
joints[:, pair[0], :], joints[:,pair[1], :] = \
joints[:,pair[1], :], joints[:,pair[0], :].copy()
joints_vis[:,pair[0], :], joints_vis[:,pair[1], :] = \
joints_vis[:,pair[1], :], joints_vis[:,pair[0], :].copy()
if kps2d is not None:
kps2d[:, pair[0], :], kps2d[:,pair[1], :] = \
kps2d[:,pair[1], :], kps2d[:,pair[0], :].copy()
# move to zero
joints -= joints[:, [0], :] # (batch_size, 24, 3),numpy.ndarray
return joints, joints_vis, kps2d
def __call__(self, records):
images = records[
'image'] #kps3d, kps3d_vis, images. images.shape(num_frames, width, height, 3)
joints = records['kps3d']
joints_vis = records['kps3d_vis']
kps2d = None
if 'kps2d' in records.keys():
kps2d = records['kps2d']
if self.flip and np.random.random() <= 0.5:
images = images[:, :, ::-1, :] # 图像水平翻转 (6, 1080, 810, 3)
joints, joints_vis, kps2d = self.flip_joints(
joints, joints_vis, images.shape[2], self.flip_pairs,
kps2d) # 关键点左右对称翻转
occlusion = False
if self.do_occlusion and random.random() <= 0.5: # 随机遮挡
height = images[0].shape[0]
width = images[0].shape[1]
occlusion = True
while True:
area_min = 0.0
area_max = 0.2
synth_area = (random.random() *
(area_max - area_min) + area_min) * width * height
ratio_min = 0.3
ratio_max = 1 / 0.3
synth_ratio = (random.random() *
(ratio_max - ratio_min) + ratio_min)
synth_h = math.sqrt(synth_area * synth_ratio)
synth_w = math.sqrt(synth_area / synth_ratio)
synth_xmin = random.random() * (width - synth_w - 1)
synth_ymin = random.random() * (height - synth_h - 1)
if synth_xmin >= 0 and synth_ymin >= 0 and synth_xmin + synth_w < width and synth_ymin + synth_h < height:
xmin = int(synth_xmin)
ymin = int(synth_ymin)
w = int(synth_w)
h = int(synth_h)
mask = np.random.rand(h, w, 3) * 255
images[:, ymin:ymin + h, xmin:xmin + w, :] = mask[
None, :, :, :]
break
records['image'] = images
records['kps3d'] = joints
records['kps3d_vis'] = joints_vis
if kps2d is not None:
records['kps2d'] = kps2d
return records
...@@ -400,6 +400,7 @@ class NormalizeImage(BaseOperator): ...@@ -400,6 +400,7 @@ class NormalizeImage(BaseOperator):
2.(optional) Each pixel minus mean and is divided by std 2.(optional) Each pixel minus mean and is divided by std
""" """
im = sample['image'] im = sample['image']
im = im.astype(np.float32, copy=False) im = im.astype(np.float32, copy=False)
if self.is_scale: if self.is_scale:
scale = 1.0 / 255.0 scale = 1.0 / 255.0
...@@ -410,6 +411,7 @@ class NormalizeImage(BaseOperator): ...@@ -410,6 +411,7 @@ class NormalizeImage(BaseOperator):
std = np.array(self.std)[np.newaxis, np.newaxis, :] std = np.array(self.std)[np.newaxis, np.newaxis, :]
im -= mean im -= mean
im /= std im /= std
sample['image'] = im sample['image'] = im
if 'pre_image' in sample: if 'pre_image' in sample:
...@@ -425,6 +427,7 @@ class NormalizeImage(BaseOperator): ...@@ -425,6 +427,7 @@ class NormalizeImage(BaseOperator):
pre_im -= mean pre_im -= mean
pre_im /= std pre_im /= std
sample['pre_image'] = pre_im sample['pre_image'] = pre_im
return sample return sample
...@@ -813,13 +816,14 @@ class Resize(BaseOperator): ...@@ -813,13 +816,14 @@ class Resize(BaseOperator):
im = sample['image'] im = sample['image']
if not isinstance(im, np.ndarray): if not isinstance(im, np.ndarray):
raise TypeError("{}: image type is not numpy.".format(self)) raise TypeError("{}: image type is not numpy.".format(self))
if len(im.shape) != 3:
raise ImageError('{}: image is not 3-dimensional.'.format(self))
# apply image # apply image
im_shape = im.shape if len(im.shape) == 3:
if self.keep_ratio: im_shape = im.shape
else:
im_shape = im[0].shape
if self.keep_ratio:
im_size_min = np.min(im_shape[0:2]) im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2]) im_size_max = np.max(im_shape[0:2])
...@@ -839,8 +843,25 @@ class Resize(BaseOperator): ...@@ -839,8 +843,25 @@ class Resize(BaseOperator):
im_scale_y = resize_h / im_shape[0] im_scale_y = resize_h / im_shape[0]
im_scale_x = resize_w / im_shape[1] im_scale_x = resize_w / im_shape[1]
im = self.apply_image(sample['image'], [im_scale_x, im_scale_y]) if len(im.shape) == 3:
sample['image'] = im.astype(np.float32) im = self.apply_image(sample['image'], [im_scale_x, im_scale_y])
sample['image'] = im.astype(np.float32)
else:
resized_images = []
for one_im in im:
applied_im = self.apply_image(one_im, [im_scale_x, im_scale_y])
resized_images.append(applied_im)
sample['image'] = np.array(resized_images)
# 2d keypoints resize
if 'kps2d' in sample.keys():
kps2d = sample['kps2d']
kps2d[:, :, 0] = kps2d[:, :, 0] * im_scale_x
kps2d[:, :, 1] = kps2d[:, :, 1] * im_scale_y
sample['kps2d'] = kps2d
sample['im_shape'] = np.asarray([resize_h, resize_w], dtype=np.float32) sample['im_shape'] = np.asarray([resize_h, resize_w], dtype=np.float32)
if 'scale_factor' in sample: if 'scale_factor' in sample:
scale_factor = sample['scale_factor'] scale_factor = sample['scale_factor']
......
...@@ -24,8 +24,9 @@ from ppdet.core.workspace import register, create ...@@ -24,8 +24,9 @@ from ppdet.core.workspace import register, create
from .meta_arch import BaseArch from .meta_arch import BaseArch
from ..keypoint_utils import transform_preds from ..keypoint_utils import transform_preds
from .. import layers as L from .. import layers as L
from paddle.nn import functional as F
__all__ = ['TopDownHRNet'] __all__ = ['TopDownHRNet', 'TinyPose3DHRNet', 'TinyPose3DHRHeatmapNet']
@register @register
...@@ -265,3 +266,207 @@ class HRNetPostProcess(object): ...@@ -265,3 +266,207 @@ class HRNetPostProcess(object):
maxvals, axis=1) maxvals, axis=1)
]] ]]
return outputs return outputs
class TinyPose3DPostProcess(object):
def __init__(self):
pass
def __call__(self, output, center, scale):
"""
Args:
output (numpy.ndarray): numpy.ndarray([batch_size, num_joints, 3]), keypoints coords
scale (numpy.ndarray): The scale factor
Returns:
preds: numpy.ndarray([batch_size, num_joints, 3]), keypoints coords
"""
preds = output.numpy().copy()
# Transform back
for i in range(output.shape[0]): # batch_size
preds[i][:, 0] = preds[i][:, 0] * scale[i][0]
preds[i][:, 1] = preds[i][:, 1] * scale[i][1]
return preds
def soft_argmax(heatmaps, joint_num):
dims = heatmaps.shape
depth_dim = (int)(dims[1] / joint_num)
heatmaps = heatmaps.reshape((-1, joint_num, depth_dim * dims[2] * dims[3]))
heatmaps = F.softmax(heatmaps, 2)
heatmaps = heatmaps.reshape((-1, joint_num, depth_dim, dims[2], dims[3]))
accu_x = heatmaps.sum(axis=(2, 3))
accu_y = heatmaps.sum(axis=(2, 4))
accu_z = heatmaps.sum(axis=(3, 4))
accu_x = accu_x * paddle.arange(1, 33)
accu_y = accu_y * paddle.arange(1, 33)
accu_z = accu_z * paddle.arange(1, 33)
accu_x = accu_x.sum(axis=2, keepdim=True) - 1
accu_y = accu_y.sum(axis=2, keepdim=True) - 1
accu_z = accu_z.sum(axis=2, keepdim=True) - 1
coord_out = paddle.concat(
(accu_x, accu_y, accu_z), axis=2) # [batch_size, joint_num, 3]
return coord_out
@register
class TinyPose3DHRHeatmapNet(BaseArch):
__category__ = 'architecture'
__inject__ = ['loss']
def __init__(
self,
width, # 40, backbone输出的channel数目
num_joints,
backbone='HRNet',
loss='KeyPointRegressionMSELoss',
post_process=TinyPose3DPostProcess):
"""
Args:
backbone (nn.Layer): backbone instance
post_process (object): post process instance
"""
super(TinyPose3DHRHeatmapNet, self).__init__()
self.backbone = backbone
self.post_process = TinyPose3DPostProcess()
self.loss = loss
self.deploy = False
self.num_joints = num_joints
self.final_conv = L.Conv2d(width, num_joints, 1, 1, 0, bias=True)
# for heatmap output
self.final_conv_new = L.Conv2d(
width, num_joints * 32, 1, 1, 0, bias=True)
@classmethod
def from_config(cls, cfg, *args, **kwargs):
# backbone
backbone = create(cfg['backbone'])
return {'backbone': backbone, }
def _forward(self):
feats = self.backbone(self.inputs) # feats:[[batch_size, 40, 32, 24]]
hrnet_outputs = self.final_conv_new(feats[0])
res = soft_argmax(hrnet_outputs, self.num_joints)
if self.training:
return self.loss(res, self.inputs)
else: # export model need
return res
def get_loss(self):
return self._forward()
def get_pred(self):
res_lst = self._forward()
outputs = {'keypoint': res_lst}
return outputs
def flip_back(self, output_flipped, matched_parts):
assert output_flipped.ndim == 4,\
'output_flipped should be [batch_size, num_joints, height, width]'
output_flipped = output_flipped[:, :, :, ::-1]
for pair in matched_parts:
tmp = output_flipped[:, pair[0], :, :].copy()
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
output_flipped[:, pair[1], :, :] = tmp
return output_flipped
@register
class TinyPose3DHRNet(BaseArch):
__category__ = 'architecture'
__inject__ = ['loss']
def __init__(self,
width,
num_joints,
backbone='HRNet',
loss='KeyPointRegressionMSELoss',
post_process=TinyPose3DPostProcess):
"""
Args:
backbone (nn.Layer): backbone instance
post_process (object): post process instance
"""
super(TinyPose3DHRNet, self).__init__()
self.backbone = backbone
self.post_process = TinyPose3DPostProcess()
self.loss = loss
self.deploy = False
self.num_joints = num_joints
self.final_conv = L.Conv2d(width, num_joints, 1, 1, 0, bias=True)
self.final_conv_new = L.Conv2d(
width, num_joints * 32, 1, 1, 0, bias=True)
self.flatten = paddle.nn.Flatten(start_axis=2, stop_axis=3)
self.fc1 = paddle.nn.Linear(768, 256)
self.act1 = paddle.nn.ReLU()
self.fc2 = paddle.nn.Linear(256, 64)
self.act2 = paddle.nn.ReLU()
self.fc3 = paddle.nn.Linear(64, 3)
# for human3.6M
self.fc1_1 = paddle.nn.Linear(3136, 1024)
self.fc2_1 = paddle.nn.Linear(1024, 256)
self.fc3_1 = paddle.nn.Linear(256, 3)
@classmethod
def from_config(cls, cfg, *args, **kwargs):
# backbone
backbone = create(cfg['backbone'])
return {'backbone': backbone, }
def _forward(self):
feats = self.backbone(self.inputs) # feats:[[batch_size, 40, 32, 24]]
hrnet_outputs = self.final_conv(feats[0])
flatten_res = self.flatten(
hrnet_outputs) # [batch_size, 24, (height/4)*(width/4)]
res = self.fc1(flatten_res)
res = self.act1(res)
res = self.fc2(res)
res = self.act2(res)
res = self.fc3(res) # [batch_size, 24, 3]
if self.training:
return self.loss(res, self.inputs)
else: # export model need
return res
def get_loss(self):
return self._forward()
def get_pred(self):
res_lst = self._forward()
outputs = {'keypoint': res_lst}
return outputs
def flip_back(self, output_flipped, matched_parts):
assert output_flipped.ndim == 4,\
'output_flipped should be [batch_size, num_joints, height, width]'
output_flipped = output_flipped[:, :, :, ::-1]
for pair in matched_parts:
tmp = output_flipped[:, pair[0], :, :].copy()
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
output_flipped[:, pair[1], :, :] = tmp
return output_flipped
...@@ -854,6 +854,11 @@ class LiteHRNet(nn.Layer): ...@@ -854,6 +854,11 @@ class LiteHRNet(nn.Layer):
def forward(self, inputs): def forward(self, inputs):
x = inputs['image'] x = inputs['image']
dims = x.shape
if len(dims) == 5:
x = paddle.reshape(x, (dims[0] * dims[1], dims[2], dims[3],
dims[4])) # [6, 3, 128, 96]
x = self.stem(x) x = self.stem(x)
y_list = [x] y_list = [x]
for stage_idx in range(3): for stage_idx in range(3):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册