Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
10bf8de7
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
10bf8de7
编写于
4月 19, 2022
作者:
F
Feng Ni
提交者:
GitHub
4月 19, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add YOLOX codes (#5740)
上级
fecae1ee
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
1574 addition
and
87 deletion
+1574
-87
deploy/python/infer.py
deploy/python/infer.py
+20
-4
deploy/python/preprocess.py
deploy/python/preprocess.py
+75
-0
ppdet/data/source/dataset.py
ppdet/data/source/dataset.py
+2
-2
ppdet/data/transform/operators.py
ppdet/data/transform/operators.py
+416
-20
ppdet/engine/export_utils.py
ppdet/engine/export_utils.py
+7
-0
ppdet/engine/trainer.py
ppdet/engine/trainer.py
+9
-1
ppdet/modeling/architectures/__init__.py
ppdet/modeling/architectures/__init__.py
+9
-0
ppdet/modeling/architectures/yolox.py
ppdet/modeling/architectures/yolox.py
+144
-0
ppdet/modeling/assigners/utils.py
ppdet/modeling/assigners/utils.py
+19
-16
ppdet/modeling/backbones/__init__.py
ppdet/modeling/backbones/__init__.py
+14
-12
ppdet/modeling/backbones/csp_darknet.py
ppdet/modeling/backbones/csp_darknet.py
+403
-0
ppdet/modeling/heads/yolo_head.py
ppdet/modeling/heads/yolo_head.py
+280
-0
ppdet/modeling/initializer.py
ppdet/modeling/initializer.py
+2
-1
ppdet/modeling/necks/yolo_fpn.py
ppdet/modeling/necks/yolo_fpn.py
+113
-13
ppdet/modeling/ops.py
ppdet/modeling/ops.py
+25
-13
ppdet/optimizer.py
ppdet/optimizer.py
+36
-5
未找到文件。
deploy/python/infer.py
浏览文件 @
10bf8de7
...
...
@@ -31,16 +31,32 @@ sys.path.insert(0, parent_path)
from
benchmark_utils
import
PaddleInferBenchmark
from
picodet_postprocess
import
PicoDetPostProcess
from
preprocess
import
preprocess
,
Resize
,
NormalizeImage
,
Permute
,
PadStride
,
LetterBoxResize
,
WarpAffine
,
decode_image
from
preprocess
import
preprocess
,
Resize
,
NormalizeImage
,
Permute
,
PadStride
,
LetterBoxResize
,
WarpAffine
,
Pad
,
decode_image
from
keypoint_preprocess
import
EvalAffine
,
TopDownEvalAffine
,
expand_crop
from
visualize
import
visualize_box_mask
from
utils
import
argsparser
,
Timer
,
get_current_memory_mb
# Global dictionary
SUPPORT_MODELS
=
{
'YOLO'
,
'RCNN'
,
'SSD'
,
'Face'
,
'FCOS'
,
'SOLOv2'
,
'TTFNet'
,
'S2ANet'
,
'JDE'
,
'FairMOT'
,
'DeepSORT'
,
'GFL'
,
'PicoDet'
,
'CenterNet'
,
'TOOD'
,
'RetinaNet'
,
'StrongBaseline'
,
'STGCN'
'YOLO'
,
'RCNN'
,
'SSD'
,
'Face'
,
'FCOS'
,
'SOLOv2'
,
'TTFNet'
,
'S2ANet'
,
'JDE'
,
'FairMOT'
,
'DeepSORT'
,
'GFL'
,
'PicoDet'
,
'CenterNet'
,
'TOOD'
,
'RetinaNet'
,
'StrongBaseline'
,
'STGCN'
,
'YOLOX'
,
}
...
...
deploy/python/preprocess.py
浏览文件 @
10bf8de7
...
...
@@ -246,6 +246,81 @@ class LetterBoxResize(object):
return
im
,
im_info
class
Pad
(
object
):
def
__init__
(
self
,
size
=
None
,
size_divisor
=
32
,
pad_mode
=
0
,
offsets
=
None
,
fill_value
=
(
127.5
,
127.5
,
127.5
)):
"""
Pad image to a specified size or multiple of size_divisor.
Args:
size (int, Sequence): image target size, if None, pad to multiple of size_divisor, default None
size_divisor (int): size divisor, default 32
pad_mode (int): pad mode, currently only supports four modes [-1, 0, 1, 2]. if -1, use specified offsets
if 0, only pad to right and bottom. if 1, pad according to center. if 2, only pad left and top
offsets (list): [offset_x, offset_y], specify offset while padding, only supported pad_mode=-1
fill_value (bool): rgb value of pad area, default (127.5, 127.5, 127.5)
"""
super
(
Pad
,
self
).
__init__
()
if
isinstance
(
size
,
int
):
size
=
[
size
,
size
]
assert
pad_mode
in
[
-
1
,
0
,
1
,
2
],
'currently only supports four modes [-1, 0, 1, 2]'
if
pad_mode
==
-
1
:
assert
offsets
,
'if pad_mode is -1, offsets should not be None'
self
.
size
=
size
self
.
size_divisor
=
size_divisor
self
.
pad_mode
=
pad_mode
self
.
fill_value
=
fill_value
self
.
offsets
=
offsets
def
apply_image
(
self
,
image
,
offsets
,
im_size
,
size
):
x
,
y
=
offsets
im_h
,
im_w
=
im_size
h
,
w
=
size
canvas
=
np
.
ones
((
h
,
w
,
3
),
dtype
=
np
.
float32
)
canvas
*=
np
.
array
(
self
.
fill_value
,
dtype
=
np
.
float32
)
canvas
[
y
:
y
+
im_h
,
x
:
x
+
im_w
,
:]
=
image
.
astype
(
np
.
float32
)
return
canvas
def
__call__
(
self
,
im
,
im_info
):
im_h
,
im_w
=
im
.
shape
[:
2
]
if
self
.
size
:
h
,
w
=
self
.
size
assert
(
im_h
<=
h
and
im_w
<=
w
),
'(h, w) of target size should be greater than (im_h, im_w)'
else
:
h
=
int
(
np
.
ceil
(
im_h
/
self
.
size_divisor
)
*
self
.
size_divisor
)
w
=
int
(
np
.
ceil
(
im_w
/
self
.
size_divisor
)
*
self
.
size_divisor
)
if
h
==
im_h
and
w
==
im_w
:
im
=
im
.
astype
(
np
.
float32
)
return
im
,
im_info
if
self
.
pad_mode
==
-
1
:
offset_x
,
offset_y
=
self
.
offsets
elif
self
.
pad_mode
==
0
:
offset_y
,
offset_x
=
0
,
0
elif
self
.
pad_mode
==
1
:
offset_y
,
offset_x
=
(
h
-
im_h
)
//
2
,
(
w
-
im_w
)
//
2
else
:
offset_y
,
offset_x
=
h
-
im_h
,
w
-
im_w
offsets
,
im_size
,
size
=
[
offset_x
,
offset_y
],
[
im_h
,
im_w
],
[
h
,
w
]
im
=
self
.
apply_image
(
im
,
offsets
,
im_size
,
size
)
if
self
.
pad_mode
==
0
:
return
im
,
im_info
return
im
,
im_info
class
WarpAffine
(
object
):
"""Warp affine the image
"""
...
...
ppdet/data/source/dataset.py
浏览文件 @
10bf8de7
...
...
@@ -5,7 +5,7 @@
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -75,7 +75,7 @@ class DetDataset(Dataset):
n
=
len
(
self
.
roidbs
)
roidb
=
[
roidb
,
]
+
[
copy
.
deepcopy
(
self
.
roidbs
[
np
.
random
.
randint
(
n
)])
for
_
in
range
(
3
)
for
_
in
range
(
4
)
]
if
isinstance
(
roidb
,
Sequence
):
for
r
in
roidb
:
...
...
ppdet/data/transform/operators.py
浏览文件 @
10bf8de7
...
...
@@ -2034,13 +2034,14 @@ class Pad(BaseOperator):
if
self
.
size
:
h
,
w
=
self
.
size
assert
(
im_h
<
h
and
im_w
<
w
im_h
<
=
h
and
im_w
<=
w
),
'(h, w) of target size should be greater than (im_h, im_w)'
else
:
h
=
int
(
np
.
ceil
(
im_h
/
self
.
size_divisor
)
*
self
.
size_divisor
)
w
=
int
(
np
.
ceil
(
im_w
/
self
.
size_divisor
)
*
self
.
size_divisor
)
if
h
==
im_h
and
w
==
im_w
:
sample
[
'image'
]
=
im
.
astype
(
np
.
float32
)
return
sample
if
self
.
pad_mode
==
-
1
:
...
...
@@ -2139,16 +2140,29 @@ class Rbox2Poly(BaseOperator):
@
register_op
class
AugmentHSV
(
BaseOperator
):
def
__init__
(
self
,
fraction
=
0.50
,
is_bgr
=
True
):
"""
Augment the SV channel of image data.
Args:
fraction (float): the fraction for augment. Default: 0.5.
is_bgr (bool): whether the image is BGR mode. Default: True.
"""
"""
Augment the SV channel of image data.
Args:
fraction (float): the fraction for augment. Default: 0.5.
is_bgr (bool): whether the image is BGR mode. Default: True.
hgain (float): H channel gains
sgain (float): S channel gains
vgain (float): V channel gains
"""
def
__init__
(
self
,
fraction
=
0.50
,
is_bgr
=
True
,
hgain
=
None
,
sgain
=
None
,
vgain
=
None
):
super
(
AugmentHSV
,
self
).
__init__
()
self
.
fraction
=
fraction
self
.
is_bgr
=
is_bgr
self
.
hgain
=
hgain
self
.
sgain
=
sgain
self
.
vgain
=
vgain
self
.
use_hsvgain
=
False
if
hgain
is
None
else
True
def
apply
(
self
,
sample
,
context
=
None
):
img
=
sample
[
'image'
]
...
...
@@ -2156,21 +2170,33 @@ class AugmentHSV(BaseOperator):
img_hsv
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2HSV
)
else
:
img_hsv
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_RGB2HSV
)
S
=
img_hsv
[:,
:,
1
].
astype
(
np
.
float32
)
V
=
img_hsv
[:,
:,
2
].
astype
(
np
.
float32
)
a
=
(
random
.
random
()
*
2
-
1
)
*
self
.
fraction
+
1
S
*=
a
if
a
>
1
:
np
.
clip
(
S
,
a_min
=
0
,
a_max
=
255
,
out
=
S
)
if
self
.
use_hsvgain
:
hsv_augs
=
np
.
random
.
uniform
(
-
1
,
1
,
3
)
*
[
self
.
hgain
,
self
.
sgain
,
self
.
vgain
]
# random selection of h, s, v
hsv_augs
*=
np
.
random
.
randint
(
0
,
2
,
3
)
img_hsv
[...,
0
]
=
(
img_hsv
[...,
0
]
+
hsv_augs
[
0
])
%
180
img_hsv
[...,
1
]
=
np
.
clip
(
img_hsv
[...,
1
]
+
hsv_augs
[
1
],
0
,
255
)
img_hsv
[...,
2
]
=
np
.
clip
(
img_hsv
[...,
2
]
+
hsv_augs
[
2
],
0
,
255
)
a
=
(
random
.
random
()
*
2
-
1
)
*
self
.
fraction
+
1
V
*=
a
if
a
>
1
:
np
.
clip
(
V
,
a_min
=
0
,
a_max
=
255
,
out
=
V
)
else
:
S
=
img_hsv
[:,
:,
1
].
astype
(
np
.
float32
)
V
=
img_hsv
[:,
:,
2
].
astype
(
np
.
float32
)
a
=
(
random
.
random
()
*
2
-
1
)
*
self
.
fraction
+
1
S
*=
a
if
a
>
1
:
np
.
clip
(
S
,
a_min
=
0
,
a_max
=
255
,
out
=
S
)
a
=
(
random
.
random
()
*
2
-
1
)
*
self
.
fraction
+
1
V
*=
a
if
a
>
1
:
np
.
clip
(
V
,
a_min
=
0
,
a_max
=
255
,
out
=
V
)
img_hsv
[:,
:,
1
]
=
S
.
astype
(
np
.
uint8
)
img_hsv
[:,
:,
2
]
=
V
.
astype
(
np
.
uint8
)
img_hsv
[:,
:,
1
]
=
S
.
astype
(
np
.
uint8
)
img_hsv
[:,
:,
2
]
=
V
.
astype
(
np
.
uint8
)
if
self
.
is_bgr
:
cv2
.
cvtColor
(
img_hsv
,
cv2
.
COLOR_HSV2BGR
,
dst
=
img
)
else
:
...
...
@@ -3018,3 +3044,373 @@ class CenterRandColor(BaseOperator):
img
=
func
(
img
,
img_gray
)
sample
[
'image'
]
=
img
return
sample
@
register_op
class
Mosaic
(
BaseOperator
):
""" Mosaic operator for image and gt_bboxes
The code is based on https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/data/datasets/mosaicdetection.py
1. get mosaic coords
2. clip bbox and get mosaic_labels
3. random_affine augment
4. Mixup augment as copypaste (optinal), not used in tiny/nano
Args:
prob (float): probability of using Mosaic, 1.0 as default
input_dim (list[int]): input shape
degrees (list[2]): the rotate range to apply, transform range is [min, max]
translate (list[2]): the translate range to apply, transform range is [min, max]
scale (list[2]): the scale range to apply, transform range is [min, max]
shear (list[2]): the shear range to apply, transform range is [min, max]
enable_mixup (bool): whether to enable Mixup or not
mixup_prob (float): probability of using Mixup, 1.0 as default
mixup_scale (list[int]): scale range of Mixup
remove_outside_box (bool): whether remove outside boxes, False as
default in COCO dataset, True in MOT dataset
"""
def
__init__
(
self
,
prob
=
1.0
,
input_dim
=
[
640
,
640
],
degrees
=
[
-
10
,
10
],
translate
=
[
-
0.1
,
0.1
],
scale
=
[
0.1
,
2
],
shear
=
[
-
2
,
2
],
enable_mixup
=
True
,
mixup_prob
=
1.0
,
mixup_scale
=
[
0.5
,
1.5
],
remove_outside_box
=
False
):
super
(
Mosaic
,
self
).
__init__
()
self
.
prob
=
prob
self
.
input_dim
=
input_dim
self
.
degrees
=
degrees
self
.
translate
=
translate
self
.
scale
=
scale
self
.
shear
=
shear
self
.
enable_mixup
=
enable_mixup
self
.
mixup_prob
=
mixup_prob
self
.
mixup_scale
=
mixup_scale
self
.
remove_outside_box
=
remove_outside_box
def
get_mosaic_coords
(
self
,
mosaic_idx
,
xc
,
yc
,
w
,
h
,
input_h
,
input_w
):
# (x1, y1, x2, y2) means coords in large image,
# small_coords means coords in small image in mosaic aug.
if
mosaic_idx
==
0
:
# top left
x1
,
y1
,
x2
,
y2
=
max
(
xc
-
w
,
0
),
max
(
yc
-
h
,
0
),
xc
,
yc
small_coords
=
w
-
(
x2
-
x1
),
h
-
(
y2
-
y1
),
w
,
h
elif
mosaic_idx
==
1
:
# top right
x1
,
y1
,
x2
,
y2
=
xc
,
max
(
yc
-
h
,
0
),
min
(
xc
+
w
,
input_w
*
2
),
yc
small_coords
=
0
,
h
-
(
y2
-
y1
),
min
(
w
,
x2
-
x1
),
h
elif
mosaic_idx
==
2
:
# bottom left
x1
,
y1
,
x2
,
y2
=
max
(
xc
-
w
,
0
),
yc
,
xc
,
min
(
input_h
*
2
,
yc
+
h
)
small_coords
=
w
-
(
x2
-
x1
),
0
,
w
,
min
(
y2
-
y1
,
h
)
elif
mosaic_idx
==
3
:
# bottom right
x1
,
y1
,
x2
,
y2
=
xc
,
yc
,
min
(
xc
+
w
,
input_w
*
2
),
min
(
input_h
*
2
,
yc
+
h
)
small_coords
=
0
,
0
,
min
(
w
,
x2
-
x1
),
min
(
y2
-
y1
,
h
)
return
(
x1
,
y1
,
x2
,
y2
),
small_coords
def
random_affine_augment
(
self
,
img
,
labels
=
[],
input_dim
=
[
640
,
640
],
degrees
=
[
-
10
,
10
],
scales
=
[
0.1
,
2
],
shears
=
[
-
2
,
2
],
translates
=
[
-
0.1
,
0.1
]):
# random rotation and scale
degree
=
random
.
uniform
(
degrees
[
0
],
degrees
[
1
])
scale
=
random
.
uniform
(
scales
[
0
],
scales
[
1
])
assert
scale
>
0
,
"Argument scale should be positive."
R
=
cv2
.
getRotationMatrix2D
(
angle
=
degree
,
center
=
(
0
,
0
),
scale
=
scale
)
M
=
np
.
ones
([
2
,
3
])
# random shear
shear
=
random
.
uniform
(
shears
[
0
],
shears
[
1
])
shear_x
=
math
.
tan
(
shear
*
math
.
pi
/
180
)
shear_y
=
math
.
tan
(
shear
*
math
.
pi
/
180
)
M
[
0
]
=
R
[
0
]
+
shear_y
*
R
[
1
]
M
[
1
]
=
R
[
1
]
+
shear_x
*
R
[
0
]
# random translation
translate
=
random
.
uniform
(
translates
[
0
],
translates
[
1
])
translation_x
=
translate
*
input_dim
[
0
]
translation_y
=
translate
*
input_dim
[
1
]
M
[
0
,
2
]
=
translation_x
M
[
1
,
2
]
=
translation_y
# warpAffine
img
=
cv2
.
warpAffine
(
img
,
M
,
dsize
=
input_dim
,
borderValue
=
(
114
,
114
,
114
))
num_gts
=
len
(
labels
)
if
num_gts
>
0
:
# warp corner points
corner_points
=
np
.
ones
((
4
*
num_gts
,
3
))
corner_points
[:,
:
2
]
=
labels
[:,
[
0
,
1
,
2
,
3
,
0
,
3
,
2
,
1
]].
reshape
(
4
*
num_gts
,
2
)
# x1y1, x2y2, x1y2, x2y1
# apply affine transform
corner_points
=
corner_points
@
M
.
T
corner_points
=
corner_points
.
reshape
(
num_gts
,
8
)
# create new boxes
corner_xs
=
corner_points
[:,
0
::
2
]
corner_ys
=
corner_points
[:,
1
::
2
]
new_bboxes
=
np
.
concatenate
((
corner_xs
.
min
(
1
),
corner_ys
.
min
(
1
),
corner_xs
.
max
(
1
),
corner_ys
.
max
(
1
)))
new_bboxes
=
new_bboxes
.
reshape
(
4
,
num_gts
).
T
# clip boxes
new_bboxes
[:,
0
::
2
]
=
np
.
clip
(
new_bboxes
[:,
0
::
2
],
0
,
input_dim
[
0
])
new_bboxes
[:,
1
::
2
]
=
np
.
clip
(
new_bboxes
[:,
1
::
2
],
0
,
input_dim
[
1
])
labels
[:,
:
4
]
=
new_bboxes
return
img
,
labels
def
__call__
(
self
,
sample
,
context
=
None
):
if
not
isinstance
(
sample
,
Sequence
):
return
sample
assert
len
(
sample
)
==
5
,
"Mosaic needs 5 samples, 4 for mosaic and 1 for mixup."
if
np
.
random
.
uniform
(
0.
,
1.
)
>
self
.
prob
:
return
sample
[
0
]
mosaic_gt_bbox
,
mosaic_gt_class
,
mosaic_is_crowd
=
[],
[],
[]
input_h
,
input_w
=
self
.
input_dim
yc
=
int
(
random
.
uniform
(
0.5
*
input_h
,
1.5
*
input_h
))
xc
=
int
(
random
.
uniform
(
0.5
*
input_w
,
1.5
*
input_w
))
mosaic_img
=
np
.
full
((
input_h
*
2
,
input_w
*
2
,
3
),
114
,
dtype
=
np
.
uint8
)
# 1. get mosaic coords
for
mosaic_idx
,
sp
in
enumerate
(
sample
[:
4
]):
img
=
sp
[
'image'
]
gt_bbox
=
sp
[
'gt_bbox'
]
h0
,
w0
=
img
.
shape
[:
2
]
scale
=
min
(
1.
*
input_h
/
h0
,
1.
*
input_w
/
w0
)
img
=
cv2
.
resize
(
img
,
(
int
(
w0
*
scale
),
int
(
h0
*
scale
)),
interpolation
=
cv2
.
INTER_LINEAR
)
(
h
,
w
,
c
)
=
img
.
shape
[:
3
]
# suffix l means large image, while s means small image in mosaic aug.
(
l_x1
,
l_y1
,
l_x2
,
l_y2
),
(
s_x1
,
s_y1
,
s_x2
,
s_y2
)
=
self
.
get_mosaic_coords
(
mosaic_idx
,
xc
,
yc
,
w
,
h
,
input_h
,
input_w
)
mosaic_img
[
l_y1
:
l_y2
,
l_x1
:
l_x2
]
=
img
[
s_y1
:
s_y2
,
s_x1
:
s_x2
]
padw
,
padh
=
l_x1
-
s_x1
,
l_y1
-
s_y1
# Normalized xywh to pixel xyxy format
_gt_bbox
=
gt_bbox
.
copy
()
if
len
(
gt_bbox
)
>
0
:
_gt_bbox
[:,
0
]
=
scale
*
gt_bbox
[:,
0
]
+
padw
_gt_bbox
[:,
1
]
=
scale
*
gt_bbox
[:,
1
]
+
padh
_gt_bbox
[:,
2
]
=
scale
*
gt_bbox
[:,
2
]
+
padw
_gt_bbox
[:,
3
]
=
scale
*
gt_bbox
[:,
3
]
+
padh
mosaic_gt_bbox
.
append
(
_gt_bbox
)
mosaic_gt_class
.
append
(
sp
[
'gt_class'
])
mosaic_is_crowd
.
append
(
sp
[
'is_crowd'
])
# 2. clip bbox and get mosaic_labels([gt_bbox, gt_class, is_crowd])
if
len
(
mosaic_gt_bbox
):
mosaic_gt_bbox
=
np
.
concatenate
(
mosaic_gt_bbox
,
0
)
mosaic_gt_class
=
np
.
concatenate
(
mosaic_gt_class
,
0
)
mosaic_is_crowd
=
np
.
concatenate
(
mosaic_is_crowd
,
0
)
mosaic_labels
=
np
.
concatenate
([
mosaic_gt_bbox
,
mosaic_gt_class
.
astype
(
mosaic_gt_bbox
.
dtype
),
mosaic_is_crowd
.
astype
(
mosaic_gt_bbox
.
dtype
)
],
1
)
if
self
.
remove_outside_box
:
# for MOT dataset
flag1
=
mosaic_gt_bbox
[:,
0
]
<
2
*
input_w
flag2
=
mosaic_gt_bbox
[:,
2
]
>
0
flag3
=
mosaic_gt_bbox
[:,
1
]
<
2
*
input_h
flag4
=
mosaic_gt_bbox
[:,
3
]
>
0
flag_all
=
flag1
*
flag2
*
flag3
*
flag4
mosaic_labels
=
mosaic_labels
[
flag_all
]
else
:
mosaic_labels
[:,
0
]
=
np
.
clip
(
mosaic_labels
[:,
0
],
0
,
2
*
input_w
)
mosaic_labels
[:,
1
]
=
np
.
clip
(
mosaic_labels
[:,
1
],
0
,
2
*
input_h
)
mosaic_labels
[:,
2
]
=
np
.
clip
(
mosaic_labels
[:,
2
],
0
,
2
*
input_w
)
mosaic_labels
[:,
3
]
=
np
.
clip
(
mosaic_labels
[:,
3
],
0
,
2
*
input_h
)
else
:
mosaic_labels
=
np
.
zeros
((
1
,
6
))
# 3. random_affine augment
mosaic_img
,
mosaic_labels
=
self
.
random_affine_augment
(
mosaic_img
,
mosaic_labels
,
input_dim
=
self
.
input_dim
,
degrees
=
self
.
degrees
,
translates
=
self
.
translate
,
scales
=
self
.
scale
,
shears
=
self
.
shear
)
# 4. Mixup augment as copypaste, https://arxiv.org/abs/2012.07177
# optinal, not used(enable_mixup=False) in tiny/nano
if
(
self
.
enable_mixup
and
not
len
(
mosaic_labels
)
==
0
and
random
.
random
()
<
self
.
mixup_prob
):
sample_mixup
=
sample
[
4
]
mixup_img
=
sample_mixup
[
'image'
]
cp_labels
=
np
.
concatenate
([
sample_mixup
[
'gt_bbox'
],
sample_mixup
[
'gt_class'
].
astype
(
mosaic_labels
.
dtype
),
sample_mixup
[
'is_crowd'
].
astype
(
mosaic_labels
.
dtype
)
],
1
)
mosaic_img
,
mosaic_labels
=
self
.
mixup_augment
(
mosaic_img
,
mosaic_labels
,
self
.
input_dim
,
cp_labels
,
mixup_img
)
sample0
=
sample
[
0
]
sample0
[
'image'
]
=
mosaic_img
.
astype
(
np
.
uint8
)
# can not be float32
sample0
[
'h'
]
=
float
(
mosaic_img
.
shape
[
0
])
sample0
[
'w'
]
=
float
(
mosaic_img
.
shape
[
1
])
sample0
[
'im_shape'
][
0
]
=
sample0
[
'h'
]
sample0
[
'im_shape'
][
1
]
=
sample0
[
'w'
]
sample0
[
'gt_bbox'
]
=
mosaic_labels
[:,
:
4
].
astype
(
np
.
float32
)
sample0
[
'gt_class'
]
=
mosaic_labels
[:,
4
:
5
].
astype
(
np
.
float32
)
sample0
[
'is_crowd'
]
=
mosaic_labels
[:,
5
:
6
].
astype
(
np
.
float32
)
return
sample0
def
mixup_augment
(
self
,
origin_img
,
origin_labels
,
input_dim
,
cp_labels
,
img
):
jit_factor
=
random
.
uniform
(
*
self
.
mixup_scale
)
FLIP
=
random
.
uniform
(
0
,
1
)
>
0.5
if
len
(
img
.
shape
)
==
3
:
cp_img
=
np
.
ones
(
(
input_dim
[
0
],
input_dim
[
1
],
3
),
dtype
=
np
.
uint8
)
*
114
else
:
cp_img
=
np
.
ones
(
input_dim
,
dtype
=
np
.
uint8
)
*
114
cp_scale_ratio
=
min
(
input_dim
[
0
]
/
img
.
shape
[
0
],
input_dim
[
1
]
/
img
.
shape
[
1
])
resized_img
=
cv2
.
resize
(
img
,
(
int
(
img
.
shape
[
1
]
*
cp_scale_ratio
),
int
(
img
.
shape
[
0
]
*
cp_scale_ratio
)),
interpolation
=
cv2
.
INTER_LINEAR
)
cp_img
[:
int
(
img
.
shape
[
0
]
*
cp_scale_ratio
),
:
int
(
img
.
shape
[
1
]
*
cp_scale_ratio
)]
=
resized_img
cp_img
=
cv2
.
resize
(
cp_img
,
(
int
(
cp_img
.
shape
[
1
]
*
jit_factor
),
int
(
cp_img
.
shape
[
0
]
*
jit_factor
)))
cp_scale_ratio
*=
jit_factor
if
FLIP
:
cp_img
=
cp_img
[:,
::
-
1
,
:]
origin_h
,
origin_w
=
cp_img
.
shape
[:
2
]
target_h
,
target_w
=
origin_img
.
shape
[:
2
]
padded_img
=
np
.
zeros
(
(
max
(
origin_h
,
target_h
),
max
(
origin_w
,
target_w
),
3
),
dtype
=
np
.
uint8
)
padded_img
[:
origin_h
,
:
origin_w
]
=
cp_img
x_offset
,
y_offset
=
0
,
0
if
padded_img
.
shape
[
0
]
>
target_h
:
y_offset
=
random
.
randint
(
0
,
padded_img
.
shape
[
0
]
-
target_h
-
1
)
if
padded_img
.
shape
[
1
]
>
target_w
:
x_offset
=
random
.
randint
(
0
,
padded_img
.
shape
[
1
]
-
target_w
-
1
)
padded_cropped_img
=
padded_img
[
y_offset
:
y_offset
+
target_h
,
x_offset
:
x_offset
+
target_w
]
# adjust boxes
cp_bboxes_origin_np
=
cp_labels
[:,
:
4
].
copy
()
cp_bboxes_origin_np
[:,
0
::
2
]
=
np
.
clip
(
cp_bboxes_origin_np
[:,
0
::
2
]
*
cp_scale_ratio
,
0
,
origin_w
)
cp_bboxes_origin_np
[:,
1
::
2
]
=
np
.
clip
(
cp_bboxes_origin_np
[:,
1
::
2
]
*
cp_scale_ratio
,
0
,
origin_h
)
if
FLIP
:
cp_bboxes_origin_np
[:,
0
::
2
]
=
(
origin_w
-
cp_bboxes_origin_np
[:,
0
::
2
][:,
::
-
1
])
cp_bboxes_transformed_np
=
cp_bboxes_origin_np
.
copy
()
if
self
.
remove_outside_box
:
# for MOT dataset
cp_bboxes_transformed_np
[:,
0
::
2
]
-=
x_offset
cp_bboxes_transformed_np
[:,
1
::
2
]
-=
y_offset
else
:
cp_bboxes_transformed_np
[:,
0
::
2
]
=
np
.
clip
(
cp_bboxes_transformed_np
[:,
0
::
2
]
-
x_offset
,
0
,
target_w
)
cp_bboxes_transformed_np
[:,
1
::
2
]
=
np
.
clip
(
cp_bboxes_transformed_np
[:,
1
::
2
]
-
y_offset
,
0
,
target_h
)
cls_labels
=
cp_labels
[:,
4
:
5
].
copy
()
crd_labels
=
cp_labels
[:,
5
:
6
].
copy
()
box_labels
=
cp_bboxes_transformed_np
labels
=
np
.
hstack
((
box_labels
,
cls_labels
,
crd_labels
))
if
self
.
remove_outside_box
:
labels
=
labels
[
labels
[:,
0
]
<
target_w
]
labels
=
labels
[
labels
[:,
2
]
>
0
]
labels
=
labels
[
labels
[:,
1
]
<
target_h
]
labels
=
labels
[
labels
[:,
3
]
>
0
]
origin_labels
=
np
.
vstack
((
origin_labels
,
labels
))
origin_img
=
origin_img
.
astype
(
np
.
float32
)
origin_img
=
0.5
*
origin_img
+
0.5
*
padded_cropped_img
.
astype
(
np
.
float32
)
return
origin_img
.
astype
(
np
.
uint8
),
origin_labels
@
register_op
class
PadResize
(
BaseOperator
):
""" PadResize for image and gt_bbbox
Args:
target_size (list[int]): input shape
fill_value (float): pixel value of padded image
"""
def
__init__
(
self
,
target_size
,
fill_value
=
114
):
super
(
PadResize
,
self
).
__init__
()
if
isinstance
(
target_size
,
Integral
):
target_size
=
[
target_size
,
target_size
]
self
.
target_size
=
target_size
self
.
fill_value
=
fill_value
def
_resize
(
self
,
img
,
bboxes
,
labels
):
ratio
=
min
(
self
.
target_size
[
0
]
/
img
.
shape
[
0
],
self
.
target_size
[
1
]
/
img
.
shape
[
1
])
w
,
h
=
int
(
img
.
shape
[
1
]
*
ratio
),
int
(
img
.
shape
[
0
]
*
ratio
)
resized_img
=
cv2
.
resize
(
img
,
(
w
,
h
),
interpolation
=
cv2
.
INTER_LINEAR
)
if
len
(
bboxes
)
>
0
:
bboxes
*=
ratio
mask
=
np
.
minimum
(
bboxes
[:,
2
]
-
bboxes
[:,
0
],
bboxes
[:,
3
]
-
bboxes
[:,
1
])
>
1
bboxes
=
bboxes
[
mask
]
labels
=
labels
[
mask
]
return
resized_img
,
bboxes
,
labels
def
_pad
(
self
,
img
):
h
,
w
,
_
=
img
.
shape
if
h
==
self
.
target_size
[
0
]
and
w
==
self
.
target_size
[
1
]:
return
img
padded_img
=
np
.
full
(
(
self
.
target_size
[
0
],
self
.
target_size
[
1
],
3
),
self
.
fill_value
,
dtype
=
np
.
uint8
)
padded_img
[:
h
,
:
w
]
=
img
return
padded_img
def
apply
(
self
,
sample
,
context
=
None
):
image
=
sample
[
'image'
]
bboxes
=
sample
[
'gt_bbox'
]
labels
=
sample
[
'gt_class'
]
image
,
bboxes
,
labels
=
self
.
_resize
(
image
,
bboxes
,
labels
)
sample
[
'image'
]
=
self
.
_pad
(
image
).
astype
(
np
.
float32
)
sample
[
'gt_bbox'
]
=
bboxes
sample
[
'gt_class'
]
=
labels
return
sample
ppdet/engine/export_utils.py
浏览文件 @
10bf8de7
...
...
@@ -48,6 +48,7 @@ TRT_MIN_SUBGRAPH = {
'PicoDet'
:
3
,
'CenterNet'
:
5
,
'TOOD'
:
5
,
'YOLOX'
:
8
,
}
KEYPOINT_ARCH
=
[
'HigherHRNet'
,
'TopDownHRNet'
]
...
...
@@ -147,6 +148,12 @@ def _dump_infer_config(config, path, image_shape, model):
infer_cfg
[
'min_subgraph_size'
]
=
min_subgraph_size
arch_state
=
True
break
if
infer_arch
==
'YOLOX'
:
infer_cfg
[
'arch'
]
=
infer_arch
infer_cfg
[
'min_subgraph_size'
]
=
TRT_MIN_SUBGRAPH
[
infer_arch
]
arch_state
=
True
if
not
arch_state
:
logger
.
error
(
'Architecture: {} is not supported for exporting model now.
\n
'
.
...
...
ppdet/engine/trainer.py
浏览文件 @
10bf8de7
...
...
@@ -28,6 +28,7 @@ from PIL import Image, ImageOps, ImageFile
ImageFile
.
LOAD_TRUNCATED_IMAGES
=
True
import
paddle
import
paddle.nn
as
nn
import
paddle.distributed
as
dist
from
paddle.distributed
import
fleet
from
paddle
import
amp
...
...
@@ -99,6 +100,12 @@ class Trainer(object):
self
.
model
=
self
.
cfg
.
model
self
.
is_loaded_weights
=
True
if
cfg
.
architecture
==
'YOLOX'
:
for
k
,
m
in
self
.
model
.
named_sublayers
():
if
isinstance
(
m
,
nn
.
BatchNorm2D
):
m
.
epsilon
=
1e-3
# for amp(fp16)
m
.
momentum
=
0.97
# 0.03 in pytorch
#normalize params for deploy
if
'slim'
in
cfg
and
cfg
[
'slim_type'
]
==
'OFA'
:
self
.
model
.
model
.
load_meanstd
(
cfg
[
'TestReader'
][
...
...
@@ -117,10 +124,11 @@ class Trainer(object):
if
self
.
use_ema
:
ema_decay
=
self
.
cfg
.
get
(
'ema_decay'
,
0.9998
)
cycle_epoch
=
self
.
cfg
.
get
(
'cycle_epoch'
,
-
1
)
ema_decay_type
=
self
.
cfg
.
get
(
'ema_decay_type'
,
'threshold'
)
self
.
ema
=
ModelEMA
(
self
.
model
,
decay
=
ema_decay
,
use_thres_step
=
Tru
e
,
ema_decay_type
=
ema_decay_typ
e
,
cycle_epoch
=
cycle_epoch
)
# EvalDataset build with BatchSampler to evaluate in single device
...
...
ppdet/modeling/architectures/__init__.py
浏览文件 @
10bf8de7
...
...
@@ -5,6 +5,13 @@
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
meta_arch
from
.
import
faster_rcnn
from
.
import
mask_rcnn
...
...
@@ -28,6 +35,7 @@ from . import sparse_rcnn
from
.
import
tood
from
.
import
retinanet
from
.
import
bytetrack
from
.
import
yolox
from
.meta_arch
import
*
from
.faster_rcnn
import
*
...
...
@@ -53,3 +61,4 @@ from .sparse_rcnn import *
from
.tood
import
*
from
.retinanet
import
*
from
.bytetrack
import
*
from
.yolox
import
*
ppdet/modeling/architectures/yolox.py
0 → 100644
浏览文件 @
10bf8de7
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
ppdet.core.workspace
import
register
,
create
from
.meta_arch
import
BaseArch
import
random
import
paddle
import
paddle.nn.functional
as
F
import
paddle.distributed
as
dist
from
ppdet.modeling.ops
import
paddle_distributed_is_initialized
__all__
=
[
'YOLOX'
]
@
register
class
YOLOX
(
BaseArch
):
"""
YOLOX network, see https://arxiv.org/abs/2107.08430
Args:
backbone (nn.Layer): backbone instance
neck (nn.Layer): neck instance
head (nn.Layer): head instance
for_mot (bool): whether used for MOT or not
input_size (list[int]): initial scale, will be reset by self._preprocess()
size_stride (int): stride of the size range
size_range (list[int]): multi-scale range for training
random_interval (int): interval of iter to change self._input_size
"""
__category__
=
'architecture'
def
__init__
(
self
,
backbone
=
'CSPDarkNet'
,
neck
=
'YOLOCSPPAN'
,
head
=
'YOLOXHead'
,
for_mot
=
False
,
input_size
=
[
640
,
640
],
size_stride
=
32
,
size_range
=
[
15
,
25
],
random_interval
=
10
):
super
(
YOLOX
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
neck
=
neck
self
.
head
=
head
self
.
for_mot
=
for_mot
self
.
input_size
=
input_size
self
.
_input_size
=
paddle
.
to_tensor
(
input_size
)
self
.
size_stride
=
size_stride
self
.
size_range
=
size_range
self
.
random_interval
=
random_interval
self
.
_step
=
0
@
classmethod
def
from_config
(
cls
,
cfg
,
*
args
,
**
kwargs
):
# backbone
backbone
=
create
(
cfg
[
'backbone'
])
# fpn
kwargs
=
{
'input_shape'
:
backbone
.
out_shape
}
neck
=
create
(
cfg
[
'neck'
],
**
kwargs
)
# head
kwargs
=
{
'input_shape'
:
neck
.
out_shape
}
head
=
create
(
cfg
[
'head'
],
**
kwargs
)
return
{
'backbone'
:
backbone
,
'neck'
:
neck
,
"head"
:
head
,
}
def
_forward
(
self
):
if
self
.
training
:
self
.
_preprocess
()
body_feats
=
self
.
backbone
(
self
.
inputs
)
neck_feats
=
self
.
neck
(
body_feats
,
self
.
for_mot
)
if
self
.
training
:
yolox_losses
=
self
.
head
(
neck_feats
,
self
.
inputs
)
yolox_losses
.
update
({
'size'
:
self
.
_input_size
[
0
]})
return
yolox_losses
else
:
head_outs
=
self
.
head
(
neck_feats
)
bbox
,
bbox_num
=
self
.
head
.
post_process
(
head_outs
,
self
.
inputs
[
'im_shape'
],
self
.
inputs
[
'scale_factor'
])
return
{
'bbox'
:
bbox
,
'bbox_num'
:
bbox_num
}
def
get_loss
(
self
):
return
self
.
_forward
()
def
get_pred
(
self
):
return
self
.
_forward
()
def
_preprocess
(
self
):
# YOLOX multi-scale training, interpolate resize before inputs of the network.
self
.
_get_size
()
scale_y
=
self
.
_input_size
[
0
]
/
self
.
input_size
[
0
]
scale_x
=
self
.
_input_size
[
1
]
/
self
.
input_size
[
1
]
if
scale_x
!=
1
or
scale_y
!=
1
:
self
.
inputs
[
'image'
]
=
F
.
interpolate
(
self
.
inputs
[
'image'
],
size
=
self
.
_input_size
,
mode
=
'bilinear'
,
align_corners
=
False
)
gt_bboxes
=
self
.
inputs
[
'gt_bbox'
]
for
i
in
range
(
len
(
gt_bboxes
)):
if
len
(
gt_bboxes
[
i
])
>
0
:
gt_bboxes
[
i
][:,
0
::
2
]
=
gt_bboxes
[
i
][:,
0
::
2
]
*
scale_x
gt_bboxes
[
i
][:,
1
::
2
]
=
gt_bboxes
[
i
][:,
1
::
2
]
*
scale_y
self
.
inputs
[
'gt_bbox'
]
=
gt_bboxes
def
_get_size
(
self
):
# random_interval = 10 as default, every 10 iters to change self._input_size
image_ratio
=
self
.
input_size
[
1
]
*
1.0
/
self
.
input_size
[
0
]
if
self
.
_step
%
self
.
random_interval
==
0
:
size_factor
=
random
.
randint
(
*
self
.
size_range
)
size
=
[
self
.
size_stride
*
size_factor
,
self
.
size_stride
*
int
(
size_factor
*
image_ratio
)
]
size
=
paddle
.
to_tensor
(
size
)
if
dist
.
get_world_size
()
>
1
and
paddle_distributed_is_initialized
(
):
dist
.
barrier
()
dist
.
broadcast
(
size
,
0
)
self
.
_input_size
=
size
self
.
_step
+=
1
ppdet/modeling/assigners/utils.py
浏览文件 @
10bf8de7
...
...
@@ -115,7 +115,7 @@ def check_points_inside_bboxes(points,
Args:
points (Tensor, float32): shape[L, 2], "xy" format, L: num_anchors
bboxes (Tensor, float32): shape[B, n, 4], "xmin, ymin, xmax, ymax" format
center_radius_tensor (Tensor, float32): shape [L, 1] Default: None.
center_radius_tensor (Tensor, float32): shape [L, 1]
.
Default: None.
eps (float): Default: 1e-9
Returns:
is_in_bboxes (Tensor, float32): shape[B, n, L], value=1. means selected
...
...
@@ -123,25 +123,28 @@ def check_points_inside_bboxes(points,
points
=
points
.
unsqueeze
([
0
,
1
])
x
,
y
=
points
.
chunk
(
2
,
axis
=-
1
)
xmin
,
ymin
,
xmax
,
ymax
=
bboxes
.
unsqueeze
(
2
).
chunk
(
4
,
axis
=-
1
)
if
center_radius_tensor
is
not
None
:
center_radius_tensor
=
center_radius_tensor
.
unsqueeze
([
0
,
1
])
bboxes_cx
=
(
xmin
+
xmax
)
/
2.
bboxes_cy
=
(
ymin
+
ymax
)
/
2.
xmin_sampling
=
bboxes_cx
-
center_radius_tensor
ymin_sampling
=
bboxes_cy
-
center_radius_tensor
xmax_sampling
=
bboxes_cx
+
center_radius_tensor
ymax_sampling
=
bboxes_cy
+
center_radius_tensor
xmin
=
paddle
.
maximum
(
xmin
,
xmin_sampling
)
ymin
=
paddle
.
maximum
(
ymin
,
ymin_sampling
)
xmax
=
paddle
.
minimum
(
xmax
,
xmax_sampling
)
ymax
=
paddle
.
minimum
(
ymax
,
ymax_sampling
)
# check whether `points` is in `bboxes`
l
=
x
-
xmin
t
=
y
-
ymin
r
=
xmax
-
x
b
=
ymax
-
y
bbox_ltrb
=
paddle
.
concat
([
l
,
t
,
r
,
b
],
axis
=-
1
)
return
(
bbox_ltrb
.
min
(
axis
=-
1
)
>
eps
).
astype
(
bboxes
.
dtype
)
delta_ltrb
=
paddle
.
concat
([
l
,
t
,
r
,
b
],
axis
=-
1
)
is_in_bboxes
=
(
delta_ltrb
.
min
(
axis
=-
1
)
>
eps
)
if
center_radius_tensor
is
not
None
:
# check whether `points` is in `center_radius`
center_radius_tensor
=
center_radius_tensor
.
unsqueeze
([
0
,
1
])
cx
=
(
xmin
+
xmax
)
*
0.5
cy
=
(
ymin
+
ymax
)
*
0.5
l
=
x
-
(
cx
-
center_radius_tensor
)
t
=
y
-
(
cy
-
center_radius_tensor
)
r
=
(
cx
+
center_radius_tensor
)
-
x
b
=
(
cy
+
center_radius_tensor
)
-
y
delta_ltrb_c
=
paddle
.
concat
([
l
,
t
,
r
,
b
],
axis
=-
1
)
is_in_center
=
(
delta_ltrb_c
.
min
(
axis
=-
1
)
>
eps
)
return
(
paddle
.
logical_and
(
is_in_bboxes
,
is_in_center
),
paddle
.
logical_or
(
is_in_bboxes
,
is_in_center
))
return
is_in_bboxes
.
astype
(
bboxes
.
dtype
)
def
compute_max_iou_anchor
(
ious
):
...
...
ppdet/modeling/backbones/__init__.py
浏览文件 @
10bf8de7
#
copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
#
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
vgg
...
...
@@ -30,6 +30,7 @@ from . import lcnet
from
.
import
hardnet
from
.
import
esnet
from
.
import
cspresnet
from
.
import
csp_darknet
from
.vgg
import
*
from
.resnet
import
*
...
...
@@ -49,3 +50,4 @@ from .lcnet import *
from
.hardnet
import
*
from
.esnet
import
*
from
.cspresnet
import
*
from
.csp_darknet
import
*
ppdet/modeling/backbones/csp_darknet.py
0 → 100644
浏览文件 @
10bf8de7
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle
import
ParamAttr
from
paddle.regularizer
import
L2Decay
from
ppdet.core.workspace
import
register
,
serializable
from
ppdet.modeling.ops
import
get_activation
from
ppdet.modeling.initializer
import
conv_init_
from
..shape_spec
import
ShapeSpec
__all__
=
[
'CSPDarkNet'
,
'BaseConv'
,
'DWConv'
,
'BottleNeck'
,
'SPPLayer'
,
'SPPFLayer'
]
class
BaseConv
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
ksize
,
stride
,
groups
=
1
,
bias
=
False
,
act
=
"silu"
):
super
(
BaseConv
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2D
(
in_channels
,
out_channels
,
kernel_size
=
ksize
,
stride
=
stride
,
padding
=
(
ksize
-
1
)
//
2
,
groups
=
groups
,
bias_attr
=
bias
)
self
.
bn
=
nn
.
BatchNorm2D
(
out_channels
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
self
.
act
=
get_activation
(
act
)
self
.
_init_weights
()
def
_init_weights
(
self
):
conv_init_
(
self
.
conv
)
def
forward
(
self
,
x
):
return
self
.
act
(
self
.
bn
(
self
.
conv
(
x
)))
class
DWConv
(
nn
.
Layer
):
"""Depthwise Conv"""
def
__init__
(
self
,
in_channels
,
out_channels
,
ksize
,
stride
=
1
,
bias
=
False
,
act
=
"silu"
):
super
(
DWConv
,
self
).
__init__
()
self
.
dw_conv
=
BaseConv
(
in_channels
,
in_channels
,
ksize
=
ksize
,
stride
=
stride
,
groups
=
in_channels
,
bias
=
bias
,
act
=
act
,
)
self
.
pw_conv
=
BaseConv
(
in_channels
,
out_channels
,
ksize
=
1
,
stride
=
1
,
groups
=
1
,
bias
=
bias
,
act
=
act
)
def
forward
(
self
,
x
):
return
self
.
pw_conv
(
self
.
dw_conv
(
x
))
class
Focus
(
nn
.
Layer
):
"""Focus width and height information into channel space, used in YOLOX."""
def
__init__
(
self
,
in_channels
,
out_channels
,
ksize
=
3
,
stride
=
1
,
bias
=
False
,
act
=
"silu"
):
super
(
Focus
,
self
).
__init__
()
self
.
conv
=
BaseConv
(
in_channels
*
4
,
out_channels
,
ksize
=
ksize
,
stride
=
stride
,
bias
=
bias
,
act
=
act
)
def
forward
(
self
,
inputs
):
# inputs [bs, C, H, W] -> outputs [bs, 4C, W/2, H/2]
top_left
=
inputs
[:,
:,
0
::
2
,
0
::
2
]
top_right
=
inputs
[:,
:,
0
::
2
,
1
::
2
]
bottom_left
=
inputs
[:,
:,
1
::
2
,
0
::
2
]
bottom_right
=
inputs
[:,
:,
1
::
2
,
1
::
2
]
outputs
=
paddle
.
concat
(
[
top_left
,
bottom_left
,
top_right
,
bottom_right
],
1
)
return
self
.
conv
(
outputs
)
class
BottleNeck
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
shortcut
=
True
,
expansion
=
0.5
,
depthwise
=
False
,
bias
=
False
,
act
=
"silu"
):
super
(
BottleNeck
,
self
).
__init__
()
hidden_channels
=
int
(
out_channels
*
expansion
)
Conv
=
DWConv
if
depthwise
else
BaseConv
self
.
conv1
=
BaseConv
(
in_channels
,
hidden_channels
,
ksize
=
1
,
stride
=
1
,
bias
=
bias
,
act
=
act
)
self
.
conv2
=
Conv
(
hidden_channels
,
out_channels
,
ksize
=
3
,
stride
=
1
,
bias
=
bias
,
act
=
act
)
self
.
add_shortcut
=
shortcut
and
in_channels
==
out_channels
def
forward
(
self
,
x
):
y
=
self
.
conv2
(
self
.
conv1
(
x
))
if
self
.
add_shortcut
:
y
=
y
+
x
return
y
class
SPPLayer
(
nn
.
Layer
):
"""Spatial Pyramid Pooling (SPP) layer used in YOLOv3-SPP and YOLOX"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_sizes
=
(
5
,
9
,
13
),
bias
=
False
,
act
=
"silu"
):
super
(
SPPLayer
,
self
).
__init__
()
hidden_channels
=
in_channels
//
2
self
.
conv1
=
BaseConv
(
in_channels
,
hidden_channels
,
ksize
=
1
,
stride
=
1
,
bias
=
bias
,
act
=
act
)
self
.
maxpoolings
=
nn
.
LayerList
([
nn
.
MaxPool2D
(
kernel_size
=
ks
,
stride
=
1
,
padding
=
ks
//
2
)
for
ks
in
kernel_sizes
])
conv2_channels
=
hidden_channels
*
(
len
(
kernel_sizes
)
+
1
)
self
.
conv2
=
BaseConv
(
conv2_channels
,
out_channels
,
ksize
=
1
,
stride
=
1
,
bias
=
bias
,
act
=
act
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
paddle
.
concat
([
x
]
+
[
mp
(
x
)
for
mp
in
self
.
maxpoolings
],
axis
=
1
)
x
=
self
.
conv2
(
x
)
return
x
class
SPPFLayer
(
nn
.
Layer
):
""" Spatial Pyramid Pooling - Fast (SPPF) layer used in YOLOv5 by Glenn Jocher,
equivalent to SPP(k=(5, 9, 13))
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
ksize
=
5
,
bias
=
False
,
act
=
'silu'
):
super
(
SPPFLayer
,
self
).
__init__
()
hidden_channels
=
in_channels
//
2
self
.
conv1
=
BaseConv
(
in_channels
,
hidden_channels
,
ksize
=
1
,
stride
=
1
,
bias
=
bias
,
act
=
act
)
self
.
maxpooling
=
nn
.
MaxPool2D
(
kernel_size
=
ksize
,
stride
=
1
,
padding
=
ksize
//
2
)
conv2_channels
=
hidden_channels
*
4
self
.
conv2
=
BaseConv
(
conv2_channels
,
out_channels
,
ksize
=
1
,
stride
=
1
,
bias
=
bias
,
act
=
act
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
y1
=
self
.
maxpooling
(
x
)
y2
=
self
.
maxpooling
(
y1
)
y3
=
self
.
maxpooling
(
y2
)
concats
=
paddle
.
concat
([
x
,
y1
,
y2
,
y3
],
axis
=
1
)
out
=
self
.
conv2
(
concats
)
return
out
class
CSPLayer
(
nn
.
Layer
):
"""CSP (Cross Stage Partial) layer with 3 convs, named C3 in YOLOv5"""
def
__init__
(
self
,
in_channels
,
out_channels
,
num_blocks
=
1
,
shortcut
=
True
,
expansion
=
0.5
,
depthwise
=
False
,
bias
=
False
,
act
=
"silu"
):
super
(
CSPLayer
,
self
).
__init__
()
hidden_channels
=
int
(
out_channels
*
expansion
)
self
.
conv1
=
BaseConv
(
in_channels
,
hidden_channels
,
ksize
=
1
,
stride
=
1
,
bias
=
bias
,
act
=
act
)
self
.
conv2
=
BaseConv
(
in_channels
,
hidden_channels
,
ksize
=
1
,
stride
=
1
,
bias
=
bias
,
act
=
act
)
self
.
bottlenecks
=
nn
.
Sequential
(
*
[
BottleNeck
(
hidden_channels
,
hidden_channels
,
shortcut
=
shortcut
,
expansion
=
1.0
,
depthwise
=
depthwise
,
bias
=
bias
,
act
=
act
)
for
_
in
range
(
num_blocks
)
])
self
.
conv3
=
BaseConv
(
hidden_channels
*
2
,
out_channels
,
ksize
=
1
,
stride
=
1
,
bias
=
bias
,
act
=
act
)
def
forward
(
self
,
x
):
x_1
=
self
.
conv1
(
x
)
x_1
=
self
.
bottlenecks
(
x_1
)
x_2
=
self
.
conv2
(
x
)
x
=
paddle
.
concat
([
x_1
,
x_2
],
axis
=
1
)
x
=
self
.
conv3
(
x
)
return
x
@
register
@
serializable
class
CSPDarkNet
(
nn
.
Layer
):
"""
CSPDarkNet backbone.
Args:
arch (str): Architecture of CSPDarkNet, from {P5, P6, X}, default as X,
and 'X' means used in YOLOX, 'P5/P6' means used in YOLOv5.
depth_mult (float): Depth multiplier, multiply number of channels in
each layer, default as 1.0.
width_mult (float): Width multiplier, multiply number of blocks in
CSPLayer, default as 1.0.
depthwise (bool): Whether to use depth-wise conv layer.
act (str): Activation function type, default as 'silu'.
return_idx (list): Index of stages whose feature maps are returned.
"""
__shared__
=
[
'depth_mult'
,
'width_mult'
,
'act'
]
# in_channels, out_channels, num_blocks, add_shortcut, use_spp(use_sppf)
# 'X' means setting used in YOLOX, 'P5/P6' means setting used in YOLOv5.
arch_settings
=
{
'X'
:
[[
64
,
128
,
3
,
True
,
False
],
[
128
,
256
,
9
,
True
,
False
],
[
256
,
512
,
9
,
True
,
False
],
[
512
,
1024
,
3
,
False
,
True
]],
'P5'
:
[[
64
,
128
,
3
,
True
,
False
],
[
128
,
256
,
6
,
True
,
False
],
[
256
,
512
,
9
,
True
,
False
],
[
512
,
1024
,
3
,
True
,
True
]],
'P6'
:
[[
64
,
128
,
3
,
True
,
False
],
[
128
,
256
,
6
,
True
,
False
],
[
256
,
512
,
9
,
True
,
False
],
[
512
,
768
,
3
,
True
,
False
],
[
768
,
1024
,
3
,
True
,
True
]],
}
def
__init__
(
self
,
arch
=
'X'
,
depth_mult
=
1.0
,
width_mult
=
1.0
,
depthwise
=
False
,
act
=
'silu'
,
return_idx
=
[
2
,
3
,
4
]):
super
(
CSPDarkNet
,
self
).
__init__
()
self
.
arch
=
arch
self
.
return_idx
=
return_idx
Conv
=
DWConv
if
depthwise
else
BaseConv
arch_setting
=
self
.
arch_settings
[
arch
]
base_channels
=
int
(
arch_setting
[
0
][
0
]
*
width_mult
)
# Note: differences between the latest YOLOv5 and the original YOLOX
# 1. self.stem, use SPPF(in YOLOv5) or SPP(in YOLOX)
# 2. use SPPF(in YOLOv5) or SPP(in YOLOX)
# 3. put SPPF before(YOLOv5) or SPP after(YOLOX) the last cspdark block's CSPLayer
# 4. whether SPPF(SPP)'CSPLayer add shortcut, True in YOLOv5, False in YOLOX
if
arch
in
[
'P5'
,
'P6'
]:
# in the latest YOLOv5, use Conv stem, and SPPF (fast, only single spp kernal size)
self
.
stem
=
Conv
(
3
,
base_channels
,
ksize
=
6
,
stride
=
2
,
bias
=
False
,
act
=
act
)
spp_kernal_sizes
=
5
elif
arch
in
[
'X'
]:
# in the original YOLOX, use Focus stem, and SPP (three spp kernal sizes)
self
.
stem
=
Focus
(
3
,
base_channels
,
ksize
=
3
,
stride
=
1
,
bias
=
False
,
act
=
act
)
spp_kernal_sizes
=
(
5
,
9
,
13
)
else
:
raise
AttributeError
(
"Unsupported arch type: {}"
.
format
(
arch
))
_out_channels
=
[
base_channels
]
layers_num
=
1
self
.
csp_dark_blocks
=
[]
for
i
,
(
in_channels
,
out_channels
,
num_blocks
,
shortcut
,
use_spp
)
in
enumerate
(
arch_setting
):
in_channels
=
int
(
in_channels
*
width_mult
)
out_channels
=
int
(
out_channels
*
width_mult
)
_out_channels
.
append
(
out_channels
)
num_blocks
=
max
(
round
(
num_blocks
*
depth_mult
),
1
)
stage
=
[]
conv_layer
=
self
.
add_sublayer
(
'layers{}.stage{}.conv_layer'
.
format
(
layers_num
,
i
+
1
),
Conv
(
in_channels
,
out_channels
,
3
,
2
,
bias
=
False
,
act
=
act
))
stage
.
append
(
conv_layer
)
layers_num
+=
1
if
use_spp
and
arch
in
[
'X'
]:
# in YOLOX use SPPLayer
spp_layer
=
self
.
add_sublayer
(
'layers{}.stage{}.spp_layer'
.
format
(
layers_num
,
i
+
1
),
SPPLayer
(
out_channels
,
out_channels
,
kernel_sizes
=
spp_kernal_sizes
,
bias
=
False
,
act
=
act
))
stage
.
append
(
spp_layer
)
layers_num
+=
1
csp_layer
=
self
.
add_sublayer
(
'layers{}.stage{}.csp_layer'
.
format
(
layers_num
,
i
+
1
),
CSPLayer
(
out_channels
,
out_channels
,
num_blocks
=
num_blocks
,
shortcut
=
shortcut
,
depthwise
=
depthwise
,
bias
=
False
,
act
=
act
))
stage
.
append
(
csp_layer
)
layers_num
+=
1
if
use_spp
and
arch
in
[
'P5'
,
'P6'
]:
# in latest YOLOv5 use SPPFLayer instead of SPPLayer
sppf_layer
=
self
.
add_sublayer
(
'layers{}.stage{}.sppf_layer'
.
format
(
layers_num
,
i
+
1
),
SPPFLayer
(
out_channels
,
out_channels
,
ksize
=
5
,
bias
=
False
,
act
=
act
))
stage
.
append
(
sppf_layer
)
layers_num
+=
1
self
.
csp_dark_blocks
.
append
(
nn
.
Sequential
(
*
stage
))
self
.
_out_channels
=
[
_out_channels
[
i
]
for
i
in
self
.
return_idx
]
self
.
strides
=
[[
2
,
4
,
8
,
16
,
32
,
64
][
i
]
for
i
in
self
.
return_idx
]
def
forward
(
self
,
inputs
):
x
=
inputs
[
'image'
]
outputs
=
[]
x
=
self
.
stem
(
x
)
for
i
,
layer
in
enumerate
(
self
.
csp_dark_blocks
):
x
=
layer
(
x
)
if
i
+
1
in
self
.
return_idx
:
outputs
.
append
(
x
)
return
outputs
@
property
def
out_shape
(
self
):
return
[
ShapeSpec
(
channels
=
c
,
stride
=
s
)
for
c
,
s
in
zip
(
self
.
_out_channels
,
self
.
strides
)
]
ppdet/modeling/heads/yolo_head.py
浏览文件 @
10bf8de7
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
...
...
@@ -5,6 +19,16 @@ from paddle import ParamAttr
from
paddle.regularizer
import
L2Decay
from
ppdet.core.workspace
import
register
import
math
import
numpy
as
np
from
..initializer
import
bias_init_with_prob
,
constant_
from
..backbones.csp_darknet
import
BaseConv
,
DWConv
from
..losses
import
IouLoss
from
ppdet.modeling.assigners.simota_assigner
import
SimOTAAssigner
from
ppdet.modeling.bbox_utils
import
bbox_overlaps
__all__
=
[
'YOLOv3Head'
,
'YOLOXHead'
]
def
_de_sigmoid
(
x
,
eps
=
1e-7
):
x
=
paddle
.
clip
(
x
,
eps
,
1.
/
eps
)
...
...
@@ -122,3 +146,259 @@ class YOLOv3Head(nn.Layer):
@
classmethod
def
from_config
(
cls
,
cfg
,
input_shape
):
return
{
'in_channels'
:
[
i
.
channels
for
i
in
input_shape
],
}
@
register
class
YOLOXHead
(
nn
.
Layer
):
__shared__
=
[
'num_classes'
,
'width_mult'
,
'act'
]
__inject__
=
[
'assigner'
,
'nms'
]
def
__init__
(
self
,
num_classes
=
80
,
width_mult
=
1.0
,
depthwise
=
False
,
in_channels
=
[
256
,
512
,
1024
],
feat_channels
=
256
,
fpn_strides
=
(
8
,
16
,
32
),
l1_epoch
=
285
,
act
=
'silu'
,
assigner
=
SimOTAAssigner
(
use_vfl
=
False
),
nms
=
'MultiClassNMS'
,
loss_weight
=
{
'cls'
:
1.0
,
'obj'
:
1.0
,
'iou'
:
5.0
,
'l1'
:
1.0
}):
super
(
YOLOXHead
,
self
).
__init__
()
self
.
_dtype
=
paddle
.
framework
.
get_default_dtype
()
self
.
num_classes
=
num_classes
assert
len
(
in_channels
)
>
0
,
"in_channels length should > 0"
self
.
in_channels
=
in_channels
feat_channels
=
int
(
feat_channels
*
width_mult
)
self
.
fpn_strides
=
fpn_strides
self
.
l1_epoch
=
l1_epoch
self
.
assigner
=
assigner
self
.
nms
=
nms
self
.
loss_weight
=
loss_weight
self
.
iou_loss
=
IouLoss
(
loss_weight
=
1.0
)
# default loss_weight 2.5
ConvBlock
=
DWConv
if
depthwise
else
BaseConv
self
.
stem_conv
=
nn
.
LayerList
()
self
.
conv_cls
=
nn
.
LayerList
()
self
.
conv_reg
=
nn
.
LayerList
()
# reg [x,y,w,h] + obj
for
in_c
in
self
.
in_channels
:
self
.
stem_conv
.
append
(
BaseConv
(
in_c
,
feat_channels
,
1
,
1
,
act
=
act
))
self
.
conv_cls
.
append
(
nn
.
Sequential
(
*
[
ConvBlock
(
feat_channels
,
feat_channels
,
3
,
1
,
act
=
act
),
ConvBlock
(
feat_channels
,
feat_channels
,
3
,
1
,
act
=
act
),
nn
.
Conv2D
(
feat_channels
,
self
.
num_classes
,
1
,
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
]))
self
.
conv_reg
.
append
(
nn
.
Sequential
(
*
[
ConvBlock
(
feat_channels
,
feat_channels
,
3
,
1
,
act
=
act
),
ConvBlock
(
feat_channels
,
feat_channels
,
3
,
1
,
act
=
act
),
nn
.
Conv2D
(
feat_channels
,
4
+
1
,
# reg [x,y,w,h] + obj
1
,
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
]))
self
.
_init_weights
()
@
classmethod
def
from_config
(
cls
,
cfg
,
input_shape
):
return
{
'in_channels'
:
[
i
.
channels
for
i
in
input_shape
],
}
def
_init_weights
(
self
):
bias_cls
=
bias_init_with_prob
(
0.01
)
bias_reg
=
paddle
.
full
([
5
],
math
.
log
(
5.
),
dtype
=
self
.
_dtype
)
bias_reg
[:
2
]
=
0.
bias_reg
[
-
1
]
=
bias_cls
for
cls_
,
reg_
in
zip
(
self
.
conv_cls
,
self
.
conv_reg
):
constant_
(
cls_
[
-
1
].
weight
)
constant_
(
cls_
[
-
1
].
bias
,
bias_cls
)
constant_
(
reg_
[
-
1
].
weight
)
reg_
[
-
1
].
bias
.
set_value
(
bias_reg
)
def
_generate_anchor_point
(
self
,
feat_sizes
,
strides
,
offset
=
0.
):
anchor_points
,
stride_tensor
=
[],
[]
num_anchors_list
=
[]
for
feat_size
,
stride
in
zip
(
feat_sizes
,
strides
):
h
,
w
=
feat_size
x
=
(
paddle
.
arange
(
w
)
+
offset
)
*
stride
y
=
(
paddle
.
arange
(
h
)
+
offset
)
*
stride
y
,
x
=
paddle
.
meshgrid
(
y
,
x
)
anchor_points
.
append
(
paddle
.
stack
([
x
,
y
],
axis
=-
1
).
reshape
([
-
1
,
2
]))
stride_tensor
.
append
(
paddle
.
full
(
[
len
(
anchor_points
[
-
1
]),
1
],
stride
,
dtype
=
self
.
_dtype
))
num_anchors_list
.
append
(
len
(
anchor_points
[
-
1
]))
anchor_points
=
paddle
.
concat
(
anchor_points
).
astype
(
self
.
_dtype
)
anchor_points
.
stop_gradient
=
True
stride_tensor
=
paddle
.
concat
(
stride_tensor
)
stride_tensor
.
stop_gradient
=
True
return
anchor_points
,
stride_tensor
,
num_anchors_list
def
forward
(
self
,
feats
,
targets
=
None
):
assert
len
(
feats
)
==
len
(
self
.
fpn_strides
),
\
"The size of feats is not equal to size of fpn_strides"
feat_sizes
=
[[
f
.
shape
[
-
2
],
f
.
shape
[
-
1
]]
for
f
in
feats
]
cls_score_list
,
reg_pred_list
=
[],
[]
obj_score_list
=
[]
for
i
,
feat
in
enumerate
(
feats
):
feat
=
self
.
stem_conv
[
i
](
feat
)
cls_logit
=
self
.
conv_cls
[
i
](
feat
)
reg_pred
=
self
.
conv_reg
[
i
](
feat
)
# cls prediction
cls_score
=
F
.
sigmoid
(
cls_logit
)
cls_score_list
.
append
(
cls_score
.
flatten
(
2
).
transpose
([
0
,
2
,
1
]))
# reg prediction
reg_xywh
,
obj_logit
=
paddle
.
split
(
reg_pred
,
[
4
,
1
],
axis
=
1
)
reg_xywh
=
reg_xywh
.
flatten
(
2
).
transpose
([
0
,
2
,
1
])
reg_pred_list
.
append
(
reg_xywh
)
# obj prediction
obj_score
=
F
.
sigmoid
(
obj_logit
)
obj_score_list
.
append
(
obj_score
.
flatten
(
2
).
transpose
([
0
,
2
,
1
]))
cls_score_list
=
paddle
.
concat
(
cls_score_list
,
axis
=
1
)
reg_pred_list
=
paddle
.
concat
(
reg_pred_list
,
axis
=
1
)
obj_score_list
=
paddle
.
concat
(
obj_score_list
,
axis
=
1
)
# bbox decode
anchor_points
,
stride_tensor
,
_
=
\
self
.
_generate_anchor_point
(
feat_sizes
,
self
.
fpn_strides
)
reg_xy
,
reg_wh
=
paddle
.
split
(
reg_pred_list
,
2
,
axis
=-
1
)
reg_xy
+=
(
anchor_points
/
stride_tensor
)
reg_wh
=
paddle
.
exp
(
reg_wh
)
*
0.5
bbox_pred_list
=
paddle
.
concat
(
[
reg_xy
-
reg_wh
,
reg_xy
+
reg_wh
],
axis
=-
1
)
if
self
.
training
:
anchor_points
,
stride_tensor
,
num_anchors_list
=
\
self
.
_generate_anchor_point
(
feat_sizes
,
self
.
fpn_strides
,
0.5
)
yolox_losses
=
self
.
get_loss
([
cls_score_list
,
bbox_pred_list
,
obj_score_list
,
anchor_points
,
stride_tensor
,
num_anchors_list
],
targets
)
return
yolox_losses
else
:
pred_scores
=
(
cls_score_list
*
obj_score_list
).
sqrt
()
return
pred_scores
,
bbox_pred_list
,
stride_tensor
def
get_loss
(
self
,
head_outs
,
targets
):
pred_cls
,
pred_bboxes
,
pred_obj
,
\
anchor_points
,
stride_tensor
,
num_anchors_list
=
head_outs
gt_labels
=
targets
[
'gt_class'
]
gt_bboxes
=
targets
[
'gt_bbox'
]
pred_scores
=
(
pred_cls
*
pred_obj
).
sqrt
()
# label assignment
center_and_strides
=
paddle
.
concat
(
[
anchor_points
,
stride_tensor
,
stride_tensor
],
axis
=-
1
)
pos_num_list
,
label_list
,
bbox_target_list
=
[],
[],
[]
for
pred_score
,
pred_bbox
,
gt_box
,
gt_label
in
zip
(
pred_scores
.
detach
(),
pred_bboxes
.
detach
()
*
stride_tensor
,
gt_bboxes
,
gt_labels
):
pos_num
,
label
,
_
,
bbox_target
=
self
.
assigner
(
pred_score
,
center_and_strides
,
pred_bbox
,
gt_box
,
gt_label
)
pos_num_list
.
append
(
pos_num
)
label_list
.
append
(
label
)
bbox_target_list
.
append
(
bbox_target
)
labels
=
paddle
.
to_tensor
(
np
.
stack
(
label_list
,
axis
=
0
))
bbox_targets
=
paddle
.
to_tensor
(
np
.
stack
(
bbox_target_list
,
axis
=
0
))
bbox_targets
/=
stride_tensor
# rescale bbox
# 1. obj score loss
mask_positive
=
(
labels
!=
self
.
num_classes
)
loss_obj
=
F
.
binary_cross_entropy
(
pred_obj
,
mask_positive
.
astype
(
pred_obj
.
dtype
).
unsqueeze
(
-
1
),
reduction
=
'sum'
)
num_pos
=
sum
(
pos_num_list
)
if
num_pos
>
0
:
num_pos
=
paddle
.
to_tensor
(
num_pos
,
dtype
=
self
.
_dtype
).
clip
(
min
=
1
)
loss_obj
/=
num_pos
# 2. iou loss
bbox_mask
=
mask_positive
.
unsqueeze
(
-
1
).
tile
([
1
,
1
,
4
])
pred_bboxes_pos
=
paddle
.
masked_select
(
pred_bboxes
,
bbox_mask
).
reshape
([
-
1
,
4
])
assigned_bboxes_pos
=
paddle
.
masked_select
(
bbox_targets
,
bbox_mask
).
reshape
([
-
1
,
4
])
bbox_iou
=
bbox_overlaps
(
pred_bboxes_pos
,
assigned_bboxes_pos
)
bbox_iou
=
paddle
.
diag
(
bbox_iou
)
loss_iou
=
self
.
iou_loss
(
pred_bboxes_pos
.
split
(
4
,
axis
=-
1
),
assigned_bboxes_pos
.
split
(
4
,
axis
=-
1
))
loss_iou
=
loss_iou
.
sum
()
/
num_pos
# 3. cls loss
cls_mask
=
mask_positive
.
unsqueeze
(
-
1
).
tile
(
[
1
,
1
,
self
.
num_classes
])
pred_cls_pos
=
paddle
.
masked_select
(
pred_cls
,
cls_mask
).
reshape
([
-
1
,
self
.
num_classes
])
assigned_cls_pos
=
paddle
.
masked_select
(
labels
,
mask_positive
)
assigned_cls_pos
=
F
.
one_hot
(
assigned_cls_pos
,
self
.
num_classes
+
1
)[...,
:
-
1
]
assigned_cls_pos
*=
bbox_iou
.
unsqueeze
(
-
1
)
loss_cls
=
F
.
binary_cross_entropy
(
pred_cls_pos
,
assigned_cls_pos
,
reduction
=
'sum'
)
loss_cls
/=
num_pos
# 4. l1 loss
if
targets
[
'epoch_id'
]
>=
self
.
l1_epoch
:
loss_l1
=
F
.
l1_loss
(
pred_bboxes_pos
,
assigned_bboxes_pos
,
reduction
=
'sum'
)
loss_l1
/=
num_pos
else
:
loss_l1
=
paddle
.
zeros
([
1
])
loss_l1
.
stop_gradient
=
False
else
:
loss_cls
=
paddle
.
zeros
([
1
])
loss_iou
=
paddle
.
zeros
([
1
])
loss_l1
=
paddle
.
zeros
([
1
])
loss_cls
.
stop_gradient
=
False
loss_iou
.
stop_gradient
=
False
loss_l1
.
stop_gradient
=
False
loss
=
self
.
loss_weight
[
'obj'
]
*
loss_obj
+
\
self
.
loss_weight
[
'cls'
]
*
loss_cls
+
\
self
.
loss_weight
[
'iou'
]
*
loss_iou
if
targets
[
'epoch_id'
]
>=
self
.
l1_epoch
:
loss
+=
(
self
.
loss_weight
[
'l1'
]
*
loss_l1
)
yolox_losses
=
{
'loss'
:
loss
,
'loss_cls'
:
loss_cls
,
'loss_obj'
:
loss_obj
,
'loss_iou'
:
loss_iou
,
'loss_l1'
:
loss_l1
,
}
return
yolox_losses
def
post_process
(
self
,
head_outs
,
img_shape
,
scale_factor
):
pred_scores
,
pred_bboxes
,
stride_tensor
=
head_outs
pred_scores
=
pred_scores
.
transpose
([
0
,
2
,
1
])
pred_bboxes
*=
stride_tensor
# scale bbox to origin image
scale_factor
=
scale_factor
.
flip
(
-
1
).
tile
([
1
,
2
]).
unsqueeze
(
1
)
pred_bboxes
/=
scale_factor
bbox_pred
,
bbox_num
,
_
=
self
.
nms
(
pred_bboxes
,
pred_scores
)
return
bbox_pred
,
bbox_num
ppdet/modeling/initializer.py
浏览文件 @
10bf8de7
...
...
@@ -273,7 +273,8 @@ def linear_init_(module):
def
conv_init_
(
module
):
bound
=
1
/
np
.
sqrt
(
np
.
prod
(
module
.
weight
.
shape
[
1
:]))
uniform_
(
module
.
weight
,
-
bound
,
bound
)
uniform_
(
module
.
bias
,
-
bound
,
bound
)
if
module
.
bias
is
not
None
:
uniform_
(
module
.
bias
,
-
bound
,
bound
)
def
bias_init_with_prob
(
prior_prob
=
0.01
):
...
...
ppdet/modeling/necks/yolo_fpn.py
浏览文件 @
10bf8de7
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
...
...
@@ -19,8 +19,9 @@ from ppdet.core.workspace import register, serializable
from
ppdet.modeling.layers
import
DropBlock
from
..backbones.darknet
import
ConvBNLayer
from
..shape_spec
import
ShapeSpec
from
..backbones.csp_darknet
import
BaseConv
,
DWConv
,
CSPLayer
__all__
=
[
'YOLOv3FPN'
,
'PPYOLOFPN'
,
'PPYOLOTinyFPN'
,
'PPYOLOPAN'
]
__all__
=
[
'YOLOv3FPN'
,
'PPYOLOFPN'
,
'PPYOLOTinyFPN'
,
'PPYOLOPAN'
,
'YOLOCSPPAN'
]
def
add_coord
(
x
,
data_format
):
...
...
@@ -986,3 +987,102 @@ class PPYOLOPAN(nn.Layer):
@
property
def
out_shape
(
self
):
return
[
ShapeSpec
(
channels
=
c
)
for
c
in
self
.
_out_channels
]
@
register
@
serializable
class
YOLOCSPPAN
(
nn
.
Layer
):
"""
YOLO CSP-PAN, used in YOLOv5 and YOLOX.
"""
__shared__
=
[
'depth_mult'
,
'act'
]
def
__init__
(
self
,
depth_mult
=
1.0
,
in_channels
=
[
256
,
512
,
1024
],
depthwise
=
False
,
act
=
'silu'
):
super
(
YOLOCSPPAN
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
_out_channels
=
in_channels
Conv
=
DWConv
if
depthwise
else
BaseConv
self
.
upsample
=
nn
.
Upsample
(
scale_factor
=
2
,
mode
=
"nearest"
)
# top-down fpn
self
.
lateral_convs
=
nn
.
LayerList
()
self
.
fpn_blocks
=
nn
.
LayerList
()
for
idx
in
range
(
len
(
in_channels
)
-
1
,
0
,
-
1
):
self
.
lateral_convs
.
append
(
BaseConv
(
int
(
in_channels
[
idx
]),
int
(
in_channels
[
idx
-
1
]),
1
,
1
,
act
=
act
))
self
.
fpn_blocks
.
append
(
CSPLayer
(
int
(
in_channels
[
idx
-
1
]
*
2
),
int
(
in_channels
[
idx
-
1
]),
round
(
3
*
depth_mult
),
shortcut
=
False
,
depthwise
=
depthwise
,
act
=
act
))
# bottom-up pan
self
.
downsample_convs
=
nn
.
LayerList
()
self
.
pan_blocks
=
nn
.
LayerList
()
for
idx
in
range
(
len
(
in_channels
)
-
1
):
self
.
downsample_convs
.
append
(
Conv
(
int
(
in_channels
[
idx
]),
int
(
in_channels
[
idx
]),
3
,
stride
=
2
,
act
=
act
))
self
.
pan_blocks
.
append
(
CSPLayer
(
int
(
in_channels
[
idx
]
*
2
),
int
(
in_channels
[
idx
+
1
]),
round
(
3
*
depth_mult
),
shortcut
=
False
,
depthwise
=
depthwise
,
act
=
act
))
def
forward
(
self
,
feats
,
for_mot
=
False
):
assert
len
(
feats
)
==
len
(
self
.
in_channels
)
# top-down fpn
inner_outs
=
[
feats
[
-
1
]]
for
idx
in
range
(
len
(
self
.
in_channels
)
-
1
,
0
,
-
1
):
feat_heigh
=
inner_outs
[
0
]
feat_low
=
feats
[
idx
-
1
]
feat_heigh
=
self
.
lateral_convs
[
len
(
self
.
in_channels
)
-
1
-
idx
](
feat_heigh
)
inner_outs
[
0
]
=
feat_heigh
upsample_feat
=
self
.
upsample
(
feat_heigh
)
inner_out
=
self
.
fpn_blocks
[
len
(
self
.
in_channels
)
-
1
-
idx
](
paddle
.
concat
(
[
upsample_feat
,
feat_low
],
axis
=
1
))
inner_outs
.
insert
(
0
,
inner_out
)
# bottom-up pan
outs
=
[
inner_outs
[
0
]]
for
idx
in
range
(
len
(
self
.
in_channels
)
-
1
):
feat_low
=
outs
[
-
1
]
feat_height
=
inner_outs
[
idx
+
1
]
downsample_feat
=
self
.
downsample_convs
[
idx
](
feat_low
)
out
=
self
.
pan_blocks
[
idx
](
paddle
.
concat
(
[
downsample_feat
,
feat_height
],
axis
=
1
))
outs
.
append
(
out
)
return
outs
@
classmethod
def
from_config
(
cls
,
cfg
,
input_shape
):
return
{
'in_channels'
:
[
i
.
channels
for
i
in
input_shape
],
}
@
property
def
out_shape
(
self
):
return
[
ShapeSpec
(
channels
=
c
)
for
c
in
self
.
_out_channels
]
ppdet/modeling/ops.py
浏览文件 @
10bf8de7
#
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
#
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
...
...
@@ -28,7 +28,7 @@ __all__ = [
'roi_pool'
,
'roi_align'
,
'prior_box'
,
'generate_proposals'
,
'iou_similarity'
,
'box_coder'
,
'yolo_box'
,
'multiclass_nms'
,
'distribute_fpn_proposals'
,
'collect_fpn_proposals'
,
'matrix_nms'
,
'batch_norm'
,
'mish'
,
'swish'
,
'identity'
'batch_norm'
,
'
get_activation'
,
'
mish'
,
'swish'
,
'identity'
]
...
...
@@ -106,6 +106,18 @@ def batch_norm(ch,
return
norm_layer
def
get_activation
(
name
=
"silu"
):
if
name
==
"silu"
:
module
=
nn
.
Silu
()
elif
name
==
"relu"
:
module
=
nn
.
ReLU
()
elif
name
==
"leakyrelu"
:
module
=
nn
.
LeakyReLU
(
0.1
)
else
:
raise
AttributeError
(
"Unsupported act type: {}"
.
format
(
name
))
return
module
@
paddle
.
jit
.
not_to_static
def
roi_pool
(
input
,
rois
,
...
...
ppdet/optimizer.py
浏览文件 @
10bf8de7
...
...
@@ -209,6 +209,33 @@ class BurninWarmup(object):
return
boundary
,
value
@
serializable
class
ExpWarmup
(
object
):
"""
Warm up learning rate in exponential mode
Args:
steps (int): warm up steps.
epochs (int|None): use epochs as warm up steps, the priority
of `epochs` is higher than `steps`. Default: None.
"""
def
__init__
(
self
,
steps
=
5
,
epochs
=
None
):
super
(
ExpWarmup
,
self
).
__init__
()
self
.
steps
=
steps
self
.
epochs
=
epochs
def
__call__
(
self
,
base_lr
,
step_per_epoch
):
boundary
=
[]
value
=
[]
warmup_steps
=
self
.
epochs
*
step_per_epoch
if
self
.
epochs
is
not
None
else
self
.
steps
for
i
in
range
(
warmup_steps
+
1
):
factor
=
(
i
/
float
(
warmup_steps
))
**
2
value
.
append
(
base_lr
*
factor
)
if
i
>
0
:
boundary
.
append
(
i
)
return
boundary
,
value
@
register
class
LearningRate
(
object
):
"""
...
...
@@ -331,7 +358,8 @@ class ModelEMA(object):
Ema's parameter are updated with the formula:
`ema_param = decay * ema_param + (1 - decay) * cur_param`.
Defaults is 0.9998.
use_thres_step (bool): Whether set decay by thres_step or not
ema_decay_type (str): type in ['threshold', 'normal', 'exponential'],
'threshold' as default.
cycle_epoch (int): The epoch of interval to reset ema_param and
step. Defaults is -1, which means not reset. Its function is to
add a regular effect to ema, which is set according to experience
...
...
@@ -341,7 +369,7 @@ class ModelEMA(object):
def
__init__
(
self
,
model
,
decay
=
0.9998
,
use_thres_step
=
False
,
ema_decay_type
=
'threshold'
,
cycle_epoch
=-
1
):
self
.
step
=
0
self
.
epoch
=
0
...
...
@@ -349,7 +377,7 @@ class ModelEMA(object):
self
.
state_dict
=
dict
()
for
k
,
v
in
model
.
state_dict
().
items
():
self
.
state_dict
[
k
]
=
paddle
.
zeros_like
(
v
)
self
.
use_thres_step
=
use_thres_step
self
.
ema_decay_type
=
ema_decay_type
self
.
cycle_epoch
=
cycle_epoch
self
.
_model_state
=
{
...
...
@@ -370,8 +398,10 @@ class ModelEMA(object):
self
.
step
=
step
def
update
(
self
,
model
=
None
):
if
self
.
use_thres_step
:
if
self
.
ema_decay_type
==
'threshold'
:
decay
=
min
(
self
.
decay
,
(
1
+
self
.
step
)
/
(
10
+
self
.
step
))
elif
self
.
ema_decay_type
==
'exponential'
:
decay
=
self
.
decay
*
(
1
-
math
.
exp
(
-
(
self
.
step
+
1
)
/
2000
))
else
:
decay
=
self
.
decay
self
.
_decay
=
decay
...
...
@@ -394,7 +424,8 @@ class ModelEMA(object):
return
self
.
state_dict
state_dict
=
dict
()
for
k
,
v
in
self
.
state_dict
.
items
():
v
=
v
/
(
1
-
self
.
_decay
**
self
.
step
)
if
self
.
ema_decay_type
!=
'exponential'
:
v
=
v
/
(
1
-
self
.
_decay
**
self
.
step
)
v
.
stop_gradient
=
True
state_dict
[
k
]
=
v
self
.
epoch
+=
1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录