未验证 提交 4257b82d 编写于 作者: L LiuChaoXD 提交者: GitHub

add tsn model based on paddle 2.0 platform (#4837)

上级 7a36ec59
# TSN 视频分类模型
本目录下为基于PaddlePaddle 动态图实现的TSN视频分类模型。模型支持PaddlePaddle Fluid 1.8, GPU, Linux。
---
## 内容
- [模型简介](#模型简介)
- [安装说明](#安装说明)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
- [模型评估](#模型评估)
- [实验结果](#实验结果)
- [参考论文](#参考论文)
## 模型简介
Temporal Segment Network (TSN) 是视频分类领域经典的基于2D-CNN的解决方案。该方法主要解决视频的长时间行为判断问题,通过稀疏采样视频帧的方式代替稠密采样,既能捕获视频全局信息,也能去除冗余,降低计算量。最终将每帧特征平均融合后得到视频的整体特征,并用于分类。本代码实现的模型为基于单路RGB图像的TSN网络结构,Backbone采用ResNet50结构。
详细内容请参考ECCV 2016年论文[Temporal Segment Networks: Towards Good Practices for Deep Action Recognition](https://arxiv.org/abs/1608.00859)
## 安装说明
### 环境依赖:
```
python=3.7
paddlepaddle-gpu==2.0.0a0
opencv=4.3
CUDA >= 9.0
cudnn >= 7.5
wget
numpy
```
### 依赖安装:
- 安装PaddlePaddle,GPU版本:
``` pip3 install paddlepaddle-gpu==2.0.0a0 -i https://mirror.baidu.com/pypi/simple```
- 安装opencv 4.3:
``` pip3 install opencv-python==4.3.0.36```
- 安装wget
``` pip3 install wget```
- 安装numpy
``` pip3 install numpy```
## 数据准备
TSN的训练数据采用UCF101动作识别数据集。数据下载及处理请参考[数据说明](./data/dataset/ucf101/README.md)。数据处理完成后,会在`./data/dataset/ucf101/`目录下,生成一下文件:
- `videos/` : 用于存放UCF101数据的视频文件。
- `rawframes/` : 用于存放UCF101视频文件的frame数据。
- `annotations/` : 用于存储UCF101数据集的标注文件。
- `ucf101_train_split_{1,2,3}_rawframes.txt``ucf101_val_split_{1,2,3}_rawframes.txt``ucf101_train_split_{1,2,3}_videos.txt``ucf101_val_split_{1,2,3}_videos.txt` : 为数据的路径list文件。
说明:对应UCF101官方的annotations标注文件,UCF101数据的list文件共有三种不同的切分。例如,ucf101_train_split_1_rawframes.txt 和 ucf101_val_split_1_rawframes.txt 表示对UCF101划分为train和val两部分。ucf101_train_split_2_rawframes.txt 和 ucf101_val_split_2_rawframes.txt 表示对UCF101的另一种train和val划分。训练和测试所使用的list文件,需要一一对应。
## 模型训练
TSN模型支持输入数据为video和frame格式。数据准备完毕后,可以通过如下方式启动不同格式的训练。
1. 多卡训练(输入为frame格式)
```bash
bash multi_gpus_run.sh ./multi_tsn_frame.yaml
```
多卡训练所使用的gpu可以通过如下方式设置:
- 修改`multi_gpus_run.sh``export CUDA_VISIBLE_DEVICES=0,1,2,3`(默认为0,1,2,3表示使用0,1,2,3卡号的gpu进行训练)
- 注意:多卡、frame格式的训练参数配置文件为`multi_tsn_frame.yaml`。若修改了batchsize则学习率也要做相应的修改,规则为大batchsize用大lr,即同倍数增长缩小关系。例如,默认四卡batchsize=128,lr=0.001,若batchsize=64,lr=0.0005。
2. 多卡训练(输入为video格式)
```bash
bash multi_gpus_run.sh ./multi_tsn_video.yaml
```
多卡训练所使用的gpu可以通过如下方式设置:
- 修改`multi_gpus_run.sh``export CUDA_VISIBLE_DEVICES=0,1,2,3`(默认为0,1,2,3表示使用0,1,2,3卡号的gpu进行训练)
- 注意:多卡、video格式的训练参数配置文件为`multi_tsn_video.yaml`。若修改了batchsize则学习率也要做相应的修改,规则同上。
3. 单卡训练(输入为frame格式)
```bash
bash single_gpu_run.sh ./single_tsn_frame.yaml
```
单卡训练所使用的gpu可以通过如下方式设置:
- 修改 `single_gpu_run.sh` 中的 `export CUDA_VISIBLE_DEVICES=0` (表示使用gpu 0 进行模型训练)
- 注意:单卡、frame格式的训练参数配置文件为`single_tsn_frame.yaml`。若修改了batchsize则学习率也要做相应的修改,规则为大batchsize用大lr,即同倍数增长缩小关系。默认单卡batchsize=64,lr=0.0005;若batchsize=32,lr=0.00025。
4. 单卡训练(输入为video格式)
```bash
bash single_gpu_run.sh ./single_tsn_video.yaml
```
单卡训练所使用的gpu可以通过如下方式设置:
- 修改 `single_gpu_run.sh` 中的 `export CUDA_VISIBLE_DEVICES=0` (表示使用gpu 0 进行模型训练)
- 注意:单卡、frame格式的训练参数配置文件为`single_tsn_video.yaml`。若修改了batchsize则学习率也要做相应的修改,规则同上。
## 模型评估
可通过如下方式进行模型评估:
```bash
bash run_eval.sh ./tsn_test.yaml ./weights/final.pdparams
```
- `./tsn_test.yaml` 是评估模型时所用的参数文件;`./weights/final.pdparams` 为模型训练完成后,保存的模型文件
- 评估结果以log的形式直接打印输出TOP1\_ACC、TOP5\_ACC等精度指标
## 实验结果
训练时,Paddle TSN (静态图/动态图) 都才用四卡、输入数据格式为frame, seg_num=3, batchsize=128, lr=0.001。
评估时,输入数据格式为frame,seg_num=25。
备注:seg_num表示训练或者测试时,对每个视频文件采样视频帧的个数。
在UCF101数据validation数据集的评估精度如下:
| | 路径文件 | seg\_num(训练) | seg\_num(测试)| Top-1 | Top-5 |
| :------: | :----------:| :----------: | :----------: | :----: | :----: |
| Paddle TSN (静态图)| ucf101_{train/val}_split_1_rawframes.txt| 3 | 25 | 84.00% | 97.38% |
| Paddle TSN (动态图)| ucf101_{train/val}_split_1_rawframes.txt| 3 | 25 | 84.27% | 97.27% |
## 参考论文
- [Temporal Segment Networks: Towards Good Practices for Deep Action Recognition](https://arxiv.org/abs/1608.00859), Limin Wang, Yuanjun Xiong, Zhe Wang, Yu Qiao, Dahua Lin, Xiaoou Tang, Luc Van Gool
# UCF101数据准备
UCF101数据的相关准备。主要包括UCF101的video文件下载,video文件提取frames,以及生成文件的路径list。
---
## 1. 数据下载
UCF101数据的详细信息可以参考网站[UCF101](https://www.crcv.ucf.edu/data/UCF101.php)。 为了方便用户使用,我们提供了UCF101数据的annotations文件和videos文件的下载脚本。
### 下载annotations文件
首先,请确保在`./data/dataset/ucf101/`目录下,输入如下UCF101数据集的标注文件的命令。
```shell
bash download_annotations.sh
```
### 下载UCF101的视频文件
同样需要确保在`./data/dataset/ucf101/`目录下,输入下述命令下载视频文件
```shell
bash download_videos.sh
```
下载完成后视频文件会存储在`./data/dataset/ucf101/videos/`文件夹下,视频文件大小为6.8G。
---
## 2. 提取视频文件的frames
为了加速网络的训练过程,我们首先对视频文件(ucf101视频文件为avi格式)提取帧 (frames)。相对于直接通过视频文件进行网络训练的方式,frames的方式能够加快网络训练的速度。
直接输入如下命令,即可提取ucf101视频文件的frames
``` python
python extract_rawframes.py ./videos/ ./rawframes/ --level 2 --ext avi
```
视频文件frames提取完成后,会存储在`./rawframes`文件夹下,大小为56G。
---
## 3. 生成frames文件和视频文件的路径list
生成视频文件的路径list,输入如下命令
```python
python build_ucf101_file_list.py videos/ --level 2 --format videos --out_list_path ./
```
生成frames文件的路径list,输入如下命令:
```python
python build_ucf101_file_list.py rawframes/ --level 2 --format rawframes --out_list_path ./
```
**参数说明**
`videos/` 或者 `rawframes/` : 表示视频或者frames文件的存储路径
`--level 2` : 表示文件的存储结构
`--format`: 表示是针对视频还是frames生成路径list
`--out_list_path `: 表示生的路径list文件存储位置
# 以上步骤完成后,文件组织形式如下所示
```
├── data
| ├── dataset
| │ ├── ucf101
| │ │ ├── ucf101_{train,val}_split_{1,2,3}_rawframes.txt
| │ │ ├── ucf101_{train,val}_split_{1,2,3}_videos.txt
| │ │ ├── annotations
| │ │ ├── videos
| │ │ │ ├── ApplyEyeMakeup
| │ │ │ │ ├── v_ApplyEyeMakeup_g01_c01.avi
|
| │ │ │ ├── YoYo
| │ │ │ │ ├── v_YoYo_g25_c05.avi
| │ │ ├── rawframes
| │ │ │ ├── ApplyEyeMakeup
| │ │ │ │ ├── v_ApplyEyeMakeup_g01_c01
| │ │ │ │ │ ├── img_00001.jpg
| │ │ │ │ │ ├── img_00002.jpg
| │ │ │ │ │ ├── ...
| │ │ │ │ │ ├── flow_x_00001.jpg
| │ │ │ │ │ ├── flow_x_00002.jpg
| │ │ │ │ │ ├── ...
| │ │ │ │ │ ├── flow_y_00001.jpg
| │ │ │ │ │ ├── flow_y_00002.jpg
| │ │ │ ├── ...
| │ │ │ ├── YoYo
| │ │ │ │ ├── v_YoYo_g01_c01
| │ │ │ │ ├── ...
| │ │ │ │ ├── v_YoYo_g25_c05
```
import argparse
import os
import glob
import fnmatch
import random
def parse_directory(path,
key_func=lambda x: x[-11:],
rgb_prefix='img_',
level=1):
"""
Parse directories holding extracted frames from standard benchmarks
"""
print('parse frames under folder {}'.format(path))
if level == 1:
frame_folders = glob.glob(os.path.join(path, '*'))
elif level == 2:
frame_folders = glob.glob(os.path.join(path, '*', '*'))
else:
raise ValueError('level can be only 1 or 2')
def count_files(directory, prefix_list):
lst = os.listdir(directory)
cnt_list = [len(fnmatch.filter(lst, x + '*')) for x in prefix_list]
return cnt_list
# check RGB
frame_dict = {}
for i, f in enumerate(frame_folders):
all_cnt = count_files(f, (rgb_prefix))
k = key_func(f)
x_cnt = all_cnt[1]
y_cnt = all_cnt[2]
if x_cnt != y_cnt:
raise ValueError('x and y direction have different number '
'of flow images. video: ' + f)
if i % 200 == 0:
print('{} videos parsed'.format(i))
frame_dict[k] = (f, all_cnt[0], x_cnt)
print('frame folder analysis done')
return frame_dict
def build_split_list(split, frame_info, shuffle=False):
def build_set_list(set_list):
rgb_list = list()
for item in set_list:
if item[0] not in frame_info:
continue
elif frame_info[item[0]][1] > 0:
rgb_cnt = frame_info[item[0]][1]
rgb_list.append('{} {} {}\n'.format(item[0], rgb_cnt, item[1]))
else:
rgb_list.append('{} {}\n'.format(item[0], item[1]))
if shuffle:
random.shuffle(rgb_list)
return rgb_list
train_rgb_list = build_set_list(split[0])
test_rgb_list = build_set_list(split[1])
return (train_rgb_list, test_rgb_list)
def parse_ucf101_splits(level):
class_ind = [x.strip().split() for x in open('./annotations/classInd.txt')]
class_mapping = {x[1]: int(x[0]) - 1 for x in class_ind}
def line2rec(line):
items = line.strip().split(' ')
vid = items[0].split('.')[0]
vid = '/'.join(vid.split('/')[-level:])
label = class_mapping[items[0].split('/')[0]]
return vid, label
splits = []
for i in range(1, 4):
train_list = [
line2rec(x)
for x in open('./annotations/trainlist{:02d}.txt'.format(i))
]
test_list = [
line2rec(x)
for x in open('./annotations/testlist{:02d}.txt'.format(i))
]
splits.append((train_list, test_list))
return splits
def parse_args():
parser = argparse.ArgumentParser(description='Build file list')
parser.add_argument(
'frame_path', type=str, help='root directory for the frames')
parser.add_argument('--rgb_prefix', type=str, default='img_')
parser.add_argument('--num_split', type=int, default=3)
parser.add_argument('--level', type=int, default=2, choices=[1, 2])
parser.add_argument(
'--format',
type=str,
default='rawframes',
choices=['rawframes', 'videos'])
parser.add_argument('--out_list_path', type=str, default='./')
parser.add_argument('--shuffle', action='store_true', default=True)
args = parser.parse_args()
return args
def main():
args = parse_args()
if args.level == 2:
def key_func(x):
return '/'.join(x.split('/')[-2:])
else:
def key_func(x):
return x.split('/')[-1]
if args.format == 'rawframes':
frame_info = parse_directory(
args.frame_path,
key_func=key_func,
rgb_prefix=args.rgb_prefix,
level=args.level)
elif args.format == 'videos':
if args.level == 1:
video_list = glob.glob(os.path.join(args.frame_path, '*'))
elif args.level == 2:
video_list = glob.glob(os.path.join(args.frame_path, '*', '*'))
frame_info = {
os.path.relpath(x.split('.')[0], args.frame_path): (x, -1, -1)
for x in video_list
}
split_tp = parse_ucf101_splits(args.level)
assert len(split_tp) == args.num_split
out_path = args.out_list_path
for i, split in enumerate(split_tp):
lists = build_split_list(split_tp[i], frame_info, shuffle=args.shuffle)
filename = 'ucf101_train_split_{}_{}.txt'.format(i + 1, args.format)
with open(os.path.join(out_path, filename), 'w') as f:
f.writelines(lists[0])
filename = 'ucf101_val_split_{}_{}.txt'.format(i + 1, args.format)
with open(os.path.join(out_path, filename), 'w') as f:
f.writelines(lists[1])
if __name__ == "__main__":
main()
#! /usr/bin/bash env
DATA_DIR="./annotations"
if [[ ! -d "${DATA_DIR}" ]]; then
echo "${DATA_DIR} does not exist. Creating";
mkdir -p ${DATA_DIR}
fi
wget --no-check-certificate "https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip"
unzip -j UCF101TrainTestSplits-RecognitionTask.zip -d ${DATA_DIR}/
rm UCF101TrainTestSplits-RecognitionTask.zip
#! /usr/bin/bash env
wget --no-check-certificate "https://www.crcv.ucf.edu/data/UCF101/UCF101.rar"
unrar x UCF101.rar
mv ./UCF-101 ./videos
rm -rf ./UCF101.rar
import argparse
import sys
import os
import os.path as osp
import glob
from pipes import quote
from multiprocessing import Pool, current_process
import cv2
def dump_frames(vid_item):
full_path, vid_path, vid_id = vid_item
vid_name = vid_path.split('.')[0]
out_full_path = osp.join(args.out_dir, vid_name)
try:
os.mkdir(out_full_path)
except OSError:
pass
vr = cv2.VideoCapture(full_path)
videolen = int(vr.get(cv2.CAP_PROP_FRAME_COUNT))
for i in range(videolen):
ret, frame = vr.read()
if ret == False:
continue
img = frame[:, :, ::-1]
# covert the BGR img
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
if img is not None:
# cv2.imwrite will write BGR into RGB images
cv2.imwrite('{}/img_{:05d}.jpg'.format(out_full_path, i + 1), img)
else:
print('[Warning] length inconsistent!'
'Early stop with {} out of {} frames'.format(i + 1, videolen))
break
print('{} done with {} frames'.format(vid_name, videolen))
sys.stdout.flush()
return True
def parse_args():
parser = argparse.ArgumentParser(description='extract frames')
parser.add_argument('src_dir', type=str)
parser.add_argument('out_dir', type=str)
parser.add_argument('--level', type=int, choices=[1, 2], default=2)
parser.add_argument('--num_worker', type=int, default=8)
parser.add_argument(
"--ext",
type=str,
default='avi',
choices=['avi', 'mp4'],
help='video file extensions')
parser.add_argument(
"--resume",
action='store_true',
default=False,
help='resume optical flow extraction '
'instead of overwriting')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if not osp.isdir(args.out_dir):
print('Creating folder: {}'.format(args.out_dir))
os.makedirs(args.out_dir)
if args.level == 2:
classes = os.listdir(args.src_dir)
for classname in classes:
new_dir = osp.join(args.out_dir, classname)
if not osp.isdir(new_dir):
print('Creating folder: {}'.format(new_dir))
os.makedirs(new_dir)
print('Reading videos from folder: ', args.src_dir)
print('Extension of videos: ', args.ext)
if args.level == 2:
fullpath_list = glob.glob(args.src_dir + '/*/*.' + args.ext)
done_fullpath_list = glob.glob(args.out_dir + '/*/*')
elif args.level == 1:
fullpath_list = glob.glob(args.src_dir + '/*.' + args.ext)
done_fullpath_list = glob.glob(args.out_dir + '/*')
print('Total number of videos found: ', len(fullpath_list))
if args.resume:
fullpath_list = set(fullpath_list).difference(set(done_fullpath_list))
fullpath_list = list(fullpath_list)
print('Resuming. number of videos to be done: ', len(fullpath_list))
if args.level == 2:
vid_list = list(
map(lambda p: osp.join('/'.join(p.split('/')[-2:])), fullpath_list))
elif args.level == 1:
vid_list = list(map(lambda p: p.split('/')[-1], fullpath_list))
pool = Pool(args.num_worker)
pool.map(dump_frames, zip(fullpath_list, vid_list, range(len(vid_list))))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import argparse
import ast
import logging
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from model import TSN_ResNet
from utils.config_utils import *
from reader.ucf101_reader import UCF101Reader
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser("Paddle Video test script")
parser.add_argument(
'--config',
type=str,
default='./tsn_test.yaml',
help='path to config file of model')
parser.add_argument(
'--batch_size',
type=int,
default=None,
help='test batch size. None to use config file setting.')
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='default use gpu.')
parser.add_argument(
'--weights', type=str, default="./weights/final", help="weight path")
args = parser.parse_args()
return args
def test(args):
# parse config
config = parse_config(args.config)
test_config = merge_configs(config, 'test', vars(args))
print_configs(test_config, 'Test')
place = fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
video_model = TSN_ResNet(test_config)
model_dict, _ = fluid.load_dygraph(args.weights)
video_model.set_dict(model_dict)
test_reader = UCF101Reader(name="TSN", mode="test", cfg=test_config)
test_reader = test_reader.create_reader()
video_model.eval()
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0
for batch_id, data in enumerate(test_reader()):
x_data = np.array([item[0] for item in data])
y_data = np.array([item[1] for item in data]).reshape([-1, 1])
imgs = to_variable(x_data)
labels = to_variable(y_data)
labels.stop_gradient = True
outputs = video_model(imgs)
loss = fluid.layers.cross_entropy(
input=outputs, label=labels, ignore_index=-1)
avg_loss = fluid.layers.mean(loss)
acc_top1 = fluid.layers.accuracy(input=outputs, label=labels, k=1)
acc_top5 = fluid.layers.accuracy(input=outputs, label=labels, k=5)
total_loss += avg_loss.numpy()
total_acc1 += acc_top1.numpy()
total_acc5 += acc_top5.numpy()
total_sample += 1
print('TEST iter {}, loss = {}, acc1 {}, acc5 {}'.format(
batch_id, avg_loss.numpy(), acc_top1.numpy(), acc_top5.numpy()))
print('Finish loss {}, acc1 {}, acc5 {}'.format(
total_loss / total_sample, total_acc1 / total_sample, total_acc5 /
total_sample))
if __name__ == "__main__":
args = parse_args()
logger.info(args)
test(args)
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
import math
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(bn_name + "_offset"),
moving_mean_name=bn_name + "_mean",
moving_variance_name=bn_name + "_variance")
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class BottleneckBlock(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
name=None):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act="relu",
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act="relu",
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=stride,
name=name + "_branch1")
self.shortcut = shortcut
self._num_channels_out = num_filters * 4
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = fluid.layers.elementwise_add(x=short, y=conv2)
return fluid.layers.relu(y)
class BasicBlock(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=3,
stride=stride,
act="relu",
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
stride=stride,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = fluid.layers.elementwise_add(x=short, y=conv1)
layer_helper = LayerHelper(self.full_name(), act="relu")
return layer_helper.append_activation(y)
class TSN_ResNet(fluid.dygraph.Layer):
def __init__(self, config):
super(TSN_ResNet, self).__init__()
self.layers = config.MODEL.num_layers
self.seg_num = config.MODEL.seg_num
self.class_dim = config.MODEL.num_classes
supported_layers = [18, 34, 50, 101, 152]
assert self.layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if self.layers == 18:
depth = [2, 2, 2, 2]
elif self.layers == 34 or self.layers == 50:
depth = [3, 4, 6, 3]
elif self.layers == 101:
depth = [3, 4, 23, 3]
elif self.layers == 152:
depth = [3, 8, 36, 3]
num_channels = [64, 256, 512,
1024] if self.layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
self.conv = ConvBNLayer(
num_channels=3,
num_filters=64,
filter_size=7,
stride=2,
act="relu",
name="conv1")
self.pool2d_max = Pool2D(
pool_size=3, pool_stride=2, pool_padding=1, pool_type="max")
self.block_list = []
if self.layers >= 50:
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
if self.layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
conv_name,
BottleneckBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
name=conv_name))
self.block_list.append(bottleneck_block)
shortcut = True
else:
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
conv_name,
BasicBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block],
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
name=conv_name))
self.block_list.append(basic_block)
shortcut = True
self.pool2d_avg = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True)
self.pool2d_avg_channels = num_channels[-1] * 2
self.out = Linear(
self.pool2d_avg_channels,
self.class_dim,
act='softmax',
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(
loc=0.0, scale=0.01),
name="fc_0.w_0"),
bias_attr=ParamAttr(
initializer=fluid.initializer.ConstantInitializer(value=0.0),
name="fc_0.b_0"))
def forward(self, inputs):
y = fluid.layers.reshape(
inputs, [-1, inputs.shape[2], inputs.shape[3], inputs.shape[4]])
y = self.conv(y)
y = self.pool2d_max(y)
for block in self.block_list:
y = block(y)
y = self.pool2d_avg(y)
y = fluid.layers.dropout(
y, dropout_prob=0.2, dropout_implementation="upscale_in_train")
y = fluid.layers.reshape(y, [-1, self.seg_num, y.shape[1]])
y = fluid.layers.reduce_mean(y, dim=1)
y = fluid.layers.reshape(y, shape=[-1, 2048])
y = self.out(y)
return y
configs=$1
pretrain="" # set pretrain model path if needed
resume="" # set checkpoints model path if u want to resume training
save_dir=""
use_gpu=True
use_data_parallel=True
export CUDA_VISIBLE_DEVICES=4,5,6,7
echo $mode "TSN" $configs $resume $pretrain
if [ "$resume"x != ""x ]; then
python -m paddle.distributed.launch train.py \
--config=$configs \
--resume=$resume \
--use_gpu=$use_gpu \
--use_data_parallel=$use_data_parallel
elif [ "$pretrain"x != ""x ]; then
python -m paddle.distributed.launch train.py \
--config=$configs \
--pretrain=$pretrain \
--use_gpu=$use_gpu \
--use_data_parallel=$use_data_parallel
else
python -m paddle.distributed.launch train.py \
--config=$configs \
--use_gpu=$use_gpu\
--use_data_parallel=$use_data_parallel
fi
MODEL:
name: "TSN"
format: "frames" # support for "frames" or "videos"
num_classes: 101
seg_num: 3
seglen: 1
image_mean: [0.485, 0.456, 0.406]
image_std: [0.229, 0.224, 0.225]
num_layers: 50
topk: 5
TRAIN:
epoch: 80
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 128
use_gpu: True
filelist: "./data/dataset/ucf101/ucf101_train_split_1_rawframes.txt"
learning_rate: 0.001
learning_rate_decay: 0.1
decay_epochs: [30, 60]
l2_weight_decay: 1e-4
momentum: 0.9
total_videos: 9738
VALID:
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 128
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
TEST:
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 64
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
\ No newline at end of file
MODEL:
name: "TSN"
format: "videos" # support for "frames" or "videos"
num_classes: 101
seg_num: 3
seglen: 1
image_mean: [0.485, 0.456, 0.406]
image_std: [0.229, 0.224, 0.225]
num_layers: 50
topk: 5
TRAIN:
epoch: 80
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 128
use_gpu: True
filelist: "./data/dataset/ucf101/ucf101_train_split_1_videos.txt"
learning_rate: 0.001
learning_rate_decay: 0.1
decay_epochs: [30, 60]
l2_weight_decay: 1e-4
momentum: 0.9
total_videos: 9738
VALID:
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 128
filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt"
TEST:
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 64
filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt"
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import cv2
import numpy as np
import random
class ReaderNotFoundError(Exception):
"Error: reader not found"
def __init__(self, reader_name, avail_readers):
super(ReaderNotFoundError, self).__init__()
self.reader_name = reader_name
self.avail_readers = avail_readers
def __str__(self):
msg = "Reader {} Not Found.\nAvailiable readers:\n".format(
self.reader_name)
for reader in self.avail_readers:
msg += " {}\n".format(reader)
return msg
class DataReader(object):
"""data reader for video input"""
def __init__(self, model_name, mode, cfg):
self.name = model_name
self.mode = mode
self.cfg = cfg
def create_reader(self):
"""Not implemented"""
pass
def get_config_from_sec(self, sec, item, default=None):
if sec.upper() not in self.cfg:
return default
return self.cfg[sec.upper()].get(item, default)
class ReaderZoo(object):
def __init__(self):
self.reader_zoo = {}
def regist(self, name, reader):
assert reader.__base__ == DataReader, "Unknow model type {}".format(
type(reader))
self.reader_zoo[name] = reader
def get(self, name, mode, cfg):
for k, v in self.reader_zoo.items():
if k == name:
return v(name, mode, cfg)
raise ReaderNotFoundError(name, self.reader_zoo.keys())
# singleton reader_zoo
reader_zoo = ReaderZoo()
def regist_reader(name, reader):
reader_zoo.regist(name, reader)
def get_reader(name, mode, cfg):
reader_model = reader_zoo.get(name, mode, cfg)
return reader_model.create_reader()
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import cv2
import math
import random
import functools
import numpy as np
import paddle
import paddle.fluid as fluid
from PIL import Image, ImageEnhance
import logging
from .reader_utils import DataReader
logger = logging.getLogger(__name__)
python_ver = sys.version_info
class VideoRecord(object):
'''
define a class method which used to describe the frames information of videos
1. self._data[0] is the frames' path
2. self._data[1] is the number of frames
3. self._data[2] is the label of frames
'''
def __init__(self, row):
self._data = row
@property
def path(self):
return "./data/dataset/ucf101/rawframes/" + self._data[0]
@property
def num_frames(self):
return int(self._data[1])
@property
def label(self):
return int(self._data[2])
class UCF101Reader(DataReader):
"""
Data reader for kinetics dataset of two format mp4 and pkl.
1. mp4, the original format of kinetics400
2. pkl, the mp4 was decoded previously and stored as pkl
In both case, load the data, and then get the frame data in the form of numpy and label as an integer.
dataset cfg: format
num_classes
seg_num
short_size
target_size
num_reader_threads
buf_size
image_mean
image_std
batch_size
list
"""
def __init__(self, name, mode, cfg):
super(UCF101Reader, self).__init__(name, mode, cfg)
self.format = cfg.MODEL.format
self.num_classes = self.get_config_from_sec('model', 'num_classes')
self.seg_num = self.get_config_from_sec('model', 'seg_num')
self.seglen = self.get_config_from_sec('model', 'seglen')
self.seg_num = self.get_config_from_sec(mode, 'seg_num', self.seg_num)
self.short_size = self.get_config_from_sec(mode, 'short_size')
self.target_size = self.get_config_from_sec(mode, 'target_size')
self.num_reader_threads = self.get_config_from_sec(mode,
'num_reader_threads')
self.buf_size = self.get_config_from_sec(mode, 'buf_size')
self.fix_random_seed = self.get_config_from_sec(mode, 'fix_random_seed')
self.img_mean = np.array(cfg.MODEL.image_mean).reshape(
[3, 1, 1]).astype(np.float32)
self.img_std = np.array(cfg.MODEL.image_std).reshape(
[3, 1, 1]).astype(np.float32)
# set batch size and file list
self.batch_size = cfg[mode.upper()]['batch_size']
self.filelist = cfg[mode.upper()]['filelist']
# set num_trainers and trainer_id when distributed training is implemented
self.num_trainers = self.get_config_from_sec(mode, 'num_trainers', 1)
self.trainer_id = self.get_config_from_sec(mode, 'trainer_id', 0)
if self.mode == 'infer':
self.video_path = cfg[mode.upper()]['video_path']
else:
self.video_path = ''
if self.fix_random_seed:
random.seed(0)
np.random.seed(0)
self.num_reader_threads = 1
def create_reader(self):
# if set video_path for inference mode, just load this single video
if (self.mode == 'infer') and (self.video_path != ''):
# load video from file stored at video_path
_reader = self._inference_reader_creator(
self.video_path,
self.mode,
seg_num=self.seg_num,
seglen=self.seglen,
short_size=self.short_size,
target_size=self.target_size,
img_mean=self.img_mean,
img_std=self.img_std)
else:
assert os.path.exists(self.filelist), \
'{} not exist, please check the data list'.format(
self.filelist)
_reader = self._reader_creator(
self.filelist,
self.mode,
seg_num=self.seg_num,
seglen=self.seglen,
short_size=self.short_size,
target_size=self.target_size,
img_mean=self.img_mean,
img_std=self.img_std,
shuffle=(self.mode == 'train'),
num_threads=self.num_reader_threads,
buf_size=self.buf_size,
format=self.format)
def _batch_reader():
batch_out = []
for imgs, label in _reader():
if imgs is None:
continue
batch_out.append((imgs, label))
if len(batch_out) == self.batch_size:
yield batch_out
batch_out = []
return _batch_reader
def _inference_reader_creator(self, video_path, mode, seg_num, seglen,
short_size, target_size, img_mean, img_std):
def reader():
try:
imgs = mp4_loader(video_path, seg_num, seglen, mode)
if len(imgs) < 1:
logger.error('{} frame length {} less than 1.'.format(
video_path, len(imgs)))
yield None, None
except:
logger.error('Error when loading {}'.format(mp4_path))
yield None, None
imgs_ret = imgs_transform(imgs, mode, seg_num, seglen, short_size,
target_size, img_mean, img_std)
label_ret = video_path
yield imgs_ret, label_ret
return reader
def _reader_creator(self,
pickle_list,
mode,
seg_num,
seglen,
short_size,
target_size,
img_mean,
img_std,
shuffle=False,
num_threads=1,
buf_size=1024,
format='avi'):
def decode_mp4(sample, mode, seg_num, seglen, short_size, target_size,
img_mean, img_std):
sample = sample[0].split(' ')
mp4_path = "./data/dataset/ucf101/videos/" + sample[0] + ".avi"
# when infer, we store vid as label
label = int(sample[1]) - 1
try:
imgs = mp4_loader(mp4_path, seg_num, seglen, mode)
if len(imgs) < 1:
logger.error('{} frame length {} less than 1.'.format(
mp4_path, len(imgs)))
return None, None
except:
logger.error('Error when loading {}'.format(mp4_path))
return None, None
return imgs_transform(
imgs,
mode,
seg_num,
seglen,
short_size,
target_size,
img_mean,
img_std,
name=self.name), label
def decode_frames(sample, mode, seg_num, seglen, short_size,
target_size, img_mean, img_std):
recode = VideoRecord(sample[0].split(' '))
frames_dir_path = recode.path
# when infer, we store vid as label
label = recode.label
try:
imgs = frames_loader(recode, seg_num, seglen, mode)
if len(imgs) < 1:
logger.error('{} frame length {} less than 1.'.format(
frames_dir_path, len(imgs)))
return None, None
except:
logger.error('Error when loading {}'.format(frames_dir_path))
return None, None
return imgs_transform(
imgs,
mode,
seg_num,
seglen,
short_size,
target_size,
img_mean,
img_std,
name=self.name), label
def reader_():
with open(pickle_list) as flist:
full_lines = [line.strip() for line in flist]
if self.mode == 'train':
if (not hasattr(reader_, 'seed')):
reader_.seed = 0
random.Random(reader_.seed).shuffle(full_lines)
print("reader shuffle seed", reader_.seed)
if reader_.seed is not None:
reader_.seed += 1
per_node_lines = int(
math.ceil(len(full_lines) * 1.0 / self.num_trainers))
total_lines = per_node_lines * self.num_trainers
# aligned full_lines so that it can evenly divisible
full_lines += full_lines[:(total_lines - len(full_lines))]
assert len(full_lines) == total_lines
# trainer get own sample
lines = full_lines[self.trainer_id:total_lines:
self.num_trainers]
logger.info("trainerid %d, trainer_count %d" %
(self.trainer_id, self.num_trainers))
logger.info(
"read images from %d, length: %d, lines length: %d, total: %d"
% (self.trainer_id * per_node_lines, per_node_lines,
len(lines), len(full_lines)))
assert len(lines) == per_node_lines
for line in lines:
pickle_path = line.strip()
yield [pickle_path]
if format == 'frames':
decode_func = decode_frames
elif format == 'videos':
decode_func = decode_mp4
else:
raise "Not implemented format {}".format(format)
mapper = functools.partial(
decode_func,
mode=mode,
seg_num=seg_num,
seglen=seglen,
short_size=short_size,
target_size=target_size,
img_mean=img_mean,
img_std=img_std)
return fluid.io.xmap_readers(mapper, reader_, num_threads, buf_size)
def imgs_transform(imgs,
mode,
seg_num,
seglen,
short_size,
target_size,
img_mean,
img_std,
name=''):
imgs = group_scale(imgs, short_size)
if mode == 'train':
imgs = group_random_crop(imgs, target_size)
imgs = group_random_flip(imgs)
else:
imgs = group_center_crop(imgs, target_size)
np_imgs = (np.array(imgs[0]).astype('float32').transpose(
(2, 0, 1))).reshape(1, 3, target_size, target_size) / 255
for i in range(len(imgs) - 1):
img = (np.array(imgs[i + 1]).astype('float32').transpose(
(2, 0, 1))).reshape(1, 3, target_size, target_size) / 255
np_imgs = np.concatenate((np_imgs, img))
imgs = np_imgs
imgs -= img_mean
imgs /= img_std
imgs = np.reshape(imgs, (seg_num, seglen * 3, target_size, target_size))
return imgs
def group_multi_scale_crop(img_group,
target_size,
scales=None,
max_distort=1,
fix_crop=True,
more_fix_crop=True):
scales = scales if scales is not None else [1, .875, .75, .66]
input_size = [target_size, target_size]
im_size = img_group[0].size
# get random crop offset
def _sample_crop_size(im_size):
image_w, image_h = im_size[0], im_size[1]
base_size = min(image_w, image_h)
crop_sizes = [int(base_size * x) for x in scales]
crop_h = [
input_size[1] if abs(x - input_size[1]) < 3 else x
for x in crop_sizes
]
crop_w = [
input_size[0] if abs(x - input_size[0]) < 3 else x
for x in crop_sizes
]
pairs = []
for i, h in enumerate(crop_h):
for j, w in enumerate(crop_w):
if abs(i - j) <= max_distort:
pairs.append((w, h))
crop_pair = random.choice(pairs)
if not fix_crop:
w_offset = random.randint(0, image_w - crop_pair[0])
h_offset = random.randint(0, image_h - crop_pair[1])
else:
w_step = (image_w - crop_pair[0]) / 4
h_step = (image_h - crop_pair[1]) / 4
ret = list()
ret.append((0, 0)) # upper left
if w_step != 0:
ret.append((4 * w_step, 0)) # upper right
if h_step != 0:
ret.append((0, 4 * h_step)) # lower left
if h_step != 0 and w_step != 0:
ret.append((4 * w_step, 4 * h_step)) # lower right
if h_step != 0 or w_step != 0:
ret.append((2 * w_step, 2 * h_step)) # center
if more_fix_crop:
ret.append((0, 2 * h_step)) # center left
ret.append((4 * w_step, 2 * h_step)) # center right
ret.append((2 * w_step, 4 * h_step)) # lower center
ret.append((2 * w_step, 0 * h_step)) # upper center
ret.append((1 * w_step, 1 * h_step)) # upper left quarter
ret.append((3 * w_step, 1 * h_step)) # upper right quarter
ret.append((1 * w_step, 3 * h_step)) # lower left quarter
ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
w_offset, h_offset = random.choice(ret)
return crop_pair[0], crop_pair[1], w_offset, h_offset
crop_w, crop_h, offset_w, offset_h = _sample_crop_size(im_size)
crop_img_group = [
img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h))
for img in img_group
]
ret_img_group = [
img.resize((input_size[0], input_size[1]), Image.BILINEAR)
for img in crop_img_group
]
return ret_img_group
def group_random_crop(img_group, target_size):
w, h = img_group[0].size
th, tw = target_size, target_size
assert (w >= target_size) and (h >= target_size), \
"image width({}) and height({}) should be larger than crop size".format(
w, h, target_size)
out_images = []
x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
for img in img_group:
if w == tw and h == th:
out_images.append(img)
else:
out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
return out_images
def group_random_flip(img_group):
v = random.random()
if v < 0.5:
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
return ret
else:
return img_group
def group_center_crop(img_group, target_size):
img_crop = []
for img in img_group:
w, h = img.size
th, tw = target_size, target_size
assert (w >= target_size) and (h >= target_size), \
"image width({}) and height({}) should be larger than crop size".format(
w, h, target_size)
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
img_crop.append(img.crop((x1, y1, x1 + tw, y1 + th)))
return img_crop
def group_scale(imgs, target_size):
resized_imgs = []
for i in range(len(imgs)):
img = imgs[i]
w, h = img.size
if (w <= h and w == target_size) or (h <= w and h == target_size):
resized_imgs.append(img)
continue
if w < h:
ow = target_size
oh = int(target_size * 4.0 / 3.0)
resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
else:
oh = target_size
ow = int(target_size * 4.0 / 3.0)
resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
return resized_imgs
def mp4_loader(filepath, nsample, seglen, mode):
cap = cv2.VideoCapture(filepath)
videolen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
sampledFrames = []
for i in range(videolen):
ret, frame = cap.read()
# maybe first frame is empty
if ret == False:
continue
img = frame[:, :, ::-1]
sampledFrames.append(img)
average_dur = int(len(sampledFrames) / nsample)
imgs = []
for i in range(nsample):
idx = 0
if mode == 'train':
if average_dur >= seglen:
idx = random.randint(0, average_dur - seglen)
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
else:
if average_dur >= seglen:
idx = (average_dur - 1) // 2
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
for jj in range(idx, idx + seglen):
imgbuf = sampledFrames[int(jj % len(sampledFrames))]
img = Image.fromarray(imgbuf, mode='RGB')
imgs.append(img)
return imgs
def load_image(directory, idx):
return Image.open(os.path.join(directory, 'img_{:05d}.jpg'.format(
idx))).convert('RGB')
def frames_loader(recode, nsample, seglen, mode):
imgpath, num_frames = recode.path, recode.num_frames
average_dur = int(num_frames / nsample)
imgs = []
for i in range(nsample):
idx = 0
if mode == 'train':
if average_dur >= seglen:
idx = random.randint(0, average_dur - seglen)
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
else:
if average_dur >= seglen:
idx = (average_dur - 1) // 2
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
for jj in range(idx, idx + seglen):
img = load_image(imgpath, jj + 1)
imgs.append(img)
return imgs
configs=$1
weights=$2
use_gpu=True
use_data_parallel=False
export CUDA_VISIBLE_DEVICES=0
echo $mode $configs $weights
if [ "$weights"x != ""x ]; then
python eval.py --config=$configs \
--weights=$weights \
--use_gpu=$use_gpu
else
python eval.py --config=$configs \
--use_gpu=$use_gpu
fi
configs=$1
pretrain="" # set pretrain model path if needed
resume="" # set checkpoints model path if u want to resume training
save_dir=""
use_gpu=True
use_data_parallel=False
weights="" #set the path of weights to enable eval and predicut, just ignore this when training
export CUDA_VISIBLE_DEVICES=0
echo $mode "TSN" $configs $resume $pretrain
if [ "$resume"x != ""x ]; then
python train.py --config=$configs \
--resume=$resume \
--use_gpu=$use_gpu \
--use_data_parallel=$use_data_parallel
elif [ "$pretrain"x != ""x ]; then
python train.py --config=$configs \
--pretrain=$pretrain \
--use_gpu=$use_gpu \
--use_data_parallel=$use_data_parallel
else
python train.py --config=$configs \
--use_gpu=$use_gpu \
--use_data_parallel=$use_data_parallel
fi
MODEL:
name: "TSN"
format: "frames" # support for "frames" or "videos"
num_classes: 101
seg_num: 3
seglen: 1
image_mean: [0.485, 0.456, 0.406]
image_std: [0.229, 0.224, 0.225]
num_layers: 50
topk: 5
TRAIN:
epoch: 80
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 64
use_gpu: True
filelist: "./data/dataset/ucf101/ucf101_train_split_1_rawframes.txt"
learning_rate: 0.0005
learning_rate_decay: 0.1
decay_epochs: [30, 60]
l2_weight_decay: 1e-4
momentum: 0.9
total_videos: 9738
VALID:
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 128
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
TEST:
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 64
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
\ No newline at end of file
MODEL:
name: "TSN"
format: "videos" # support for "frames" or "videos"
num_classes: 101
seg_num: 3
seglen: 1
image_mean: [0.485, 0.456, 0.406]
image_std: [0.229, 0.224, 0.225]
num_layers: 50
topk: 5
TRAIN:
epoch: 80
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 64
use_gpu: True
filelist: "./data/dataset/ucf101/ucf101_train_split_1_videos.txt"
learning_rate: 0.0005
learning_rate_decay: 0.1
decay_epochs: [30, 60]
l2_weight_decay: 1e-4
momentum: 0.9
total_videos: 9738
VALID:
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 128
filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt"
TEST:
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 64
filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt"
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import argparse
import ast
import wget
import tarfile
import logging
import numpy as np
import paddle.fluid as fluid
import glob
from paddle.fluid.dygraph.base import to_variable
from model import TSN_ResNet
from utils.config_utils import *
from reader.ucf101_reader import UCF101Reader
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser("Paddle Video train script")
parser.add_argument(
'--config',
type=str,
default='tsn.yaml',
help='path to config file of model')
parser.add_argument(
'--batch_size',
type=int,
default=None,
help='training batch size. None to use config file setting.')
parser.add_argument(
'--pretrain',
type=str,
default=None,
help='path to pretrain weights. None to use default weights path in ~/.paddle/weights.'
)
parser.add_argument(
'--resume',
type=str,
default=None,
help='path to resume training based on previous checkpoints. '
'None for not resuming any checkpoints.')
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='default use gpu.')
parser.add_argument(
'--epoch',
type=int,
default=None,
help='epoch number, 0 for read from config file')
parser.add_argument(
'--use_data_parallel',
type=ast.literal_eval,
default=True,
help='default use data parallel.')
parser.add_argument(
'--checkpoint',
type=str,
default="./checkpoint",
help='path to resume training based on previous checkpoints. '
'None for not resuming any checkpoints.')
parser.add_argument(
'--weights',
type=str,
default="./weights",
help='path to save the final optimized model.'
'default path is "./weights".')
parser.add_argument(
'--validate',
type=str,
default=True,
help='whether to validating in training phase.'
'default value is True.')
args = parser.parse_args()
return args
def decompress(path):
t = tarfile.open(path)
print("path[0] {}".format(os.path.split(path)[0]))
t.extractall(path=os.path.split(path)[0])
t.close()
def download(url, path):
weight_dir = os.path.split(path)[0]
if not os.path.exists(weight_dir):
os.makedirs(weight_dir)
path = path + ".tar.gz"
print("path {}".format(path))
wget.download(url, path)
decompress(path)
def pretrain_info():
return (
'ResNet50_pretrained',
'https://paddlemodels.bj.bcebos.com/video_classification/ResNet50_pretrained.tar.gz'
)
def download_pretrained(pretrained):
if pretrained is not None:
WEIGHT_DIR = pretrained
else:
WEIGHT_DIR = os.path.join(os.path.expanduser('~'), '.paddle', 'weights')
path, url = pretrain_info()
if not path:
return None
path = os.path.join(WEIGHT_DIR, path)
if not os.path.isdir(WEIGHT_DIR):
logger.info('{} not exists, will be created automatically.'.format(
WEIGHT_DIR))
os.makedirs(WEIGHT_DIR)
if os.path.exists(path):
return path
logger.info("Download pretrain weights of ResNet50 from {}".format(url))
download(url, path)
return path
def init_model(model, pre_state_dict):
param_state_dict = {}
model_dict = model.state_dict()
for key in model_dict.keys():
weight_name = model_dict[key].name
if weight_name in pre_state_dict.keys(
) and weight_name != "fc_0.w_0" and weight_name != "fc_0.b_0":
print('Load weight: {}, shape: {}'.format(
weight_name, pre_state_dict[weight_name].shape))
param_state_dict[key] = pre_state_dict[weight_name]
else:
param_state_dict[key] = model_dict[key]
model.set_dict(param_state_dict)
return model
def val(epoch, model, cfg, args):
reader = UCF101Reader(name="TSN", mode="valid", cfg=cfg)
reader = reader.create_reader()
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0
for batch_id, data in enumerate(reader()):
x_data = np.array([item[0] for item in data])
y_data = np.array([item[1] for item in data]).reshape([-1, 1])
imgs = to_variable(x_data)
labels = to_variable(y_data)
labels.stop_gradient = True
outputs = model(imgs)
loss = fluid.layers.cross_entropy(
input=outputs, label=labels, ignore_index=-1)
avg_loss = fluid.layers.mean(loss)
acc_top1 = fluid.layers.accuracy(input=outputs, label=labels, k=1)
acc_top5 = fluid.layers.accuracy(input=outputs, label=labels, k=5)
dy_out = avg_loss.numpy()[0]
total_loss += dy_out
total_acc1 += acc_top1.numpy()[0]
total_acc5 += acc_top5.numpy()[0]
total_sample += 1
if batch_id % 5 == 0:
print(
"TEST Epoch {}, iter {}, loss={:.5f}, acc1 {:.5f}, acc5 {:.5f}".
format(epoch, batch_id, total_loss / total_sample, total_acc1 /
total_sample, total_acc5 / total_sample))
print('Finish loss {} , acc1 {} , acc5 {}'.format(
total_loss / total_sample, total_acc1 / total_sample, total_acc5 /
total_sample))
return total_acc1 / total_sample
def create_optimizer(cfg, params):
total_videos = cfg.total_videos
step = int(total_videos / cfg.batch_size + 1)
bd = [e * step for e in cfg.decay_epochs]
base_lr = cfg.learning_rate
lr_decay = cfg.learning_rate_decay
lr = [base_lr, base_lr * lr_decay, base_lr * lr_decay * lr_decay]
l2_weight_decay = cfg.l2_weight_decay
momentum = cfg.momentum
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=momentum,
regularization=fluid.regularizer.L2Decay(l2_weight_decay),
parameter_list=params)
return optimizer
def train(args):
config = parse_config(args.config)
train_config = merge_configs(config, 'train', vars(args))
valid_config = merge_configs(config, 'valid', vars(args))
print_configs(train_config, 'Train')
# get the pretrained weights
pretrained_path = download_pretrained(args.pretrain)
use_data_parallel = args.use_data_parallel
trainer_count = fluid.dygraph.parallel.Env().nranks
# (data_parallel step1/6)
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if use_data_parallel else fluid.CUDAPlace(0)
pre_state_dict = fluid.load_program_state(pretrained_path)
with fluid.dygraph.guard(place):
if use_data_parallel:
# (data_parallel step2/6)
strategy = fluid.dygraph.parallel.prepare_context()
video_model = TSN_ResNet(train_config)
video_model = init_model(video_model, pre_state_dict)
optimizer = create_optimizer(train_config.TRAIN,
video_model.parameters())
if use_data_parallel:
# (data_parallel step3/6)
video_model = fluid.dygraph.parallel.DataParallel(video_model,
strategy)
bs_denominator = 1
if args.use_gpu:
# check number of GPUs
gpus = os.getenv("CUDA_VISIBLE_DEVICES", "")
if gpus == "":
pass
else:
gpus = gpus.split(",")
num_gpus = len(gpus)
bs_denominator = num_gpus
train_config.TRAIN.batch_size = int(train_config.TRAIN.batch_size /
bs_denominator)
train_reader = UCF101Reader(name="TSN", mode="train", cfg=train_config)
train_reader = train_reader.create_reader()
if use_data_parallel:
# (data_parallel step4/6)
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)
# resume training the model
if args.resume is not None:
model_state, opt_state = fluid.load_dygraph(args.resume)
video_model.set_dict(model_state)
optimizer.set_dict(opt_state)
for epoch in range(1, train_config.TRAIN.epoch + 1):
video_model.train()
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0
batch_start = time.time()
for batch_id, data in enumerate(train_reader()):
train_reader_cost = time.time() - batch_start
x_data = np.array([item[0] for item in data]).astype("float32")
y_data = np.array([item[1] for item in data]).reshape([-1, 1])
imgs = to_variable(x_data)
labels = to_variable(y_data)
labels.stop_gradient = True
outputs = video_model(imgs)
loss = fluid.layers.cross_entropy(
input=outputs, label=labels, ignore_index=-1)
avg_loss = fluid.layers.mean(loss)
acc_top1 = fluid.layers.accuracy(
input=outputs, label=labels, k=1)
acc_top5 = fluid.layers.accuracy(
input=outputs, label=labels, k=5)
dy_out = avg_loss.numpy()[0]
if use_data_parallel:
# (data_parallel step5/6)
avg_loss = video_model.scale_loss(avg_loss)
avg_loss.backward()
video_model.apply_collective_grads()
else:
avg_loss.backward()
optimizer.minimize(avg_loss)
video_model.clear_gradients()
total_loss += dy_out
total_acc1 += acc_top1.numpy()[0]
total_acc5 += acc_top5.numpy()[0]
total_sample += 1
train_batch_cost = time.time() - batch_start
print(
'TRAIN Epoch: {}, iter: {}, batch_cost: {:.5f} s, reader_cost: {:.5f} s, loss={:.6f}, acc1 {:.6f}, acc5 {:.6f} '.
format(epoch, batch_id, train_batch_cost, train_reader_cost,
total_loss / total_sample, total_acc1 / total_sample,
total_acc5 / total_sample))
batch_start = time.time()
print(
'TRAIN End, Epoch {}, avg_loss= {}, avg_acc1= {}, avg_acc5= {}'.
format(epoch, total_loss / total_sample, total_acc1 /
total_sample, total_acc5 / total_sample))
# save model's and optimizer's parameters which used for resuming the training stage
save_parameters = (not use_data_parallel) or (
use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0)
if save_parameters:
model_path_pre = "_tsn"
if not os.path.isdir(args.checkpoint):
os.makedirs(args.checkpoint)
model_path = os.path.join(
args.checkpoint,
"_" + model_path_pre + "_epoch{}".format(epoch))
fluid.dygraph.save_dygraph(video_model.state_dict(), model_path)
fluid.dygraph.save_dygraph(optimizer.state_dict(), model_path)
if args.validate:
video_model.eval()
val_acc = val(epoch, video_model, valid_config, args)
# save the best parameters in trainging stage
if epoch == 1:
best_acc = val_acc
else:
if val_acc > best_acc:
best_acc = val_acc
if fluid.dygraph.parallel.Env().local_rank == 0:
if not os.path.isdir(args.weights):
os.makedirs(args.weights)
fluid.dygraph.save_dygraph(video_model.state_dict(),
args.weights + "/final")
else:
if fluid.dygraph.parallel.Env().local_rank == 0:
if not os.path.isdir(args.weights):
os.makedirs(args.weights)
fluid.dygraph.save_dygraph(video_model.state_dict(),
args.weights + "/final")
logger.info('[TRAIN] training finished')
if __name__ == "__main__":
args = parse_args()
logger.info(args)
train(args)
MODEL:
name: "TSN"
format: "frames"
num_classes: 101
seg_num: 25
seglen: 1
image_mean: [0.485, 0.456, 0.406]
image_std: [0.229, 0.224, 0.225]
num_layers: 50
topk: 5
VALID:
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 4
batch_size: 32
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
TEST:
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 4
batch_size: 32
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import yaml
import logging
logger = logging.getLogger(__name__)
CONFIG_SECS = [
'train',
'valid',
'test',
'infer',
]
class AttrDict(dict):
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
if key in self.__dict__:
self.__dict__[key] = value
else:
self[key] = value
def parse_config(cfg_file):
"""Load a config file into AttrDict"""
import yaml
with open(cfg_file, 'r') as fopen:
yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.Loader))
create_attr_dict(yaml_config)
return yaml_config
def create_attr_dict(yaml_config):
from ast import literal_eval
for key, value in yaml_config.items():
if type(value) is dict:
yaml_config[key] = value = AttrDict(value)
if isinstance(value, str):
try:
value = literal_eval(value)
except BaseException:
pass
if isinstance(value, AttrDict):
create_attr_dict(yaml_config[key])
else:
yaml_config[key] = value
return
def merge_configs(cfg, sec, args_dict):
assert sec in CONFIG_SECS, "invalid config section {}".format(sec)
sec_dict = getattr(cfg, sec.upper())
for k, v in args_dict.items():
if v is None:
continue
try:
if hasattr(sec_dict, k):
setattr(sec_dict, k, v)
except:
pass
return cfg
def print_configs(cfg, mode):
logger.info("---------------- {:>5} Arguments ----------------".format(
mode))
for sec, sec_items in cfg.items():
logger.info("{}:".format(sec))
for k, v in sec_items.items():
logger.info(" {}:{}".format(k, v))
logger.info("-------------------------------------------------")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册