README.md 4.5 KB
Newer Older
D
dengkaipeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# TSM 视频分类模型

---

## 内容

- [模型简介](#模型简介)
- [快速开始](#快速开始)
- [参考论文](#参考论文)


## 模型简介

Temporal Shift Module是由MIT和IBM Watson AI Lab的Ji Lin,Chuang Gan和Song Han等人提出的通过时间位移来提高网络视频理解能力的模块,其位移操作原理如下图所示。

<p align="center">
D
dengkaipeng 已提交
17
<img src="./images/temporal_shift.png" height=250 width=800 hspace='10'/> <br />
D
dengkaipeng 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
Temporal shift module
</p>

上图中矩阵表示特征图中的temporal和channel维度,通过将一部分的channel在temporal维度上向前位移一步,一部分的channel在temporal维度上向后位移一步,位移后的空缺补零。通过这种方式在特征图中引入temporal维度上的上下文交互,提高在时间维度上的视频理解能力。

TSM模型是将Temporal Shift Module插入到ResNet网络中构建的视频分类模型,本模型库实现版本为以ResNet-50作为主干网络的TSM模型。

详细内容请参考论文[Temporal Shift Module for Efficient Video Understanding](https://arxiv.org/abs/1811.08383v1)

## 快速开始

### 安装说明

#### paddle安装

   本项目依赖于 PaddlePaddle 1.7及以上版本或适当的develop版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装

#### 代码下载及环境变量设置

    克隆代码库到本地,并设置`PYTHONPATH`环境变量
D
dengkaipeng 已提交
38 39

    ```bash
D
dengkaipeng 已提交
40 41
    git clone https://github.com/PaddlePaddle/hapi
    cd hapi
D
dengkaipeng 已提交
42 43
    export PYTHONPATH=`pwd`:$PYTHONPATH
    cd examples/tsm
D
dengkaipeng 已提交
44 45 46 47
    ```

### 数据准备

D
dengkaipeng 已提交
48
TSM的训练数据采用由DeepMind公布的Kinetics-400动作识别数据集。数据下载及准备请参考[数据说明](./dataset/README.md)
D
dengkaipeng 已提交
49 50 51 52 53

#### 小数据集验证

为了便于快速迭代,我们采用了较小的数据集进行动态图训练验证,从Kinetics-400数据集中选取分类标签(label)分别为 0, 2, 3, 4, 6, 7, 9, 12, 14, 15的即前10类数据验证模型精度。

D
dengkaipeng 已提交
54
### 模型训练
D
dengkaipeng 已提交
55

D
dengkaipeng 已提交
56 57 58 59
数据准备完毕后,可使用`main.py`脚本启动训练和评估,如下脚本会自动每epoch交替进行训练和模型评估,并将checkpoint默认保存在`tsm_checkpoint`目录下。

`main.py`脚本参数可通过如下命令查询

D
dengkaipeng 已提交
60
```bash
D
dengkaipeng 已提交
61 62
python main.py --help
```
D
dengkaipeng 已提交
63 64 65 66 67

#### 静态图训练

使用如下方式进行单卡训练:

D
dengkaipeng 已提交
68
```bash
D
dengkaipeng 已提交
69 70 71 72 73 74
export CUDA_VISIBLE_DEVICES=0
python main.py --data=<path/to/dataset> --batch_size=16
```

使用如下方式进行多卡训练:

D
dengkaipeng 已提交
75
```bash
D
dengkaipeng 已提交
76
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch main.py --data=<path/to/dataset> --batch_size=8
D
dengkaipeng 已提交
77 78 79 80 81 82 83 84
```

#### 动态图训练

动态图训练只需要在运行脚本时添加`-d`参数即可。

使用如下方式进行单卡训练:

D
dengkaipeng 已提交
85
```bash
D
dengkaipeng 已提交
86 87 88 89 90 91
export CUDA_VISIBLE_DEVICES=0
python main.py --data=<path/to/dataset> --batch_size=16 -d
```

使用如下方式进行多卡训练:

D
dengkaipeng 已提交
92
```bash
D
dengkaipeng 已提交
93
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch main.py --data=<path/to/dataset> --batch_size=8 -d
D
dengkaipeng 已提交
94 95 96 97
```

**注意:** 对于静态图和动态图,多卡训练中`--batch_size`为每卡上的batch_size,即总batch_size为`--batch_size`乘以卡数

D
dengkaipeng 已提交
98
### 模型评估
D
dengkaipeng 已提交
99

D
dengkaipeng 已提交
100 101 102 103
可通过如下两种方式进行模型评估。

1. 自动下载Paddle发布的[TSM-ResNet50](https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams)权重评估

D
dengkaipeng 已提交
104 105
```bash
python main.py --data=<path/to/dataset> --eval_only
D
dengkaipeng 已提交
106 107 108 109
```

2. 加载checkpoint进行精度评估

D
dengkaipeng 已提交
110 111
```bash
python main.py --data=<path/to/dataset> --eval_only --weights=tsm_checkpoint/final
D
dengkaipeng 已提交
112 113 114 115 116
```

#### 评估精度

在10类小数据集下训练模型权重见[TSM-ResNet50](https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams),评估精度如下:
D
dengkaipeng 已提交
117 118 119

|Top-1|Top-5|
|:-:|:-:|
D
dengkaipeng 已提交
120
|76%|98%|
D
dengkaipeng 已提交
121

D
dengkaipeng 已提交
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
### 模型推断

可通过如下两种方式进行模型推断。

1. 自动下载Paddle发布的[TSM-ResNet50](https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams)权重推断

```bash
python infer.py --data=<path/to/dataset> --label_list=<path/to/label_list> --infer_file=<path/to/pickle>
```

2. 加载checkpoint进行精度推断

```bash
python infer.py --data=<path/to/dataset> --label_list=<path/to/label_list> --infer_file=<path/to/pickle> --weights=tsm_checkpoint/final
```

模型推断结果会以如下日志形式输出

```text
2020-04-03 07:37:16,321-INFO: Sample ./kineteics/val_10/data_batch_10-042_6 predict label: 6, ground truth label: 6
```

D
dengkaipeng 已提交
144 145
**注意:** 推断时`--infer_file`需要指定到pickle文件路径。

D
dengkaipeng 已提交
146 147 148
## 参考论文

- [Temporal Shift Module for Efficient Video Understanding](https://arxiv.org/abs/1811.08383v1), Ji Lin, Chuang Gan, Song Han