Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
5ec5f453
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5ec5f453
编写于
3月 03, 2020
作者:
D
dengkaipeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fit for dygraph
上级
074a08e5
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
52 addition
and
53 deletion
+52
-53
metrics/coco.py
metrics/coco.py
+33
-33
metrics/metric.py
metrics/metric.py
+7
-5
model.py
model.py
+7
-11
yolov3.py
yolov3.py
+5
-4
未找到文件。
metrics/coco.py
浏览文件 @
5ec5f453
...
...
@@ -16,6 +16,8 @@ from __future__ import absolute_import
import
sys
import
json
from
pycocotools.cocoeval
import
COCOeval
from
pycocotools.coco
import
COCO
from
.metric
import
Metric
...
...
@@ -32,22 +34,21 @@ OUTFILE = './bbox.json'
class
COCOMetric
(
Metric
):
"""
Base class for metric, encapsulates metric logic and APIs
Metrci for MS-COCO dataset, only support update with batch
size as 1.
Usage:
m = SomeMetric()
for prediction, label in ...:
m.update(prediction, label)
m.accumulate()
Args:
anno_path(str): path to COCO annotation json file
with_background(bool): whether load category id with
background as 0, default True
"""
def
__init__
(
self
,
anno_path
,
with_background
=
True
,
**
kwargs
):
super
(
COCOMetric
,
self
).
__init__
(
**
kwargs
)
self
.
states
[
'bbox'
]
=
[]
self
.
anno_path
=
anno_path
self
.
with_background
=
with_background
self
.
bbox_results
=
[]
from
pycocotools.coco
import
COCO
self
.
coco_gt
=
COCO
(
anno_path
)
cat_ids
=
self
.
coco_gt
.
getCatIds
()
self
.
clsid2catid
=
dict
(
...
...
@@ -56,39 +57,40 @@ class COCOMetric(Metric):
def
update
(
self
,
preds
,
*
args
,
**
kwargs
):
im_ids
,
bboxes
=
preds
if
bboxes
[
0
].
shape
[
1
]
!=
6
:
assert
im_ids
.
shape
[
0
]
==
1
,
\
"COCOMetric can only update with batch size = 1"
if
bboxes
.
shape
[
1
]
!=
6
:
# no bbox detected in this batch
return
idx
=
0
bboxes
,
lods
=
bboxes
for
i
,
(
im_id
,
lod
)
in
enumerate
(
zip
(
im_ids
,
lods
[
0
])):
im_id
=
int
(
im_id
)
for
i
in
range
(
lod
):
dt
=
bboxes
[
idx
]
clsid
,
score
,
xmin
,
ymin
,
xmax
,
ymax
=
dt
.
tolist
()
catid
=
(
self
.
clsid2catid
[
int
(
clsid
)])
w
=
xmax
-
xmin
+
1
h
=
ymax
-
ymin
+
1
bbox
=
[
xmin
,
ymin
,
w
,
h
]
coco_res
=
{
'image_id'
:
im_id
,
'category_id'
:
catid
,
'bbox'
:
bbox
,
'score'
:
score
}
self
.
states
[
'bbox'
].
append
(
coco_res
)
idx
+=
1
im_id
=
int
(
im_ids
)
for
i
in
range
(
bboxes
.
shape
[
0
]):
dt
=
bboxes
[
i
,
:]
clsid
,
score
,
xmin
,
ymin
,
xmax
,
ymax
=
dt
.
tolist
()
catid
=
(
self
.
clsid2catid
[
int
(
clsid
)])
w
=
xmax
-
xmin
+
1
h
=
ymax
-
ymin
+
1
bbox
=
[
xmin
,
ymin
,
w
,
h
]
coco_res
=
{
'image_id'
:
im_id
,
'category_id'
:
catid
,
'bbox'
:
bbox
,
'score'
:
score
}
self
.
bbox_results
.
append
(
coco_res
)
def
reset
(
self
):
self
.
bbox_results
=
[]
def
accumulate
(
self
):
if
len
(
self
.
states
[
'bbox'
]
)
==
0
:
if
len
(
self
.
bbox_results
)
==
0
:
logger
.
warning
(
"The number of valid bbox detected is zero.
\n
\
Please use reasonable model and check input data.
\n
\
stop COCOMetric accumulate!"
)
return
[
0.0
]
with
open
(
OUTFILE
,
'w'
)
as
f
:
json
.
dump
(
self
.
states
[
'bbox'
]
,
f
)
json
.
dump
(
self
.
bbox_results
,
f
)
map_stats
=
self
.
cocoapi_eval
(
OUTFILE
,
'bbox'
,
coco_gt
=
self
.
coco_gt
)
# flush coco evaluation result
...
...
@@ -98,10 +100,8 @@ class COCOMetric(Metric):
def
cocoapi_eval
(
self
,
jsonfile
,
style
,
coco_gt
=
None
,
anno_file
=
None
):
assert
coco_gt
!=
None
or
anno_file
!=
None
from
pycocotools.cocoeval
import
COCOeval
if
coco_gt
==
None
:
from
pycocotools.coco
import
COCO
coco_gt
=
COCO
(
anno_file
)
logger
.
info
(
"Start evaluate..."
)
coco_dt
=
coco_gt
.
loadRes
(
jsonfile
)
...
...
metrics/metric.py
浏览文件 @
5ec5f453
...
...
@@ -17,6 +17,11 @@ from __future__ import absolute_import
import
six
import
abc
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'Metric'
]
...
...
@@ -32,15 +37,12 @@ class Metric(object):
m.accumulate()
"""
def
__init__
(
self
,
**
kwargs
):
self
.
reset
()
@
abc
.
abstractmethod
def
reset
(
self
):
"""
Reset states and result
"""
self
.
states
=
{}
self
.
result
=
None
raise
NotImplementedError
(
"function 'reset' not implemented in {}."
.
format
(
self
.
__class__
.
__name__
))
@
abc
.
abstractmethod
def
update
(
self
,
*
args
,
**
kwargs
):
...
...
model.py
浏览文件 @
5ec5f453
...
...
@@ -285,17 +285,11 @@ class StaticGraphAdapter(object):
compiled_prog
,
feed
=
feed
,
fetch_list
=
fetch_list
,
return_numpy
=
False
)
# rets = [(np.array(v), v.recursive_sequence_lengths()) if v.lod() for v in rets]
np_rets
=
[]
for
ret
in
rets
:
seq_len
=
ret
.
recursive_sequence_lengths
()
if
len
(
seq_len
)
==
0
:
np_rets
.
append
(
np
.
array
(
ret
))
else
:
np_rets
.
append
((
np
.
array
(
ret
),
seq_len
))
outputs
=
np_rets
[:
num_output
]
labels
=
np_rets
[
num_output
:
num_output
+
num_label
]
losses
=
np_rets
[
num_output
+
num_label
:]
# LoDTensor cannot be fetch as numpy directly
rets
=
[
np
.
array
(
v
)
for
v
in
rets
]
outputs
=
rets
[:
num_output
]
labels
=
rets
[
num_output
:
num_output
+
num_label
]
losses
=
rets
[
num_output
+
num_label
:]
if
self
.
mode
==
'test'
:
return
outputs
elif
self
.
mode
==
'eval'
:
...
...
@@ -443,6 +437,8 @@ class DynamicGraphAdapter(object):
labels
=
to_list
(
labels
)
outputs
=
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
])
losses
=
self
.
model
.
_loss_function
(
outputs
[
0
],
labels
)
for
metric
in
self
.
model
.
_metrics
:
metric
.
update
([
to_numpy
(
o
)
for
o
in
outputs
[
1
:]],
labels
)
return
[
to_numpy
(
o
)
for
o
in
to_list
(
outputs
[
0
])],
\
[
to_numpy
(
l
)
for
l
in
losses
]
...
...
yolov3.py
浏览文件 @
5ec5f453
...
...
@@ -515,7 +515,7 @@ def main():
coco2017
(
FLAGS
.
data
,
'val'
),
process_num
=
8
,
buffer_size
=
4
*
batch_size
),
batch_size
=
batch_size
),
batch_size
=
1
),
process_num
=
2
,
buffer_size
=
4
)
if
not
os
.
path
.
exists
(
'yolo_checkpoints'
):
...
...
@@ -536,14 +536,15 @@ def main():
metrics
=
COCOMetric
(
anno_path
,
with_background
=
False
))
for
e
in
range
(
epoch
):
logger
.
info
(
"======== train epoch {} ========"
.
format
(
e
))
run
(
model
,
train_loader
)
model
.
save
(
'yolo_checkpoints/{:02d}'
.
format
(
e
))
#
logger.info("======== train epoch {} ========".format(e))
#
run(model, train_loader)
#
model.save('yolo_checkpoints/{:02d}'.format(e))
logger
.
info
(
"======== eval epoch {} ========"
.
format
(
e
))
run
(
model
,
val_loader
,
mode
=
'eval'
)
# should be called in fit()
for
metric
in
model
.
_metrics
:
metric
.
accumulate
()
metric
.
reset
()
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录