Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
2b426605
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2b426605
编写于
4月 26, 2020
作者:
D
dengkaipeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add test_metrics
上级
15aef487
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
254 addition
and
15 deletion
+254
-15
examples/tsm/README.md
examples/tsm/README.md
+4
-2
examples/tsm/infer.py
examples/tsm/infer.py
+3
-1
examples/tsm/main.py
examples/tsm/main.py
+9
-1
examples/tsm/utils.py
examples/tsm/utils.py
+44
-0
examples/yolov3/README.md
examples/yolov3/README.md
+5
-5
examples/yolov3/infer.py
examples/yolov3/infer.py
+3
-2
examples/yolov3/main.py
examples/yolov3/main.py
+12
-2
examples/yolov3/utils.py
examples/yolov3/utils.py
+44
-0
hapi/metrics.py
hapi/metrics.py
+2
-2
tests/test_metrics.py
tests/test_metrics.py
+128
-0
未找到文件。
examples/tsm/README.md
浏览文件 @
2b426605
...
...
@@ -39,8 +39,8 @@ TSM模型是将Temporal Shift Module插入到ResNet网络中构建的视频分
```bash
git clone https://github.com/PaddlePaddle/hapi
cd hapi
export PYTHONPATH=
$PYTHONPATH:`pwd`
cd tsm
export PYTHONPATH=
`pwd`:$PYTHONPATH
cd
examples/
tsm
```
### 数据准备
...
...
@@ -141,6 +141,8 @@ python infer.py --data=<path/to/dataset> --label_list=<path/to/label_list> --inf
2020-04-03 07:37:16,321-INFO: Sample ./kineteics/val_10/data_batch_10-042_6 predict label: 6, ground truth label: 6
```
**注意:**
推断时
`--infer_file`
需要指定到pickle文件路径。
## 参考论文
-
[
Temporal Shift Module for Efficient Video Understanding
](
https://arxiv.org/abs/1811.08383v1
)
, Ji Lin, Chuang Gan, Song Han
...
...
examples/tsm/infer.py
浏览文件 @
2b426605
...
...
@@ -25,6 +25,7 @@ from check import check_gpu, check_version
from
modeling
import
tsm_resnet50
from
kinetics_dataset
import
KineticsDataset
from
transforms
import
*
from
utils
import
print_arguments
import
logging
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -55,7 +56,7 @@ def main():
model
.
load
(
FLAGS
.
weights
,
reset_optimizer
=
True
)
imgs
,
label
=
dataset
[
0
]
pred
=
model
.
test
([
imgs
[
np
.
newaxis
,
:]])
pred
=
model
.
test
_batch
([
imgs
[
np
.
newaxis
,
:]])
pred
=
labels
[
np
.
argmax
(
pred
)]
logger
.
info
(
"Sample {} predict label: {}, ground truth label: {}"
\
.
format
(
FLAGS
.
infer_file
,
pred
,
labels
[
int
(
label
)]))
...
...
@@ -85,6 +86,7 @@ if __name__ == '__main__':
type
=
str
,
help
=
"weights path for evaluation"
)
FLAGS
=
parser
.
parse_args
()
print_arguments
(
FLAGS
)
check_gpu
(
str
.
lower
(
FLAGS
.
device
)
==
'gpu'
)
check_version
()
...
...
examples/tsm/main.py
浏览文件 @
2b426605
...
...
@@ -29,6 +29,7 @@ from modeling import tsm_resnet50
from
check
import
check_gpu
,
check_version
from
kinetics_dataset
import
KineticsDataset
from
transforms
import
*
from
utils
import
print_arguments
def
make_optimizer
(
step_per_epoch
,
parameter_list
=
None
):
...
...
@@ -107,7 +108,7 @@ def main():
eval_data
=
val_dataset
,
epochs
=
FLAGS
.
epoch
,
batch_size
=
FLAGS
.
batch_size
,
save_dir
=
'tsm_checkpoint'
,
save_dir
=
FLAGS
.
save_dir
or
'tsm_checkpoint'
,
num_workers
=
FLAGS
.
num_workers
,
drop_last
=
True
,
shuffle
=
True
)
...
...
@@ -149,7 +150,14 @@ if __name__ == '__main__':
default
=
None
,
type
=
str
,
help
=
"weights path for evaluation"
)
parser
.
add_argument
(
"-s"
,
"--save_dir"
,
default
=
None
,
type
=
str
,
help
=
"directory path for checkpoint saving, default ./yolo_checkpoint"
)
FLAGS
=
parser
.
parse_args
()
print_arguments
(
FLAGS
)
check_gpu
(
str
.
lower
(
FLAGS
.
device
)
==
'gpu'
)
check_version
()
...
...
examples/tsm/utils.py
0 → 100644
浏览文件 @
2b426605
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
six
import
logging
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'print_ar'
]
def
print_arguments
(
args
):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
logger
.
info
(
"----------- Configuration Arguments -----------"
)
for
arg
,
value
in
sorted
(
six
.
iteritems
(
vars
(
args
))):
logger
.
info
(
"%s: %s"
%
(
arg
,
value
))
logger
.
info
(
"------------------------------------------------"
)
examples/yolov3/README.md
浏览文件 @
2b426605
...
...
@@ -53,8 +53,8 @@ YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层
```bash
git clone https://github.com/PaddlePaddle/hapi
cd hapi
export PYTHONPATH=
$PYTHONPATH:`pwd`
cd
tsm
export PYTHONPATH=
`pwd`:$PYTHONPATH
cd
examples/yolov3
```
#### 安装COCO-API
...
...
@@ -126,13 +126,13 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --data=
使用如下方式进行多卡训练:
```
bash
CUDA_VISIBLE_DEVICES
=
0,1,2,3 python
main.py
-m
paddle.distributed.launch
--data
=
<path/to/dataset>
--batch_size
=
16
-d
CUDA_VISIBLE_DEVICES
=
0,1,2,3 python
-m
paddle.distributed.launch main.py
--data
=
<path/to/dataset>
--batch_size
=
16
-d
```
### 模型评估
YOLOv3模型输出为LoDTensor,只支持使用batch_size为1进行评估,可通过如下两种方式进行模型评估。
YOLOv3模型输出为LoDTensor,只支持使用
单卡且
batch_size为1进行评估,可通过如下两种方式进行模型评估。
1.
自动下载Paddle发布的
[
YOLOv3-DarkNet53
](
https://paddlemodels.bj.bcebos.com/hapi/yolov3_darknet53.pdparams
)
权重评估
...
...
@@ -180,7 +180,7 @@ python infer.py --label_list=dataset/voc/label_list.txt --infer_image=image/dog.
2.
加载checkpoint进行精度评估
```
bash
python infer.py
--label_list
=
dataset/voc/label_list.txt
--infer_image
=
image/dog.jpg
--weights
=
yolo_checkpoint/
m
o_mixup/final
python infer.py
--label_list
=
dataset/voc/label_list.txt
--infer_image
=
image/dog.jpg
--weights
=
yolo_checkpoint/
n
o_mixup/final
```
推断结果可视化图像会保存于
`--output`
指定的文件夹下,默认保存于
`./output`
目录。
...
...
examples/yolov3/infer.py
浏览文件 @
2b426605
...
...
@@ -28,7 +28,7 @@ from hapi.model import Model, Input, set_device
from
modeling
import
yolov3_darknet53
,
YoloLoss
from
transforms
import
*
from
utils
import
print_arguments
from
visualizer
import
draw_bbox
import
logging
...
...
@@ -91,7 +91,7 @@ def main():
img_id
=
np
.
array
([
0
]).
astype
(
'int64'
)[
np
.
newaxis
,
:]
img_shape
=
np
.
array
([
h
,
w
]).
astype
(
'int32'
)[
np
.
newaxis
,
:]
_
,
bboxes
=
model
.
test
([
img_id
,
img_shape
,
img
])
_
,
bboxes
=
model
.
test
_batch
([
img_id
,
img_shape
,
img
])
vis_img
=
draw_bbox
(
orig_img
,
cat2name
,
bboxes
,
FLAGS
.
draw_threshold
)
save_name
=
get_save_image_name
(
FLAGS
.
output_dir
,
FLAGS
.
infer_image
)
...
...
@@ -121,6 +121,7 @@ if __name__ == '__main__':
"-w"
,
"--weights"
,
default
=
None
,
type
=
str
,
help
=
"path to weights for inference"
)
FLAGS
=
parser
.
parse_args
()
print_arguments
(
FLAGS
)
assert
os
.
path
.
isfile
(
FLAGS
.
infer_image
),
\
"infer_image {} not a file"
.
format
(
FLAGS
.
infer_image
)
assert
os
.
path
.
isfile
(
FLAGS
.
label_list
),
\
...
...
examples/yolov3/main.py
浏览文件 @
2b426605
...
...
@@ -33,6 +33,7 @@ from modeling import yolov3_darknet53, YoloLoss
from
coco
import
COCODataset
from
coco_metric
import
COCOMetric
from
transforms
import
*
from
utils
import
print_arguments
NUM_MAX_BOXES
=
50
...
...
@@ -171,16 +172,18 @@ def main():
if
FLAGS
.
resume
is
not
None
:
model
.
load
(
FLAGS
.
resume
)
save_dir
=
FLAGS
.
save_dir
or
'yolo_checkpoint'
model
.
fit
(
train_data
=
loader
,
epochs
=
FLAGS
.
epoch
-
FLAGS
.
no_mixup_epoch
,
save_dir
=
"yolo_checkpoint/mixup"
,
save_dir
=
os
.
path
.
join
(
save_dir
,
"mixup"
)
,
save_freq
=
10
)
# do not use image mixup transfrom in the last FLAGS.no_mixup_epoch epoches
dataset
.
mixup
=
False
model
.
fit
(
train_data
=
loader
,
epochs
=
FLAGS
.
no_mixup_epoch
,
save_dir
=
"yolo_checkpoint/no_mixup"
,
save_dir
=
os
.
path
.
join
(
save_dir
,
"no_mixup"
)
,
save_freq
=
5
)
...
...
@@ -233,6 +236,13 @@ if __name__ == '__main__':
default
=
None
,
type
=
str
,
help
=
"path to weights for evaluation"
)
parser
.
add_argument
(
"-s"
,
"--save_dir"
,
default
=
None
,
type
=
str
,
help
=
"directory path for checkpoint saving, default ./yolo_checkpoint"
)
FLAGS
=
parser
.
parse_args
()
print_arguments
(
FLAGS
)
assert
FLAGS
.
data
,
"error: must provide data path"
main
()
examples/yolov3/utils.py
0 → 100644
浏览文件 @
2b426605
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
six
import
logging
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'print_ar'
]
def
print_arguments
(
args
):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
logger
.
info
(
"----------- Configuration Arguments -----------"
)
for
arg
,
value
in
sorted
(
six
.
iteritems
(
vars
(
args
))):
logger
.
info
(
"%s: %s"
%
(
arg
,
value
))
logger
.
info
(
"------------------------------------------------"
)
hapi/metrics.py
浏览文件 @
2b426605
...
...
@@ -116,7 +116,7 @@ class Accuracy(Metric):
def
add_metric_op
(
self
,
pred
,
label
,
*
args
):
pred
=
fluid
.
layers
.
argsort
(
pred
,
descending
=
True
)[
1
][:,
:
self
.
maxk
]
correct
=
pred
==
label
return
correct
return
fluid
.
layers
.
cast
(
correct
,
dtype
=
'float32'
)
def
update
(
self
,
correct
,
*
args
):
accs
=
[]
...
...
@@ -143,7 +143,7 @@ class Accuracy(Metric):
if
self
.
maxk
!=
1
:
self
.
_name
=
[
'{}_top{}'
.
format
(
name
,
k
)
for
k
in
self
.
topk
]
else
:
self
.
_name
=
[
'acc'
]
self
.
_name
=
[
name
]
def
name
(
self
):
return
self
.
_name
tests/test_metrics.py
0 → 100644
浏览文件 @
2b426605
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.base
import
to_variable
from
hapi.metrics
import
*
from
hapi.model
import
to_list
def
accuracy
(
pred
,
label
,
topk
=
(
1
,
)):
maxk
=
max
(
topk
)
pred
=
np
.
argsort
(
pred
)[:,
::
-
1
][:,
:
maxk
]
correct
=
(
pred
==
np
.
repeat
(
label
,
maxk
,
1
))
batch_size
=
label
.
shape
[
0
]
res
=
[]
for
k
in
topk
:
correct_k
=
correct
[:,
:
k
].
sum
()
res
.
append
(
correct_k
/
batch_size
)
return
res
def
convert_to_one_hot
(
y
,
C
):
oh
=
np
.
random
.
random
((
y
.
shape
[
0
],
C
)).
astype
(
'float32'
)
*
.
5
for
i
in
range
(
y
.
shape
[
0
]):
oh
[
i
,
int
(
y
[
i
])]
=
1.
return
oh
class
TestAccuracyDynamic
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
topk
=
(
1
,
)
self
.
class_num
=
5
self
.
sample_num
=
1000
self
.
name
=
None
def
random_pred_label
(
self
):
label
=
np
.
random
.
randint
(
0
,
self
.
class_num
,
(
self
.
sample_num
,
1
)).
astype
(
'int64'
)
pred
=
np
.
random
.
randint
(
0
,
self
.
class_num
,
(
self
.
sample_num
,
1
)).
astype
(
'int32'
)
pred_one_hot
=
convert_to_one_hot
(
pred
,
self
.
class_num
)
pred_one_hot
=
pred_one_hot
.
astype
(
'float32'
)
return
label
,
pred_one_hot
def
test_main
(
self
):
with
fluid
.
dygraph
.
guard
(
fluid
.
CPUPlace
()):
acc
=
Accuracy
(
topk
=
self
.
topk
,
name
=
self
.
name
)
for
i
in
range
(
10
):
label
,
pred
=
self
.
random_pred_label
()
label_var
=
to_variable
(
label
)
pred_var
=
to_variable
(
pred
)
state
=
to_list
(
acc
.
add_metric_op
(
pred_var
,
label_var
))
acc
.
update
(
*
[
s
.
numpy
()
for
s
in
state
])
res_m
=
acc
.
accumulate
()
res_f
=
accuracy
(
pred
,
label
,
self
.
topk
)
assert
np
.
all
(
np
.
isclose
(
np
.
array
(
res_m
),
np
.
array
(
res_f
),
rtol
=
1e-3
)),
\
"Accuracy precision error: {} != {}"
.
format
(
res_m
,
res_f
)
acc
.
reset
()
assert
np
.
sum
(
acc
.
total
)
==
0
assert
np
.
sum
(
acc
.
count
)
==
0
class
TestAccuracyDynamicMultiTopk
(
TestAccuracyDynamic
):
def
setUp
(
self
):
self
.
topk
=
(
1
,
5
)
self
.
class_num
=
10
self
.
sample_num
=
1000
self
.
name
=
"accuracy"
class
TestAccuracyStatic
(
TestAccuracyDynamic
):
def
test_main
(
self
):
main_prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
pred
=
fluid
.
data
(
name
=
'pred'
,
shape
=
[
None
,
self
.
class_num
],
dtype
=
'float32'
)
label
=
fluid
.
data
(
name
=
'label'
,
shape
=
[
None
,
1
],
dtype
=
'int64'
)
acc
=
Accuracy
(
topk
=
self
.
topk
,
name
=
self
.
name
)
state
=
acc
.
add_metric_op
(
pred
,
label
)
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
compiled_main_prog
=
fluid
.
CompiledProgram
(
main_prog
)
for
i
in
range
(
10
):
label
,
pred
=
self
.
random_pred_label
()
state_ret
=
exe
.
run
(
compiled_main_prog
,
feed
=
{
'pred'
:
pred
,
'label'
:
label
},
fetch_list
=
[
s
.
name
for
s
in
to_list
(
state
)],
return_numpy
=
True
)
acc
.
update
(
*
state_ret
)
res_m
=
acc
.
accumulate
()
res_f
=
accuracy
(
pred
,
label
,
self
.
topk
)
assert
np
.
all
(
np
.
isclose
(
np
.
array
(
res_m
),
np
.
array
(
res_f
),
rtol
=
1e-3
)),
\
"Accuracy precision error: {} != {}"
.
format
(
res_m
,
res_f
)
acc
.
reset
()
assert
np
.
sum
(
acc
.
total
)
==
0
assert
np
.
sum
(
acc
.
count
)
==
0
class
TestAccuracyStaticMultiTopk
(
TestAccuracyStatic
):
def
setUp
(
self
):
self
.
topk
=
(
1
,
5
)
self
.
class_num
=
10
self
.
sample_num
=
1000
self
.
name
=
"accuracy"
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录