diff --git a/tsm/README.md b/tsm/README.md index 5d4e2b708743f5aa4b451c646c67a70d6f6c408c..8e6138942c76ec577975c5cdfb9d89e578e7d17b 100644 --- a/tsm/README.md +++ b/tsm/README.md @@ -119,6 +119,28 @@ python main.py --data= --eval_only --weights=tsm_checkpoint/fin |:-:|:-:| |76%|98%| +### 模型推断 + +可通过如下两种方式进行模型推断。 + +1. 自动下载Paddle发布的[TSM-ResNet50](https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams)权重推断 + +```bash +python infer.py --data= --label_list= --infer_file= +``` + +2. 加载checkpoint进行精度推断 + +```bash +python infer.py --data= --label_list= --infer_file= --weights=tsm_checkpoint/final +``` + +模型推断结果会以如下日志形式输出 + +```text +2020-04-03 07:37:16,321-INFO: Sample ./kineteics/val_10/data_batch_10-042_6 predict label: 6, ground truth label: 6 +``` + ## 参考论文 - [Temporal Shift Module for Efficient Video Understanding](https://arxiv.org/abs/1811.08383v1), Ji Lin, Chuang Gan, Song Han diff --git a/tsm/infer.py b/tsm/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..c7995dfa62b2eccde24a0e5110cd3731f4c82e12 --- /dev/null +++ b/tsm/infer.py @@ -0,0 +1,91 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import os +import argparse +import numpy as np + +from model import Input, set_device + +from check import check_gpu, check_version +from modeling import tsm_resnet50 +from kinetics_dataset import KineticsDataset +from transforms import * + +import logging +logger = logging.getLogger(__name__) + + +def main(): + device = set_device(FLAGS.device) + fluid.enable_dygraph(device) if FLAGS.dynamic else None + + transform = Compose([GroupScale(), + GroupCenterCrop(), + NormalizeImage()]) + dataset = KineticsDataset( + pickle_file=FLAGS.infer_file, + label_list=FLAGS.label_list, + mode='test', + transform=transform) + labels = dataset.label_list + + model = tsm_resnet50(num_classes=len(labels), + pretrained=FLAGS.weights is None) + + inputs = [Input([None, 8, 3, 224, 224], 'float32', name='image')] + + model.prepare(inputs=inputs, device=FLAGS.device) + + if FLAGS.weights is not None: + model.load(FLAGS.weights, reset_optimizer=True) + + imgs, label = dataset[0] + pred = model.test([imgs[np.newaxis, :]]) + pred = labels[np.argmax(pred)] + logger.info("Sample {} predict label: {}, ground truth label: {}" \ + .format(FLAGS.infer_file, pred, labels[int(label)])) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("CNN training on TSM") + parser.add_argument( + "--data", type=str, default='dataset/kinetics', + help="path to dataset root directory") + parser.add_argument( + "--device", type=str, default='gpu', + help="device to use, gpu or cpu") + parser.add_argument( + "-d", "--dynamic", action='store_true', + help="enable dygraph mode") + parser.add_argument( + "--label_list", type=str, default=None, + help="path to category index label list file") + parser.add_argument( + "--infer_file", type=str, default=None, + help="path to pickle file for inference") + parser.add_argument( + "-w", + "--weights", + default=None, + type=str, + help="weights path for evaluation") + FLAGS = parser.parse_args() + + check_gpu(str.lower(FLAGS.device) == 'gpu') + check_version() + main() diff --git a/tsm/kinetics_dataset.py b/tsm/kinetics_dataset.py index 7d0e8fe11db6b0d1188fbd07086b4d79bd3e0ea4..7e07543f37392744a2bf82ecc9b038e78d2d5524 100644 --- a/tsm/kinetics_dataset.py +++ b/tsm/kinetics_dataset.py @@ -56,21 +56,32 @@ class KineticsDataset(Dataset): """ def __init__(self, - file_list, - pickle_dir, + file_list=None, + pickle_dir=None, + pickle_file=None, label_list=None, mode='train', seg_num=8, seg_len=1, transform=None): - assert os.path.isfile(file_list), \ - "file_list {} not a file".format(file_list) - with open(file_list) as f: - self.pickle_paths = [l.strip() for l in f] - - assert os.path.isdir(pickle_dir), \ - "pickle_dir {} not a directory".format(pickle_dir) - self.pickle_dir = pickle_dir + assert str.lower(mode) in ['train', 'val', 'test'], \ + "mode can only be 'train' 'val' or 'test'" + self.mode = str.lower(mode) + + if self.mode in ['train', 'val']: + assert os.path.isfile(file_list), \ + "file_list {} not a file".format(file_list) + with open(file_list) as f: + self.pickle_paths = [l.strip() for l in f] + + assert os.path.isdir(pickle_dir), \ + "pickle_dir {} not a directory".format(pickle_dir) + self.pickle_dir = pickle_dir + else: + assert os.path.isfile(pickle_file), \ + "pickle_file {} not a file".format(pickle_file) + self.pickle_dir = '' + self.pickle_paths = [pickle_file] self.label_list = label_list if self.label_list is not None: @@ -79,10 +90,6 @@ class KineticsDataset(Dataset): with open(self.label_list) as f: self.label_list = [int(l.strip()) for l in f] - assert mode in ['train', 'val'], \ - "mode can only be 'train' or 'val'" - self.mode = mode - self.seg_num = seg_num self.seg_len = seg_len self.transform = transform