Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
fa67fb9f
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fa67fb9f
编写于
10月 17, 2022
作者:
S
shangliang Xu
提交者:
GitHub
10月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[dev] fix export model bug in DETR (#7120)
上级
6d6573b1
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
52 addition
and
35 deletion
+52
-35
deploy/python/infer.py
deploy/python/infer.py
+12
-2
ppdet/engine/export_utils.py
ppdet/engine/export_utils.py
+1
-2
ppdet/modeling/architectures/detr.py
ppdet/modeling/architectures/detr.py
+12
-4
ppdet/modeling/post_process.py
ppdet/modeling/post_process.py
+3
-3
ppdet/modeling/transformers/detr_transformer.py
ppdet/modeling/transformers/detr_transformer.py
+17
-14
ppdet/modeling/transformers/position_encoding.py
ppdet/modeling/transformers/position_encoding.py
+3
-6
ppdet/modeling/transformers/utils.py
ppdet/modeling/transformers/utils.py
+4
-4
未找到文件。
deploy/python/infer.py
浏览文件 @
fa67fb9f
...
...
@@ -42,9 +42,11 @@ from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco
SUPPORT_MODELS
=
{
'YOLO'
,
'RCNN'
,
'SSD'
,
'Face'
,
'FCOS'
,
'SOLOv2'
,
'TTFNet'
,
'S2ANet'
,
'JDE'
,
'FairMOT'
,
'DeepSORT'
,
'GFL'
,
'PicoDet'
,
'CenterNet'
,
'TOOD'
,
'RetinaNet'
,
'StrongBaseline'
,
'STGCN'
,
'YOLOX'
,
'PPHGNet'
,
'PPLCNet'
'StrongBaseline'
,
'STGCN'
,
'YOLOX'
,
'PPHGNet'
,
'PPLCNet'
,
'DETR'
}
TUNED_TRT_DYNAMIC_MODELS
=
{
'DETR'
}
def
bench_log
(
detector
,
img_list
,
model_info
,
batch_size
=
1
,
name
=
None
):
mems
=
{
...
...
@@ -103,6 +105,7 @@ class Detector(object):
self
.
pred_config
=
self
.
set_config
(
model_dir
)
self
.
predictor
,
self
.
config
=
load_predictor
(
model_dir
,
self
.
pred_config
.
arch
,
run_mode
=
run_mode
,
batch_size
=
batch_size
,
min_subgraph_size
=
self
.
pred_config
.
min_subgraph_size
,
...
...
@@ -775,6 +778,7 @@ class PredictConfig():
def
load_predictor
(
model_dir
,
arch
,
run_mode
=
'paddle'
,
batch_size
=
1
,
device
=
'CPU'
,
...
...
@@ -787,7 +791,8 @@ def load_predictor(model_dir,
cpu_threads
=
1
,
enable_mkldnn
=
False
,
enable_mkldnn_bfloat16
=
False
,
delete_shuffle_pass
=
False
):
delete_shuffle_pass
=
False
,
tuned_trt_shape_file
=
"shape_range_info.pbtxt"
):
"""set AnalysisConfig, generate AnalysisPredictor
Args:
model_dir (str): root path of __model__ and __params__
...
...
@@ -854,6 +859,8 @@ def load_predictor(model_dir,
'trt_fp16'
:
Config
.
Precision
.
Half
}
if
run_mode
in
precision_map
.
keys
():
if
arch
in
TUNED_TRT_DYNAMIC_MODELS
:
config
.
collect_shape_range_info
(
tuned_trt_shape_file
)
config
.
enable_tensorrt_engine
(
workspace_size
=
(
1
<<
25
)
*
batch_size
,
max_batch_size
=
batch_size
,
...
...
@@ -861,6 +868,9 @@ def load_predictor(model_dir,
precision_mode
=
precision_map
[
run_mode
],
use_static
=
False
,
use_calib_mode
=
trt_calib_mode
)
if
arch
in
TUNED_TRT_DYNAMIC_MODELS
:
config
.
enable_tuned_tensorrt_dynamic_shape
(
tuned_trt_shape_file
,
True
)
if
use_dynamic_shape
:
min_input_shape
=
{
...
...
ppdet/engine/export_utils.py
浏览文件 @
fa67fb9f
...
...
@@ -50,6 +50,7 @@ TRT_MIN_SUBGRAPH = {
'TOOD'
:
5
,
'YOLOX'
:
8
,
'METRO_Body'
:
3
,
'DETR'
:
3
,
}
KEYPOINT_ARCH
=
[
'HigherHRNet'
,
'TopDownHRNet'
]
...
...
@@ -134,7 +135,6 @@ def _dump_infer_config(config, path, image_shape, model):
export_onnx
=
config
.
get
(
'export_onnx'
,
False
)
export_eb
=
config
.
get
(
'export_eb'
,
False
)
infer_arch
=
config
[
'architecture'
]
if
'RCNN'
in
infer_arch
and
export_onnx
:
logger
.
warning
(
...
...
@@ -142,7 +142,6 @@ def _dump_infer_config(config, path, image_shape, model):
infer_cfg
[
'export_onnx'
]
=
True
infer_cfg
[
'export_eb'
]
=
export_eb
if
infer_arch
in
MOT_ARCH
:
if
infer_arch
==
'DeepSORT'
:
tracker_cfg
=
config
[
'DeepSORTTracker'
]
...
...
ppdet/modeling/architectures/detr.py
浏览文件 @
fa67fb9f
...
...
@@ -27,17 +27,20 @@ __all__ = ['DETR']
class
DETR
(
BaseArch
):
__category__
=
'architecture'
__inject__
=
[
'post_process'
]
__shared__
=
[
'exclude_post_process'
]
def
__init__
(
self
,
backbone
,
transformer
,
detr_head
,
post_process
=
'DETRBBoxPostProcess'
):
post_process
=
'DETRBBoxPostProcess'
,
exclude_post_process
=
False
):
super
(
DETR
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
transformer
=
transformer
self
.
detr_head
=
detr_head
self
.
post_process
=
post_process
self
.
exclude_post_process
=
exclude_post_process
@
classmethod
def
from_config
(
cls
,
cfg
,
*
args
,
**
kwargs
):
...
...
@@ -65,15 +68,20 @@ class DETR(BaseArch):
body_feats
=
self
.
backbone
(
self
.
inputs
)
# Transformer
out_transformer
=
self
.
transformer
(
body_feats
,
self
.
inputs
[
'pad_mask'
])
pad_mask
=
self
.
inputs
[
'pad_mask'
]
if
self
.
training
else
None
out_transformer
=
self
.
transformer
(
body_feats
,
pad_mask
)
# DETR Head
if
self
.
training
:
return
self
.
detr_head
(
out_transformer
,
body_feats
,
self
.
inputs
)
else
:
preds
=
self
.
detr_head
(
out_transformer
,
body_feats
)
bbox
,
bbox_num
=
self
.
post_process
(
preds
,
self
.
inputs
[
'im_shape'
],
self
.
inputs
[
'scale_factor'
])
if
self
.
exclude_post_process
:
bboxes
,
logits
,
masks
=
preds
return
bboxes
,
logits
else
:
bbox
,
bbox_num
=
self
.
post_process
(
preds
,
self
.
inputs
[
'im_shape'
],
self
.
inputs
[
'scale_factor'
])
return
bbox
,
bbox_num
def
get_loss
(
self
,
):
...
...
ppdet/modeling/post_process.py
浏览文件 @
fa67fb9f
...
...
@@ -479,9 +479,9 @@ class DETRBBoxPostProcess(object):
bbox_pred
=
bbox_cxcywh_to_xyxy
(
bboxes
)
origin_shape
=
paddle
.
floor
(
im_shape
/
scale_factor
+
0.5
)
img_h
,
img_w
=
origin_shape
.
unbind
(
1
)
origin_shape
=
paddle
.
stack
(
[
img_w
,
img_h
,
img_w
,
img_h
],
axis
=-
1
).
unsqueeze
(
0
)
img_h
,
img_w
=
paddle
.
split
(
origin_shape
,
2
,
axis
=-
1
)
origin_shape
=
paddle
.
concat
(
[
img_w
,
img_h
,
img_w
,
img_h
],
axis
=-
1
).
reshape
([
-
1
,
1
,
4
]
)
bbox_pred
*=
origin_shape
scores
=
F
.
sigmoid
(
logits
)
if
self
.
use_focal_loss
else
F
.
softmax
(
...
...
ppdet/modeling/transformers/detr_transformer.py
浏览文件 @
fa67fb9f
...
...
@@ -69,8 +69,6 @@ class TransformerEncoderLayer(nn.Layer):
return
tensor
if
pos_embed
is
None
else
tensor
+
pos_embed
def
forward
(
self
,
src
,
src_mask
=
None
,
pos_embed
=
None
):
src_mask
=
_convert_attention_mask
(
src_mask
,
src
.
dtype
)
residual
=
src
if
self
.
normalize_before
:
src
=
self
.
norm1
(
src
)
...
...
@@ -99,8 +97,6 @@ class TransformerEncoder(nn.Layer):
self
.
norm
=
norm
def
forward
(
self
,
src
,
src_mask
=
None
,
pos_embed
=
None
):
src_mask
=
_convert_attention_mask
(
src_mask
,
src
.
dtype
)
output
=
src
for
layer
in
self
.
layers
:
output
=
layer
(
output
,
src_mask
=
src_mask
,
pos_embed
=
pos_embed
)
...
...
@@ -158,7 +154,6 @@ class TransformerDecoderLayer(nn.Layer):
pos_embed
=
None
,
query_pos_embed
=
None
):
tgt_mask
=
_convert_attention_mask
(
tgt_mask
,
tgt
.
dtype
)
memory_mask
=
_convert_attention_mask
(
memory_mask
,
memory
.
dtype
)
residual
=
tgt
if
self
.
normalize_before
:
...
...
@@ -209,7 +204,6 @@ class TransformerDecoder(nn.Layer):
pos_embed
=
None
,
query_pos_embed
=
None
):
tgt_mask
=
_convert_attention_mask
(
tgt_mask
,
tgt
.
dtype
)
memory_mask
=
_convert_attention_mask
(
memory_mask
,
memory
.
dtype
)
output
=
tgt
intermediate
=
[]
...
...
@@ -298,6 +292,9 @@ class DETRTransformer(nn.Layer):
'backbone_num_channels'
:
[
i
.
channels
for
i
in
input_shape
][
-
1
],
}
def
_convert_attention_mask
(
self
,
mask
):
return
(
mask
-
1.0
)
*
1e9
def
forward
(
self
,
src
,
src_mask
=
None
):
r
"""
Applies a Transformer model on the inputs.
...
...
@@ -321,20 +318,21 @@ class DETRTransformer(nn.Layer):
"""
# use last level feature map
src_proj
=
self
.
input_proj
(
src
[
-
1
])
bs
,
c
,
h
,
w
=
src_proj
.
shape
bs
,
c
,
h
,
w
=
paddle
.
shape
(
src_proj
)
# flatten [B, C, H, W] to [B, HxW, C]
src_flatten
=
src_proj
.
flatten
(
2
).
transpose
([
0
,
2
,
1
])
if
src_mask
is
not
None
:
src_mask
=
F
.
interpolate
(
src_mask
.
unsqueeze
(
0
).
astype
(
src_flatten
.
dtype
),
size
=
(
h
,
w
))[
0
].
astype
(
'bool'
)
src_mask
=
F
.
interpolate
(
src_mask
.
unsqueeze
(
0
),
size
=
(
h
,
w
))[
0
]
else
:
src_mask
=
paddle
.
ones
([
bs
,
h
,
w
]
,
dtype
=
'bool'
)
src_mask
=
paddle
.
ones
([
bs
,
h
,
w
])
pos_embed
=
self
.
position_embedding
(
src_mask
).
flatten
(
2
).
transpose
(
[
0
,
2
,
1
])
src_mask
=
_convert_attention_mask
(
src_mask
,
src_flatten
.
dtype
)
src_mask
=
src_mask
.
reshape
([
bs
,
1
,
1
,
-
1
])
if
self
.
training
:
src_mask
=
self
.
_convert_attention_mask
(
src_mask
)
src_mask
=
src_mask
.
reshape
([
bs
,
1
,
1
,
h
*
w
])
else
:
src_mask
=
None
memory
=
self
.
encoder
(
src_flatten
,
src_mask
=
src_mask
,
pos_embed
=
pos_embed
)
...
...
@@ -349,5 +347,10 @@ class DETRTransformer(nn.Layer):
pos_embed
=
pos_embed
,
query_pos_embed
=
query_pos_embed
)
if
self
.
training
:
src_mask
=
src_mask
.
reshape
([
bs
,
1
,
1
,
h
,
w
])
else
:
src_mask
=
None
return
(
output
,
memory
.
transpose
([
0
,
2
,
1
]).
reshape
([
bs
,
c
,
h
,
w
]),
src_proj
,
src_mask
.
reshape
([
bs
,
1
,
1
,
h
,
w
])
)
src_proj
,
src_mask
)
ppdet/modeling/transformers/position_encoding.py
浏览文件 @
fa67fb9f
...
...
@@ -65,11 +65,9 @@ class PositionEmbedding(nn.Layer):
Returns:
pos (Tensor): [B, C, H, W]
"""
assert
mask
.
dtype
==
paddle
.
bool
if
self
.
embed_type
==
'sine'
:
mask
=
mask
.
astype
(
'float32'
)
y_embed
=
mask
.
cumsum
(
1
,
dtype
=
'float32'
)
x_embed
=
mask
.
cumsum
(
2
,
dtype
=
'float32'
)
y_embed
=
mask
.
cumsum
(
1
)
x_embed
=
mask
.
cumsum
(
2
)
if
self
.
normalize
:
y_embed
=
(
y_embed
+
self
.
offset
)
/
(
y_embed
[:,
-
1
:,
:]
+
self
.
eps
)
*
self
.
scale
...
...
@@ -101,8 +99,7 @@ class PositionEmbedding(nn.Layer):
x_emb
.
unsqueeze
(
0
).
repeat
(
h
,
1
,
1
),
y_emb
.
unsqueeze
(
1
).
repeat
(
1
,
w
,
1
),
],
axis
=-
1
).
transpose
([
2
,
0
,
1
]).
unsqueeze
(
0
).
tile
(
mask
.
shape
[
0
],
1
,
1
,
1
)
axis
=-
1
).
transpose
([
2
,
0
,
1
]).
unsqueeze
(
0
)
return
pos
else
:
raise
ValueError
(
f
"not supported
{
self
.
embed_type
}
"
)
ppdet/modeling/transformers/utils.py
浏览文件 @
fa67fb9f
...
...
@@ -38,15 +38,15 @@ def _get_clones(module, N):
def
bbox_cxcywh_to_xyxy
(
x
):
x_c
,
y_c
,
w
,
h
=
x
.
unbind
(
-
1
)
x_c
,
y_c
,
w
,
h
=
x
.
split
(
4
,
axis
=
-
1
)
b
=
[(
x_c
-
0.5
*
w
),
(
y_c
-
0.5
*
h
),
(
x_c
+
0.5
*
w
),
(
y_c
+
0.5
*
h
)]
return
paddle
.
stack
(
b
,
axis
=-
1
)
return
paddle
.
concat
(
b
,
axis
=-
1
)
def
bbox_xyxy_to_cxcywh
(
x
):
x0
,
y0
,
x1
,
y1
=
x
.
unbind
(
-
1
)
x0
,
y0
,
x1
,
y1
=
x
.
split
(
4
,
axis
=
-
1
)
b
=
[(
x0
+
x1
)
/
2
,
(
y0
+
y1
)
/
2
,
(
x1
-
x0
),
(
y1
-
y0
)]
return
paddle
.
stack
(
b
,
axis
=-
1
)
return
paddle
.
concat
(
b
,
axis
=-
1
)
def
sigmoid_focal_loss
(
logit
,
label
,
normalizer
=
1.0
,
alpha
=
0.25
,
gamma
=
2.0
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录