提交 41adc109 编写于 作者: D dengkaipeng

add infer.py

上级 d9f6c64f
...@@ -119,6 +119,28 @@ python main.py --data=<path/to/dataset> --eval_only --weights=tsm_checkpoint/fin ...@@ -119,6 +119,28 @@ python main.py --data=<path/to/dataset> --eval_only --weights=tsm_checkpoint/fin
|:-:|:-:| |:-:|:-:|
|76%|98%| |76%|98%|
### 模型推断
可通过如下两种方式进行模型推断。
1. 自动下载Paddle发布的[TSM-ResNet50](https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams)权重推断
```bash
python infer.py --data=<path/to/dataset> --label_list=<path/to/label_list> --infer_file=<path/to/pickle>
```
2. 加载checkpoint进行精度推断
```bash
python infer.py --data=<path/to/dataset> --label_list=<path/to/label_list> --infer_file=<path/to/pickle> --weights=tsm_checkpoint/final
```
模型推断结果会以如下日志形式输出
```text
2020-04-03 07:37:16,321-INFO: Sample ./kineteics/val_10/data_batch_10-042_6 predict label: 6, ground truth label: 6
```
## 参考论文 ## 参考论文
- [Temporal Shift Module for Efficient Video Understanding](https://arxiv.org/abs/1811.08383v1), Ji Lin, Chuang Gan, Song Han - [Temporal Shift Module for Efficient Video Understanding](https://arxiv.org/abs/1811.08383v1), Ji Lin, Chuang Gan, Song Han
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import os
import argparse
import numpy as np
from model import Input, set_device
from check import check_gpu, check_version
from modeling import tsm_resnet50
from kinetics_dataset import KineticsDataset
from transforms import *
import logging
logger = logging.getLogger(__name__)
def main():
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if FLAGS.dynamic else None
transform = Compose([GroupScale(),
GroupCenterCrop(),
NormalizeImage()])
dataset = KineticsDataset(
pickle_file=FLAGS.infer_file,
label_list=FLAGS.label_list,
mode='test',
transform=transform)
labels = dataset.label_list
model = tsm_resnet50(num_classes=len(labels),
pretrained=FLAGS.weights is None)
inputs = [Input([None, 8, 3, 224, 224], 'float32', name='image')]
model.prepare(inputs=inputs, device=FLAGS.device)
if FLAGS.weights is not None:
model.load(FLAGS.weights, reset_optimizer=True)
imgs, label = dataset[0]
pred = model.test([imgs[np.newaxis, :]])
pred = labels[np.argmax(pred)]
logger.info("Sample {} predict label: {}, ground truth label: {}" \
.format(FLAGS.infer_file, pred, labels[int(label)]))
if __name__ == '__main__':
parser = argparse.ArgumentParser("CNN training on TSM")
parser.add_argument(
"--data", type=str, default='dataset/kinetics',
help="path to dataset root directory")
parser.add_argument(
"--device", type=str, default='gpu',
help="device to use, gpu or cpu")
parser.add_argument(
"-d", "--dynamic", action='store_true',
help="enable dygraph mode")
parser.add_argument(
"--label_list", type=str, default=None,
help="path to category index label list file")
parser.add_argument(
"--infer_file", type=str, default=None,
help="path to pickle file for inference")
parser.add_argument(
"-w",
"--weights",
default=None,
type=str,
help="weights path for evaluation")
FLAGS = parser.parse_args()
check_gpu(str.lower(FLAGS.device) == 'gpu')
check_version()
main()
...@@ -56,13 +56,19 @@ class KineticsDataset(Dataset): ...@@ -56,13 +56,19 @@ class KineticsDataset(Dataset):
""" """
def __init__(self, def __init__(self,
file_list, file_list=None,
pickle_dir, pickle_dir=None,
pickle_file=None,
label_list=None, label_list=None,
mode='train', mode='train',
seg_num=8, seg_num=8,
seg_len=1, seg_len=1,
transform=None): transform=None):
assert str.lower(mode) in ['train', 'val', 'test'], \
"mode can only be 'train' 'val' or 'test'"
self.mode = str.lower(mode)
if self.mode in ['train', 'val']:
assert os.path.isfile(file_list), \ assert os.path.isfile(file_list), \
"file_list {} not a file".format(file_list) "file_list {} not a file".format(file_list)
with open(file_list) as f: with open(file_list) as f:
...@@ -71,6 +77,11 @@ class KineticsDataset(Dataset): ...@@ -71,6 +77,11 @@ class KineticsDataset(Dataset):
assert os.path.isdir(pickle_dir), \ assert os.path.isdir(pickle_dir), \
"pickle_dir {} not a directory".format(pickle_dir) "pickle_dir {} not a directory".format(pickle_dir)
self.pickle_dir = pickle_dir self.pickle_dir = pickle_dir
else:
assert os.path.isfile(pickle_file), \
"pickle_file {} not a file".format(pickle_file)
self.pickle_dir = ''
self.pickle_paths = [pickle_file]
self.label_list = label_list self.label_list = label_list
if self.label_list is not None: if self.label_list is not None:
...@@ -79,10 +90,6 @@ class KineticsDataset(Dataset): ...@@ -79,10 +90,6 @@ class KineticsDataset(Dataset):
with open(self.label_list) as f: with open(self.label_list) as f:
self.label_list = [int(l.strip()) for l in f] self.label_list = [int(l.strip()) for l in f]
assert mode in ['train', 'val'], \
"mode can only be 'train' or 'val'"
self.mode = mode
self.seg_num = seg_num self.seg_num = seg_num
self.seg_len = seg_len self.seg_len = seg_len
self.transform = transform self.transform = transform
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册