Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
e8aeb802
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e8aeb802
编写于
7月 19, 2021
作者:
S
shangliang Xu
提交者:
GitHub
7月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[transformer] add Deformable DETR base code (#3718)
上级
283f5ac7
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
667 addition
and
12 deletion
+667
-12
ppdet/modeling/heads/detr_head.py
ppdet/modeling/heads/detr_head.py
+77
-2
ppdet/modeling/post_process.py
ppdet/modeling/post_process.py
+16
-5
ppdet/modeling/transformers/__init__.py
ppdet/modeling/transformers/__init__.py
+2
-0
ppdet/modeling/transformers/deformable_transformer.py
ppdet/modeling/transformers/deformable_transformer.py
+514
-0
ppdet/modeling/transformers/position_encoding.py
ppdet/modeling/transformers/position_encoding.py
+8
-4
ppdet/modeling/transformers/utils.py
ppdet/modeling/transformers/utils.py
+50
-1
未找到文件。
ppdet/modeling/heads/detr_head.py
浏览文件 @
e8aeb802
...
...
@@ -21,9 +21,10 @@ import paddle.nn as nn
import
paddle.nn.functional
as
F
from
ppdet.core.workspace
import
register
import
pycocotools.mask
as
mask_util
from
..initializer
import
linear_init_
from
..initializer
import
linear_init_
,
constant_
from
..transformers.utils
import
inverse_sigmoid
__all__
=
[
'DETRHead'
]
__all__
=
[
'DETRHead'
,
'DeformableDETRHead'
]
class
MLP
(
nn
.
Layer
):
...
...
@@ -275,3 +276,77 @@ class DETRHead(nn.Layer):
gt_mask
=
gt_mask
)
else
:
return
(
outputs_bbox
[
-
1
],
outputs_logit
[
-
1
],
outputs_seg
)
@
register
class
DeformableDETRHead
(
nn
.
Layer
):
__shared__
=
[
'num_classes'
,
'hidden_dim'
]
__inject__
=
[
'loss'
]
def
__init__
(
self
,
num_classes
=
80
,
hidden_dim
=
512
,
nhead
=
8
,
num_mlp_layers
=
3
,
loss
=
'DETRLoss'
):
super
(
DeformableDETRHead
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
hidden_dim
=
hidden_dim
self
.
nhead
=
nhead
self
.
loss
=
loss
self
.
score_head
=
nn
.
Linear
(
hidden_dim
,
self
.
num_classes
)
self
.
bbox_head
=
MLP
(
hidden_dim
,
hidden_dim
,
output_dim
=
4
,
num_layers
=
num_mlp_layers
)
self
.
_reset_parameters
()
def
_reset_parameters
(
self
):
linear_init_
(
self
.
score_head
)
constant_
(
self
.
score_head
.
bias
,
-
4.595
)
constant_
(
self
.
bbox_head
.
layers
[
-
1
].
weight
)
bias
=
paddle
.
zeros_like
(
self
.
bbox_head
.
layers
[
-
1
].
bias
)
bias
[
2
:]
=
-
2.0
self
.
bbox_head
.
layers
[
-
1
].
bias
.
set_value
(
bias
)
@
classmethod
def
from_config
(
cls
,
cfg
,
hidden_dim
,
nhead
,
input_shape
):
return
{
'hidden_dim'
:
hidden_dim
,
'nhead'
:
nhead
}
def
forward
(
self
,
out_transformer
,
body_feats
,
inputs
=
None
):
r
"""
Args:
out_transformer (Tuple): (feats: [num_levels, batch_size,
num_queries, hidden_dim],
memory: [batch_size,
\sum_{l=0}^{L-1} H_l \cdot W_l, hidden_dim],
reference_points: [batch_size, num_queries, 2])
body_feats (List(Tensor)): list[[B, C, H, W]]
inputs (dict): dict(inputs)
"""
feats
,
memory
,
reference_points
=
out_transformer
reference_points
=
inverse_sigmoid
(
reference_points
.
unsqueeze
(
0
))
outputs_bbox
=
self
.
bbox_head
(
feats
)
# It's equivalent to "outputs_bbox[:, :, :, :2] += reference_points",
# but the gradient is wrong in paddle.
outputs_bbox
=
paddle
.
concat
(
[
outputs_bbox
[:,
:,
:,
:
2
]
+
reference_points
,
outputs_bbox
[:,
:,
:,
2
:]
],
axis
=-
1
)
outputs_bbox
=
F
.
sigmoid
(
outputs_bbox
)
outputs_logit
=
self
.
score_head
(
feats
)
if
self
.
training
:
assert
inputs
is
not
None
assert
'gt_bbox'
in
inputs
and
'gt_class'
in
inputs
return
self
.
loss
(
outputs_bbox
,
outputs_logit
,
inputs
[
'gt_bbox'
],
inputs
[
'gt_class'
])
else
:
return
(
outputs_bbox
[
-
1
],
outputs_logit
[
-
1
],
None
)
ppdet/modeling/post_process.py
浏览文件 @
e8aeb802
...
...
@@ -532,12 +532,23 @@ class DETRBBoxPostProcess(object):
scores
=
F
.
sigmoid
(
logits
)
if
self
.
use_focal_loss
else
F
.
softmax
(
logits
)[:,
:,
:
-
1
]
scores
,
labels
=
scores
.
max
(
-
1
),
scores
.
argmax
(
-
1
)
if
scores
.
shape
[
1
]
>
self
.
num_top_queries
:
scores
,
index
=
paddle
.
topk
(
scores
,
self
.
num_top_queries
,
axis
=-
1
)
labels
=
paddle
.
stack
(
[
paddle
.
gather
(
l
,
i
)
for
l
,
i
in
zip
(
labels
,
index
)])
if
not
self
.
use_focal_loss
:
scores
,
labels
=
scores
.
max
(
-
1
),
scores
.
argmax
(
-
1
)
if
scores
.
shape
[
1
]
>
self
.
num_top_queries
:
scores
,
index
=
paddle
.
topk
(
scores
,
self
.
num_top_queries
,
axis
=-
1
)
labels
=
paddle
.
stack
(
[
paddle
.
gather
(
l
,
i
)
for
l
,
i
in
zip
(
labels
,
index
)])
bbox_pred
=
paddle
.
stack
(
[
paddle
.
gather
(
b
,
i
)
for
b
,
i
in
zip
(
bbox_pred
,
index
)])
else
:
scores
,
index
=
paddle
.
topk
(
scores
.
reshape
([
logits
.
shape
[
0
],
-
1
]),
self
.
num_top_queries
,
axis
=-
1
)
labels
=
index
%
logits
.
shape
[
2
]
index
=
index
//
logits
.
shape
[
2
]
bbox_pred
=
paddle
.
stack
(
[
paddle
.
gather
(
b
,
i
)
for
b
,
i
in
zip
(
bbox_pred
,
index
)])
...
...
ppdet/modeling/transformers/__init__.py
浏览文件 @
e8aeb802
...
...
@@ -16,8 +16,10 @@ from . import detr_transformer
from
.
import
utils
from
.
import
matchers
from
.
import
position_encoding
from
.
import
deformable_transformer
from
.detr_transformer
import
*
from
.utils
import
*
from
.matchers
import
*
from
.position_encoding
import
*
from
.deformable_transformer
import
*
ppdet/modeling/transformers/deformable_transformer.py
0 → 100644
浏览文件 @
e8aeb802
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle
import
ParamAttr
from
ppdet.core.workspace
import
register
from
..layers
import
MultiHeadAttention
from
.position_encoding
import
PositionEmbedding
from
.utils
import
_get_clones
,
deformable_attention_core_func
from
..initializer
import
linear_init_
,
constant_
,
xavier_uniform_
,
normal_
__all__
=
[
'DeformableTransformer'
]
class
MSDeformableAttention
(
nn
.
Layer
):
def
__init__
(
self
,
embed_dim
=
256
,
num_heads
=
8
,
num_levels
=
4
,
num_points
=
4
,
lr_mult
=
0.1
):
"""
Multi-Scale Deformable Attention Module
"""
super
(
MSDeformableAttention
,
self
).
__init__
()
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
num_levels
=
num_levels
self
.
num_points
=
num_points
self
.
total_points
=
num_heads
*
num_levels
*
num_points
self
.
head_dim
=
embed_dim
//
num_heads
assert
self
.
head_dim
*
num_heads
==
self
.
embed_dim
,
"embed_dim must be divisible by num_heads"
self
.
sampling_offsets
=
nn
.
Linear
(
embed_dim
,
self
.
total_points
*
2
,
weight_attr
=
ParamAttr
(
learning_rate
=
lr_mult
),
bias_attr
=
ParamAttr
(
learning_rate
=
lr_mult
))
self
.
attention_weights
=
nn
.
Linear
(
embed_dim
,
self
.
total_points
)
self
.
value_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
output_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
_reset_parameters
()
def
_reset_parameters
(
self
):
# sampling_offsets
constant_
(
self
.
sampling_offsets
.
weight
)
thetas
=
paddle
.
arange
(
self
.
num_heads
,
dtype
=
paddle
.
float32
)
*
(
2.0
*
math
.
pi
/
self
.
num_heads
)
grid_init
=
paddle
.
stack
([
thetas
.
cos
(),
thetas
.
sin
()],
-
1
)
grid_init
=
grid_init
/
grid_init
.
abs
().
max
(
-
1
,
keepdim
=
True
)
grid_init
=
grid_init
.
reshape
([
self
.
num_heads
,
1
,
1
,
2
]).
tile
(
[
1
,
self
.
num_levels
,
self
.
num_points
,
1
])
scaling
=
paddle
.
arange
(
1
,
self
.
num_points
+
1
,
dtype
=
paddle
.
float32
).
reshape
([
1
,
1
,
-
1
,
1
])
grid_init
*=
scaling
self
.
sampling_offsets
.
bias
.
set_value
(
grid_init
.
flatten
())
# attention_weights
constant_
(
self
.
attention_weights
.
weight
)
constant_
(
self
.
attention_weights
.
bias
)
# proj
xavier_uniform_
(
self
.
value_proj
.
weight
)
constant_
(
self
.
value_proj
.
bias
)
xavier_uniform_
(
self
.
output_proj
.
weight
)
constant_
(
self
.
output_proj
.
bias
)
def
forward
(
self
,
query
,
reference_points
,
value
,
value_spatial_shapes
,
value_mask
=
None
):
"""
Args:
query (Tensor): [bs, query_length, C]
reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area
value (Tensor): [bs, value_length, C]
value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
Returns:
output (Tensor): [bs, Length_{query}, C]
"""
bs
,
Len_q
=
query
.
shape
[:
2
]
Len_v
=
value
.
shape
[
1
]
assert
int
(
value_spatial_shapes
.
prod
(
1
).
sum
())
==
Len_v
value
=
self
.
value_proj
(
value
)
if
value_mask
is
not
None
:
value_mask
=
value_mask
.
astype
(
value
.
dtype
).
unsqueeze
(
-
1
)
value
*=
value_mask
value
=
value
.
reshape
([
bs
,
Len_v
,
self
.
num_heads
,
self
.
head_dim
])
sampling_offsets
=
self
.
sampling_offsets
(
query
).
reshape
(
[
bs
,
Len_q
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
2
])
attention_weights
=
self
.
attention_weights
(
query
).
reshape
(
[
bs
,
Len_q
,
self
.
num_heads
,
self
.
num_levels
*
self
.
num_points
])
attention_weights
=
F
.
softmax
(
attention_weights
,
-
1
).
reshape
(
[
bs
,
Len_q
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
])
offset_normalizer
=
value_spatial_shapes
.
flip
([
1
]).
reshape
(
[
1
,
1
,
1
,
self
.
num_levels
,
1
,
2
])
sampling_locations
=
reference_points
.
reshape
([
bs
,
Len_q
,
1
,
self
.
num_levels
,
1
,
2
])
+
sampling_offsets
/
offset_normalizer
output
=
deformable_attention_core_func
(
value
,
value_spatial_shapes
,
sampling_locations
,
attention_weights
)
output
=
self
.
output_proj
(
output
)
return
output
class
DeformableTransformerEncoderLayer
(
nn
.
Layer
):
def
__init__
(
self
,
d_model
=
256
,
n_head
=
8
,
dim_feedforward
=
1024
,
dropout
=
0.1
,
activation
=
"relu"
,
n_levels
=
4
,
n_points
=
4
,
weight_attr
=
None
,
bias_attr
=
None
):
super
(
DeformableTransformerEncoderLayer
,
self
).
__init__
()
# self attention
self
.
self_attn
=
MSDeformableAttention
(
d_model
,
n_head
,
n_levels
,
n_points
)
self
.
dropout1
=
nn
.
Dropout
(
dropout
)
self
.
norm1
=
nn
.
LayerNorm
(
d_model
)
# ffn
self
.
linear1
=
nn
.
Linear
(
d_model
,
dim_feedforward
,
weight_attr
,
bias_attr
)
self
.
activation
=
getattr
(
F
,
activation
)
self
.
dropout2
=
nn
.
Dropout
(
dropout
)
self
.
linear2
=
nn
.
Linear
(
dim_feedforward
,
d_model
,
weight_attr
,
bias_attr
)
self
.
dropout3
=
nn
.
Dropout
(
dropout
)
self
.
norm2
=
nn
.
LayerNorm
(
d_model
)
self
.
_reset_parameters
()
def
_reset_parameters
(
self
):
linear_init_
(
self
.
linear1
)
linear_init_
(
self
.
linear2
)
xavier_uniform_
(
self
.
linear1
.
weight
)
xavier_uniform_
(
self
.
linear2
.
weight
)
def
with_pos_embed
(
self
,
tensor
,
pos
):
return
tensor
if
pos
is
None
else
tensor
+
pos
def
forward_ffn
(
self
,
src
):
src2
=
self
.
linear2
(
self
.
dropout2
(
self
.
activation
(
self
.
linear1
(
src
))))
src
=
src
+
self
.
dropout3
(
src2
)
src
=
self
.
norm2
(
src
)
return
src
def
forward
(
self
,
src
,
reference_points
,
spatial_shapes
,
src_mask
=
None
,
pos_embed
=
None
):
# self attention
src2
=
self
.
self_attn
(
self
.
with_pos_embed
(
src
,
pos_embed
),
reference_points
,
src
,
spatial_shapes
,
src_mask
)
src
=
src
+
self
.
dropout1
(
src2
)
src
=
self
.
norm1
(
src
)
# ffn
src
=
self
.
forward_ffn
(
src
)
return
src
class
DeformableTransformerEncoder
(
nn
.
Layer
):
def
__init__
(
self
,
encoder_layer
,
num_layers
):
super
(
DeformableTransformerEncoder
,
self
).
__init__
()
self
.
layers
=
_get_clones
(
encoder_layer
,
num_layers
)
self
.
num_layers
=
num_layers
@
staticmethod
def
get_reference_points
(
spatial_shapes
,
valid_ratios
):
valid_ratios
=
valid_ratios
.
unsqueeze
(
1
)
reference_points
=
[]
for
i
,
(
H
,
W
)
in
enumerate
(
spatial_shapes
.
tolist
()):
ref_y
,
ref_x
=
paddle
.
meshgrid
(
paddle
.
linspace
(
0.5
,
H
-
0.5
,
H
),
paddle
.
linspace
(
0.5
,
W
-
0.5
,
W
))
ref_y
=
ref_y
.
flatten
().
unsqueeze
(
0
)
/
(
valid_ratios
[:,
:,
i
,
1
]
*
H
)
ref_x
=
ref_x
.
flatten
().
unsqueeze
(
0
)
/
(
valid_ratios
[:,
:,
i
,
0
]
*
W
)
reference_points
.
append
(
paddle
.
stack
((
ref_x
,
ref_y
),
axis
=-
1
))
reference_points
=
paddle
.
concat
(
reference_points
,
1
).
unsqueeze
(
2
)
reference_points
=
reference_points
*
valid_ratios
return
reference_points
def
forward
(
self
,
src
,
spatial_shapes
,
src_mask
=
None
,
pos_embed
=
None
,
valid_ratios
=
None
):
output
=
src
if
valid_ratios
is
None
:
valid_ratios
=
paddle
.
ones
(
[
src
.
shape
[
0
],
spatial_shapes
.
shape
[
0
],
2
])
reference_points
=
self
.
get_reference_points
(
spatial_shapes
,
valid_ratios
)
for
layer
in
self
.
layers
:
output
=
layer
(
output
,
reference_points
,
spatial_shapes
,
src_mask
,
pos_embed
)
return
output
class
DeformableTransformerDecoderLayer
(
nn
.
Layer
):
def
__init__
(
self
,
d_model
=
256
,
n_head
=
8
,
dim_feedforward
=
1024
,
dropout
=
0.1
,
activation
=
"relu"
,
n_levels
=
4
,
n_points
=
4
,
weight_attr
=
None
,
bias_attr
=
None
):
super
(
DeformableTransformerDecoderLayer
,
self
).
__init__
()
# self attention
self
.
self_attn
=
MultiHeadAttention
(
d_model
,
n_head
,
dropout
=
dropout
)
self
.
dropout1
=
nn
.
Dropout
(
dropout
)
self
.
norm1
=
nn
.
LayerNorm
(
d_model
)
# cross attention
self
.
cross_attn
=
MSDeformableAttention
(
d_model
,
n_head
,
n_levels
,
n_points
)
self
.
dropout2
=
nn
.
Dropout
(
dropout
)
self
.
norm2
=
nn
.
LayerNorm
(
d_model
)
# ffn
self
.
linear1
=
nn
.
Linear
(
d_model
,
dim_feedforward
,
weight_attr
,
bias_attr
)
self
.
activation
=
getattr
(
F
,
activation
)
self
.
dropout3
=
nn
.
Dropout
(
dropout
)
self
.
linear2
=
nn
.
Linear
(
dim_feedforward
,
d_model
,
weight_attr
,
bias_attr
)
self
.
dropout4
=
nn
.
Dropout
(
dropout
)
self
.
norm3
=
nn
.
LayerNorm
(
d_model
)
self
.
_reset_parameters
()
def
_reset_parameters
(
self
):
linear_init_
(
self
.
linear1
)
linear_init_
(
self
.
linear2
)
xavier_uniform_
(
self
.
linear1
.
weight
)
xavier_uniform_
(
self
.
linear2
.
weight
)
def
with_pos_embed
(
self
,
tensor
,
pos
):
return
tensor
if
pos
is
None
else
tensor
+
pos
def
forward_ffn
(
self
,
tgt
):
tgt2
=
self
.
linear2
(
self
.
dropout3
(
self
.
activation
(
self
.
linear1
(
tgt
))))
tgt
=
tgt
+
self
.
dropout4
(
tgt2
)
tgt
=
self
.
norm3
(
tgt
)
return
tgt
def
forward
(
self
,
tgt
,
reference_points
,
memory
,
memory_spatial_shapes
,
memory_mask
=
None
,
query_pos_embed
=
None
):
# self attention
q
=
k
=
self
.
with_pos_embed
(
tgt
,
query_pos_embed
)
tgt2
=
self
.
self_attn
(
q
,
k
,
value
=
tgt
)
tgt
=
tgt
+
self
.
dropout1
(
tgt2
)
tgt
=
self
.
norm1
(
tgt
)
# cross attention
tgt2
=
self
.
cross_attn
(
self
.
with_pos_embed
(
tgt
,
query_pos_embed
),
reference_points
,
memory
,
memory_spatial_shapes
,
memory_mask
)
tgt
=
tgt
+
self
.
dropout2
(
tgt2
)
tgt
=
self
.
norm2
(
tgt
)
# ffn
tgt
=
self
.
forward_ffn
(
tgt
)
return
tgt
class
DeformableTransformerDecoder
(
nn
.
Layer
):
def
__init__
(
self
,
decoder_layer
,
num_layers
,
return_intermediate
=
False
):
super
(
DeformableTransformerDecoder
,
self
).
__init__
()
self
.
layers
=
_get_clones
(
decoder_layer
,
num_layers
)
self
.
num_layers
=
num_layers
self
.
return_intermediate
=
return_intermediate
def
forward
(
self
,
tgt
,
reference_points
,
memory
,
memory_spatial_shapes
,
memory_mask
=
None
,
query_pos_embed
=
None
):
output
=
tgt
intermediate
=
[]
for
lid
,
layer
in
enumerate
(
self
.
layers
):
output
=
layer
(
output
,
reference_points
,
memory
,
memory_spatial_shapes
,
memory_mask
,
query_pos_embed
)
if
self
.
return_intermediate
:
intermediate
.
append
(
output
)
if
self
.
return_intermediate
:
return
paddle
.
stack
(
intermediate
)
return
output
.
unsqueeze
(
0
)
@
register
class
DeformableTransformer
(
nn
.
Layer
):
__shared__
=
[
'hidden_dim'
]
def
__init__
(
self
,
num_queries
=
300
,
position_embed_type
=
'sine'
,
return_intermediate_dec
=
True
,
backbone_num_channels
=
[
512
,
1024
,
2048
],
num_feature_levels
=
4
,
num_encoder_points
=
4
,
num_decoder_points
=
4
,
hidden_dim
=
256
,
nhead
=
8
,
num_encoder_layers
=
6
,
num_decoder_layers
=
6
,
dim_feedforward
=
1024
,
dropout
=
0.1
,
activation
=
"relu"
,
lr_mult
=
0.1
,
weight_attr
=
None
,
bias_attr
=
None
):
super
(
DeformableTransformer
,
self
).
__init__
()
assert
position_embed_type
in
[
'sine'
,
'learned'
],
\
f
'ValueError: position_embed_type not supported
{
position_embed_type
}
!'
assert
len
(
backbone_num_channels
)
<=
num_feature_levels
self
.
hidden_dim
=
hidden_dim
self
.
nhead
=
nhead
self
.
num_feature_levels
=
num_feature_levels
encoder_layer
=
DeformableTransformerEncoderLayer
(
hidden_dim
,
nhead
,
dim_feedforward
,
dropout
,
activation
,
num_feature_levels
,
num_encoder_points
,
weight_attr
,
bias_attr
)
self
.
encoder
=
DeformableTransformerEncoder
(
encoder_layer
,
num_encoder_layers
)
decoder_layer
=
DeformableTransformerDecoderLayer
(
hidden_dim
,
nhead
,
dim_feedforward
,
dropout
,
activation
,
num_feature_levels
,
num_decoder_points
,
weight_attr
,
bias_attr
)
self
.
decoder
=
DeformableTransformerDecoder
(
decoder_layer
,
num_decoder_layers
,
return_intermediate_dec
)
self
.
level_embed
=
nn
.
Embedding
(
num_feature_levels
,
hidden_dim
)
self
.
tgt_embed
=
nn
.
Embedding
(
num_queries
,
hidden_dim
)
self
.
query_pos_embed
=
nn
.
Embedding
(
num_queries
,
hidden_dim
)
self
.
reference_points
=
nn
.
Linear
(
hidden_dim
,
2
,
weight_attr
=
ParamAttr
(
learning_rate
=
lr_mult
),
bias_attr
=
ParamAttr
(
learning_rate
=
lr_mult
))
self
.
input_proj
=
nn
.
LayerList
()
for
in_channels
in
backbone_num_channels
:
self
.
input_proj
.
append
(
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
,
hidden_dim
,
kernel_size
=
1
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
nn
.
GroupNorm
(
32
,
hidden_dim
)))
in_channels
=
backbone_num_channels
[
-
1
]
for
_
in
range
(
num_feature_levels
-
len
(
backbone_num_channels
)):
self
.
input_proj
.
append
(
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
,
hidden_dim
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
nn
.
GroupNorm
(
32
,
hidden_dim
)))
in_channels
=
hidden_dim
self
.
position_embedding
=
PositionEmbedding
(
hidden_dim
//
2
,
normalize
=
True
if
position_embed_type
==
'sine'
else
False
,
embed_type
=
position_embed_type
,
offset
=-
0.5
)
self
.
_reset_parameters
()
def
_reset_parameters
(
self
):
normal_
(
self
.
level_embed
.
weight
)
normal_
(
self
.
tgt_embed
.
weight
)
normal_
(
self
.
query_pos_embed
.
weight
)
xavier_uniform_
(
self
.
reference_points
.
weight
)
constant_
(
self
.
reference_points
.
bias
)
for
l
in
self
.
input_proj
:
xavier_uniform_
(
l
[
0
].
weight
)
constant_
(
l
[
0
].
bias
)
@
classmethod
def
from_config
(
cls
,
cfg
,
input_shape
):
return
{
'backbone_num_channels'
:
[
i
.
channels
for
i
in
input_shape
],
}
def
_get_valid_ratio
(
self
,
mask
):
mask
=
mask
.
astype
(
paddle
.
float32
)
_
,
H
,
W
=
mask
.
shape
valid_ratio_h
=
paddle
.
sum
(
mask
[:,
:,
0
],
1
)
/
H
valid_ratio_w
=
paddle
.
sum
(
mask
[:,
0
,
:],
1
)
/
W
valid_ratio
=
paddle
.
stack
([
valid_ratio_w
,
valid_ratio_h
],
-
1
)
return
valid_ratio
def
forward
(
self
,
src_feats
,
src_mask
=
None
):
srcs
=
[]
for
i
in
range
(
len
(
src_feats
)):
srcs
.
append
(
self
.
input_proj
[
i
](
src_feats
[
i
]))
if
self
.
num_feature_levels
>
len
(
srcs
):
len_srcs
=
len
(
srcs
)
for
i
in
range
(
len_srcs
,
self
.
num_feature_levels
):
if
i
==
len_srcs
:
srcs
.
append
(
self
.
input_proj
[
i
](
src_feats
[
-
1
]))
else
:
srcs
.
append
(
self
.
input_proj
[
i
](
srcs
[
-
1
]))
src_flatten
=
[]
mask_flatten
=
[]
lvl_pos_embed_flatten
=
[]
spatial_shapes
=
[]
valid_ratios
=
[]
for
level
,
src
in
enumerate
(
srcs
):
bs
,
c
,
h
,
w
=
src
.
shape
spatial_shapes
.
append
([
h
,
w
])
src
=
src
.
flatten
(
2
).
transpose
([
0
,
2
,
1
])
src_flatten
.
append
(
src
)
if
src_mask
is
not
None
:
mask
=
F
.
interpolate
(
src_mask
.
unsqueeze
(
0
).
astype
(
src
.
dtype
),
size
=
(
h
,
w
))[
0
].
astype
(
'bool'
)
else
:
mask
=
paddle
.
ones
([
bs
,
h
,
w
],
dtype
=
'bool'
)
valid_ratios
.
append
(
self
.
_get_valid_ratio
(
mask
))
pos_embed
=
self
.
position_embedding
(
mask
).
flatten
(
2
).
transpose
(
[
0
,
2
,
1
])
lvl_pos_embed
=
pos_embed
+
self
.
level_embed
.
weight
[
level
].
reshape
(
[
1
,
1
,
-
1
])
lvl_pos_embed_flatten
.
append
(
lvl_pos_embed
)
mask
=
mask
.
astype
(
src
.
dtype
).
flatten
(
1
)
mask_flatten
.
append
(
mask
)
src_flatten
=
paddle
.
concat
(
src_flatten
,
1
)
mask_flatten
=
paddle
.
concat
(
mask_flatten
,
1
)
lvl_pos_embed_flatten
=
paddle
.
concat
(
lvl_pos_embed_flatten
,
1
)
# [l, 2]
spatial_shapes
=
paddle
.
to_tensor
(
spatial_shapes
,
dtype
=
'int64'
)
# [b, l, 2]
valid_ratios
=
paddle
.
stack
(
valid_ratios
,
1
)
# encoder
memory
=
self
.
encoder
(
src_flatten
,
spatial_shapes
,
mask_flatten
,
lvl_pos_embed_flatten
,
valid_ratios
)
# prepare input for decoder
bs
,
_
,
c
=
memory
.
shape
query_embed
=
self
.
query_pos_embed
.
weight
.
unsqueeze
(
0
).
tile
([
bs
,
1
,
1
])
tgt
=
self
.
tgt_embed
.
weight
.
unsqueeze
(
0
).
tile
([
bs
,
1
,
1
])
reference_points
=
F
.
sigmoid
(
self
.
reference_points
(
query_embed
))
reference_points_input
=
reference_points
.
unsqueeze
(
2
)
*
valid_ratios
.
unsqueeze
(
1
)
# decoder
hs
=
self
.
decoder
(
tgt
,
reference_points_input
,
memory
,
spatial_shapes
,
mask_flatten
,
query_embed
)
return
(
hs
,
memory
,
reference_points
)
ppdet/modeling/transformers/position_encoding.py
浏览文件 @
e8aeb802
...
...
@@ -32,11 +32,14 @@ class PositionEmbedding(nn.Layer):
normalize
=
True
,
scale
=
None
,
embed_type
=
'sine'
,
num_embeddings
=
50
):
num_embeddings
=
50
,
offset
=
0.
):
super
(
PositionEmbedding
,
self
).
__init__
()
assert
embed_type
in
[
'sine'
,
'learned'
]
self
.
embed_type
=
embed_type
self
.
offset
=
offset
self
.
eps
=
1e-6
if
self
.
embed_type
==
'sine'
:
self
.
num_pos_feats
=
num_pos_feats
self
.
temperature
=
temperature
...
...
@@ -65,9 +68,10 @@ class PositionEmbedding(nn.Layer):
y_embed
=
mask
.
cumsum
(
1
,
dtype
=
'float32'
)
x_embed
=
mask
.
cumsum
(
2
,
dtype
=
'float32'
)
if
self
.
normalize
:
eps
=
1e-6
y_embed
=
y_embed
/
(
y_embed
[:,
-
1
:,
:]
+
eps
)
*
self
.
scale
x_embed
=
x_embed
/
(
x_embed
[:,
:,
-
1
:]
+
eps
)
*
self
.
scale
y_embed
=
(
y_embed
+
self
.
offset
)
/
(
y_embed
[:,
-
1
:,
:]
+
self
.
eps
)
*
self
.
scale
x_embed
=
(
x_embed
+
self
.
offset
)
/
(
x_embed
[:,
:,
-
1
:]
+
self
.
eps
)
*
self
.
scale
dim_t
=
2
*
(
paddle
.
arange
(
self
.
num_pos_feats
)
//
2
).
astype
(
'float32'
)
...
...
ppdet/modeling/transformers/utils.py
浏览文件 @
e8aeb802
...
...
@@ -25,7 +25,8 @@ from ..bbox_utils import bbox_overlaps
__all__
=
[
'_get_clones'
,
'bbox_overlaps'
,
'bbox_cxcywh_to_xyxy'
,
'bbox_xyxy_to_cxcywh'
,
'sigmoid_focal_loss'
'bbox_xyxy_to_cxcywh'
,
'sigmoid_focal_loss'
,
'inverse_sigmoid'
,
'deformable_attention_core_func'
]
...
...
@@ -55,3 +56,51 @@ def sigmoid_focal_loss(logit, label, normalizer=1.0, alpha=0.25, gamma=2.0):
alpha_t
=
alpha
*
label
+
(
1
-
alpha
)
*
(
1
-
label
)
loss
=
alpha_t
*
loss
return
loss
.
mean
(
1
).
sum
()
/
normalizer
def
inverse_sigmoid
(
x
,
eps
=
1e-6
):
x
=
x
.
clip
(
min
=
0.
,
max
=
1.
)
return
paddle
.
log
(
x
/
(
1
-
x
+
eps
)
+
eps
)
def
deformable_attention_core_func
(
value
,
value_spatial_shapes
,
sampling_locations
,
attention_weights
):
"""
Args:
value (Tensor): [bs, value_length, n_head, c]
value_spatial_shapes (Tensor): [n_levels, 2]
sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2]
attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points]
Returns:
output (Tensor): [bs, Length_{query}, C]
"""
bs
,
Len_v
,
n_head
,
c
=
value
.
shape
_
,
Len_q
,
n_head
,
n_levels
,
n_points
,
_
=
sampling_locations
.
shape
value_list
=
value
.
split
(
value_spatial_shapes
.
prod
(
1
).
tolist
(),
axis
=
1
)
sampling_grids
=
2
*
sampling_locations
-
1
sampling_value_list
=
[]
for
level
,
(
h
,
w
)
in
enumerate
(
value_spatial_shapes
.
tolist
()):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_
=
value_list
[
level
].
flatten
(
2
).
transpose
(
[
0
,
2
,
1
]).
reshape
([
bs
*
n_head
,
c
,
h
,
w
])
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
sampling_grid_l_
=
sampling_grids
[:,
:,
:,
level
].
transpose
(
[
0
,
2
,
1
,
3
,
4
]).
flatten
(
0
,
1
)
# N_*M_, D_, Lq_, P_
sampling_value_l_
=
F
.
grid_sample
(
value_l_
,
sampling_grid_l_
,
mode
=
'bilinear'
,
padding_mode
=
'zeros'
,
align_corners
=
False
)
sampling_value_list
.
append
(
sampling_value_l_
)
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_*M_, 1, Lq_, L_*P_)
attention_weights
=
attention_weights
.
transpose
([
0
,
2
,
1
,
3
,
4
]).
reshape
(
[
bs
*
n_head
,
1
,
Len_q
,
n_levels
*
n_points
])
output
=
(
paddle
.
stack
(
sampling_value_list
,
axis
=-
2
).
flatten
(
-
2
)
*
attention_weights
).
sum
(
-
1
).
reshape
([
bs
,
n_head
*
c
,
Len_q
])
return
output
.
transpose
([
0
,
2
,
1
])
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录