未验证 提交 d4e34fe1 编写于 作者: Z zhiboniu 提交者: GitHub

pose3d metro modeling (#6612)

* pose3d metro modeling

* delete extra comments
上级 c9823094
......@@ -62,3 +62,4 @@ from .tood import *
from .retinanet import *
from .bytetrack import *
from .yolox import *
from .pose3d_metro import *
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
from .. import layers as L
__all__ = ['METRO_Body']
def orthographic_projection(X, camera):
"""Perform orthographic projection of 3D points X using the camera parameters
Args:
X: size = [B, N, 3]
camera: size = [B, 3]
Returns:
Projected 2D points -- size = [B, N, 2]
"""
camera = camera.reshape((-1, 1, 3))
X_trans = X[:, :, :2] + camera[:, :, 1:]
shape = paddle.shape(X_trans)
X_2d = (camera[:, :, 0] * X_trans.reshape((shape[0], -1))).reshape(shape)
return X_2d
@register
class METRO_Body(BaseArch):
__category__ = 'architecture'
__inject__ = ['loss']
def __init__(
self,
num_joints,
backbone='HRNet',
trans_encoder='',
loss='Pose3DLoss', ):
"""
METRO network, see https://arxiv.org/abs/
Args:
backbone (nn.Layer): backbone instance
"""
super(METRO_Body, self).__init__()
self.num_joints = num_joints
self.backbone = backbone
self.loss = loss
self.deploy = False
self.trans_encoder = trans_encoder
self.conv_learn_tokens = paddle.nn.Conv1D(49, 10 + num_joints, 1)
self.cam_param_fc = paddle.nn.Linear(3, 1)
self.cam_param_fc2 = paddle.nn.Linear(10, 250)
self.cam_param_fc3 = paddle.nn.Linear(250, 3)
@classmethod
def from_config(cls, cfg, *args, **kwargs):
# backbone
backbone = create(cfg['backbone'])
trans_encoder = create(cfg['trans_encoder'])
return {'backbone': backbone, 'trans_encoder': trans_encoder}
def _forward(self):
batch_size = self.inputs['image'].shape[0]
image_feat = self.backbone(self.inputs)
image_feat_flatten = image_feat.reshape((batch_size, 2048, 49))
image_feat_flatten = image_feat_flatten.transpose(perm=(0, 2, 1))
# and apply a conv layer to learn image token for each 3d joint/vertex position
features = self.conv_learn_tokens(image_feat_flatten)
if self.training:
# apply mask vertex/joint modeling
# meta_masks is a tensor of all the masks, randomly generated in dataloader
# we pre-define a [MASK] token, which is a floating-value vector with 0.01s
meta_masks = self.inputs['mjm_mask'].expand((-1, -1, 2048))
constant_tensor = paddle.ones_like(features) * 0.01
features = features * meta_masks + constant_tensor * (1 - meta_masks
)
pred_out = self.trans_encoder(features)
pred_3d_joints = pred_out[:, :self.num_joints, :]
cam_features = pred_out[:, self.num_joints:, :]
# learn camera parameters
x = self.cam_param_fc(cam_features)
x = x.transpose(perm=(0, 2, 1))
x = self.cam_param_fc2(x)
x = self.cam_param_fc3(x)
cam_param = x.transpose(perm=(0, 2, 1))
pred_camera = cam_param.squeeze()
pred_2d_joints = orthographic_projection(pred_3d_joints, pred_camera)
return pred_3d_joints, pred_2d_joints
def get_loss(self):
preds_3d, preds_2d = self._forward()
loss = self.loss(preds_3d, preds_2d, self.inputs)
output = {'loss': loss}
return output
def get_pred(self):
preds_3d, preds_2d = self._forward()
outputs = {'pose3d': preds_3d, 'pose2d': preds_2d}
return outputs
......@@ -58,3 +58,4 @@ from .convnext import *
from .vision_transformer import *
from .vision_transformer import *
from .mobileone import *
from .trans_encoder import *
......@@ -37,6 +37,7 @@ class ConvNormLayer(nn.Layer):
norm_type='bn',
norm_groups=32,
use_dcn=False,
norm_momentum=0.9,
norm_decay=0.,
freeze_norm=False,
act=None,
......@@ -66,6 +67,7 @@ class ConvNormLayer(nn.Layer):
if norm_type in ['bn', 'sync_bn']:
self.norm = nn.BatchNorm2D(
ch_out,
momentum=norm_momentum,
weight_attr=param_attr,
bias_attr=bias_attr,
use_global_stats=global_stats)
......@@ -93,6 +95,7 @@ class Layer1(nn.Layer):
def __init__(self,
num_channels,
has_se=False,
norm_momentum=0.9,
norm_decay=0.,
freeze_norm=True,
name=None):
......@@ -109,6 +112,7 @@ class Layer1(nn.Layer):
has_se=has_se,
stride=1,
downsample=True if i == 0 else False,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name=name + '_' + str(i + 1)))
......@@ -125,6 +129,7 @@ class TransitionLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
norm_momentum=0.9,
norm_decay=0.,
freeze_norm=True,
name=None):
......@@ -144,6 +149,7 @@ class TransitionLayer(nn.Layer):
ch_in=in_channels[i],
ch_out=out_channels[i],
filter_size=3,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
act='relu',
......@@ -156,6 +162,7 @@ class TransitionLayer(nn.Layer):
ch_out=out_channels[i],
filter_size=3,
stride=2,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
act='relu',
......@@ -181,6 +188,7 @@ class Branches(nn.Layer):
in_channels,
out_channels,
has_se=False,
norm_momentum=0.9,
norm_decay=0.,
freeze_norm=True,
name=None):
......@@ -197,6 +205,7 @@ class Branches(nn.Layer):
num_channels=in_ch,
num_filters=out_channels[i],
has_se=has_se,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name=name + '_branch_layer_' + str(i + 1) + '_' +
......@@ -221,6 +230,7 @@ class BottleneckBlock(nn.Layer):
has_se,
stride=1,
downsample=False,
norm_momentum=0.9,
norm_decay=0.,
freeze_norm=True,
name=None):
......@@ -233,6 +243,7 @@ class BottleneckBlock(nn.Layer):
ch_in=num_channels,
ch_out=num_filters,
filter_size=1,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
act="relu",
......@@ -242,6 +253,7 @@ class BottleneckBlock(nn.Layer):
ch_out=num_filters,
filter_size=3,
stride=stride,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
act="relu",
......@@ -250,6 +262,7 @@ class BottleneckBlock(nn.Layer):
ch_in=num_filters,
ch_out=num_filters * 4,
filter_size=1,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
act=None,
......@@ -260,6 +273,7 @@ class BottleneckBlock(nn.Layer):
ch_in=num_channels,
ch_out=num_filters * 4,
filter_size=1,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
act=None,
......@@ -296,6 +310,7 @@ class BasicBlock(nn.Layer):
stride=1,
has_se=False,
downsample=False,
norm_momentum=0.9,
norm_decay=0.,
freeze_norm=True,
name=None):
......@@ -307,6 +322,7 @@ class BasicBlock(nn.Layer):
ch_in=num_channels,
ch_out=num_filters,
filter_size=3,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
stride=stride,
......@@ -316,6 +332,7 @@ class BasicBlock(nn.Layer):
ch_in=num_filters,
ch_out=num_filters,
filter_size=3,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
stride=1,
......@@ -327,6 +344,7 @@ class BasicBlock(nn.Layer):
ch_in=num_channels,
ch_out=num_filters * 4,
filter_size=1,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
act=None,
......@@ -394,6 +412,7 @@ class Stage(nn.Layer):
num_modules,
num_filters,
has_se=False,
norm_momentum=0.9,
norm_decay=0.,
freeze_norm=True,
multi_scale_output=True,
......@@ -410,6 +429,7 @@ class Stage(nn.Layer):
num_channels=num_channels,
num_filters=num_filters,
has_se=has_se,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
multi_scale_output=False,
......@@ -421,6 +441,7 @@ class Stage(nn.Layer):
num_channels=num_channels,
num_filters=num_filters,
has_se=has_se,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name=name + '_' + str(i + 1)))
......@@ -440,6 +461,7 @@ class HighResolutionModule(nn.Layer):
num_filters,
has_se=False,
multi_scale_output=True,
norm_momentum=0.9,
norm_decay=0.,
freeze_norm=True,
name=None):
......@@ -449,6 +471,7 @@ class HighResolutionModule(nn.Layer):
in_channels=num_channels,
out_channels=num_filters,
has_se=has_se,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name=name)
......@@ -457,6 +480,7 @@ class HighResolutionModule(nn.Layer):
in_channels=num_filters,
out_channels=num_filters,
multi_scale_output=multi_scale_output,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name=name)
......@@ -472,6 +496,7 @@ class FuseLayers(nn.Layer):
in_channels,
out_channels,
multi_scale_output=True,
norm_momentum=0.9,
norm_decay=0.,
freeze_norm=True,
name=None):
......@@ -493,6 +518,7 @@ class FuseLayers(nn.Layer):
filter_size=1,
stride=1,
act=None,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name=name + '_layer_' + str(i + 1) + '_' +
......@@ -510,6 +536,7 @@ class FuseLayers(nn.Layer):
ch_out=out_channels[i],
filter_size=3,
stride=2,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
act=None,
......@@ -525,6 +552,7 @@ class FuseLayers(nn.Layer):
ch_out=out_channels[j],
filter_size=3,
stride=2,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
act="relu",
......@@ -549,7 +577,6 @@ class FuseLayers(nn.Layer):
for k in range(i - j):
y = self.residual_func_list[residual_func_idx](y)
residual_func_idx += 1
residual = paddle.add(x=residual, y=y)
residual = F.relu(residual)
outs.append(residual)
......@@ -567,6 +594,7 @@ class HRNet(nn.Layer):
has_se (bool): whether to add SE block for each stage
freeze_at (int): the stage to freeze
freeze_norm (bool): whether to freeze norm in HRNet
norm_momentum (float): momentum of BatchNorm
norm_decay (float): weight decay for normalization layer weights
return_idx (List): the stage to return
upsample (bool): whether to upsample and concat the backbone feats
......@@ -577,9 +605,11 @@ class HRNet(nn.Layer):
has_se=False,
freeze_at=0,
freeze_norm=True,
norm_momentum=0.9,
norm_decay=0.,
return_idx=[0, 1, 2, 3],
upsample=False):
upsample=False,
downsample=False):
super(HRNet, self).__init__()
self.width = width
......@@ -591,6 +621,7 @@ class HRNet(nn.Layer):
self.freeze_at = freeze_at
self.return_idx = return_idx
self.upsample = upsample
self.downsample = downsample
self.channels = {
18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]],
......@@ -613,6 +644,7 @@ class HRNet(nn.Layer):
ch_out=64,
filter_size=3,
stride=2,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
act='relu',
......@@ -623,6 +655,7 @@ class HRNet(nn.Layer):
ch_out=64,
filter_size=3,
stride=2,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
act='relu',
......@@ -631,6 +664,7 @@ class HRNet(nn.Layer):
self.la1 = Layer1(
num_channels=64,
has_se=has_se,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name="layer2")
......@@ -638,6 +672,7 @@ class HRNet(nn.Layer):
self.tr1 = TransitionLayer(
in_channels=[256],
out_channels=channels_2,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name="tr1")
......@@ -647,6 +682,7 @@ class HRNet(nn.Layer):
num_modules=num_modules_2,
num_filters=channels_2,
has_se=self.has_se,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name="st2")
......@@ -654,6 +690,7 @@ class HRNet(nn.Layer):
self.tr2 = TransitionLayer(
in_channels=channels_2,
out_channels=channels_3,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name="tr2")
......@@ -663,6 +700,7 @@ class HRNet(nn.Layer):
num_modules=num_modules_3,
num_filters=channels_3,
has_se=self.has_se,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name="st3")
......@@ -670,6 +708,7 @@ class HRNet(nn.Layer):
self.tr3 = TransitionLayer(
in_channels=channels_3,
out_channels=channels_4,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name="tr3")
......@@ -678,11 +717,107 @@ class HRNet(nn.Layer):
num_modules=num_modules_4,
num_filters=channels_4,
has_se=self.has_se,
norm_momentum=norm_momentum,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
multi_scale_output=len(return_idx) > 1,
name="st4")
self.incre_modules, self.downsamp_modules, \
self.final_layer = self._make_head(channels_4, norm_momentum=norm_momentum, has_se=self.has_se)
self.classifier = nn.Linear(2048, 1000)
def _make_layer(self,
block,
inplanes,
planes,
blocks,
stride=1,
norm_momentum=0.9,
has_se=False,
name=None):
downsample = None
if stride != 1 or inplanes != planes * 4:
downsample = True
layers = []
layers.append(
block(
inplanes,
planes,
has_se,
stride,
downsample,
norm_momentum=norm_momentum,
freeze_norm=False,
name=name + "_s0"))
inplanes = planes * 4
for i in range(1, blocks):
layers.append(
block(
inplanes,
planes,
has_se,
norm_momentum=norm_momentum,
freeze_norm=False,
name=name + "_s" + str(i)))
return nn.Sequential(*layers)
def _make_head(self, pre_stage_channels, norm_momentum=0.9, has_se=False):
head_block = BottleneckBlock
head_channels = [32, 64, 128, 256]
# Increasing the #channels on each resolution
# from C, 2C, 4C, 8C to 128, 256, 512, 1024
incre_modules = []
for i, channels in enumerate(pre_stage_channels):
incre_module = self._make_layer(
head_block,
channels,
head_channels[i],
1,
stride=1,
norm_momentum=norm_momentum,
has_se=has_se,
name='incre' + str(i))
incre_modules.append(incre_module)
incre_modules = nn.LayerList(incre_modules)
# downsampling modules
downsamp_modules = []
for i in range(len(pre_stage_channels) - 1):
in_channels = head_channels[i] * 4
out_channels = head_channels[i + 1] * 4
downsamp_module = nn.Sequential(
nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1),
nn.BatchNorm2D(
out_channels, momentum=norm_momentum),
nn.ReLU())
downsamp_modules.append(downsamp_module)
downsamp_modules = nn.LayerList(downsamp_modules)
final_layer = nn.Sequential(
nn.Conv2D(
in_channels=head_channels[3] * 4,
out_channels=2048,
kernel_size=1,
stride=1,
padding=0),
nn.BatchNorm2D(
2048, momentum=norm_momentum),
nn.ReLU())
return incre_modules, downsamp_modules, final_layer
def forward(self, inputs):
x = inputs['image']
conv1 = self.conv_layer1_1(x)
......@@ -707,6 +842,14 @@ class HRNet(nn.Layer):
x = paddle.concat([st4[0], x1, x2, x3], 1)
return x
if self.downsample:
y = self.incre_modules[0](st4[0])
for i in range(len(self.downsamp_modules)):
y = self.incre_modules[i+1](st4[i+1]) + \
self.downsamp_modules[i](y)
y = self.final_layer(y)
return y
res = []
for i, layer in enumerate(st4):
if i == self.freeze_at:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import ReLU, Swish, GELU
import math
from ppdet.core.workspace import register
from ..shape_spec import ShapeSpec
__all__ = ['TransEncoder']
class BertEmbeddings(nn.Layer):
def __init__(self, word_size, position_embeddings_size, word_type_size,
hidden_size, dropout_prob):
super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(
word_size, hidden_size, padding_idx=0)
self.position_embeddings = nn.Embedding(position_embeddings_size,
hidden_size)
self.token_type_embeddings = nn.Embedding(word_type_size, hidden_size)
self.layernorm = nn.LayerNorm(hidden_size, epsilon=1e-8)
self.dropout = nn.Dropout(dropout_prob)
def forward(self, x, token_type_ids=None, position_ids=None):
seq_len = paddle.shape(x)[1]
if position_ids is None:
position_ids = paddle.arange(seq_len).unsqueeze(0).expand_as(x)
if token_type_ids is None:
token_type_ids = paddle.zeros(paddle.shape(x))
word_embs = self.word_embeddings(x)
position_embs = self.position_embeddings(position_ids)
token_type_embs = self.token_type_embeddings(token_type_ids)
embs_cmb = word_embs + position_embs + token_type_embs
embs_out = self.layernorm(embs_cmb)
embs_out = self.dropout(embs_out)
return embs_out
class BertSelfAttention(nn.Layer):
def __init__(self,
hidden_size,
num_attention_heads,
attention_probs_dropout_prob,
output_attentions=False):
super(BertSelfAttention, self).__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError(
"The hidden_size must be a multiple of the number of attention "
"heads, but got {} % {} != 0" %
(hidden_size, num_attention_heads))
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(hidden_size, self.all_head_size)
self.key = nn.Linear(hidden_size, self.all_head_size)
self.value = nn.Linear(hidden_size, self.all_head_size)
self.dropout = nn.Dropout(attention_probs_dropout_prob)
self.output_attentions = output_attentions
def forward(self, x, attention_mask, head_mask=None):
query = self.query(x)
key = self.key(x)
value = self.value(x)
query_dim1, query_dim2 = paddle.shape(query)[:-1]
new_shape = [
query_dim1, query_dim2, self.num_attention_heads,
self.attention_head_size
]
query = query.reshape(new_shape).transpose(perm=(0, 2, 1, 3))
key = key.reshape(new_shape).transpose(perm=(0, 2, 3, 1))
value = value.reshape(new_shape).transpose(perm=(0, 2, 1, 3))
attention = paddle.matmul(query,
key) / math.sqrt(self.attention_head_size)
attention = attention + attention_mask
attention_value = F.softmax(attention, axis=-1)
attention_value = self.dropout(attention_value)
if head_mask is not None:
attention_value = attention_value * head_mask
context = paddle.matmul(attention_value, value).transpose(perm=(0, 2, 1,
3))
ctx_dim1, ctx_dim2 = paddle.shape(context)[:-2]
new_context_shape = [
ctx_dim1,
ctx_dim2,
self.all_head_size,
]
context = context.reshape(new_context_shape)
if self.output_attentions:
return (context, attention_value)
else:
return (context, )
class BertAttention(nn.Layer):
def __init__(self,
hidden_size,
num_attention_heads,
attention_probs_dropout_prob,
fc_dropout_prob,
output_attentions=False):
super(BertAttention, self).__init__()
self.bert_selfattention = BertSelfAttention(
hidden_size, num_attention_heads, attention_probs_dropout_prob,
output_attentions)
self.fc = nn.Linear(hidden_size, hidden_size)
self.layernorm = nn.LayerNorm(hidden_size, epsilon=1e-8)
self.dropout = nn.Dropout(fc_dropout_prob)
def forward(self, x, attention_mask, head_mask=None):
attention_feats = self.bert_selfattention(x, attention_mask, head_mask)
features = self.fc(attention_feats[0])
features = self.dropout(features)
features = self.layernorm(features + x)
if len(attention_feats) == 2:
return (features, attention_feats[1])
else:
return (features, )
class BertFeedForward(nn.Layer):
def __init__(self,
hidden_size,
intermediate_size,
num_attention_heads,
attention_probs_dropout_prob,
fc_dropout_prob,
act_fn='ReLU',
output_attentions=False):
super(BertFeedForward, self).__init__()
self.fc1 = nn.Linear(hidden_size, intermediate_size)
self.act_fn = eval(act_fn)
self.fc2 = nn.Linear(intermediate_size, hidden_size)
self.layernorm = nn.LayerNorm(hidden_size, epsilon=1e-8)
self.dropout = nn.Dropout(fc_dropout_prob)
def forward(self, x):
features = self.fc1(x)
features = self.act_fn(features)
features = self.fc2(features)
features = self.dropout(features)
features = self.layernorm(features + x)
return features
class BertLayer(nn.Layer):
def __init__(self,
hidden_size,
intermediate_size,
num_attention_heads,
attention_probs_dropout_prob,
fc_dropout_prob,
act_fn='ReLU',
output_attentions=False):
super(BertLayer, self).__init__()
self.attention = BertAttention(hidden_size, num_attention_heads,
attention_probs_dropout_prob,
output_attentions)
self.feed_forward = BertFeedForward(
hidden_size, intermediate_size, num_attention_heads,
attention_probs_dropout_prob, fc_dropout_prob, act_fn,
output_attentions)
def forward(self, x, attention_mask, head_mask=None):
attention_feats = self.attention(x, attention_mask, head_mask)
features = self.feed_forward(attention_feats[0])
if len(attention_feats) == 2:
return (features, attention_feats[1])
else:
return (features, )
class BertEncoder(nn.Layer):
def __init__(self,
num_hidden_layers,
hidden_size,
intermediate_size,
num_attention_heads,
attention_probs_dropout_prob,
fc_dropout_prob,
act_fn='ReLU',
output_attentions=False,
output_hidden_feats=False):
super(BertEncoder, self).__init__()
self.output_attentions = output_attentions
self.output_hidden_feats = output_hidden_feats
self.layers = nn.LayerList([
BertLayer(hidden_size, intermediate_size, num_attention_heads,
attention_probs_dropout_prob, fc_dropout_prob, act_fn,
output_attentions) for _ in range(num_hidden_layers)
])
def forward(self, x, attention_mask, head_mask=None):
all_features = (x, )
all_attentions = ()
for i, layer in enumerate(self.layers):
mask = head_mask[i] if head_mask is not None else None
layer_out = layer(x, attention_mask, mask)
if self.output_hidden_feats:
all_features = all_features + (x, )
x = layer_out[0]
if self.output_attentions:
all_attentions = all_attentions + (layer_out[1], )
outputs = (x, )
if self.output_hidden_feats:
outputs += (all_features, )
if self.output_attentions:
outputs += (all_attentions, )
return outputs
class BertPooler(nn.Layer):
def __init__(self, hidden_size):
super(BertPooler, self).__init__()
self.fc = nn.Linear(hidden_size, hidden_size)
self.act = nn.Tanh()
def forward(self, x):
first_token = x[:, 0]
pooled_output = self.fc(first_token)
pooled_output = self.act(pooled_output)
return pooled_output
class METROEncoder(nn.Layer):
def __init__(self,
vocab_size,
num_hidden_layers,
features_dims,
position_embeddings_size,
hidden_size,
intermediate_size,
output_feature_dim,
num_attention_heads,
attention_probs_dropout_prob,
fc_dropout_prob,
act_fn='ReLU',
output_attentions=False,
output_hidden_feats=False,
use_img_layernorm=False):
super(METROEncoder, self).__init__()
self.img_dims = features_dims
self.num_hidden_layers = num_hidden_layers
self.use_img_layernorm = use_img_layernorm
self.output_attentions = output_attentions
self.embedding = BertEmbeddings(vocab_size, position_embeddings_size, 2,
hidden_size, fc_dropout_prob)
self.encoder = BertEncoder(
num_hidden_layers, hidden_size, intermediate_size,
num_attention_heads, attention_probs_dropout_prob, fc_dropout_prob,
act_fn, output_attentions, output_hidden_feats)
self.pooler = BertPooler(hidden_size)
self.position_embeddings = nn.Embedding(position_embeddings_size,
hidden_size)
self.img_embedding = nn.Linear(
features_dims, hidden_size, bias_attr=True)
self.dropout = nn.Dropout(fc_dropout_prob)
self.cls_head = nn.Linear(hidden_size, output_feature_dim)
self.residual = nn.Linear(features_dims, output_feature_dim)
self.apply(self.init_weights)
def init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.set_value(
paddle.normal(
mean=0.0, std=0.02, shape=module.weight.shape))
elif isinstance(module, nn.LayerNorm):
module.bias.set_value(paddle.zeros(shape=module.bias.shape))
module.weight.set_value(
paddle.full(
shape=module.weight.shape, fill_value=1.0))
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.set_value(paddle.zeros(shape=module.bias.shape))
def forward(self, x):
batchsize, seq_len = paddle.shape(x)[:2]
input_ids = paddle.zeros((batchsize, seq_len), dtype="int64")
position_ids = paddle.arange(
seq_len, dtype="int64").unsqueeze(0).expand_as(input_ids)
attention_mask = paddle.ones_like(input_ids).unsqueeze(1).unsqueeze(2)
head_mask = [None] * self.num_hidden_layers
position_embs = self.position_embeddings(position_ids)
attention_mask = (1.0 - attention_mask) * -10000.0
img_features = self.img_embedding(x)
# We empirically observe that adding an additional learnable position embedding leads to more stable training
embeddings = position_embs + img_features
if self.use_img_layernorm:
embeddings = self.layernorm(embeddings)
embeddings = self.dropout(embeddings)
encoder_outputs = self.encoder(
embeddings, attention_mask, head_mask=head_mask)
pred_score = self.cls_head(encoder_outputs[0])
res_img_feats = self.residual(x)
pred_score = pred_score + res_img_feats
if self.output_attentions and self.output_hidden_feats:
return pred_score, encoder_outputs[1], encoder_outputs[-1]
else:
return pred_score
def gelu(x):
"""Implementation of the gelu activation function.
https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0)))
@register
class TransEncoder(nn.Layer):
def __init__(self,
vocab_size=30522,
num_hidden_layers=4,
num_attention_heads=4,
position_embeddings_size=512,
intermediate_size=3072,
input_feat_dim=[2048, 512, 128],
hidden_feat_dim=[1024, 256, 128],
attention_probs_dropout_prob=0.1,
fc_dropout_prob=0.1,
act_fn='gelu',
output_attentions=False,
output_hidden_feats=False):
super(TransEncoder, self).__init__()
output_feat_dim = input_feat_dim[1:] + [3]
trans_encoder = []
for i in range(len(output_feat_dim)):
features_dims = input_feat_dim[i]
output_feature_dim = output_feat_dim[i]
hidden_size = hidden_feat_dim[i]
# init a transformer encoder and append it to a list
assert hidden_size % num_attention_heads == 0
model = METROEncoder(vocab_size, num_hidden_layers, features_dims,
position_embeddings_size, hidden_size,
intermediate_size, output_feature_dim,
num_attention_heads,
attention_probs_dropout_prob, fc_dropout_prob,
act_fn, output_attentions, output_hidden_feats)
trans_encoder.append(model)
self.trans_encoder = paddle.nn.Sequential(*trans_encoder)
def forward(self, x):
out = self.trans_encoder(x)
return out
......@@ -43,3 +43,4 @@ from .detr_loss import *
from .sparsercnn_loss import *
from .focal_loss import *
from .smooth_l1_loss import *
from .pose3d_loss import *
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from itertools import cycle, islice
from collections import abc
import paddle
import paddle.nn as nn
from ppdet.core.workspace import register, serializable
__all__ = ['Pose3DLoss']
@register
@serializable
class Pose3DLoss(nn.Layer):
def __init__(self, weight_3d=1.0, weight_2d=0.0, reduction='none'):
"""
KeyPointMSELoss layer
Args:
weight_3d (float): weight of 3d loss
weight_2d (float): weight of 2d loss
reduction (bool): whether use reduction to loss
"""
super(Pose3DLoss, self).__init__()
self.weight_3d = weight_3d
self.weight_2d = weight_2d
self.criterion_2dpose = nn.MSELoss(reduction=reduction)
self.criterion_3dpose = nn.MSELoss(reduction=reduction)
self.criterion_smoothl1 = nn.SmoothL1Loss(
reduction=reduction, delta=1.0)
self.criterion_vertices = nn.L1Loss()
def forward(self, pred3d, pred2d, inputs):
"""
mpjpe: mpjpe loss between 3d joints
keypoint_2d_loss: 2d joints loss compute by criterion_2dpose
"""
gt_3d_joints = inputs['joints_3d']
gt_2d_joints = inputs['joints_2d']
has_3d_joints = inputs['has_3d_joints']
has_2d_joints = inputs['has_2d_joints']
loss_3d = mpjpe(pred3d, gt_3d_joints, has_3d_joints)
loss_2d = keypoint_2d_loss(self.criterion_2dpose, pred2d, gt_2d_joints,
has_2d_joints)
return self.weight_3d * loss_3d + self.weight_2d * loss_2d
def filter_3d_joints(pred, gt, has_3d_joints):
"""
filter 3d joints
"""
gt = gt[has_3d_joints == 1]
gt = gt[:, :, :3]
pred = pred[has_3d_joints == 1]
gt_pelvis = (gt[:, 2, :] + gt[:, 3, :]) / 2
gt = gt - gt_pelvis[:, None, :]
pred_pelvis = (pred[:, 2, :] + pred[:, 3, :]) / 2
pred = pred - pred_pelvis[:, None, :]
return pred, gt
@register
@serializable
def mpjpe(pred, gt, has_3d_joints):
"""
mPJPE loss
"""
pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
error = paddle.sqrt(((pred - gt)**2).sum(axis=-1)).mean()
return error
@register
@serializable
def mpjpe_criterion(pred, gt, has_3d_joints, criterion_pose3d):
"""
mPJPE loss of self define criterion
"""
pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
error = paddle.sqrt(criterion_pose3d(pred, gt).sum(axis=-1)).mean()
return error
@register
@serializable
def weighted_mpjpe(pred, gt, has_3d_joints):
"""
Weighted_mPJPE
"""
pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
weight = paddle.linalg.norm(pred, p=2, axis=-1)
weight = paddle.to_tensor(
[1.5, 1.3, 1.2, 1.2, 1.3, 1.5, 1.5, 1.3, 1.2, 1.2, 1.3, 1.5, 1., 1.])
error = (weight * paddle.linalg.norm(pred - gt, p=2, axis=-1)).mean()
return error
@register
@serializable
def normed_mpjpe(pred, gt, has_3d_joints):
"""
Normalized MPJPE (scale only), adapted from:
https://github.com/hrhodin/UnsupervisedGeometryAwareRepresentationLearning/blob/master/losses/poses.py
"""
assert pred.shape == gt.shape
pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
norm_predicted = paddle.mean(
paddle.sum(pred**2, axis=3, keepdim=True), axis=2, keepdim=True)
norm_target = paddle.mean(
paddle.sum(gt * pred, axis=3, keepdim=True), axis=2, keepdim=True)
scale = norm_target / norm_predicted
return mpjpe(scale * pred, gt)
@register
@serializable
def mpjpe_np(pred, gt, has_3d_joints):
"""
mPJPE_NP
"""
pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
error = np.sqrt(((pred - gt)**2).sum(axis=-1)).mean()
return error
@register
@serializable
def mean_per_vertex_error(pred, gt, has_smpl):
"""
Compute mPVE
"""
pred = pred[has_smpl == 1]
gt = gt[has_smpl == 1]
with paddle.no_grad():
error = paddle.sqrt(((pred - gt)**2).sum(axis=-1)).mean()
return error
@register
@serializable
def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d,
has_pose_2d):
"""
Compute 2D reprojection loss if 2D keypoint annotations are available.
The confidence (conf) is binary and indicates whether the keypoints exist or not.
"""
conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
loss = (conf * criterion_keypoints(pred_keypoints_2d,
gt_keypoints_2d[:, :, :-1])).mean()
return loss
@register
@serializable
def keypoint_3d_loss(criterion_keypoints, pred_keypoints_3d, gt_keypoints_3d,
has_pose_3d):
"""
Compute 3D keypoint loss if 3D keypoint annotations are available.
"""
conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone()
gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1]
conf = conf[has_pose_3d == 1]
pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1]
if len(gt_keypoints_3d) > 0:
gt_pelvis = (gt_keypoints_3d[:, 2, :] + gt_keypoints_3d[:, 3, :]) / 2
gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :]
pred_pelvis = (
pred_keypoints_3d[:, 2, :] + pred_keypoints_3d[:, 3, :]) / 2
pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :]
return (conf * criterion_keypoints(pred_keypoints_3d,
gt_keypoints_3d)).mean()
else:
return paddle.to_tensor([1.]).fill_(0.)
@register
@serializable
def vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl):
"""
Compute per-vertex loss if vertex annotations are available.
"""
pred_vertices_with_shape = pred_vertices[has_smpl == 1]
gt_vertices_with_shape = gt_vertices[has_smpl == 1]
if len(gt_vertices_with_shape) > 0:
return criterion_vertices(pred_vertices_with_shape,
gt_vertices_with_shape)
else:
return paddle.to_tensor([1.]).fill_(0.)
@register
@serializable
def rectify_pose(pose):
pose = pose.copy()
R_mod = cv2.Rodrigues(np.array([np.pi, 0, 0]))[0]
R_root = cv2.Rodrigues(pose[:3])[0]
new_root = R_root.dot(R_mod)
pose[:3] = cv2.Rodrigues(new_root)[0].reshape(3)
return pose
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册