Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
41d8be66
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
41d8be66
编写于
11月 15, 2022
作者:
F
Feng Ni
提交者:
GitHub
11月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support YOLOF (#7336)
上级
fb2d5549
变更
23
显示空白变更内容
内联
并排
Showing
23 changed file
with
1032 addition
and
94 deletion
+1032
-94
README_cn.md
README_cn.md
+1
-0
README_en.md
README_en.md
+1
-0
configs/runtime.yml
configs/runtime.yml
+1
-0
configs/yolof/README.md
configs/yolof/README.md
+22
-0
configs/yolof/_base_/optimizer_1x.yml
configs/yolof/_base_/optimizer_1x.yml
+19
-0
configs/yolof/_base_/yolof_r50_c5.yml
configs/yolof/_base_/yolof_r50_c5.yml
+54
-0
configs/yolof/_base_/yolof_reader.yml
configs/yolof/_base_/yolof_reader.yml
+38
-0
configs/yolof/yolof_r50_c5_1x_coco.yml
configs/yolof/yolof_r50_c5_1x_coco.yml
+10
-0
deploy/python/infer.py
deploy/python/infer.py
+1
-1
docs/MODEL_ZOO_cn.md
docs/MODEL_ZOO_cn.md
+4
-0
docs/MODEL_ZOO_en.md
docs/MODEL_ZOO_en.md
+4
-0
ppdet/data/transform/operators.py
ppdet/data/transform/operators.py
+61
-0
ppdet/engine/export_utils.py
ppdet/engine/export_utils.py
+2
-1
ppdet/engine/trainer.py
ppdet/engine/trainer.py
+9
-0
ppdet/modeling/architectures/__init__.py
ppdet/modeling/architectures/__init__.py
+2
-0
ppdet/modeling/architectures/yolof.py
ppdet/modeling/architectures/yolof.py
+88
-0
ppdet/modeling/assigners/__init__.py
ppdet/modeling/assigners/__init__.py
+2
-0
ppdet/modeling/assigners/uniform_assigner.py
ppdet/modeling/assigners/uniform_assigner.py
+93
-0
ppdet/modeling/bbox_utils.py
ppdet/modeling/bbox_utils.py
+98
-92
ppdet/modeling/heads/__init__.py
ppdet/modeling/heads/__init__.py
+2
-0
ppdet/modeling/heads/yolof_head.py
ppdet/modeling/heads/yolof_head.py
+368
-0
ppdet/modeling/necks/__init__.py
ppdet/modeling/necks/__init__.py
+2
-0
ppdet/modeling/necks/dilated_encoder.py
ppdet/modeling/necks/dilated_encoder.py
+150
-0
未找到文件。
README_cn.md
浏览文件 @
41d8be66
...
...
@@ -129,6 +129,7 @@
<li>
PP-YOLOE
</li>
<li>
PP-YOLOE+
</li>
<li>
YOLOX
</li>
<li>
YOLOF
</li>
<li>
SSD
</li>
<li>
CenterNet
</li>
<li>
FCOS
</li>
...
...
README_en.md
浏览文件 @
41d8be66
...
...
@@ -114,6 +114,7 @@
<li>
PP-YOLOE
</li>
<li>
PP-YOLOE+
</li>
<li>
YOLOX
</li>
<li>
YOLOF
</li>
<li>
SSD
</li>
<li>
CenterNet
</li>
<li>
FCOS
</li>
...
...
configs/runtime.yml
浏览文件 @
41d8be66
...
...
@@ -5,6 +5,7 @@ log_iter: 20
save_dir
:
output
snapshot_epoch
:
1
print_flops
:
false
print_params
:
false
# Exporting the model
export
:
...
...
configs/yolof/README.md
0 → 100644
浏览文件 @
41d8be66
# YOLOF (You Only Look One-level Feature)
## ModelZOO
| 网络网络 | 输入尺寸 | 图片数/GPU | Epochs | 模型推理耗时(ms) | mAP
<sup>
val
<br>
0.5:0.95 | Params(M) | FLOPs(G) | 下载链接 | 配置文件 |
| :--------------------- | :------- | :-------: | :----: | :----------: | :---------------------: | :----------------: |:---------: | :------: |:---------------: |
| YOLOF-R_50_C5 (paper) | 800x1333 | 4 | 12 | - | 37.7 | - | - | - | - |
| YOLOF-R_50_C5 | 800x1333 | 4 | 12 | - | 38.1 | 44.16 | 241.64 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/yolof_r50_c5_1x_coco.pdparams
)
|
[
配置文件
](
./yolof_r50_c5_1x_coco.yml
)
|
**注意:**
-
YOLOF模型训练过程中默认使用8 GPUs进行混合精度训练,总batch_size默认为32。
## Citations
```
@inproceedings{chen2021you,
title={You Only Look One-level Feature},
author={Chen, Qiang and Wang, Yingming and Yang, Tong and Zhang, Xiangyu and Cheng, Jian and Sun, Jian},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
year={2021}
}
```
configs/yolof/_base_/optimizer_1x.yml
0 → 100644
浏览文件 @
41d8be66
epoch
:
12
LearningRate
:
base_lr
:
0.06
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
[
8
,
11
]
-
!LinearWarmup
start_factor
:
0.00066
steps
:
1500
OptimizerBuilder
:
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.0001
type
:
L2
configs/yolof/_base_/yolof_r50_c5.yml
0 → 100644
浏览文件 @
41d8be66
architecture
:
YOLOF
find_unused_parameters
:
True
pretrain_weights
:
https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams
YOLOF
:
backbone
:
ResNet
neck
:
DilatedEncoder
head
:
YOLOFHead
ResNet
:
depth
:
50
variant
:
b
# resnet-va in paper
freeze_at
:
0
# res2
return_idx
:
[
3
]
# only res5 feature
lr_mult_list
:
[
0.3333
,
0.3333
,
0.3333
,
0.3333
]
DilatedEncoder
:
in_channels
:
[
2048
]
out_channels
:
[
512
]
block_mid_channels
:
128
num_residual_blocks
:
4
block_dilations
:
[
2
,
4
,
6
,
8
]
YOLOFHead
:
conv_feat
:
name
:
YOLOFFeat
feat_in
:
512
feat_out
:
512
num_cls_convs
:
2
num_reg_convs
:
4
norm_type
:
bn
anchor_generator
:
name
:
AnchorGenerator
anchor_sizes
:
[[
32
,
64
,
128
,
256
,
512
]]
aspect_ratios
:
[
1.0
]
strides
:
[
32
]
bbox_assigner
:
name
:
UniformAssigner
pos_ignore_thr
:
0.15
neg_ignore_thr
:
0.7
match_times
:
4
loss_class
:
name
:
FocalLoss
gamma
:
2.0
alpha
:
0.25
loss_bbox
:
name
:
GIoULoss
nms
:
name
:
MultiClassNMS
nms_top_k
:
1000
keep_top_k
:
100
score_threshold
:
0.05
nms_threshold
:
0.6
configs/yolof/_base_/yolof_reader.yml
0 → 100644
浏览文件 @
41d8be66
worker_num
:
4
TrainReader
:
sample_transforms
:
-
Decode
:
{}
-
RandomShift
:
{
prob
:
0.5
,
max_shift
:
32
}
-
Resize
:
{
target_size
:
[
800
,
1333
],
keep_ratio
:
True
,
interp
:
1
}
-
NormalizeImage
:
{
mean
:
[
0.485
,
0.456
,
0.406
],
std
:
[
0.229
,
0.224
,
0.225
],
is_scale
:
True
}
-
RandomFlip
:
{}
-
Permute
:
{}
batch_transforms
:
-
PadBatch
:
{
pad_to_stride
:
32
}
batch_size
:
4
shuffle
:
True
drop_last
:
True
collate_batch
:
False
EvalReader
:
sample_transforms
:
-
Decode
:
{}
-
Resize
:
{
target_size
:
[
800
,
1333
],
keep_ratio
:
True
,
interp
:
1
}
-
NormalizeImage
:
{
is_scale
:
True
,
mean
:
[
0.485
,
0.456
,
0.406
],
std
:
[
0.229
,
0.224
,
0.225
]}
-
Permute
:
{}
batch_transforms
:
-
PadBatch
:
{
pad_to_stride
:
32
}
batch_size
:
1
TestReader
:
sample_transforms
:
-
Decode
:
{}
-
Resize
:
{
target_size
:
[
800
,
1333
],
keep_ratio
:
True
,
interp
:
1
}
-
NormalizeImage
:
{
is_scale
:
True
,
mean
:
[
0.485
,
0.456
,
0.406
],
std
:
[
0.229
,
0.224
,
0.225
]}
-
Permute
:
{}
batch_transforms
:
-
PadBatch
:
{
pad_to_stride
:
32
}
batch_size
:
1
fuse_normalize
:
True
configs/yolof/yolof_r50_c5_1x_coco.yml
0 → 100644
浏览文件 @
41d8be66
_BASE_
:
[
'
../datasets/coco_detection.yml'
,
'
../runtime.yml'
,
'
./_base_/optimizer_1x.yml'
,
'
./_base_/yolof_r50_c5.yml'
,
'
./_base_/yolof_reader.yml'
]
log_iter
:
50
snapshot_epoch
:
1
weights
:
output/yolof_r50_c5_1x_coco/model_final
deploy/python/infer.py
浏览文件 @
41d8be66
...
...
@@ -42,7 +42,7 @@ from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco
SUPPORT_MODELS
=
{
'YOLO'
,
'RCNN'
,
'SSD'
,
'Face'
,
'FCOS'
,
'SOLOv2'
,
'TTFNet'
,
'S2ANet'
,
'JDE'
,
'FairMOT'
,
'DeepSORT'
,
'GFL'
,
'PicoDet'
,
'CenterNet'
,
'TOOD'
,
'RetinaNet'
,
'StrongBaseline'
,
'STGCN'
,
'YOLOX'
,
'PPHGNet'
,
'PPLCNet'
,
'DETR'
'StrongBaseline'
,
'STGCN'
,
'YOLOX'
,
'
YOLOF'
,
'
PPHGNet'
,
'PPLCNet'
,
'DETR'
}
TUNED_TRT_DYNAMIC_MODELS
=
{
'DETR'
}
...
...
docs/MODEL_ZOO_cn.md
浏览文件 @
41d8be66
...
...
@@ -95,6 +95,10 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型
请参考
[
YOLOX
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolox
)
### YOLOF
请参考
[
YOLOF
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolof
)
### YOLOv5
请参考
[
YOLOv5
](
https://github.com/nemonameless/PaddleDetection_YOLOSeries/tree/develop/configs/yolov5
)
...
...
docs/MODEL_ZOO_en.md
浏览文件 @
41d8be66
...
...
@@ -94,6 +94,10 @@ Please refer to[PP-YOLOE](https://github.com/PaddlePaddle/PaddleDetection/tree/d
Please refer to
[
YOLOX
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolox
)
### YOLOF
Please refer to
[
YOLOF
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolof
)
### YOLOv5
Please refer to
[
YOLOv5
](
https://github.com/nemonameless/PaddleDetection_YOLOSeries/tree/develop/configs/yolov5
)
...
...
ppdet/data/transform/operators.py
浏览文件 @
41d8be66
...
...
@@ -3396,3 +3396,64 @@ class PadResize(BaseOperator):
sample
[
'gt_bbox'
]
=
bboxes
sample
[
'gt_class'
]
=
labels
return
sample
@
register_op
class
RandomShift
(
BaseOperator
):
"""
Randomly shift image
Args:
prob (float): probability to do random shift.
max_shift (int): max shift pixels
filter_thr (int): filter gt bboxes if one side is smaller than this
"""
def
__init__
(
self
,
prob
=
0.5
,
max_shift
=
32
,
filter_thr
=
1
):
super
(
RandomShift
,
self
).
__init__
()
self
.
prob
=
prob
self
.
max_shift
=
max_shift
self
.
filter_thr
=
filter_thr
def
calc_shift_coor
(
self
,
im_h
,
im_w
,
shift_h
,
shift_w
):
return
[
max
(
0
,
shift_w
),
max
(
0
,
shift_h
),
min
(
im_w
,
im_w
+
shift_w
),
min
(
im_h
,
im_h
+
shift_h
)
]
def
apply
(
self
,
sample
,
context
=
None
):
if
random
.
random
()
>
self
.
prob
:
return
sample
im
=
sample
[
'image'
]
gt_bbox
=
sample
[
'gt_bbox'
]
gt_class
=
sample
[
'gt_class'
]
im_h
,
im_w
=
im
.
shape
[:
2
]
shift_h
=
random
.
randint
(
-
self
.
max_shift
,
self
.
max_shift
)
shift_w
=
random
.
randint
(
-
self
.
max_shift
,
self
.
max_shift
)
gt_bbox
[:,
0
::
2
]
+=
shift_w
gt_bbox
[:,
1
::
2
]
+=
shift_h
gt_bbox
[:,
0
::
2
]
=
np
.
clip
(
gt_bbox
[:,
0
::
2
],
0
,
im_w
)
gt_bbox
[:,
1
::
2
]
=
np
.
clip
(
gt_bbox
[:,
1
::
2
],
0
,
im_h
)
gt_bbox_h
=
gt_bbox
[:,
2
]
-
gt_bbox
[:,
0
]
gt_bbox_w
=
gt_bbox
[:,
3
]
-
gt_bbox
[:,
1
]
keep
=
(
gt_bbox_w
>
self
.
filter_thr
)
&
(
gt_bbox_h
>
self
.
filter_thr
)
if
not
keep
.
any
():
return
sample
gt_bbox
=
gt_bbox
[
keep
]
gt_class
=
gt_class
[
keep
]
# shift image
coor_new
=
self
.
calc_shift_coor
(
im_h
,
im_w
,
shift_h
,
shift_w
)
# shift frame to the opposite direction
coor_old
=
self
.
calc_shift_coor
(
im_h
,
im_w
,
-
shift_h
,
-
shift_w
)
canvas
=
np
.
zeros_like
(
im
)
canvas
[
coor_new
[
1
]:
coor_new
[
3
],
coor_new
[
0
]:
coor_new
[
2
]]
\
=
im
[
coor_old
[
1
]:
coor_old
[
3
],
coor_old
[
0
]:
coor_old
[
2
]]
sample
[
'image'
]
=
canvas
sample
[
'gt_bbox'
]
=
gt_bbox
sample
[
'gt_class'
]
=
gt_class
return
sample
ppdet/engine/export_utils.py
浏览文件 @
41d8be66
...
...
@@ -49,6 +49,7 @@ TRT_MIN_SUBGRAPH = {
'CenterNet'
:
5
,
'TOOD'
:
5
,
'YOLOX'
:
8
,
'YOLOF'
:
40
,
'METRO_Body'
:
3
,
'DETR'
:
3
,
}
...
...
@@ -156,7 +157,7 @@ def _dump_infer_config(config, path, image_shape, model):
arch_state
=
True
break
if
infer_arch
==
'YOLOX'
:
if
infer_arch
in
[
'YOLOX'
,
'YOLOF'
]
:
infer_cfg
[
'arch'
]
=
infer_arch
infer_cfg
[
'min_subgraph_size'
]
=
TRT_MIN_SUBGRAPH
[
infer_arch
]
arch_state
=
True
...
...
ppdet/engine/trainer.py
浏览文件 @
41d8be66
...
...
@@ -150,6 +150,15 @@ class Trainer(object):
self
.
_eval_batch_sampler
)
# TestDataset build after user set images, skip loader creation here
# get Params
print_params
=
self
.
cfg
.
get
(
'print_params'
,
False
)
if
print_params
:
params
=
sum
([
p
.
numel
()
for
n
,
p
in
self
.
model
.
named_parameters
()
if
all
([
x
not
in
n
for
x
in
[
'_mean'
,
'_variance'
]])
])
# exclude BatchNorm running status
logger
.
info
(
'Params: '
,
params
/
1e6
)
# build optimizer in train mode
if
self
.
mode
==
'train'
:
steps_per_epoch
=
len
(
self
.
loader
)
...
...
ppdet/modeling/architectures/__init__.py
浏览文件 @
41d8be66
...
...
@@ -36,6 +36,7 @@ from . import tood
from
.
import
retinanet
from
.
import
bytetrack
from
.
import
yolox
from
.
import
yolof
from
.
import
pose3d_metro
from
.meta_arch
import
*
...
...
@@ -63,4 +64,5 @@ from .tood import *
from
.retinanet
import
*
from
.bytetrack
import
*
from
.yolox
import
*
from
.yolof
import
*
from
.pose3d_metro
import
*
ppdet/modeling/architectures/yolof.py
0 → 100644
浏览文件 @
41d8be66
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
ppdet.core.workspace
import
register
,
create
from
.meta_arch
import
BaseArch
__all__
=
[
'YOLOF'
]
@
register
class
YOLOF
(
BaseArch
):
__category__
=
'architecture'
def
__init__
(
self
,
backbone
=
'ResNet'
,
neck
=
'DilatedEncoder'
,
head
=
'YOLOFHead'
,
for_mot
=
False
):
"""
YOLOF network, see https://arxiv.org/abs/2103.09460
Args:
backbone (nn.Layer): backbone instance
neck (nn.Layer): DilatedEncoder instance
head (nn.Layer): YOLOFHead instance
for_mot (bool): whether return other features for multi-object tracking
models, default False in pure object detection models.
"""
super
(
YOLOF
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
neck
=
neck
self
.
head
=
head
self
.
for_mot
=
for_mot
@
classmethod
def
from_config
(
cls
,
cfg
,
*
args
,
**
kwargs
):
# backbone
backbone
=
create
(
cfg
[
'backbone'
])
# fpn
kwargs
=
{
'input_shape'
:
backbone
.
out_shape
}
neck
=
create
(
cfg
[
'neck'
],
**
kwargs
)
# head
kwargs
=
{
'input_shape'
:
neck
.
out_shape
}
head
=
create
(
cfg
[
'head'
],
**
kwargs
)
return
{
'backbone'
:
backbone
,
'neck'
:
neck
,
"head"
:
head
,
}
def
_forward
(
self
):
body_feats
=
self
.
backbone
(
self
.
inputs
)
neck_feats
=
self
.
neck
(
body_feats
,
self
.
for_mot
)
if
self
.
training
:
yolo_losses
=
self
.
head
(
neck_feats
,
self
.
inputs
)
return
yolo_losses
else
:
yolo_head_outs
=
self
.
head
(
neck_feats
)
bbox
,
bbox_num
=
self
.
head
.
post_process
(
yolo_head_outs
,
self
.
inputs
[
'im_shape'
],
self
.
inputs
[
'scale_factor'
])
output
=
{
'bbox'
:
bbox
,
'bbox_num'
:
bbox_num
}
return
output
def
get_loss
(
self
):
return
self
.
_forward
()
def
get_pred
(
self
):
return
self
.
_forward
()
ppdet/modeling/assigners/__init__.py
浏览文件 @
41d8be66
...
...
@@ -20,6 +20,7 @@ from . import max_iou_assigner
from
.
import
fcosr_assigner
from
.
import
rotated_task_aligned_assigner
from
.
import
task_aligned_assigner_cr
from
.
import
uniform_assigner
from
.utils
import
*
from
.task_aligned_assigner
import
*
...
...
@@ -29,3 +30,4 @@ from .max_iou_assigner import *
from
.fcosr_assigner
import
*
from
.rotated_task_aligned_assigner
import
*
from
.task_aligned_assigner_cr
import
*
from
.uniform_assigner
import
*
ppdet/modeling/assigners/uniform_assigner.py
0 → 100644
浏览文件 @
41d8be66
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
ppdet.core.workspace
import
register
from
ppdet.modeling.bbox_utils
import
batch_bbox_overlaps
from
ppdet.modeling.transformers
import
bbox_xyxy_to_cxcywh
__all__
=
[
'UniformAssigner'
]
def
batch_p_dist
(
x
,
y
,
p
=
2
):
"""
calculate pairwise p_dist, the first index of x and y are batch
return [x.shape[0], y.shape[0]]
"""
x
=
x
.
unsqueeze
(
1
)
diff
=
x
-
y
return
paddle
.
norm
(
diff
,
p
=
p
,
axis
=
list
(
range
(
2
,
diff
.
dim
())))
@
register
class
UniformAssigner
(
nn
.
Layer
):
def
__init__
(
self
,
pos_ignore_thr
,
neg_ignore_thr
,
match_times
=
4
):
super
(
UniformAssigner
,
self
).
__init__
()
self
.
pos_ignore_thr
=
pos_ignore_thr
self
.
neg_ignore_thr
=
neg_ignore_thr
self
.
match_times
=
match_times
def
forward
(
self
,
bbox_pred
,
anchor
,
gt_bboxes
,
gt_labels
=
None
):
num_bboxes
=
bbox_pred
.
shape
[
0
]
num_gts
=
gt_bboxes
.
shape
[
0
]
match_labels
=
paddle
.
full
([
num_bboxes
],
-
1
,
dtype
=
paddle
.
int32
)
pred_ious
=
batch_bbox_overlaps
(
bbox_pred
,
gt_bboxes
)
pred_max_iou
=
pred_ious
.
max
(
axis
=
1
)
neg_ignore
=
pred_max_iou
>
self
.
neg_ignore_thr
# exclude potential ignored neg samples first, deal with pos samples later
#match_labels: -2(ignore), -1(neg) or >=0(pos_inds)
match_labels
=
paddle
.
where
(
neg_ignore
,
paddle
.
full_like
(
match_labels
,
-
2
),
match_labels
)
bbox_pred_c
=
bbox_xyxy_to_cxcywh
(
bbox_pred
)
anchor_c
=
bbox_xyxy_to_cxcywh
(
anchor
)
gt_bboxes_c
=
bbox_xyxy_to_cxcywh
(
gt_bboxes
)
bbox_pred_dist
=
batch_p_dist
(
bbox_pred_c
,
gt_bboxes_c
,
p
=
1
)
anchor_dist
=
batch_p_dist
(
anchor_c
,
gt_bboxes_c
,
p
=
1
)
top_pred
=
bbox_pred_dist
.
topk
(
k
=
self
.
match_times
,
axis
=
0
,
largest
=
False
)[
1
]
top_anchor
=
anchor_dist
.
topk
(
k
=
self
.
match_times
,
axis
=
0
,
largest
=
False
)[
1
]
tar_pred
=
paddle
.
arange
(
num_gts
).
expand
([
self
.
match_times
,
num_gts
])
tar_anchor
=
paddle
.
arange
(
num_gts
).
expand
([
self
.
match_times
,
num_gts
])
pos_places
=
paddle
.
concat
([
top_pred
,
top_anchor
]).
reshape
([
-
1
])
pos_inds
=
paddle
.
concat
([
tar_pred
,
tar_anchor
]).
reshape
([
-
1
])
pos_anchor
=
anchor
[
pos_places
]
pos_tar_bbox
=
gt_bboxes
[
pos_inds
]
pos_ious
=
batch_bbox_overlaps
(
pos_anchor
,
pos_tar_bbox
,
is_aligned
=
True
)
pos_ignore
=
pos_ious
<
self
.
pos_ignore_thr
pos_inds
=
paddle
.
where
(
pos_ignore
,
paddle
.
full_like
(
pos_inds
,
-
2
),
pos_inds
)
match_labels
[
pos_places
]
=
pos_inds
match_labels
.
stop_gradient
=
True
pos_keep
=
~
pos_ignore
if
pos_keep
.
sum
()
>
0
:
pos_places_keep
=
pos_places
[
pos_keep
]
pos_bbox_pred
=
bbox_pred
[
pos_places_keep
].
reshape
([
-
1
,
4
])
pos_bbox_tar
=
pos_tar_bbox
[
pos_keep
].
reshape
([
-
1
,
4
]).
detach
()
else
:
pos_bbox_pred
=
None
pos_bbox_tar
=
None
return
match_labels
,
pos_bbox_pred
,
pos_bbox_tar
ppdet/modeling/bbox_utils.py
浏览文件 @
41d8be66
...
...
@@ -17,7 +17,9 @@ import paddle
import
numpy
as
np
def
bbox2delta
(
src_boxes
,
tgt_boxes
,
weights
):
def
bbox2delta
(
src_boxes
,
tgt_boxes
,
weights
=
[
1.0
,
1.0
,
1.0
,
1.0
]):
"""Encode bboxes to deltas.
"""
src_w
=
src_boxes
[:,
2
]
-
src_boxes
[:,
0
]
src_h
=
src_boxes
[:,
3
]
-
src_boxes
[:,
1
]
src_ctr_x
=
src_boxes
[:,
0
]
+
0.5
*
src_w
...
...
@@ -38,7 +40,11 @@ def bbox2delta(src_boxes, tgt_boxes, weights):
return
deltas
def
delta2bbox
(
deltas
,
boxes
,
weights
):
def
delta2bbox
(
deltas
,
boxes
,
weights
=
[
1.0
,
1.0
,
1.0
,
1.0
],
max_shape
=
None
):
"""Decode deltas to boxes. Used in RCNNBox,CascadeHead,RCNNHead,RetinaHead.
Note: return tensor shape [n,1,4]
If you want to add a reshape, please add after the calling code instead of here.
"""
clip_scale
=
math
.
log
(
1000.0
/
16
)
widths
=
boxes
[:,
2
]
-
boxes
[:,
0
]
...
...
@@ -67,6 +73,96 @@ def delta2bbox(deltas, boxes, weights):
pred_boxes
.
append
(
pred_ctr_y
+
0.5
*
pred_h
)
pred_boxes
=
paddle
.
stack
(
pred_boxes
,
axis
=-
1
)
if
max_shape
is
not
None
:
pred_boxes
[...,
0
::
2
]
=
pred_boxes
[...,
0
::
2
].
clip
(
min
=
0
,
max
=
max_shape
[
1
])
pred_boxes
[...,
1
::
2
]
=
pred_boxes
[...,
1
::
2
].
clip
(
min
=
0
,
max
=
max_shape
[
0
])
return
pred_boxes
def
bbox2delta_v2
(
src_boxes
,
tgt_boxes
,
delta_mean
=
[
0.0
,
0.0
,
0.0
,
0.0
],
delta_std
=
[
1.0
,
1.0
,
1.0
,
1.0
]):
"""Encode bboxes to deltas.
Modified from bbox2delta() which just use weight parameters to multiply deltas.
"""
src_w
=
src_boxes
[:,
2
]
-
src_boxes
[:,
0
]
src_h
=
src_boxes
[:,
3
]
-
src_boxes
[:,
1
]
src_ctr_x
=
src_boxes
[:,
0
]
+
0.5
*
src_w
src_ctr_y
=
src_boxes
[:,
1
]
+
0.5
*
src_h
tgt_w
=
tgt_boxes
[:,
2
]
-
tgt_boxes
[:,
0
]
tgt_h
=
tgt_boxes
[:,
3
]
-
tgt_boxes
[:,
1
]
tgt_ctr_x
=
tgt_boxes
[:,
0
]
+
0.5
*
tgt_w
tgt_ctr_y
=
tgt_boxes
[:,
1
]
+
0.5
*
tgt_h
dx
=
(
tgt_ctr_x
-
src_ctr_x
)
/
src_w
dy
=
(
tgt_ctr_y
-
src_ctr_y
)
/
src_h
dw
=
paddle
.
log
(
tgt_w
/
src_w
)
dh
=
paddle
.
log
(
tgt_h
/
src_h
)
deltas
=
paddle
.
stack
((
dx
,
dy
,
dw
,
dh
),
axis
=
1
)
deltas
=
(
deltas
-
paddle
.
to_tensor
(
delta_mean
))
/
paddle
.
to_tensor
(
delta_std
)
return
deltas
def
delta2bbox_v2
(
deltas
,
boxes
,
delta_mean
=
[
0.0
,
0.0
,
0.0
,
0.0
],
delta_std
=
[
1.0
,
1.0
,
1.0
,
1.0
],
max_shape
=
None
,
ctr_clip
=
32.0
):
"""Decode deltas to bboxes.
Modified from delta2bbox() which just use weight parameters to be divided by deltas.
Used in YOLOFHead.
Note: return tensor shape [n,1,4]
If you want to add a reshape, please add after the calling code instead of here.
"""
clip_scale
=
math
.
log
(
1000.0
/
16
)
widths
=
boxes
[:,
2
]
-
boxes
[:,
0
]
heights
=
boxes
[:,
3
]
-
boxes
[:,
1
]
ctr_x
=
boxes
[:,
0
]
+
0.5
*
widths
ctr_y
=
boxes
[:,
1
]
+
0.5
*
heights
deltas
=
deltas
*
paddle
.
to_tensor
(
delta_std
)
+
paddle
.
to_tensor
(
delta_mean
)
dx
=
deltas
[:,
0
::
4
]
dy
=
deltas
[:,
1
::
4
]
dw
=
deltas
[:,
2
::
4
]
dh
=
deltas
[:,
3
::
4
]
# Prevent sending too large values into paddle.exp()
dx
=
dx
*
widths
.
unsqueeze
(
1
)
dy
=
dy
*
heights
.
unsqueeze
(
1
)
if
ctr_clip
is
not
None
:
dx
=
paddle
.
clip
(
dx
,
max
=
ctr_clip
,
min
=-
ctr_clip
)
dy
=
paddle
.
clip
(
dy
,
max
=
ctr_clip
,
min
=-
ctr_clip
)
dw
=
paddle
.
clip
(
dw
,
max
=
clip_scale
)
dh
=
paddle
.
clip
(
dh
,
max
=
clip_scale
)
else
:
dw
=
dw
.
clip
(
min
=-
ctr_clip
,
max
=
ctr_clip
)
dh
=
dh
.
clip
(
min
=-
ctr_clip
,
max
=
ctr_clip
)
pred_ctr_x
=
dx
+
ctr_x
.
unsqueeze
(
1
)
pred_ctr_y
=
dy
+
ctr_y
.
unsqueeze
(
1
)
pred_w
=
paddle
.
exp
(
dw
)
*
widths
.
unsqueeze
(
1
)
pred_h
=
paddle
.
exp
(
dh
)
*
heights
.
unsqueeze
(
1
)
pred_boxes
=
[]
pred_boxes
.
append
(
pred_ctr_x
-
0.5
*
pred_w
)
pred_boxes
.
append
(
pred_ctr_y
-
0.5
*
pred_h
)
pred_boxes
.
append
(
pred_ctr_x
+
0.5
*
pred_w
)
pred_boxes
.
append
(
pred_ctr_y
+
0.5
*
pred_h
)
pred_boxes
=
paddle
.
stack
(
pred_boxes
,
axis
=-
1
)
if
max_shape
is
not
None
:
pred_boxes
[...,
0
::
2
]
=
pred_boxes
[...,
0
::
2
].
clip
(
min
=
0
,
max
=
max_shape
[
1
])
pred_boxes
[...,
1
::
2
]
=
pred_boxes
[...,
1
::
2
].
clip
(
min
=
0
,
max
=
max_shape
[
0
])
return
pred_boxes
...
...
@@ -489,96 +585,6 @@ def batch_distance2bbox(points, distance, max_shapes=None):
return
out_bbox
def
delta2bbox_v2
(
rois
,
deltas
,
means
=
(
0.0
,
0.0
,
0.0
,
0.0
),
stds
=
(
1.0
,
1.0
,
1.0
,
1.0
),
max_shape
=
None
,
wh_ratio_clip
=
16.0
/
1000.0
,
ctr_clip
=
None
):
"""Transform network output(delta) to bboxes.
Based on https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/
bbox/coder/delta_xywh_bbox_coder.py
Args:
rois (Tensor): shape [..., 4], base bboxes, typical examples include
anchor and rois
deltas (Tensor): shape [..., 4], offset relative to base bboxes
means (list[float]): the mean that was used to normalize deltas,
must be of size 4
stds (list[float]): the std that was used to normalize deltas,
must be of size 4
max_shape (list[float] or None): height and width of image, will be
used to clip bboxes if not None
wh_ratio_clip (float): to clip delta wh of decoded bboxes
ctr_clip (float or None): whether to clip delta xy of decoded bboxes
"""
if
rois
.
size
==
0
:
return
paddle
.
empty_like
(
rois
)
means
=
paddle
.
to_tensor
(
means
)
stds
=
paddle
.
to_tensor
(
stds
)
deltas
=
deltas
*
stds
+
means
dxy
=
deltas
[...,
:
2
]
dwh
=
deltas
[...,
2
:]
pxy
=
(
rois
[...,
:
2
]
+
rois
[...,
2
:])
*
0.5
pwh
=
rois
[...,
2
:]
-
rois
[...,
:
2
]
dxy_wh
=
pwh
*
dxy
max_ratio
=
np
.
abs
(
np
.
log
(
wh_ratio_clip
))
if
ctr_clip
is
not
None
:
dxy_wh
=
paddle
.
clip
(
dxy_wh
,
max
=
ctr_clip
,
min
=-
ctr_clip
)
dwh
=
paddle
.
clip
(
dwh
,
max
=
max_ratio
)
else
:
dwh
=
dwh
.
clip
(
min
=-
max_ratio
,
max
=
max_ratio
)
gxy
=
pxy
+
dxy_wh
gwh
=
pwh
*
dwh
.
exp
()
x1y1
=
gxy
-
(
gwh
*
0.5
)
x2y2
=
gxy
+
(
gwh
*
0.5
)
bboxes
=
paddle
.
concat
([
x1y1
,
x2y2
],
axis
=-
1
)
if
max_shape
is
not
None
:
bboxes
[...,
0
::
2
]
=
bboxes
[...,
0
::
2
].
clip
(
min
=
0
,
max
=
max_shape
[
1
])
bboxes
[...,
1
::
2
]
=
bboxes
[...,
1
::
2
].
clip
(
min
=
0
,
max
=
max_shape
[
0
])
return
bboxes
def
bbox2delta_v2
(
src_boxes
,
tgt_boxes
,
means
=
(
0.0
,
0.0
,
0.0
,
0.0
),
stds
=
(
1.0
,
1.0
,
1.0
,
1.0
)):
"""Encode bboxes to deltas.
Modified from ppdet.modeling.bbox_utils.bbox2delta.
Args:
src_boxes (Tensor[..., 4]): base bboxes
tgt_boxes (Tensor[..., 4]): target bboxes
means (list[float]): the mean that will be used to normalize delta
stds (list[float]): the std that will be used to normalize delta
"""
if
src_boxes
.
size
==
0
:
return
paddle
.
empty_like
(
src_boxes
)
src_w
=
src_boxes
[...,
2
]
-
src_boxes
[...,
0
]
src_h
=
src_boxes
[...,
3
]
-
src_boxes
[...,
1
]
src_ctr_x
=
src_boxes
[...,
0
]
+
0.5
*
src_w
src_ctr_y
=
src_boxes
[...,
1
]
+
0.5
*
src_h
tgt_w
=
tgt_boxes
[...,
2
]
-
tgt_boxes
[...,
0
]
tgt_h
=
tgt_boxes
[...,
3
]
-
tgt_boxes
[...,
1
]
tgt_ctr_x
=
tgt_boxes
[...,
0
]
+
0.5
*
tgt_w
tgt_ctr_y
=
tgt_boxes
[...,
1
]
+
0.5
*
tgt_h
dx
=
(
tgt_ctr_x
-
src_ctr_x
)
/
src_w
dy
=
(
tgt_ctr_y
-
src_ctr_y
)
/
src_h
dw
=
paddle
.
log
(
tgt_w
/
src_w
)
dh
=
paddle
.
log
(
tgt_h
/
src_h
)
deltas
=
paddle
.
stack
((
dx
,
dy
,
dw
,
dh
),
axis
=
1
)
# [n, 4]
means
=
paddle
.
to_tensor
(
means
,
place
=
src_boxes
.
place
)
stds
=
paddle
.
to_tensor
(
stds
,
place
=
src_boxes
.
place
)
deltas
=
(
deltas
-
means
)
/
stds
return
deltas
def
iou_similarity
(
box1
,
box2
,
eps
=
1e-10
):
"""Calculate iou of box1 and box2
...
...
ppdet/modeling/heads/__init__.py
浏览文件 @
41d8be66
...
...
@@ -36,6 +36,7 @@ from . import ppyoloe_head
from
.
import
fcosr_head
from
.
import
ppyoloe_r_head
from
.
import
ld_gfl_head
from
.
import
yolof_head
from
.bbox_head
import
*
from
.mask_head
import
*
...
...
@@ -61,3 +62,4 @@ from .ppyoloe_head import *
from
.fcosr_head
import
*
from
.ld_gfl_head
import
*
from
.ppyoloe_r_head
import
*
from
.yolof_head
import
*
ppdet/modeling/heads/yolof_head.py
0 → 100644
浏览文件 @
41d8be66
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle
import
ParamAttr
from
paddle.regularizer
import
L2Decay
from
paddle.nn.initializer
import
Normal
,
Constant
from
ppdet.modeling.layers
import
MultiClassNMS
from
ppdet.core.workspace
import
register
from
ppdet.modeling.bbox_utils
import
delta2bbox_v2
__all__
=
[
'YOLOFHead'
]
INF
=
1e8
def
reduce_mean
(
tensor
):
world_size
=
paddle
.
distributed
.
get_world_size
()
if
world_size
==
1
:
return
tensor
paddle
.
distributed
.
all_reduce
(
tensor
)
return
tensor
/
world_size
def
find_inside_anchor
(
feat_size
,
stride
,
num_anchors
,
im_shape
):
feat_h
,
feat_w
=
feat_size
[:
2
]
im_h
,
im_w
=
im_shape
[:
2
]
inside_h
=
min
(
int
(
np
.
ceil
(
im_h
/
stride
)),
feat_h
)
inside_w
=
min
(
int
(
np
.
ceil
(
im_w
/
stride
)),
feat_w
)
inside_mask
=
paddle
.
zeros
([
feat_h
,
feat_w
],
dtype
=
paddle
.
bool
)
inside_mask
[:
inside_h
,
:
inside_w
]
=
True
inside_mask
=
inside_mask
.
unsqueeze
(
-
1
).
expand
(
[
feat_h
,
feat_w
,
num_anchors
])
return
inside_mask
.
reshape
([
-
1
])
@
register
class
YOLOFFeat
(
nn
.
Layer
):
def
__init__
(
self
,
feat_in
=
256
,
feat_out
=
256
,
num_cls_convs
=
2
,
num_reg_convs
=
4
,
norm_type
=
'bn'
):
super
(
YOLOFFeat
,
self
).
__init__
()
assert
norm_type
==
'bn'
,
"YOLOFFeat only support BN now."
self
.
feat_in
=
feat_in
self
.
feat_out
=
feat_out
self
.
num_cls_convs
=
num_cls_convs
self
.
num_reg_convs
=
num_reg_convs
self
.
norm_type
=
norm_type
cls_subnet
,
reg_subnet
=
[],
[]
for
i
in
range
(
self
.
num_cls_convs
):
feat_in
=
self
.
feat_in
if
i
==
0
else
self
.
feat_out
cls_subnet
.
append
(
nn
.
Conv2D
(
feat_in
,
self
.
feat_out
,
3
,
stride
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0.0
,
std
=
0.01
)),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
value
=
0.0
))))
cls_subnet
.
append
(
nn
.
BatchNorm2D
(
self
.
feat_out
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
))))
cls_subnet
.
append
(
nn
.
ReLU
())
for
i
in
range
(
self
.
num_reg_convs
):
feat_in
=
self
.
feat_in
if
i
==
0
else
self
.
feat_out
reg_subnet
.
append
(
nn
.
Conv2D
(
feat_in
,
self
.
feat_out
,
3
,
stride
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0.0
,
std
=
0.01
)),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
value
=
0.0
))))
reg_subnet
.
append
(
nn
.
BatchNorm2D
(
self
.
feat_out
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
))))
reg_subnet
.
append
(
nn
.
ReLU
())
self
.
cls_subnet
=
nn
.
Sequential
(
*
cls_subnet
)
self
.
reg_subnet
=
nn
.
Sequential
(
*
reg_subnet
)
def
forward
(
self
,
fpn_feat
):
cls_feat
=
self
.
cls_subnet
(
fpn_feat
)
reg_feat
=
self
.
reg_subnet
(
fpn_feat
)
return
cls_feat
,
reg_feat
@
register
class
YOLOFHead
(
nn
.
Layer
):
__shared__
=
[
'num_classes'
,
'trt'
,
'exclude_nms'
]
__inject__
=
[
'conv_feat'
,
'anchor_generator'
,
'bbox_assigner'
,
'loss_class'
,
'loss_bbox'
,
'nms'
]
def
__init__
(
self
,
num_classes
=
80
,
conv_feat
=
'YOLOFFeat'
,
anchor_generator
=
'AnchorGenerator'
,
bbox_assigner
=
'UniformAssigner'
,
loss_class
=
'FocalLoss'
,
loss_bbox
=
'GIoULoss'
,
ctr_clip
=
32.0
,
delta_mean
=
[
0.0
,
0.0
,
0.0
,
0.0
],
delta_std
=
[
1.0
,
1.0
,
1.0
,
1.0
],
nms
=
'MultiClassNMS'
,
prior_prob
=
0.01
,
nms_pre
=
1000
,
use_inside_anchor
=
False
,
trt
=
False
,
exclude_nms
=
False
):
super
(
YOLOFHead
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
conv_feat
=
conv_feat
self
.
anchor_generator
=
anchor_generator
self
.
na
=
self
.
anchor_generator
.
num_anchors
self
.
bbox_assigner
=
bbox_assigner
self
.
loss_class
=
loss_class
self
.
loss_bbox
=
loss_bbox
self
.
ctr_clip
=
ctr_clip
self
.
delta_mean
=
delta_mean
self
.
delta_std
=
delta_std
self
.
nms
=
nms
self
.
nms_pre
=
nms_pre
self
.
use_inside_anchor
=
use_inside_anchor
if
isinstance
(
self
.
nms
,
MultiClassNMS
)
and
trt
:
self
.
nms
.
trt
=
trt
self
.
exclude_nms
=
exclude_nms
bias_init_value
=
-
math
.
log
((
1
-
prior_prob
)
/
prior_prob
)
self
.
cls_score
=
self
.
add_sublayer
(
'cls_score'
,
nn
.
Conv2D
(
in_channels
=
conv_feat
.
feat_out
,
out_channels
=
self
.
num_classes
*
self
.
na
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0.0
,
std
=
0.01
)),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
value
=
bias_init_value
))))
self
.
bbox_pred
=
self
.
add_sublayer
(
'bbox_pred'
,
nn
.
Conv2D
(
in_channels
=
conv_feat
.
feat_out
,
out_channels
=
4
*
self
.
na
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0.0
,
std
=
0.01
)),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
value
=
0
))))
self
.
object_pred
=
self
.
add_sublayer
(
'object_pred'
,
nn
.
Conv2D
(
in_channels
=
conv_feat
.
feat_out
,
out_channels
=
self
.
na
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0.0
,
std
=
0.01
)),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
value
=
0
))))
def
forward
(
self
,
feats
,
targets
=
None
):
assert
len
(
feats
)
==
1
,
"YOLOF only has one level feature."
conv_cls_feat
,
conv_reg_feat
=
self
.
conv_feat
(
feats
[
0
])
cls_logits
=
self
.
cls_score
(
conv_cls_feat
)
objectness
=
self
.
object_pred
(
conv_reg_feat
)
bboxes_reg
=
self
.
bbox_pred
(
conv_reg_feat
)
N
,
C
,
H
,
W
=
paddle
.
shape
(
cls_logits
)[:]
cls_logits
=
cls_logits
.
reshape
((
N
,
self
.
na
,
self
.
num_classes
,
H
,
W
))
objectness
=
objectness
.
reshape
((
N
,
self
.
na
,
1
,
H
,
W
))
norm_cls_logits
=
cls_logits
+
objectness
-
paddle
.
log
(
1.0
+
paddle
.
clip
(
cls_logits
.
exp
(),
max
=
INF
)
+
paddle
.
clip
(
objectness
.
exp
(),
max
=
INF
))
norm_cls_logits
=
norm_cls_logits
.
reshape
((
N
,
C
,
H
,
W
))
anchors
=
self
.
anchor_generator
([
norm_cls_logits
])
if
self
.
training
:
yolof_losses
=
self
.
get_loss
(
[
anchors
[
0
],
norm_cls_logits
,
bboxes_reg
],
targets
)
return
yolof_losses
else
:
return
anchors
[
0
],
norm_cls_logits
,
bboxes_reg
def
get_loss
(
self
,
head_outs
,
targets
):
anchors
,
cls_logits
,
bbox_preds
=
head_outs
feat_size
=
cls_logits
.
shape
[
-
2
:]
cls_logits
=
cls_logits
.
transpose
([
0
,
2
,
3
,
1
])
cls_logits
=
cls_logits
.
reshape
([
0
,
-
1
,
self
.
num_classes
])
bbox_preds
=
bbox_preds
.
transpose
([
0
,
2
,
3
,
1
])
bbox_preds
=
bbox_preds
.
reshape
([
0
,
-
1
,
4
])
num_pos_list
=
[]
cls_pred_list
,
cls_tar_list
=
[],
[]
reg_pred_list
,
reg_tar_list
=
[],
[]
# find and gather preds and targets in each image
for
cls_logit
,
bbox_pred
,
gt_bbox
,
gt_class
,
im_shape
in
zip
(
cls_logits
,
bbox_preds
,
targets
[
'gt_bbox'
],
targets
[
'gt_class'
],
targets
[
'im_shape'
]):
if
self
.
use_inside_anchor
:
inside_mask
=
find_inside_anchor
(
feat_size
,
self
.
anchor_generator
.
strides
[
0
],
self
.
na
,
im_shape
.
tolist
())
cls_logit
=
cls_logit
[
inside_mask
]
bbox_pred
=
bbox_pred
[
inside_mask
]
anchors
=
anchors
[
inside_mask
]
bbox_pred
=
delta2bbox_v2
(
bbox_pred
,
anchors
,
self
.
delta_mean
,
self
.
delta_std
,
ctr_clip
=
self
.
ctr_clip
)
bbox_pred
=
bbox_pred
.
reshape
([
-
1
,
bbox_pred
.
shape
[
-
1
]])
# -2:ignore, -1:neg, >=0:pos
match_labels
,
pos_bbox_pred
,
pos_bbox_tar
=
self
.
bbox_assigner
(
bbox_pred
,
anchors
,
gt_bbox
)
pos_mask
=
(
match_labels
>=
0
)
neg_mask
=
(
match_labels
==
-
1
)
chosen_mask
=
paddle
.
logical_or
(
pos_mask
,
neg_mask
)
gt_class
=
gt_class
.
reshape
([
-
1
])
bg_class
=
paddle
.
to_tensor
(
[
self
.
num_classes
],
dtype
=
gt_class
.
dtype
)
# a trick to assign num_classes to negative targets
gt_class
=
paddle
.
concat
([
gt_class
,
bg_class
],
axis
=-
1
)
match_labels
=
paddle
.
where
(
neg_mask
,
paddle
.
full_like
(
match_labels
,
gt_class
.
size
-
1
),
match_labels
)
num_pos_list
.
append
(
max
(
1.0
,
pos_mask
.
sum
().
item
()))
cls_pred_list
.
append
(
cls_logit
[
chosen_mask
])
cls_tar_list
.
append
(
gt_class
[
match_labels
[
chosen_mask
]])
reg_pred_list
.
append
(
pos_bbox_pred
)
reg_tar_list
.
append
(
pos_bbox_tar
)
num_tot_pos
=
paddle
.
to_tensor
(
sum
(
num_pos_list
))
num_tot_pos
=
reduce_mean
(
num_tot_pos
).
item
()
num_tot_pos
=
max
(
1.0
,
num_tot_pos
)
cls_pred
=
paddle
.
concat
(
cls_pred_list
)
cls_tar
=
paddle
.
concat
(
cls_tar_list
)
cls_loss
=
self
.
loss_class
(
cls_pred
,
cls_tar
,
reduction
=
'sum'
)
/
num_tot_pos
reg_pred_list
=
[
_
for
_
in
reg_pred_list
if
_
is
not
None
]
reg_tar_list
=
[
_
for
_
in
reg_tar_list
if
_
is
not
None
]
if
len
(
reg_pred_list
)
==
0
:
reg_loss
=
bbox_preds
.
sum
()
*
0.0
else
:
reg_pred
=
paddle
.
concat
(
reg_pred_list
)
reg_tar
=
paddle
.
concat
(
reg_tar_list
)
reg_loss
=
self
.
loss_bbox
(
reg_pred
,
reg_tar
).
sum
()
/
num_tot_pos
yolof_losses
=
{
'loss'
:
cls_loss
+
reg_loss
,
'loss_cls'
:
cls_loss
,
'loss_reg'
:
reg_loss
,
}
return
yolof_losses
def
get_bboxes_single
(
self
,
anchors
,
cls_scores
,
bbox_preds
,
im_shape
,
scale_factor
,
rescale
=
True
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
mlvl_bboxes
=
[]
mlvl_scores
=
[]
for
anchor
,
cls_score
,
bbox_pred
in
zip
(
anchors
,
cls_scores
,
bbox_preds
):
cls_score
=
cls_score
.
reshape
([
-
1
,
self
.
num_classes
])
bbox_pred
=
bbox_pred
.
reshape
([
-
1
,
4
])
if
self
.
nms_pre
is
not
None
and
cls_score
.
shape
[
0
]
>
self
.
nms_pre
:
max_score
=
cls_score
.
max
(
axis
=
1
)
_
,
topk_inds
=
max_score
.
topk
(
self
.
nms_pre
)
bbox_pred
=
bbox_pred
.
gather
(
topk_inds
)
anchor
=
anchor
.
gather
(
topk_inds
)
cls_score
=
cls_score
.
gather
(
topk_inds
)
bbox_pred
=
delta2bbox_v2
(
bbox_pred
,
anchor
,
self
.
delta_mean
,
self
.
delta_std
,
max_shape
=
im_shape
,
ctr_clip
=
self
.
ctr_clip
).
squeeze
()
mlvl_bboxes
.
append
(
bbox_pred
)
mlvl_scores
.
append
(
F
.
sigmoid
(
cls_score
))
mlvl_bboxes
=
paddle
.
concat
(
mlvl_bboxes
)
mlvl_bboxes
=
paddle
.
squeeze
(
mlvl_bboxes
)
if
rescale
:
mlvl_bboxes
=
mlvl_bboxes
/
paddle
.
concat
(
[
scale_factor
[::
-
1
],
scale_factor
[::
-
1
]])
mlvl_scores
=
paddle
.
concat
(
mlvl_scores
)
mlvl_scores
=
mlvl_scores
.
transpose
([
1
,
0
])
return
mlvl_bboxes
,
mlvl_scores
def
decode
(
self
,
anchors
,
cls_scores
,
bbox_preds
,
im_shape
,
scale_factor
):
batch_bboxes
=
[]
batch_scores
=
[]
for
img_id
in
range
(
cls_scores
[
0
].
shape
[
0
]):
num_lvls
=
len
(
cls_scores
)
cls_score_list
=
[
cls_scores
[
i
][
img_id
]
for
i
in
range
(
num_lvls
)]
bbox_pred_list
=
[
bbox_preds
[
i
][
img_id
]
for
i
in
range
(
num_lvls
)]
bboxes
,
scores
=
self
.
get_bboxes_single
(
anchors
,
cls_score_list
,
bbox_pred_list
,
im_shape
[
img_id
],
scale_factor
[
img_id
])
batch_bboxes
.
append
(
bboxes
)
batch_scores
.
append
(
scores
)
batch_bboxes
=
paddle
.
stack
(
batch_bboxes
,
0
)
batch_scores
=
paddle
.
stack
(
batch_scores
,
0
)
return
batch_bboxes
,
batch_scores
def
post_process
(
self
,
head_outs
,
im_shape
,
scale_factor
):
anchors
,
cls_scores
,
bbox_preds
=
head_outs
cls_scores
=
cls_scores
.
transpose
([
0
,
2
,
3
,
1
])
bbox_preds
=
bbox_preds
.
transpose
([
0
,
2
,
3
,
1
])
pred_bboxes
,
pred_scores
=
self
.
decode
(
[
anchors
],
[
cls_scores
],
[
bbox_preds
],
im_shape
,
scale_factor
)
if
self
.
exclude_nms
:
# `exclude_nms=True` just use in benchmark
return
pred_bboxes
.
sum
(),
pred_scores
.
sum
()
else
:
bbox_pred
,
bbox_num
,
_
=
self
.
nms
(
pred_bboxes
,
pred_scores
)
return
bbox_pred
,
bbox_num
ppdet/modeling/necks/__init__.py
浏览文件 @
41d8be66
...
...
@@ -22,6 +22,7 @@ from . import csp_pan
from
.
import
es_pan
from
.
import
lc_pan
from
.
import
custom_pan
from
.
import
dilated_encoder
from
.fpn
import
*
from
.yolo_fpn
import
*
...
...
@@ -34,3 +35,4 @@ from .csp_pan import *
from
.es_pan
import
*
from
.lc_pan
import
*
from
.custom_pan
import
*
from
.dilated_encoder
import
*
ppdet/modeling/necks/dilated_encoder.py
0 → 100644
浏览文件 @
41d8be66
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
from
paddle
import
ParamAttr
from
paddle.regularizer
import
L2Decay
from
paddle.nn.initializer
import
KaimingUniform
,
Constant
,
Normal
from
ppdet.core.workspace
import
register
,
serializable
from
..shape_spec
import
ShapeSpec
__all__
=
[
'DilatedEncoder'
]
class
Bottleneck
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
mid_channels
,
dilation
):
super
(
Bottleneck
,
self
).
__init__
()
self
.
conv1
=
nn
.
Sequential
(
*
[
nn
.
Conv2D
(
in_channels
,
mid_channels
,
1
,
padding
=
0
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0
,
std
=
0.01
)),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
0.0
))),
nn
.
BatchNorm2D
(
mid_channels
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
))),
nn
.
ReLU
(),
])
self
.
conv2
=
nn
.
Sequential
(
*
[
nn
.
Conv2D
(
mid_channels
,
mid_channels
,
3
,
padding
=
dilation
,
dilation
=
dilation
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0
,
std
=
0.01
)),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
0.0
))),
nn
.
BatchNorm2D
(
mid_channels
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
))),
nn
.
ReLU
(),
])
self
.
conv3
=
nn
.
Sequential
(
*
[
nn
.
Conv2D
(
mid_channels
,
in_channels
,
1
,
padding
=
0
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0
,
std
=
0.01
)),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
0.0
))),
nn
.
BatchNorm2D
(
in_channels
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
))),
nn
.
ReLU
(),
])
def
forward
(
self
,
x
):
identity
=
x
y
=
self
.
conv3
(
self
.
conv2
(
self
.
conv1
(
x
)))
return
y
+
identity
@
register
class
DilatedEncoder
(
nn
.
Layer
):
"""
DilatedEncoder used in YOLOF
"""
def
__init__
(
self
,
in_channels
=
[
2048
],
out_channels
=
[
512
],
block_mid_channels
=
128
,
num_residual_blocks
=
4
,
block_dilations
=
[
2
,
4
,
6
,
8
]):
super
(
DilatedEncoder
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
assert
len
(
self
.
in_channels
)
==
1
,
"YOLOF only has one level feature."
assert
len
(
self
.
out_channels
)
==
1
,
"YOLOF only has one level feature."
self
.
block_mid_channels
=
block_mid_channels
self
.
num_residual_blocks
=
num_residual_blocks
self
.
block_dilations
=
block_dilations
out_ch
=
self
.
out_channels
[
0
]
self
.
lateral_conv
=
nn
.
Conv2D
(
self
.
in_channels
[
0
],
out_ch
,
1
,
weight_attr
=
ParamAttr
(
initializer
=
KaimingUniform
(
negative_slope
=
1
,
nonlinearity
=
'leaky_relu'
)),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
value
=
0.0
)))
self
.
lateral_norm
=
nn
.
BatchNorm2D
(
out_ch
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
self
.
fpn_conv
=
nn
.
Conv2D
(
out_ch
,
out_ch
,
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
KaimingUniform
(
negative_slope
=
1
,
nonlinearity
=
'leaky_relu'
)))
self
.
fpn_norm
=
nn
.
BatchNorm2D
(
out_ch
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
encoder_blocks
=
[]
for
i
in
range
(
self
.
num_residual_blocks
):
encoder_blocks
.
append
(
Bottleneck
(
out_ch
,
self
.
block_mid_channels
,
dilation
=
block_dilations
[
i
]))
self
.
dilated_encoder_blocks
=
nn
.
Sequential
(
*
encoder_blocks
)
def
forward
(
self
,
inputs
,
for_mot
=
False
):
out
=
self
.
lateral_norm
(
self
.
lateral_conv
(
inputs
[
0
]))
out
=
self
.
fpn_norm
(
self
.
fpn_conv
(
out
))
out
=
self
.
dilated_encoder_blocks
(
out
)
return
[
out
]
@
classmethod
def
from_config
(
cls
,
cfg
,
input_shape
):
return
{
'in_channels'
:
[
i
.
channels
for
i
in
input_shape
],
}
@
property
def
out_shape
(
self
):
return
[
ShapeSpec
(
channels
=
c
)
for
c
in
self
.
out_channels
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录