提交 e035d917 编写于 作者: L LiuChaoXD

refined dynamic tsn 2020-08-30

上级 ba0dd405
# TSN 视频分类模型
本目录下为基于PaddlePaddle 动态图实现的 TSN视频分类模型
本目录下为基于PaddlePaddle 动态图实现的TSN视频分类模型。模型支持PaddlePaddle Fluid 1.8, GPU, Linux。
---
## 内容
- [模型简介](#模型简介)
- [安装说明](#安装说明)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
- [模型评估](#模型评估)
......@@ -13,10 +14,39 @@
## 模型简介
Temporal Segment Network (TSN) 是视频分类领域经典的基于2D-CNN的解决方案。该方法主要解决视频的长时间行为判断问题,通过稀疏采样视频帧的方式代替稠密采样,既能捕获视频全局信息,也能去除冗余,降低计算量。最终将每帧特征平均融合后得到视频的整体特征,并用于分类。本代码实现的模型为基于单路RGB图像的TSN网络结构,Backbone采用ResNet-50结构。
Temporal Segment Network (TSN) 是视频分类领域经典的基于2D-CNN的解决方案。该方法主要解决视频的长时间行为判断问题,通过稀疏采样视频帧的方式代替稠密采样,既能捕获视频全局信息,也能去除冗余,降低计算量。最终将每帧特征平均融合后得到视频的整体特征,并用于分类。本代码实现的模型为基于单路RGB图像的TSN网络结构,Backbone采用ResNet50结构。
详细内容请参考ECCV 2016年论文[Temporal Segment Networks: Towards Good Practices for Deep Action Recognition](https://arxiv.org/abs/1608.00859)
## 安装说明
### 环境依赖:
```
python=3.7
paddlepaddle-gpu==1.8.3.post97
opencv=4.3
CUDA >= 9.0
cudnn >= 7.5
wget
numpy
```
### 依赖安装:
- 安装PaddlePaddle,GPU版本:
``` pip3 install paddlepaddle-gpu==1.8.3.post97 -i https://mirror.baidu.com/pypi/simple```
- 安装opencv 4.2:
``` pip3 install opencv-python==4.3.0.36```
- 安装wget
``` pip3 install wget```
- 安装numpy
``` pip3 install numpy```
## 数据准备
TSN的训练数据采用UCF101动作识别数据集。数据下载及准备请参考[数据说明](./data/dataset/ucf101/README.md)
......@@ -27,32 +57,30 @@ TSN的训练数据采用UCF101动作识别数据集。数据下载及准备请
1. 多卡训练
```bash
bash multi-gpus-run.sh ./configs/tsn.yaml
bash multi_gpus_run.sh ./configs/tsn.yaml
```
多卡训练所使用的gpu可以通过如下方式设置:
- 首先,修改`./configs/tsn.yaml` 中的 num_gpus (默认为4,表示使用4个gpu进行训练)
- 其次,修改`multi-gpus-run.sh``export CUDA_VISIBLE_DEVICES=0,1,2,3`(默认为0,1,2,3表示使用0,1,2,3卡号的gpu进行训练)
- 注意:若修改了batchsize则学习率也要做相应的修改。例如,默认batchsize=128,lr=0.001,若batchsize=64,lr=0.0005
- 修改`multi_gpus_run.sh``export CUDA_VISIBLE_DEVICES=0,1,2,3`(默认为0,1,2,3表示使用0,1,2,3卡号的gpu进行训练)
- 注意:若修改了batchsize则学习率也要做相应的修改,规则为大batchsize用大lr,即同倍数增长缩小关系。例如,默认batchsize=128,lr=0.001,若batchsize=64,lr=0.0005
2. 单卡训练
```bash
bash run.sh ./configs/tsn.yaml
bash single_gpu_run.sh ./configs/tsn.yaml
```
单卡训练所使用的gpu可以通过如下方式设置:
- 首先,修改`./configs/tsn.yaml` 中的 `num_gpus=1` (表示使用单卡进行训练)
- 其次,修改 `run.sh` 中的 `export CUDA_VISIBLE_DEVICES=0` (表示使用gpu 0 进行模型训练)
- 注意,若修改了batchsize则学习率也要做相应的修改。例如,默认batchsize=128,lr=0.001,若batchsize=64,lr=0.0005
- 修改 `run.sh` 中的 `export CUDA_VISIBLE_DEVICES=0` (表示使用gpu 0 进行模型训练)
- 注意:若修改了batchsize则学习率也要做相应的修改,规则为大batchsize用大lr,即同倍数增长缩小关系。例如,默认batchsize=128,lr=0.001,若batchsize=64,lr=0.0005
## 模型评估
可通过如下方式进行模型评估:
```bash
bash run-eval.sh ./configs/tsn-test.yaml ./weights/final.pdparams
bash run_eval.sh ./configs/tsn_test.yaml ./weights/final.pdparams
```
- 使用`run.sh`进行评估时,需要修改脚本中的`weights`参数指定需要评估的权重
- `./tsn-test.yaml` 是评估模型时所用的参数文件;`./weights/final.pdparams` 为模型训练完成后,保存的模型文件
- `./tsn_test.yaml` 是评估模型时所用的参数文件;`./weights/final.pdparams` 为模型训练完成后,保存的模型文件
- 评估结果以log的形式直接打印输出TOP1\_ACC、TOP5\_ACC等精度指标
......@@ -62,7 +90,6 @@ bash run-eval.sh ./configs/tsn-test.yaml ./weights/final.pdparams
| | seg\_num | Top-1 | Top-5 |
| :------: | :----------: | :----: | :----: |
| Pytorch TSN | 3 | 83.88% | 96.78% |
| Paddle TSN (静态图) | 3 | 84.00% | 97.38% |
| Paddle TSN (动态图) | 3 | 84.27% | 97.27% |
......
......@@ -14,29 +14,30 @@ bash download_annotations.sh
### 下载UCF101的视频文件
同样需要确保在`./data/dataset/ucf101/`目录下,输入下述命令下载视频文件
```shell
bash download_annotations.sh
bash download_videos.sh
```
下载完成后视频文件会存储在`./data/dataset/ucf101/videos/`文件夹下
下载完成后视频文件会存储在`./data/dataset/ucf101/videos/`文件夹下,视频文件大小为6.8G。
---
## 2. 提取视频文件的frames
为了加速网络的训练过程,我们首先对视频文件(ucf101视频文件为avi格式)提取帧 (frames)。通过读取frames的方式替换原始的直接读取视频文件,能够极大的减小巡训练的时间开销
为了加速网络的训练过程,我们首先对视频文件(ucf101视频文件为avi格式)提取帧 (frames)。相对于直接通过视频文件进行网络训练的方式,frames的方式能够加快网络训练的速度
直接输入如下命令,即可提取ucf101视频文件的frames
``` python
python extract_rawframes.py ./videos/ ./rawframes/ --level 2 --ext avi
```
视频文件frames提取完成后,会存储在`./rawframes`文件夹下,大小为56G。
---
## 3. 生成frames文件和视频文件的路径list
生成视频文件的路径list,输入如下命令
```python
python build_ucf101_file_list.py videos/ --level 2 --format videos --out_list_path ./ --shuffle
python build_ucf101_file_list.py videos/ --level 2 --format videos --out_list_path ./
```
生成frames文件的路径list,输入如下命令:
```python
python build_ucf101_file_list.py rawframes/ --level 2 --format rawframes --out_list_path ./ --shuffle
python build_ucf101_file_list.py rawframes/ --level 2 --format rawframes --out_list_path ./
```
**参数说明**
......@@ -49,8 +50,6 @@ python extract_rawframes.py ./videos/ ./rawframes/ --level 2 --ext avi
`--out_list_path `: 表示生的路径list文件存储位置
`--shuffle`: 表示对路径list中的文件顺序进行shuffle
# 以上步骤完成后,文件组织形式如下所示
......
......@@ -50,7 +50,6 @@ def build_split_list(split, frame_info, shuffle=False):
rgb_list = list()
for item in set_list:
if item[0] not in frame_info:
# print("item:", item)
continue
elif frame_info[item[0]][1] > 0:
rgb_cnt = frame_info[item[0]][1]
......@@ -97,9 +96,6 @@ def parse_args():
'frame_path', type=str, help='root directory for the frames')
parser.add_argument('--rgb_prefix', type=str, default='img_')
parser.add_argument('--num_split', type=int, default=3)
parser.add_argument(
'--subset', type=str, default='train',
choices=['train', 'val', 'test'])
parser.add_argument('--level', type=int, default=2, choices=[1, 2])
parser.add_argument(
'--format',
......@@ -145,29 +141,16 @@ def main():
assert len(split_tp) == args.num_split
out_path = args.out_list_path
if len(split_tp) > 1:
for i, split in enumerate(split_tp):
lists = build_split_list(
split_tp[i], frame_info, shuffle=args.shuffle)
filename = 'ucf101_train_split_{}_{}.txt'.format(i + 1, args.format)
with open(os.path.join(out_path, filename), 'w') as f:
f.writelines(lists[0])
filename = 'ucf101_val_split_{}_{}.txt'.format(i + 1, args.format)
with open(os.path.join(out_path, filename), 'w') as f:
f.writelines(lists[1])
else:
lists = build_split_list(split_tp[0], frame_info, shuffle=args.shuffle)
filename = '{}_{}_list_{}.txt'.format(args.dataset, args.subset,
args.format)
if args.subset == 'train':
ind = 0
elif args.subset == 'val':
ind = 1
elif args.subset == 'test':
ind = 2
for i, split in enumerate(split_tp):
lists = build_split_list(split_tp[i], frame_info, shuffle=args.shuffle)
filename = 'ucf101_train_split_{}_{}.txt'.format(i + 1, args.format)
with open(os.path.join(out_path, filename), 'w') as f:
f.writelines(lists[0])
filename = 'ucf101_val_split_{}_{}.txt'.format(i + 1, args.format)
with open(os.path.join(out_path, filename), 'w') as f:
f.writelines(lists[0][ind])
f.writelines(lists[1])
if __name__ == "__main__":
......
......@@ -23,9 +23,10 @@ def dump_frames(vid_item):
if ret == False:
continue
img = frame[:, :, ::-1]
# covert the BGR img into RGB img
# covert the BGR img
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
if img is not None:
# cv2.imwrite will write BGR into RGB images
cv2.imwrite('{}/img_{:05d}.jpg'.format(out_full_path, i + 1), img)
else:
print('[Warning] length inconsistent!'
......@@ -37,27 +38,18 @@ def dump_frames(vid_item):
def parse_args():
parser = argparse.ArgumentParser(description='extract optical flows')
parser = argparse.ArgumentParser(description='extract frames')
parser.add_argument('src_dir', type=str)
parser.add_argument('out_dir', type=str)
parser.add_argument('--level', type=int, choices=[1, 2], default=2)
parser.add_argument('--num_worker', type=int, default=8)
parser.add_argument(
"--out_format",
type=str,
default='dir',
choices=['dir', 'zip'],
help='output format')
parser.add_argument(
"--ext",
type=str,
default='avi',
choices=['avi', 'mp4'],
help='video file extensions')
parser.add_argument(
"--new_width", type=int, default=0, help='resize image width')
parser.add_argument(
"--new_height", type=int, default=0, help='resize image height')
parser.add_argument(
"--resume",
action='store_true',
......
......@@ -37,7 +37,7 @@ def parse_args():
parser.add_argument(
'--config',
type=str,
default='./tsn-test.yaml',
default='./tsn_test.yaml',
help='path to config file of model')
parser.add_argument(
'--batch_size',
......@@ -69,7 +69,6 @@ def test(args):
video_model.set_dict(model_dict)
test_reader = UCF101Reader(name="TSN", mode="test", cfg=test_config)
#test_reader = KineticsReader(mode='test', cfg=test_config)
test_reader = test_reader.create_reader()
video_model.eval()
......
......@@ -25,8 +25,6 @@ from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
import math
__all__ = ["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"]
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
......@@ -119,9 +117,6 @@ class BottleneckBlock(fluid.dygraph.Layer):
y = fluid.layers.elementwise_add(x=short, y=conv2)
return fluid.layers.relu(y)
# layer_helper = LayerHelper(self.full_name(), act="relu")
# return layer_helper.append_activation(y)
class BasicBlock(fluid.dygraph.Layer):
def __init__(self,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import cv2
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,24 +18,9 @@ import cv2
import math
import random
import functools
try:
import cPickle as pickle
from cStringIO import StringIO
except ImportError:
import pickle
from io import BytesIO
import numpy as np
import paddle
import paddle.fluid as fluid
try:
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import tempfile
from nvidia.dali.plugin.paddle import DALIGenericIterator
except:
Pipeline = object
print("DALI is not installed, you can improve performance if use DALI")
from PIL import Image, ImageEnhance
import logging
......@@ -115,9 +100,6 @@ class UCF101Reader(DataReader):
# set num_trainers and trainer_id when distributed training is implemented
self.num_trainers = self.get_config_from_sec(mode, 'num_trainers', 1)
self.trainer_id = self.get_config_from_sec(mode, 'trainer_id', 0)
self.use_dali = self.get_config_from_sec(mode, 'use_dali', False)
self.dali_mean = cfg.MODEL.image_mean * (self.seg_num * self.seglen)
self.dali_std = cfg.MODEL.image_std * (self.seg_num * self.seglen)
if self.mode == 'infer':
self.video_path = cfg[mode.upper()]['video_path']
......@@ -129,9 +111,6 @@ class UCF101Reader(DataReader):
self.num_reader_threads = 1
def create_reader(self):
# if use_dali to improve performance
if self.use_dali:
return self.build_dali_reader()
# if set video_path for inference mode, just load this single video
if (self.mode == 'infer') and (self.video_path != ''):
......@@ -237,42 +216,6 @@ class UCF101Reader(DataReader):
img_std,
name=self.name), label
def decode_pickle(sample, mode, seg_num, seglen, short_size,
target_size, img_mean, img_std):
pickle_path = sample[0]
try:
if python_ver < (3, 0):
data_loaded = pickle.load(open(pickle_path, 'rb'))
else:
data_loaded = pickle.load(
open(pickle_path, 'rb'), encoding='bytes')
vid, label, frames = data_loaded
if len(frames) < 1:
logger.error('{} frame length {} less than 1.'.format(
pickle_path, len(frames)))
return None, None
except:
logger.info('Error when loading {}'.format(pickle_path))
return None, None
if mode == 'train' or mode == 'valid' or mode == 'test':
ret_label = label
elif mode == 'infer':
ret_label = vid
imgs = video_loader(frames, seg_num, seglen, mode)
return imgs_transform(
imgs,
mode,
seg_num,
seglen,
short_size,
target_size,
img_mean,
img_std,
name=self.name), ret_label
def decode_frames(sample, mode, seg_num, seglen, short_size,
target_size, img_mean, img_std):
recode = VideoRecord(sample[0].split(' '))
......@@ -334,11 +277,9 @@ class UCF101Reader(DataReader):
pickle_path = line.strip()
yield [pickle_path]
if format == 'pkl':
decode_func = decode_pickle
if format == 'frames':
decode_func = decode_frames
elif format == 'mp4' or 'avi':
elif format == 'videos':
decode_func = decode_mp4
else:
raise "Not implemented format {}".format(format)
......@@ -355,249 +296,6 @@ class UCF101Reader(DataReader):
return fluid.io.xmap_readers(mapper, reader_, num_threads, buf_size)
def build_dali_reader(self):
"""
build dali training reader
"""
def reader_():
with open(self.filelist) as flist:
full_lines = [line for line in flist]
if self.mode == 'train':
if (not hasattr(reader_, 'seed')):
reader_.seed = 0
random.Random(reader_.seed).shuffle(full_lines)
print("reader shuffle seed", reader_.seed)
if reader_.seed is not None:
reader_.seed += 1
per_node_lines = int(
math.ceil(len(full_lines) * 1.0 / self.num_trainers))
total_lines = per_node_lines * self.num_trainers
# aligned full_lines so that it can evenly divisible
full_lines += full_lines[:(total_lines - len(full_lines))]
assert len(full_lines) == total_lines
# trainer get own sample
lines = full_lines[self.trainer_id:total_lines:
self.num_trainers]
assert len(lines) == per_node_lines
logger.info("trainerid %d, trainer_count %d" %
(self.trainer_id, self.num_trainers))
logger.info(
"read images from %d, length: %d, lines length: %d, total: %d"
% (self.trainer_id * per_node_lines, per_node_lines,
len(lines), len(full_lines)))
video_files = ''
for item in lines:
video_files += item
tf = tempfile.NamedTemporaryFile()
tf.write(str.encode(video_files))
tf.flush()
video_files = tf.name
device_id = int(os.getenv('FLAGS_selected_gpus', 0))
print('---------- device id -----------', device_id)
if self.mode == 'train':
pipe = VideoPipe(
batch_size=self.batch_size,
num_threads=1,
device_id=device_id,
file_list=video_files,
sequence_length=self.seg_num * self.seglen,
seg_num=self.seg_num,
seg_length=self.seglen,
resize_shorter_scale=self.short_size,
crop_target_size=self.target_size,
is_training=(self.mode == 'train'),
dali_mean=self.dali_mean,
dali_std=self.dali_std)
else:
pipe = VideoTestPipe(
batch_size=self.batch_size,
num_threads=1,
device_id=device_id,
file_list=video_files,
sequence_length=self.seg_num * self.seglen,
seg_num=self.seg_num,
seg_length=self.seglen,
resize_shorter_scale=self.short_size,
crop_target_size=self.target_size,
is_training=(self.mode == 'train'),
dali_mean=self.dali_mean,
dali_std=self.dali_std)
logger.info(
'initializing dataset, it will take several minutes if it is too large .... '
)
video_loader = DALIGenericIterator(
[pipe], ['image', 'label'],
len(lines),
dynamic_shape=True,
auto_reset=True)
return video_loader
dali_reader = reader_()
def ret_reader():
for data in dali_reader:
yield data[0]['image'], data[0]['label']
return ret_reader
class VideoPipe(Pipeline):
def __init__(self,
batch_size,
num_threads,
device_id,
file_list,
sequence_length,
seg_num,
seg_length,
resize_shorter_scale,
crop_target_size,
is_training=False,
initial_prefetch_size=10,
num_shards=1,
shard_id=0,
dali_mean=0.,
dali_std=1.0):
super(VideoPipe, self).__init__(batch_size, num_threads, device_id)
self.input = ops.VideoReader(
device="gpu",
file_list=file_list,
sequence_length=sequence_length,
seg_num=seg_num,
seg_length=seg_length,
is_training=is_training,
num_shards=num_shards,
shard_id=shard_id,
random_shuffle=is_training,
initial_fill=initial_prefetch_size)
# the sequece data read by ops.VideoReader is of shape [F, H, W, C]
# Because the ops.Resize does not support sequence data,
# it will be transposed into [H, W, F, C],
# then reshaped to [H, W, FC], and then resized like a 2-D image.
self.transpose = ops.Transpose(device="gpu", perm=[1, 2, 0, 3])
self.reshape = ops.Reshape(
device="gpu", rel_shape=[1.0, 1.0, -1], layout='HWC')
self.resize = ops.Resize(
device="gpu", resize_shorter=resize_shorter_scale)
# crops and mirror are applied by ops.CropMirrorNormalize.
# Normalization will be implemented in paddle due to the difficulty of dimension broadcast,
# It is not sure whether dimension broadcast can be implemented correctly by dali, just take the Paddle Op instead.
self.pos_rng_x = ops.Uniform(range=(0.0, 1.0))
self.pos_rng_y = ops.Uniform(range=(0.0, 1.0))
self.mirror_generator = ops.Uniform(range=(0.0, 1.0))
self.cast_mirror = ops.Cast(dtype=types.DALIDataType.INT32)
self.crop_mirror_norm = ops.CropMirrorNormalize(
device="gpu",
crop=[crop_target_size, crop_target_size],
mean=dali_mean,
std=dali_std)
self.reshape_back = ops.Reshape(
device="gpu",
shape=[
seg_num, seg_length * 3, crop_target_size, crop_target_size
],
layout='FCHW')
self.cast_label = ops.Cast(device="gpu", dtype=types.DALIDataType.INT64)
def define_graph(self):
output, label = self.input(name="Reader")
output = self.transpose(output)
output = self.reshape(output)
output = self.resize(output)
output = output / 255.
pos_x = self.pos_rng_x()
pos_y = self.pos_rng_y()
mirror_flag = self.mirror_generator()
mirror_flag = (mirror_flag > 0.5)
mirror_flag = self.cast_mirror(mirror_flag)
#output = self.crop(output, crop_pos_x=pos_x, crop_pos_y=pos_y)
output = self.crop_mirror_norm(
output, crop_pos_x=pos_x, crop_pos_y=pos_y, mirror=mirror_flag)
output = self.reshape_back(output)
label = self.cast_label(label)
return output, label
class VideoTestPipe(Pipeline):
def __init__(self,
batch_size,
num_threads,
device_id,
file_list,
sequence_length,
seg_num,
seg_length,
resize_shorter_scale,
crop_target_size,
is_training=False,
initial_prefetch_size=10,
num_shards=1,
shard_id=0,
dali_mean=0.,
dali_std=1.0):
super(VideoTestPipe, self).__init__(batch_size, num_threads, device_id)
self.input = ops.VideoReader(
device="gpu",
file_list=file_list,
sequence_length=sequence_length,
seg_num=seg_num,
seg_length=seg_length,
is_training=is_training,
num_shards=num_shards,
shard_id=shard_id,
random_shuffle=is_training,
initial_fill=initial_prefetch_size)
# the sequece data read by ops.VideoReader is of shape [F, H, W, C]
# Because the ops.Resize does not support sequence data,
# it will be transposed into [H, W, F, C],
# then reshaped to [H, W, FC], and then resized like a 2-D image.
self.transpose = ops.Transpose(device="gpu", perm=[1, 2, 0, 3])
self.reshape = ops.Reshape(
device="gpu", rel_shape=[1.0, 1.0, -1], layout='HWC')
self.resize = ops.Resize(
device="gpu", resize_shorter=resize_shorter_scale)
# crops and mirror are applied by ops.CropMirrorNormalize.
# Normalization will be implemented in paddle due to the difficulty of dimension broadcast,
# It is not sure whether dimension broadcast can be implemented correctly by dali, just take the Paddle Op instead.
self.crop_mirror_norm = ops.CropMirrorNormalize(
device="gpu",
crop=[crop_target_size, crop_target_size],
crop_pos_x=0.5,
crop_pos_y=0.5,
mirror=0,
mean=dali_mean,
std=dali_std)
self.reshape_back = ops.Reshape(
device="gpu",
shape=[
seg_num, seg_length * 3, crop_target_size, crop_target_size
],
layout='FCHW')
self.cast_label = ops.Cast(device="gpu", dtype=types.DALIDataType.INT64)
def define_graph(self):
output, label = self.input(name="Reader")
output = self.transpose(output)
output = self.reshape(output)
output = self.resize(output)
output = output / 255.
#output = self.crop(output, crop_pos_x=pos_x, crop_pos_y=pos_y)
output = self.crop_mirror_norm(output)
output = self.reshape_back(output)
label = self.cast_label(label)
return output, label
def imgs_transform(imgs,
mode,
......@@ -611,8 +309,7 @@ def imgs_transform(imgs,
imgs = group_scale(imgs, short_size)
if mode == 'train':
if name == "TSM":
imgs = group_multi_scale_crop(imgs, short_size)
imgs = group_random_crop(imgs, target_size)
imgs = group_random_flip(imgs)
else:
......@@ -777,47 +474,6 @@ def group_scale(imgs, target_size):
return resized_imgs
def imageloader(buf):
if isinstance(buf, str):
img = Image.open(StringIO(buf))
else:
img = Image.open(BytesIO(buf))
return img.convert('RGB')
def video_loader(frames, nsample, seglen, mode):
videolen = len(frames)
average_dur = int(videolen / nsample)
imgs = []
for i in range(nsample):
idx = 0
if mode == 'train':
if average_dur >= seglen:
idx = random.randint(0, average_dur - seglen)
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
else:
if average_dur >= seglen:
idx = (average_dur - seglen) // 2
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
for jj in range(idx, idx + seglen):
imgbuf = frames[int(jj % videolen)]
img = imageloader(imgbuf)
imgs.append(img)
return imgs
def mp4_loader(filepath, nsample, seglen, mode):
cap = cv2.VideoCapture(filepath)
videolen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
......@@ -858,20 +514,9 @@ def mp4_loader(filepath, nsample, seglen, mode):
return imgs
# the additional function which used to load the frames
# loading images by PIL
# def load_image(directory, idx):
# return Image.open(os.path.join(
# directory, 'img_{:05d}.jpg'.format(idx))).convert('RGB')
# loading images by opencv
def load_image(directory, idx):
img = cv2.imread(os.path.join(directory, 'img_{:05d}.jpg'.format(idx)))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
return Image.open(os.path.join(directory, 'img_{:05d}.jpg'.format(
idx))).convert('RGB')
def frames_loader(recode, nsample, seglen, mode):
......@@ -899,7 +544,5 @@ def frames_loader(recode, nsample, seglen, mode):
for jj in range(idx, idx + seglen):
img = load_image(imgpath, jj + 1)
img = Image.fromarray(img, mode='RGB')
# print("the readed image shape {}".format(img.shape))
imgs.append(img)
return imgs
configs="tsn-test.yaml"
configs="tsn_test.yaml"
use_gpu=True
use_data_parallel=False
......
......@@ -84,6 +84,12 @@ def parse_args():
default="./weights",
help='path to save the final optimized model.'
'default path is "./weights".')
parser.add_argument(
'--validate',
type=str,
default=False,
help='whether to validating in training phase.'
'default value is True.')
args = parser.parse_args()
return args
......@@ -187,6 +193,7 @@ def val(epoch, model, cfg, args):
print('Finish loss {} , acc1 {} , acc5 {}'.format(
total_loss / total_sample, total_acc1 / total_sample, total_acc5 /
total_sample))
return total_acc1 / total_sample
def create_optimizer(cfg, params):
......@@ -248,13 +255,7 @@ def train(args):
else:
gpus = gpus.split(",")
num_gpus = len(gpus)
assert num_gpus == train_config.TRAIN.num_gpus, \
"num_gpus({}) set by CUDA_VISIBLE_DEVICES" \
"shoud be the same as that" \
"set in {}({})".format(
num_gpus, args.config, train_config.TRAIN.num_gpus)
bs_denominator = train_config.TRAIN.num_gpus
bs_denominator = num_gpus
train_config.TRAIN.batch_size = int(train_config.TRAIN.batch_size /
bs_denominator)
......@@ -314,7 +315,7 @@ def train(args):
total_sample += 1
train_batch_cost = time.time() - batch_start
print(
'TRAIN Epoch: {}, iter: {}, batch_cost: {: .5f}s, reader_cost: {: .5f}s loss={: .6f}, acc1 {: .6f}, acc5 {: .6f} \t'.
'TRAIN Epoch: {}, iter: {}, batch_cost: {: .5f} s, reader_cost: {: .5f} s loss={: .6f}, acc1 {: .6f}, acc5 {: .6f} \t'.
format(epoch, batch_id, train_batch_cost, train_reader_cost,
avg_loss.numpy()[0],
acc_top1.numpy()[0], acc_top5.numpy()[0]))
......@@ -339,14 +340,27 @@ def train(args):
fluid.dygraph.save_dygraph(video_model.state_dict(), model_path)
fluid.dygraph.save_dygraph(optimizer.state_dict(), model_path)
video_model.eval()
val(epoch, video_model, valid_config, args)
if args.validate:
video_model.eval()
val_acc = val(epoch, video_model, valid_config, args)
# save the best parameters in trainging stage
if epoch == 1:
best_acc = val_acc
else:
if val_acc > best_acc:
best_acc = val_acc
if fluid.dygraph.parallel.Env().local_rank == 0:
if not os.path.isdir(args.weights):
os.makedirs(args.weights)
fluid.dygraph.save_dygraph(video_model.state_dict(),
args.weights + "/final")
else:
if fluid.dygraph.parallel.Env().local_rank == 0:
if not os.path.isdir(args.weights):
os.makedirs(args.weights)
fluid.dygraph.save_dygraph(video_model.state_dict(),
args.weights + "/final")
if fluid.dygraph.parallel.Env().local_rank == 0:
if not os.path.isdir(args.weights):
os.makedirs(args.weights)
fluid.dygraph.save_dygraph(video_model.state_dict(),
args.weights + "/final")
logger.info('[TRAIN] training finished')
......
MODEL:
name: "TSN"
format: "frames"
format: "frames" # support for "frames" or "videos"
num_classes: 101
seg_num: 3
seglen: 1
......@@ -15,9 +15,8 @@ TRAIN:
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 128
batch_size: 256
use_gpu: True
num_gpus: 4 #8
filelist: "./data/dataset/ucf101/ucf101_train_split_1_rawframes.txt"
learning_rate: 0.001
learning_rate_decay: 0.1
......@@ -40,4 +39,4 @@ TEST:
num_reader_threads: 12
buf_size: 1024
batch_size: 64
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
\ No newline at end of file
......@@ -14,14 +14,14 @@ VALID:
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1
batch_size: 128
buf_size: 4
batch_size: 32
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
TEST:
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 1
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
buf_size: 4
batch_size: 32
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册