未验证 提交 c9823094 编写于 作者: Z zhiboniu 提交者: GitHub

pose3d metro datasets part (#6611)

* pose3d metro datasets

* delete extra comment lines
上级 7b6bdf91
......@@ -28,3 +28,4 @@ from .keypoint_coco import *
from .mot import *
from .sniper_coco import SniperCOCODataSet
from .dataset import ImageFolder
from .pose3d_cmb import Pose3DDataset
......@@ -118,6 +118,9 @@ def get_categories(metric_type, anno_file=None, arch=None):
) == 'keypointtopdownmpiieval':
return (None, {'id': 'keypoint'})
elif metric_type.lower() == 'pose3deval':
return (None, {'id': 'pose3d'})
elif metric_type.lower() in ['mot', 'motdet', 'reid']:
if anno_file and os.path.isfile(anno_file):
cats = []
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
this code is base on https://github.com/open-mmlab/mmpose
"""
import os
import cv2
import numpy as np
import json
import copy
import pycocotools
from pycocotools.coco import COCO
from .dataset import DetDataset
from ppdet.core.workspace import register, serializable
@serializable
class Pose3DDataset(DetDataset):
"""Pose3D Dataset class.
Args:
dataset_dir (str): Root path to the dataset.
anno_list (list of str): each of the element is a relative path to the annotation file.
image_dirs (list of str): each of path is a relative path where images are held.
transform (composed(operators)): A sequence of data transforms.
test_mode (bool): Store True when building test or
validation dataset. Default: False.
24 joints order:
0-2: 'R_Ankle', 'R_Knee', 'R_Hip',
3-5:'L_Hip', 'L_Knee', 'L_Ankle',
6-8:'R_Wrist', 'R_Elbow', 'R_Shoulder',
9-11:'L_Shoulder','L_Elbow','L_Wrist',
12-14:'Neck','Top_of_Head','Pelvis',
15-18:'Thorax','Spine','Jaw','Head',
19-23:'Nose','L_Eye','R_Eye','L_Ear','R_Ear'
"""
def __init__(self,
dataset_dir,
image_dirs,
anno_list,
transform=[],
num_joints=24,
test_mode=False):
super().__init__(dataset_dir, image_dirs, anno_list)
self.image_info = {}
self.ann_info = {}
self.num_joints = num_joints
self.transform = transform
self.test_mode = test_mode
self.img_ids = []
self.dataset_dir = dataset_dir
self.image_dirs = image_dirs
self.anno_list = anno_list
def get_mask(self, mvm_percent=0.3):
num_joints = self.num_joints
mjm_mask = np.ones((num_joints, 1)).astype(np.float)
if self.test_mode == False:
pb = np.random.random_sample()
masked_num = int(
pb * mvm_percent *
num_joints) # at most x% of the joints could be masked
indices = np.random.choice(
np.arange(num_joints), replace=False, size=masked_num)
mjm_mask[indices, :] = 0.0
mvm_mask = np.ones((10, 1)).astype(np.float)
if self.test_mode == False:
num_vertices = 10
pb = np.random.random_sample()
masked_num = int(
pb * mvm_percent *
num_vertices) # at most x% of the vertices could be masked
indices = np.random.choice(
np.arange(num_vertices), replace=False, size=masked_num)
mvm_mask[indices, :] = 0.0
mjm_mask = np.concatenate([mjm_mask, mvm_mask], axis=0)
return mjm_mask
def filterjoints(self, x):
if self.num_joints == 24:
return x
elif self.num_joints == 14:
return x[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18], :]
elif self.num_joints == 17:
return x[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 19], :]
else:
raise ValueError(
"unsupported joint numbers, only [24 or 17 or 14] is supported!")
def parse_dataset(self):
print("Loading annotations..., please wait")
self.annos = []
im_id = 0
for idx, annof in enumerate(self.anno_list):
img_prefix = os.path.join(self.dataset_dir, self.image_dirs[idx])
dataf = os.path.join(self.dataset_dir, annof)
with open(dataf, 'r') as rf:
anno_data = json.load(rf)
annos = anno_data['data']
new_annos = []
print("{} has annos numbers: {}".format(dataf, len(annos)))
for anno in annos:
new_anno = {}
new_anno['im_id'] = im_id
im_id += 1
imagename = anno['imageName']
if imagename.startswith("COCO_train2014_"):
imagename = imagename[len("COCO_train2014_"):]
elif imagename.startswith("COCO_val2014_"):
imagename = imagename[len("COCO_val2014_"):]
imagename = os.path.join(img_prefix, imagename)
if not os.path.exists(imagename):
if "train2017" in imagename:
imagename = imagename.replace("train2017",
"val2017")
if not os.path.exists(imagename):
print("cannot find imagepath:{}".format(
imagename))
continue
else:
print("cannot find imagepath:{}".format(imagename))
continue
new_anno['imageName'] = imagename
new_anno['bbox_center'] = anno['bbox_center']
new_anno['bbox_scale'] = anno['bbox_scale']
new_anno['joints_2d'] = np.array(anno[
'gt_keypoint_2d']).astype(np.float32)
if new_anno['joints_2d'].shape[0] == 49:
#if the joints_2d is in SPIN format(which generated by eft), choose the last 24 public joints
#for detail please refer: https://github.com/nkolot/SPIN/blob/master/constants.py
new_anno['joints_2d'] = new_anno['joints_2d'][25:]
new_anno['joints_3d'] = np.array(anno[
'pose3d'])[:, :3].astype(np.float32)
new_anno['mjm_mask'] = self.get_mask()
if not 'has_3d_joints' in anno:
new_anno['has_3d_joints'] = int(1)
new_anno['has_2d_joints'] = int(1)
else:
new_anno['has_3d_joints'] = int(anno['has_3d_joints'])
new_anno['has_2d_joints'] = int(anno['has_2d_joints'])
new_anno['joints_2d'] = self.filterjoints(new_anno[
'joints_2d'])
self.annos.append(new_anno)
del annos
def __len__(self):
"""Get dataset length."""
return len(self.annos)
def _get_imganno(self, idx):
"""Get anno for a single image."""
return self.annos[idx]
def __getitem__(self, idx):
"""Prepare image for training given the index."""
records = copy.deepcopy(self._get_imganno(idx))
imgpath = records['imageName']
assert os.path.exists(imgpath), "cannot find image {}".format(imgpath)
records['image'] = cv2.imread(imgpath)
records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB)
records = self.transform(records)
return records
def check_or_download_dataset(self):
alldatafind = True
for image_dir in self.image_dirs:
image_dir = os.path.join(self.dataset_dir, image_dir)
if not os.path.isdir(image_dir):
print("dataset [{}] is not found".format(image_dir))
alldatafind = False
if not alldatafind:
raise ValueError(
"Some dataset is not valid and cannot download automatically now, please prepare the dataset first"
)
......@@ -36,19 +36,12 @@ logger = setup_logger(__name__)
registered_ops = []
__all__ = [
'RandomAffine',
'KeyPointFlip',
'TagGenerate',
'ToHeatmaps',
'NormalizePermute',
'EvalAffine',
'RandomFlipHalfBodyTransform',
'TopDownAffine',
'ToHeatmapsTopDown',
'ToHeatmapsTopDown_DARK',
'ToHeatmapsTopDown_UDP',
'TopDownEvalAffine',
'AugmentationbyInformantionDropping',
'RandomAffine', 'KeyPointFlip', 'TagGenerate', 'ToHeatmaps',
'NormalizePermute', 'EvalAffine', 'RandomFlipHalfBodyTransform',
'TopDownAffine', 'ToHeatmapsTopDown', 'ToHeatmapsTopDown_DARK',
'ToHeatmapsTopDown_UDP', 'TopDownEvalAffine',
'AugmentationbyInformantionDropping', 'SinglePoseAffine', 'NoiseJitter',
'FlipPose'
]
......@@ -618,6 +611,169 @@ class TopDownAffine(object):
return records
@register_keypointop
class SinglePoseAffine(object):
"""apply affine transform to image and coords
Args:
trainsize (list): [w, h], the standard size used to train
use_udp (bool): whether to use Unbiased Data Processing.
records(dict): the dict contained the image and coords
Returns:
records (dict): contain the image and coords after tranformed
"""
def __init__(self,
trainsize,
rotate=[1.0, 30],
scale=[1.0, 0.25],
use_udp=False):
self.trainsize = trainsize
self.use_udp = use_udp
self.rot_prob = rotate[0]
self.rot_range = rotate[1]
self.scale_prob = scale[0]
self.scale_ratio = scale[1]
def __call__(self, records):
image = records['image']
if 'joints_2d' in records:
joints = records['joints_2d'] if 'joints_2d' in records else None
joints_vis = records[
'joints_vis'] if 'joints_vis' in records else np.ones(
(len(joints), 1))
rot = 0
s = 1.
if np.random.random() < self.rot_prob:
rot = np.clip(np.random.randn() * self.rot_range,
-self.rot_range * 2, self.rot_range * 2)
if np.random.random() < self.scale_prob:
s = np.clip(np.random.randn() * self.scale_ratio + 1,
1 - self.scale_ratio, 1 + self.scale_ratio)
if self.use_udp:
trans = get_warp_matrix(
rot,
np.array(records['bbox_center']) * 2.0,
[self.trainsize[0] - 1.0, self.trainsize[1] - 1.0],
records['bbox_scale'] * 200.0 * s)
image = cv2.warpAffine(
image,
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR)
if 'joints_2d' in records:
joints[:, 0:2] = warp_affine_joints(joints[:, 0:2].copy(),
trans)
else:
trans = get_affine_transform(
np.array(records['bbox_center']),
records['bbox_scale'] * s * 200, rot, self.trainsize)
image = cv2.warpAffine(
image,
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR)
if 'joints_2d' in records:
for i in range(len(joints)):
if joints_vis[i, 0] > 0.0:
joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
if 'joints_3d' in records:
pose3d = records['joints_3d']
if not rot == 0:
trans_3djoints = np.eye(3)
rot_rad = -rot * np.pi / 180
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
trans_3djoints[0, :2] = [cs, -sn]
trans_3djoints[1, :2] = [sn, cs]
pose3d[:, :3] = np.einsum('ij,kj->ki', trans_3djoints,
pose3d[:, :3])
records['joints_3d'] = pose3d
records['image'] = image
if 'joints_2d' in records:
records['joints_2d'] = joints
return records
@register_keypointop
class NoiseJitter(object):
"""apply NoiseJitter to image
Args:
noise_factor (float): the noise factor ratio used to generate the jitter
Returns:
records (dict): contain the image and coords after tranformed
"""
def __init__(self, noise_factor=0.4):
self.noise_factor = noise_factor
def __call__(self, records):
self.pn = np.random.uniform(1 - self.noise_factor,
1 + self.noise_factor, 3)
rgb_img = records['image']
rgb_img[:, :, 0] = np.minimum(
255.0, np.maximum(0.0, rgb_img[:, :, 0] * self.pn[0]))
rgb_img[:, :, 1] = np.minimum(
255.0, np.maximum(0.0, rgb_img[:, :, 1] * self.pn[1]))
rgb_img[:, :, 2] = np.minimum(
255.0, np.maximum(0.0, rgb_img[:, :, 2] * self.pn[2]))
records['image'] = rgb_img
return records
@register_keypointop
class FlipPose(object):
"""random apply flip to image
Args:
noise_factor (float): the noise factor ratio used to generate the jitter
Returns:
records (dict): contain the image and coords after tranformed
"""
def __init__(self, flip_prob=0.5, img_res=224, num_joints=14):
self.flip_pob = flip_prob
self.img_res = img_res
if num_joints == 24:
self.perm = [
5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17,
18, 19, 21, 20, 23, 22
]
elif num_joints == 14:
self.perm = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13]
else:
print("error num_joints in flip :{}".format(num_joints))
def __call__(self, records):
if np.random.random() < self.flip_pob:
img = records['image']
img = np.fliplr(img)
if 'joints_2d' in records:
joints_2d = records['joints_2d']
joints_2d = joints_2d[self.perm]
joints_2d[:, 0] = self.img_res - joints_2d[:, 0]
records['joints_2d'] = joints_2d
if 'joints_3d' in records:
joints_3d = records['joints_3d']
joints_3d = joints_3d[self.perm]
joints_3d[:, 0] = -joints_3d[:, 0]
records['joints_3d'] = joints_3d
records['image'] = img
return records
@register_keypointop
class TopDownEvalAffine(object):
"""apply affine transform to image and coords
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册