Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
41d8be66
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
41d8be66
编写于
11月 15, 2022
作者:
F
Feng Ni
提交者:
GitHub
11月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support YOLOF (#7336)
上级
fb2d5549
变更
23
隐藏空白更改
内联
并排
Showing
23 changed file
with
1032 addition
and
94 deletion
+1032
-94
README_cn.md
README_cn.md
+1
-0
README_en.md
README_en.md
+1
-0
configs/runtime.yml
configs/runtime.yml
+1
-0
configs/yolof/README.md
configs/yolof/README.md
+22
-0
configs/yolof/_base_/optimizer_1x.yml
configs/yolof/_base_/optimizer_1x.yml
+19
-0
configs/yolof/_base_/yolof_r50_c5.yml
configs/yolof/_base_/yolof_r50_c5.yml
+54
-0
configs/yolof/_base_/yolof_reader.yml
configs/yolof/_base_/yolof_reader.yml
+38
-0
configs/yolof/yolof_r50_c5_1x_coco.yml
configs/yolof/yolof_r50_c5_1x_coco.yml
+10
-0
deploy/python/infer.py
deploy/python/infer.py
+1
-1
docs/MODEL_ZOO_cn.md
docs/MODEL_ZOO_cn.md
+4
-0
docs/MODEL_ZOO_en.md
docs/MODEL_ZOO_en.md
+4
-0
ppdet/data/transform/operators.py
ppdet/data/transform/operators.py
+61
-0
ppdet/engine/export_utils.py
ppdet/engine/export_utils.py
+2
-1
ppdet/engine/trainer.py
ppdet/engine/trainer.py
+9
-0
ppdet/modeling/architectures/__init__.py
ppdet/modeling/architectures/__init__.py
+2
-0
ppdet/modeling/architectures/yolof.py
ppdet/modeling/architectures/yolof.py
+88
-0
ppdet/modeling/assigners/__init__.py
ppdet/modeling/assigners/__init__.py
+2
-0
ppdet/modeling/assigners/uniform_assigner.py
ppdet/modeling/assigners/uniform_assigner.py
+93
-0
ppdet/modeling/bbox_utils.py
ppdet/modeling/bbox_utils.py
+98
-92
ppdet/modeling/heads/__init__.py
ppdet/modeling/heads/__init__.py
+2
-0
ppdet/modeling/heads/yolof_head.py
ppdet/modeling/heads/yolof_head.py
+368
-0
ppdet/modeling/necks/__init__.py
ppdet/modeling/necks/__init__.py
+2
-0
ppdet/modeling/necks/dilated_encoder.py
ppdet/modeling/necks/dilated_encoder.py
+150
-0
未找到文件。
README_cn.md
浏览文件 @
41d8be66
...
...
@@ -129,6 +129,7 @@
<li>
PP-YOLOE
</li>
<li>
PP-YOLOE+
</li>
<li>
YOLOX
</li>
<li>
YOLOF
</li>
<li>
SSD
</li>
<li>
CenterNet
</li>
<li>
FCOS
</li>
...
...
README_en.md
浏览文件 @
41d8be66
...
...
@@ -114,6 +114,7 @@
<li>
PP-YOLOE
</li>
<li>
PP-YOLOE+
</li>
<li>
YOLOX
</li>
<li>
YOLOF
</li>
<li>
SSD
</li>
<li>
CenterNet
</li>
<li>
FCOS
</li>
...
...
configs/runtime.yml
浏览文件 @
41d8be66
...
...
@@ -5,6 +5,7 @@ log_iter: 20
save_dir
:
output
snapshot_epoch
:
1
print_flops
:
false
print_params
:
false
# Exporting the model
export
:
...
...
configs/yolof/README.md
0 → 100644
浏览文件 @
41d8be66
# YOLOF (You Only Look One-level Feature)
## ModelZOO
| 网络网络 | 输入尺寸 | 图片数/GPU | Epochs | 模型推理耗时(ms) | mAP
<sup>
val
<br>
0.5:0.95 | Params(M) | FLOPs(G) | 下载链接 | 配置文件 |
| :--------------------- | :------- | :-------: | :----: | :----------: | :---------------------: | :----------------: |:---------: | :------: |:---------------: |
| YOLOF-R_50_C5 (paper) | 800x1333 | 4 | 12 | - | 37.7 | - | - | - | - |
| YOLOF-R_50_C5 | 800x1333 | 4 | 12 | - | 38.1 | 44.16 | 241.64 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/yolof_r50_c5_1x_coco.pdparams
)
|
[
配置文件
](
./yolof_r50_c5_1x_coco.yml
)
|
**注意:**
-
YOLOF模型训练过程中默认使用8 GPUs进行混合精度训练,总batch_size默认为32。
## Citations
```
@inproceedings{chen2021you,
title={You Only Look One-level Feature},
author={Chen, Qiang and Wang, Yingming and Yang, Tong and Zhang, Xiangyu and Cheng, Jian and Sun, Jian},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
year={2021}
}
```
configs/yolof/_base_/optimizer_1x.yml
0 → 100644
浏览文件 @
41d8be66
epoch
:
12
LearningRate
:
base_lr
:
0.06
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
[
8
,
11
]
-
!LinearWarmup
start_factor
:
0.00066
steps
:
1500
OptimizerBuilder
:
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.0001
type
:
L2
configs/yolof/_base_/yolof_r50_c5.yml
0 → 100644
浏览文件 @
41d8be66
architecture
:
YOLOF
find_unused_parameters
:
True
pretrain_weights
:
https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams
YOLOF
:
backbone
:
ResNet
neck
:
DilatedEncoder
head
:
YOLOFHead
ResNet
:
depth
:
50
variant
:
b
# resnet-va in paper
freeze_at
:
0
# res2
return_idx
:
[
3
]
# only res5 feature
lr_mult_list
:
[
0.3333
,
0.3333
,
0.3333
,
0.3333
]
DilatedEncoder
:
in_channels
:
[
2048
]
out_channels
:
[
512
]
block_mid_channels
:
128
num_residual_blocks
:
4
block_dilations
:
[
2
,
4
,
6
,
8
]
YOLOFHead
:
conv_feat
:
name
:
YOLOFFeat
feat_in
:
512
feat_out
:
512
num_cls_convs
:
2
num_reg_convs
:
4
norm_type
:
bn
anchor_generator
:
name
:
AnchorGenerator
anchor_sizes
:
[[
32
,
64
,
128
,
256
,
512
]]
aspect_ratios
:
[
1.0
]
strides
:
[
32
]
bbox_assigner
:
name
:
UniformAssigner
pos_ignore_thr
:
0.15
neg_ignore_thr
:
0.7
match_times
:
4
loss_class
:
name
:
FocalLoss
gamma
:
2.0
alpha
:
0.25
loss_bbox
:
name
:
GIoULoss
nms
:
name
:
MultiClassNMS
nms_top_k
:
1000
keep_top_k
:
100
score_threshold
:
0.05
nms_threshold
:
0.6
configs/yolof/_base_/yolof_reader.yml
0 → 100644
浏览文件 @
41d8be66
worker_num
:
4
TrainReader
:
sample_transforms
:
-
Decode
:
{}
-
RandomShift
:
{
prob
:
0.5
,
max_shift
:
32
}
-
Resize
:
{
target_size
:
[
800
,
1333
],
keep_ratio
:
True
,
interp
:
1
}
-
NormalizeImage
:
{
mean
:
[
0.485
,
0.456
,
0.406
],
std
:
[
0.229
,
0.224
,
0.225
],
is_scale
:
True
}
-
RandomFlip
:
{}
-
Permute
:
{}
batch_transforms
:
-
PadBatch
:
{
pad_to_stride
:
32
}
batch_size
:
4
shuffle
:
True
drop_last
:
True
collate_batch
:
False
EvalReader
:
sample_transforms
:
-
Decode
:
{}
-
Resize
:
{
target_size
:
[
800
,
1333
],
keep_ratio
:
True
,
interp
:
1
}
-
NormalizeImage
:
{
is_scale
:
True
,
mean
:
[
0.485
,
0.456
,
0.406
],
std
:
[
0.229
,
0.224
,
0.225
]}
-
Permute
:
{}
batch_transforms
:
-
PadBatch
:
{
pad_to_stride
:
32
}
batch_size
:
1
TestReader
:
sample_transforms
:
-
Decode
:
{}
-
Resize
:
{
target_size
:
[
800
,
1333
],
keep_ratio
:
True
,
interp
:
1
}
-
NormalizeImage
:
{
is_scale
:
True
,
mean
:
[
0.485
,
0.456
,
0.406
],
std
:
[
0.229
,
0.224
,
0.225
]}
-
Permute
:
{}
batch_transforms
:
-
PadBatch
:
{
pad_to_stride
:
32
}
batch_size
:
1
fuse_normalize
:
True
configs/yolof/yolof_r50_c5_1x_coco.yml
0 → 100644
浏览文件 @
41d8be66
_BASE_
:
[
'
../datasets/coco_detection.yml'
,
'
../runtime.yml'
,
'
./_base_/optimizer_1x.yml'
,
'
./_base_/yolof_r50_c5.yml'
,
'
./_base_/yolof_reader.yml'
]
log_iter
:
50
snapshot_epoch
:
1
weights
:
output/yolof_r50_c5_1x_coco/model_final
deploy/python/infer.py
浏览文件 @
41d8be66
...
...
@@ -42,7 +42,7 @@ from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco
SUPPORT_MODELS
=
{
'YOLO'
,
'RCNN'
,
'SSD'
,
'Face'
,
'FCOS'
,
'SOLOv2'
,
'TTFNet'
,
'S2ANet'
,
'JDE'
,
'FairMOT'
,
'DeepSORT'
,
'GFL'
,
'PicoDet'
,
'CenterNet'
,
'TOOD'
,
'RetinaNet'
,
'StrongBaseline'
,
'STGCN'
,
'YOLOX'
,
'PPHGNet'
,
'PPLCNet'
,
'DETR'
'StrongBaseline'
,
'STGCN'
,
'YOLOX'
,
'
YOLOF'
,
'
PPHGNet'
,
'PPLCNet'
,
'DETR'
}
TUNED_TRT_DYNAMIC_MODELS
=
{
'DETR'
}
...
...
docs/MODEL_ZOO_cn.md
浏览文件 @
41d8be66
...
...
@@ -95,6 +95,10 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型
请参考
[
YOLOX
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolox
)
### YOLOF
请参考
[
YOLOF
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolof
)
### YOLOv5
请参考
[
YOLOv5
](
https://github.com/nemonameless/PaddleDetection_YOLOSeries/tree/develop/configs/yolov5
)
...
...
docs/MODEL_ZOO_en.md
浏览文件 @
41d8be66
...
...
@@ -94,6 +94,10 @@ Please refer to[PP-YOLOE](https://github.com/PaddlePaddle/PaddleDetection/tree/d
Please refer to
[
YOLOX
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolox
)
### YOLOF
Please refer to
[
YOLOF
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolof
)
### YOLOv5
Please refer to
[
YOLOv5
](
https://github.com/nemonameless/PaddleDetection_YOLOSeries/tree/develop/configs/yolov5
)
...
...
ppdet/data/transform/operators.py
浏览文件 @
41d8be66
...
...
@@ -3396,3 +3396,64 @@ class PadResize(BaseOperator):
sample
[
'gt_bbox'
]
=
bboxes
sample
[
'gt_class'
]
=
labels
return
sample
@
register_op
class
RandomShift
(
BaseOperator
):
"""
Randomly shift image
Args:
prob (float): probability to do random shift.
max_shift (int): max shift pixels
filter_thr (int): filter gt bboxes if one side is smaller than this
"""
def
__init__
(
self
,
prob
=
0.5
,
max_shift
=
32
,
filter_thr
=
1
):
super
(
RandomShift
,
self
).
__init__
()
self
.
prob
=
prob
self
.
max_shift
=
max_shift
self
.
filter_thr
=
filter_thr
def
calc_shift_coor
(
self
,
im_h
,
im_w
,
shift_h
,
shift_w
):
return
[
max
(
0
,
shift_w
),
max
(
0
,
shift_h
),
min
(
im_w
,
im_w
+
shift_w
),
min
(
im_h
,
im_h
+
shift_h
)
]
def
apply
(
self
,
sample
,
context
=
None
):
if
random
.
random
()
>
self
.
prob
:
return
sample
im
=
sample
[
'image'
]
gt_bbox
=
sample
[
'gt_bbox'
]
gt_class
=
sample
[
'gt_class'
]
im_h
,
im_w
=
im
.
shape
[:
2
]
shift_h
=
random
.
randint
(
-
self
.
max_shift
,
self
.
max_shift
)
shift_w
=
random
.
randint
(
-
self
.
max_shift
,
self
.
max_shift
)
gt_bbox
[:,
0
::
2
]
+=
shift_w
gt_bbox
[:,
1
::
2
]
+=
shift_h
gt_bbox
[:,
0
::
2
]
=
np
.
clip
(
gt_bbox
[:,
0
::
2
],
0
,
im_w
)
gt_bbox
[:,
1
::
2
]
=
np
.
clip
(
gt_bbox
[:,
1
::
2
],
0
,
im_h
)
gt_bbox_h
=
gt_bbox
[:,
2
]
-
gt_bbox
[:,
0
]
gt_bbox_w
=
gt_bbox
[:,
3
]
-
gt_bbox
[:,
1
]
keep
=
(
gt_bbox_w
>
self
.
filter_thr
)
&
(
gt_bbox_h
>
self
.
filter_thr
)
if
not
keep
.
any
():
return
sample
gt_bbox
=
gt_bbox
[
keep
]
gt_class
=
gt_class
[
keep
]
# shift image
coor_new
=
self
.
calc_shift_coor
(
im_h
,
im_w
,
shift_h
,
shift_w
)
# shift frame to the opposite direction
coor_old
=
self
.
calc_shift_coor
(
im_h
,
im_w
,
-
shift_h
,
-
shift_w
)
canvas
=
np
.
zeros_like
(
im
)
canvas
[
coor_new
[
1
]:
coor_new
[
3
],
coor_new
[
0
]:
coor_new
[
2
]]
\
=
im
[
coor_old
[
1
]:
coor_old
[
3
],
coor_old
[
0
]:
coor_old
[
2
]]
sample
[
'image'
]
=
canvas
sample
[
'gt_bbox'
]
=
gt_bbox
sample
[
'gt_class'
]
=
gt_class
return
sample
ppdet/engine/export_utils.py
浏览文件 @
41d8be66
...
...
@@ -49,6 +49,7 @@ TRT_MIN_SUBGRAPH = {
'CenterNet'
:
5
,
'TOOD'
:
5
,
'YOLOX'
:
8
,
'YOLOF'
:
40
,
'METRO_Body'
:
3
,
'DETR'
:
3
,
}
...
...
@@ -156,7 +157,7 @@ def _dump_infer_config(config, path, image_shape, model):
arch_state
=
True
break
if
infer_arch
==
'YOLOX'
:
if
infer_arch
in
[
'YOLOX'
,
'YOLOF'
]
:
infer_cfg
[
'arch'
]
=
infer_arch
infer_cfg
[
'min_subgraph_size'
]
=
TRT_MIN_SUBGRAPH
[
infer_arch
]
arch_state
=
True
...
...
ppdet/engine/trainer.py
浏览文件 @
41d8be66
...
...
@@ -150,6 +150,15 @@ class Trainer(object):
self
.
_eval_batch_sampler
)
# TestDataset build after user set images, skip loader creation here
# get Params
print_params
=
self
.
cfg
.
get
(
'print_params'
,
False
)
if
print_params
:
params
=
sum
([
p
.
numel
()
for
n
,
p
in
self
.
model
.
named_parameters
()
if
all
([
x
not
in
n
for
x
in
[
'_mean'
,
'_variance'
]])
])
# exclude BatchNorm running status
logger
.
info
(
'Params: '
,
params
/
1e6
)
# build optimizer in train mode
if
self
.
mode
==
'train'
:
steps_per_epoch
=
len
(
self
.
loader
)
...
...
ppdet/modeling/architectures/__init__.py
浏览文件 @
41d8be66
...
...
@@ -36,6 +36,7 @@ from . import tood
from
.
import
retinanet
from
.
import
bytetrack
from
.
import
yolox
from
.
import
yolof
from
.
import
pose3d_metro
from
.meta_arch
import
*
...
...
@@ -63,4 +64,5 @@ from .tood import *
from
.retinanet
import
*
from
.bytetrack
import
*
from
.yolox
import
*
from
.yolof
import
*
from
.pose3d_metro
import
*
ppdet/modeling/architectures/yolof.py
0 → 100644
浏览文件 @
41d8be66
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
ppdet.core.workspace
import
register
,
create
from
.meta_arch
import
BaseArch
__all__
=
[
'YOLOF'
]
@
register
class
YOLOF
(
BaseArch
):
__category__
=
'architecture'
def
__init__
(
self
,
backbone
=
'ResNet'
,
neck
=
'DilatedEncoder'
,
head
=
'YOLOFHead'
,
for_mot
=
False
):
"""
YOLOF network, see https://arxiv.org/abs/2103.09460
Args:
backbone (nn.Layer): backbone instance
neck (nn.Layer): DilatedEncoder instance
head (nn.Layer): YOLOFHead instance
for_mot (bool): whether return other features for multi-object tracking
models, default False in pure object detection models.
"""
super
(
YOLOF
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
neck
=
neck
self
.
head
=
head
self
.
for_mot
=
for_mot
@
classmethod
def
from_config
(
cls
,
cfg
,
*
args
,
**
kwargs
):
# backbone
backbone
=
create
(
cfg
[
'backbone'
])
# fpn
kwargs
=
{
'input_shape'
:
backbone
.
out_shape
}
neck
=
create
(
cfg
[
'neck'
],
**
kwargs
)
# head
kwargs
=
{
'input_shape'
:
neck
.
out_shape
}
head
=
create
(
cfg
[
'head'
],
**
kwargs
)
return
{
'backbone'
:
backbone
,
'neck'
:
neck
,
"head"
:
head
,
}
def
_forward
(
self
):
body_feats
=
self
.
backbone
(
self
.
inputs
)
neck_feats
=
self
.
neck
(
body_feats
,
self
.
for_mot
)
if
self
.
training
:
yolo_losses
=
self
.
head
(
neck_feats
,
self
.
inputs
)
return
yolo_losses
else
:
yolo_head_outs
=
self
.
head
(
neck_feats
)
bbox
,
bbox_num
=
self
.
head
.
post_process
(
yolo_head_outs
,
self
.
inputs
[
'im_shape'
],
self
.
inputs
[
'scale_factor'
])
output
=
{
'bbox'
:
bbox
,
'bbox_num'
:
bbox_num
}
return
output
def
get_loss
(
self
):
return
self
.
_forward
()
def
get_pred
(
self
):
return
self
.
_forward
()
ppdet/modeling/assigners/__init__.py
浏览文件 @
41d8be66
...
...
@@ -20,6 +20,7 @@ from . import max_iou_assigner
from
.
import
fcosr_assigner
from
.
import
rotated_task_aligned_assigner
from
.
import
task_aligned_assigner_cr
from
.
import
uniform_assigner
from
.utils
import
*
from
.task_aligned_assigner
import
*
...
...
@@ -29,3 +30,4 @@ from .max_iou_assigner import *
from
.fcosr_assigner
import
*
from
.rotated_task_aligned_assigner
import
*
from
.task_aligned_assigner_cr
import
*
from
.uniform_assigner
import
*
ppdet/modeling/assigners/uniform_assigner.py
0 → 100644
浏览文件 @
41d8be66
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
ppdet.core.workspace
import
register
from
ppdet.modeling.bbox_utils
import
batch_bbox_overlaps
from
ppdet.modeling.transformers
import
bbox_xyxy_to_cxcywh
__all__
=
[
'UniformAssigner'
]
def
batch_p_dist
(
x
,
y
,
p
=
2
):
"""
calculate pairwise p_dist, the first index of x and y are batch
return [x.shape[0], y.shape[0]]
"""
x
=
x
.
unsqueeze
(
1
)
diff
=
x
-
y
return
paddle
.
norm
(
diff
,
p
=
p
,
axis
=
list
(
range
(
2
,
diff
.
dim
())))
@
register
class
UniformAssigner
(
nn
.
Layer
):
def
__init__
(
self
,
pos_ignore_thr
,
neg_ignore_thr
,
match_times
=
4
):
super
(
UniformAssigner
,
self
).
__init__
()
self
.
pos_ignore_thr
=
pos_ignore_thr
self
.
neg_ignore_thr
=
neg_ignore_thr
self
.
match_times
=
match_times
def
forward
(
self
,
bbox_pred
,
anchor
,
gt_bboxes
,
gt_labels
=
None
):
num_bboxes
=
bbox_pred
.
shape
[
0
]
num_gts
=
gt_bboxes
.
shape
[
0
]
match_labels
=
paddle
.
full
([
num_bboxes
],
-
1
,
dtype
=
paddle
.
int32
)
pred_ious
=
batch_bbox_overlaps
(
bbox_pred
,
gt_bboxes
)
pred_max_iou
=
pred_ious
.
max
(
axis
=
1
)
neg_ignore
=
pred_max_iou
>
self
.
neg_ignore_thr
# exclude potential ignored neg samples first, deal with pos samples later
#match_labels: -2(ignore), -1(neg) or >=0(pos_inds)
match_labels
=
paddle
.
where
(
neg_ignore
,
paddle
.
full_like
(
match_labels
,
-
2
),
match_labels
)
bbox_pred_c
=
bbox_xyxy_to_cxcywh
(
bbox_pred
)
anchor_c
=
bbox_xyxy_to_cxcywh
(
anchor
)
gt_bboxes_c
=
bbox_xyxy_to_cxcywh
(
gt_bboxes
)
bbox_pred_dist
=
batch_p_dist
(
bbox_pred_c
,
gt_bboxes_c
,
p
=
1
)
anchor_dist
=
batch_p_dist
(
anchor_c
,
gt_bboxes_c
,
p
=
1
)
top_pred
=
bbox_pred_dist
.
topk
(
k
=
self
.
match_times
,
axis
=
0
,
largest
=
False
)[
1
]
top_anchor
=
anchor_dist
.
topk
(
k
=
self
.
match_times
,
axis
=
0
,
largest
=
False
)[
1
]
tar_pred
=
paddle
.
arange
(
num_gts
).
expand
([
self
.
match_times
,
num_gts
])
tar_anchor
=
paddle
.
arange
(
num_gts
).
expand
([
self
.
match_times
,
num_gts
])
pos_places
=
paddle
.
concat
([
top_pred
,
top_anchor
]).
reshape
([
-
1
])
pos_inds
=
paddle
.
concat
([
tar_pred
,
tar_anchor
]).
reshape
([
-
1
])
pos_anchor
=
anchor
[
pos_places
]
pos_tar_bbox
=
gt_bboxes
[
pos_inds
]
pos_ious
=
batch_bbox_overlaps
(
pos_anchor
,
pos_tar_bbox
,
is_aligned
=
True
)
pos_ignore
=
pos_ious
<
self
.
pos_ignore_thr
pos_inds
=
paddle
.
where
(
pos_ignore
,
paddle
.
full_like
(
pos_inds
,
-
2
),
pos_inds
)
match_labels
[
pos_places
]
=
pos_inds
match_labels
.
stop_gradient
=
True
pos_keep
=
~
pos_ignore
if
pos_keep
.
sum
()
>
0
:
pos_places_keep
=
pos_places
[
pos_keep
]
pos_bbox_pred
=
bbox_pred
[
pos_places_keep
].
reshape
([
-
1
,
4
])
pos_bbox_tar
=
pos_tar_bbox
[
pos_keep
].
reshape
([
-
1
,
4
]).
detach
()
else
:
pos_bbox_pred
=
None
pos_bbox_tar
=
None
return
match_labels
,
pos_bbox_pred
,
pos_bbox_tar
ppdet/modeling/bbox_utils.py
浏览文件 @
41d8be66
...
...
@@ -17,7 +17,9 @@ import paddle
import
numpy
as
np
def
bbox2delta
(
src_boxes
,
tgt_boxes
,
weights
):
def
bbox2delta
(
src_boxes
,
tgt_boxes
,
weights
=
[
1.0
,
1.0
,
1.0
,
1.0
]):
"""Encode bboxes to deltas.
"""
src_w
=
src_boxes
[:,
2
]
-
src_boxes
[:,
0
]
src_h
=
src_boxes
[:,
3
]
-
src_boxes
[:,
1
]
src_ctr_x
=
src_boxes
[:,
0
]
+
0.5
*
src_w
...
...
@@ -38,7 +40,11 @@ def bbox2delta(src_boxes, tgt_boxes, weights):
return
deltas
def
delta2bbox
(
deltas
,
boxes
,
weights
):
def
delta2bbox
(
deltas
,
boxes
,
weights
=
[
1.0
,
1.0
,
1.0
,
1.0
],
max_shape
=
None
):
"""Decode deltas to boxes. Used in RCNNBox,CascadeHead,RCNNHead,RetinaHead.
Note: return tensor shape [n,1,4]
If you want to add a reshape, please add after the calling code instead of here.
"""
clip_scale
=
math
.
log
(
1000.0
/
16
)
widths
=
boxes
[:,
2
]
-
boxes
[:,
0
]
...
...
@@ -67,6 +73,96 @@ def delta2bbox(deltas, boxes, weights):
pred_boxes
.
append
(
pred_ctr_y
+
0.5
*
pred_h
)
pred_boxes
=
paddle
.
stack
(
pred_boxes
,
axis
=-
1
)
if
max_shape
is
not
None
:
pred_boxes
[...,
0
::
2
]
=
pred_boxes
[...,
0
::
2
].
clip
(
min
=
0
,
max
=
max_shape
[
1
])
pred_boxes
[...,
1
::
2
]
=
pred_boxes
[...,
1
::
2
].
clip
(
min
=
0
,
max
=
max_shape
[
0
])
return
pred_boxes
def
bbox2delta_v2
(
src_boxes
,
tgt_boxes
,
delta_mean
=
[
0.0
,
0.0
,
0.0
,
0.0
],
delta_std
=
[
1.0
,
1.0
,
1.0
,
1.0
]):
"""Encode bboxes to deltas.
Modified from bbox2delta() which just use weight parameters to multiply deltas.
"""
src_w
=
src_boxes
[:,
2
]
-
src_boxes
[:,
0
]
src_h
=
src_boxes
[:,
3
]
-
src_boxes
[:,
1
]
src_ctr_x
=
src_boxes
[:,
0
]
+
0.5
*
src_w
src_ctr_y
=
src_boxes
[:,
1
]
+
0.5
*
src_h
tgt_w
=
tgt_boxes
[:,
2
]
-
tgt_boxes
[:,
0
]
tgt_h
=
tgt_boxes
[:,
3
]
-
tgt_boxes
[:,
1
]
tgt_ctr_x
=
tgt_boxes
[:,
0
]
+
0.5
*
tgt_w
tgt_ctr_y
=
tgt_boxes
[:,
1
]
+
0.5
*
tgt_h
dx
=
(
tgt_ctr_x
-
src_ctr_x
)
/
src_w
dy
=
(
tgt_ctr_y
-
src_ctr_y
)
/
src_h
dw
=
paddle
.
log
(
tgt_w
/
src_w
)
dh
=
paddle
.
log
(
tgt_h
/
src_h
)
deltas
=
paddle
.
stack
((
dx
,
dy
,
dw
,
dh
),
axis
=
1
)
deltas
=
(
deltas
-
paddle
.
to_tensor
(
delta_mean
))
/
paddle
.
to_tensor
(
delta_std
)
return
deltas
def
delta2bbox_v2
(
deltas
,
boxes
,
delta_mean
=
[
0.0
,
0.0
,
0.0
,
0.0
],
delta_std
=
[
1.0
,
1.0
,
1.0
,
1.0
],
max_shape
=
None
,
ctr_clip
=
32.0
):
"""Decode deltas to bboxes.
Modified from delta2bbox() which just use weight parameters to be divided by deltas.
Used in YOLOFHead.
Note: return tensor shape [n,1,4]
If you want to add a reshape, please add after the calling code instead of here.
"""
clip_scale
=
math
.
log
(
1000.0
/
16
)
widths
=
boxes
[:,
2
]
-
boxes
[:,
0
]
heights
=
boxes
[:,
3
]
-
boxes
[:,
1
]
ctr_x
=
boxes
[:,
0
]
+
0.5
*
widths
ctr_y
=
boxes
[:,
1
]
+
0.5
*
heights
deltas
=
deltas
*
paddle
.
to_tensor
(
delta_std
)
+
paddle
.
to_tensor
(
delta_mean
)
dx
=
deltas
[:,
0
::
4
]
dy
=
deltas
[:,
1
::
4
]
dw
=
deltas
[:,
2
::
4
]
dh
=
deltas
[:,
3
::
4
]
# Prevent sending too large values into paddle.exp()
dx
=
dx
*
widths
.
unsqueeze
(
1
)
dy
=
dy
*
heights
.
unsqueeze
(
1
)
if
ctr_clip
is
not
None
:
dx
=
paddle
.
clip
(
dx
,
max
=
ctr_clip
,
min
=-
ctr_clip
)
dy
=
paddle
.
clip
(
dy
,
max
=
ctr_clip
,
min
=-
ctr_clip
)
dw
=
paddle
.
clip
(
dw
,
max
=
clip_scale
)
dh
=
paddle
.
clip
(
dh
,
max
=
clip_scale
)
else
:
dw
=
dw
.
clip
(
min
=-
ctr_clip
,
max
=
ctr_clip
)
dh
=
dh
.
clip
(
min
=-
ctr_clip
,
max
=
ctr_clip
)
pred_ctr_x
=
dx
+
ctr_x
.
unsqueeze
(
1
)
pred_ctr_y
=
dy
+
ctr_y
.
unsqueeze
(
1
)
pred_w
=
paddle
.
exp
(
dw
)
*
widths
.
unsqueeze
(
1
)
pred_h
=
paddle
.
exp
(
dh
)
*
heights
.
unsqueeze
(
1
)
pred_boxes
=
[]
pred_boxes
.
append
(
pred_ctr_x
-
0.5
*
pred_w
)
pred_boxes
.
append
(
pred_ctr_y
-
0.5
*
pred_h
)
pred_boxes
.
append
(
pred_ctr_x
+
0.5
*
pred_w
)
pred_boxes
.
append
(
pred_ctr_y
+
0.5
*
pred_h
)
pred_boxes
=
paddle
.
stack
(
pred_boxes
,
axis
=-
1
)
if
max_shape
is
not
None
:
pred_boxes
[...,
0
::
2
]
=
pred_boxes
[...,
0
::
2
].
clip
(
min
=
0
,
max
=
max_shape
[
1
])
pred_boxes
[...,
1
::
2
]
=
pred_boxes
[...,
1
::
2
].
clip
(
min
=
0
,
max
=
max_shape
[
0
])
return
pred_boxes
...
...
@@ -489,96 +585,6 @@ def batch_distance2bbox(points, distance, max_shapes=None):
return
out_bbox
def
delta2bbox_v2
(
rois
,
deltas
,
means
=
(
0.0
,
0.0
,
0.0
,
0.0
),
stds
=
(
1.0
,
1.0
,
1.0
,
1.0
),
max_shape
=
None
,
wh_ratio_clip
=
16.0
/
1000.0
,
ctr_clip
=
None
):
"""Transform network output(delta) to bboxes.
Based on https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/
bbox/coder/delta_xywh_bbox_coder.py
Args:
rois (Tensor): shape [..., 4], base bboxes, typical examples include
anchor and rois
deltas (Tensor): shape [..., 4], offset relative to base bboxes
means (list[float]): the mean that was used to normalize deltas,
must be of size 4
stds (list[float]): the std that was used to normalize deltas,
must be of size 4
max_shape (list[float] or None): height and width of image, will be
used to clip bboxes if not None
wh_ratio_clip (float): to clip delta wh of decoded bboxes
ctr_clip (float or None): whether to clip delta xy of decoded bboxes
"""
if
rois
.
size
==
0
:
return
paddle
.
empty_like
(
rois
)
means
=
paddle
.
to_tensor
(
means
)
stds
=
paddle
.
to_tensor
(
stds
)
deltas
=
deltas
*
stds
+
means
dxy
=
deltas
[...,
:
2
]
dwh
=
deltas
[...,
2
:]
pxy
=
(
rois
[...,
:
2
]
+
rois
[...,
2
:])
*
0.5
pwh
=
rois
[...,
2
:]
-
rois
[...,
:
2
]
dxy_wh
=
pwh
*
dxy
max_ratio
=
np
.
abs
(
np
.
log
(
wh_ratio_clip
))
if
ctr_clip
is
not
None
:
dxy_wh
=
paddle
.
clip
(
dxy_wh
,
max
=
ctr_clip
,
min
=-
ctr_clip
)
dwh
=
paddle
.
clip
(
dwh
,
max
=
max_ratio
)
else
:
dwh
=
dwh
.
clip
(
min
=-
max_ratio
,
max
=
max_ratio
)
gxy
=
pxy
+
dxy_wh
gwh
=
pwh
*
dwh
.
exp
()
x1y1
=
gxy
-
(
gwh
*
0.5
)
x2y2
=
gxy
+
(
gwh
*
0.5
)
bboxes
=
paddle
.
concat
([
x1y1
,
x2y2
],
axis
=-
1
)
if
max_shape
is
not
None
:
bboxes
[...,
0
::
2
]
=
bboxes
[...,
0
::
2
].
clip
(
min
=
0
,
max
=
max_shape
[
1
])
bboxes
[...,
1
::
2
]
=
bboxes
[...,
1
::
2
].
clip
(
min
=
0
,
max
=
max_shape
[
0
])
return
bboxes
def
bbox2delta_v2
(
src_boxes
,
tgt_boxes
,
means
=
(
0.0
,
0.0
,
0.0
,
0.0
),
stds
=
(
1.0
,
1.0
,
1.0
,
1.0
)):
"""Encode bboxes to deltas.
Modified from ppdet.modeling.bbox_utils.bbox2delta.
Args:
src_boxes (Tensor[..., 4]): base bboxes
tgt_boxes (Tensor[..., 4]): target bboxes
means (list[float]): the mean that will be used to normalize delta
stds (list[float]): the std that will be used to normalize delta
"""
if
src_boxes
.
size
==
0
:
return
paddle
.
empty_like
(
src_boxes
)
src_w
=
src_boxes
[...,
2
]
-
src_boxes
[...,
0
]
src_h
=
src_boxes
[...,
3
]
-
src_boxes
[...,
1
]
src_ctr_x
=
src_boxes
[...,
0
]
+
0.5
*
src_w
src_ctr_y
=
src_boxes
[...,
1
]
+
0.5
*
src_h
tgt_w
=
tgt_boxes
[...,
2
]
-
tgt_boxes
[...,
0
]
tgt_h
=
tgt_boxes
[...,
3
]
-
tgt_boxes
[...,
1
]
tgt_ctr_x
=
tgt_boxes
[...,
0
]
+
0.5
*
tgt_w
tgt_ctr_y
=
tgt_boxes
[...,
1
]
+
0.5
*
tgt_h
dx
=
(
tgt_ctr_x
-
src_ctr_x
)
/
src_w
dy
=
(
tgt_ctr_y
-
src_ctr_y
)
/
src_h
dw
=
paddle
.
log
(
tgt_w
/
src_w
)
dh
=
paddle
.
log
(
tgt_h
/
src_h
)
deltas
=
paddle
.
stack
((
dx
,
dy
,
dw
,
dh
),
axis
=
1
)
# [n, 4]
means
=
paddle
.
to_tensor
(
means
,
place
=
src_boxes
.
place
)
stds
=
paddle
.
to_tensor
(
stds
,
place
=
src_boxes
.
place
)
deltas
=
(
deltas
-
means
)
/
stds
return
deltas
def
iou_similarity
(
box1
,
box2
,
eps
=
1e-10
):
"""Calculate iou of box1 and box2
...
...
ppdet/modeling/heads/__init__.py
浏览文件 @
41d8be66
...
...
@@ -36,6 +36,7 @@ from . import ppyoloe_head
from
.
import
fcosr_head
from
.
import
ppyoloe_r_head
from
.
import
ld_gfl_head
from
.
import
yolof_head
from
.bbox_head
import
*
from
.mask_head
import
*
...
...
@@ -61,3 +62,4 @@ from .ppyoloe_head import *
from
.fcosr_head
import
*
from
.ld_gfl_head
import
*
from
.ppyoloe_r_head
import
*
from
.yolof_head
import
*
ppdet/modeling/heads/yolof_head.py
0 → 100644
浏览文件 @
41d8be66
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle
import
ParamAttr
from
paddle.regularizer
import
L2Decay
from
paddle.nn.initializer
import
Normal
,
Constant
from
ppdet.modeling.layers
import
MultiClassNMS
from
ppdet.core.workspace
import
register
from
ppdet.modeling.bbox_utils
import
delta2bbox_v2
__all__
=
[
'YOLOFHead'
]
INF
=
1e8
def
reduce_mean
(
tensor
):
world_size
=
paddle
.
distributed
.
get_world_size
()
if
world_size
==
1
:
return
tensor
paddle
.
distributed
.
all_reduce
(
tensor
)
return
tensor
/
world_size
def
find_inside_anchor
(
feat_size
,
stride
,
num_anchors
,
im_shape
):
feat_h
,
feat_w
=
feat_size
[:
2
]
im_h
,
im_w
=
im_shape
[:
2
]
inside_h
=
min
(
int
(
np
.
ceil
(
im_h
/
stride
)),
feat_h
)
inside_w
=
min
(
int
(
np
.
ceil
(
im_w
/
stride
)),
feat_w
)
inside_mask
=
paddle
.
zeros
([
feat_h
,
feat_w
],
dtype
=
paddle
.
bool
)
inside_mask
[:
inside_h
,
:
inside_w
]
=
True
inside_mask
=
inside_mask
.
unsqueeze
(
-
1
).
expand
(
[
feat_h
,
feat_w
,
num_anchors
])
return
inside_mask
.
reshape
([
-
1
])
@
register
class
YOLOFFeat
(
nn
.
Layer
):
def
__init__
(
self
,
feat_in
=
256
,
feat_out
=
256
,
num_cls_convs
=
2
,
num_reg_convs
=
4
,
norm_type
=
'bn'
):
super
(
YOLOFFeat
,
self
).
__init__
()
assert
norm_type
==
'bn'
,
"YOLOFFeat only support BN now."
self
.
feat_in
=
feat_in
self
.
feat_out
=
feat_out
self
.
num_cls_convs
=
num_cls_convs
self
.
num_reg_convs
=
num_reg_convs
self
.
norm_type
=
norm_type
cls_subnet
,
reg_subnet
=
[],
[]
for
i
in
range
(
self
.
num_cls_convs
):
feat_in
=
self
.
feat_in
if
i
==
0
else
self
.
feat_out
cls_subnet
.
append
(
nn
.
Conv2D
(
feat_in
,
self
.
feat_out
,
3
,
stride
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0.0
,
std
=
0.01
)),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
value
=
0.0
))))
cls_subnet
.
append
(
nn
.
BatchNorm2D
(
self
.
feat_out
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
))))
cls_subnet
.
append
(
nn
.
ReLU
())
for
i
in
range
(
self
.
num_reg_convs
):
feat_in
=
self
.
feat_in
if
i
==
0
else
self
.
feat_out
reg_subnet
.
append
(
nn
.
Conv2D
(
feat_in
,
self
.
feat_out
,
3
,
stride
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0.0
,
std
=
0.01
)),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
value
=
0.0
))))
reg_subnet
.
append
(
nn
.
BatchNorm2D
(
self
.
feat_out
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
))))
reg_subnet
.
append
(
nn
.
ReLU
())
self
.
cls_subnet
=
nn
.
Sequential
(
*
cls_subnet
)
self
.
reg_subnet
=
nn
.
Sequential
(
*
reg_subnet
)
def
forward
(
self
,
fpn_feat
):
cls_feat
=
self
.
cls_subnet
(
fpn_feat
)
reg_feat
=
self
.
reg_subnet
(
fpn_feat
)
return
cls_feat
,
reg_feat
@
register
class
YOLOFHead
(
nn
.
Layer
):
__shared__
=
[
'num_classes'
,
'trt'
,
'exclude_nms'
]
__inject__
=
[
'conv_feat'
,
'anchor_generator'
,
'bbox_assigner'
,
'loss_class'
,
'loss_bbox'
,
'nms'
]
def
__init__
(
self
,
num_classes
=
80
,
conv_feat
=
'YOLOFFeat'
,
anchor_generator
=
'AnchorGenerator'
,
bbox_assigner
=
'UniformAssigner'
,
loss_class
=
'FocalLoss'
,
loss_bbox
=
'GIoULoss'
,
ctr_clip
=
32.0
,
delta_mean
=
[
0.0
,
0.0
,
0.0
,
0.0
],
delta_std
=
[
1.0
,
1.0
,
1.0
,
1.0
],
nms
=
'MultiClassNMS'
,
prior_prob
=
0.01
,
nms_pre
=
1000
,
use_inside_anchor
=
False
,
trt
=
False
,
exclude_nms
=
False
):
super
(
YOLOFHead
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
conv_feat
=
conv_feat
self
.
anchor_generator
=
anchor_generator
self
.
na
=
self
.
anchor_generator
.
num_anchors
self
.
bbox_assigner
=
bbox_assigner
self
.
loss_class
=
loss_class
self
.
loss_bbox
=
loss_bbox
self
.
ctr_clip
=
ctr_clip
self
.
delta_mean
=
delta_mean
self
.
delta_std
=
delta_std
self
.
nms
=
nms
self
.
nms_pre
=
nms_pre
self
.
use_inside_anchor
=
use_inside_anchor
if
isinstance
(
self
.
nms
,
MultiClassNMS
)
and
trt
:
self
.
nms
.
trt
=
trt
self
.
exclude_nms
=
exclude_nms
bias_init_value
=
-
math
.
log
((
1
-
prior_prob
)
/
prior_prob
)
self
.
cls_score
=
self
.
add_sublayer
(
'cls_score'
,
nn
.
Conv2D
(
in_channels
=
conv_feat
.
feat_out
,
out_channels
=
self
.
num_classes
*
self
.
na
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0.0
,
std
=
0.01
)),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
value
=
bias_init_value
))))
self
.
bbox_pred
=
self
.
add_sublayer
(
'bbox_pred'
,
nn
.
Conv2D
(
in_channels
=
conv_feat
.
feat_out
,
out_channels
=
4
*
self
.
na
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0.0
,
std
=
0.01
)),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
value
=
0
))))
self
.
object_pred
=
self
.
add_sublayer
(
'object_pred'
,
nn
.
Conv2D
(
in_channels
=
conv_feat
.
feat_out
,
out_channels
=
self
.
na
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0.0
,
std
=
0.01
)),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
value
=
0
))))
def
forward
(
self
,
feats
,
targets
=
None
):
assert
len
(
feats
)
==
1
,
"YOLOF only has one level feature."
conv_cls_feat
,
conv_reg_feat
=
self
.
conv_feat
(
feats
[
0
])
cls_logits
=
self
.
cls_score
(
conv_cls_feat
)
objectness
=
self
.
object_pred
(
conv_reg_feat
)
bboxes_reg
=
self
.
bbox_pred
(
conv_reg_feat
)
N
,
C
,
H
,
W
=
paddle
.
shape
(
cls_logits
)[:]
cls_logits
=
cls_logits
.
reshape
((
N
,
self
.
na
,
self
.
num_classes
,
H
,
W
))
objectness
=
objectness
.
reshape
((
N
,
self
.
na
,
1
,
H
,
W
))
norm_cls_logits
=
cls_logits
+
objectness
-
paddle
.
log
(
1.0
+
paddle
.
clip
(
cls_logits
.
exp
(),
max
=
INF
)
+
paddle
.
clip
(
objectness
.
exp
(),
max
=
INF
))
norm_cls_logits
=
norm_cls_logits
.
reshape
((
N
,
C
,
H
,
W
))
anchors
=
self
.
anchor_generator
([
norm_cls_logits
])
if
self
.
training
:
yolof_losses
=
self
.
get_loss
(
[
anchors
[
0
],
norm_cls_logits
,
bboxes_reg
],
targets
)
return
yolof_losses
else
:
return
anchors
[
0
],
norm_cls_logits
,
bboxes_reg
def
get_loss
(
self
,
head_outs
,
targets
):
anchors
,
cls_logits
,
bbox_preds
=
head_outs
feat_size
=
cls_logits
.
shape
[
-
2
:]
cls_logits
=
cls_logits
.
transpose
([
0
,
2
,
3
,
1
])
cls_logits
=
cls_logits
.
reshape
([
0
,
-
1
,
self
.
num_classes
])
bbox_preds
=
bbox_preds
.
transpose
([
0
,
2
,
3
,
1
])
bbox_preds
=
bbox_preds
.
reshape
([
0
,
-
1
,
4
])
num_pos_list
=
[]
cls_pred_list
,
cls_tar_list
=
[],
[]
reg_pred_list
,
reg_tar_list
=
[],
[]
# find and gather preds and targets in each image
for
cls_logit
,
bbox_pred
,
gt_bbox
,
gt_class
,
im_shape
in
zip
(
cls_logits
,
bbox_preds
,
targets
[
'gt_bbox'
],
targets
[
'gt_class'
],
targets
[
'im_shape'
]):
if
self
.
use_inside_anchor
:
inside_mask
=
find_inside_anchor
(
feat_size
,
self
.
anchor_generator
.
strides
[
0
],
self
.
na
,
im_shape
.
tolist
())
cls_logit
=
cls_logit
[
inside_mask
]
bbox_pred
=
bbox_pred
[
inside_mask
]
anchors
=
anchors
[
inside_mask
]
bbox_pred
=
delta2bbox_v2
(
bbox_pred
,
anchors
,
self
.
delta_mean
,
self
.
delta_std
,
ctr_clip
=
self
.
ctr_clip
)
bbox_pred
=
bbox_pred
.
reshape
([
-
1
,
bbox_pred
.
shape
[
-
1
]])
# -2:ignore, -1:neg, >=0:pos
match_labels
,
pos_bbox_pred
,
pos_bbox_tar
=
self
.
bbox_assigner
(
bbox_pred
,
anchors
,
gt_bbox
)
pos_mask
=
(
match_labels
>=
0
)
neg_mask
=
(
match_labels
==
-
1
)
chosen_mask
=
paddle
.
logical_or
(
pos_mask
,
neg_mask
)
gt_class
=
gt_class
.
reshape
([
-
1
])
bg_class
=
paddle
.
to_tensor
(
[
self
.
num_classes
],
dtype
=
gt_class
.
dtype
)
# a trick to assign num_classes to negative targets
gt_class
=
paddle
.
concat
([
gt_class
,
bg_class
],
axis
=-
1
)
match_labels
=
paddle
.
where
(
neg_mask
,
paddle
.
full_like
(
match_labels
,
gt_class
.
size
-
1
),
match_labels
)
num_pos_list
.
append
(
max
(
1.0
,
pos_mask
.
sum
().
item
()))
cls_pred_list
.
append
(
cls_logit
[
chosen_mask
])
cls_tar_list
.
append
(
gt_class
[
match_labels
[
chosen_mask
]])
reg_pred_list
.
append
(
pos_bbox_pred
)
reg_tar_list
.
append
(
pos_bbox_tar
)
num_tot_pos
=
paddle
.
to_tensor
(
sum
(
num_pos_list
))
num_tot_pos
=
reduce_mean
(
num_tot_pos
).
item
()
num_tot_pos
=
max
(
1.0
,
num_tot_pos
)
cls_pred
=
paddle
.
concat
(
cls_pred_list
)
cls_tar
=
paddle
.
concat
(
cls_tar_list
)
cls_loss
=
self
.
loss_class
(
cls_pred
,
cls_tar
,
reduction
=
'sum'
)
/
num_tot_pos
reg_pred_list
=
[
_
for
_
in
reg_pred_list
if
_
is
not
None
]
reg_tar_list
=
[
_
for
_
in
reg_tar_list
if
_
is
not
None
]
if
len
(
reg_pred_list
)
==
0
:
reg_loss
=
bbox_preds
.
sum
()
*
0.0
else
:
reg_pred
=
paddle
.
concat
(
reg_pred_list
)
reg_tar
=
paddle
.
concat
(
reg_tar_list
)
reg_loss
=
self
.
loss_bbox
(
reg_pred
,
reg_tar
).
sum
()
/
num_tot_pos
yolof_losses
=
{
'loss'
:
cls_loss
+
reg_loss
,
'loss_cls'
:
cls_loss
,
'loss_reg'
:
reg_loss
,
}
return
yolof_losses
def
get_bboxes_single
(
self
,
anchors
,
cls_scores
,
bbox_preds
,
im_shape
,
scale_factor
,
rescale
=
True
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
mlvl_bboxes
=
[]
mlvl_scores
=
[]
for
anchor
,
cls_score
,
bbox_pred
in
zip
(
anchors
,
cls_scores
,
bbox_preds
):
cls_score
=
cls_score
.
reshape
([
-
1
,
self
.
num_classes
])
bbox_pred
=
bbox_pred
.
reshape
([
-
1
,
4
])
if
self
.
nms_pre
is
not
None
and
cls_score
.
shape
[
0
]
>
self
.
nms_pre
:
max_score
=
cls_score
.
max
(
axis
=
1
)
_
,
topk_inds
=
max_score
.
topk
(
self
.
nms_pre
)
bbox_pred
=
bbox_pred
.
gather
(
topk_inds
)
anchor
=
anchor
.
gather
(
topk_inds
)
cls_score
=
cls_score
.
gather
(
topk_inds
)
bbox_pred
=
delta2bbox_v2
(
bbox_pred
,
anchor
,
self
.
delta_mean
,
self
.
delta_std
,
max_shape
=
im_shape
,
ctr_clip
=
self
.
ctr_clip
).
squeeze
()
mlvl_bboxes
.
append
(
bbox_pred
)
mlvl_scores
.
append
(
F
.
sigmoid
(
cls_score
))
mlvl_bboxes
=
paddle
.
concat
(
mlvl_bboxes
)
mlvl_bboxes
=
paddle
.
squeeze
(
mlvl_bboxes
)
if
rescale
:
mlvl_bboxes
=
mlvl_bboxes
/
paddle
.
concat
(
[
scale_factor
[::
-
1
],
scale_factor
[::
-
1
]])
mlvl_scores
=
paddle
.
concat
(
mlvl_scores
)
mlvl_scores
=
mlvl_scores
.
transpose
([
1
,
0
])
return
mlvl_bboxes
,
mlvl_scores
def
decode
(
self
,
anchors
,
cls_scores
,
bbox_preds
,
im_shape
,
scale_factor
):
batch_bboxes
=
[]
batch_scores
=
[]
for
img_id
in
range
(
cls_scores
[
0
].
shape
[
0
]):
num_lvls
=
len
(
cls_scores
)
cls_score_list
=
[
cls_scores
[
i
][
img_id
]
for
i
in
range
(
num_lvls
)]
bbox_pred_list
=
[
bbox_preds
[
i
][
img_id
]
for
i
in
range
(
num_lvls
)]
bboxes
,
scores
=
self
.
get_bboxes_single
(
anchors
,
cls_score_list
,
bbox_pred_list
,
im_shape
[
img_id
],
scale_factor
[
img_id
])
batch_bboxes
.
append
(
bboxes
)
batch_scores
.
append
(
scores
)
batch_bboxes
=
paddle
.
stack
(
batch_bboxes
,
0
)
batch_scores
=
paddle
.
stack
(
batch_scores
,
0
)
return
batch_bboxes
,
batch_scores
def
post_process
(
self
,
head_outs
,
im_shape
,
scale_factor
):
anchors
,
cls_scores
,
bbox_preds
=
head_outs
cls_scores
=
cls_scores
.
transpose
([
0
,
2
,
3
,
1
])
bbox_preds
=
bbox_preds
.
transpose
([
0
,
2
,
3
,
1
])
pred_bboxes
,
pred_scores
=
self
.
decode
(
[
anchors
],
[
cls_scores
],
[
bbox_preds
],
im_shape
,
scale_factor
)
if
self
.
exclude_nms
:
# `exclude_nms=True` just use in benchmark
return
pred_bboxes
.
sum
(),
pred_scores
.
sum
()
else
:
bbox_pred
,
bbox_num
,
_
=
self
.
nms
(
pred_bboxes
,
pred_scores
)
return
bbox_pred
,
bbox_num
ppdet/modeling/necks/__init__.py
浏览文件 @
41d8be66
...
...
@@ -22,6 +22,7 @@ from . import csp_pan
from
.
import
es_pan
from
.
import
lc_pan
from
.
import
custom_pan
from
.
import
dilated_encoder
from
.fpn
import
*
from
.yolo_fpn
import
*
...
...
@@ -34,3 +35,4 @@ from .csp_pan import *
from
.es_pan
import
*
from
.lc_pan
import
*
from
.custom_pan
import
*
from
.dilated_encoder
import
*
ppdet/modeling/necks/dilated_encoder.py
0 → 100644
浏览文件 @
41d8be66
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
from
paddle
import
ParamAttr
from
paddle.regularizer
import
L2Decay
from
paddle.nn.initializer
import
KaimingUniform
,
Constant
,
Normal
from
ppdet.core.workspace
import
register
,
serializable
from
..shape_spec
import
ShapeSpec
__all__
=
[
'DilatedEncoder'
]
class
Bottleneck
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
mid_channels
,
dilation
):
super
(
Bottleneck
,
self
).
__init__
()
self
.
conv1
=
nn
.
Sequential
(
*
[
nn
.
Conv2D
(
in_channels
,
mid_channels
,
1
,
padding
=
0
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0
,
std
=
0.01
)),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
0.0
))),
nn
.
BatchNorm2D
(
mid_channels
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
))),
nn
.
ReLU
(),
])
self
.
conv2
=
nn
.
Sequential
(
*
[
nn
.
Conv2D
(
mid_channels
,
mid_channels
,
3
,
padding
=
dilation
,
dilation
=
dilation
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0
,
std
=
0.01
)),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
0.0
))),
nn
.
BatchNorm2D
(
mid_channels
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
))),
nn
.
ReLU
(),
])
self
.
conv3
=
nn
.
Sequential
(
*
[
nn
.
Conv2D
(
mid_channels
,
in_channels
,
1
,
padding
=
0
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0
,
std
=
0.01
)),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
0.0
))),
nn
.
BatchNorm2D
(
in_channels
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
))),
nn
.
ReLU
(),
])
def
forward
(
self
,
x
):
identity
=
x
y
=
self
.
conv3
(
self
.
conv2
(
self
.
conv1
(
x
)))
return
y
+
identity
@
register
class
DilatedEncoder
(
nn
.
Layer
):
"""
DilatedEncoder used in YOLOF
"""
def
__init__
(
self
,
in_channels
=
[
2048
],
out_channels
=
[
512
],
block_mid_channels
=
128
,
num_residual_blocks
=
4
,
block_dilations
=
[
2
,
4
,
6
,
8
]):
super
(
DilatedEncoder
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
assert
len
(
self
.
in_channels
)
==
1
,
"YOLOF only has one level feature."
assert
len
(
self
.
out_channels
)
==
1
,
"YOLOF only has one level feature."
self
.
block_mid_channels
=
block_mid_channels
self
.
num_residual_blocks
=
num_residual_blocks
self
.
block_dilations
=
block_dilations
out_ch
=
self
.
out_channels
[
0
]
self
.
lateral_conv
=
nn
.
Conv2D
(
self
.
in_channels
[
0
],
out_ch
,
1
,
weight_attr
=
ParamAttr
(
initializer
=
KaimingUniform
(
negative_slope
=
1
,
nonlinearity
=
'leaky_relu'
)),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
value
=
0.0
)))
self
.
lateral_norm
=
nn
.
BatchNorm2D
(
out_ch
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
self
.
fpn_conv
=
nn
.
Conv2D
(
out_ch
,
out_ch
,
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
KaimingUniform
(
negative_slope
=
1
,
nonlinearity
=
'leaky_relu'
)))
self
.
fpn_norm
=
nn
.
BatchNorm2D
(
out_ch
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
encoder_blocks
=
[]
for
i
in
range
(
self
.
num_residual_blocks
):
encoder_blocks
.
append
(
Bottleneck
(
out_ch
,
self
.
block_mid_channels
,
dilation
=
block_dilations
[
i
]))
self
.
dilated_encoder_blocks
=
nn
.
Sequential
(
*
encoder_blocks
)
def
forward
(
self
,
inputs
,
for_mot
=
False
):
out
=
self
.
lateral_norm
(
self
.
lateral_conv
(
inputs
[
0
]))
out
=
self
.
fpn_norm
(
self
.
fpn_conv
(
out
))
out
=
self
.
dilated_encoder_blocks
(
out
)
return
[
out
]
@
classmethod
def
from_config
(
cls
,
cfg
,
input_shape
):
return
{
'in_channels'
:
[
i
.
channels
for
i
in
input_shape
],
}
@
property
def
out_shape
(
self
):
return
[
ShapeSpec
(
channels
=
c
)
for
c
in
self
.
out_channels
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录