Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
d4e34fe1
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d4e34fe1
编写于
8月 31, 2022
作者:
Z
zhiboniu
提交者:
GitHub
8月 31, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
pose3d metro modeling (#6612)
* pose3d metro modeling * delete extra comments
上级
c9823094
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
872 addition
and
2 deletion
+872
-2
ppdet/modeling/architectures/__init__.py
ppdet/modeling/architectures/__init__.py
+1
-0
ppdet/modeling/architectures/pose3d_metro.py
ppdet/modeling/architectures/pose3d_metro.py
+123
-0
ppdet/modeling/backbones/__init__.py
ppdet/modeling/backbones/__init__.py
+1
-0
ppdet/modeling/backbones/hrnet.py
ppdet/modeling/backbones/hrnet.py
+145
-2
ppdet/modeling/backbones/trans_encoder.py
ppdet/modeling/backbones/trans_encoder.py
+381
-0
ppdet/modeling/losses/__init__.py
ppdet/modeling/losses/__init__.py
+1
-0
ppdet/modeling/losses/pose3d_loss.py
ppdet/modeling/losses/pose3d_loss.py
+220
-0
未找到文件。
ppdet/modeling/architectures/__init__.py
浏览文件 @
d4e34fe1
...
...
@@ -62,3 +62,4 @@ from .tood import *
from
.retinanet
import
*
from
.bytetrack
import
*
from
.yolox
import
*
from
.pose3d_metro
import
*
ppdet/modeling/architectures/pose3d_metro.py
0 → 100644
浏览文件 @
d4e34fe1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
ppdet.core.workspace
import
register
,
create
from
.meta_arch
import
BaseArch
from
..
import
layers
as
L
__all__
=
[
'METRO_Body'
]
def
orthographic_projection
(
X
,
camera
):
"""Perform orthographic projection of 3D points X using the camera parameters
Args:
X: size = [B, N, 3]
camera: size = [B, 3]
Returns:
Projected 2D points -- size = [B, N, 2]
"""
camera
=
camera
.
reshape
((
-
1
,
1
,
3
))
X_trans
=
X
[:,
:,
:
2
]
+
camera
[:,
:,
1
:]
shape
=
paddle
.
shape
(
X_trans
)
X_2d
=
(
camera
[:,
:,
0
]
*
X_trans
.
reshape
((
shape
[
0
],
-
1
))).
reshape
(
shape
)
return
X_2d
@
register
class
METRO_Body
(
BaseArch
):
__category__
=
'architecture'
__inject__
=
[
'loss'
]
def
__init__
(
self
,
num_joints
,
backbone
=
'HRNet'
,
trans_encoder
=
''
,
loss
=
'Pose3DLoss'
,
):
"""
METRO network, see https://arxiv.org/abs/
Args:
backbone (nn.Layer): backbone instance
"""
super
(
METRO_Body
,
self
).
__init__
()
self
.
num_joints
=
num_joints
self
.
backbone
=
backbone
self
.
loss
=
loss
self
.
deploy
=
False
self
.
trans_encoder
=
trans_encoder
self
.
conv_learn_tokens
=
paddle
.
nn
.
Conv1D
(
49
,
10
+
num_joints
,
1
)
self
.
cam_param_fc
=
paddle
.
nn
.
Linear
(
3
,
1
)
self
.
cam_param_fc2
=
paddle
.
nn
.
Linear
(
10
,
250
)
self
.
cam_param_fc3
=
paddle
.
nn
.
Linear
(
250
,
3
)
@
classmethod
def
from_config
(
cls
,
cfg
,
*
args
,
**
kwargs
):
# backbone
backbone
=
create
(
cfg
[
'backbone'
])
trans_encoder
=
create
(
cfg
[
'trans_encoder'
])
return
{
'backbone'
:
backbone
,
'trans_encoder'
:
trans_encoder
}
def
_forward
(
self
):
batch_size
=
self
.
inputs
[
'image'
].
shape
[
0
]
image_feat
=
self
.
backbone
(
self
.
inputs
)
image_feat_flatten
=
image_feat
.
reshape
((
batch_size
,
2048
,
49
))
image_feat_flatten
=
image_feat_flatten
.
transpose
(
perm
=
(
0
,
2
,
1
))
# and apply a conv layer to learn image token for each 3d joint/vertex position
features
=
self
.
conv_learn_tokens
(
image_feat_flatten
)
if
self
.
training
:
# apply mask vertex/joint modeling
# meta_masks is a tensor of all the masks, randomly generated in dataloader
# we pre-define a [MASK] token, which is a floating-value vector with 0.01s
meta_masks
=
self
.
inputs
[
'mjm_mask'
].
expand
((
-
1
,
-
1
,
2048
))
constant_tensor
=
paddle
.
ones_like
(
features
)
*
0.01
features
=
features
*
meta_masks
+
constant_tensor
*
(
1
-
meta_masks
)
pred_out
=
self
.
trans_encoder
(
features
)
pred_3d_joints
=
pred_out
[:,
:
self
.
num_joints
,
:]
cam_features
=
pred_out
[:,
self
.
num_joints
:,
:]
# learn camera parameters
x
=
self
.
cam_param_fc
(
cam_features
)
x
=
x
.
transpose
(
perm
=
(
0
,
2
,
1
))
x
=
self
.
cam_param_fc2
(
x
)
x
=
self
.
cam_param_fc3
(
x
)
cam_param
=
x
.
transpose
(
perm
=
(
0
,
2
,
1
))
pred_camera
=
cam_param
.
squeeze
()
pred_2d_joints
=
orthographic_projection
(
pred_3d_joints
,
pred_camera
)
return
pred_3d_joints
,
pred_2d_joints
def
get_loss
(
self
):
preds_3d
,
preds_2d
=
self
.
_forward
()
loss
=
self
.
loss
(
preds_3d
,
preds_2d
,
self
.
inputs
)
output
=
{
'loss'
:
loss
}
return
output
def
get_pred
(
self
):
preds_3d
,
preds_2d
=
self
.
_forward
()
outputs
=
{
'pose3d'
:
preds_3d
,
'pose2d'
:
preds_2d
}
return
outputs
ppdet/modeling/backbones/__init__.py
浏览文件 @
d4e34fe1
...
...
@@ -58,3 +58,4 @@ from .convnext import *
from
.vision_transformer
import
*
from
.vision_transformer
import
*
from
.mobileone
import
*
from
.trans_encoder
import
*
ppdet/modeling/backbones/hrnet.py
浏览文件 @
d4e34fe1
...
...
@@ -37,6 +37,7 @@ class ConvNormLayer(nn.Layer):
norm_type
=
'bn'
,
norm_groups
=
32
,
use_dcn
=
False
,
norm_momentum
=
0.9
,
norm_decay
=
0.
,
freeze_norm
=
False
,
act
=
None
,
...
...
@@ -66,6 +67,7 @@ class ConvNormLayer(nn.Layer):
if
norm_type
in
[
'bn'
,
'sync_bn'
]:
self
.
norm
=
nn
.
BatchNorm2D
(
ch_out
,
momentum
=
norm_momentum
,
weight_attr
=
param_attr
,
bias_attr
=
bias_attr
,
use_global_stats
=
global_stats
)
...
...
@@ -93,6 +95,7 @@ class Layer1(nn.Layer):
def
__init__
(
self
,
num_channels
,
has_se
=
False
,
norm_momentum
=
0.9
,
norm_decay
=
0.
,
freeze_norm
=
True
,
name
=
None
):
...
...
@@ -109,6 +112,7 @@ class Layer1(nn.Layer):
has_se
=
has_se
,
stride
=
1
,
downsample
=
True
if
i
==
0
else
False
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
name
=
name
+
'_'
+
str
(
i
+
1
)))
...
...
@@ -125,6 +129,7 @@ class TransitionLayer(nn.Layer):
def
__init__
(
self
,
in_channels
,
out_channels
,
norm_momentum
=
0.9
,
norm_decay
=
0.
,
freeze_norm
=
True
,
name
=
None
):
...
...
@@ -144,6 +149,7 @@ class TransitionLayer(nn.Layer):
ch_in
=
in_channels
[
i
],
ch_out
=
out_channels
[
i
],
filter_size
=
3
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
act
=
'relu'
,
...
...
@@ -156,6 +162,7 @@ class TransitionLayer(nn.Layer):
ch_out
=
out_channels
[
i
],
filter_size
=
3
,
stride
=
2
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
act
=
'relu'
,
...
...
@@ -181,6 +188,7 @@ class Branches(nn.Layer):
in_channels
,
out_channels
,
has_se
=
False
,
norm_momentum
=
0.9
,
norm_decay
=
0.
,
freeze_norm
=
True
,
name
=
None
):
...
...
@@ -197,6 +205,7 @@ class Branches(nn.Layer):
num_channels
=
in_ch
,
num_filters
=
out_channels
[
i
],
has_se
=
has_se
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
name
=
name
+
'_branch_layer_'
+
str
(
i
+
1
)
+
'_'
+
...
...
@@ -221,6 +230,7 @@ class BottleneckBlock(nn.Layer):
has_se
,
stride
=
1
,
downsample
=
False
,
norm_momentum
=
0.9
,
norm_decay
=
0.
,
freeze_norm
=
True
,
name
=
None
):
...
...
@@ -233,6 +243,7 @@ class BottleneckBlock(nn.Layer):
ch_in
=
num_channels
,
ch_out
=
num_filters
,
filter_size
=
1
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
act
=
"relu"
,
...
...
@@ -242,6 +253,7 @@ class BottleneckBlock(nn.Layer):
ch_out
=
num_filters
,
filter_size
=
3
,
stride
=
stride
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
act
=
"relu"
,
...
...
@@ -250,6 +262,7 @@ class BottleneckBlock(nn.Layer):
ch_in
=
num_filters
,
ch_out
=
num_filters
*
4
,
filter_size
=
1
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
act
=
None
,
...
...
@@ -260,6 +273,7 @@ class BottleneckBlock(nn.Layer):
ch_in
=
num_channels
,
ch_out
=
num_filters
*
4
,
filter_size
=
1
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
act
=
None
,
...
...
@@ -296,6 +310,7 @@ class BasicBlock(nn.Layer):
stride
=
1
,
has_se
=
False
,
downsample
=
False
,
norm_momentum
=
0.9
,
norm_decay
=
0.
,
freeze_norm
=
True
,
name
=
None
):
...
...
@@ -307,6 +322,7 @@ class BasicBlock(nn.Layer):
ch_in
=
num_channels
,
ch_out
=
num_filters
,
filter_size
=
3
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
stride
=
stride
,
...
...
@@ -316,6 +332,7 @@ class BasicBlock(nn.Layer):
ch_in
=
num_filters
,
ch_out
=
num_filters
,
filter_size
=
3
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
stride
=
1
,
...
...
@@ -327,6 +344,7 @@ class BasicBlock(nn.Layer):
ch_in
=
num_channels
,
ch_out
=
num_filters
*
4
,
filter_size
=
1
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
act
=
None
,
...
...
@@ -394,6 +412,7 @@ class Stage(nn.Layer):
num_modules
,
num_filters
,
has_se
=
False
,
norm_momentum
=
0.9
,
norm_decay
=
0.
,
freeze_norm
=
True
,
multi_scale_output
=
True
,
...
...
@@ -410,6 +429,7 @@ class Stage(nn.Layer):
num_channels
=
num_channels
,
num_filters
=
num_filters
,
has_se
=
has_se
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
multi_scale_output
=
False
,
...
...
@@ -421,6 +441,7 @@ class Stage(nn.Layer):
num_channels
=
num_channels
,
num_filters
=
num_filters
,
has_se
=
has_se
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
name
=
name
+
'_'
+
str
(
i
+
1
)))
...
...
@@ -440,6 +461,7 @@ class HighResolutionModule(nn.Layer):
num_filters
,
has_se
=
False
,
multi_scale_output
=
True
,
norm_momentum
=
0.9
,
norm_decay
=
0.
,
freeze_norm
=
True
,
name
=
None
):
...
...
@@ -449,6 +471,7 @@ class HighResolutionModule(nn.Layer):
in_channels
=
num_channels
,
out_channels
=
num_filters
,
has_se
=
has_se
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
name
=
name
)
...
...
@@ -457,6 +480,7 @@ class HighResolutionModule(nn.Layer):
in_channels
=
num_filters
,
out_channels
=
num_filters
,
multi_scale_output
=
multi_scale_output
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
name
=
name
)
...
...
@@ -472,6 +496,7 @@ class FuseLayers(nn.Layer):
in_channels
,
out_channels
,
multi_scale_output
=
True
,
norm_momentum
=
0.9
,
norm_decay
=
0.
,
freeze_norm
=
True
,
name
=
None
):
...
...
@@ -493,6 +518,7 @@ class FuseLayers(nn.Layer):
filter_size
=
1
,
stride
=
1
,
act
=
None
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
name
=
name
+
'_layer_'
+
str
(
i
+
1
)
+
'_'
+
...
...
@@ -510,6 +536,7 @@ class FuseLayers(nn.Layer):
ch_out
=
out_channels
[
i
],
filter_size
=
3
,
stride
=
2
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
act
=
None
,
...
...
@@ -525,6 +552,7 @@ class FuseLayers(nn.Layer):
ch_out
=
out_channels
[
j
],
filter_size
=
3
,
stride
=
2
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
act
=
"relu"
,
...
...
@@ -549,7 +577,6 @@ class FuseLayers(nn.Layer):
for
k
in
range
(
i
-
j
):
y
=
self
.
residual_func_list
[
residual_func_idx
](
y
)
residual_func_idx
+=
1
residual
=
paddle
.
add
(
x
=
residual
,
y
=
y
)
residual
=
F
.
relu
(
residual
)
outs
.
append
(
residual
)
...
...
@@ -567,6 +594,7 @@ class HRNet(nn.Layer):
has_se (bool): whether to add SE block for each stage
freeze_at (int): the stage to freeze
freeze_norm (bool): whether to freeze norm in HRNet
norm_momentum (float): momentum of BatchNorm
norm_decay (float): weight decay for normalization layer weights
return_idx (List): the stage to return
upsample (bool): whether to upsample and concat the backbone feats
...
...
@@ -577,9 +605,11 @@ class HRNet(nn.Layer):
has_se
=
False
,
freeze_at
=
0
,
freeze_norm
=
True
,
norm_momentum
=
0.9
,
norm_decay
=
0.
,
return_idx
=
[
0
,
1
,
2
,
3
],
upsample
=
False
):
upsample
=
False
,
downsample
=
False
):
super
(
HRNet
,
self
).
__init__
()
self
.
width
=
width
...
...
@@ -591,6 +621,7 @@ class HRNet(nn.Layer):
self
.
freeze_at
=
freeze_at
self
.
return_idx
=
return_idx
self
.
upsample
=
upsample
self
.
downsample
=
downsample
self
.
channels
=
{
18
:
[[
18
,
36
],
[
18
,
36
,
72
],
[
18
,
36
,
72
,
144
]],
...
...
@@ -613,6 +644,7 @@ class HRNet(nn.Layer):
ch_out
=
64
,
filter_size
=
3
,
stride
=
2
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
act
=
'relu'
,
...
...
@@ -623,6 +655,7 @@ class HRNet(nn.Layer):
ch_out
=
64
,
filter_size
=
3
,
stride
=
2
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
act
=
'relu'
,
...
...
@@ -631,6 +664,7 @@ class HRNet(nn.Layer):
self
.
la1
=
Layer1
(
num_channels
=
64
,
has_se
=
has_se
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
name
=
"layer2"
)
...
...
@@ -638,6 +672,7 @@ class HRNet(nn.Layer):
self
.
tr1
=
TransitionLayer
(
in_channels
=
[
256
],
out_channels
=
channels_2
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
name
=
"tr1"
)
...
...
@@ -647,6 +682,7 @@ class HRNet(nn.Layer):
num_modules
=
num_modules_2
,
num_filters
=
channels_2
,
has_se
=
self
.
has_se
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
name
=
"st2"
)
...
...
@@ -654,6 +690,7 @@ class HRNet(nn.Layer):
self
.
tr2
=
TransitionLayer
(
in_channels
=
channels_2
,
out_channels
=
channels_3
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
name
=
"tr2"
)
...
...
@@ -663,6 +700,7 @@ class HRNet(nn.Layer):
num_modules
=
num_modules_3
,
num_filters
=
channels_3
,
has_se
=
self
.
has_se
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
name
=
"st3"
)
...
...
@@ -670,6 +708,7 @@ class HRNet(nn.Layer):
self
.
tr3
=
TransitionLayer
(
in_channels
=
channels_3
,
out_channels
=
channels_4
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
name
=
"tr3"
)
...
...
@@ -678,11 +717,107 @@ class HRNet(nn.Layer):
num_modules
=
num_modules_4
,
num_filters
=
channels_4
,
has_se
=
self
.
has_se
,
norm_momentum
=
norm_momentum
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
multi_scale_output
=
len
(
return_idx
)
>
1
,
name
=
"st4"
)
self
.
incre_modules
,
self
.
downsamp_modules
,
\
self
.
final_layer
=
self
.
_make_head
(
channels_4
,
norm_momentum
=
norm_momentum
,
has_se
=
self
.
has_se
)
self
.
classifier
=
nn
.
Linear
(
2048
,
1000
)
def
_make_layer
(
self
,
block
,
inplanes
,
planes
,
blocks
,
stride
=
1
,
norm_momentum
=
0.9
,
has_se
=
False
,
name
=
None
):
downsample
=
None
if
stride
!=
1
or
inplanes
!=
planes
*
4
:
downsample
=
True
layers
=
[]
layers
.
append
(
block
(
inplanes
,
planes
,
has_se
,
stride
,
downsample
,
norm_momentum
=
norm_momentum
,
freeze_norm
=
False
,
name
=
name
+
"_s0"
))
inplanes
=
planes
*
4
for
i
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
inplanes
,
planes
,
has_se
,
norm_momentum
=
norm_momentum
,
freeze_norm
=
False
,
name
=
name
+
"_s"
+
str
(
i
)))
return
nn
.
Sequential
(
*
layers
)
def
_make_head
(
self
,
pre_stage_channels
,
norm_momentum
=
0.9
,
has_se
=
False
):
head_block
=
BottleneckBlock
head_channels
=
[
32
,
64
,
128
,
256
]
# Increasing the #channels on each resolution
# from C, 2C, 4C, 8C to 128, 256, 512, 1024
incre_modules
=
[]
for
i
,
channels
in
enumerate
(
pre_stage_channels
):
incre_module
=
self
.
_make_layer
(
head_block
,
channels
,
head_channels
[
i
],
1
,
stride
=
1
,
norm_momentum
=
norm_momentum
,
has_se
=
has_se
,
name
=
'incre'
+
str
(
i
))
incre_modules
.
append
(
incre_module
)
incre_modules
=
nn
.
LayerList
(
incre_modules
)
# downsampling modules
downsamp_modules
=
[]
for
i
in
range
(
len
(
pre_stage_channels
)
-
1
):
in_channels
=
head_channels
[
i
]
*
4
out_channels
=
head_channels
[
i
+
1
]
*
4
downsamp_module
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
),
nn
.
BatchNorm2D
(
out_channels
,
momentum
=
norm_momentum
),
nn
.
ReLU
())
downsamp_modules
.
append
(
downsamp_module
)
downsamp_modules
=
nn
.
LayerList
(
downsamp_modules
)
final_layer
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
head_channels
[
3
]
*
4
,
out_channels
=
2048
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
),
nn
.
BatchNorm2D
(
2048
,
momentum
=
norm_momentum
),
nn
.
ReLU
())
return
incre_modules
,
downsamp_modules
,
final_layer
def
forward
(
self
,
inputs
):
x
=
inputs
[
'image'
]
conv1
=
self
.
conv_layer1_1
(
x
)
...
...
@@ -707,6 +842,14 @@ class HRNet(nn.Layer):
x
=
paddle
.
concat
([
st4
[
0
],
x1
,
x2
,
x3
],
1
)
return
x
if
self
.
downsample
:
y
=
self
.
incre_modules
[
0
](
st4
[
0
])
for
i
in
range
(
len
(
self
.
downsamp_modules
)):
y
=
self
.
incre_modules
[
i
+
1
](
st4
[
i
+
1
])
+
\
self
.
downsamp_modules
[
i
](
y
)
y
=
self
.
final_layer
(
y
)
return
y
res
=
[]
for
i
,
layer
in
enumerate
(
st4
):
if
i
==
self
.
freeze_at
:
...
...
ppdet/modeling/backbones/trans_encoder.py
0 → 100644
浏览文件 @
d4e34fe1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn
import
ReLU
,
Swish
,
GELU
import
math
from
ppdet.core.workspace
import
register
from
..shape_spec
import
ShapeSpec
__all__
=
[
'TransEncoder'
]
class
BertEmbeddings
(
nn
.
Layer
):
def
__init__
(
self
,
word_size
,
position_embeddings_size
,
word_type_size
,
hidden_size
,
dropout_prob
):
super
(
BertEmbeddings
,
self
).
__init__
()
self
.
word_embeddings
=
nn
.
Embedding
(
word_size
,
hidden_size
,
padding_idx
=
0
)
self
.
position_embeddings
=
nn
.
Embedding
(
position_embeddings_size
,
hidden_size
)
self
.
token_type_embeddings
=
nn
.
Embedding
(
word_type_size
,
hidden_size
)
self
.
layernorm
=
nn
.
LayerNorm
(
hidden_size
,
epsilon
=
1e-8
)
self
.
dropout
=
nn
.
Dropout
(
dropout_prob
)
def
forward
(
self
,
x
,
token_type_ids
=
None
,
position_ids
=
None
):
seq_len
=
paddle
.
shape
(
x
)[
1
]
if
position_ids
is
None
:
position_ids
=
paddle
.
arange
(
seq_len
).
unsqueeze
(
0
).
expand_as
(
x
)
if
token_type_ids
is
None
:
token_type_ids
=
paddle
.
zeros
(
paddle
.
shape
(
x
))
word_embs
=
self
.
word_embeddings
(
x
)
position_embs
=
self
.
position_embeddings
(
position_ids
)
token_type_embs
=
self
.
token_type_embeddings
(
token_type_ids
)
embs_cmb
=
word_embs
+
position_embs
+
token_type_embs
embs_out
=
self
.
layernorm
(
embs_cmb
)
embs_out
=
self
.
dropout
(
embs_out
)
return
embs_out
class
BertSelfAttention
(
nn
.
Layer
):
def
__init__
(
self
,
hidden_size
,
num_attention_heads
,
attention_probs_dropout_prob
,
output_attentions
=
False
):
super
(
BertSelfAttention
,
self
).
__init__
()
if
hidden_size
%
num_attention_heads
!=
0
:
raise
ValueError
(
"The hidden_size must be a multiple of the number of attention "
"heads, but got {} % {} != 0"
%
(
hidden_size
,
num_attention_heads
))
self
.
num_attention_heads
=
num_attention_heads
self
.
attention_head_size
=
int
(
hidden_size
/
num_attention_heads
)
self
.
all_head_size
=
self
.
num_attention_heads
*
self
.
attention_head_size
self
.
query
=
nn
.
Linear
(
hidden_size
,
self
.
all_head_size
)
self
.
key
=
nn
.
Linear
(
hidden_size
,
self
.
all_head_size
)
self
.
value
=
nn
.
Linear
(
hidden_size
,
self
.
all_head_size
)
self
.
dropout
=
nn
.
Dropout
(
attention_probs_dropout_prob
)
self
.
output_attentions
=
output_attentions
def
forward
(
self
,
x
,
attention_mask
,
head_mask
=
None
):
query
=
self
.
query
(
x
)
key
=
self
.
key
(
x
)
value
=
self
.
value
(
x
)
query_dim1
,
query_dim2
=
paddle
.
shape
(
query
)[:
-
1
]
new_shape
=
[
query_dim1
,
query_dim2
,
self
.
num_attention_heads
,
self
.
attention_head_size
]
query
=
query
.
reshape
(
new_shape
).
transpose
(
perm
=
(
0
,
2
,
1
,
3
))
key
=
key
.
reshape
(
new_shape
).
transpose
(
perm
=
(
0
,
2
,
3
,
1
))
value
=
value
.
reshape
(
new_shape
).
transpose
(
perm
=
(
0
,
2
,
1
,
3
))
attention
=
paddle
.
matmul
(
query
,
key
)
/
math
.
sqrt
(
self
.
attention_head_size
)
attention
=
attention
+
attention_mask
attention_value
=
F
.
softmax
(
attention
,
axis
=-
1
)
attention_value
=
self
.
dropout
(
attention_value
)
if
head_mask
is
not
None
:
attention_value
=
attention_value
*
head_mask
context
=
paddle
.
matmul
(
attention_value
,
value
).
transpose
(
perm
=
(
0
,
2
,
1
,
3
))
ctx_dim1
,
ctx_dim2
=
paddle
.
shape
(
context
)[:
-
2
]
new_context_shape
=
[
ctx_dim1
,
ctx_dim2
,
self
.
all_head_size
,
]
context
=
context
.
reshape
(
new_context_shape
)
if
self
.
output_attentions
:
return
(
context
,
attention_value
)
else
:
return
(
context
,
)
class
BertAttention
(
nn
.
Layer
):
def
__init__
(
self
,
hidden_size
,
num_attention_heads
,
attention_probs_dropout_prob
,
fc_dropout_prob
,
output_attentions
=
False
):
super
(
BertAttention
,
self
).
__init__
()
self
.
bert_selfattention
=
BertSelfAttention
(
hidden_size
,
num_attention_heads
,
attention_probs_dropout_prob
,
output_attentions
)
self
.
fc
=
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
layernorm
=
nn
.
LayerNorm
(
hidden_size
,
epsilon
=
1e-8
)
self
.
dropout
=
nn
.
Dropout
(
fc_dropout_prob
)
def
forward
(
self
,
x
,
attention_mask
,
head_mask
=
None
):
attention_feats
=
self
.
bert_selfattention
(
x
,
attention_mask
,
head_mask
)
features
=
self
.
fc
(
attention_feats
[
0
])
features
=
self
.
dropout
(
features
)
features
=
self
.
layernorm
(
features
+
x
)
if
len
(
attention_feats
)
==
2
:
return
(
features
,
attention_feats
[
1
])
else
:
return
(
features
,
)
class
BertFeedForward
(
nn
.
Layer
):
def
__init__
(
self
,
hidden_size
,
intermediate_size
,
num_attention_heads
,
attention_probs_dropout_prob
,
fc_dropout_prob
,
act_fn
=
'ReLU'
,
output_attentions
=
False
):
super
(
BertFeedForward
,
self
).
__init__
()
self
.
fc1
=
nn
.
Linear
(
hidden_size
,
intermediate_size
)
self
.
act_fn
=
eval
(
act_fn
)
self
.
fc2
=
nn
.
Linear
(
intermediate_size
,
hidden_size
)
self
.
layernorm
=
nn
.
LayerNorm
(
hidden_size
,
epsilon
=
1e-8
)
self
.
dropout
=
nn
.
Dropout
(
fc_dropout_prob
)
def
forward
(
self
,
x
):
features
=
self
.
fc1
(
x
)
features
=
self
.
act_fn
(
features
)
features
=
self
.
fc2
(
features
)
features
=
self
.
dropout
(
features
)
features
=
self
.
layernorm
(
features
+
x
)
return
features
class
BertLayer
(
nn
.
Layer
):
def
__init__
(
self
,
hidden_size
,
intermediate_size
,
num_attention_heads
,
attention_probs_dropout_prob
,
fc_dropout_prob
,
act_fn
=
'ReLU'
,
output_attentions
=
False
):
super
(
BertLayer
,
self
).
__init__
()
self
.
attention
=
BertAttention
(
hidden_size
,
num_attention_heads
,
attention_probs_dropout_prob
,
output_attentions
)
self
.
feed_forward
=
BertFeedForward
(
hidden_size
,
intermediate_size
,
num_attention_heads
,
attention_probs_dropout_prob
,
fc_dropout_prob
,
act_fn
,
output_attentions
)
def
forward
(
self
,
x
,
attention_mask
,
head_mask
=
None
):
attention_feats
=
self
.
attention
(
x
,
attention_mask
,
head_mask
)
features
=
self
.
feed_forward
(
attention_feats
[
0
])
if
len
(
attention_feats
)
==
2
:
return
(
features
,
attention_feats
[
1
])
else
:
return
(
features
,
)
class
BertEncoder
(
nn
.
Layer
):
def
__init__
(
self
,
num_hidden_layers
,
hidden_size
,
intermediate_size
,
num_attention_heads
,
attention_probs_dropout_prob
,
fc_dropout_prob
,
act_fn
=
'ReLU'
,
output_attentions
=
False
,
output_hidden_feats
=
False
):
super
(
BertEncoder
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
output_hidden_feats
=
output_hidden_feats
self
.
layers
=
nn
.
LayerList
([
BertLayer
(
hidden_size
,
intermediate_size
,
num_attention_heads
,
attention_probs_dropout_prob
,
fc_dropout_prob
,
act_fn
,
output_attentions
)
for
_
in
range
(
num_hidden_layers
)
])
def
forward
(
self
,
x
,
attention_mask
,
head_mask
=
None
):
all_features
=
(
x
,
)
all_attentions
=
()
for
i
,
layer
in
enumerate
(
self
.
layers
):
mask
=
head_mask
[
i
]
if
head_mask
is
not
None
else
None
layer_out
=
layer
(
x
,
attention_mask
,
mask
)
if
self
.
output_hidden_feats
:
all_features
=
all_features
+
(
x
,
)
x
=
layer_out
[
0
]
if
self
.
output_attentions
:
all_attentions
=
all_attentions
+
(
layer_out
[
1
],
)
outputs
=
(
x
,
)
if
self
.
output_hidden_feats
:
outputs
+=
(
all_features
,
)
if
self
.
output_attentions
:
outputs
+=
(
all_attentions
,
)
return
outputs
class
BertPooler
(
nn
.
Layer
):
def
__init__
(
self
,
hidden_size
):
super
(
BertPooler
,
self
).
__init__
()
self
.
fc
=
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
act
=
nn
.
Tanh
()
def
forward
(
self
,
x
):
first_token
=
x
[:,
0
]
pooled_output
=
self
.
fc
(
first_token
)
pooled_output
=
self
.
act
(
pooled_output
)
return
pooled_output
class
METROEncoder
(
nn
.
Layer
):
def
__init__
(
self
,
vocab_size
,
num_hidden_layers
,
features_dims
,
position_embeddings_size
,
hidden_size
,
intermediate_size
,
output_feature_dim
,
num_attention_heads
,
attention_probs_dropout_prob
,
fc_dropout_prob
,
act_fn
=
'ReLU'
,
output_attentions
=
False
,
output_hidden_feats
=
False
,
use_img_layernorm
=
False
):
super
(
METROEncoder
,
self
).
__init__
()
self
.
img_dims
=
features_dims
self
.
num_hidden_layers
=
num_hidden_layers
self
.
use_img_layernorm
=
use_img_layernorm
self
.
output_attentions
=
output_attentions
self
.
embedding
=
BertEmbeddings
(
vocab_size
,
position_embeddings_size
,
2
,
hidden_size
,
fc_dropout_prob
)
self
.
encoder
=
BertEncoder
(
num_hidden_layers
,
hidden_size
,
intermediate_size
,
num_attention_heads
,
attention_probs_dropout_prob
,
fc_dropout_prob
,
act_fn
,
output_attentions
,
output_hidden_feats
)
self
.
pooler
=
BertPooler
(
hidden_size
)
self
.
position_embeddings
=
nn
.
Embedding
(
position_embeddings_size
,
hidden_size
)
self
.
img_embedding
=
nn
.
Linear
(
features_dims
,
hidden_size
,
bias_attr
=
True
)
self
.
dropout
=
nn
.
Dropout
(
fc_dropout_prob
)
self
.
cls_head
=
nn
.
Linear
(
hidden_size
,
output_feature_dim
)
self
.
residual
=
nn
.
Linear
(
features_dims
,
output_feature_dim
)
self
.
apply
(
self
.
init_weights
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
"""
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
)):
module
.
weight
.
set_value
(
paddle
.
normal
(
mean
=
0.0
,
std
=
0.02
,
shape
=
module
.
weight
.
shape
))
elif
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
set_value
(
paddle
.
zeros
(
shape
=
module
.
bias
.
shape
))
module
.
weight
.
set_value
(
paddle
.
full
(
shape
=
module
.
weight
.
shape
,
fill_value
=
1.0
))
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
set_value
(
paddle
.
zeros
(
shape
=
module
.
bias
.
shape
))
def
forward
(
self
,
x
):
batchsize
,
seq_len
=
paddle
.
shape
(
x
)[:
2
]
input_ids
=
paddle
.
zeros
((
batchsize
,
seq_len
),
dtype
=
"int64"
)
position_ids
=
paddle
.
arange
(
seq_len
,
dtype
=
"int64"
).
unsqueeze
(
0
).
expand_as
(
input_ids
)
attention_mask
=
paddle
.
ones_like
(
input_ids
).
unsqueeze
(
1
).
unsqueeze
(
2
)
head_mask
=
[
None
]
*
self
.
num_hidden_layers
position_embs
=
self
.
position_embeddings
(
position_ids
)
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
img_features
=
self
.
img_embedding
(
x
)
# We empirically observe that adding an additional learnable position embedding leads to more stable training
embeddings
=
position_embs
+
img_features
if
self
.
use_img_layernorm
:
embeddings
=
self
.
layernorm
(
embeddings
)
embeddings
=
self
.
dropout
(
embeddings
)
encoder_outputs
=
self
.
encoder
(
embeddings
,
attention_mask
,
head_mask
=
head_mask
)
pred_score
=
self
.
cls_head
(
encoder_outputs
[
0
])
res_img_feats
=
self
.
residual
(
x
)
pred_score
=
pred_score
+
res_img_feats
if
self
.
output_attentions
and
self
.
output_hidden_feats
:
return
pred_score
,
encoder_outputs
[
1
],
encoder_outputs
[
-
1
]
else
:
return
pred_score
def
gelu
(
x
):
"""Implementation of the gelu activation function.
https://arxiv.org/abs/1606.08415
"""
return
x
*
0.5
*
(
1.0
+
paddle
.
erf
(
x
/
math
.
sqrt
(
2.0
)))
@
register
class
TransEncoder
(
nn
.
Layer
):
def
__init__
(
self
,
vocab_size
=
30522
,
num_hidden_layers
=
4
,
num_attention_heads
=
4
,
position_embeddings_size
=
512
,
intermediate_size
=
3072
,
input_feat_dim
=
[
2048
,
512
,
128
],
hidden_feat_dim
=
[
1024
,
256
,
128
],
attention_probs_dropout_prob
=
0.1
,
fc_dropout_prob
=
0.1
,
act_fn
=
'gelu'
,
output_attentions
=
False
,
output_hidden_feats
=
False
):
super
(
TransEncoder
,
self
).
__init__
()
output_feat_dim
=
input_feat_dim
[
1
:]
+
[
3
]
trans_encoder
=
[]
for
i
in
range
(
len
(
output_feat_dim
)):
features_dims
=
input_feat_dim
[
i
]
output_feature_dim
=
output_feat_dim
[
i
]
hidden_size
=
hidden_feat_dim
[
i
]
# init a transformer encoder and append it to a list
assert
hidden_size
%
num_attention_heads
==
0
model
=
METROEncoder
(
vocab_size
,
num_hidden_layers
,
features_dims
,
position_embeddings_size
,
hidden_size
,
intermediate_size
,
output_feature_dim
,
num_attention_heads
,
attention_probs_dropout_prob
,
fc_dropout_prob
,
act_fn
,
output_attentions
,
output_hidden_feats
)
trans_encoder
.
append
(
model
)
self
.
trans_encoder
=
paddle
.
nn
.
Sequential
(
*
trans_encoder
)
def
forward
(
self
,
x
):
out
=
self
.
trans_encoder
(
x
)
return
out
ppdet/modeling/losses/__init__.py
浏览文件 @
d4e34fe1
...
...
@@ -43,3 +43,4 @@ from .detr_loss import *
from
.sparsercnn_loss
import
*
from
.focal_loss
import
*
from
.smooth_l1_loss
import
*
from
.pose3d_loss
import
*
ppdet/modeling/losses/pose3d_loss.py
0 → 100644
浏览文件 @
d4e34fe1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
itertools
import
cycle
,
islice
from
collections
import
abc
import
paddle
import
paddle.nn
as
nn
from
ppdet.core.workspace
import
register
,
serializable
__all__
=
[
'Pose3DLoss'
]
@
register
@
serializable
class
Pose3DLoss
(
nn
.
Layer
):
def
__init__
(
self
,
weight_3d
=
1.0
,
weight_2d
=
0.0
,
reduction
=
'none'
):
"""
KeyPointMSELoss layer
Args:
weight_3d (float): weight of 3d loss
weight_2d (float): weight of 2d loss
reduction (bool): whether use reduction to loss
"""
super
(
Pose3DLoss
,
self
).
__init__
()
self
.
weight_3d
=
weight_3d
self
.
weight_2d
=
weight_2d
self
.
criterion_2dpose
=
nn
.
MSELoss
(
reduction
=
reduction
)
self
.
criterion_3dpose
=
nn
.
MSELoss
(
reduction
=
reduction
)
self
.
criterion_smoothl1
=
nn
.
SmoothL1Loss
(
reduction
=
reduction
,
delta
=
1.0
)
self
.
criterion_vertices
=
nn
.
L1Loss
()
def
forward
(
self
,
pred3d
,
pred2d
,
inputs
):
"""
mpjpe: mpjpe loss between 3d joints
keypoint_2d_loss: 2d joints loss compute by criterion_2dpose
"""
gt_3d_joints
=
inputs
[
'joints_3d'
]
gt_2d_joints
=
inputs
[
'joints_2d'
]
has_3d_joints
=
inputs
[
'has_3d_joints'
]
has_2d_joints
=
inputs
[
'has_2d_joints'
]
loss_3d
=
mpjpe
(
pred3d
,
gt_3d_joints
,
has_3d_joints
)
loss_2d
=
keypoint_2d_loss
(
self
.
criterion_2dpose
,
pred2d
,
gt_2d_joints
,
has_2d_joints
)
return
self
.
weight_3d
*
loss_3d
+
self
.
weight_2d
*
loss_2d
def
filter_3d_joints
(
pred
,
gt
,
has_3d_joints
):
"""
filter 3d joints
"""
gt
=
gt
[
has_3d_joints
==
1
]
gt
=
gt
[:,
:,
:
3
]
pred
=
pred
[
has_3d_joints
==
1
]
gt_pelvis
=
(
gt
[:,
2
,
:]
+
gt
[:,
3
,
:])
/
2
gt
=
gt
-
gt_pelvis
[:,
None
,
:]
pred_pelvis
=
(
pred
[:,
2
,
:]
+
pred
[:,
3
,
:])
/
2
pred
=
pred
-
pred_pelvis
[:,
None
,
:]
return
pred
,
gt
@
register
@
serializable
def
mpjpe
(
pred
,
gt
,
has_3d_joints
):
"""
mPJPE loss
"""
pred
,
gt
=
filter_3d_joints
(
pred
,
gt
,
has_3d_joints
)
error
=
paddle
.
sqrt
(((
pred
-
gt
)
**
2
).
sum
(
axis
=-
1
)).
mean
()
return
error
@
register
@
serializable
def
mpjpe_criterion
(
pred
,
gt
,
has_3d_joints
,
criterion_pose3d
):
"""
mPJPE loss of self define criterion
"""
pred
,
gt
=
filter_3d_joints
(
pred
,
gt
,
has_3d_joints
)
error
=
paddle
.
sqrt
(
criterion_pose3d
(
pred
,
gt
).
sum
(
axis
=-
1
)).
mean
()
return
error
@
register
@
serializable
def
weighted_mpjpe
(
pred
,
gt
,
has_3d_joints
):
"""
Weighted_mPJPE
"""
pred
,
gt
=
filter_3d_joints
(
pred
,
gt
,
has_3d_joints
)
weight
=
paddle
.
linalg
.
norm
(
pred
,
p
=
2
,
axis
=-
1
)
weight
=
paddle
.
to_tensor
(
[
1.5
,
1.3
,
1.2
,
1.2
,
1.3
,
1.5
,
1.5
,
1.3
,
1.2
,
1.2
,
1.3
,
1.5
,
1.
,
1.
])
error
=
(
weight
*
paddle
.
linalg
.
norm
(
pred
-
gt
,
p
=
2
,
axis
=-
1
)).
mean
()
return
error
@
register
@
serializable
def
normed_mpjpe
(
pred
,
gt
,
has_3d_joints
):
"""
Normalized MPJPE (scale only), adapted from:
https://github.com/hrhodin/UnsupervisedGeometryAwareRepresentationLearning/blob/master/losses/poses.py
"""
assert
pred
.
shape
==
gt
.
shape
pred
,
gt
=
filter_3d_joints
(
pred
,
gt
,
has_3d_joints
)
norm_predicted
=
paddle
.
mean
(
paddle
.
sum
(
pred
**
2
,
axis
=
3
,
keepdim
=
True
),
axis
=
2
,
keepdim
=
True
)
norm_target
=
paddle
.
mean
(
paddle
.
sum
(
gt
*
pred
,
axis
=
3
,
keepdim
=
True
),
axis
=
2
,
keepdim
=
True
)
scale
=
norm_target
/
norm_predicted
return
mpjpe
(
scale
*
pred
,
gt
)
@
register
@
serializable
def
mpjpe_np
(
pred
,
gt
,
has_3d_joints
):
"""
mPJPE_NP
"""
pred
,
gt
=
filter_3d_joints
(
pred
,
gt
,
has_3d_joints
)
error
=
np
.
sqrt
(((
pred
-
gt
)
**
2
).
sum
(
axis
=-
1
)).
mean
()
return
error
@
register
@
serializable
def
mean_per_vertex_error
(
pred
,
gt
,
has_smpl
):
"""
Compute mPVE
"""
pred
=
pred
[
has_smpl
==
1
]
gt
=
gt
[
has_smpl
==
1
]
with
paddle
.
no_grad
():
error
=
paddle
.
sqrt
(((
pred
-
gt
)
**
2
).
sum
(
axis
=-
1
)).
mean
()
return
error
@
register
@
serializable
def
keypoint_2d_loss
(
criterion_keypoints
,
pred_keypoints_2d
,
gt_keypoints_2d
,
has_pose_2d
):
"""
Compute 2D reprojection loss if 2D keypoint annotations are available.
The confidence (conf) is binary and indicates whether the keypoints exist or not.
"""
conf
=
gt_keypoints_2d
[:,
:,
-
1
].
unsqueeze
(
-
1
).
clone
()
loss
=
(
conf
*
criterion_keypoints
(
pred_keypoints_2d
,
gt_keypoints_2d
[:,
:,
:
-
1
])).
mean
()
return
loss
@
register
@
serializable
def
keypoint_3d_loss
(
criterion_keypoints
,
pred_keypoints_3d
,
gt_keypoints_3d
,
has_pose_3d
):
"""
Compute 3D keypoint loss if 3D keypoint annotations are available.
"""
conf
=
gt_keypoints_3d
[:,
:,
-
1
].
unsqueeze
(
-
1
).
clone
()
gt_keypoints_3d
=
gt_keypoints_3d
[:,
:,
:
-
1
].
clone
()
gt_keypoints_3d
=
gt_keypoints_3d
[
has_pose_3d
==
1
]
conf
=
conf
[
has_pose_3d
==
1
]
pred_keypoints_3d
=
pred_keypoints_3d
[
has_pose_3d
==
1
]
if
len
(
gt_keypoints_3d
)
>
0
:
gt_pelvis
=
(
gt_keypoints_3d
[:,
2
,
:]
+
gt_keypoints_3d
[:,
3
,
:])
/
2
gt_keypoints_3d
=
gt_keypoints_3d
-
gt_pelvis
[:,
None
,
:]
pred_pelvis
=
(
pred_keypoints_3d
[:,
2
,
:]
+
pred_keypoints_3d
[:,
3
,
:])
/
2
pred_keypoints_3d
=
pred_keypoints_3d
-
pred_pelvis
[:,
None
,
:]
return
(
conf
*
criterion_keypoints
(
pred_keypoints_3d
,
gt_keypoints_3d
)).
mean
()
else
:
return
paddle
.
to_tensor
([
1.
]).
fill_
(
0.
)
@
register
@
serializable
def
vertices_loss
(
criterion_vertices
,
pred_vertices
,
gt_vertices
,
has_smpl
):
"""
Compute per-vertex loss if vertex annotations are available.
"""
pred_vertices_with_shape
=
pred_vertices
[
has_smpl
==
1
]
gt_vertices_with_shape
=
gt_vertices
[
has_smpl
==
1
]
if
len
(
gt_vertices_with_shape
)
>
0
:
return
criterion_vertices
(
pred_vertices_with_shape
,
gt_vertices_with_shape
)
else
:
return
paddle
.
to_tensor
([
1.
]).
fill_
(
0.
)
@
register
@
serializable
def
rectify_pose
(
pose
):
pose
=
pose
.
copy
()
R_mod
=
cv2
.
Rodrigues
(
np
.
array
([
np
.
pi
,
0
,
0
]))[
0
]
R_root
=
cv2
.
Rodrigues
(
pose
[:
3
])[
0
]
new_root
=
R_root
.
dot
(
R_mod
)
pose
[:
3
]
=
cv2
.
Rodrigues
(
new_root
)[
0
].
reshape
(
3
)
return
pose
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录