Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
841f2f4e
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 2 年 前同步成功
通知
708
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
841f2f4e
编写于
7月 07, 2021
作者:
F
FL77N
提交者:
GitHub
7月 07, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sparsercnn (#3623)
* add sparsercnn * update sparsercnn
上级
bb846096
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
1010 addition
and
2 deletion
+1010
-2
ppdet/data/transform/batch_operators.py
ppdet/data/transform/batch_operators.py
+26
-1
ppdet/modeling/architectures/__init__.py
ppdet/modeling/architectures/__init__.py
+2
-0
ppdet/modeling/architectures/sparse_rcnn.py
ppdet/modeling/architectures/sparse_rcnn.py
+99
-0
ppdet/modeling/heads/__init__.py
ppdet/modeling/heads/__init__.py
+2
-0
ppdet/modeling/heads/sparsercnn_head.py
ppdet/modeling/heads/sparsercnn_head.py
+371
-0
ppdet/modeling/losses/__init__.py
ppdet/modeling/losses/__init__.py
+2
-0
ppdet/modeling/losses/sparsercnn_loss.py
ppdet/modeling/losses/sparsercnn_loss.py
+420
-0
ppdet/modeling/post_process.py
ppdet/modeling/post_process.py
+88
-1
未找到文件。
ppdet/data/transform/batch_operators.py
浏览文件 @
841f2f4e
...
...
@@ -33,7 +33,7 @@ logger = setup_logger(__name__)
__all__
=
[
'PadBatch'
,
'BatchRandomResize'
,
'Gt2YoloTarget'
,
'Gt2FCOSTarget'
,
'Gt2TTFTarget'
,
'Gt2Solov2Target'
'Gt2TTFTarget'
,
'Gt2Solov2Target'
,
'Gt2SparseRCNNTarget'
]
...
...
@@ -746,3 +746,28 @@ class Gt2Solov2Target(BaseOperator):
data
[
'grid_order{}'
.
format
(
idx
)]
=
gt_grid_order
return
samples
@
register_op
class
Gt2SparseRCNNTarget
(
BaseOperator
):
'''
Generate SparseRCNN targets by groud truth data
'''
def
__init__
(
self
):
super
(
Gt2SparseRCNNTarget
,
self
).
__init__
()
def
__call__
(
self
,
samples
,
context
=
None
):
for
sample
in
samples
:
im
=
sample
[
"image"
]
h
,
w
=
im
.
shape
[
1
:
3
]
img_whwh
=
np
.
array
([
w
,
h
,
w
,
h
],
dtype
=
np
.
int32
)
sample
[
"img_whwh"
]
=
img_whwh
if
"scale_factor"
in
sample
:
sample
[
"scale_factor_wh"
]
=
np
.
array
([
sample
[
"scale_factor"
][
1
],
sample
[
"scale_factor"
][
0
]],
dtype
=
np
.
float32
)
sample
.
pop
(
"scale_factor"
)
else
:
sample
[
"scale_factor_wh"
]
=
np
.
array
([
1.0
,
1.0
],
dtype
=
np
.
float32
)
return
samples
ppdet/modeling/architectures/__init__.py
浏览文件 @
841f2f4e
...
...
@@ -22,6 +22,7 @@ from . import deepsort
from
.
import
fairmot
from
.
import
centernet
from
.
import
detr
from
.
import
sparse_rcnn
from
.meta_arch
import
*
from
.faster_rcnn
import
*
...
...
@@ -41,3 +42,4 @@ from .fairmot import *
from
.centernet
import
*
from
.blazeface
import
*
from
.detr
import
*
from
.sparse_rcnn
import
*
ppdet/modeling/architectures/sparse_rcnn.py
0 → 100644
浏览文件 @
841f2f4e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
ppdet.core.workspace
import
register
,
create
from
.meta_arch
import
BaseArch
__all__
=
[
"SparseRCNN"
]
@
register
class
SparseRCNN
(
BaseArch
):
__category__
=
'architecture'
__inject__
=
[
"postprocess"
]
def
__init__
(
self
,
backbone
,
neck
,
head
=
"SparsercnnHead"
,
postprocess
=
"SparsePostProcess"
):
super
(
SparseRCNN
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
neck
=
neck
self
.
head
=
head
self
.
postprocess
=
postprocess
@
classmethod
def
from_config
(
cls
,
cfg
,
*
args
,
**
kwargs
):
backbone
=
create
(
cfg
[
'backbone'
])
kwargs
=
{
'input_shape'
:
backbone
.
out_shape
}
neck
=
create
(
cfg
[
'neck'
],
**
kwargs
)
kwargs
=
{
'roi_input_shape'
:
neck
.
out_shape
}
head
=
create
(
cfg
[
'head'
],
**
kwargs
)
return
{
'backbone'
:
backbone
,
'neck'
:
neck
,
"head"
:
head
,
}
def
_forward
(
self
):
body_feats
=
self
.
backbone
(
self
.
inputs
)
fpn_feats
=
self
.
neck
(
body_feats
)
head_outs
=
self
.
head
(
fpn_feats
,
self
.
inputs
[
"img_whwh"
])
if
not
self
.
training
:
bboxes
=
self
.
postprocess
(
head_outs
[
"pred_logits"
],
head_outs
[
"pred_boxes"
],
self
.
inputs
[
"scale_factor_wh"
],
self
.
inputs
[
"img_whwh"
])
return
bboxes
else
:
return
head_outs
def
get_loss
(
self
):
batch_gt_class
=
self
.
inputs
[
"gt_class"
]
batch_gt_box
=
self
.
inputs
[
"gt_bbox"
]
batch_whwh
=
self
.
inputs
[
"img_whwh"
]
targets
=
[]
for
i
in
range
(
len
(
batch_gt_class
)):
boxes
=
batch_gt_box
[
i
]
labels
=
batch_gt_class
[
i
].
squeeze
(
-
1
)
img_whwh
=
batch_whwh
[
i
]
img_whwh_tgt
=
img_whwh
.
unsqueeze
(
0
).
tile
([
int
(
boxes
.
shape
[
0
]),
1
])
targets
.
append
({
"boxes"
:
boxes
,
"labels"
:
labels
,
"img_whwh"
:
img_whwh
,
"img_whwh_tgt"
:
img_whwh_tgt
})
outputs
=
self
.
_forward
()
loss_dict
=
self
.
head
.
get_loss
(
outputs
,
targets
)
acc
=
loss_dict
[
"acc"
]
loss_dict
.
pop
(
"acc"
)
total_loss
=
sum
(
loss_dict
.
values
())
loss_dict
.
update
({
"loss"
:
total_loss
,
"acc"
:
acc
})
return
loss_dict
def
get_pred
(
self
):
bbox_pred
,
bbox_num
=
self
.
_forward
()
output
=
{
'bbox'
:
bbox_pred
,
'bbox_num'
:
bbox_num
}
return
output
ppdet/modeling/heads/__init__.py
浏览文件 @
841f2f4e
...
...
@@ -26,6 +26,7 @@ from . import s2anet_head
from
.
import
keypoint_hrhrnet_head
from
.
import
centernet_head
from
.
import
detr_head
from
.
import
sparsercnn_head
from
.bbox_head
import
*
from
.mask_head
import
*
...
...
@@ -41,3 +42,4 @@ from .s2anet_head import *
from
.keypoint_hrhrnet_head
import
*
from
.centernet_head
import
*
from
.detr_head
import
*
from
.sparsercnn_head
import
*
ppdet/modeling/heads/sparsercnn_head.py
0 → 100644
浏览文件 @
841f2f4e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
copy
import
paddle
import
paddle.nn
as
nn
import
ppdet.modeling.initializer
as
init
from
ppdet.core.workspace
import
register
from
ppdet.modeling.heads.roi_extractor
import
RoIAlign
from
ppdet.modeling.bbox_utils
import
delta2bbox
_DEFAULT_SCALE_CLAMP
=
math
.
log
(
100000.
/
16
)
class
DynamicConv
(
nn
.
Layer
):
def
__init__
(
self
,
head_hidden_dim
,
head_dim_dynamic
,
head_num_dynamic
,
):
super
().
__init__
()
self
.
hidden_dim
=
head_hidden_dim
self
.
dim_dynamic
=
head_dim_dynamic
self
.
num_dynamic
=
head_num_dynamic
self
.
num_params
=
self
.
hidden_dim
*
self
.
dim_dynamic
self
.
dynamic_layer
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
num_dynamic
*
self
.
num_params
)
self
.
norm1
=
nn
.
LayerNorm
(
self
.
dim_dynamic
)
self
.
norm2
=
nn
.
LayerNorm
(
self
.
hidden_dim
)
self
.
activation
=
nn
.
ReLU
()
pooler_resolution
=
7
num_output
=
self
.
hidden_dim
*
pooler_resolution
**
2
self
.
out_layer
=
nn
.
Linear
(
num_output
,
self
.
hidden_dim
)
self
.
norm3
=
nn
.
LayerNorm
(
self
.
hidden_dim
)
def
forward
(
self
,
pro_features
,
roi_features
):
'''
pro_features: (1, N * nr_boxes, self.d_model)
roi_features: (49, N * nr_boxes, self.d_model)
'''
features
=
roi_features
.
transpose
(
perm
=
[
1
,
0
,
2
])
parameters
=
self
.
dynamic_layer
(
pro_features
).
transpose
(
perm
=
[
1
,
0
,
2
])
param1
=
parameters
[:,
:,
:
self
.
num_params
].
reshape
(
[
-
1
,
self
.
hidden_dim
,
self
.
dim_dynamic
])
param2
=
parameters
[:,
:,
self
.
num_params
:].
reshape
(
[
-
1
,
self
.
dim_dynamic
,
self
.
hidden_dim
])
features
=
paddle
.
bmm
(
features
,
param1
)
features
=
self
.
norm1
(
features
)
features
=
self
.
activation
(
features
)
features
=
paddle
.
bmm
(
features
,
param2
)
features
=
self
.
norm2
(
features
)
features
=
self
.
activation
(
features
)
features
=
features
.
flatten
(
1
)
features
=
self
.
out_layer
(
features
)
features
=
self
.
norm3
(
features
)
features
=
self
.
activation
(
features
)
return
features
class
RCNNHead
(
nn
.
Layer
):
def
__init__
(
self
,
d_model
,
num_classes
,
dim_feedforward
,
nhead
,
dropout
,
head_cls
,
head_reg
,
head_dim_dynamic
,
head_num_dynamic
,
scale_clamp
:
float
=
_DEFAULT_SCALE_CLAMP
,
bbox_weights
=
(
2.0
,
2.0
,
1.0
,
1.0
),
):
super
().
__init__
()
self
.
d_model
=
d_model
# dynamic.
self
.
self_attn
=
nn
.
MultiHeadAttention
(
d_model
,
nhead
,
dropout
=
dropout
)
self
.
inst_interact
=
DynamicConv
(
d_model
,
head_dim_dynamic
,
head_num_dynamic
)
self
.
linear1
=
nn
.
Linear
(
d_model
,
dim_feedforward
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
linear2
=
nn
.
Linear
(
dim_feedforward
,
d_model
)
self
.
norm1
=
nn
.
LayerNorm
(
d_model
)
self
.
norm2
=
nn
.
LayerNorm
(
d_model
)
self
.
norm3
=
nn
.
LayerNorm
(
d_model
)
self
.
dropout1
=
nn
.
Dropout
(
dropout
)
self
.
dropout2
=
nn
.
Dropout
(
dropout
)
self
.
dropout3
=
nn
.
Dropout
(
dropout
)
self
.
activation
=
nn
.
ReLU
()
# cls.
num_cls
=
head_cls
cls_module
=
list
()
for
_
in
range
(
num_cls
):
cls_module
.
append
(
nn
.
Linear
(
d_model
,
d_model
,
bias_attr
=
False
))
cls_module
.
append
(
nn
.
LayerNorm
(
d_model
))
cls_module
.
append
(
nn
.
ReLU
())
self
.
cls_module
=
nn
.
LayerList
(
cls_module
)
# reg.
num_reg
=
head_reg
reg_module
=
list
()
for
_
in
range
(
num_reg
):
reg_module
.
append
(
nn
.
Linear
(
d_model
,
d_model
,
bias_attr
=
False
))
reg_module
.
append
(
nn
.
LayerNorm
(
d_model
))
reg_module
.
append
(
nn
.
ReLU
())
self
.
reg_module
=
nn
.
LayerList
(
reg_module
)
# pred.
self
.
class_logits
=
nn
.
Linear
(
d_model
,
num_classes
)
self
.
bboxes_delta
=
nn
.
Linear
(
d_model
,
4
)
self
.
scale_clamp
=
scale_clamp
self
.
bbox_weights
=
bbox_weights
def
forward
(
self
,
features
,
bboxes
,
pro_features
,
pooler
):
"""
:param bboxes: (N, nr_boxes, 4)
:param pro_features: (N, nr_boxes, d_model)
"""
N
,
nr_boxes
=
bboxes
.
shape
[:
2
]
proposal_boxes
=
list
()
for
b
in
range
(
N
):
proposal_boxes
.
append
(
bboxes
[
b
])
roi_num
=
paddle
.
full
([
N
],
nr_boxes
).
astype
(
"int32"
)
roi_features
=
pooler
(
features
,
proposal_boxes
,
roi_num
)
roi_features
=
roi_features
.
reshape
(
[
N
*
nr_boxes
,
self
.
d_model
,
-
1
]).
transpose
(
perm
=
[
2
,
0
,
1
])
# self_att.
pro_features
=
pro_features
.
reshape
([
N
,
nr_boxes
,
self
.
d_model
])
pro_features2
=
self
.
self_attn
(
pro_features
,
pro_features
,
value
=
pro_features
)
pro_features
=
pro_features
.
transpose
(
perm
=
[
1
,
0
,
2
])
+
self
.
dropout1
(
pro_features2
.
transpose
(
perm
=
[
1
,
0
,
2
]))
pro_features
=
self
.
norm1
(
pro_features
)
# inst_interact.
pro_features
=
pro_features
.
reshape
(
[
nr_boxes
,
N
,
self
.
d_model
]).
transpose
(
perm
=
[
1
,
0
,
2
]).
reshape
(
[
1
,
N
*
nr_boxes
,
self
.
d_model
])
pro_features2
=
self
.
inst_interact
(
pro_features
,
roi_features
)
pro_features
=
pro_features
+
self
.
dropout2
(
pro_features2
)
obj_features
=
self
.
norm2
(
pro_features
)
# obj_feature.
obj_features2
=
self
.
linear2
(
self
.
dropout
(
self
.
activation
(
self
.
linear1
(
obj_features
))))
obj_features
=
obj_features
+
self
.
dropout3
(
obj_features2
)
obj_features
=
self
.
norm3
(
obj_features
)
fc_feature
=
obj_features
.
transpose
(
perm
=
[
1
,
0
,
2
]).
reshape
(
[
N
*
nr_boxes
,
-
1
])
cls_feature
=
fc_feature
.
clone
()
reg_feature
=
fc_feature
.
clone
()
for
cls_layer
in
self
.
cls_module
:
cls_feature
=
cls_layer
(
cls_feature
)
for
reg_layer
in
self
.
reg_module
:
reg_feature
=
reg_layer
(
reg_feature
)
class_logits
=
self
.
class_logits
(
cls_feature
)
bboxes_deltas
=
self
.
bboxes_delta
(
reg_feature
)
pred_bboxes
=
delta2bbox
(
bboxes_deltas
,
bboxes
.
reshape
([
-
1
,
4
]),
self
.
bbox_weights
)
return
class_logits
.
reshape
([
N
,
nr_boxes
,
-
1
]),
pred_bboxes
.
reshape
(
[
N
,
nr_boxes
,
-
1
]),
obj_features
@
register
class
SparseRCNNHead
(
nn
.
Layer
):
'''
SparsercnnHead
Args:
roi_input_shape (list[ShapeSpec]): The output shape of fpn
num_classes (int): Number of classes,
head_hidden_dim (int): The param of MultiHeadAttention,
head_dim_feedforward (int): The param of MultiHeadAttention,
nhead (int): The param of MultiHeadAttention,
head_dropout (float): The p of dropout,
head_cls (int): The number of class head,
head_reg (int): The number of regressionhead,
head_num_dynamic (int): The number of DynamicConv's param,
head_num_heads (int): The number of RCNNHead,
deep_supervision (int): wheather supervise the intermediate results,
num_proposals (int): the number of proposals boxes and features
'''
__inject__
=
[
'loss_func'
]
__shared__
=
[
'num_classes'
]
def
__init__
(
self
,
head_hidden_dim
,
head_dim_feedforward
,
nhead
,
head_dropout
,
head_cls
,
head_reg
,
head_dim_dynamic
,
head_num_dynamic
,
head_num_heads
,
deep_supervision
,
num_proposals
,
num_classes
=
80
,
loss_func
=
"SparseRCNNLoss"
,
roi_input_shape
=
None
,
):
super
().
__init__
()
# Build RoI.
box_pooler
=
self
.
_init_box_pooler
(
roi_input_shape
)
self
.
box_pooler
=
box_pooler
# Build heads.
rcnn_head
=
RCNNHead
(
head_hidden_dim
,
num_classes
,
head_dim_feedforward
,
nhead
,
head_dropout
,
head_cls
,
head_reg
,
head_dim_dynamic
,
head_num_dynamic
,
)
self
.
head_series
=
nn
.
LayerList
(
[
copy
.
deepcopy
(
rcnn_head
)
for
i
in
range
(
head_num_heads
)])
self
.
return_intermediate
=
deep_supervision
self
.
num_classes
=
num_classes
# build init proposal
self
.
init_proposal_features
=
nn
.
Embedding
(
num_proposals
,
head_hidden_dim
)
self
.
init_proposal_boxes
=
nn
.
Embedding
(
num_proposals
,
4
)
self
.
lossfunc
=
loss_func
# Init parameters.
init
.
reset_initialized_parameter
(
self
)
self
.
_reset_parameters
()
def
_reset_parameters
(
self
):
# init all parameters.
prior_prob
=
0.01
bias_value
=
-
math
.
log
((
1
-
prior_prob
)
/
prior_prob
)
for
m
in
self
.
sublayers
():
if
isinstance
(
m
,
nn
.
Linear
):
init
.
xavier_normal_
(
m
.
weight
,
reverse
=
True
)
elif
not
isinstance
(
m
,
nn
.
Embedding
)
and
hasattr
(
m
,
"weight"
)
and
m
.
weight
.
dim
()
>
1
:
init
.
xavier_normal_
(
m
.
weight
,
reverse
=
False
)
if
hasattr
(
m
,
"bias"
)
and
m
.
bias
is
not
None
and
m
.
bias
.
shape
[
-
1
]
==
self
.
num_classes
:
init
.
constant_
(
m
.
bias
,
bias_value
)
init_bboxes
=
paddle
.
empty_like
(
self
.
init_proposal_boxes
.
weight
)
init_bboxes
[:,
:
2
]
=
0.5
init_bboxes
[:,
2
:]
=
1.0
self
.
init_proposal_boxes
.
weight
.
set_value
(
init_bboxes
)
@
staticmethod
def
_init_box_pooler
(
input_shape
):
pooler_resolution
=
7
sampling_ratio
=
2
if
input_shape
is
not
None
:
pooler_scales
=
tuple
(
1.0
/
input_shape
[
k
].
stride
for
k
in
range
(
len
(
input_shape
)))
in_channels
=
[
input_shape
[
f
].
channels
for
f
in
range
(
len
(
input_shape
))
]
end_level
=
len
(
input_shape
)
-
1
# Check all channel counts are equal
assert
len
(
set
(
in_channels
))
==
1
,
in_channels
else
:
pooler_scales
=
[
1.0
/
4.0
,
1.0
/
8.0
,
1.0
/
16.0
,
1.0
/
32.0
]
end_level
=
3
box_pooler
=
RoIAlign
(
resolution
=
pooler_resolution
,
spatial_scale
=
pooler_scales
,
sampling_ratio
=
sampling_ratio
,
end_level
=
end_level
,
aligned
=
True
)
return
box_pooler
def
forward
(
self
,
features
,
input_whwh
):
bs
=
len
(
features
[
0
])
bboxes
=
box_cxcywh_to_xyxy
(
self
.
init_proposal_boxes
.
weight
.
clone
(
)).
unsqueeze
(
0
)
bboxes
=
bboxes
*
input_whwh
.
unsqueeze
(
-
2
)
init_features
=
self
.
init_proposal_features
.
weight
.
unsqueeze
(
0
).
tile
(
[
1
,
bs
,
1
])
proposal_features
=
init_features
.
clone
()
inter_class_logits
=
[]
inter_pred_bboxes
=
[]
for
rcnn_head
in
self
.
head_series
:
class_logits
,
pred_bboxes
,
proposal_features
=
rcnn_head
(
features
,
bboxes
,
proposal_features
,
self
.
box_pooler
)
if
self
.
return_intermediate
:
inter_class_logits
.
append
(
class_logits
)
inter_pred_bboxes
.
append
(
pred_bboxes
)
bboxes
=
pred_bboxes
.
detach
()
output
=
{
'pred_logits'
:
inter_class_logits
[
-
1
],
'pred_boxes'
:
inter_pred_bboxes
[
-
1
]
}
if
self
.
return_intermediate
:
output
[
'aux_outputs'
]
=
[{
'pred_logits'
:
a
,
'pred_boxes'
:
b
}
for
a
,
b
in
zip
(
inter_class_logits
[:
-
1
],
inter_pred_bboxes
[:
-
1
])]
return
output
def
get_loss
(
self
,
outputs
,
targets
):
losses
=
self
.
lossfunc
(
outputs
,
targets
)
weight_dict
=
self
.
lossfunc
.
weight_dict
for
k
in
losses
.
keys
():
if
k
in
weight_dict
:
losses
[
k
]
*=
weight_dict
[
k
]
return
losses
def
box_cxcywh_to_xyxy
(
x
):
x_c
,
y_c
,
w
,
h
=
x
.
unbind
(
-
1
)
b
=
[(
x_c
-
0.5
*
w
),
(
y_c
-
0.5
*
h
),
(
x_c
+
0.5
*
w
),
(
y_c
+
0.5
*
h
)]
return
paddle
.
stack
(
b
,
axis
=-
1
)
\ No newline at end of file
ppdet/modeling/losses/__init__.py
浏览文件 @
841f2f4e
...
...
@@ -23,6 +23,7 @@ from . import keypoint_loss
from
.
import
jde_loss
from
.
import
fairmot_loss
from
.
import
detr_loss
from
.
import
sparsercnn_loss
from
.yolo_loss
import
*
from
.iou_aware_loss
import
*
...
...
@@ -35,3 +36,4 @@ from .keypoint_loss import *
from
.jde_loss
import
*
from
.fairmot_loss
import
*
from
.detr_loss
import
*
from
.sparsercnn_loss
import
*
ppdet/modeling/losses/sparsercnn_loss.py
0 → 100644
浏览文件 @
841f2f4e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
scipy.optimize
import
linear_sum_assignment
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.metric
import
accuracy
from
ppdet.core.workspace
import
register
from
ppdet.modeling.losses.iou_loss
import
GIoULoss
__all__
=
[
"SparseRCNNLoss"
]
@
register
class
SparseRCNNLoss
(
nn
.
Layer
):
""" This class computes the loss for SparseRCNN.
The process happens in two steps:
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
"""
__shared__
=
[
'num_classes'
]
def
__init__
(
self
,
losses
,
focal_loss_alpha
,
focal_loss_gamma
,
num_classes
=
80
,
class_weight
=
2.
,
l1_weight
=
5.
,
giou_weight
=
2.
):
""" Create the criterion.
Parameters:
num_classes: number of object categories, omitting the special no-object category
weight_dict: dict containing as key the names of the losses and as values their relative weight.
losses: list of all the losses to be applied. See get_loss for list of available losses.
matcher: module able to compute a matching between targets and proposals
"""
super
().
__init__
()
self
.
num_classes
=
num_classes
weight_dict
=
{
"loss_ce"
:
class_weight
,
"loss_bbox"
:
l1_weight
,
"loss_giou"
:
giou_weight
}
self
.
weight_dict
=
weight_dict
self
.
losses
=
losses
self
.
giou_loss
=
GIoULoss
(
reduction
=
"sum"
)
self
.
focal_loss_alpha
=
focal_loss_alpha
self
.
focal_loss_gamma
=
focal_loss_gamma
self
.
matcher
=
HungarianMatcher
(
focal_loss_alpha
,
focal_loss_gamma
,
class_weight
,
l1_weight
,
giou_weight
)
def
loss_labels
(
self
,
outputs
,
targets
,
indices
,
num_boxes
,
log
=
True
):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
assert
'pred_logits'
in
outputs
src_logits
=
outputs
[
'pred_logits'
]
idx
=
self
.
_get_src_permutation_idx
(
indices
)
target_classes_o
=
paddle
.
concat
([
paddle
.
gather
(
t
[
"labels"
],
J
,
axis
=
0
)
for
t
,
(
_
,
J
)
in
zip
(
targets
,
indices
)
])
target_classes
=
paddle
.
full
(
src_logits
.
shape
[:
2
],
self
.
num_classes
,
dtype
=
"int32"
)
for
i
,
ind
in
enumerate
(
zip
(
idx
[
0
],
idx
[
1
])):
target_classes
[
int
(
ind
[
0
]),
int
(
ind
[
1
])]
=
target_classes_o
[
i
]
target_classes
.
stop_gradient
=
True
src_logits
=
src_logits
.
flatten
(
start_axis
=
0
,
stop_axis
=
1
)
# prepare one_hot target.
target_classes
=
target_classes
.
flatten
(
start_axis
=
0
,
stop_axis
=
1
)
class_ids
=
paddle
.
arange
(
0
,
self
.
num_classes
)
labels
=
(
target_classes
.
unsqueeze
(
-
1
)
==
class_ids
).
astype
(
"float32"
)
labels
.
stop_gradient
=
True
# comp focal loss.
class_loss
=
sigmoid_focal_loss
(
src_logits
,
labels
,
alpha
=
self
.
focal_loss_alpha
,
gamma
=
self
.
focal_loss_gamma
,
reduction
=
"sum"
,
)
/
num_boxes
losses
=
{
'loss_ce'
:
class_loss
}
if
log
:
label_acc
=
target_classes_o
.
unsqueeze
(
-
1
)
src_idx
=
[
src
for
(
src
,
_
)
in
indices
]
pred_list
=
[]
for
i
in
range
(
outputs
[
"pred_logits"
].
shape
[
0
]):
pred_list
.
append
(
paddle
.
gather
(
outputs
[
"pred_logits"
][
i
],
src_idx
[
i
],
axis
=
0
))
pred
=
F
.
sigmoid
(
paddle
.
concat
(
pred_list
,
axis
=
0
))
acc
=
accuracy
(
pred
,
label_acc
.
astype
(
"int64"
))
losses
[
"acc"
]
=
acc
return
losses
def
loss_boxes
(
self
,
outputs
,
targets
,
indices
,
num_boxes
):
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
"""
assert
'pred_boxes'
in
outputs
# [batch_size, num_proposals, 4]
src_idx
=
[
src
for
(
src
,
_
)
in
indices
]
src_boxes_list
=
[]
for
i
in
range
(
outputs
[
"pred_boxes"
].
shape
[
0
]):
src_boxes_list
.
append
(
paddle
.
gather
(
outputs
[
"pred_boxes"
][
i
],
src_idx
[
i
],
axis
=
0
))
src_boxes
=
paddle
.
concat
(
src_boxes_list
,
axis
=
0
)
target_boxes
=
paddle
.
concat
(
[
paddle
.
gather
(
t
[
'boxes'
],
I
,
axis
=
0
)
for
t
,
(
_
,
I
)
in
zip
(
targets
,
indices
)
],
axis
=
0
)
target_boxes
.
stop_gradient
=
True
losses
=
{}
losses
[
'loss_giou'
]
=
self
.
giou_loss
(
src_boxes
,
target_boxes
)
/
num_boxes
image_size
=
paddle
.
concat
([
v
[
"img_whwh_tgt"
]
for
v
in
targets
])
src_boxes_
=
src_boxes
/
image_size
target_boxes_
=
target_boxes
/
image_size
loss_bbox
=
F
.
l1_loss
(
src_boxes_
,
target_boxes_
,
reduction
=
'sum'
)
losses
[
'loss_bbox'
]
=
loss_bbox
/
num_boxes
return
losses
def
_get_src_permutation_idx
(
self
,
indices
):
# permute predictions following indices
batch_idx
=
paddle
.
concat
(
[
paddle
.
full_like
(
src
,
i
)
for
i
,
(
src
,
_
)
in
enumerate
(
indices
)])
src_idx
=
paddle
.
concat
([
src
for
(
src
,
_
)
in
indices
])
return
batch_idx
,
src_idx
def
_get_tgt_permutation_idx
(
self
,
indices
):
# permute targets following indices
batch_idx
=
paddle
.
concat
(
[
paddle
.
full_like
(
tgt
,
i
)
for
i
,
(
_
,
tgt
)
in
enumerate
(
indices
)])
tgt_idx
=
paddle
.
concat
([
tgt
for
(
_
,
tgt
)
in
indices
])
return
batch_idx
,
tgt_idx
def
get_loss
(
self
,
loss
,
outputs
,
targets
,
indices
,
num_boxes
,
**
kwargs
):
loss_map
=
{
'labels'
:
self
.
loss_labels
,
'boxes'
:
self
.
loss_boxes
,
}
assert
loss
in
loss_map
,
f
'do you really want to compute
{
loss
}
loss?'
return
loss_map
[
loss
](
outputs
,
targets
,
indices
,
num_boxes
,
**
kwargs
)
def
forward
(
self
,
outputs
,
targets
):
""" This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
"""
outputs_without_aux
=
{
k
:
v
for
k
,
v
in
outputs
.
items
()
if
k
!=
'aux_outputs'
}
# Retrieve the matching between the outputs of the last layer and the targets
indices
=
self
.
matcher
(
outputs_without_aux
,
targets
)
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes
=
sum
(
len
(
t
[
"labels"
])
for
t
in
targets
)
num_boxes
=
paddle
.
to_tensor
(
[
num_boxes
],
dtype
=
"float32"
,
place
=
next
(
iter
(
outputs
.
values
())).
place
)
# Compute all the requested losses
losses
=
{}
for
loss
in
self
.
losses
:
losses
.
update
(
self
.
get_loss
(
loss
,
outputs
,
targets
,
indices
,
num_boxes
))
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if
'aux_outputs'
in
outputs
:
for
i
,
aux_outputs
in
enumerate
(
outputs
[
'aux_outputs'
]):
indices
=
self
.
matcher
(
aux_outputs
,
targets
)
for
loss
in
self
.
losses
:
kwargs
=
{}
if
loss
==
'labels'
:
# Logging is enabled only for the last layer
kwargs
=
{
'log'
:
False
}
l_dict
=
self
.
get_loss
(
loss
,
aux_outputs
,
targets
,
indices
,
num_boxes
,
**
kwargs
)
w_dict
=
{}
for
k
in
l_dict
.
keys
():
if
k
in
self
.
weight_dict
:
w_dict
[
k
+
f
'_
{
i
}
'
]
=
l_dict
[
k
]
*
self
.
weight_dict
[
k
]
else
:
w_dict
[
k
+
f
'_
{
i
}
'
]
=
l_dict
[
k
]
losses
.
update
(
w_dict
)
return
losses
class
HungarianMatcher
(
nn
.
Layer
):
"""This class computes an assignment between the targets and the predictions of the network
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
while the others are un-matched (and thus treated as non-objects).
"""
def
__init__
(
self
,
focal_loss_alpha
,
focal_loss_gamma
,
cost_class
:
float
=
1
,
cost_bbox
:
float
=
1
,
cost_giou
:
float
=
1
):
"""Creates the matcher
Params:
cost_class: This is the relative weight of the classification error in the matching cost
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
"""
super
().
__init__
()
self
.
cost_class
=
cost_class
self
.
cost_bbox
=
cost_bbox
self
.
cost_giou
=
cost_giou
self
.
focal_loss_alpha
=
focal_loss_alpha
self
.
focal_loss_gamma
=
focal_loss_gamma
assert
cost_class
!=
0
or
cost_bbox
!=
0
or
cost_giou
!=
0
,
"all costs cant be 0"
@
paddle
.
no_grad
()
def
forward
(
self
,
outputs
,
targets
):
""" Performs the matching
Args:
outputs: This is a dict that contains at least these entries:
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
eg. outputs = {"pred_logits": pred_logits, "pred_boxes": pred_boxes}
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
objects in the target) containing the class labels
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
eg. targets = [{"labels":labels, "boxes": boxes}, ...,{"labels":labels, "boxes": boxes}]
Returns:
A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected targets (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
bs
,
num_queries
=
outputs
[
"pred_logits"
].
shape
[:
2
]
# We flatten to compute the cost matrices in a batch
out_prob
=
F
.
sigmoid
(
outputs
[
"pred_logits"
].
flatten
(
start_axis
=
0
,
stop_axis
=
1
))
out_bbox
=
outputs
[
"pred_boxes"
].
flatten
(
start_axis
=
0
,
stop_axis
=
1
)
# Also concat the target labels and boxes
tgt_ids
=
paddle
.
concat
([
v
[
"labels"
]
for
v
in
targets
])
assert
(
tgt_ids
>
-
1
).
all
()
tgt_bbox
=
paddle
.
concat
([
v
[
"boxes"
]
for
v
in
targets
])
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
# Compute the classification cost.
alpha
=
self
.
focal_loss_alpha
gamma
=
self
.
focal_loss_gamma
neg_cost_class
=
(
1
-
alpha
)
*
(
out_prob
**
gamma
)
*
(
-
(
1
-
out_prob
+
1e-8
).
log
())
pos_cost_class
=
alpha
*
((
1
-
out_prob
)
**
gamma
)
*
(
-
(
out_prob
+
1e-8
).
log
())
cost_class
=
paddle
.
gather
(
pos_cost_class
,
tgt_ids
,
axis
=
1
)
-
paddle
.
gather
(
neg_cost_class
,
tgt_ids
,
axis
=
1
)
# Compute the L1 cost between boxes
image_size_out
=
paddle
.
concat
(
[
v
[
"img_whwh"
].
unsqueeze
(
0
)
for
v
in
targets
])
image_size_out
=
image_size_out
.
unsqueeze
(
1
).
tile
(
[
1
,
num_queries
,
1
]).
flatten
(
start_axis
=
0
,
stop_axis
=
1
)
image_size_tgt
=
paddle
.
concat
([
v
[
"img_whwh_tgt"
]
for
v
in
targets
])
out_bbox_
=
out_bbox
/
image_size_out
tgt_bbox_
=
tgt_bbox
/
image_size_tgt
cost_bbox
=
F
.
l1_loss
(
out_bbox_
.
unsqueeze
(
-
2
),
tgt_bbox_
,
reduction
=
'none'
).
sum
(
-
1
)
# [batch_size * num_queries, num_tgts]
# Compute the giou cost betwen boxes
cost_giou
=
-
get_bboxes_giou
(
out_bbox
,
tgt_bbox
)
# Final cost matrix
C
=
self
.
cost_bbox
*
cost_bbox
+
self
.
cost_class
*
cost_class
+
self
.
cost_giou
*
cost_giou
C
=
C
.
reshape
([
bs
,
num_queries
,
-
1
])
sizes
=
[
len
(
v
[
"boxes"
])
for
v
in
targets
]
indices
=
[
linear_sum_assignment
(
c
[
i
].
numpy
())
for
i
,
c
in
enumerate
(
C
.
split
(
sizes
,
-
1
))
]
return
[(
paddle
.
to_tensor
(
i
,
dtype
=
"int32"
),
paddle
.
to_tensor
(
j
,
dtype
=
"int32"
))
for
i
,
j
in
indices
]
def
box_area
(
boxes
):
assert
(
boxes
[:,
2
:]
>=
boxes
[:,
:
2
]).
all
()
wh
=
boxes
[:,
2
:]
-
boxes
[:,
:
2
]
return
wh
[:,
0
]
*
wh
[:,
1
]
def
boxes_iou
(
boxes1
,
boxes2
):
'''
Compute iou
Args:
boxes1 (paddle.tensor) shape (N, 4)
boxes2 (paddle.tensor) shape (M, 4)
Return:
(paddle.tensor) shape (N, M)
'''
area1
=
box_area
(
boxes1
)
area2
=
box_area
(
boxes2
)
lt
=
paddle
.
maximum
(
boxes1
.
unsqueeze
(
-
2
)[:,
:,
:
2
],
boxes2
[:,
:
2
])
rb
=
paddle
.
minimum
(
boxes1
.
unsqueeze
(
-
2
)[:,
:,
2
:],
boxes2
[:,
2
:])
wh
=
(
rb
-
lt
).
astype
(
"float32"
).
clip
(
min
=
1e-9
)
inter
=
wh
[:,
:,
0
]
*
wh
[:,
:,
1
]
union
=
area1
.
unsqueeze
(
-
1
)
+
area2
-
inter
+
1e-9
iou
=
inter
/
union
return
iou
,
union
def
get_bboxes_giou
(
boxes1
,
boxes2
,
eps
=
1e-9
):
"""calculate the ious of boxes1 and boxes2
Args:
boxes1 (Tensor): shape [N, 4]
boxes2 (Tensor): shape [M, 4]
eps (float): epsilon to avoid divide by zero
Return:
ious (Tensor): ious of boxes1 and boxes2, with the shape [N, M]
"""
assert
(
boxes1
[:,
2
:]
>=
boxes1
[:,
:
2
]).
all
()
assert
(
boxes2
[:,
2
:]
>=
boxes2
[:,
:
2
]).
all
()
iou
,
union
=
boxes_iou
(
boxes1
,
boxes2
)
lt
=
paddle
.
minimum
(
boxes1
.
unsqueeze
(
-
2
)[:,
:,
:
2
],
boxes2
[:,
:
2
])
rb
=
paddle
.
maximum
(
boxes1
.
unsqueeze
(
-
2
)[:,
:,
2
:],
boxes2
[:,
2
:])
wh
=
(
rb
-
lt
).
astype
(
"float32"
).
clip
(
min
=
eps
)
enclose_area
=
wh
[:,
:,
0
]
*
wh
[:,
:,
1
]
giou
=
iou
-
(
enclose_area
-
union
)
/
enclose_area
return
giou
def
sigmoid_focal_loss
(
inputs
,
targets
,
alpha
,
gamma
,
reduction
=
"sum"
):
assert
reduction
in
[
"sum"
,
"mean"
],
f
'do not support this
{
reduction
}
reduction?'
p
=
F
.
sigmoid
(
inputs
)
ce_loss
=
F
.
binary_cross_entropy_with_logits
(
inputs
,
targets
,
reduction
=
"none"
)
p_t
=
p
*
targets
+
(
1
-
p
)
*
(
1
-
targets
)
loss
=
ce_loss
*
((
1
-
p_t
)
**
gamma
)
if
alpha
>=
0
:
alpha_t
=
alpha
*
targets
+
(
1
-
alpha
)
*
(
1
-
targets
)
loss
=
alpha_t
*
loss
if
reduction
==
"mean"
:
loss
=
loss
.
mean
()
elif
reduction
==
"sum"
:
loss
=
loss
.
sum
()
return
loss
ppdet/modeling/post_process.py
浏览文件 @
841f2f4e
...
...
@@ -28,7 +28,7 @@ except Exception:
__all__
=
[
'BBoxPostProcess'
,
'MaskPostProcess'
,
'FCOSPostProcess'
,
'S2ANetBBoxPostProcess'
,
'JDEBBoxPostProcess'
,
'CenterNetPostProcess'
,
'DETRBBoxPostProcess'
'DETRBBoxPostProcess'
,
'SparsePostProcess'
]
...
...
@@ -551,3 +551,90 @@ class DETRBBoxPostProcess(object):
bbox_pred
.
shape
[
1
],
dtype
=
'int32'
).
tile
([
bbox_pred
.
shape
[
0
]])
bbox_pred
=
bbox_pred
.
reshape
([
-
1
,
6
])
return
bbox_pred
,
bbox_num
@
register
class
SparsePostProcess
(
object
):
__shared__
=
[
'num_classes'
]
def
__init__
(
self
,
num_proposals
,
num_classes
=
80
):
super
(
SparsePostProcess
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
num_proposals
=
num_proposals
def
__call__
(
self
,
box_cls
,
box_pred
,
scale_factor_wh
,
img_whwh
):
"""
Arguments:
box_cls (Tensor): tensor of shape (batch_size, num_proposals, K).
The tensor predicts the classification probability for each proposal.
box_pred (Tensor): tensors of shape (batch_size, num_proposals, 4).
The tensor predicts 4-vector (x,y,w,h) box
regression values for every proposal
scale_factor_wh (Tensor): tensors of shape [batch_size, 2] the scalor of per img
img_whwh (Tensor): tensors of shape [batch_size, 4]
Returns:
bbox_pred (Tensor): tensors of shape [num_boxes, 6] Each row has 6 values:
[label, confidence, xmin, ymin, xmax, ymax]
bbox_num (Tensor): tensors of shape [batch_size] the number of RoIs in each image.
"""
assert
len
(
box_cls
)
==
len
(
scale_factor_wh
)
==
len
(
img_whwh
)
img_wh
=
img_whwh
[:,
:
2
]
scores
=
F
.
sigmoid
(
box_cls
)
labels
=
paddle
.
arange
(
0
,
self
.
num_classes
).
\
unsqueeze
(
0
).
tile
([
self
.
num_proposals
,
1
]).
flatten
(
start_axis
=
0
,
stop_axis
=
1
)
classes_all
=
[]
scores_all
=
[]
boxes_all
=
[]
for
i
,
(
scores_per_image
,
box_pred_per_image
)
in
enumerate
(
zip
(
scores
,
box_pred
)):
scores_per_image
,
topk_indices
=
scores_per_image
.
flatten
(
0
,
1
).
topk
(
self
.
num_proposals
,
sorted
=
False
)
labels_per_image
=
paddle
.
gather
(
labels
,
topk_indices
,
axis
=
0
)
box_pred_per_image
=
box_pred_per_image
.
reshape
([
-
1
,
1
,
4
]).
tile
(
[
1
,
self
.
num_classes
,
1
]).
reshape
([
-
1
,
4
])
box_pred_per_image
=
paddle
.
gather
(
box_pred_per_image
,
topk_indices
,
axis
=
0
)
classes_all
.
append
(
labels_per_image
)
scores_all
.
append
(
scores_per_image
)
boxes_all
.
append
(
box_pred_per_image
)
bbox_num
=
paddle
.
zeros
([
len
(
scale_factor_wh
)],
dtype
=
"int32"
)
boxes_final
=
[]
for
i
in
range
(
len
(
scale_factor_wh
)):
classes
=
classes_all
[
i
]
boxes
=
boxes_all
[
i
]
scores
=
scores_all
[
i
]
boxes
[:,
0
::
2
]
=
paddle
.
clip
(
boxes
[:,
0
::
2
],
min
=
0
,
max
=
img_wh
[
i
][
0
])
/
scale_factor_wh
[
i
][
0
]
boxes
[:,
1
::
2
]
=
paddle
.
clip
(
boxes
[:,
1
::
2
],
min
=
0
,
max
=
img_wh
[
i
][
1
])
/
scale_factor_wh
[
i
][
1
]
boxes_w
,
boxes_h
=
(
boxes
[:,
2
]
-
boxes
[:,
0
]).
numpy
(),
(
boxes
[:,
3
]
-
boxes
[:,
1
]).
numpy
()
keep
=
(
boxes_w
>
1.
)
&
(
boxes_h
>
1.
)
if
(
keep
.
sum
()
==
0
):
bboxes
=
paddle
.
zeros
([
1
,
6
]).
astype
(
"float32"
)
else
:
boxes
=
paddle
.
to_tensor
(
boxes
.
numpy
()[
keep
]).
astype
(
"float32"
)
classes
=
paddle
.
to_tensor
(
classes
.
numpy
()[
keep
]).
astype
(
"float32"
).
unsqueeze
(
-
1
)
scores
=
paddle
.
to_tensor
(
scores
.
numpy
()[
keep
]).
astype
(
"float32"
).
unsqueeze
(
-
1
)
bboxes
=
paddle
.
concat
([
classes
,
scores
,
boxes
],
axis
=-
1
)
boxes_final
.
append
(
bboxes
)
bbox_num
[
i
]
=
bboxes
.
shape
[
0
]
bbox_pred
=
paddle
.
concat
(
boxes_final
)
return
bbox_pred
,
bbox_num
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录