Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
dfb8ea1e
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dfb8ea1e
编写于
11月 30, 2021
作者:
S
shangliang Xu
提交者:
GitHub
11月 30, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[TOOD] fix dy2st (#4751)
上级
a5409751
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
113 addition
and
98 deletion
+113
-98
configs/tood/README.md
configs/tood/README.md
+1
-1
configs/tood/_base_/tood_reader.yml
configs/tood/_base_/tood_reader.yml
+3
-3
deploy/python/infer.py
deploy/python/infer.py
+2
-1
ppdet/engine/export_utils.py
ppdet/engine/export_utils.py
+1
-0
ppdet/modeling/assigners/utils.py
ppdet/modeling/assigners/utils.py
+45
-0
ppdet/modeling/bbox_utils.py
ppdet/modeling/bbox_utils.py
+24
-2
ppdet/modeling/heads/tood_head.py
ppdet/modeling/heads/tood_head.py
+31
-91
ppdet/modeling/ops.py
ppdet/modeling/ops.py
+6
-0
未找到文件。
configs/tood/README.md
浏览文件 @
dfb8ea1e
...
...
@@ -11,7 +11,7 @@ TOOD is an object detection model. We reproduced the model of the paper.
| Backbone | Model | Images/GPU | Inf time (fps) | Box AP | Config | Download |
|:------:|:--------:|:--------:|:--------------:|:------:|:------:|:--------:|
| R-50 | TOOD | 4 | --- | 42.
8
|
[
config
](
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/tood/tood_r50_fpn_1x_coco.yml
)
|
[
model
](
https://paddledet.bj.bcebos.com/models/tood_r50_fpn_1x_coco.pdparams
)
|
| R-50 | TOOD | 4 | --- | 42.
5
|
[
config
](
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/tood/tood_r50_fpn_1x_coco.yml
)
|
[
model
](
https://paddledet.bj.bcebos.com/models/tood_r50_fpn_1x_coco.pdparams
)
|
**Notes:**
...
...
configs/tood/_base_/tood_reader.yml
浏览文件 @
dfb8ea1e
...
...
@@ -3,7 +3,7 @@ TrainReader:
sample_transforms
:
-
Decode
:
{}
-
RandomFlip
:
{
prob
:
0.5
}
-
Resize
:
{
target_size
:
[
800
,
1333
],
keep_ratio
:
true
,
interp
:
1
}
-
Resize
:
{
target_size
:
[
800
,
1333
],
keep_ratio
:
true
}
-
NormalizeImage
:
{
is_scale
:
true
,
mean
:
[
0.485
,
0.456
,
0.406
],
std
:
[
0.229
,
0.224
,
0.225
]}
-
Permute
:
{}
batch_transforms
:
...
...
@@ -18,7 +18,7 @@ TrainReader:
EvalReader
:
sample_transforms
:
-
Decode
:
{}
-
Resize
:
{
interp
:
1
,
target_size
:
[
800
,
1333
],
keep_ratio
:
True
}
-
Resize
:
{
target_size
:
[
800
,
1333
],
keep_ratio
:
True
}
-
NormalizeImage
:
{
is_scale
:
true
,
mean
:
[
0.485
,
0.456
,
0.406
],
std
:
[
0.229
,
0.224
,
0.225
]}
-
Permute
:
{}
batch_transforms
:
...
...
@@ -30,7 +30,7 @@ EvalReader:
TestReader
:
sample_transforms
:
-
Decode
:
{}
-
Resize
:
{
interp
:
1
,
target_size
:
[
800
,
1333
],
keep_ratio
:
True
}
-
Resize
:
{
target_size
:
[
800
,
1333
],
keep_ratio
:
True
}
-
NormalizeImage
:
{
is_scale
:
true
,
mean
:
[
0.485
,
0.456
,
0.406
],
std
:
[
0.229
,
0.224
,
0.225
]}
-
Permute
:
{}
batch_transforms
:
...
...
deploy/python/infer.py
浏览文件 @
dfb8ea1e
...
...
@@ -46,6 +46,7 @@ SUPPORT_MODELS = {
'GFL'
,
'PicoDet'
,
'CenterNet'
,
'TOOD'
,
}
...
...
@@ -680,7 +681,7 @@ def predict_video(detector, camera_id):
if
not
os
.
path
.
exists
(
FLAGS
.
output_dir
):
os
.
makedirs
(
FLAGS
.
output_dir
)
out_path
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
video_out_name
)
fourcc
=
cv2
.
VideoWriter_fourcc
(
*
'mp4v'
)
fourcc
=
cv2
.
VideoWriter_fourcc
(
*
'mp4v'
)
writer
=
cv2
.
VideoWriter
(
out_path
,
fourcc
,
fps
,
(
width
,
height
))
index
=
1
while
(
1
):
...
...
ppdet/engine/export_utils.py
浏览文件 @
dfb8ea1e
...
...
@@ -46,6 +46,7 @@ TRT_MIN_SUBGRAPH = {
'GFL'
:
16
,
'PicoDet'
:
3
,
'CenterNet'
:
5
,
'TOOD'
:
5
,
}
KEYPOINT_ARCH
=
[
'HigherHRNet'
,
'TopDownHRNet'
]
...
...
ppdet/modeling/assigners/utils.py
浏览文件 @
dfb8ea1e
...
...
@@ -19,6 +19,12 @@ from __future__ import print_function
import
paddle
import
paddle.nn.functional
as
F
__all__
=
[
'pad_gt'
,
'gather_topk_anchors'
,
'check_points_inside_bboxes'
,
'compute_max_iou_anchor'
,
'compute_max_iou_gt'
,
'generate_anchors_for_grid_cell'
]
def
pad_gt
(
gt_labels
,
gt_bboxes
,
gt_scores
=
None
):
r
""" Pad 0 in gt_labels and gt_bboxes.
...
...
@@ -147,3 +153,42 @@ def compute_max_iou_gt(ious):
max_iou_index
=
ious
.
argmax
(
axis
=-
1
)
is_max_iou
=
F
.
one_hot
(
max_iou_index
,
num_anchors
)
return
is_max_iou
.
astype
(
ious
.
dtype
)
def
generate_anchors_for_grid_cell
(
feats
,
fpn_strides
,
grid_cell_size
=
5.0
,
grid_cell_offset
=
0.5
):
r
"""
Like ATSS, generate anchors based on grid size.
Args:
feats (List[Tensor]): shape[s, (b, c, h, w)]
fpn_strides (tuple|list): shape[s], stride for each scale feature
grid_cell_size (float): anchor size
grid_cell_offset (float): The range is between 0 and 1.
Returns:
anchors (List[Tensor]): shape[s, (l, 4)]
num_anchors_list (List[int]): shape[s]
stride_tensor_list (List[Tensor]): shape[s, (l, 1)]
"""
assert
len
(
feats
)
==
len
(
fpn_strides
)
anchors
=
[]
num_anchors_list
=
[]
stride_tensor_list
=
[]
for
feat
,
stride
in
zip
(
feats
,
fpn_strides
):
_
,
_
,
h
,
w
=
feat
.
shape
cell_half_size
=
grid_cell_size
*
stride
*
0.5
shift_x
=
(
paddle
.
arange
(
end
=
w
)
+
grid_cell_offset
)
*
stride
shift_y
=
(
paddle
.
arange
(
end
=
h
)
+
grid_cell_offset
)
*
stride
shift_y
,
shift_x
=
paddle
.
meshgrid
(
shift_y
,
shift_x
)
anchor
=
paddle
.
stack
(
[
shift_x
-
cell_half_size
,
shift_y
-
cell_half_size
,
shift_x
+
cell_half_size
,
shift_y
+
cell_half_size
],
axis
=-
1
).
astype
(
feat
.
dtype
)
anchors
.
append
(
anchor
.
reshape
([
-
1
,
4
]))
num_anchors_list
.
append
(
len
(
anchors
[
-
1
]))
stride_tensor_list
.
append
(
paddle
.
full
([
num_anchors_list
[
-
1
],
1
],
stride
))
return
anchors
,
num_anchors_list
,
stride_tensor_list
ppdet/modeling/bbox_utils.py
浏览文件 @
dfb8ea1e
...
...
@@ -748,6 +748,28 @@ def bbox_center(boxes):
Returns:
Tensor: boxes centers with shape (N, 2), "cx, cy" format.
"""
boxes_cx
=
(
boxes
[
:,
0
]
+
boxes
[:
,
2
])
/
2
boxes_cy
=
(
boxes
[
:,
1
]
+
boxes
[:
,
3
])
/
2
boxes_cx
=
(
boxes
[
...,
0
]
+
boxes
[...
,
2
])
/
2
boxes_cy
=
(
boxes
[
...,
1
]
+
boxes
[...
,
3
])
/
2
return
paddle
.
stack
([
boxes_cx
,
boxes_cy
],
axis
=-
1
)
def
batch_distance2bbox
(
points
,
distance
,
max_shapes
=
None
):
"""Decode distance prediction to bounding box for batch.
Args:
points (Tensor): [B, ..., 2]
distance (Tensor): [B, ..., 4]
max_shapes (tuple): [B, 2], "h,w" format, Shape of the image.
Returns:
Tensor: Decoded bboxes.
"""
x1
=
points
[...,
0
]
-
distance
[...,
0
]
y1
=
points
[...,
1
]
-
distance
[...,
1
]
x2
=
points
[...,
0
]
+
distance
[...,
2
]
y2
=
points
[...,
1
]
+
distance
[...,
3
]
if
max_shapes
is
not
None
:
for
i
,
max_shape
in
enumerate
(
max_shapes
):
x1
[
i
]
=
x1
[
i
].
clip
(
min
=
0
,
max
=
max_shape
[
1
])
y1
[
i
]
=
y1
[
i
].
clip
(
min
=
0
,
max
=
max_shape
[
0
])
x2
[
i
]
=
x2
[
i
].
clip
(
min
=
0
,
max
=
max_shape
[
1
])
y2
[
i
]
=
y2
[
i
].
clip
(
min
=
0
,
max
=
max_shape
[
0
])
return
paddle
.
stack
([
x1
,
y1
,
x2
,
y2
],
-
1
)
ppdet/modeling/heads/tood_head.py
浏览文件 @
dfb8ea1e
...
...
@@ -24,10 +24,11 @@ from paddle.nn.initializer import Constant
from
ppdet.core.workspace
import
register
from
..initializer
import
normal_
,
constant_
,
bias_init_with_prob
from
ppdet.modeling.bbox_utils
import
bbox_center
from
ppdet.modeling.bbox_utils
import
bbox_center
,
batch_distance2bbox
from
..losses
import
GIoULoss
from
paddle.vision.ops
import
deform_conv2d
from
ppdet.modeling.layers
import
ConvNormLayer
from
ppdet.modeling.ops
import
get_static_shape
from
ppdet.modeling.assigners.utils
import
generate_anchors_for_grid_cell
class
ScaleReg
(
nn
.
Layer
):
...
...
@@ -84,25 +85,13 @@ class TaskDecomposition(nn.Layer):
normal_
(
self
.
la_conv1
.
weight
,
std
=
0.001
)
normal_
(
self
.
la_conv2
.
weight
,
std
=
0.001
)
def
forward
(
self
,
feat
,
avg_feat
=
None
):
b
,
_
,
h
,
w
=
feat
.
shape
if
avg_feat
is
None
:
avg_feat
=
F
.
adaptive_avg_pool2d
(
feat
,
(
1
,
1
))
def
forward
(
self
,
feat
,
avg_feat
):
b
,
_
,
h
,
w
=
get_static_shape
(
feat
)
weight
=
F
.
relu
(
self
.
la_conv1
(
avg_feat
))
weight
=
F
.
sigmoid
(
self
.
la_conv2
(
weight
))
# here new_conv_weight = layer_attention_weight * conv_weight
# in order to save memory and FLOPs.
conv_weight
=
weight
.
reshape
([
b
,
1
,
self
.
stacked_convs
,
1
])
*
\
self
.
reduction_conv
.
conv
.
weight
.
reshape
(
[
1
,
self
.
feat_channels
,
self
.
stacked_convs
,
self
.
feat_channels
])
conv_weight
=
conv_weight
.
reshape
(
[
b
,
self
.
feat_channels
,
self
.
in_channels
])
feat
=
feat
.
reshape
([
b
,
self
.
in_channels
,
h
*
w
])
feat
=
paddle
.
bmm
(
conv_weight
,
feat
).
reshape
(
[
b
,
self
.
feat_channels
,
h
,
w
])
if
self
.
norm_type
is
not
None
:
feat
=
self
.
reduction_conv
.
norm
(
feat
)
weight
=
F
.
sigmoid
(
self
.
la_conv2
(
weight
)).
unsqueeze
(
-
1
)
feat
=
paddle
.
reshape
(
feat
,
[
b
,
self
.
stacked_convs
,
self
.
feat_channels
,
h
,
w
])
*
weight
feat
=
self
.
reduction_conv
(
feat
.
flatten
(
1
,
2
))
feat
=
F
.
relu
(
feat
)
return
feat
...
...
@@ -211,81 +200,32 @@ class TOODHead(nn.Layer):
normal_
(
self
.
cls_prob_conv2
.
weight
,
std
=
0.01
)
constant_
(
self
.
cls_prob_conv2
.
bias
,
bias_cls
)
normal_
(
self
.
reg_offset_conv1
.
weight
,
std
=
0.001
)
normal_
(
self
.
reg_offset_conv2
.
weight
,
std
=
0.001
)
constant_
(
self
.
reg_offset_conv2
.
weight
)
constant_
(
self
.
reg_offset_conv2
.
bias
)
def
_generate_anchors
(
self
,
feats
):
anchors
,
num_anchors_list
=
[],
[]
stride_tensor_list
=
[]
for
feat
,
stride
in
zip
(
feats
,
self
.
fpn_strides
):
_
,
_
,
h
,
w
=
feat
.
shape
cell_half_size
=
self
.
grid_cell_scale
*
stride
*
0.5
shift_x
=
(
paddle
.
arange
(
end
=
w
)
+
self
.
grid_cell_offset
)
*
stride
shift_y
=
(
paddle
.
arange
(
end
=
h
)
+
self
.
grid_cell_offset
)
*
stride
shift_y
,
shift_x
=
paddle
.
meshgrid
(
shift_y
,
shift_x
)
anchor
=
paddle
.
stack
(
[
shift_x
-
cell_half_size
,
shift_y
-
cell_half_size
,
shift_x
+
cell_half_size
,
shift_y
+
cell_half_size
],
axis
=-
1
)
anchors
.
append
(
anchor
.
reshape
([
-
1
,
4
]))
num_anchors_list
.
append
(
len
(
anchors
[
-
1
]))
stride_tensor_list
.
append
(
paddle
.
full
([
num_anchors_list
[
-
1
],
1
],
stride
))
return
anchors
,
num_anchors_list
,
stride_tensor_list
@
staticmethod
def
_batch_distance2bbox
(
points
,
distance
,
max_shapes
=
None
):
"""Decode distance prediction to bounding box.
Args:
points (Tensor): [B, l, 2]
distance (Tensor): [B, l, 4]
max_shapes (tuple): [B, 2], "h w" format, Shape of the image.
Returns:
Tensor: Decoded bboxes.
"""
x1
=
points
[:,
:,
0
]
-
distance
[:,
:,
0
]
y1
=
points
[:,
:,
1
]
-
distance
[:,
:,
1
]
x2
=
points
[:,
:,
0
]
+
distance
[:,
:,
2
]
y2
=
points
[:,
:,
1
]
+
distance
[:,
:,
3
]
bboxes
=
paddle
.
stack
([
x1
,
y1
,
x2
,
y2
],
-
1
)
if
max_shapes
is
not
None
:
out_bboxes
=
[]
for
bbox
,
max_shape
in
zip
(
bboxes
,
max_shapes
):
bbox
[:,
0
]
=
bbox
[:,
0
].
clip
(
min
=
0
,
max
=
max_shape
[
1
])
bbox
[:,
1
]
=
bbox
[:,
1
].
clip
(
min
=
0
,
max
=
max_shape
[
0
])
bbox
[:,
2
]
=
bbox
[:,
2
].
clip
(
min
=
0
,
max
=
max_shape
[
1
])
bbox
[:,
3
]
=
bbox
[:,
3
].
clip
(
min
=
0
,
max
=
max_shape
[
0
])
out_bboxes
.
append
(
bbox
)
out_bboxes
=
paddle
.
stack
(
out_bboxes
)
return
out_bboxes
return
bboxes
@
staticmethod
def
_deform_sampling
(
feat
,
offset
):
""" Sampling the feature according to offset.
Args:
feat (Tensor): Feature
offset (Tensor): Spatial offset for for feature sampliing
"""
# it is an equivalent implementation of bilinear interpolation
# you can also use F.grid_sample instead
c
=
feat
.
shape
[
1
]
weight
=
paddle
.
ones
([
c
,
1
,
1
,
1
])
y
=
deform_conv2d
(
feat
,
offset
,
weight
,
deformable_groups
=
c
,
groups
=
c
)
return
y
def
_reg_grid_sample
(
self
,
feat
,
offset
,
anchor_points
):
b
,
_
,
h
,
w
=
get_static_shape
(
feat
)
feat
=
paddle
.
reshape
(
feat
,
[
-
1
,
1
,
h
,
w
])
offset
=
paddle
.
reshape
(
offset
,
[
-
1
,
2
,
h
,
w
]).
transpose
([
0
,
2
,
3
,
1
])
grid_shape
=
paddle
.
concat
([
w
,
h
]).
astype
(
'float32'
)
grid
=
(
offset
+
anchor_points
)
/
grid_shape
grid
=
2
*
grid
.
clip
(
0.
,
1.
)
-
1
feat
=
F
.
grid_sample
(
feat
,
grid
)
feat
=
paddle
.
reshape
(
feat
,
[
b
,
-
1
,
h
,
w
])
return
feat
def
forward
(
self
,
feats
):
assert
len
(
feats
)
==
len
(
self
.
fpn_strides
),
\
"The size of feats is not equal to size of fpn_strides"
anchors
,
num_anchors_list
,
stride_tensor_list
=
self
.
_generate_anchors
(
feats
)
anchors
,
num_anchors_list
,
stride_tensor_list
=
generate_anchors_for_grid_cell
(
feats
,
self
.
fpn_strides
,
self
.
grid_cell_scale
,
self
.
grid_cell_offset
)
cls_score_list
,
bbox_pred_list
=
[],
[]
for
feat
,
scale_reg
,
anchor
,
stride
in
zip
(
feats
,
self
.
scales_regs
,
anchors
,
self
.
fpn_strides
):
b
,
_
,
h
,
w
=
feat
.
shape
b
,
_
,
h
,
w
=
get_static_shape
(
feat
)
inter_feats
=
[]
for
inter_conv
in
self
.
inter_convs
:
feat
=
F
.
relu
(
inter_conv
(
feat
))
...
...
@@ -309,16 +249,16 @@ class TOODHead(nn.Layer):
# reg prediction and alignment
reg_dist
=
scale_reg
(
self
.
tood_reg
(
reg_feat
).
exp
())
reg_dist
=
reg_dist
.
transpose
([
0
,
2
,
3
,
1
]).
reshape
([
b
,
-
1
,
4
])
reg_dist
=
reg_dist
.
flatten
(
2
).
transpose
([
0
,
2
,
1
])
anchor_centers
=
bbox_center
(
anchor
).
unsqueeze
(
0
)
/
stride
reg_bbox
=
self
.
_batch_distance2bbox
(
anchor_centers
.
tile
([
b
,
1
,
1
]),
reg_dist
)
reg_bbox
=
batch_distance2bbox
(
anchor_centers
,
reg_dist
)
if
self
.
use_align_head
:
reg_bbox
=
reg_bbox
.
reshape
([
b
,
h
,
w
,
4
]).
transpose
(
[
0
,
3
,
1
,
2
])
reg_offset
=
F
.
relu
(
self
.
reg_offset_conv1
(
feat
))
reg_offset
=
self
.
reg_offset_conv2
(
reg_offset
)
bbox_pred
=
self
.
_deform_sampling
(
reg_bbox
,
reg_offset
)
reg_bbox
=
reg_bbox
.
transpose
([
0
,
2
,
1
]).
reshape
([
b
,
4
,
h
,
w
])
anchor_centers
=
anchor_centers
.
reshape
([
1
,
h
,
w
,
2
])
bbox_pred
=
self
.
_reg_grid_sample
(
reg_bbox
,
reg_offset
,
anchor_centers
)
bbox_pred
=
bbox_pred
.
flatten
(
2
).
transpose
([
0
,
2
,
1
])
else
:
bbox_pred
=
reg_bbox
...
...
ppdet/modeling/ops.py
浏览文件 @
dfb8ea1e
...
...
@@ -1600,3 +1600,9 @@ def channel_shuffle(x, groups):
x
=
paddle
.
transpose
(
x
=
x
,
perm
=
[
0
,
2
,
1
,
3
,
4
])
x
=
paddle
.
reshape
(
x
=
x
,
shape
=
[
batch_size
,
num_channels
,
height
,
width
])
return
x
def
get_static_shape
(
tensor
):
shape
=
paddle
.
shape
(
tensor
)
shape
.
stop_gradient
=
True
return
shape
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录