diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index 9360a5b7b15596302ae54fc7b375e83718820da0..c0a6ec40d4059cbbf61b8ae4a62887a2b80b0a21 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -62,3 +62,4 @@ from .tood import * from .retinanet import * from .bytetrack import * from .yolox import * +from .pose3d_metro import * diff --git a/ppdet/modeling/architectures/pose3d_metro.py b/ppdet/modeling/architectures/pose3d_metro.py new file mode 100644 index 0000000000000000000000000000000000000000..9e66bd78f90895f13ca74aa35e02f8911b1b27da --- /dev/null +++ b/ppdet/modeling/architectures/pose3d_metro.py @@ -0,0 +1,123 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ppdet.core.workspace import register, create +from .meta_arch import BaseArch +from .. import layers as L + +__all__ = ['METRO_Body'] + + +def orthographic_projection(X, camera): + """Perform orthographic projection of 3D points X using the camera parameters + Args: + X: size = [B, N, 3] + camera: size = [B, 3] + Returns: + Projected 2D points -- size = [B, N, 2] + """ + camera = camera.reshape((-1, 1, 3)) + X_trans = X[:, :, :2] + camera[:, :, 1:] + shape = paddle.shape(X_trans) + X_2d = (camera[:, :, 0] * X_trans.reshape((shape[0], -1))).reshape(shape) + return X_2d + + +@register +class METRO_Body(BaseArch): + __category__ = 'architecture' + __inject__ = ['loss'] + + def __init__( + self, + num_joints, + backbone='HRNet', + trans_encoder='', + loss='Pose3DLoss', ): + """ + METRO network, see https://arxiv.org/abs/ + + Args: + backbone (nn.Layer): backbone instance + """ + super(METRO_Body, self).__init__() + self.num_joints = num_joints + self.backbone = backbone + self.loss = loss + self.deploy = False + + self.trans_encoder = trans_encoder + self.conv_learn_tokens = paddle.nn.Conv1D(49, 10 + num_joints, 1) + self.cam_param_fc = paddle.nn.Linear(3, 1) + self.cam_param_fc2 = paddle.nn.Linear(10, 250) + self.cam_param_fc3 = paddle.nn.Linear(250, 3) + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + # backbone + backbone = create(cfg['backbone']) + trans_encoder = create(cfg['trans_encoder']) + + return {'backbone': backbone, 'trans_encoder': trans_encoder} + + def _forward(self): + batch_size = self.inputs['image'].shape[0] + + image_feat = self.backbone(self.inputs) + image_feat_flatten = image_feat.reshape((batch_size, 2048, 49)) + image_feat_flatten = image_feat_flatten.transpose(perm=(0, 2, 1)) + # and apply a conv layer to learn image token for each 3d joint/vertex position + features = self.conv_learn_tokens(image_feat_flatten) + + if self.training: + # apply mask vertex/joint modeling + # meta_masks is a tensor of all the masks, randomly generated in dataloader + # we pre-define a [MASK] token, which is a floating-value vector with 0.01s + meta_masks = self.inputs['mjm_mask'].expand((-1, -1, 2048)) + constant_tensor = paddle.ones_like(features) * 0.01 + features = features * meta_masks + constant_tensor * (1 - meta_masks + ) + + pred_out = self.trans_encoder(features) + pred_3d_joints = pred_out[:, :self.num_joints, :] + cam_features = pred_out[:, self.num_joints:, :] + + # learn camera parameters + x = self.cam_param_fc(cam_features) + x = x.transpose(perm=(0, 2, 1)) + x = self.cam_param_fc2(x) + x = self.cam_param_fc3(x) + cam_param = x.transpose(perm=(0, 2, 1)) + pred_camera = cam_param.squeeze() + pred_2d_joints = orthographic_projection(pred_3d_joints, pred_camera) + + return pred_3d_joints, pred_2d_joints + + def get_loss(self): + preds_3d, preds_2d = self._forward() + loss = self.loss(preds_3d, preds_2d, self.inputs) + output = {'loss': loss} + return output + + def get_pred(self): + preds_3d, preds_2d = self._forward() + outputs = {'pose3d': preds_3d, 'pose2d': preds_2d} + return outputs diff --git a/ppdet/modeling/backbones/__init__.py b/ppdet/modeling/backbones/__init__.py index 12e9354b744dc97e2de584915a8827d137a3f7c2..a8f5dad25cb0d5efcd71802f16fca40cb4ea0ce8 100644 --- a/ppdet/modeling/backbones/__init__.py +++ b/ppdet/modeling/backbones/__init__.py @@ -58,3 +58,4 @@ from .convnext import * from .vision_transformer import * from .vision_transformer import * from .mobileone import * +from .trans_encoder import * diff --git a/ppdet/modeling/backbones/hrnet.py b/ppdet/modeling/backbones/hrnet.py index 0f09aedcaf7bc3552fd322ab670b25ebbd543dd4..17c92eb138a70011dc9da7c73eb11969968fa7fa 100644 --- a/ppdet/modeling/backbones/hrnet.py +++ b/ppdet/modeling/backbones/hrnet.py @@ -37,6 +37,7 @@ class ConvNormLayer(nn.Layer): norm_type='bn', norm_groups=32, use_dcn=False, + norm_momentum=0.9, norm_decay=0., freeze_norm=False, act=None, @@ -66,6 +67,7 @@ class ConvNormLayer(nn.Layer): if norm_type in ['bn', 'sync_bn']: self.norm = nn.BatchNorm2D( ch_out, + momentum=norm_momentum, weight_attr=param_attr, bias_attr=bias_attr, use_global_stats=global_stats) @@ -93,6 +95,7 @@ class Layer1(nn.Layer): def __init__(self, num_channels, has_se=False, + norm_momentum=0.9, norm_decay=0., freeze_norm=True, name=None): @@ -109,6 +112,7 @@ class Layer1(nn.Layer): has_se=has_se, stride=1, downsample=True if i == 0 else False, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, name=name + '_' + str(i + 1))) @@ -125,6 +129,7 @@ class TransitionLayer(nn.Layer): def __init__(self, in_channels, out_channels, + norm_momentum=0.9, norm_decay=0., freeze_norm=True, name=None): @@ -144,6 +149,7 @@ class TransitionLayer(nn.Layer): ch_in=in_channels[i], ch_out=out_channels[i], filter_size=3, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, act='relu', @@ -156,6 +162,7 @@ class TransitionLayer(nn.Layer): ch_out=out_channels[i], filter_size=3, stride=2, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, act='relu', @@ -181,6 +188,7 @@ class Branches(nn.Layer): in_channels, out_channels, has_se=False, + norm_momentum=0.9, norm_decay=0., freeze_norm=True, name=None): @@ -197,6 +205,7 @@ class Branches(nn.Layer): num_channels=in_ch, num_filters=out_channels[i], has_se=has_se, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, name=name + '_branch_layer_' + str(i + 1) + '_' + @@ -221,6 +230,7 @@ class BottleneckBlock(nn.Layer): has_se, stride=1, downsample=False, + norm_momentum=0.9, norm_decay=0., freeze_norm=True, name=None): @@ -233,6 +243,7 @@ class BottleneckBlock(nn.Layer): ch_in=num_channels, ch_out=num_filters, filter_size=1, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, act="relu", @@ -242,6 +253,7 @@ class BottleneckBlock(nn.Layer): ch_out=num_filters, filter_size=3, stride=stride, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, act="relu", @@ -250,6 +262,7 @@ class BottleneckBlock(nn.Layer): ch_in=num_filters, ch_out=num_filters * 4, filter_size=1, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, act=None, @@ -260,6 +273,7 @@ class BottleneckBlock(nn.Layer): ch_in=num_channels, ch_out=num_filters * 4, filter_size=1, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, act=None, @@ -296,6 +310,7 @@ class BasicBlock(nn.Layer): stride=1, has_se=False, downsample=False, + norm_momentum=0.9, norm_decay=0., freeze_norm=True, name=None): @@ -307,6 +322,7 @@ class BasicBlock(nn.Layer): ch_in=num_channels, ch_out=num_filters, filter_size=3, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, stride=stride, @@ -316,6 +332,7 @@ class BasicBlock(nn.Layer): ch_in=num_filters, ch_out=num_filters, filter_size=3, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, stride=1, @@ -327,6 +344,7 @@ class BasicBlock(nn.Layer): ch_in=num_channels, ch_out=num_filters * 4, filter_size=1, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, act=None, @@ -394,6 +412,7 @@ class Stage(nn.Layer): num_modules, num_filters, has_se=False, + norm_momentum=0.9, norm_decay=0., freeze_norm=True, multi_scale_output=True, @@ -410,6 +429,7 @@ class Stage(nn.Layer): num_channels=num_channels, num_filters=num_filters, has_se=has_se, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, multi_scale_output=False, @@ -421,6 +441,7 @@ class Stage(nn.Layer): num_channels=num_channels, num_filters=num_filters, has_se=has_se, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, name=name + '_' + str(i + 1))) @@ -440,6 +461,7 @@ class HighResolutionModule(nn.Layer): num_filters, has_se=False, multi_scale_output=True, + norm_momentum=0.9, norm_decay=0., freeze_norm=True, name=None): @@ -449,6 +471,7 @@ class HighResolutionModule(nn.Layer): in_channels=num_channels, out_channels=num_filters, has_se=has_se, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, name=name) @@ -457,6 +480,7 @@ class HighResolutionModule(nn.Layer): in_channels=num_filters, out_channels=num_filters, multi_scale_output=multi_scale_output, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, name=name) @@ -472,6 +496,7 @@ class FuseLayers(nn.Layer): in_channels, out_channels, multi_scale_output=True, + norm_momentum=0.9, norm_decay=0., freeze_norm=True, name=None): @@ -493,6 +518,7 @@ class FuseLayers(nn.Layer): filter_size=1, stride=1, act=None, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, name=name + '_layer_' + str(i + 1) + '_' + @@ -510,6 +536,7 @@ class FuseLayers(nn.Layer): ch_out=out_channels[i], filter_size=3, stride=2, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, act=None, @@ -525,6 +552,7 @@ class FuseLayers(nn.Layer): ch_out=out_channels[j], filter_size=3, stride=2, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, act="relu", @@ -549,7 +577,6 @@ class FuseLayers(nn.Layer): for k in range(i - j): y = self.residual_func_list[residual_func_idx](y) residual_func_idx += 1 - residual = paddle.add(x=residual, y=y) residual = F.relu(residual) outs.append(residual) @@ -567,6 +594,7 @@ class HRNet(nn.Layer): has_se (bool): whether to add SE block for each stage freeze_at (int): the stage to freeze freeze_norm (bool): whether to freeze norm in HRNet + norm_momentum (float): momentum of BatchNorm norm_decay (float): weight decay for normalization layer weights return_idx (List): the stage to return upsample (bool): whether to upsample and concat the backbone feats @@ -577,9 +605,11 @@ class HRNet(nn.Layer): has_se=False, freeze_at=0, freeze_norm=True, + norm_momentum=0.9, norm_decay=0., return_idx=[0, 1, 2, 3], - upsample=False): + upsample=False, + downsample=False): super(HRNet, self).__init__() self.width = width @@ -591,6 +621,7 @@ class HRNet(nn.Layer): self.freeze_at = freeze_at self.return_idx = return_idx self.upsample = upsample + self.downsample = downsample self.channels = { 18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]], @@ -613,6 +644,7 @@ class HRNet(nn.Layer): ch_out=64, filter_size=3, stride=2, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, act='relu', @@ -623,6 +655,7 @@ class HRNet(nn.Layer): ch_out=64, filter_size=3, stride=2, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, act='relu', @@ -631,6 +664,7 @@ class HRNet(nn.Layer): self.la1 = Layer1( num_channels=64, has_se=has_se, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, name="layer2") @@ -638,6 +672,7 @@ class HRNet(nn.Layer): self.tr1 = TransitionLayer( in_channels=[256], out_channels=channels_2, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, name="tr1") @@ -647,6 +682,7 @@ class HRNet(nn.Layer): num_modules=num_modules_2, num_filters=channels_2, has_se=self.has_se, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, name="st2") @@ -654,6 +690,7 @@ class HRNet(nn.Layer): self.tr2 = TransitionLayer( in_channels=channels_2, out_channels=channels_3, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, name="tr2") @@ -663,6 +700,7 @@ class HRNet(nn.Layer): num_modules=num_modules_3, num_filters=channels_3, has_se=self.has_se, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, name="st3") @@ -670,6 +708,7 @@ class HRNet(nn.Layer): self.tr3 = TransitionLayer( in_channels=channels_3, out_channels=channels_4, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, name="tr3") @@ -678,11 +717,107 @@ class HRNet(nn.Layer): num_modules=num_modules_4, num_filters=channels_4, has_se=self.has_se, + norm_momentum=norm_momentum, norm_decay=norm_decay, freeze_norm=freeze_norm, multi_scale_output=len(return_idx) > 1, name="st4") + self.incre_modules, self.downsamp_modules, \ + self.final_layer = self._make_head(channels_4, norm_momentum=norm_momentum, has_se=self.has_se) + + self.classifier = nn.Linear(2048, 1000) + + def _make_layer(self, + block, + inplanes, + planes, + blocks, + stride=1, + norm_momentum=0.9, + has_se=False, + name=None): + downsample = None + if stride != 1 or inplanes != planes * 4: + downsample = True + + layers = [] + layers.append( + block( + inplanes, + planes, + has_se, + stride, + downsample, + norm_momentum=norm_momentum, + freeze_norm=False, + name=name + "_s0")) + inplanes = planes * 4 + for i in range(1, blocks): + layers.append( + block( + inplanes, + planes, + has_se, + norm_momentum=norm_momentum, + freeze_norm=False, + name=name + "_s" + str(i))) + + return nn.Sequential(*layers) + + def _make_head(self, pre_stage_channels, norm_momentum=0.9, has_se=False): + head_block = BottleneckBlock + head_channels = [32, 64, 128, 256] + + # Increasing the #channels on each resolution + # from C, 2C, 4C, 8C to 128, 256, 512, 1024 + incre_modules = [] + for i, channels in enumerate(pre_stage_channels): + incre_module = self._make_layer( + head_block, + channels, + head_channels[i], + 1, + stride=1, + norm_momentum=norm_momentum, + has_se=has_se, + name='incre' + str(i)) + incre_modules.append(incre_module) + incre_modules = nn.LayerList(incre_modules) + + # downsampling modules + downsamp_modules = [] + for i in range(len(pre_stage_channels) - 1): + in_channels = head_channels[i] * 4 + out_channels = head_channels[i + 1] * 4 + + downsamp_module = nn.Sequential( + nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1), + nn.BatchNorm2D( + out_channels, momentum=norm_momentum), + nn.ReLU()) + + downsamp_modules.append(downsamp_module) + downsamp_modules = nn.LayerList(downsamp_modules) + + final_layer = nn.Sequential( + nn.Conv2D( + in_channels=head_channels[3] * 4, + out_channels=2048, + kernel_size=1, + stride=1, + padding=0), + nn.BatchNorm2D( + 2048, momentum=norm_momentum), + nn.ReLU()) + + return incre_modules, downsamp_modules, final_layer + def forward(self, inputs): x = inputs['image'] conv1 = self.conv_layer1_1(x) @@ -707,6 +842,14 @@ class HRNet(nn.Layer): x = paddle.concat([st4[0], x1, x2, x3], 1) return x + if self.downsample: + y = self.incre_modules[0](st4[0]) + for i in range(len(self.downsamp_modules)): + y = self.incre_modules[i+1](st4[i+1]) + \ + self.downsamp_modules[i](y) + y = self.final_layer(y) + return y + res = [] for i, layer in enumerate(st4): if i == self.freeze_at: diff --git a/ppdet/modeling/backbones/trans_encoder.py b/ppdet/modeling/backbones/trans_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1a45e0f0e61567b153b2210fa59ea2bfe2bb8b16 --- /dev/null +++ b/ppdet/modeling/backbones/trans_encoder.py @@ -0,0 +1,381 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn import ReLU, Swish, GELU +import math + +from ppdet.core.workspace import register +from ..shape_spec import ShapeSpec + +__all__ = ['TransEncoder'] + + +class BertEmbeddings(nn.Layer): + def __init__(self, word_size, position_embeddings_size, word_type_size, + hidden_size, dropout_prob): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding( + word_size, hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(position_embeddings_size, + hidden_size) + self.token_type_embeddings = nn.Embedding(word_type_size, hidden_size) + self.layernorm = nn.LayerNorm(hidden_size, epsilon=1e-8) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, x, token_type_ids=None, position_ids=None): + seq_len = paddle.shape(x)[1] + if position_ids is None: + position_ids = paddle.arange(seq_len).unsqueeze(0).expand_as(x) + if token_type_ids is None: + token_type_ids = paddle.zeros(paddle.shape(x)) + + word_embs = self.word_embeddings(x) + position_embs = self.position_embeddings(position_ids) + token_type_embs = self.token_type_embeddings(token_type_ids) + + embs_cmb = word_embs + position_embs + token_type_embs + embs_out = self.layernorm(embs_cmb) + embs_out = self.dropout(embs_out) + return embs_out + + +class BertSelfAttention(nn.Layer): + def __init__(self, + hidden_size, + num_attention_heads, + attention_probs_dropout_prob, + output_attentions=False): + super(BertSelfAttention, self).__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError( + "The hidden_size must be a multiple of the number of attention " + "heads, but got {} % {} != 0" % + (hidden_size, num_attention_heads)) + + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(attention_probs_dropout_prob) + self.output_attentions = output_attentions + + def forward(self, x, attention_mask, head_mask=None): + query = self.query(x) + key = self.key(x) + value = self.value(x) + + query_dim1, query_dim2 = paddle.shape(query)[:-1] + new_shape = [ + query_dim1, query_dim2, self.num_attention_heads, + self.attention_head_size + ] + query = query.reshape(new_shape).transpose(perm=(0, 2, 1, 3)) + key = key.reshape(new_shape).transpose(perm=(0, 2, 3, 1)) + value = value.reshape(new_shape).transpose(perm=(0, 2, 1, 3)) + + attention = paddle.matmul(query, + key) / math.sqrt(self.attention_head_size) + attention = attention + attention_mask + attention_value = F.softmax(attention, axis=-1) + attention_value = self.dropout(attention_value) + + if head_mask is not None: + attention_value = attention_value * head_mask + + context = paddle.matmul(attention_value, value).transpose(perm=(0, 2, 1, + 3)) + ctx_dim1, ctx_dim2 = paddle.shape(context)[:-2] + new_context_shape = [ + ctx_dim1, + ctx_dim2, + self.all_head_size, + ] + context = context.reshape(new_context_shape) + + if self.output_attentions: + return (context, attention_value) + else: + return (context, ) + + +class BertAttention(nn.Layer): + def __init__(self, + hidden_size, + num_attention_heads, + attention_probs_dropout_prob, + fc_dropout_prob, + output_attentions=False): + super(BertAttention, self).__init__() + self.bert_selfattention = BertSelfAttention( + hidden_size, num_attention_heads, attention_probs_dropout_prob, + output_attentions) + self.fc = nn.Linear(hidden_size, hidden_size) + self.layernorm = nn.LayerNorm(hidden_size, epsilon=1e-8) + self.dropout = nn.Dropout(fc_dropout_prob) + + def forward(self, x, attention_mask, head_mask=None): + attention_feats = self.bert_selfattention(x, attention_mask, head_mask) + features = self.fc(attention_feats[0]) + features = self.dropout(features) + features = self.layernorm(features + x) + if len(attention_feats) == 2: + return (features, attention_feats[1]) + else: + return (features, ) + + +class BertFeedForward(nn.Layer): + def __init__(self, + hidden_size, + intermediate_size, + num_attention_heads, + attention_probs_dropout_prob, + fc_dropout_prob, + act_fn='ReLU', + output_attentions=False): + super(BertFeedForward, self).__init__() + self.fc1 = nn.Linear(hidden_size, intermediate_size) + self.act_fn = eval(act_fn) + self.fc2 = nn.Linear(intermediate_size, hidden_size) + self.layernorm = nn.LayerNorm(hidden_size, epsilon=1e-8) + self.dropout = nn.Dropout(fc_dropout_prob) + + def forward(self, x): + features = self.fc1(x) + features = self.act_fn(features) + features = self.fc2(features) + features = self.dropout(features) + features = self.layernorm(features + x) + return features + + +class BertLayer(nn.Layer): + def __init__(self, + hidden_size, + intermediate_size, + num_attention_heads, + attention_probs_dropout_prob, + fc_dropout_prob, + act_fn='ReLU', + output_attentions=False): + super(BertLayer, self).__init__() + self.attention = BertAttention(hidden_size, num_attention_heads, + attention_probs_dropout_prob, + output_attentions) + self.feed_forward = BertFeedForward( + hidden_size, intermediate_size, num_attention_heads, + attention_probs_dropout_prob, fc_dropout_prob, act_fn, + output_attentions) + + def forward(self, x, attention_mask, head_mask=None): + attention_feats = self.attention(x, attention_mask, head_mask) + features = self.feed_forward(attention_feats[0]) + if len(attention_feats) == 2: + return (features, attention_feats[1]) + else: + return (features, ) + + +class BertEncoder(nn.Layer): + def __init__(self, + num_hidden_layers, + hidden_size, + intermediate_size, + num_attention_heads, + attention_probs_dropout_prob, + fc_dropout_prob, + act_fn='ReLU', + output_attentions=False, + output_hidden_feats=False): + super(BertEncoder, self).__init__() + self.output_attentions = output_attentions + self.output_hidden_feats = output_hidden_feats + self.layers = nn.LayerList([ + BertLayer(hidden_size, intermediate_size, num_attention_heads, + attention_probs_dropout_prob, fc_dropout_prob, act_fn, + output_attentions) for _ in range(num_hidden_layers) + ]) + + def forward(self, x, attention_mask, head_mask=None): + all_features = (x, ) + all_attentions = () + + for i, layer in enumerate(self.layers): + mask = head_mask[i] if head_mask is not None else None + layer_out = layer(x, attention_mask, mask) + + if self.output_hidden_feats: + all_features = all_features + (x, ) + x = layer_out[0] + if self.output_attentions: + all_attentions = all_attentions + (layer_out[1], ) + + outputs = (x, ) + if self.output_hidden_feats: + outputs += (all_features, ) + if self.output_attentions: + outputs += (all_attentions, ) + return outputs + + +class BertPooler(nn.Layer): + def __init__(self, hidden_size): + super(BertPooler, self).__init__() + self.fc = nn.Linear(hidden_size, hidden_size) + self.act = nn.Tanh() + + def forward(self, x): + first_token = x[:, 0] + pooled_output = self.fc(first_token) + pooled_output = self.act(pooled_output) + return pooled_output + + +class METROEncoder(nn.Layer): + def __init__(self, + vocab_size, + num_hidden_layers, + features_dims, + position_embeddings_size, + hidden_size, + intermediate_size, + output_feature_dim, + num_attention_heads, + attention_probs_dropout_prob, + fc_dropout_prob, + act_fn='ReLU', + output_attentions=False, + output_hidden_feats=False, + use_img_layernorm=False): + super(METROEncoder, self).__init__() + self.img_dims = features_dims + self.num_hidden_layers = num_hidden_layers + self.use_img_layernorm = use_img_layernorm + self.output_attentions = output_attentions + self.embedding = BertEmbeddings(vocab_size, position_embeddings_size, 2, + hidden_size, fc_dropout_prob) + self.encoder = BertEncoder( + num_hidden_layers, hidden_size, intermediate_size, + num_attention_heads, attention_probs_dropout_prob, fc_dropout_prob, + act_fn, output_attentions, output_hidden_feats) + self.pooler = BertPooler(hidden_size) + self.position_embeddings = nn.Embedding(position_embeddings_size, + hidden_size) + self.img_embedding = nn.Linear( + features_dims, hidden_size, bias_attr=True) + self.dropout = nn.Dropout(fc_dropout_prob) + self.cls_head = nn.Linear(hidden_size, output_feature_dim) + self.residual = nn.Linear(features_dims, output_feature_dim) + + self.apply(self.init_weights) + + def init_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.set_value( + paddle.normal( + mean=0.0, std=0.02, shape=module.weight.shape)) + elif isinstance(module, nn.LayerNorm): + module.bias.set_value(paddle.zeros(shape=module.bias.shape)) + module.weight.set_value( + paddle.full( + shape=module.weight.shape, fill_value=1.0)) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.set_value(paddle.zeros(shape=module.bias.shape)) + + def forward(self, x): + batchsize, seq_len = paddle.shape(x)[:2] + input_ids = paddle.zeros((batchsize, seq_len), dtype="int64") + position_ids = paddle.arange( + seq_len, dtype="int64").unsqueeze(0).expand_as(input_ids) + + attention_mask = paddle.ones_like(input_ids).unsqueeze(1).unsqueeze(2) + head_mask = [None] * self.num_hidden_layers + + position_embs = self.position_embeddings(position_ids) + attention_mask = (1.0 - attention_mask) * -10000.0 + + img_features = self.img_embedding(x) + + # We empirically observe that adding an additional learnable position embedding leads to more stable training + embeddings = position_embs + img_features + if self.use_img_layernorm: + embeddings = self.layernorm(embeddings) + embeddings = self.dropout(embeddings) + + encoder_outputs = self.encoder( + embeddings, attention_mask, head_mask=head_mask) + + pred_score = self.cls_head(encoder_outputs[0]) + res_img_feats = self.residual(x) + pred_score = pred_score + res_img_feats + + if self.output_attentions and self.output_hidden_feats: + return pred_score, encoder_outputs[1], encoder_outputs[-1] + else: + return pred_score + + +def gelu(x): + """Implementation of the gelu activation function. + https://arxiv.org/abs/1606.08415 + """ + return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0))) + + +@register +class TransEncoder(nn.Layer): + def __init__(self, + vocab_size=30522, + num_hidden_layers=4, + num_attention_heads=4, + position_embeddings_size=512, + intermediate_size=3072, + input_feat_dim=[2048, 512, 128], + hidden_feat_dim=[1024, 256, 128], + attention_probs_dropout_prob=0.1, + fc_dropout_prob=0.1, + act_fn='gelu', + output_attentions=False, + output_hidden_feats=False): + super(TransEncoder, self).__init__() + output_feat_dim = input_feat_dim[1:] + [3] + trans_encoder = [] + for i in range(len(output_feat_dim)): + features_dims = input_feat_dim[i] + output_feature_dim = output_feat_dim[i] + hidden_size = hidden_feat_dim[i] + + # init a transformer encoder and append it to a list + assert hidden_size % num_attention_heads == 0 + model = METROEncoder(vocab_size, num_hidden_layers, features_dims, + position_embeddings_size, hidden_size, + intermediate_size, output_feature_dim, + num_attention_heads, + attention_probs_dropout_prob, fc_dropout_prob, + act_fn, output_attentions, output_hidden_feats) + trans_encoder.append(model) + self.trans_encoder = paddle.nn.Sequential(*trans_encoder) + + def forward(self, x): + out = self.trans_encoder(x) + return out diff --git a/ppdet/modeling/losses/__init__.py b/ppdet/modeling/losses/__init__.py index 94eff1f175fdf6789dba811ce3023b41dfd16dd0..0e6ebe9069ea0671bce74ea4496863f0cb052803 100644 --- a/ppdet/modeling/losses/__init__.py +++ b/ppdet/modeling/losses/__init__.py @@ -43,3 +43,4 @@ from .detr_loss import * from .sparsercnn_loss import * from .focal_loss import * from .smooth_l1_loss import * +from .pose3d_loss import * diff --git a/ppdet/modeling/losses/pose3d_loss.py b/ppdet/modeling/losses/pose3d_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..2b98508f429b5ec5a77e9a8176883e3ac1afc061 --- /dev/null +++ b/ppdet/modeling/losses/pose3d_loss.py @@ -0,0 +1,220 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from itertools import cycle, islice +from collections import abc +import paddle +import paddle.nn as nn + +from ppdet.core.workspace import register, serializable + +__all__ = ['Pose3DLoss'] + + +@register +@serializable +class Pose3DLoss(nn.Layer): + def __init__(self, weight_3d=1.0, weight_2d=0.0, reduction='none'): + """ + KeyPointMSELoss layer + + Args: + weight_3d (float): weight of 3d loss + weight_2d (float): weight of 2d loss + reduction (bool): whether use reduction to loss + """ + super(Pose3DLoss, self).__init__() + self.weight_3d = weight_3d + self.weight_2d = weight_2d + self.criterion_2dpose = nn.MSELoss(reduction=reduction) + self.criterion_3dpose = nn.MSELoss(reduction=reduction) + self.criterion_smoothl1 = nn.SmoothL1Loss( + reduction=reduction, delta=1.0) + self.criterion_vertices = nn.L1Loss() + + def forward(self, pred3d, pred2d, inputs): + """ + mpjpe: mpjpe loss between 3d joints + keypoint_2d_loss: 2d joints loss compute by criterion_2dpose + """ + gt_3d_joints = inputs['joints_3d'] + gt_2d_joints = inputs['joints_2d'] + has_3d_joints = inputs['has_3d_joints'] + has_2d_joints = inputs['has_2d_joints'] + + loss_3d = mpjpe(pred3d, gt_3d_joints, has_3d_joints) + loss_2d = keypoint_2d_loss(self.criterion_2dpose, pred2d, gt_2d_joints, + has_2d_joints) + return self.weight_3d * loss_3d + self.weight_2d * loss_2d + + +def filter_3d_joints(pred, gt, has_3d_joints): + """ + filter 3d joints + """ + gt = gt[has_3d_joints == 1] + gt = gt[:, :, :3] + pred = pred[has_3d_joints == 1] + + gt_pelvis = (gt[:, 2, :] + gt[:, 3, :]) / 2 + gt = gt - gt_pelvis[:, None, :] + pred_pelvis = (pred[:, 2, :] + pred[:, 3, :]) / 2 + pred = pred - pred_pelvis[:, None, :] + return pred, gt + + +@register +@serializable +def mpjpe(pred, gt, has_3d_joints): + """ + mPJPE loss + """ + pred, gt = filter_3d_joints(pred, gt, has_3d_joints) + error = paddle.sqrt(((pred - gt)**2).sum(axis=-1)).mean() + return error + + +@register +@serializable +def mpjpe_criterion(pred, gt, has_3d_joints, criterion_pose3d): + """ + mPJPE loss of self define criterion + """ + pred, gt = filter_3d_joints(pred, gt, has_3d_joints) + error = paddle.sqrt(criterion_pose3d(pred, gt).sum(axis=-1)).mean() + return error + + +@register +@serializable +def weighted_mpjpe(pred, gt, has_3d_joints): + """ + Weighted_mPJPE + """ + pred, gt = filter_3d_joints(pred, gt, has_3d_joints) + weight = paddle.linalg.norm(pred, p=2, axis=-1) + weight = paddle.to_tensor( + [1.5, 1.3, 1.2, 1.2, 1.3, 1.5, 1.5, 1.3, 1.2, 1.2, 1.3, 1.5, 1., 1.]) + error = (weight * paddle.linalg.norm(pred - gt, p=2, axis=-1)).mean() + return error + + +@register +@serializable +def normed_mpjpe(pred, gt, has_3d_joints): + """ + Normalized MPJPE (scale only), adapted from: + https://github.com/hrhodin/UnsupervisedGeometryAwareRepresentationLearning/blob/master/losses/poses.py + """ + assert pred.shape == gt.shape + pred, gt = filter_3d_joints(pred, gt, has_3d_joints) + + norm_predicted = paddle.mean( + paddle.sum(pred**2, axis=3, keepdim=True), axis=2, keepdim=True) + norm_target = paddle.mean( + paddle.sum(gt * pred, axis=3, keepdim=True), axis=2, keepdim=True) + scale = norm_target / norm_predicted + return mpjpe(scale * pred, gt) + + +@register +@serializable +def mpjpe_np(pred, gt, has_3d_joints): + """ + mPJPE_NP + """ + pred, gt = filter_3d_joints(pred, gt, has_3d_joints) + error = np.sqrt(((pred - gt)**2).sum(axis=-1)).mean() + return error + + +@register +@serializable +def mean_per_vertex_error(pred, gt, has_smpl): + """ + Compute mPVE + """ + pred = pred[has_smpl == 1] + gt = gt[has_smpl == 1] + with paddle.no_grad(): + error = paddle.sqrt(((pred - gt)**2).sum(axis=-1)).mean() + return error + + +@register +@serializable +def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d, + has_pose_2d): + """ + Compute 2D reprojection loss if 2D keypoint annotations are available. + The confidence (conf) is binary and indicates whether the keypoints exist or not. + """ + conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() + loss = (conf * criterion_keypoints(pred_keypoints_2d, + gt_keypoints_2d[:, :, :-1])).mean() + return loss + + +@register +@serializable +def keypoint_3d_loss(criterion_keypoints, pred_keypoints_3d, gt_keypoints_3d, + has_pose_3d): + """ + Compute 3D keypoint loss if 3D keypoint annotations are available. + """ + conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() + gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone() + gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1] + conf = conf[has_pose_3d == 1] + pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1] + if len(gt_keypoints_3d) > 0: + gt_pelvis = (gt_keypoints_3d[:, 2, :] + gt_keypoints_3d[:, 3, :]) / 2 + gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :] + pred_pelvis = ( + pred_keypoints_3d[:, 2, :] + pred_keypoints_3d[:, 3, :]) / 2 + pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :] + return (conf * criterion_keypoints(pred_keypoints_3d, + gt_keypoints_3d)).mean() + else: + return paddle.to_tensor([1.]).fill_(0.) + + +@register +@serializable +def vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl): + """ + Compute per-vertex loss if vertex annotations are available. + """ + pred_vertices_with_shape = pred_vertices[has_smpl == 1] + gt_vertices_with_shape = gt_vertices[has_smpl == 1] + if len(gt_vertices_with_shape) > 0: + return criterion_vertices(pred_vertices_with_shape, + gt_vertices_with_shape) + else: + return paddle.to_tensor([1.]).fill_(0.) + + +@register +@serializable +def rectify_pose(pose): + pose = pose.copy() + R_mod = cv2.Rodrigues(np.array([np.pi, 0, 0]))[0] + R_root = cv2.Rodrigues(pose[:3])[0] + new_root = R_root.dot(R_mod) + pose[:3] = cv2.Rodrigues(new_root)[0].reshape(3) + return pose