Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
a1446709
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a1446709
编写于
12月 11, 2020
作者:
G
Guanghua Yu
提交者:
GitHub
12月 11, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dygraph] update dygraph export model (#1857)
* update dygraph export model * delete get_export * adapt faster and cascade
上级
9b279ee3
变更
18
显示空白变更内容
内联
并排
Showing
18 changed file
with
448 addition
and
403 deletion
+448
-403
deploy/python/infer.py
deploy/python/infer.py
+1
-1
deploy/python/preprocess.py
deploy/python/preprocess.py
+1
-1
ppdet/modeling/architecture/cascade_rcnn.py
ppdet/modeling/architecture/cascade_rcnn.py
+4
-9
ppdet/modeling/architecture/faster_rcnn.py
ppdet/modeling/architecture/faster_rcnn.py
+3
-5
ppdet/modeling/architecture/mask_rcnn.py
ppdet/modeling/architecture/mask_rcnn.py
+4
-8
ppdet/modeling/architecture/meta_arch.py
ppdet/modeling/architecture/meta_arch.py
+8
-5
ppdet/modeling/architecture/yolo.py
ppdet/modeling/architecture/yolo.py
+5
-9
ppdet/modeling/head/mask_head.py
ppdet/modeling/head/mask_head.py
+11
-4
ppdet/modeling/head/roi_extractor.py
ppdet/modeling/head/roi_extractor.py
+23
-24
ppdet/modeling/neck/fpn.py
ppdet/modeling/neck/fpn.py
+2
-1
ppdet/modeling/ops.py
ppdet/modeling/ops.py
+293
-268
ppdet/py_op/post_process.py
ppdet/py_op/post_process.py
+2
-3
ppdet/utils/eval_utils.py
ppdet/utils/eval_utils.py
+18
-8
tools/eval.py
tools/eval.py
+17
-3
tools/export_model.py
tools/export_model.py
+26
-46
tools/export_utils.py
tools/export_utils.py
+2
-1
tools/infer.py
tools/infer.py
+21
-4
tools/train.py
tools/train.py
+7
-3
未找到文件。
deploy/python/infer.py
浏览文件 @
a1446709
...
@@ -133,7 +133,7 @@ class Detector(object):
...
@@ -133,7 +133,7 @@ class Detector(object):
boxes_tensor
=
self
.
predictor
.
get_output_handle
(
output_names
[
0
])
boxes_tensor
=
self
.
predictor
.
get_output_handle
(
output_names
[
0
])
np_boxes
=
boxes_tensor
.
copy_to_cpu
()
np_boxes
=
boxes_tensor
.
copy_to_cpu
()
if
self
.
pred_config
.
mask_resolution
is
not
None
:
if
self
.
pred_config
.
mask_resolution
is
not
None
:
masks_tensor
=
self
.
predictor
.
get_output_handle
(
output_names
[
1
])
masks_tensor
=
self
.
predictor
.
get_output_handle
(
output_names
[
2
])
np_masks
=
masks_tensor
.
copy_to_cpu
()
np_masks
=
masks_tensor
.
copy_to_cpu
()
t2
=
time
.
time
()
t2
=
time
.
time
()
ms
=
(
t2
-
t1
)
*
1000.0
/
repeats
ms
=
(
t2
-
t1
)
*
1000.0
/
repeats
...
...
deploy/python/preprocess.py
浏览文件 @
a1446709
...
@@ -79,7 +79,7 @@ class ResizeOp(object):
...
@@ -79,7 +79,7 @@ class ResizeOp(object):
im_info
[
'scale_factor'
]
=
np
.
array
(
im_info
[
'scale_factor'
]
=
np
.
array
(
[
im_scale_y
,
im_scale_x
]).
astype
(
'float32'
)
[
im_scale_y
,
im_scale_x
]).
astype
(
'float32'
)
# padding im when image_shape fixed by infer_cfg.yml
# padding im when image_shape fixed by infer_cfg.yml
if
self
.
keep_ratio
:
if
self
.
keep_ratio
and
im_info
[
'input_shape'
][
1
]
is
not
None
:
max_size
=
im_info
[
'input_shape'
][
1
]
max_size
=
im_info
[
'input_shape'
][
1
]
padding_im
=
np
.
zeros
(
padding_im
=
np
.
zeros
(
(
max_size
,
max_size
,
im_channel
),
dtype
=
np
.
float32
)
(
max_size
,
max_size
,
im_channel
),
dtype
=
np
.
float32
)
...
...
ppdet/modeling/architecture/cascade_rcnn.py
浏览文件 @
a1446709
...
@@ -158,17 +158,12 @@ class CascadeRCNN(BaseArch):
...
@@ -158,17 +158,12 @@ class CascadeRCNN(BaseArch):
loss
.
update
({
'loss'
:
total_loss
})
loss
.
update
({
'loss'
:
total_loss
})
return
loss
return
loss
def
get_pred
(
self
,
return_numpy
=
True
):
def
get_pred
(
self
):
bbox
,
bbox_num
=
self
.
bboxes
bbox
,
bbox_num
=
self
.
bboxes
output
=
{
output
=
{
'bbox'
:
bbox
.
numpy
(),
'bbox'
:
bbox
,
'bbox_num'
:
bbox_num
.
numpy
(),
'bbox_num'
:
bbox_num
,
'im_id'
:
self
.
inputs
[
'im_id'
].
numpy
(),
}
}
if
self
.
with_mask
:
if
self
.
with_mask
:
mask
=
self
.
mask_post_process
(
self
.
bboxes
,
self
.
mask_head_out
,
output
.
update
(
self
.
mask_head_out
)
self
.
inputs
[
'im_shape'
],
self
.
inputs
[
'scale_factor'
])
output
.
update
(
mask
)
return
output
return
output
ppdet/modeling/architecture/faster_rcnn.py
浏览文件 @
a1446709
...
@@ -92,12 +92,10 @@ class FasterRCNN(BaseArch):
...
@@ -92,12 +92,10 @@ class FasterRCNN(BaseArch):
loss
.
update
({
'loss'
:
total_loss
})
loss
.
update
({
'loss'
:
total_loss
})
return
loss
return
loss
def
get_pred
(
self
,
return_numpy
=
True
):
def
get_pred
(
self
):
bbox
,
bbox_num
=
self
.
bboxes
bbox
,
bbox_num
=
self
.
bboxes
output
=
{
output
=
{
'bbox'
:
bbox
.
numpy
(),
'bbox'
:
bbox
,
'bbox_num'
:
bbox_num
.
numpy
(),
'bbox_num'
:
bbox_num
,
'im_id'
:
self
.
inputs
[
'im_id'
].
numpy
()
}
}
return
output
return
output
ppdet/modeling/architecture/mask_rcnn.py
浏览文件 @
a1446709
...
@@ -133,15 +133,11 @@ class MaskRCNN(BaseArch):
...
@@ -133,15 +133,11 @@ class MaskRCNN(BaseArch):
loss
.
update
({
'loss'
:
total_loss
})
loss
.
update
({
'loss'
:
total_loss
})
return
loss
return
loss
def
get_pred
(
self
,
return_numpy
=
True
):
def
get_pred
(
self
):
mask
=
self
.
mask_post_process
(
self
.
bboxes
,
self
.
mask_head_out
,
self
.
inputs
[
'im_shape'
],
self
.
inputs
[
'scale_factor'
])
bbox
,
bbox_num
=
self
.
bboxes
bbox
,
bbox_num
=
self
.
bboxes
output
=
{
output
=
{
'bbox'
:
bbox
.
numpy
()
,
'bbox'
:
bbox
,
'bbox_num'
:
bbox_num
.
numpy
()
,
'bbox_num'
:
bbox_num
,
'
im_id'
:
self
.
inputs
[
'im_id'
].
numpy
()
'
mask'
:
self
.
mask_head_out
}
}
output
.
update
(
mask
)
return
output
return
output
ppdet/modeling/architecture/meta_arch.py
浏览文件 @
a1446709
...
@@ -16,18 +16,24 @@ class BaseArch(nn.Layer):
...
@@ -16,18 +16,24 @@ class BaseArch(nn.Layer):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
BaseArch
,
self
).
__init__
()
super
(
BaseArch
,
self
).
__init__
()
def
forward
(
self
,
data
,
input_def
,
mode
,
input_tensor
=
None
):
def
forward
(
self
,
input_tensor
=
None
,
data
=
None
,
input_def
=
None
,
mode
=
'infer'
):
if
input_tensor
is
None
:
if
input_tensor
is
None
:
assert
data
is
not
None
and
input_def
is
not
None
self
.
inputs
=
self
.
build_inputs
(
data
,
input_def
)
self
.
inputs
=
self
.
build_inputs
(
data
,
input_def
)
else
:
else
:
self
.
inputs
=
input_tensor
self
.
inputs
=
input_tensor
self
.
inputs
[
'mode'
]
=
mode
self
.
inputs
[
'mode'
]
=
mode
self
.
model_arch
()
self
.
model_arch
()
if
mode
==
'train'
:
if
mode
==
'train'
:
out
=
self
.
get_loss
()
out
=
self
.
get_loss
()
elif
mode
==
'infer'
:
elif
mode
==
'infer'
:
out
=
self
.
get_pred
(
input_tensor
is
None
)
out
=
self
.
get_pred
()
else
:
else
:
out
=
None
out
=
None
raise
"Now, only support train and infer mode!"
raise
"Now, only support train and infer mode!"
...
@@ -47,6 +53,3 @@ class BaseArch(nn.Layer):
...
@@ -47,6 +53,3 @@ class BaseArch(nn.Layer):
def
get_pred
(
self
,
):
def
get_pred
(
self
,
):
raise
NotImplementedError
(
"Should implement get_pred method!"
)
raise
NotImplementedError
(
"Should implement get_pred method!"
)
def
get_export_model
(
self
,
input_tensor
):
return
self
.
forward
(
None
,
None
,
'infer'
,
input_tensor
)
ppdet/modeling/architecture/yolo.py
浏览文件 @
a1446709
...
@@ -43,16 +43,12 @@ class YOLOv3(BaseArch):
...
@@ -43,16 +43,12 @@ class YOLOv3(BaseArch):
loss
=
self
.
yolo_head
.
get_loss
(
self
.
yolo_head_outs
,
self
.
inputs
)
loss
=
self
.
yolo_head
.
get_loss
(
self
.
yolo_head_outs
,
self
.
inputs
)
return
loss
return
loss
def
get_pred
(
self
,
return_numpy
=
True
):
def
get_pred
(
self
):
bbox
,
bbox_num
=
self
.
post_process
(
bbox
,
bbox_num
=
self
.
post_process
(
self
.
yolo_head_outs
,
self
.
yolo_head
.
mask_anchors
,
self
.
yolo_head_outs
,
self
.
yolo_head
.
mask_anchors
,
self
.
inputs
[
'im_shape'
],
self
.
inputs
[
'scale_factor'
])
self
.
inputs
[
'im_shape'
],
self
.
inputs
[
'scale_factor'
])
if
return_numpy
:
outs
=
{
outs
=
{
"bbox"
:
bbox
.
numpy
(),
"bbox"
:
bbox
,
"bbox_num"
:
bbox_num
.
numpy
(),
"bbox_num"
:
bbox_num
,
'im_id'
:
self
.
inputs
[
'im_id'
].
numpy
()
}
}
else
:
outs
=
[
bbox
,
bbox_num
]
return
outs
return
outs
ppdet/modeling/head/mask_head.py
浏览文件 @
a1446709
...
@@ -160,12 +160,19 @@ class MaskHead(Layer):
...
@@ -160,12 +160,19 @@ class MaskHead(Layer):
bbox
,
bbox_num
=
bboxes
bbox
,
bbox_num
=
bboxes
if
bbox
.
shape
[
0
]
==
0
:
if
bbox
.
shape
[
0
]
==
0
:
mask_head_out
=
bbox
mask_head_out
=
paddle
.
full
([
1
,
6
],
-
1
)
return
mask_head_out
else
:
else
:
scale_factor_list
=
[]
# TODO(guanghua): Remove fluid dependency
scale_factor_list
=
paddle
.
fluid
.
layers
.
create_array
(
'float32'
)
num_count
=
0
for
idx
,
num
in
enumerate
(
bbox_num
):
for
idx
,
num
in
enumerate
(
bbox_num
):
for
n
in
range
(
num
):
for
n
in
range
(
num
):
scale_factor_list
.
append
(
scale_factor
[
idx
,
0
])
paddle
.
fluid
.
layers
.
array_write
(
x
=
scale_factor
[
idx
,
0
],
i
=
paddle
.
to_tensor
(
num_count
),
array
=
scale_factor_list
)
num_count
+=
1
scale_factor_list
=
paddle
.
cast
(
scale_factor_list
=
paddle
.
cast
(
paddle
.
concat
(
scale_factor_list
),
'float32'
)
paddle
.
concat
(
scale_factor_list
),
'float32'
)
scale_factor_list
=
paddle
.
reshape
(
scale_factor_list
,
shape
=
[
-
1
,
1
])
scale_factor_list
=
paddle
.
reshape
(
scale_factor_list
,
shape
=
[
-
1
,
1
])
...
...
ppdet/modeling/head/roi_extractor.py
浏览文件 @
a1446709
...
@@ -36,7 +36,6 @@ class RoIAlign(object):
...
@@ -36,7 +36,6 @@ class RoIAlign(object):
def
__call__
(
self
,
feats
,
rois
,
spatial_scale
):
def
__call__
(
self
,
feats
,
rois
,
spatial_scale
):
roi
,
rois_num
=
rois
roi
,
rois_num
=
rois
if
self
.
start_level
==
self
.
end_level
:
if
self
.
start_level
==
self
.
end_level
:
rois_feat
=
ops
.
roi_align
(
rois_feat
=
ops
.
roi_align
(
feats
[
self
.
start_level
],
feats
[
self
.
start_level
],
...
@@ -44,7 +43,7 @@ class RoIAlign(object):
...
@@ -44,7 +43,7 @@ class RoIAlign(object):
self
.
resolution
,
self
.
resolution
,
spatial_scale
,
spatial_scale
,
rois_num
=
rois_num
)
rois_num
=
rois_num
)
return
rois_feat
else
:
offset
=
2
offset
=
2
k_min
=
self
.
start_level
+
offset
k_min
=
self
.
start_level
+
offset
k_max
=
self
.
end_level
+
offset
k_max
=
self
.
end_level
+
offset
...
...
ppdet/modeling/neck/fpn.py
浏览文件 @
a1446709
...
@@ -80,7 +80,8 @@ class FPN(Layer):
...
@@ -80,7 +80,8 @@ class FPN(Layer):
for
lvl
in
range
(
self
.
min_level
,
self
.
max_level
):
for
lvl
in
range
(
self
.
min_level
,
self
.
max_level
):
laterals
.
append
(
self
.
lateral_convs
[
lvl
](
body_feats
[
lvl
]))
laterals
.
append
(
self
.
lateral_convs
[
lvl
](
body_feats
[
lvl
]))
for
lvl
in
range
(
self
.
max_level
-
1
,
self
.
min_level
,
-
1
):
for
i
in
range
(
self
.
min_level
+
1
,
self
.
max_level
):
lvl
=
self
.
max_level
+
self
.
min_level
-
i
upsample
=
F
.
interpolate
(
upsample
=
F
.
interpolate
(
laterals
[
lvl
],
laterals
[
lvl
],
scale_factor
=
2.
,
scale_factor
=
2.
,
...
...
ppdet/modeling/ops.py
浏览文件 @
a1446709
...
@@ -29,10 +29,19 @@ import numpy as np
...
@@ -29,10 +29,19 @@ import numpy as np
from
functools
import
reduce
from
functools
import
reduce
__all__
=
[
__all__
=
[
'roi_pool'
,
'roi_align'
,
'prior_box'
,
'anchor_generator'
,
'roi_pool'
,
'generate_proposals'
,
'iou_similarity'
,
'box_coder'
,
'yolo_box'
,
'roi_align'
,
'multiclass_nms'
,
'distribute_fpn_proposals'
,
'collect_fpn_proposals'
,
'prior_box'
,
'matrix_nms'
,
'batch_norm'
'anchor_generator'
,
'generate_proposals'
,
'iou_similarity'
,
'box_coder'
,
'yolo_box'
,
'multiclass_nms'
,
'distribute_fpn_proposals'
,
'collect_fpn_proposals'
,
'matrix_nms'
,
'batch_norm'
,
]
]
...
@@ -51,6 +60,7 @@ def batch_norm(ch, norm_type='bn', name=None):
...
@@ -51,6 +60,7 @@ def batch_norm(ch, norm_type='bn', name=None):
name
=
bn_name
+
'.offset'
,
regularizer
=
L2Decay
(
0.
)))
name
=
bn_name
+
'.offset'
,
regularizer
=
L2Decay
(
0.
)))
@
paddle
.
jit
.
not_to_static
def
roi_pool
(
input
,
def
roi_pool
(
input
,
rois
,
rois
,
output_size
,
output_size
,
...
@@ -123,6 +133,7 @@ def roi_pool(input,
...
@@ -123,6 +133,7 @@ def roi_pool(input,
"pooled_width"
,
pooled_width
,
"spatial_scale"
,
spatial_scale
)
"pooled_width"
,
pooled_width
,
"spatial_scale"
,
spatial_scale
)
return
pool_out
,
argmaxes
return
pool_out
,
argmaxes
else
:
check_variable_and_dtype
(
input
,
'input'
,
[
'float32'
],
'roi_pool'
)
check_variable_and_dtype
(
input
,
'input'
,
[
'float32'
],
'roi_pool'
)
check_variable_and_dtype
(
rois
,
'rois'
,
[
'float32'
],
'roi_pool'
)
check_variable_and_dtype
(
rois
,
'rois'
,
[
'float32'
],
'roi_pool'
)
helper
=
LayerHelper
(
'roi_pool'
,
**
locals
())
helper
=
LayerHelper
(
'roi_pool'
,
**
locals
())
...
@@ -149,6 +160,7 @@ def roi_pool(input,
...
@@ -149,6 +160,7 @@ def roi_pool(input,
return
pool_out
,
argmaxes
return
pool_out
,
argmaxes
@
paddle
.
jit
.
not_to_static
def
roi_align
(
input
,
def
roi_align
(
input
,
rois
,
rois
,
output_size
,
output_size
,
...
@@ -228,9 +240,11 @@ def roi_align(input,
...
@@ -228,9 +240,11 @@ def roi_align(input,
"sampling_ratio"
,
sampling_ratio
)
"sampling_ratio"
,
sampling_ratio
)
return
align_out
return
align_out
else
:
check_variable_and_dtype
(
input
,
'input'
,
[
'float32'
,
'float64'
],
check_variable_and_dtype
(
input
,
'input'
,
[
'float32'
,
'float64'
],
'roi_align'
)
'roi_align'
)
check_variable_and_dtype
(
rois
,
'rois'
,
[
'float32'
,
'float64'
],
'roi_align'
)
check_variable_and_dtype
(
rois
,
'rois'
,
[
'float32'
,
'float64'
],
'roi_align'
)
helper
=
LayerHelper
(
'roi_align'
,
**
locals
())
helper
=
LayerHelper
(
'roi_align'
,
**
locals
())
dtype
=
helper
.
input_dtype
()
dtype
=
helper
.
input_dtype
()
align_out
=
helper
.
create_variable_for_type_inference
(
dtype
)
align_out
=
helper
.
create_variable_for_type_inference
(
dtype
)
...
@@ -253,6 +267,7 @@ def roi_align(input,
...
@@ -253,6 +267,7 @@ def roi_align(input,
return
align_out
return
align_out
@
paddle
.
jit
.
not_to_static
def
iou_similarity
(
x
,
y
,
box_normalized
=
True
,
name
=
None
):
def
iou_similarity
(
x
,
y
,
box_normalized
=
True
,
name
=
None
):
"""
"""
Computes intersection-over-union (IOU) between two box lists.
Computes intersection-over-union (IOU) between two box lists.
...
@@ -303,7 +318,7 @@ def iou_similarity(x, y, box_normalized=True, name=None):
...
@@ -303,7 +318,7 @@ def iou_similarity(x, y, box_normalized=True, name=None):
if
in_dygraph_mode
():
if
in_dygraph_mode
():
out
=
core
.
ops
.
iou_similarity
(
x
,
y
,
'box_normalized'
,
box_normalized
)
out
=
core
.
ops
.
iou_similarity
(
x
,
y
,
'box_normalized'
,
box_normalized
)
return
out
return
out
else
:
helper
=
LayerHelper
(
"iou_similarity"
,
**
locals
())
helper
=
LayerHelper
(
"iou_similarity"
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
...
@@ -316,6 +331,7 @@ def iou_similarity(x, y, box_normalized=True, name=None):
...
@@ -316,6 +331,7 @@ def iou_similarity(x, y, box_normalized=True, name=None):
return
out
return
out
@
paddle
.
jit
.
not_to_static
def
collect_fpn_proposals
(
multi_rois
,
def
collect_fpn_proposals
(
multi_rois
,
multi_scores
,
multi_scores
,
min_level
,
min_level
,
...
@@ -398,7 +414,9 @@ def collect_fpn_proposals(multi_rois,
...
@@ -398,7 +414,9 @@ def collect_fpn_proposals(multi_rois,
attrs
=
(
'post_nms_topN'
,
post_nms_top_n
)
attrs
=
(
'post_nms_topN'
,
post_nms_top_n
)
output_rois
,
rois_num
=
core
.
ops
.
collect_fpn_proposals
(
output_rois
,
rois_num
=
core
.
ops
.
collect_fpn_proposals
(
input_rois
,
input_scores
,
rois_num_per_level
,
*
attrs
)
input_rois
,
input_scores
,
rois_num_per_level
,
*
attrs
)
return
output_rois
,
rois_num
else
:
helper
=
LayerHelper
(
'collect_fpn_proposals'
,
**
locals
())
helper
=
LayerHelper
(
'collect_fpn_proposals'
,
**
locals
())
dtype
=
helper
.
input_dtype
(
'multi_rois'
)
dtype
=
helper
.
input_dtype
(
'multi_rois'
)
check_dtype
(
dtype
,
'multi_rois'
,
[
'float32'
,
'float64'
],
check_dtype
(
dtype
,
'multi_rois'
,
[
'float32'
,
'float64'
],
...
@@ -421,11 +439,10 @@ def collect_fpn_proposals(multi_rois,
...
@@ -421,11 +439,10 @@ def collect_fpn_proposals(multi_rois,
inputs
=
inputs
,
inputs
=
inputs
,
outputs
=
outputs
,
outputs
=
outputs
,
attrs
=
{
'post_nms_topN'
:
post_nms_top_n
})
attrs
=
{
'post_nms_topN'
:
post_nms_top_n
})
if
rois_num_per_level
is
not
None
:
return
output_rois
,
rois_num
return
output_rois
,
rois_num
return
output_rois
@
paddle
.
jit
.
not_to_static
def
distribute_fpn_proposals
(
fpn_rois
,
def
distribute_fpn_proposals
(
fpn_rois
,
min_level
,
min_level
,
max_level
,
max_level
,
...
@@ -510,12 +527,14 @@ def distribute_fpn_proposals(fpn_rois,
...
@@ -510,12 +527,14 @@ def distribute_fpn_proposals(fpn_rois,
fpn_rois
,
rois_num
,
num_lvl
,
num_lvl
,
*
attrs
)
fpn_rois
,
rois_num
,
num_lvl
,
num_lvl
,
*
attrs
)
return
multi_rois
,
restore_ind
,
rois_num_per_level
return
multi_rois
,
restore_ind
,
rois_num_per_level
else
:
check_variable_and_dtype
(
fpn_rois
,
'fpn_rois'
,
[
'float32'
,
'float64'
],
check_variable_and_dtype
(
fpn_rois
,
'fpn_rois'
,
[
'float32'
,
'float64'
],
'distribute_fpn_proposals'
)
'distribute_fpn_proposals'
)
helper
=
LayerHelper
(
'distribute_fpn_proposals'
,
**
locals
())
helper
=
LayerHelper
(
'distribute_fpn_proposals'
,
**
locals
())
dtype
=
helper
.
input_dtype
(
'fpn_rois'
)
dtype
=
helper
.
input_dtype
(
'fpn_rois'
)
multi_rois
=
[
multi_rois
=
[
helper
.
create_variable_for_type_inference
(
dtype
)
for
i
in
range
(
num_lvl
)
helper
.
create_variable_for_type_inference
(
dtype
)
for
i
in
range
(
num_lvl
)
]
]
restore_ind
=
helper
.
create_variable_for_type_inference
(
dtype
=
'int32'
)
restore_ind
=
helper
.
create_variable_for_type_inference
(
dtype
=
'int32'
)
...
@@ -544,11 +563,10 @@ def distribute_fpn_proposals(fpn_rois,
...
@@ -544,11 +563,10 @@ def distribute_fpn_proposals(fpn_rois,
'refer_level'
:
refer_level
,
'refer_level'
:
refer_level
,
'refer_scale'
:
refer_scale
'refer_scale'
:
refer_scale
})
})
if
rois_num
is
not
None
:
return
multi_rois
,
restore_ind
,
rois_num_per_level
return
multi_rois
,
restore_ind
,
rois_num_per_level
return
multi_rois
,
restore_ind
@
paddle
.
jit
.
not_to_static
def
yolo_box
(
def
yolo_box
(
x
,
x
,
origin_shape
,
origin_shape
,
...
@@ -685,6 +703,7 @@ def yolo_box(
...
@@ -685,6 +703,7 @@ def yolo_box(
return
boxes
,
scores
return
boxes
,
scores
@
paddle
.
jit
.
not_to_static
def
prior_box
(
input
,
def
prior_box
(
input
,
image
,
image
,
min_sizes
,
min_sizes
,
...
@@ -798,7 +817,7 @@ def prior_box(input,
...
@@ -798,7 +817,7 @@ def prior_box(input,
attrs
=
tuple
(
attrs
)
attrs
=
tuple
(
attrs
)
box
,
var
=
core
.
ops
.
prior_box
(
input
,
image
,
*
attrs
)
box
,
var
=
core
.
ops
.
prior_box
(
input
,
image
,
*
attrs
)
return
box
,
var
return
box
,
var
else
:
attrs
=
{
attrs
=
{
'min_sizes'
:
min_sizes
,
'min_sizes'
:
min_sizes
,
'aspect_ratios'
:
aspect_ratios
,
'aspect_ratios'
:
aspect_ratios
,
...
@@ -828,6 +847,7 @@ def prior_box(input,
...
@@ -828,6 +847,7 @@ def prior_box(input,
return
box
,
var
return
box
,
var
@
paddle
.
jit
.
not_to_static
def
anchor_generator
(
input
,
def
anchor_generator
(
input
,
anchor_sizes
=
None
,
anchor_sizes
=
None
,
aspect_ratios
=
None
,
aspect_ratios
=
None
,
...
@@ -916,6 +936,7 @@ def anchor_generator(input,
...
@@ -916,6 +936,7 @@ def anchor_generator(input,
anchor
,
var
=
core
.
ops
.
anchor_generator
(
input
,
*
attrs
)
anchor
,
var
=
core
.
ops
.
anchor_generator
(
input
,
*
attrs
)
return
anchor
,
var
return
anchor
,
var
else
:
attrs
=
{
attrs
=
{
'anchor_sizes'
:
anchor_sizes
,
'anchor_sizes'
:
anchor_sizes
,
'aspect_ratios'
:
aspect_ratios
,
'aspect_ratios'
:
aspect_ratios
,
...
@@ -937,6 +958,7 @@ def anchor_generator(input,
...
@@ -937,6 +958,7 @@ def anchor_generator(input,
return
anchor
,
var
return
anchor
,
var
@
paddle
.
jit
.
not_to_static
def
multiclass_nms
(
bboxes
,
def
multiclass_nms
(
bboxes
,
scores
,
scores
,
score_threshold
,
score_threshold
,
...
@@ -1091,6 +1113,7 @@ def multiclass_nms(bboxes,
...
@@ -1091,6 +1113,7 @@ def multiclass_nms(bboxes,
return
output
,
nms_rois_num
,
index
return
output
,
nms_rois_num
,
index
@
paddle
.
jit
.
not_to_static
def
matrix_nms
(
bboxes
,
def
matrix_nms
(
bboxes
,
scores
,
scores
,
score_threshold
,
score_threshold
,
...
@@ -1196,7 +1219,7 @@ def matrix_nms(bboxes,
...
@@ -1196,7 +1219,7 @@ def matrix_nms(bboxes,
if
return_rois_num
:
if
return_rois_num
:
return
out
,
rois_num
return
out
,
rois_num
return
out
return
out
else
:
helper
=
LayerHelper
(
'matrix_nms'
,
**
locals
())
helper
=
LayerHelper
(
'matrix_nms'
,
**
locals
())
output
=
helper
.
create_variable_for_type_inference
(
dtype
=
bboxes
.
dtype
)
output
=
helper
.
create_variable_for_type_inference
(
dtype
=
bboxes
.
dtype
)
index
=
helper
.
create_variable_for_type_inference
(
dtype
=
'int'
)
index
=
helper
.
create_variable_for_type_inference
(
dtype
=
'int'
)
...
@@ -1231,6 +1254,7 @@ def matrix_nms(bboxes,
...
@@ -1231,6 +1254,7 @@ def matrix_nms(bboxes,
return
output
return
output
@
paddle
.
jit
.
not_to_static
def
box_coder
(
prior_box
,
def
box_coder
(
prior_box
,
prior_box_var
,
prior_box_var
,
target_box
,
target_box
,
...
@@ -1357,7 +1381,7 @@ def box_coder(prior_box,
...
@@ -1357,7 +1381,7 @@ def box_coder(prior_box,
raise
TypeError
(
raise
TypeError
(
"Input variance of box_coder must be Variable or list"
)
"Input variance of box_coder must be Variable or list"
)
return
output_box
return
output_box
else
:
helper
=
LayerHelper
(
"box_coder"
,
**
locals
())
helper
=
LayerHelper
(
"box_coder"
,
**
locals
())
output_box
=
helper
.
create_variable_for_type_inference
(
output_box
=
helper
.
create_variable_for_type_inference
(
...
@@ -1374,7 +1398,8 @@ def box_coder(prior_box,
...
@@ -1374,7 +1398,8 @@ def box_coder(prior_box,
elif
isinstance
(
prior_box_var
,
list
):
elif
isinstance
(
prior_box_var
,
list
):
attrs
[
'variance'
]
=
prior_box_var
attrs
[
'variance'
]
=
prior_box_var
else
:
else
:
raise
TypeError
(
"Input variance of box_coder must be Variable or list"
)
raise
TypeError
(
"Input variance of box_coder must be Variable or list"
)
helper
.
append_op
(
helper
.
append_op
(
type
=
"box_coder"
,
type
=
"box_coder"
,
inputs
=
inputs
,
inputs
=
inputs
,
...
@@ -1383,6 +1408,7 @@ def box_coder(prior_box,
...
@@ -1383,6 +1408,7 @@ def box_coder(prior_box,
return
output_box
return
output_box
@
paddle
.
jit
.
not_to_static
def
generate_proposals
(
scores
,
def
generate_proposals
(
scores
,
bbox_deltas
,
bbox_deltas
,
im_shape
,
im_shape
,
...
@@ -1472,6 +1498,7 @@ def generate_proposals(scores,
...
@@ -1472,6 +1498,7 @@ def generate_proposals(scores,
scores
,
bbox_deltas
,
im_shape
,
anchors
,
variances
,
*
attrs
)
scores
,
bbox_deltas
,
im_shape
,
anchors
,
variances
,
*
attrs
)
return
rpn_rois
,
rpn_roi_probs
,
rpn_rois_num
return
rpn_rois
,
rpn_roi_probs
,
rpn_rois_num
else
:
helper
=
LayerHelper
(
'generate_proposals_v2'
,
**
locals
())
helper
=
LayerHelper
(
'generate_proposals_v2'
,
**
locals
())
check_variable_and_dtype
(
scores
,
'scores'
,
[
'float32'
],
check_variable_and_dtype
(
scores
,
'scores'
,
[
'float32'
],
...
@@ -1494,7 +1521,8 @@ def generate_proposals(scores,
...
@@ -1494,7 +1521,8 @@ def generate_proposals(scores,
'RpnRoiProbs'
:
rpn_roi_probs
,
'RpnRoiProbs'
:
rpn_roi_probs
,
}
}
if
return_rois_num
:
if
return_rois_num
:
rpn_rois_num
=
helper
.
create_variable_for_type_inference
(
dtype
=
'int32'
)
rpn_rois_num
=
helper
.
create_variable_for_type_inference
(
dtype
=
'int32'
)
rpn_rois_num
.
stop_gradient
=
True
rpn_rois_num
.
stop_gradient
=
True
outputs
[
'RpnRoisNum'
]
=
rpn_rois_num
outputs
[
'RpnRoisNum'
]
=
rpn_rois_num
...
@@ -1518,10 +1546,7 @@ def generate_proposals(scores,
...
@@ -1518,10 +1546,7 @@ def generate_proposals(scores,
rpn_rois
.
stop_gradient
=
True
rpn_rois
.
stop_gradient
=
True
rpn_roi_probs
.
stop_gradient
=
True
rpn_roi_probs
.
stop_gradient
=
True
if
return_rois_num
:
return
rpn_rois
,
rpn_roi_probs
,
rpn_rois_num
return
rpn_rois
,
rpn_roi_probs
,
rpn_rois_num
else
:
return
rpn_rois
,
rpn_roi_probs
def
sigmoid_cross_entropy_with_logits
(
input
,
def
sigmoid_cross_entropy_with_logits
(
input
,
...
...
ppdet/py_op/post_process.py
浏览文件 @
a1446709
...
@@ -73,7 +73,8 @@ def bbox_post_process(bboxes,
...
@@ -73,7 +73,8 @@ def bbox_post_process(bboxes,
@
jit
@
jit
def
mask_post_process
(
bboxes
,
def
mask_post_process
(
bbox
,
bbox_nums
,
masks
,
masks
,
im_shape
,
im_shape
,
scale_factor
,
scale_factor
,
...
@@ -81,7 +82,6 @@ def mask_post_process(bboxes,
...
@@ -81,7 +82,6 @@ def mask_post_process(bboxes,
binary_thresh
=
0.5
):
binary_thresh
=
0.5
):
if
masks
.
shape
[
0
]
==
0
:
if
masks
.
shape
[
0
]
==
0
:
return
masks
return
masks
bbox
,
bbox_nums
=
bboxes
M
=
resolution
M
=
resolution
scale
=
(
M
+
2.0
)
/
M
scale
=
(
M
+
2.0
)
/
M
boxes
=
bbox
[:,
2
:]
boxes
=
bbox
[:,
2
:]
...
@@ -98,7 +98,6 @@ def mask_post_process(bboxes,
...
@@ -98,7 +98,6 @@ def mask_post_process(bboxes,
boxes_n
=
boxes
[
st_num
:
end_num
]
boxes_n
=
boxes
[
st_num
:
end_num
]
labels_n
=
labels
[
st_num
:
end_num
]
labels_n
=
labels
[
st_num
:
end_num
]
masks_n
=
masks
[
st_num
:
end_num
]
masks_n
=
masks
[
st_num
:
end_num
]
im_h
=
int
(
round
(
im_shape
[
i
][
0
]
/
scale_factor
[
i
]))
im_h
=
int
(
round
(
im_shape
[
i
][
0
]
/
scale_factor
[
i
]))
im_w
=
int
(
round
(
im_shape
[
i
][
1
]
/
scale_factor
[
i
]))
im_w
=
int
(
round
(
im_shape
[
i
][
1
]
/
scale_factor
[
i
]))
boxes_n
=
expand_bbox
(
boxes_n
,
scale
)
boxes_n
=
expand_bbox
(
boxes_n
,
scale
)
...
...
ppdet/utils/eval_utils.py
浏览文件 @
a1446709
...
@@ -5,7 +5,7 @@ from __future__ import print_function
...
@@ -5,7 +5,7 @@ from __future__ import print_function
import
os
import
os
import
sys
import
sys
import
json
import
json
from
ppdet.py_op.post_process
import
get_det_res
,
get_seg_res
from
ppdet.py_op.post_process
import
get_det_res
,
get_seg_res
,
mask_post_process
import
logging
import
logging
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -33,7 +33,8 @@ def json_eval_results(metric, json_directory=None, dataset=None):
...
@@ -33,7 +33,8 @@ def json_eval_results(metric, json_directory=None, dataset=None):
logger
.
info
(
"{} not exists!"
.
format
(
v_json
))
logger
.
info
(
"{} not exists!"
.
format
(
v_json
))
def
get_infer_results
(
outs_res
,
eval_type
,
catid
):
def
get_infer_results
(
outs_res
,
eval_type
,
catid
,
im_info
,
mask_resolution
=
None
):
"""
"""
Get result at the stage of inference.
Get result at the stage of inference.
The output format is dictionary containing bbox or mask result.
The output format is dictionary containing bbox or mask result.
...
@@ -49,16 +50,25 @@ def get_infer_results(outs_res, eval_type, catid):
...
@@ -49,16 +50,25 @@ def get_infer_results(outs_res, eval_type, catid):
if
'bbox'
in
eval_type
:
if
'bbox'
in
eval_type
:
box_res
=
[]
box_res
=
[]
for
outs
in
outs_res
:
for
i
,
outs
in
enumerate
(
outs_res
):
box_res
+=
get_det_res
(
outs
[
'bbox'
],
outs
[
'bbox_num'
],
im_ids
=
im_info
[
i
][
2
]
outs
[
'im_id'
],
catid
)
box_res
+=
get_det_res
(
outs
[
'bbox'
].
numpy
(),
outs
[
'bbox_num'
].
numpy
(),
im_ids
,
catid
)
infer_res
[
'bbox'
]
=
box_res
infer_res
[
'bbox'
]
=
box_res
if
'mask'
in
eval_type
:
if
'mask'
in
eval_type
:
seg_res
=
[]
seg_res
=
[]
for
outs
in
outs_res
:
# mask post process
seg_res
+=
get_seg_res
(
outs
[
'mask'
],
outs
[
'bbox_num'
],
for
i
,
outs
in
enumerate
(
outs_res
):
outs
[
'im_id'
],
catid
)
im_shape
=
im_info
[
i
][
0
]
scale_factor
=
im_info
[
i
][
1
]
im_ids
=
im_info
[
i
][
2
]
mask
=
mask_post_process
(
outs
[
'bbox'
].
numpy
(),
outs
[
'bbox_num'
].
numpy
(),
outs
[
'mask'
].
numpy
(),
im_shape
,
scale_factor
[
0
],
mask_resolution
)
seg_res
+=
get_seg_res
(
mask
,
outs
[
'bbox_num'
].
numpy
(),
im_ids
,
catid
)
infer_res
[
'mask'
]
=
seg_res
infer_res
[
'mask'
]
=
seg_res
return
infer_res
return
infer_res
...
...
tools/eval.py
浏览文件 @
a1446709
...
@@ -75,12 +75,18 @@ def run(FLAGS, cfg, place):
...
@@ -75,12 +75,18 @@ def run(FLAGS, cfg, place):
outs_res
=
[]
outs_res
=
[]
start_time
=
time
.
time
()
start_time
=
time
.
time
()
sample_num
=
0
sample_num
=
0
im_info
=
[]
for
iter_id
,
data
in
enumerate
(
eval_loader
):
for
iter_id
,
data
in
enumerate
(
eval_loader
):
# forward
# forward
fields
=
cfg
[
'EvalReader'
][
'inputs_def'
][
'fields'
]
model
.
eval
()
model
.
eval
()
outs
=
model
(
data
,
cfg
[
'EvalReader'
][
'inputs_def'
][
'fields'
],
'infer'
)
outs
=
model
(
data
=
data
,
input_def
=
fields
,
mode
=
'infer'
)
outs_res
.
append
(
outs
)
outs_res
.
append
(
outs
)
im_info
.
append
([
data
[
fields
.
index
(
'im_shape'
)].
numpy
(),
data
[
fields
.
index
(
'scale_factor'
)].
numpy
(),
data
[
fields
.
index
(
'im_id'
)].
numpy
()
])
# log
# log
sample_num
+=
len
(
data
)
sample_num
+=
len
(
data
)
if
iter_id
%
100
==
0
:
if
iter_id
%
100
==
0
:
...
@@ -102,7 +108,15 @@ def run(FLAGS, cfg, place):
...
@@ -102,7 +108,15 @@ def run(FLAGS, cfg, place):
clsid2catid
,
catid2name
=
get_category_info
(
anno_file
,
with_background
,
clsid2catid
,
catid2name
=
get_category_info
(
anno_file
,
with_background
,
use_default_label
)
use_default_label
)
infer_res
=
get_infer_results
(
outs_res
,
eval_type
,
clsid2catid
)
mask_resolution
=
None
if
cfg
[
'MaskPostProcess'
][
'mask_resolution'
]
is
not
None
:
mask_resolution
=
int
(
cfg
[
'MaskPostProcess'
][
'mask_resolution'
])
infer_res
=
get_infer_results
(
outs_res
,
eval_type
,
clsid2catid
,
im_info
,
mask_resolution
=
mask_resolution
)
eval_results
(
infer_res
,
cfg
.
metric
,
anno_file
)
eval_results
(
infer_res
,
cfg
.
metric
,
anno_file
)
...
...
tools/export_model.py
浏览文件 @
a1446709
...
@@ -53,63 +53,43 @@ def parse_args():
...
@@ -53,63 +53,43 @@ def parse_args():
return
args
return
args
def
dygraph_to_static
(
model
,
save_dir
,
cfg
):
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
inputs_def
=
cfg
[
'TestReader'
][
'inputs_def'
]
image_shape
=
inputs_def
.
get
(
'image_shape'
)
if
image_shape
is
None
:
image_shape
=
[
3
,
None
,
None
]
# Save infer cfg
dump_infer_config
(
cfg
,
os
.
path
.
join
(
save_dir
,
'infer_cfg.yml'
),
image_shape
)
input_spec
=
[{
"image"
:
InputSpec
(
shape
=
[
None
]
+
image_shape
,
name
=
'image'
),
"im_shape"
:
InputSpec
(
shape
=
[
None
,
2
],
name
=
'im_shape'
),
"scale_factor"
:
InputSpec
(
shape
=
[
None
,
2
],
name
=
'scale_factor'
)
}]
export_model
=
to_static
(
model
,
input_spec
=
input_spec
)
# save Model
paddle
.
jit
.
save
(
export_model
,
os
.
path
.
join
(
save_dir
,
'model'
))
def
run
(
FLAGS
,
cfg
):
def
run
(
FLAGS
,
cfg
):
# Model
# Model
main_arch
=
cfg
.
architecture
main_arch
=
cfg
.
architecture
model
=
create
(
cfg
.
architecture
)
model
=
create
(
cfg
.
architecture
)
inputs_def
=
cfg
[
'TestReader'
][
'inputs_def'
]
assert
'image_shape'
in
inputs_def
,
'image_shape must be specified.'
image_shape
=
inputs_def
.
get
(
'image_shape'
)
assert
not
None
in
image_shape
,
'image_shape should not contain None'
cfg_name
=
os
.
path
.
basename
(
FLAGS
.
config
).
split
(
'.'
)[
0
]
cfg_name
=
os
.
path
.
basename
(
FLAGS
.
config
).
split
(
'.'
)[
0
]
save_dir
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
cfg_name
)
save_dir
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
cfg_name
)
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
image_shape
=
dump_infer_config
(
cfg
,
os
.
path
.
join
(
save_dir
,
'infer_cfg.yml'
),
image_shape
)
class
ExportModel
(
nn
.
Layer
):
def
__init__
(
self
,
model
):
super
(
ExportModel
,
self
).
__init__
()
self
.
model
=
model
@
to_static
(
input_spec
=
[
{
'image'
:
InputSpec
(
shape
=
[
None
]
+
image_shape
,
name
=
'image'
)
},
{
'im_shape'
:
InputSpec
(
shape
=
[
None
,
2
],
name
=
'im_shape'
)
},
{
'scale_factor'
:
InputSpec
(
shape
=
[
None
,
2
],
name
=
'scale_factor'
)
},
])
def
forward
(
self
,
image
,
im_shape
,
scale_factor
):
inputs
=
{}
inputs_tensor
=
[
image
,
im_shape
,
scale_factor
]
for
t
in
inputs_tensor
:
inputs
.
update
(
t
)
outs
=
self
.
model
.
get_export_model
(
inputs
)
return
outs
export_model
=
ExportModel
(
model
)
# debug for dy2static, remove later
#paddle.jit.set_code_level()
# Init Model
# Init Model
load_weight
(
export_model
.
model
,
cfg
.
weights
)
load_weight
(
model
,
cfg
.
weights
)
export_model
.
eval
()
# export config and model
# export config and model
paddle
.
jit
.
save
(
export_model
,
os
.
path
.
join
(
save_dir
,
'model'
)
)
dygraph_to_static
(
model
,
save_dir
,
cfg
)
logger
.
info
(
'Export model to {}'
.
format
(
save_dir
))
logger
.
info
(
'Export model to {}'
.
format
(
save_dir
))
...
...
tools/export_utils.py
浏览文件 @
a1446709
...
@@ -109,7 +109,8 @@ def dump_infer_config(config, path, image_shape):
...
@@ -109,7 +109,8 @@ def dump_infer_config(config, path, image_shape):
os
.
_exit
(
0
)
os
.
_exit
(
0
)
if
'Mask'
in
config
[
'architecture'
]:
if
'Mask'
in
config
[
'architecture'
]:
infer_cfg
[
'mask_resolution'
]
=
config
[
'Mask'
][
'mask_resolution'
]
infer_cfg
[
'mask_resolution'
]
=
config
[
'MaskPostProcess'
][
'mask_resolution'
]
infer_cfg
[
'with_background'
],
infer_cfg
[
'Preprocess'
],
infer_cfg
[
infer_cfg
[
'with_background'
],
infer_cfg
[
'Preprocess'
],
infer_cfg
[
'label_list'
],
image_shape
=
parse_reader
(
'label_list'
],
image_shape
=
parse_reader
(
config
[
'TestReader'
],
config
[
'TestDataset'
],
config
[
'metric'
],
config
[
'TestReader'
],
config
[
'TestDataset'
],
config
[
'metric'
],
...
...
tools/infer.py
浏览文件 @
a1446709
...
@@ -147,15 +147,32 @@ def run(FLAGS, cfg, place):
...
@@ -147,15 +147,32 @@ def run(FLAGS, cfg, place):
# Run Infer
# Run Infer
for
iter_id
,
data
in
enumerate
(
test_loader
):
for
iter_id
,
data
in
enumerate
(
test_loader
):
# forward
# forward
fields
=
cfg
.
TestReader
[
'inputs_def'
][
'fields'
]
model
.
eval
()
model
.
eval
()
outs
=
model
(
data
,
cfg
.
TestReader
[
'inputs_def'
][
'fields'
],
'infer'
)
outs
=
model
(
data
=
data
,
batch_res
=
get_infer_results
([
outs
],
outs
.
keys
(),
clsid2catid
)
input_def
=
cfg
.
TestReader
[
'inputs_def'
][
'fields'
],
mode
=
'infer'
)
im_info
=
[[
data
[
fields
.
index
(
'im_shape'
)].
numpy
(),
data
[
fields
.
index
(
'scale_factor'
)].
numpy
(),
data
[
fields
.
index
(
'im_id'
)].
numpy
()
]]
im_ids
=
data
[
fields
.
index
(
'im_id'
)].
numpy
()
mask_resolution
=
None
if
cfg
[
'MaskPostProcess'
][
'mask_resolution'
]
is
not
None
:
mask_resolution
=
int
(
cfg
[
'MaskPostProcess'
][
'mask_resolution'
])
batch_res
=
get_infer_results
(
[
outs
],
outs
.
keys
(),
clsid2catid
,
im_info
,
mask_resolution
=
mask_resolution
)
logger
.
info
(
'Infer iter {}'
.
format
(
iter_id
))
logger
.
info
(
'Infer iter {}'
.
format
(
iter_id
))
bbox_res
=
None
bbox_res
=
None
mask_res
=
None
mask_res
=
None
im_ids
=
outs
[
'im_id'
]
bbox_num
=
outs
[
'bbox_num'
]
bbox_num
=
outs
[
'bbox_num'
]
start
=
0
start
=
0
for
i
,
im_id
in
enumerate
(
im_ids
):
for
i
,
im_id
in
enumerate
(
im_ids
):
...
...
tools/train.py
浏览文件 @
a1446709
...
@@ -35,6 +35,7 @@ from ppdet.utils.stats import TrainingStats
...
@@ -35,6 +35,7 @@ from ppdet.utils.stats import TrainingStats
from
ppdet.utils.check
import
check_gpu
,
check_version
,
check_config
from
ppdet.utils.check
import
check_gpu
,
check_version
,
check_config
from
ppdet.utils.cli
import
ArgsParser
from
ppdet.utils.cli
import
ArgsParser
from
ppdet.utils.checkpoint
import
load_weight
,
load_pretrain_weight
,
save_model
from
ppdet.utils.checkpoint
import
load_weight
,
load_pretrain_weight
,
save_model
from
export_model
import
dygraph_to_static
from
paddle.distributed
import
ParallelEnv
from
paddle.distributed
import
ParallelEnv
import
logging
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
...
@@ -149,6 +150,8 @@ def run(FLAGS, cfg, place):
...
@@ -149,6 +150,8 @@ def run(FLAGS, cfg, place):
model
=
paddle
.
DataParallel
(
model
)
model
=
paddle
.
DataParallel
(
model
)
fields
=
train_loader
.
collate_fn
.
output_fields
fields
=
train_loader
.
collate_fn
.
output_fields
cfg_name
=
os
.
path
.
basename
(
FLAGS
.
config
).
split
(
'.'
)[
0
]
save_dir
=
os
.
path
.
join
(
cfg
.
save_dir
,
cfg_name
)
# Run Train
# Run Train
time_stat
=
deque
(
maxlen
=
cfg
.
log_iter
)
time_stat
=
deque
(
maxlen
=
cfg
.
log_iter
)
start_time
=
time
.
time
()
start_time
=
time
.
time
()
...
@@ -167,7 +170,7 @@ def run(FLAGS, cfg, place):
...
@@ -167,7 +170,7 @@ def run(FLAGS, cfg, place):
# Model Forward
# Model Forward
model
.
train
()
model
.
train
()
outputs
=
model
(
data
,
fields
,
'train'
)
outputs
=
model
(
data
=
data
,
input_def
=
fields
,
mode
=
'train'
)
# Model Backward
# Model Backward
loss
=
outputs
[
'loss'
]
loss
=
outputs
[
'loss'
]
...
@@ -193,11 +196,12 @@ def run(FLAGS, cfg, place):
...
@@ -193,11 +196,12 @@ def run(FLAGS, cfg, place):
if
ParallelEnv
().
local_rank
==
0
and
(
if
ParallelEnv
().
local_rank
==
0
and
(
cur_eid
%
cfg
.
snapshot_epoch
==
0
or
cur_eid
%
cfg
.
snapshot_epoch
==
0
or
(
cur_eid
+
1
)
==
int
(
cfg
.
epoch
)):
(
cur_eid
+
1
)
==
int
(
cfg
.
epoch
)):
cfg_name
=
os
.
path
.
basename
(
FLAGS
.
config
).
split
(
'.'
)[
0
]
save_name
=
str
(
cur_eid
)
if
cur_eid
+
1
!=
int
(
save_name
=
str
(
cur_eid
)
if
cur_eid
+
1
!=
int
(
cfg
.
epoch
)
else
"model_final"
cfg
.
epoch
)
else
"model_final"
save_dir
=
os
.
path
.
join
(
cfg
.
save_dir
,
cfg_name
)
save_model
(
model
,
optimizer
,
save_dir
,
save_name
,
cur_eid
+
1
)
save_model
(
model
,
optimizer
,
save_dir
,
save_name
,
cur_eid
+
1
)
# TODO(guanghua): dygraph model to static model
# if ParallelEnv().local_rank == 0 and (cur_eid + 1) == int(cfg.epoch)):
# dygraph_to_static(model, os.path.join(save_dir, 'static_model_final'), cfg)
def
main
():
def
main
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录