Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
80542055
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
80542055
编写于
9月 08, 2022
作者:
S
shangliang Xu
提交者:
GitHub
9月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PPYOLOE] fix proj_conv in ptq bug (#6908)
上级
bbc2edfa
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
64 addition
and
35 deletion
+64
-35
configs/ppyoloe/_base_/ppyoloe_crn.yml
configs/ppyoloe/_base_/ppyoloe_crn.yml
+1
-0
configs/ppyoloe/_base_/ppyoloe_plus_crn.yml
configs/ppyoloe/_base_/ppyoloe_plus_crn.yml
+1
-0
configs/ppyoloe/ppyoloe_crn_l_36e_coco_xpu.yml
configs/ppyoloe/ppyoloe_crn_l_36e_coco_xpu.yml
+2
-0
deploy/auto_compression/configs/ppyoloe_plus_m_qat_dis.yaml
deploy/auto_compression/configs/ppyoloe_plus_m_qat_dis.yaml
+1
-0
deploy/auto_compression/configs/ppyoloe_plus_reader.yml
deploy/auto_compression/configs/ppyoloe_plus_reader.yml
+1
-3
deploy/auto_compression/configs/ppyoloe_plus_x_qat_dis.yaml
deploy/auto_compression/configs/ppyoloe_plus_x_qat_dis.yaml
+1
-0
deploy/python/utils.py
deploy/python/utils.py
+1
-1
ppdet/engine/trainer.py
ppdet/engine/trainer.py
+4
-2
ppdet/modeling/assigners/atss_assigner.py
ppdet/modeling/assigners/atss_assigner.py
+1
-1
ppdet/modeling/assigners/task_aligned_assigner.py
ppdet/modeling/assigners/task_aligned_assigner.py
+1
-1
ppdet/modeling/heads/ppyoloe_head.py
ppdet/modeling/heads/ppyoloe_head.py
+12
-15
ppdet/optimizer/ema.py
ppdet/optimizer/ema.py
+38
-12
未找到文件。
configs/ppyoloe/_base_/ppyoloe_crn.yml
浏览文件 @
80542055
...
@@ -2,6 +2,7 @@ architecture: YOLOv3
...
@@ -2,6 +2,7 @@ architecture: YOLOv3
norm_type
:
sync_bn
norm_type
:
sync_bn
use_ema
:
true
use_ema
:
true
ema_decay
:
0.9998
ema_decay
:
0.9998
ema_black_list
:
[
'
proj_conv.weight'
]
custom_black_list
:
[
'
reduce_mean'
]
custom_black_list
:
[
'
reduce_mean'
]
YOLOv3
:
YOLOv3
:
...
...
configs/ppyoloe/_base_/ppyoloe_plus_crn.yml
浏览文件 @
80542055
...
@@ -2,6 +2,7 @@ architecture: YOLOv3
...
@@ -2,6 +2,7 @@ architecture: YOLOv3
norm_type
:
sync_bn
norm_type
:
sync_bn
use_ema
:
true
use_ema
:
true
ema_decay
:
0.9998
ema_decay
:
0.9998
ema_black_list
:
[
'
proj_conv.weight'
]
custom_black_list
:
[
'
reduce_mean'
]
custom_black_list
:
[
'
reduce_mean'
]
YOLOv3
:
YOLOv3
:
...
...
configs/ppyoloe/ppyoloe_crn_l_36e_coco_xpu.yml
浏览文件 @
80542055
...
@@ -26,6 +26,8 @@ architecture: YOLOv3
...
@@ -26,6 +26,8 @@ architecture: YOLOv3
norm_type
:
sync_bn
norm_type
:
sync_bn
use_ema
:
true
use_ema
:
true
ema_decay
:
0.9998
ema_decay
:
0.9998
ema_black_list
:
[
'
proj_conv.weight'
]
custom_black_list
:
[
'
reduce_mean'
]
YOLOv3
:
YOLOv3
:
backbone
:
CSPResNet
backbone
:
CSPResNet
...
...
deploy/auto_compression/configs/ppyoloe_plus_m_qat_dis.yaml
浏览文件 @
80542055
...
@@ -14,6 +14,7 @@ Distillation:
...
@@ -14,6 +14,7 @@ Distillation:
Quantization
:
Quantization
:
use_pact
:
true
use_pact
:
true
onnx_format
:
True
activation_quantize_type
:
'
moving_average_abs_max'
activation_quantize_type
:
'
moving_average_abs_max'
quantize_op_types
:
quantize_op_types
:
-
conv2d
-
conv2d
...
...
deploy/auto_compression/configs/ppyoloe_plus_reader.yml
浏览文件 @
80542055
metric
:
COCO
metric
:
COCO
num_classes
:
80
num_classes
:
80
...
@@ -23,6 +21,6 @@ EvalReader:
...
@@ -23,6 +21,6 @@ EvalReader:
sample_transforms
:
sample_transforms
:
-
Decode
:
{}
-
Decode
:
{}
-
Resize
:
{
target_size
:
[
640
,
640
],
keep_ratio
:
False
,
interp
:
2
}
-
Resize
:
{
target_size
:
[
640
,
640
],
keep_ratio
:
False
,
interp
:
2
}
-
NormalizeImage
:
{
mean
:
[
0.
,
0.
,
0.
],
std
:
[
1.
,
1.
,
1.
],
is_scale
:
Tru
e
}
-
NormalizeImage
:
{
mean
:
[
0.
,
0.
,
0.
],
std
:
[
1.
,
1.
,
1.
],
norm_type
:
non
e
}
-
Permute
:
{}
-
Permute
:
{}
batch_size
:
4
batch_size
:
4
deploy/auto_compression/configs/ppyoloe_plus_x_qat_dis.yaml
浏览文件 @
80542055
...
@@ -14,6 +14,7 @@ Distillation:
...
@@ -14,6 +14,7 @@ Distillation:
Quantization
:
Quantization
:
use_pact
:
true
use_pact
:
true
onnx_format
:
True
activation_quantize_type
:
'
moving_average_abs_max'
activation_quantize_type
:
'
moving_average_abs_max'
quantize_op_types
:
quantize_op_types
:
-
conv2d
-
conv2d
...
...
deploy/python/utils.py
浏览文件 @
80542055
...
@@ -108,7 +108,7 @@ def argsparser():
...
@@ -108,7 +108,7 @@ def argsparser():
"calibration, trt_calib_mode need to set True."
)
"calibration, trt_calib_mode need to set True."
)
parser
.
add_argument
(
parser
.
add_argument
(
'--save_images'
,
'--save_images'
,
type
=
boo
l
,
type
=
ast
.
literal_eva
l
,
default
=
True
,
default
=
True
,
help
=
'Save visualization image results.'
)
help
=
'Save visualization image results.'
)
parser
.
add_argument
(
parser
.
add_argument
(
...
...
ppdet/engine/trainer.py
浏览文件 @
80542055
...
@@ -169,13 +169,15 @@ class Trainer(object):
...
@@ -169,13 +169,15 @@ class Trainer(object):
self
.
use_ema
=
(
'use_ema'
in
cfg
and
cfg
[
'use_ema'
])
self
.
use_ema
=
(
'use_ema'
in
cfg
and
cfg
[
'use_ema'
])
if
self
.
use_ema
:
if
self
.
use_ema
:
ema_decay
=
self
.
cfg
.
get
(
'ema_decay'
,
0.9998
)
ema_decay
=
self
.
cfg
.
get
(
'ema_decay'
,
0.9998
)
cycle_epoch
=
self
.
cfg
.
get
(
'cycle_epoch'
,
-
1
)
ema_decay_type
=
self
.
cfg
.
get
(
'ema_decay_type'
,
'threshold'
)
ema_decay_type
=
self
.
cfg
.
get
(
'ema_decay_type'
,
'threshold'
)
cycle_epoch
=
self
.
cfg
.
get
(
'cycle_epoch'
,
-
1
)
ema_black_list
=
self
.
cfg
.
get
(
'ema_black_list'
,
None
)
self
.
ema
=
ModelEMA
(
self
.
ema
=
ModelEMA
(
self
.
model
,
self
.
model
,
decay
=
ema_decay
,
decay
=
ema_decay
,
ema_decay_type
=
ema_decay_type
,
ema_decay_type
=
ema_decay_type
,
cycle_epoch
=
cycle_epoch
)
cycle_epoch
=
cycle_epoch
,
ema_black_list
=
ema_black_list
)
self
.
_nranks
=
dist
.
get_world_size
()
self
.
_nranks
=
dist
.
get_world_size
()
self
.
_local_rank
=
dist
.
get_rank
()
self
.
_local_rank
=
dist
.
get_rank
()
...
...
ppdet/modeling/assigners/atss_assigner.py
浏览文件 @
80542055
...
@@ -120,7 +120,7 @@ class ATSSAssigner(nn.Layer):
...
@@ -120,7 +120,7 @@ class ATSSAssigner(nn.Layer):
# negative batch
# negative batch
if
num_max_boxes
==
0
:
if
num_max_boxes
==
0
:
assigned_labels
=
paddle
.
full
(
assigned_labels
=
paddle
.
full
(
[
batch_size
,
num_anchors
],
bg_index
,
dtype
=
gt_labels
.
dtype
)
[
batch_size
,
num_anchors
],
bg_index
,
dtype
=
'int32'
)
assigned_bboxes
=
paddle
.
zeros
([
batch_size
,
num_anchors
,
4
])
assigned_bboxes
=
paddle
.
zeros
([
batch_size
,
num_anchors
,
4
])
assigned_scores
=
paddle
.
zeros
(
assigned_scores
=
paddle
.
zeros
(
[
batch_size
,
num_anchors
,
self
.
num_classes
])
[
batch_size
,
num_anchors
,
self
.
num_classes
])
...
...
ppdet/modeling/assigners/task_aligned_assigner.py
浏览文件 @
80542055
...
@@ -86,7 +86,7 @@ class TaskAlignedAssigner(nn.Layer):
...
@@ -86,7 +86,7 @@ class TaskAlignedAssigner(nn.Layer):
# negative batch
# negative batch
if
num_max_boxes
==
0
:
if
num_max_boxes
==
0
:
assigned_labels
=
paddle
.
full
(
assigned_labels
=
paddle
.
full
(
[
batch_size
,
num_anchors
],
bg_index
,
dtype
=
gt_labels
.
dtype
)
[
batch_size
,
num_anchors
],
bg_index
,
dtype
=
'int32'
)
assigned_bboxes
=
paddle
.
zeros
([
batch_size
,
num_anchors
,
4
])
assigned_bboxes
=
paddle
.
zeros
([
batch_size
,
num_anchors
,
4
])
assigned_scores
=
paddle
.
zeros
(
assigned_scores
=
paddle
.
zeros
(
[
batch_size
,
num_anchors
,
num_classes
])
[
batch_size
,
num_anchors
,
num_classes
])
...
...
ppdet/modeling/heads/ppyoloe_head.py
浏览文件 @
80542055
...
@@ -130,11 +130,10 @@ class PPYOLOEHead(nn.Layer):
...
@@ -130,11 +130,10 @@ class PPYOLOEHead(nn.Layer):
constant_
(
reg_
.
weight
)
constant_
(
reg_
.
weight
)
constant_
(
reg_
.
bias
,
1.0
)
constant_
(
reg_
.
bias
,
1.0
)
self
.
proj
=
paddle
.
linspace
(
0
,
self
.
reg_max
,
self
.
reg_max
+
1
)
proj
=
paddle
.
linspace
(
0
,
self
.
reg_max
,
self
.
reg_max
+
1
).
reshape
(
self
.
proj_conv
.
weight
.
set_value
(
[
1
,
self
.
reg_max
+
1
,
1
,
1
])
self
.
proj
.
reshape
([
1
,
self
.
reg_max
+
1
,
1
,
1
])
)
self
.
proj_conv
.
weight
.
set_value
(
proj
)
self
.
proj_conv
.
weight
.
stop_gradient
=
True
self
.
proj_conv
.
weight
.
stop_gradient
=
True
if
self
.
eval_size
:
if
self
.
eval_size
:
anchor_points
,
stride_tensor
=
self
.
_generate_anchors
()
anchor_points
,
stride_tensor
=
self
.
_generate_anchors
()
self
.
anchor_points
=
anchor_points
self
.
anchor_points
=
anchor_points
...
@@ -200,15 +199,15 @@ class PPYOLOEHead(nn.Layer):
...
@@ -200,15 +199,15 @@ class PPYOLOEHead(nn.Layer):
feat
)
feat
)
reg_dist
=
self
.
pred_reg
[
i
](
self
.
stem_reg
[
i
](
feat
,
avg_feat
))
reg_dist
=
self
.
pred_reg
[
i
](
self
.
stem_reg
[
i
](
feat
,
avg_feat
))
reg_dist
=
reg_dist
.
reshape
([
-
1
,
4
,
self
.
reg_max
+
1
,
l
]).
transpose
(
reg_dist
=
reg_dist
.
reshape
([
-
1
,
4
,
self
.
reg_max
+
1
,
l
]).
transpose
(
[
0
,
2
,
1
,
3
])
[
0
,
2
,
3
,
1
])
reg_dist
=
self
.
proj_conv
(
F
.
softmax
(
reg_dist
,
axis
=
1
))
reg_dist
=
self
.
proj_conv
(
F
.
softmax
(
reg_dist
,
axis
=
1
))
.
squeeze
(
1
)
# cls and reg
# cls and reg
cls_score
=
F
.
sigmoid
(
cls_logit
)
cls_score
=
F
.
sigmoid
(
cls_logit
)
cls_score_list
.
append
(
cls_score
.
reshape
([
b
,
self
.
num_classes
,
l
]))
cls_score_list
.
append
(
cls_score
.
reshape
([
b
,
self
.
num_classes
,
l
]))
reg_dist_list
.
append
(
reg_dist
.
reshape
([
b
,
4
,
l
])
)
reg_dist_list
.
append
(
reg_dist
)
cls_score_list
=
paddle
.
concat
(
cls_score_list
,
axis
=-
1
)
cls_score_list
=
paddle
.
concat
(
cls_score_list
,
axis
=-
1
)
reg_dist_list
=
paddle
.
concat
(
reg_dist_list
,
axis
=
-
1
)
reg_dist_list
=
paddle
.
concat
(
reg_dist_list
,
axis
=
1
)
return
cls_score_list
,
reg_dist_list
,
anchor_points
,
stride_tensor
return
cls_score_list
,
reg_dist_list
,
anchor_points
,
stride_tensor
...
@@ -240,8 +239,8 @@ class PPYOLOEHead(nn.Layer):
...
@@ -240,8 +239,8 @@ class PPYOLOEHead(nn.Layer):
def
_bbox_decode
(
self
,
anchor_points
,
pred_dist
):
def
_bbox_decode
(
self
,
anchor_points
,
pred_dist
):
b
,
l
,
_
=
get_static_shape
(
pred_dist
)
b
,
l
,
_
=
get_static_shape
(
pred_dist
)
pred_dist
=
F
.
softmax
(
pred_dist
.
reshape
([
b
,
l
,
4
,
self
.
reg_max
+
1
pred_dist
=
F
.
softmax
(
pred_dist
.
reshape
([
b
,
l
,
4
,
self
.
reg_max
+
1
]))
])).
matmul
(
self
.
proj
)
pred_dist
=
self
.
proj_conv
(
pred_dist
.
transpose
([
0
,
3
,
1
,
2
])).
squeeze
(
1
)
return
batch_distance2bbox
(
anchor_points
,
pred_dist
)
return
batch_distance2bbox
(
anchor_points
,
pred_dist
)
def
_bbox2distance
(
self
,
points
,
bbox
):
def
_bbox2distance
(
self
,
points
,
bbox
):
...
@@ -347,9 +346,8 @@ class PPYOLOEHead(nn.Layer):
...
@@ -347,9 +346,8 @@ class PPYOLOEHead(nn.Layer):
assigned_scores_sum
=
assigned_scores
.
sum
()
assigned_scores_sum
=
assigned_scores
.
sum
()
if
paddle
.
distributed
.
get_world_size
()
>
1
:
if
paddle
.
distributed
.
get_world_size
()
>
1
:
paddle
.
distributed
.
all_reduce
(
assigned_scores_sum
)
paddle
.
distributed
.
all_reduce
(
assigned_scores_sum
)
assigned_scores_sum
=
paddle
.
clip
(
assigned_scores_sum
/=
paddle
.
distributed
.
get_world_size
()
assigned_scores_sum
/
paddle
.
distributed
.
get_world_size
(),
assigned_scores_sum
=
paddle
.
clip
(
assigned_scores_sum
,
min
=
1.
)
min
=
1
)
loss_cls
/=
assigned_scores_sum
loss_cls
/=
assigned_scores_sum
loss_l1
,
loss_iou
,
loss_dfl
=
\
loss_l1
,
loss_iou
,
loss_dfl
=
\
...
@@ -370,8 +368,7 @@ class PPYOLOEHead(nn.Layer):
...
@@ -370,8 +368,7 @@ class PPYOLOEHead(nn.Layer):
def
post_process
(
self
,
head_outs
,
scale_factor
):
def
post_process
(
self
,
head_outs
,
scale_factor
):
pred_scores
,
pred_dist
,
anchor_points
,
stride_tensor
=
head_outs
pred_scores
,
pred_dist
,
anchor_points
,
stride_tensor
=
head_outs
pred_bboxes
=
batch_distance2bbox
(
anchor_points
,
pred_bboxes
=
batch_distance2bbox
(
anchor_points
,
pred_dist
)
pred_dist
.
transpose
([
0
,
2
,
1
]))
pred_bboxes
*=
stride_tensor
pred_bboxes
*=
stride_tensor
if
self
.
exclude_post_process
:
if
self
.
exclude_post_process
:
return
paddle
.
concat
(
return
paddle
.
concat
(
...
...
ppdet/optimizer/ema.py
浏览文件 @
80542055
...
@@ -36,21 +36,30 @@ class ModelEMA(object):
...
@@ -36,21 +36,30 @@ class ModelEMA(object):
step. Defaults is -1, which means not reset. Its function is to
step. Defaults is -1, which means not reset. Its function is to
add a regular effect to ema, which is set according to experience
add a regular effect to ema, which is set according to experience
and is effective when the total training epoch is large.
and is effective when the total training epoch is large.
ema_black_list (set|list|tuple, optional): The custom EMA black_list.
Blacklist of weight names that will not participate in EMA
calculation. Default: None.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
model
,
model
,
decay
=
0.9998
,
decay
=
0.9998
,
ema_decay_type
=
'threshold'
,
ema_decay_type
=
'threshold'
,
cycle_epoch
=-
1
):
cycle_epoch
=-
1
,
ema_black_list
=
None
):
self
.
step
=
0
self
.
step
=
0
self
.
epoch
=
0
self
.
epoch
=
0
self
.
decay
=
decay
self
.
decay
=
decay
self
.
state_dict
=
dict
()
for
k
,
v
in
model
.
state_dict
().
items
():
self
.
state_dict
[
k
]
=
paddle
.
zeros_like
(
v
)
self
.
ema_decay_type
=
ema_decay_type
self
.
ema_decay_type
=
ema_decay_type
self
.
cycle_epoch
=
cycle_epoch
self
.
cycle_epoch
=
cycle_epoch
self
.
ema_black_list
=
self
.
_match_ema_black_list
(
model
.
state_dict
().
keys
(),
ema_black_list
)
self
.
state_dict
=
dict
()
for
k
,
v
in
model
.
state_dict
().
items
():
if
k
in
self
.
ema_black_list
:
self
.
state_dict
[
k
]
=
v
else
:
self
.
state_dict
[
k
]
=
paddle
.
zeros_like
(
v
)
self
.
_model_state
=
{
self
.
_model_state
=
{
k
:
weakref
.
ref
(
p
)
k
:
weakref
.
ref
(
p
)
...
@@ -61,7 +70,10 @@ class ModelEMA(object):
...
@@ -61,7 +70,10 @@ class ModelEMA(object):
self
.
step
=
0
self
.
step
=
0
self
.
epoch
=
0
self
.
epoch
=
0
for
k
,
v
in
self
.
state_dict
.
items
():
for
k
,
v
in
self
.
state_dict
.
items
():
self
.
state_dict
[
k
]
=
paddle
.
zeros_like
(
v
)
if
k
in
self
.
ema_black_list
:
self
.
state_dict
[
k
]
=
v
else
:
self
.
state_dict
[
k
]
=
paddle
.
zeros_like
(
v
)
def
resume
(
self
,
state_dict
,
step
=
0
):
def
resume
(
self
,
state_dict
,
step
=
0
):
for
k
,
v
in
state_dict
.
items
():
for
k
,
v
in
state_dict
.
items
():
...
@@ -89,9 +101,10 @@ class ModelEMA(object):
...
@@ -89,9 +101,10 @@ class ModelEMA(object):
[
v
is
not
None
for
_
,
v
in
model_dict
.
items
()]),
'python gc.'
[
v
is
not
None
for
_
,
v
in
model_dict
.
items
()]),
'python gc.'
for
k
,
v
in
self
.
state_dict
.
items
():
for
k
,
v
in
self
.
state_dict
.
items
():
v
=
decay
*
v
+
(
1
-
decay
)
*
model_dict
[
k
]
if
k
not
in
self
.
ema_black_list
:
v
.
stop_gradient
=
True
v
=
decay
*
v
+
(
1
-
decay
)
*
model_dict
[
k
]
self
.
state_dict
[
k
]
=
v
v
.
stop_gradient
=
True
self
.
state_dict
[
k
]
=
v
self
.
step
+=
1
self
.
step
+=
1
def
apply
(
self
):
def
apply
(
self
):
...
@@ -99,12 +112,25 @@ class ModelEMA(object):
...
@@ -99,12 +112,25 @@ class ModelEMA(object):
return
self
.
state_dict
return
self
.
state_dict
state_dict
=
dict
()
state_dict
=
dict
()
for
k
,
v
in
self
.
state_dict
.
items
():
for
k
,
v
in
self
.
state_dict
.
items
():
if
self
.
ema_decay_type
!=
'exponential'
:
if
k
in
self
.
ema_black_list
:
v
=
v
/
(
1
-
self
.
_decay
**
self
.
step
)
v
.
stop_gradient
=
True
v
.
stop_gradient
=
True
state_dict
[
k
]
=
v
state_dict
[
k
]
=
v
else
:
if
self
.
ema_decay_type
!=
'exponential'
:
v
=
v
/
(
1
-
self
.
_decay
**
self
.
step
)
v
.
stop_gradient
=
True
state_dict
[
k
]
=
v
self
.
epoch
+=
1
self
.
epoch
+=
1
if
self
.
cycle_epoch
>
0
and
self
.
epoch
==
self
.
cycle_epoch
:
if
self
.
cycle_epoch
>
0
and
self
.
epoch
==
self
.
cycle_epoch
:
self
.
reset
()
self
.
reset
()
return
state_dict
return
state_dict
def
_match_ema_black_list
(
self
,
weight_name
,
ema_black_list
=
None
):
out_list
=
set
()
if
ema_black_list
:
for
name
in
weight_name
:
for
key
in
ema_black_list
:
if
key
in
name
:
out_list
.
add
(
name
)
return
out_list
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录