未验证 提交 0bf1c25c 编写于 作者: Z zhiboniu 提交者: GitHub

add metro3d config file (#7703)

* best mpjpe 55

* rename configfile

* replace api
上级 a65904d8
use_gpu: True
log_iter: 20
save_dir: output
snapshot_epoch: 3
weights: output/metro_modified/model_final
epoch: 50
metric: Pose3DEval
num_classes: 1
train_height: &train_height 224
train_width: &train_width 224
trainsize: &trainsize [*train_width, *train_height]
num_joints: &num_joints 24
#####model
architecture: METRO_Body
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/Trunc_HRNet_W32_C_pretrained.pdparams
METRO_Body:
backbone: HRNet
trans_encoder: TransEncoder
num_joints: *num_joints
loss: Pose3DLoss
HRNet:
width: 32
freeze_at: -1
freeze_norm: False
norm_momentum: 0.1
downsample: True
TransEncoder:
vocab_size: 30522
num_hidden_layers: 4
num_attention_heads: 4
position_embeddings_size: 512
intermediate_size: 3072
input_feat_dim: [2048, 512, 128]
hidden_feat_dim: [1024, 256, 128]
attention_probs_dropout_prob: 0.1
fc_dropout_prob: 0.1
act_fn: 'gelu'
output_attentions: False
output_hidden_feats: False
Pose3DLoss:
weight_3d: 1.0
weight_2d: 0.0
#####optimizer
LearningRate:
base_lr: 0.0001
schedulers:
- !CosineDecay
max_epochs: 52
- !LinearWarmup
start_factor: 0.01
steps: 2000
OptimizerBuilder:
clip_grad_by_norm: 0.2
optimizer:
type: Adam
regularizer:
factor: 0.0
type: L2
#####data
TrainDataset:
!Pose3DDataset
dataset_dir: dataset/traindata/
image_dirs: ["human3.6m", "posetrack3d", "hr-lspet", "hr-lspet", "mpii/images", "coco/train2017"]
anno_list: ["pose3d/Human3.6m_train.json", "pose3d/PoseTrack_ver01.json", "pose3d/LSPet_train_ver10.json", "pose3d/LSPet_test_ver10.json", "pose3d/MPII_ver01.json", "pose3d/COCO2014-All-ver01.json"]
num_joints: *num_joints
test_mode: False
EvalDataset:
!Pose3DDataset
dataset_dir: dataset/traindata/
image_dirs: ["human3.6m"]
anno_list: ["pose3d/Human3.6m_valid.json"]
num_joints: *num_joints
test_mode: True
TestDataset:
!ImageFolder
anno_path: dataset/traindata/coco/keypoint_imagelist.txt
worker_num: 4
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
TrainReader:
sample_transforms:
- SinglePoseAffine:
trainsize: *trainsize
rotate: [1.0, 30] #[prob, rotate range]
scale: [1.0, 0.25] #[prob, scale range]
- FlipPose:
flip_prob: 0.5
img_res: *train_width
num_joints: *num_joints
- NoiseJitter:
noise_factor: 0.4
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 64
shuffle: true
drop_last: true
EvalReader:
sample_transforms:
- SinglePoseAffine:
trainsize: *trainsize
rotate: [0., 30]
scale: [0., 0.25]
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 16
shuffle: false
drop_last: false
TestReader:
inputs_def:
image_shape: [3, *train_height, *train_width]
sample_transforms:
- Decode: {}
- TopDownEvalAffine:
trainsize: *trainsize
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
fuse_normalize: false #whether to fuse nomalize layer into model while export model
...@@ -77,10 +77,12 @@ class Pose3DDataset(DetDataset): ...@@ -77,10 +77,12 @@ class Pose3DDataset(DetDataset):
indices = np.random.choice( indices = np.random.choice(
np.arange(num_joints), replace=False, size=masked_num) np.arange(num_joints), replace=False, size=masked_num)
mjm_mask[indices, :] = 0.0 mjm_mask[indices, :] = 0.0
# return mjm_mask
mvm_mask = np.ones((10, 1)).astype(np.float32) num_joints = 1
mvm_mask = np.ones((num_joints, 1)).astype(np.float)
if self.test_mode == False: if self.test_mode == False:
num_vertices = 10 num_vertices = num_joints
pb = np.random.random_sample() pb = np.random.random_sample()
masked_num = int( masked_num = int(
pb * mvm_percent * pb * mvm_percent *
...@@ -108,6 +110,7 @@ class Pose3DDataset(DetDataset): ...@@ -108,6 +110,7 @@ class Pose3DDataset(DetDataset):
print("Loading annotations..., please wait") print("Loading annotations..., please wait")
self.annos = [] self.annos = []
im_id = 0 im_id = 0
self.human36m_num = 0
for idx, annof in enumerate(self.anno_list): for idx, annof in enumerate(self.anno_list):
img_prefix = os.path.join(self.dataset_dir, self.image_dirs[idx]) img_prefix = os.path.join(self.dataset_dir, self.image_dirs[idx])
dataf = os.path.join(self.dataset_dir, annof) dataf = os.path.join(self.dataset_dir, annof)
...@@ -138,6 +141,8 @@ class Pose3DDataset(DetDataset): ...@@ -138,6 +141,8 @@ class Pose3DDataset(DetDataset):
print("cannot find imagepath:{}".format(imagename)) print("cannot find imagepath:{}".format(imagename))
continue continue
new_anno['imageName'] = imagename new_anno['imageName'] = imagename
if 'human3.6m' in imagename:
self.human36m_num += 1
new_anno['bbox_center'] = anno['bbox_center'] new_anno['bbox_center'] = anno['bbox_center']
new_anno['bbox_scale'] = anno['bbox_scale'] new_anno['bbox_scale'] = anno['bbox_scale']
new_anno['joints_2d'] = np.array(anno[ new_anno['joints_2d'] = np.array(anno[
...@@ -160,6 +165,10 @@ class Pose3DDataset(DetDataset): ...@@ -160,6 +165,10 @@ class Pose3DDataset(DetDataset):
self.annos.append(new_anno) self.annos.append(new_anno)
del annos del annos
def get_temp_num(self):
"""get temporal data number, like human3.6m"""
return self.human36m_num
def __len__(self): def __len__(self):
"""Get dataset length.""" """Get dataset length."""
return len(self.annos) return len(self.annos)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
from paddle.distributed import ParallelEnv
import os import os
import json import json
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
...@@ -161,8 +162,10 @@ class Pose3DEval(object): ...@@ -161,8 +162,10 @@ class Pose3DEval(object):
return paddle.index_select(input, J24_TO_J14, axis=1) return paddle.index_select(input, J24_TO_J14, axis=1)
def update(self, inputs, outputs): def update(self, inputs, outputs):
gt_3d_joints = all_gather(inputs['joints_3d']) gt_3d_joints = all_gather(inputs['joints_3d'].cuda(ParallelEnv()
has_3d_joints = all_gather(inputs['has_3d_joints']) .local_rank))
has_3d_joints = all_gather(inputs['has_3d_joints'].cuda(ParallelEnv()
.local_rank))
pred_3d_joints = all_gather(outputs['pose3d']) pred_3d_joints = all_gather(outputs['pose3d'])
if gt_3d_joints.shape[1] == 24: if gt_3d_joints.shape[1] == 24:
gt_3d_joints = self.get_human36m_joints(gt_3d_joints) gt_3d_joints = self.get_human36m_joints(gt_3d_joints)
......
...@@ -65,10 +65,8 @@ class METRO_Body(BaseArch): ...@@ -65,10 +65,8 @@ class METRO_Body(BaseArch):
self.deploy = False self.deploy = False
self.trans_encoder = trans_encoder self.trans_encoder = trans_encoder
self.conv_learn_tokens = paddle.nn.Conv1D(49, 10 + num_joints, 1) self.conv_learn_tokens = paddle.nn.Conv1D(49, num_joints + 1, 1)
self.cam_param_fc = paddle.nn.Linear(3, 1) self.cam_param_fc = paddle.nn.Linear(3, 2)
self.cam_param_fc2 = paddle.nn.Linear(10, 250)
self.cam_param_fc3 = paddle.nn.Linear(250, 3)
@classmethod @classmethod
def from_config(cls, cfg, *args, **kwargs): def from_config(cls, cfg, *args, **kwargs):
...@@ -85,7 +83,7 @@ class METRO_Body(BaseArch): ...@@ -85,7 +83,7 @@ class METRO_Body(BaseArch):
image_feat_flatten = image_feat.reshape((batch_size, 2048, 49)) image_feat_flatten = image_feat.reshape((batch_size, 2048, 49))
image_feat_flatten = image_feat_flatten.transpose(perm=(0, 2, 1)) image_feat_flatten = image_feat_flatten.transpose(perm=(0, 2, 1))
# and apply a conv layer to learn image token for each 3d joint/vertex position # and apply a conv layer to learn image token for each 3d joint/vertex position
features = self.conv_learn_tokens(image_feat_flatten) features = self.conv_learn_tokens(image_feat_flatten) # (B, J, C)
if self.training: if self.training:
# apply mask vertex/joint modeling # apply mask vertex/joint modeling
...@@ -95,20 +93,13 @@ class METRO_Body(BaseArch): ...@@ -95,20 +93,13 @@ class METRO_Body(BaseArch):
constant_tensor = paddle.ones_like(features) * 0.01 constant_tensor = paddle.ones_like(features) * 0.01
features = features * meta_masks + constant_tensor * (1 - meta_masks features = features * meta_masks + constant_tensor * (1 - meta_masks
) )
pred_out = self.trans_encoder(features) pred_out = self.trans_encoder(features)
pred_3d_joints = pred_out[:, :self.num_joints, :] pred_3d_joints = pred_out[:, :self.num_joints, :]
cam_features = pred_out[:, self.num_joints:, :] cam_features = pred_out[:, self.num_joints:, :]
# learn camera parameters # learn camera parameters
x = self.cam_param_fc(cam_features) pred_2d_joints = self.cam_param_fc(cam_features)
x = x.transpose(perm=(0, 2, 1))
x = self.cam_param_fc2(x)
x = self.cam_param_fc3(x)
cam_param = x.transpose(perm=(0, 2, 1))
pred_camera = cam_param.squeeze()
pred_2d_joints = orthographic_projection(pred_3d_joints, pred_camera)
return pred_3d_joints, pred_2d_joints return pred_3d_joints, pred_2d_joints
def get_loss(self): def get_loss(self):
......
...@@ -20,8 +20,11 @@ from itertools import cycle, islice ...@@ -20,8 +20,11 @@ from itertools import cycle, islice
from collections import abc from collections import abc
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from ppdet.utils.logger import setup_logger
logger = setup_logger('ppdet.engine')
__all__ = ['Pose3DLoss'] __all__ = ['Pose3DLoss']
...@@ -42,7 +45,7 @@ class Pose3DLoss(nn.Layer): ...@@ -42,7 +45,7 @@ class Pose3DLoss(nn.Layer):
self.weight_3d = weight_3d self.weight_3d = weight_3d
self.weight_2d = weight_2d self.weight_2d = weight_2d
self.criterion_2dpose = nn.MSELoss(reduction=reduction) self.criterion_2dpose = nn.MSELoss(reduction=reduction)
self.criterion_3dpose = nn.MSELoss(reduction=reduction) self.criterion_3dpose = nn.L1Loss(reduction=reduction)
self.criterion_smoothl1 = nn.SmoothL1Loss( self.criterion_smoothl1 = nn.SmoothL1Loss(
reduction=reduction, delta=1.0) reduction=reduction, delta=1.0)
self.criterion_vertices = nn.L1Loss() self.criterion_vertices = nn.L1Loss()
...@@ -57,10 +60,17 @@ class Pose3DLoss(nn.Layer): ...@@ -57,10 +60,17 @@ class Pose3DLoss(nn.Layer):
has_3d_joints = inputs['has_3d_joints'] has_3d_joints = inputs['has_3d_joints']
has_2d_joints = inputs['has_2d_joints'] has_2d_joints = inputs['has_2d_joints']
loss_3d = mpjpe(pred3d, gt_3d_joints, has_3d_joints) loss_3d = mpjpe_focal(pred3d, gt_3d_joints, has_3d_joints)
loss_2d = keypoint_2d_loss(self.criterion_2dpose, pred2d, gt_2d_joints, loss = self.weight_3d * loss_3d
has_2d_joints) epoch = inputs['epoch_id']
return self.weight_3d * loss_3d + self.weight_2d * loss_2d if self.weight_2d > 0:
weight = self.weight_2d * pow(0.1, (epoch // 8))
if epoch > 8:
weight = 0
loss_2d = keypoint_2d_loss(self.criterion_2dpose, pred2d,
gt_2d_joints, has_2d_joints)
loss += weight * loss_2d
return loss
def filter_3d_joints(pred, gt, has_3d_joints): def filter_3d_joints(pred, gt, has_3d_joints):
...@@ -78,25 +88,45 @@ def filter_3d_joints(pred, gt, has_3d_joints): ...@@ -78,25 +88,45 @@ def filter_3d_joints(pred, gt, has_3d_joints):
return pred, gt return pred, gt
@register
@serializable
def mpjpe(pred, gt, has_3d_joints): def mpjpe(pred, gt, has_3d_joints):
""" """
mPJPE loss mPJPE loss
""" """
pred, gt = filter_3d_joints(pred, gt, has_3d_joints) pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
error = paddle.sqrt(((pred - gt)**2).sum(axis=-1)).mean() error = paddle.sqrt((paddle.minimum((pred - gt), paddle.to_tensor(1.2))**2
).sum(axis=-1)).mean()
return error
def mpjpe_focal(pred, gt, has_3d_joints):
"""
mPJPE loss
"""
pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
mse_error = ((pred - gt)**2).sum(axis=-1)
mpjpe_error = paddle.sqrt(mse_error)
mean = mpjpe_error.mean()
std = mpjpe_error.std()
atte = 2 * F.sigmoid(6 * (mpjpe_error - mean) / std)
mse_error *= atte
return mse_error.mean()
def mpjpe_mse(pred, gt, has_3d_joints, weight=1.):
"""
mPJPE loss
"""
pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
error = (((pred - gt)**2).sum(axis=-1)).mean()
return error return error
@register
@serializable
def mpjpe_criterion(pred, gt, has_3d_joints, criterion_pose3d): def mpjpe_criterion(pred, gt, has_3d_joints, criterion_pose3d):
""" """
mPJPE loss of self define criterion mPJPE loss of self define criterion
""" """
pred, gt = filter_3d_joints(pred, gt, has_3d_joints) pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
error = paddle.sqrt(criterion_pose3d(pred, gt).sum(axis=-1)).mean() error = paddle.sqrt(criterion_pose3d(pred, gt)).mean()
return error return error
...@@ -165,8 +195,8 @@ def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d, ...@@ -165,8 +195,8 @@ def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d,
The confidence (conf) is binary and indicates whether the keypoints exist or not. The confidence (conf) is binary and indicates whether the keypoints exist or not.
""" """
conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
loss = (conf * criterion_keypoints(pred_keypoints_2d, loss = (conf * criterion_keypoints(
gt_keypoints_2d[:, :, :-1])).mean() pred_keypoints_2d, gt_keypoints_2d[:, :, :-1] * 0.001)).mean()
return loss return loss
......
...@@ -50,7 +50,8 @@ def visualize_results(image, ...@@ -50,7 +50,8 @@ def visualize_results(image,
if keypoint_res is not None: if keypoint_res is not None:
image = draw_pose(image, keypoint_res, threshold) image = draw_pose(image, keypoint_res, threshold)
if pose3d_res is not None: if pose3d_res is not None:
image = draw_pose3d(image, pose3d_res, threshold) pose3d = np.array(pose3d_res[0]['pose3d']) * 1000
image = draw_pose3d(image, pose3d, visual_thread=threshold)
return image return image
...@@ -325,12 +326,11 @@ def draw_pose(image, ...@@ -325,12 +326,11 @@ def draw_pose(image,
def draw_pose3d(image, def draw_pose3d(image,
results, pose3d,
pose2d=None,
visual_thread=0.6, visual_thread=0.6,
save_name='pose3d.jpg', save_name='pose3d.jpg',
save_dir='output', returnimg=True):
returnimg=False,
ids=None):
try: try:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib import matplotlib
...@@ -339,12 +339,11 @@ def draw_pose3d(image, ...@@ -339,12 +339,11 @@ def draw_pose3d(image,
logger.error('Matplotlib not found, please install matplotlib.' logger.error('Matplotlib not found, please install matplotlib.'
'for example: `pip install matplotlib`.') 'for example: `pip install matplotlib`.')
raise e raise e
pose3d = np.array(results[0]['pose3d']) * 1000
if pose3d.shape[0] == 24: if pose3d.shape[0] == 24:
joints_connectivity_dict = [ joints_connectivity_dict = [
[0, 1, 0], [1, 2, 0], [5, 4, 1], [4, 3, 1], [2, 3, 0], [2, 14, 1], [0, 1, 0], [1, 2, 0], [5, 4, 1], [4, 3, 1], [2, 3, 0], [2, 14, 1],
[3, 14, 1], [14, 15, 1], [15, 16, 1], [16, 12, 1], [6, 7, 0], [3, 14, 1], [14, 16, 1], [15, 16, 1], [15, 12, 1], [6, 7, 0],
[7, 8, 0], [11, 10, 1], [10, 9, 1], [8, 12, 0], [9, 12, 1], [7, 8, 0], [11, 10, 1], [10, 9, 1], [8, 12, 0], [9, 12, 1],
[12, 19, 1], [19, 18, 1], [19, 20, 0], [19, 21, 1], [22, 20, 0], [12, 19, 1], [19, 18, 1], [19, 20, 0], [19, 21, 1], [22, 20, 0],
[23, 21, 1] [23, 21, 1]
...@@ -450,6 +449,9 @@ def draw_pose3d(image, ...@@ -450,6 +449,9 @@ def draw_pose3d(image,
image = Image.frombytes("RGBA", (w, h), buf.tostring()) image = Image.frombytes("RGBA", (w, h), buf.tostring())
return image.convert("RGB") return image.convert("RGB")
fig = draw_img_pose(pose3d, frame=image) fig = draw_img_pose(pose3d, pose2d, frame=image)
data = fig2data(fig) data = fig2data(fig)
if returnimg is False:
data.save(save_name)
else:
return data return data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册