README.md 3.3 KB
Newer Older
D
dengkaipeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
# StNet 视频分类模型

---
## 内容

- [简介](#简介)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
- [模型评估](#模型评估)
- [模型推断](#模型推断)


## 简介

StNet为百度自研模型,该框架为百度在ActivityNet Kinetics Challenge 2018中夺冠的基础网络框架,本次开源的是基于ResNet50实现的StNet模型,基于其他backbone网络的框架用户可以依样配置。该模型提出“super-image"的概念,在super-image上进行2D卷积,建模视频中局部时空相关性。另外通过temporal modeling block建模视频的全局时空依赖,最后用一个temporal Xception block对抽取的特征序列进行长时序建模。StNet主体网络结构如下图所示:

<p align="center">
<img src="../../images/StNet.png" height=300 width=500 hspace='10'/> <br />
StNet Framework Overview
</p>

详细内容请参考AAAI'2019年论文[StNet:Local and Global Spatial-Temporal Modeling for Human Action Recognition](https://arxiv.org/abs/1811.01549)

## 数据准备

StNet的训练数据采用由DeepMind公布的Kinetics-400动作识别数据集。数据下载及准备请参考[数据说明](../../dataset/README.md)

## 模型训练

数据准备完毕后,可以通过如下两种方式启动训练:

    python train.py --model-name=STNET
            --config=./configs/attention_stnet.txt
            --save-dir=checkpoints 
            --epoch=20 
            --log-interval=10 
            --valid-interval=1

    bash scripts/train/train_attention_stnet.sh

**数据读取器说明:** 模型读取Kinetics-400数据集中的`mp4`数据,每条数据抽取`seg_num`段,每段抽取`seg_len`帧图像,对每帧图像做随机增强后,缩放至`target_size`

**训练策略:**

*  采用Momentum优化算法训练,momentum=0.9
*  权重衰减系数为1e-4
*  学习率在训练的总epoch数的1/3和2/3时分别做0.1的衰减

## 模型评估

可通过如下两种方式进行模型评估:

    python test.py --model-name=STNET
            --config=configs/attention_stnet.txt
            --log-interval=1 --weights=$PATH_TO_WEIGHTS

    bash scripts/test/test_attention_stnet.sh

- 使用`scripts/test/test_attention_stnet.sh`进行评估时,需要修改脚本中的`--weights`参数指定需要评估的权重。

- 若未指定`--weights`参数,脚本会下载Paddle release权重[PaddleStNet](https://paddlemodels.bj.bcebos.com/video_classification/attention_stnet_kinetics.tar.gz)进行评估

当取如下参数时:

| 参数 | 取值 |
| :---------: | :----: |
| seg\_num | 25 |
| seglen | 5 |
| target\_size | 256 |

在Kinetics400的validation数据集下评估精度如下:

| 精度指标 | 模型精度 |
| :---------: | :----: |
| Prec@1 | 0.69 |


## 模型推断

可通过如下命令进行模型推断:

    python infer.py --model-name=attention_stnet
            --config=configs/attention_stnet.txt
            --log-interval=1 
            --weights=$PATH_TO_WEIGHTS 
            --filelist=$FILELIST

- 模型推断结果存储于`STNET_infer_result`中,通过`pickle`格式存储。

- 若未指定`--weights`参数,脚本会下载Paddle release权重[PaddleStNet](https://paddlemodels.bj.bcebos.com/video_classification/attention_stnet_kinetics.tar.gz)进行推断