Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
41adc109
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
41adc109
编写于
4月 03, 2020
作者:
D
dengkaipeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add infer.py
上级
d9f6c64f
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
134 addition
and
14 deletion
+134
-14
tsm/README.md
tsm/README.md
+22
-0
tsm/infer.py
tsm/infer.py
+91
-0
tsm/kinetics_dataset.py
tsm/kinetics_dataset.py
+21
-14
未找到文件。
tsm/README.md
浏览文件 @
41adc109
...
@@ -119,6 +119,28 @@ python main.py --data=<path/to/dataset> --eval_only --weights=tsm_checkpoint/fin
...
@@ -119,6 +119,28 @@ python main.py --data=<path/to/dataset> --eval_only --weights=tsm_checkpoint/fin
|:-:|:-:|
|:-:|:-:|
|76%|98%|
|76%|98%|
### 模型推断
可通过如下两种方式进行模型推断。
1.
自动下载Paddle发布的
[
TSM-ResNet50
](
https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams
)
权重推断
```
bash
python infer.py
--data
=
<path/to/dataset>
--label_list
=
<path/to/label_list>
--infer_file
=
<path/to/pickle>
```
2.
加载checkpoint进行精度推断
```
bash
python infer.py
--data
=
<path/to/dataset>
--label_list
=
<path/to/label_list>
--infer_file
=
<path/to/pickle>
--weights
=
tsm_checkpoint/final
```
模型推断结果会以如下日志形式输出
```
text
2020-04-03 07:37:16,321-INFO: Sample ./kineteics/val_10/data_batch_10-042_6 predict label: 6, ground truth label: 6
```
## 参考论文
## 参考论文
-
[
Temporal Shift Module for Efficient Video Understanding
](
https://arxiv.org/abs/1811.08383v1
)
, Ji Lin, Chuang Gan, Song Han
-
[
Temporal Shift Module for Efficient Video Understanding
](
https://arxiv.org/abs/1811.08383v1
)
, Ji Lin, Chuang Gan, Song Han
...
...
tsm/infer.py
0 → 100644
浏览文件 @
41adc109
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
from
__future__
import
print_function
import
os
import
argparse
import
numpy
as
np
from
model
import
Input
,
set_device
from
check
import
check_gpu
,
check_version
from
modeling
import
tsm_resnet50
from
kinetics_dataset
import
KineticsDataset
from
transforms
import
*
import
logging
logger
=
logging
.
getLogger
(
__name__
)
def
main
():
device
=
set_device
(
FLAGS
.
device
)
fluid
.
enable_dygraph
(
device
)
if
FLAGS
.
dynamic
else
None
transform
=
Compose
([
GroupScale
(),
GroupCenterCrop
(),
NormalizeImage
()])
dataset
=
KineticsDataset
(
pickle_file
=
FLAGS
.
infer_file
,
label_list
=
FLAGS
.
label_list
,
mode
=
'test'
,
transform
=
transform
)
labels
=
dataset
.
label_list
model
=
tsm_resnet50
(
num_classes
=
len
(
labels
),
pretrained
=
FLAGS
.
weights
is
None
)
inputs
=
[
Input
([
None
,
8
,
3
,
224
,
224
],
'float32'
,
name
=
'image'
)]
model
.
prepare
(
inputs
=
inputs
,
device
=
FLAGS
.
device
)
if
FLAGS
.
weights
is
not
None
:
model
.
load
(
FLAGS
.
weights
,
reset_optimizer
=
True
)
imgs
,
label
=
dataset
[
0
]
pred
=
model
.
test
([
imgs
[
np
.
newaxis
,
:]])
pred
=
labels
[
np
.
argmax
(
pred
)]
logger
.
info
(
"Sample {} predict label: {}, ground truth label: {}"
\
.
format
(
FLAGS
.
infer_file
,
pred
,
labels
[
int
(
label
)]))
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
"CNN training on TSM"
)
parser
.
add_argument
(
"--data"
,
type
=
str
,
default
=
'dataset/kinetics'
,
help
=
"path to dataset root directory"
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
'gpu'
,
help
=
"device to use, gpu or cpu"
)
parser
.
add_argument
(
"-d"
,
"--dynamic"
,
action
=
'store_true'
,
help
=
"enable dygraph mode"
)
parser
.
add_argument
(
"--label_list"
,
type
=
str
,
default
=
None
,
help
=
"path to category index label list file"
)
parser
.
add_argument
(
"--infer_file"
,
type
=
str
,
default
=
None
,
help
=
"path to pickle file for inference"
)
parser
.
add_argument
(
"-w"
,
"--weights"
,
default
=
None
,
type
=
str
,
help
=
"weights path for evaluation"
)
FLAGS
=
parser
.
parse_args
()
check_gpu
(
str
.
lower
(
FLAGS
.
device
)
==
'gpu'
)
check_version
()
main
()
tsm/kinetics_dataset.py
浏览文件 @
41adc109
...
@@ -56,13 +56,19 @@ class KineticsDataset(Dataset):
...
@@ -56,13 +56,19 @@ class KineticsDataset(Dataset):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
file_list
,
file_list
=
None
,
pickle_dir
,
pickle_dir
=
None
,
pickle_file
=
None
,
label_list
=
None
,
label_list
=
None
,
mode
=
'train'
,
mode
=
'train'
,
seg_num
=
8
,
seg_num
=
8
,
seg_len
=
1
,
seg_len
=
1
,
transform
=
None
):
transform
=
None
):
assert
str
.
lower
(
mode
)
in
[
'train'
,
'val'
,
'test'
],
\
"mode can only be 'train' 'val' or 'test'"
self
.
mode
=
str
.
lower
(
mode
)
if
self
.
mode
in
[
'train'
,
'val'
]:
assert
os
.
path
.
isfile
(
file_list
),
\
assert
os
.
path
.
isfile
(
file_list
),
\
"file_list {} not a file"
.
format
(
file_list
)
"file_list {} not a file"
.
format
(
file_list
)
with
open
(
file_list
)
as
f
:
with
open
(
file_list
)
as
f
:
...
@@ -71,6 +77,11 @@ class KineticsDataset(Dataset):
...
@@ -71,6 +77,11 @@ class KineticsDataset(Dataset):
assert
os
.
path
.
isdir
(
pickle_dir
),
\
assert
os
.
path
.
isdir
(
pickle_dir
),
\
"pickle_dir {} not a directory"
.
format
(
pickle_dir
)
"pickle_dir {} not a directory"
.
format
(
pickle_dir
)
self
.
pickle_dir
=
pickle_dir
self
.
pickle_dir
=
pickle_dir
else
:
assert
os
.
path
.
isfile
(
pickle_file
),
\
"pickle_file {} not a file"
.
format
(
pickle_file
)
self
.
pickle_dir
=
''
self
.
pickle_paths
=
[
pickle_file
]
self
.
label_list
=
label_list
self
.
label_list
=
label_list
if
self
.
label_list
is
not
None
:
if
self
.
label_list
is
not
None
:
...
@@ -79,10 +90,6 @@ class KineticsDataset(Dataset):
...
@@ -79,10 +90,6 @@ class KineticsDataset(Dataset):
with
open
(
self
.
label_list
)
as
f
:
with
open
(
self
.
label_list
)
as
f
:
self
.
label_list
=
[
int
(
l
.
strip
())
for
l
in
f
]
self
.
label_list
=
[
int
(
l
.
strip
())
for
l
in
f
]
assert
mode
in
[
'train'
,
'val'
],
\
"mode can only be 'train' or 'val'"
self
.
mode
=
mode
self
.
seg_num
=
seg_num
self
.
seg_num
=
seg_num
self
.
seg_len
=
seg_len
self
.
seg_len
=
seg_len
self
.
transform
=
transform
self
.
transform
=
transform
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录