未验证 提交 2c8b76b1 编写于 作者: H huangjun12 提交者: GitHub

add slowfast model to video classification (#4815)

上级 096fa39f
# SlowFast 视频分类模型动态图实现
---
## 内容
- [模型简介](#模型简介)
- [代码结构](#代码结构)
- [安装说明](#安装说明)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
- [模型评估](#模型评估)
- [模型预测](#模型预测)
- [参考论文](#参考论文)
## 模型简介
SlowFast是视频分类领域的高精度模型,使用slow和fast两个分支。slow分支以稀疏采样得到的帧作为输入,捕捉视频中的表观信息。fast分支以高频采样得到的帧作为输入,捕获视频中的运动信息,最终将两个分支的特征拼接得到预测结果。
<p align="center">
<img src="./SLOWFAST.png" height=300 width=500 hspace='10'/> <br />
SlowFast Overview
</p>
详细内容请参考ICCV 2019论文[SlowFast Networks for Video Recognition](https://arxiv.org/abs/1812.03982)
## 代码结构
```
├── slowfast.yaml # 多卡配置文件,用户可方便的配置超参数
├── slowfast-single.yaml # 单卡评估预测配置文件,用户可方便的配置超参数
├── run_train_multi.sh # 多卡训练脚本
├── run_train_single.sh # 单卡训练脚本
├── run_eval_multi.sh # 多卡评估脚本
├── run_infer_multi.sh # 多卡预测脚本
├── run_eval_single.sh # 单卡评估脚本
├── run_infer_single.sh # 单卡预测脚本
├── train.py # 训练代码
├── eval.py # 评估代码,评估网络性能
├── predict.py # 预测代码
├── model.py # 网络结构
├── model_utils.py # 网络结构细节相关
├── lr_policy.py # 学习率调整方式
├── kinetics_dataset.py # kinetics-400数据预处理代码
└── config_utils.py # 配置细节相关代码
```
## 安装说明
请使用```python3.7```运行样例代码
### 环境依赖:
```
CUDA >= 9.0
cudnn >= 7.5
```
### 依赖安装:
- PaddlePaddle版本>= 2.0.0-alpha0:
``` pip3.7 install paddlepaddle-gpu==2.0.0a0 -i https://mirror.baidu.com/pypi/simple ```
- 安装opencv 4.2:
``` pip3.7 install opencv-python==4.2.0.34```
- 如果想可视化训练曲线,请安装VisualDL:
``` pip3.7 install visualdl -i https://mirror.baidu.com/pypi/simple```
## 数据准备
SlowFast模型的训练数据采用Kinetics400数据集,直接输入mp4文件进行训练,数据准备方式如下:
### mp4视频下载
在Code\_Root目录下创建文件夹
```
cd $Code_Root/ && mkdir data
cd data && mkdir data_k400 && cd data_k400
mkdir train_mp4 && mkdir val_mp4
```
ActivityNet官方提供了Kinetics的下载工具,具体参考其[官方repo](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics)即可下载Kinetics400的mp4视频集合。将kinetics400的训练与验证集合分别下载到data/data\_k400/train\_mp4与data/data\_k400/val\_mp4。
### 生成训练和验证集list
验证(Validation)和评估(Evaluation)使用同一份数据集,预测(Predict)可以从验证数据集中挑选少量样本进行预测
```
cd $Code_Root/data/
ls $Code_Root/data/data_k400/train_mp4/* > train.csv
ls $Code_Root/data/data_k400/val_mp4/* > val.csv
ls $Code_Root/data/data_k400/val_mp4/* > test.csv
ls $Code_Root/data/data_k400/val_mp4/* > infer.csv
```
## 模型训练
数据准备完成后,可通过如下两种方式启动训练:
默认使用8卡训练,启动方式如下:
bash run_train_multi.sh
若使用单卡训练,启动方式如下:
bash run_train_single.sh
- 建议使用多卡训练方式,单卡由于batch\_size减小,精度可能会有损失。
- 从头开始训练,使用上述启动命令行或者脚本程序即可启动训练,不需要用到预训练模型。
- Visual DL可以用来对训练过程进行可视化,具体使用方法请参考[VisualDL](https://github.com/PaddlePaddle/VisualDL)
**训练资源要求:**
* 8卡V100,总batch\_size=64,单卡batch\_size=8,单卡显存占用约9G。
* Kinetics400训练集较大(约23万个样本),SlowFast模型迭代epoch数较多(196个),因此模型训练耗时较长,约200个小时。
* 训练加速工作进行中,敬请期待。
## 模型评估
训练完成后,可通过如下方式进行模型评估:
多卡评估方式如下:
bash run_eval_multi.sh
若使用单卡评估,启动方式如下:
bash run_eval_single.sh
- 进行评估时,可修改脚本中的`weights`参数指定用到的权重文件,如果不设置,将使用默认参数文件checkpoints/slowfast_epoch195.pdparams。
- 使用```multi_crop```的方式进行评估,因此评估有一定耗时,建议使用多卡评估,加快评估速度。若使用默认方式进行多卡评估,耗时约4小时。
- 模型最终的评估精度会打印在日志文件中。
在Kinetics400数据集下评估精度如下:
| Acc1 | Acc5 |
| :---: | :---: |
| 74.35 | 91.33 |
- 由于Kinetics400数据集部分源文件已缺失,无法下载,我们使用的数据集比官方数据少~5%,因此精度相比于论文公布的结果有一定损失。
## 模型预测
训练完成后,可通过如下方式进行模型预测:
多卡预测方式如下:
bash run_infer_multi.sh
若使用单卡预测,启动方式如下:
bash run_infer_single.sh
- 进行预测时,可修改脚本中的`weights`参数指定用到的权重文件,如果不设置,将使用默认参数文件checkpoints/slowfast_epoch195.pdparams。
- 使用```multi_crop```的方式进行评估,因此单个视频文件预测也有一定耗时。
- 预测结果保存在./data/results/result.json文件中。
## 参考论文
- [SlowFast Networks for Video Recognition](https://arxiv.org/abs/1812.03982)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import yaml
import logging
logger = logging.getLogger(__name__)
CONFIG_SECS = [
'train',
'valid',
'test',
'infer',
]
class AttrDict(dict):
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
if key in self.__dict__:
self.__dict__[key] = value
else:
self[key] = value
def parse_config(cfg_file):
"""Load a config file into AttrDict"""
import yaml
with open(cfg_file, 'r') as fopen:
yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.Loader))
create_attr_dict(yaml_config)
return yaml_config
def create_attr_dict(yaml_config):
from ast import literal_eval
for key, value in yaml_config.items():
if type(value) is dict:
yaml_config[key] = value = AttrDict(value)
if isinstance(value, str):
try:
value = literal_eval(value)
except BaseException:
pass
if isinstance(value, AttrDict):
create_attr_dict(yaml_config[key])
else:
yaml_config[key] = value
return
def merge_configs(cfg, sec, args_dict):
assert sec in CONFIG_SECS, "invalid config section {}".format(sec)
sec_dict = getattr(cfg, sec.upper())
for k, v in args_dict.items():
if v is None:
continue
try:
if hasattr(sec_dict, k):
setattr(sec_dict, k, v)
except:
pass
return cfg
def print_configs(cfg, mode):
logger.info("---------------- {:>5} Arguments ----------------".format(
mode))
for sec, sec_items in cfg.items():
logger.info("{}:".format(sec))
for k, v in sec_items.items():
logger.info(" {}:{}".format(k, v))
logger.info("-------------------------------------------------")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import argparse
import ast
import logging
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.io import DataLoader, Dataset
from paddle.incubate.hapi.distributed import DistributedBatchSampler, _all_gather
from paddle.fluid.dygraph.parallel import ParallelEnv
from model import *
from config_utils import *
from kinetics_dataset import KineticsDataset
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(
"SLOWFAST test for performance evaluation.")
parser.add_argument(
'--config_file',
type=str,
default='slowfast.yaml',
help='path to config file of model')
parser.add_argument(
'--batch_size',
type=int,
default=None,
help='total eval batch size of all gpus.')
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='default use gpu.')
parser.add_argument(
'--use_data_parallel',
type=ast.literal_eval,
default=True,
help='default use data parallel.')
parser.add_argument(
'--weights',
type=str,
default=None,
help='Weight path, None to use config setting.')
parser.add_argument(
'--log_interval',
type=int,
default=1,
help='mini-batch interval to log.')
args = parser.parse_args()
return args
# Performance Evaluation
def test_slowfast(args):
config = parse_config(args.config_file)
test_config = merge_configs(config, 'test', vars(args))
print_configs(test_config, "Test")
if not args.use_gpu:
place = fluid.CPUPlace()
elif not args.use_data_parallel:
place = fluid.CUDAPlace(0)
else:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
_nranks = ParallelEnv().nranks # num gpu
bs_single = int(test_config.TEST.batch_size /
_nranks) # batch_size of each gpu
with fluid.dygraph.guard(place):
#build model
slowfast = SlowFast(cfg=test_config, num_classes=400)
if args.weights:
assert os.path.exists(args.weights + '.pdparams'),\
"Given weight dir {} not exist.".format(args.weights)
logger.info('load test weights from {}'.format(args.weights))
model_dict, _ = fluid.load_dygraph(args.weights)
slowfast.set_dict(model_dict)
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
slowfast = fluid.dygraph.parallel.DataParallel(slowfast, strategy)
#create reader
test_data = KineticsDataset(mode="test", cfg=test_config)
test_sampler = DistributedBatchSampler(
test_data, batch_size=bs_single, shuffle=False, drop_last=False)
test_loader = DataLoader(
test_data,
batch_sampler=test_sampler,
places=place,
feed_list=None,
num_workers=8,
return_list=True)
# start eval
num_ensemble_views = test_config.TEST.num_ensemble_views
num_spatial_crops = test_config.TEST.num_spatial_crops
num_cls = test_config.MODEL.num_classes
num_clips = num_ensemble_views * num_spatial_crops
num_videos = len(test_data) // num_clips
video_preds = np.zeros((num_videos, num_cls))
video_labels = np.zeros((num_videos, 1), dtype="int64")
clip_count = {}
print(
"[EVAL] eval start, number of videos {}, total number of clips {}".
format(num_videos, num_clips * num_videos))
slowfast.eval()
for batch_id, data in enumerate(test_loader):
# call net
model_inputs = [data[0], data[1]]
preds = slowfast(model_inputs, training=False)
labels = data[2]
clip_ids = data[3]
# gather mulit card, results of following process in each card is the same.
if _nranks > 1:
preds = _all_gather(preds, _nranks)
labels = _all_gather(labels, _nranks)
clip_ids = _all_gather(clip_ids, _nranks)
# to numpy
preds = preds.numpy()
labels = labels.numpy().astype("int64")
clip_ids = clip_ids.numpy()
# preds ensemble
for ind in range(preds.shape[0]):
vid_id = int(clip_ids[ind]) // num_clips
ts_idx = int(clip_ids[ind]) % num_clips
if vid_id not in clip_count:
clip_count[vid_id] = []
if ts_idx in clip_count[vid_id]:
print(
"[EVAL] Passed!! read video {} clip index {} / {} repeatedly.".
format(vid_id, ts_idx, clip_ids[ind]))
else:
clip_count[vid_id].append(ts_idx)
video_preds[vid_id] += preds[ind] # ensemble method: sum
if video_labels[vid_id].sum() > 0:
assert video_labels[vid_id] == labels[ind]
video_labels[vid_id] = labels[ind]
if batch_id % args.log_interval == 0:
print("[EVAL] Processing batch {}/{} ...".format(
batch_id, len(test_data) // test_config.TEST.batch_size))
# check clip index of each video
for key in clip_count.keys():
if len(clip_count[key]) != num_clips or sum(clip_count[
key]) != num_clips * (num_clips - 1) / 2:
print(
"[EVAL] Warning!! video [{}] clip count [{}] not match number clips {}".
format(key, clip_count[key], num_clips))
video_preds = to_variable(video_preds)
video_labels = to_variable(video_labels)
acc_top1 = fluid.layers.accuracy(
input=video_preds, label=video_labels, k=1)
acc_top5 = fluid.layers.accuracy(
input=video_preds, label=video_labels, k=5)
print('[EVAL] eval finished, avg_acc1= {}, avg_acc5= {} '.format(
acc_top1.numpy(), acc_top5.numpy()))
if __name__ == "__main__":
args = parse_args()
logger.info(args)
test_slowfast(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import math
import random
import numpy as np
from PIL import Image, ImageEnhance
import logging
from paddle.io import Dataset
logger = logging.getLogger(__name__)
__all__ = ['KineticsDataset']
class KineticsDataset(Dataset):
def __init__(self, mode, cfg):
self.mode = mode
self.format = cfg.MODEL.format
self.num_frames = cfg.MODEL.num_frames
self.sampling_rate = cfg.MODEL.sampling_rate
self.target_fps = cfg.MODEL.target_fps
self.slowfast_alpha = cfg.MODEL.alpha
self.target_size = cfg[mode.upper()]['target_size']
self.img_mean = cfg.MODEL.image_mean
self.img_std = cfg.MODEL.image_std
self.filelist = cfg[mode.upper()]['filelist']
if self.mode in ["train", "valid"]:
self.min_size = cfg[mode.upper()]['min_size']
self.max_size = cfg[mode.upper()]['max_size']
self.num_ensemble_views = 1
self.num_spatial_crops = 1
self._num_clips = 1
elif self.mode in ['test', 'infer']:
self.min_size = self.max_size = self.target_size
self.num_ensemble_views = cfg.TEST.num_ensemble_views
self.num_spatial_crops = cfg.TEST.num_spatial_crops
self._num_clips = (self.num_ensemble_views * self.num_spatial_crops)
self._construct_loader()
def _construct_loader(self):
"""
Construct the video loader.
"""
self._num_retries = 5
self._path_to_videos = []
self._labels = []
self._spatial_temporal_idx = []
with open(self.filelist, "r") as f:
for clip_idx, path_label in enumerate(f.read().splitlines()):
if self.mode == 'infer':
path = path_label
label = 0 # without label when infer actually
else:
path, label = path_label.split()
for idx in range(self._num_clips):
self._path_to_videos.append(path)
self._labels.append(int(label))
self._spatial_temporal_idx.append(idx)
def __len__(self):
return len(self._path_to_videos)
def __getitem__(self, idx):
if self.mode in ["train", "valid"]:
temporal_sample_index = -1
spatial_sample_index = -1
elif self.mode in ["test", 'infer']:
temporal_sample_index = (self._spatial_temporal_idx[idx] //
self.num_spatial_crops)
spatial_sample_index = (self._spatial_temporal_idx[idx] %
self.num_spatial_crops)
for ir in range(self._num_retries):
mp4_path = self._path_to_videos[idx]
try:
pathways = self.mp4_loader(
mp4_path,
temporal_sample_index,
spatial_sample_index,
temporal_num_clips=self.num_ensemble_views,
spatial_num_clips=self.num_spatial_crops,
num_frames=self.num_frames,
sampling_rate=self.sampling_rate,
target_fps=self.target_fps,
target_size=self.target_size,
img_mean=self.img_mean,
img_std=self.img_std,
slowfast_alpha=self.slowfast_alpha,
min_size=self.min_size,
max_size=self.max_size)
except:
if ir < self._num_retries - 1:
logger.error(
'Error when loading {}, have {} trys, will try again'.
format(mp4_path, ir))
idx = random.randint(0, len(self._path_to_videos) - 1)
continue
else:
logger.error(
'Error when loading {}, have {} trys, will not try again'.
format(mp4_path, ir))
return None, None
label = self._labels[idx]
return pathways[0], pathways[1], np.array([label]), np.array([idx])
def mp4_loader(self, filepath, temporal_sample_index, spatial_sample_index,
temporal_num_clips, spatial_num_clips, num_frames,
sampling_rate, target_fps, target_size, img_mean, img_std,
slowfast_alpha, min_size, max_size):
frames_sample, clip_size = self.decode_sampling(
filepath, temporal_sample_index, temporal_num_clips, num_frames,
sampling_rate, target_fps)
frames_select = self.temporal_sampling(
frames_sample, clip_size, num_frames, filepath,
temporal_sample_index, temporal_num_clips)
frames_resize = self.scale(frames_select, min_size, max_size)
frames_crop = self.crop(frames_resize, target_size,
spatial_sample_index, spatial_num_clips)
frames_flip = self.flip(frames_crop, spatial_sample_index)
#list to nparray
npframes = (np.stack(frames_flip)).astype('float32')
npframes_norm = self.color_norm(npframes, img_mean, img_std)
frames_list = self.pack_output(npframes_norm, slowfast_alpha)
return frames_list
def get_start_end_idx(self, video_size, clip_size, clip_idx,
temporal_num_clips):
delta = max(video_size - clip_size, 0)
if clip_idx == -1: # when test, temporal_num_clips is not used
# Random temporal sampling.
start_idx = random.uniform(0, delta)
else:
# Uniformly sample the clip with the given index.
start_idx = delta * clip_idx / temporal_num_clips
end_idx = start_idx + clip_size - 1
return start_idx, end_idx
def decode_sampling(self, filepath, temporal_sample_index,
temporal_num_clips, num_frames, sampling_rate,
target_fps):
cap = cv2.VideoCapture(filepath)
videolen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
(major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.')
if int(major_ver) < 3:
fps = cap.get(cv2.cv.CV_CAP_PROP_FPS)
else:
fps = cap.get(cv2.CAP_PROP_FPS)
clip_size = num_frames * sampling_rate * fps / target_fps
if filepath[-3:] != 'mp4':
start_idx, end_idx = 0, math.inf
else:
start_idx, end_idx = self.get_start_end_idx(
videolen, clip_size, temporal_sample_index, temporal_num_clips)
#print("filepath:",filepath,"start_idx:",start_idx,"end_idx:",end_idx)
frames_sample = [] #start randomly, decode clip size
start_idx = math.ceil(start_idx)
cap.set(cv2.CAP_PROP_POS_FRAMES, start_idx)
for i in range(videolen):
if i < start_idx:
continue
ret, frame = cap.read()
if ret == False:
continue
if i <= end_idx + 1: #buffer
img = frame[:, :, ::-1] #BGR -> RGB
frames_sample.append(img)
else:
break
return frames_sample, clip_size
def temporal_sampling(self, frames_sample, clip_size, num_frames, filepath,
temporal_sample_index, temporal_num_clips):
""" sample num_frames from clip_size """
fs_len = len(frames_sample)
if filepath[-3:] != 'mp4':
start_idx, end_idx = self.get_start_end_idx(
fs_len, clip_size, temporal_sample_index, temporal_num_clips)
else:
start_idx, end_idx = self.get_start_end_idx(fs_len, clip_size, 0, 1)
index = np.linspace(start_idx, end_idx, num_frames).astype("int64")
index = np.clip(index, 0, fs_len - 1)
frames_select = []
for i in range(index.shape[0]):
idx = index[i]
imgbuf = frames_sample[idx]
img = Image.fromarray(imgbuf, mode='RGB')
frames_select.append(img)
return frames_select
def scale(self, frames_select, min_size, max_size):
size = int(round(np.random.uniform(min_size, max_size)))
assert (len(frames_select) >= 1) , \
"len(frames_select):{} should be larger than 1".format(len(frames_select))
width, height = frames_select[0].size
if (width <= height and width == size) or (height <= width and
height == size):
return frames_select
new_width = size
new_height = size
if width < height:
new_height = int(math.floor((float(height) / width) * size))
else:
new_width = int(math.floor((float(width) / height) * size))
frames_resize = []
for j in range(len(frames_select)):
img = frames_select[j]
scale_img = img.resize((new_width, new_height), Image.BILINEAR)
frames_resize.append(scale_img)
return frames_resize
def crop(self, frames_resize, target_size, spatial_sample_index,
spatial_num_clips):
w, h = frames_resize[0].size
if w == target_size and h == target_size:
return frames_resize
assert (w >= target_size) and (h >= target_size), \
"image width({}) and height({}) should be larger than crop size({},{})".format(w, h, target_size, target_size)
frames_crop = []
if spatial_sample_index == -1:
x_offset = random.randint(0, w - target_size)
y_offset = random.randint(0, h - target_size)
else:
x_gap = int(math.ceil((w - target_size) / (spatial_num_clips - 1)))
y_gap = int(math.ceil((h - target_size) / (spatial_num_clips - 1)))
if h > w:
x_offset = int(math.ceil((w - target_size) / 2))
if spatial_sample_index == 0:
y_offset = 0
elif spatial_sample_index == spatial_num_clips - 1:
y_offset = h - target_size
else:
y_offset = y_gap * spatial_sample_index
else:
y_offset = int(math.ceil((h - target_size) / 2))
if spatial_sample_index == 0:
x_offset = 0
elif spatial_sample_index == spatial_num_clips - 1:
x_offset = w - target_size
else:
x_offset = x_gap * spatial_sample_index
for img in frames_resize:
nimg = img.crop((x_offset, y_offset, x_offset + target_size,
y_offset + target_size))
frames_crop.append(nimg)
return frames_crop
def flip(self, frames_crop, spatial_sample_index):
# without flip when test
if spatial_sample_index != -1:
return frames_crop
frames_flip = []
if np.random.uniform() < 0.5:
for img in frames_crop:
nimg = img.transpose(Image.FLIP_LEFT_RIGHT)
frames_flip.append(nimg)
else:
frames_flip = frames_crop
return frames_flip
def color_norm(self, npframes_norm, c_mean, c_std):
npframes_norm /= 255.0
npframes_norm -= np.array(c_mean).reshape(
[1, 1, 1, 3]).astype(np.float32)
npframes_norm /= np.array(c_std).reshape(
[1, 1, 1, 3]).astype(np.float32)
return npframes_norm
def pack_output(self, npframes_norm, slowfast_alpha):
fast_pathway = npframes_norm
# sample num points between start and end
slow_idx_start = 0
slow_idx_end = fast_pathway.shape[0] - 1
slow_idx_num = fast_pathway.shape[0] // slowfast_alpha
slow_idxs_select = np.linspace(slow_idx_start, slow_idx_end,
slow_idx_num).astype("int64")
slow_pathway = fast_pathway[slow_idxs_select]
# T H W C -> C T H W.
slow_pathway = slow_pathway.transpose(3, 0, 1, 2)
fast_pathway = fast_pathway.transpose(3, 0, 1, 2)
# slow + fast
frames_list = [slow_pathway, fast_pathway]
return frames_list
"""Learning rate policy."""
import math
def get_epoch_lr(cur_epoch, cfg):
"""
Retrieve the learning rate of the current epoch with the option to perform
warm up in the beginning of the training stage.
Args:
cfg (CfgNode): configs. Details can be found in
slowfast/config/defaults.py
cur_epoch (float): the number of epoch of the current training stage.
#"""
warmup_epochs = cfg.warmup_epochs #34
warmup_start_lr = cfg.warmup_start_lr #0.01
lr = lr_func_cosine(cur_epoch, cfg)
# Perform warm up.
if cur_epoch < warmup_epochs:
lr_start = warmup_start_lr
lr_end = lr_func_cosine(warmup_epochs, cfg)
alpha = (lr_end - lr_start) / warmup_epochs
lr = cur_epoch * alpha + lr_start
return lr
def lr_func_cosine(cur_epoch, cfg):
"""
Retrieve the learning rate to specified values at specified epoch with the
cosine learning rate schedule. Details can be found in:
Ilya Loshchilov, and Frank Hutter
SGDR: Stochastic Gradient Descent With Warm Restarts.
Args:
cfg (CfgNode): configs. Details can be found in
slowfast/config/defaults.py
cur_epoch (float): the number of epoch of the current training stage.
"""
base_lr = cfg.base_lr #0.1
max_epoch = cfg.epoch #196
return base_lr * (math.cos(math.pi * cur_epoch / max_epoch) + 1.0) * 0.5
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from model_utils import *
class ResNetBasicStem(fluid.dygraph.Layer):
"""
ResNe(X)t 3D stem module.
Performs spatiotemporal Convolution, BN, and Relu following by a
spatiotemporal pooling.
"""
def __init__(
self,
dim_in,
dim_out,
kernel,
stride,
padding,
eps=1e-5, ):
super(ResNetBasicStem, self).__init__()
self.kernel = kernel
self.stride = stride
self.padding = padding
self.eps = eps
self._construct_stem(dim_in, dim_out)
def _construct_stem(self, dim_in, dim_out):
fan = (dim_out) * (self.kernel[0] * self.kernel[1] * self.kernel[2])
initializer_tmp = get_conv_init(fan)
batchnorm_weight = 1.0
self._conv = fluid.dygraph.nn.Conv3D(
num_channels=dim_in,
num_filters=dim_out,
filter_size=self.kernel,
stride=self.stride,
padding=self.padding,
param_attr=fluid.ParamAttr(initializer=initializer_tmp),
bias_attr=False)
self._bn = fluid.dygraph.BatchNorm(
num_channels=dim_out,
epsilon=self.eps,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(batchnorm_weight),
regularizer=fluid.regularizer.L2Decay(
regularization_coeff=0.0)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=fluid.regularizer.L2Decay(
regularization_coeff=0.0)))
def forward(self, x):
x = self._conv(x)
x = self._bn(x)
x = fluid.layers.relu(x)
x = fluid.layers.pool3d(
input=x,
pool_type="max",
pool_size=[1, 3, 3],
pool_stride=[1, 2, 2],
pool_padding=[0, 1, 1],
data_format="NCDHW")
return x
class VideoModelStem(fluid.dygraph.Layer):
"""
Video 3D stem module. Provides stem operations of Conv, BN, ReLU, MaxPool
on input data tensor for slow and fast pathways.
"""
def __init__(
self,
dim_in,
dim_out,
kernel,
stride,
padding,
eps=1e-5, ):
"""
Args:
dim_in (list): the list of channel dimensions of the inputs.
dim_out (list): the output dimension of the convolution in the stem
layer.
kernel (list): the kernels' size of the convolutions in the stem
layers. Temporal kernel size, height kernel size, width kernel
size in order.
stride (list): the stride sizes of the convolutions in the stem
layer. Temporal kernel stride, height kernel size, width kernel
size in order.
padding (list): the paddings' sizes of the convolutions in the stem
layer. Temporal padding size, height padding size, width padding
size in order.
eps (float): epsilon for batch norm.
"""
super(VideoModelStem, self).__init__()
assert (len({
len(dim_in),
len(dim_out),
len(kernel),
len(stride),
len(padding),
}) == 1), "Input pathway dimensions are not consistent."
self.num_pathways = len(dim_in)
self.kernel = kernel
self.stride = stride
self.padding = padding
self.eps = eps
self._construct_stem(dim_in, dim_out)
def _construct_stem(self, dim_in, dim_out):
for pathway in range(len(dim_in)):
stem = ResNetBasicStem(
dim_in[pathway],
dim_out[pathway],
self.kernel[pathway],
self.stride[pathway],
self.padding[pathway],
self.eps, )
self.add_sublayer("pathway{}_stem".format(pathway), stem)
def forward(self, x):
assert (
len(x) == self.num_pathways
), "Input tensor does not contain {} pathway".format(self.num_pathways)
for pathway in range(len(x)):
m = getattr(self, "pathway{}_stem".format(pathway))
x[pathway] = m(to_variable(x[pathway]))
return x
class FuseFastToSlow(fluid.dygraph.Layer):
"""
Fuses the information from the Fast pathway to the Slow pathway. Given the
tensors from Slow pathway and Fast pathway, fuse information from Fast to
Slow, then return the fused tensors from Slow and Fast pathway in order.
"""
def __init__(
self,
dim_in,
fusion_conv_channel_ratio,
fusion_kernel,
alpha,
eps=1e-5, ):
"""
Args:
dim_in (int): the channel dimension of the input.
fusion_conv_channel_ratio (int): channel ratio for the convolution
used to fuse from Fast pathway to Slow pathway.
fusion_kernel (int): kernel size of the convolution used to fuse
from Fast pathway to Slow pathway.
alpha (int): the frame rate ratio between the Fast and Slow pathway.
eps (float): epsilon for batch norm.
"""
super(FuseFastToSlow, self).__init__()
fan = (dim_in * fusion_conv_channel_ratio) * (fusion_kernel * 1 * 1)
initializer_tmp = get_conv_init(fan)
batchnorm_weight = 1.0
self._conv_f2s = fluid.dygraph.nn.Conv3D(
num_channels=dim_in,
num_filters=dim_in * fusion_conv_channel_ratio,
filter_size=[fusion_kernel, 1, 1],
stride=[alpha, 1, 1],
padding=[fusion_kernel // 2, 0, 0],
param_attr=fluid.ParamAttr(initializer=initializer_tmp),
bias_attr=False)
self._bn = fluid.dygraph.BatchNorm(
num_channels=dim_in * fusion_conv_channel_ratio,
epsilon=eps,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(batchnorm_weight),
regularizer=fluid.regularizer.L2Decay(
regularization_coeff=0.0)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=fluid.regularizer.L2Decay(
regularization_coeff=0.0)))
def forward(self, x):
x_s = x[0]
x_f = x[1]
fuse = self._conv_f2s(x_f)
fuse = self._bn(fuse)
fuse = fluid.layers.relu(fuse)
x_s_fuse = fluid.layers.concat(input=[x_s, fuse], axis=1, name=None)
return [x_s_fuse, x_f]
class SlowFast(fluid.dygraph.Layer):
"""
SlowFast model builder for SlowFast network.
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
"Slowfast networks for video recognition."
https://arxiv.org/pdf/1812.03982.pdf
"""
def __init__(self, cfg, num_classes):
"""
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
super(SlowFast, self).__init__()
self.num_classes = num_classes
self.num_frames = cfg.MODEL.num_frames #32
self.alpha = cfg.MODEL.alpha #8
self.beta = cfg.MODEL.beta #8
self.crop_size = cfg.MODEL.crop_size #224
self.num_pathways = 2
self.res_depth = 50
self.num_groups = 1
self.input_channel_num = [3, 3]
self.width_per_group = 64
self.fusion_conv_channel_ratio = 2
self.fusion_kernel_sz = 5
self.dropout_rate = 0.5
self._construct_network(cfg)
def _construct_network(self, cfg):
"""
Builds a SlowFast model.
The first pathway is the Slow pathway
and the second pathway is the Fast pathway.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
temp_kernel = [
[[1], [5]], # conv1 temporal kernel for slow and fast pathway.
[[1], [3]], # res2 temporal kernel for slow and fast pathway.
[[1], [3]], # res3 temporal kernel for slow and fast pathway.
[[3], [3]], # res4 temporal kernel for slow and fast pathway.
[[3], [3]],
] # res5 temporal kernel for slow and fast pathway.
self.s1 = VideoModelStem(
dim_in=self.input_channel_num,
dim_out=[self.width_per_group, self.width_per_group // self.beta],
kernel=[temp_kernel[0][0] + [7, 7], temp_kernel[0][1] + [7, 7]],
stride=[[1, 2, 2]] * 2,
padding=[
[temp_kernel[0][0][0] // 2, 3, 3],
[temp_kernel[0][1][0] // 2, 3, 3],
], )
self.s1_fuse = FuseFastToSlow(
dim_in=self.width_per_group // self.beta,
fusion_conv_channel_ratio=self.fusion_conv_channel_ratio,
fusion_kernel=self.fusion_kernel_sz,
alpha=self.alpha, )
# ResNet backbone
MODEL_STAGE_DEPTH = {50: (3, 4, 6, 3)}
(d2, d3, d4, d5) = MODEL_STAGE_DEPTH[self.res_depth]
num_block_temp_kernel = [[3, 3], [4, 4], [6, 6], [3, 3]]
spatial_dilations = [[1, 1], [1, 1], [1, 1], [1, 1]]
spatial_strides = [[1, 1], [2, 2], [2, 2], [2, 2]]
out_dim_ratio = self.beta // self.fusion_conv_channel_ratio #4
dim_inner = self.width_per_group * self.num_groups #64
self.s2 = ResStage(
dim_in=[
self.width_per_group + self.width_per_group // out_dim_ratio,
self.width_per_group // self.beta,
],
dim_out=[
self.width_per_group * 4,
self.width_per_group * 4 // self.beta,
],
dim_inner=[dim_inner, dim_inner // self.beta],
temp_kernel_sizes=temp_kernel[1],
stride=spatial_strides[0],
num_blocks=[d2] * 2,
num_groups=[self.num_groups] * 2,
num_block_temp_kernel=num_block_temp_kernel[0],
dilation=spatial_dilations[0], )
self.s2_fuse = FuseFastToSlow(
dim_in=self.width_per_group * 4 // self.beta,
fusion_conv_channel_ratio=self.fusion_conv_channel_ratio,
fusion_kernel=self.fusion_kernel_sz,
alpha=self.alpha, )
self.s3 = ResStage(
dim_in=[
self.width_per_group * 4 + self.width_per_group * 4 //
out_dim_ratio,
self.width_per_group * 4 // self.beta,
],
dim_out=[
self.width_per_group * 8,
self.width_per_group * 8 // self.beta,
],
dim_inner=[dim_inner * 2, dim_inner * 2 // self.beta],
temp_kernel_sizes=temp_kernel[2],
stride=spatial_strides[1],
num_blocks=[d3] * 2,
num_groups=[self.num_groups] * 2,
num_block_temp_kernel=num_block_temp_kernel[1],
dilation=spatial_dilations[1], )
self.s3_fuse = FuseFastToSlow(
dim_in=self.width_per_group * 8 // self.beta,
fusion_conv_channel_ratio=self.fusion_conv_channel_ratio,
fusion_kernel=self.fusion_kernel_sz,
alpha=self.alpha, )
self.s4 = ResStage(
dim_in=[
self.width_per_group * 8 + self.width_per_group * 8 //
out_dim_ratio,
self.width_per_group * 8 // self.beta,
],
dim_out=[
self.width_per_group * 16,
self.width_per_group * 16 // self.beta,
],
dim_inner=[dim_inner * 4, dim_inner * 4 // self.beta],
temp_kernel_sizes=temp_kernel[3],
stride=spatial_strides[2],
num_blocks=[d4] * 2,
num_groups=[self.num_groups] * 2,
num_block_temp_kernel=num_block_temp_kernel[2],
dilation=spatial_dilations[2], )
self.s4_fuse = FuseFastToSlow(
dim_in=self.width_per_group * 16 // self.beta,
fusion_conv_channel_ratio=self.fusion_conv_channel_ratio,
fusion_kernel=self.fusion_kernel_sz,
alpha=self.alpha, )
self.s5 = ResStage(
dim_in=[
self.width_per_group * 16 + self.width_per_group * 16 //
out_dim_ratio,
self.width_per_group * 16 // self.beta,
],
dim_out=[
self.width_per_group * 32,
self.width_per_group * 32 // self.beta,
],
dim_inner=[dim_inner * 8, dim_inner * 8 // self.beta],
temp_kernel_sizes=temp_kernel[4],
stride=spatial_strides[3],
num_blocks=[d5] * 2,
num_groups=[self.num_groups] * 2,
num_block_temp_kernel=num_block_temp_kernel[3],
dilation=spatial_dilations[3], )
self.pool_size = [[1, 1, 1], [1, 1, 1]]
self.head = ResNetBasicHead(
dim_in=[
self.width_per_group * 32,
self.width_per_group * 32 // self.beta,
],
num_classes=self.num_classes,
pool_size=[
[
self.num_frames // self.alpha // self.pool_size[0][0],
self.crop_size // 32 // self.pool_size[0][1],
self.crop_size // 32 // self.pool_size[0][2],
],
[
self.num_frames // self.pool_size[1][0],
self.crop_size // 32 // self.pool_size[1][1],
self.crop_size // 32 // self.pool_size[1][2],
],
],
dropout_rate=self.dropout_rate, )
def forward(self, x, training):
x = self.s1(x) #VideoModelStem
x = self.s1_fuse(x) #FuseFastToSlow
x = self.s2(x) #ResStage
x = self.s2_fuse(x)
for pathway in range(self.num_pathways):
x[pathway] = fluid.layers.pool3d(
input=x[pathway],
pool_type="max",
pool_size=self.pool_size[pathway],
pool_stride=self.pool_size[pathway],
pool_padding=[0, 0, 0],
data_format="NCDHW")
x = self.s3(x)
x = self.s3_fuse(x)
x = self.s4(x)
x = self.s4_fuse(x)
x = self.s5(x)
x = self.head(x, training) #ResNetBasicHead
return x
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from paddle.fluid.initializer import MSRA
import paddle.fluid as fluid
# get init parameters for conv layer
def get_conv_init(fan_out):
return MSRA(uniform=False, fan_in=fan_out)
"""Video models."""
class BottleneckTransform(fluid.dygraph.Layer):
"""
Bottleneck transformation: Tx1x1, 1x3x3, 1x1x1, where T is the size of
temporal kernel.
"""
def __init__(
self,
dim_in,
dim_out,
temp_kernel_size,
stride,
dim_inner,
num_groups,
stride_1x1=False,
inplace_relu=True,
eps=1e-5,
dilation=1, ):
"""
Args:
dim_in (int): the channel dimensions of the input.
dim_out (int): the channel dimension of the output.
temp_kernel_size (int): the temporal kernel sizes of the middle
convolution in the bottleneck.
stride (int): the stride of the bottleneck.
dim_inner (int): the inner dimension of the block.
num_groups (int): number of groups for the convolution. num_groups=1
is for standard ResNet like networks, and num_groups>1 is for
ResNeXt like networks.
stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise
apply stride to the 3x3 conv.
inplace_relu (bool): if True, calculate the relu on the original
input without allocating new memory.
eps (float): epsilon for batch norm.
dilation (int): size of dilation.
"""
super(BottleneckTransform, self).__init__()
self.temp_kernel_size = temp_kernel_size
self._inplace_relu = inplace_relu
self._eps = eps
self._stride_1x1 = stride_1x1
self._construct(dim_in, dim_out, stride, dim_inner, num_groups,
dilation)
def _construct(self, dim_in, dim_out, stride, dim_inner, num_groups,
dilation):
str1x1, str3x3 = (stride, 1) if self._stride_1x1 else (1, stride)
fan = (dim_inner) * (self.temp_kernel_size * 1 * 1)
initializer_tmp = get_conv_init(fan)
batchnorm_weight = 1.0
self.a = fluid.dygraph.nn.Conv3D(
num_channels=dim_in,
num_filters=dim_inner,
filter_size=[self.temp_kernel_size, 1, 1],
stride=[1, str1x1, str1x1],
padding=[int(self.temp_kernel_size // 2), 0, 0],
param_attr=fluid.ParamAttr(initializer=initializer_tmp),
bias_attr=False)
self.a_bn = fluid.dygraph.BatchNorm(
num_channels=dim_inner,
epsilon=self._eps,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(batchnorm_weight),
regularizer=fluid.regularizer.L2Decay(
regularization_coeff=0.0)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=fluid.regularizer.L2Decay(
regularization_coeff=0.0)))
# 1x3x3, BN, ReLU.
fan = (dim_inner) * (1 * 3 * 3)
initializer_tmp = get_conv_init(fan)
batchnorm_weight = 1.0
self.b = fluid.dygraph.nn.Conv3D(
num_channels=dim_inner,
num_filters=dim_inner,
filter_size=[1, 3, 3],
stride=[1, str3x3, str3x3],
padding=[0, dilation, dilation],
groups=num_groups,
dilation=[1, dilation, dilation],
param_attr=fluid.ParamAttr(initializer=initializer_tmp),
bias_attr=False)
self.b_bn = fluid.dygraph.BatchNorm(
num_channels=dim_inner,
epsilon=self._eps,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(batchnorm_weight),
regularizer=fluid.regularizer.L2Decay(
regularization_coeff=0.0)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=fluid.regularizer.L2Decay(
regularization_coeff=0.0)))
# 1x1x1, BN.
fan = (dim_out) * (1 * 1 * 1)
initializer_tmp = get_conv_init(fan)
batchnorm_weight = 0.0
self.c = fluid.dygraph.nn.Conv3D(
num_channels=dim_inner,
num_filters=dim_out,
filter_size=[1, 1, 1],
stride=[1, 1, 1],
padding=[0, 0, 0],
param_attr=fluid.ParamAttr(initializer=initializer_tmp),
bias_attr=False)
self.c_bn = fluid.dygraph.BatchNorm(
num_channels=dim_out,
epsilon=self._eps,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(batchnorm_weight),
regularizer=fluid.regularizer.L2Decay(
regularization_coeff=0.0)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=fluid.regularizer.L2Decay(
regularization_coeff=0.0)))
def forward(self, x):
# Branch2a.
x = self.a(x)
x = self.a_bn(x)
x = fluid.layers.relu(x)
# Branch2b.
x = self.b(x)
x = self.b_bn(x)
x = fluid.layers.relu(x)
# Branch2c
x = self.c(x)
x = self.c_bn(x)
return x
class ResBlock(fluid.dygraph.Layer):
"""
Residual block.
"""
def __init__(
self,
dim_in,
dim_out,
temp_kernel_size,
stride,
dim_inner,
num_groups=1,
stride_1x1=False,
inplace_relu=True,
eps=1e-5,
dilation=1, ):
"""
ResBlock class constructs redisual blocks. More details can be found in:
Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun.
"Deep residual learning for image recognition."
https://arxiv.org/abs/1512.03385
Args:
dim_in (int): the channel dimensions of the input.
dim_out (int): the channel dimension of the output.
temp_kernel_size (int): the temporal kernel sizes of the middle
convolution in the bottleneck.
stride (int): the stride of the bottleneck.
trans_func (string): transform function to be used to construct the
bottleneck.
dim_inner (int): the inner dimension of the block.
num_groups (int): number of groups for the convolution. num_groups=1
is for standard ResNet like networks, and num_groups>1 is for
ResNeXt like networks.
stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise
apply stride to the 3x3 conv.
inplace_relu (bool): calculate the relu on the original input
without allocating new memory.
eps (float): epsilon for batch norm.
dilation (int): size of dilation.
"""
super(ResBlock, self).__init__()
self._inplace_relu = inplace_relu
self._eps = eps
self._construct(
dim_in,
dim_out,
temp_kernel_size,
stride,
dim_inner,
num_groups,
stride_1x1,
inplace_relu,
dilation, )
def _construct(
self,
dim_in,
dim_out,
temp_kernel_size,
stride,
dim_inner,
num_groups,
stride_1x1,
inplace_relu,
dilation, ):
# Use skip connection with projection if dim or res change.
if (dim_in != dim_out) or (stride != 1):
fan = (dim_out) * (1 * 1 * 1)
initializer_tmp = get_conv_init(fan)
batchnorm_weight = 1.0
self.branch1 = fluid.dygraph.nn.Conv3D(
num_channels=dim_in,
num_filters=dim_out,
filter_size=1,
stride=[1, stride, stride],
padding=0,
param_attr=fluid.ParamAttr(initializer=initializer_tmp),
bias_attr=False,
dilation=1)
self.branch1_bn = fluid.dygraph.BatchNorm(
num_channels=dim_out,
epsilon=self._eps,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(batchnorm_weight),
regularizer=fluid.regularizer.L2Decay(
regularization_coeff=0.0)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=fluid.regularizer.L2Decay(
regularization_coeff=0.0)))
self.branch2 = BottleneckTransform(
dim_in,
dim_out,
temp_kernel_size,
stride,
dim_inner,
num_groups,
stride_1x1=stride_1x1,
inplace_relu=inplace_relu,
dilation=dilation, )
def forward(self, x):
if hasattr(self, "branch1"):
x1 = self.branch1(x)
x1 = self.branch1_bn(x1)
x2 = self.branch2(x)
x = fluid.layers.elementwise_add(x=x1, y=x2)
else:
x2 = self.branch2(x)
x = fluid.layers.elementwise_add(x=x, y=x2)
x = fluid.layers.relu(x)
return x
class ResStage(fluid.dygraph.Layer):
"""
Stage of 3D ResNet. It expects to have one or more tensors as input for
multi-pathway (SlowFast) cases. More details can be found here:
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
"Slowfast networks for video recognition."
https://arxiv.org/pdf/1812.03982.pdf
"""
def __init__(
self,
dim_in,
dim_out,
stride,
temp_kernel_sizes,
num_blocks,
dim_inner,
num_groups,
num_block_temp_kernel,
dilation,
stride_1x1=False,
inplace_relu=True, ):
"""
The `__init__` method of any subclass should also contain these arguments.
ResStage builds p streams, where p can be greater or equal to one.
Args:
dim_in (list): list of p the channel dimensions of the input.
Different channel dimensions control the input dimension of
different pathways.
dim_out (list): list of p the channel dimensions of the output.
Different channel dimensions control the input dimension of
different pathways.
temp_kernel_sizes (list): list of the p temporal kernel sizes of the
convolution in the bottleneck. Different temp_kernel_sizes
control different pathway.
stride (list): list of the p strides of the bottleneck. Different
stride control different pathway.
num_blocks (list): list of p numbers of blocks for each of the
pathway.
dim_inner (list): list of the p inner channel dimensions of the
input. Different channel dimensions control the input dimension
of different pathways.
num_groups (list): list of number of p groups for the convolution.
num_groups=1 is for standard ResNet like networks, and
num_groups>1 is for ResNeXt like networks.
num_block_temp_kernel (list): extent the temp_kernel_sizes to
num_block_temp_kernel blocks, then fill temporal kernel size
of 1 for the rest of the layers.
dilation (list): size of dilation for each pathway.
"""
super(ResStage, self).__init__()
assert all((num_block_temp_kernel[i] <= num_blocks[i]
for i in range(len(temp_kernel_sizes))))
self.num_blocks = num_blocks
self.temp_kernel_sizes = [(temp_kernel_sizes[i] * num_blocks[i]
)[:num_block_temp_kernel[i]] + [1] *
(num_blocks[i] - num_block_temp_kernel[i])
for i in range(len(temp_kernel_sizes))]
assert (len({
len(dim_in),
len(dim_out),
len(temp_kernel_sizes),
len(stride),
len(num_blocks),
len(dim_inner),
len(num_groups),
len(num_block_temp_kernel),
}) == 1)
self.num_pathways = len(self.num_blocks)
self._construct(
dim_in,
dim_out,
stride,
dim_inner,
num_groups,
stride_1x1,
inplace_relu,
dilation, )
def _construct(
self,
dim_in,
dim_out,
stride,
dim_inner,
num_groups,
stride_1x1,
inplace_relu,
dilation, ):
for pathway in range(self.num_pathways):
for i in range(self.num_blocks[pathway]):
res_block = ResBlock(
dim_in[pathway] if i == 0 else dim_out[pathway],
dim_out[pathway],
self.temp_kernel_sizes[pathway][i],
stride[pathway] if i == 0 else 1,
dim_inner[pathway],
num_groups[pathway],
stride_1x1=stride_1x1,
inplace_relu=inplace_relu,
dilation=dilation[pathway], )
self.add_sublayer("pathway{}_res{}".format(pathway, i),
res_block)
def forward(self, inputs):
output = []
for pathway in range(self.num_pathways):
x = inputs[pathway]
for i in range(self.num_blocks[pathway]):
m = getattr(self, "pathway{}_res{}".format(pathway, i))
x = m(x)
output.append(x)
return output
"""ResNe(X)t Head helper."""
class ResNetBasicHead(fluid.dygraph.Layer):
"""
ResNe(X)t 3D head.
This layer performs a fully-connected projection during training, when the
input size is 1x1x1. It performs a convolutional projection during testing
when the input size is larger than 1x1x1. If the inputs are from multiple
different pathways, the inputs will be concatenated after pooling.
"""
def __init__(
self,
dim_in,
num_classes,
pool_size,
dropout_rate=0.0, ):
"""
ResNetBasicHead takes p pathways as input where p in [1, infty].
Args:
dim_in (list): the list of channel dimensions of the p inputs to the
ResNetHead.
num_classes (int): the channel dimensions of the p outputs to the
ResNetHead.
pool_size (list): the list of kernel sizes of p spatial temporal
poolings, temporal pool kernel size, spatial pool kernel size,
spatial pool kernel size in order.
dropout_rate (float): dropout rate. If equal to 0.0, perform no
dropout.
"""
super(ResNetBasicHead, self).__init__()
assert (len({len(pool_size), len(dim_in)}) == 1
), "pathway dimensions are not consistent."
self.num_pathways = len(pool_size)
self.pool_size = pool_size
self.dropout_rate = dropout_rate
fc_init_std = 0.01
initializer_tmp = fluid.initializer.NormalInitializer(
loc=0.0, scale=fc_init_std)
self.projection = fluid.dygraph.Linear(
input_dim=sum(dim_in),
output_dim=num_classes,
param_attr=fluid.ParamAttr(initializer=initializer_tmp),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.0)), )
def forward(self, inputs, training):
assert (
len(inputs) == self.num_pathways
), "Input tensor does not contain {} pathway".format(self.num_pathways)
pool_out = []
for pathway in range(self.num_pathways):
tmp_out = fluid.layers.pool3d(
input=inputs[pathway],
pool_type="avg",
pool_size=self.pool_size[pathway],
pool_stride=1,
data_format="NCDHW")
pool_out.append(tmp_out)
x = fluid.layers.concat(input=pool_out, axis=1, name=None)
x = fluid.layers.transpose(x=x, perm=(0, 2, 3, 4, 1))
# Perform dropout.
if self.dropout_rate > 0.0:
x = fluid.layers.dropout(
x,
dropout_prob=self.dropout_rate,
dropout_implementation='upscale_in_train')
x = self.projection(x)
# Performs fully convlutional inference.
if not training:
x = fluid.layers.softmax(x, axis=4)
x = fluid.layers.reduce_mean(x, dim=[1, 2, 3])
x = fluid.layers.reshape(x, shape=(x.shape[0], -1))
return x
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import argparse
import ast
import logging
import numpy as np
import json
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.io import DataLoader, Dataset
from paddle.incubate.hapi.distributed import DistributedBatchSampler, _all_gather
from paddle.fluid.dygraph.parallel import ParallelEnv
from model import *
from config_utils import *
from kinetics_dataset import KineticsDataset
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(
"SLOWFAST test for performance evaluation.")
parser.add_argument(
'--config_file',
type=str,
default='slowfast.yaml',
help='path to config file of model')
parser.add_argument(
'--batch_size',
type=int,
default=None,
help='total eval batch size of all gpus.')
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='default use gpu.')
parser.add_argument(
'--use_data_parallel',
type=ast.literal_eval,
default=True,
help='default use data parallel.')
parser.add_argument(
'--weights',
type=str,
default=None,
help='weight path, None to use config setting.')
parser.add_argument(
'--log_interval',
type=int,
default=1,
help='mini-batch interval to log.')
parser.add_argument(
'--save_path',
type=str,
default=None,
help='save path, None to use config setting.')
args = parser.parse_args()
return args
# Prediction
def infer_slowfast(args):
config = parse_config(args.config_file)
infer_config = merge_configs(config, 'infer', vars(args))
print_configs(infer_config, "Infer")
if not os.path.isdir(infer_config.INFER.save_path):
os.makedirs(infer_config.INFER.save_path)
if not args.use_gpu:
place = fluid.CPUPlace()
elif not args.use_data_parallel:
place = fluid.CUDAPlace(0)
else:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
_nranks = ParallelEnv().nranks # num gpu
bs_single = int(infer_config.INFER.batch_size /
_nranks) # batch_size of each gpu
with fluid.dygraph.guard(place):
#build model
slowfast = SlowFast(cfg=infer_config, num_classes=400)
if args.weights:
assert os.path.exists(args.weights + '.pdparams'),\
"Given weight dir {} not exist.".format(args.weights)
logger.info('load test weights from {}'.format(args.weights))
model_dict, _ = fluid.load_dygraph(args.weights)
slowfast.set_dict(model_dict)
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
slowfast = fluid.dygraph.parallel.DataParallel(slowfast, strategy)
#create reader
infer_data = KineticsDataset(mode="infer", cfg=infer_config)
infer_sampler = DistributedBatchSampler(
infer_data, batch_size=bs_single, shuffle=False, drop_last=False)
infer_loader = DataLoader(
infer_data,
batch_sampler=infer_sampler,
places=place,
feed_list=None,
num_workers=0,
return_list=True)
# start infer
num_ensemble_views = infer_config.INFER.num_ensemble_views
num_spatial_crops = infer_config.INFER.num_spatial_crops
num_cls = infer_config.MODEL.num_classes
num_clips = num_ensemble_views * num_spatial_crops
num_videos = len(infer_data) // num_clips
video_preds = np.zeros((num_videos, num_cls))
clip_count = {}
video_paths = []
with open(infer_config.INFER.filelist, "r") as f:
for path in f.read().splitlines():
video_paths.append(path)
print(
"[INFER] infer start, number of videos {}, number of clips {}, total number of clips {}".
format(num_videos, num_clips, num_clips * num_videos))
slowfast.eval()
for batch_id, data in enumerate(infer_loader):
# call net
model_inputs = [data[0], data[1]]
preds = slowfast(model_inputs, training=False)
clip_ids = data[3]
# gather mulit card, results of following process in each card is the same.
if _nranks > 1:
preds = _all_gather(preds, _nranks)
clip_ids = _all_gather(clip_ids, _nranks)
# to numpy
preds = preds.numpy()
clip_ids = clip_ids.numpy()
# preds ensemble
for ind in range(preds.shape[0]):
vid_id = int(clip_ids[ind]) // num_clips
ts_idx = int(clip_ids[ind]) % num_clips
if vid_id not in clip_count:
clip_count[vid_id] = []
if ts_idx in clip_count[vid_id]:
print(
"[INFER] Passed!! read video {} clip index {} / {} repeatedly.".
format(vid_id, ts_idx, clip_ids[ind]))
else:
clip_count[vid_id].append(ts_idx)
video_preds[vid_id] += preds[ind] # ensemble method: sum
if batch_id % args.log_interval == 0:
print("[INFER] Processing batch {}/{} ...".format(
batch_id, len(infer_data) // infer_config.INFER.batch_size))
# check clip index of each video
for key in clip_count.keys():
if len(clip_count[key]) != num_clips or sum(clip_count[
key]) != num_clips * (num_clips - 1) / 2:
print(
"[INFER] Warning!! video [{}] clip count [{}] not match number clips {}".
format(key, clip_count[key], num_clips))
res_list = []
for j in range(video_preds.shape[0]):
pred = to_variable(video_preds[j] / num_clips) #mean prob
video_path = video_paths[j]
pred = to_variable(pred)
top1_values, top1_indices = fluid.layers.topk(pred, k=1)
top5_values, top5_indices = fluid.layers.topk(pred, k=5)
top1_values = top1_values.numpy().astype("float64")[0]
top1_indices = int(top1_indices.numpy()[0])
top5_values = list(top5_values.numpy().astype("float64"))
top5_indices = [int(item) for item in top5_indices.numpy()
] #np.int is not JSON serializable
print("[INFER] video id [{}], top1 value {}, top1 indices {}".
format(video_path, top1_values, top1_indices))
print("[INFER] video id [{}], top5 value {}, top5 indices {}".
format(video_path, top5_values, top5_indices))
save_dict = {
'video_id': video_path,
'top1_values': top1_values,
'top1_indices': top1_indices,
'top5_values': top5_values,
'top5_indices': top5_indices
}
res_list.append(save_dict)
with open(
os.path.join(infer_config.INFER.save_path, 'result' + '.json'),
'w') as f:
json.dump(res_list, f)
print('[INFER] infer finished, results saved in {}'.format(
infer_config.INFER.save_path))
if __name__ == "__main__":
args = parse_args()
logger.info(args)
infer_slowfast(args)
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python3.7 -m paddle.distributed.launch \
eval.py \
--config=slowfast.yaml \
--use_gpu=True \
--use_data_parallel=1 \
--weights=checkpoints/slowfast_epoch195
export CUDA_VISIBLE_DEVICES=0
python3.7 eval.py \
--config=slowfast-single.yaml \
--use_gpu=True \
--use_data_parallel=0 \
--weights=checkpoints/slowfast_epoch195
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python3.7 -m paddle.distributed.launch \
predict.py \
--config=slowfast.yaml \
--use_gpu=True \
--use_data_parallel=1 \
--weights=checkpoints/slowfast_epoch195
export CUDA_VISIBLE_DEVICES=0
python3.7 predict.py \
--config=slowfast-single.yaml \
--use_gpu=True \
--use_data_parallel=0 \
--weights=checkpoints/slowfast_epoch195
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
start_time=$(date +%s)
python3.7 -m paddle.distributed.launch --log_dir=logs \
train.py \
--config=slowfast.yaml \
--use_gpu=True \
--use_data_parallel=1 \
end_time=$(date +%s)
cost_time=$[ $end_time-$start_time ]
echo "8 card bs=64, 196 epoch 34 warmup epoch, 400 class, preciseBN 200 iter build kernel time is $(($cost_time/60))min $(($cost_time%60))s"
export CUDA_VISIBLE_DEVICES=0
start_time=$(date +%s)
python3.7 train.py --config=slowfast-single.yaml --use_gpu=True --use_data_parallel=0
end_time=$(date +%s)
cost_time=$[ $end_time-$start_time ]
echo "1 card bs=8, 196 epoch 34 warmup epoch, 400 class, preciseBN 200 iter build kernel time is $(($cost_time/60))min $(($cost_time%60))s"
MODEL:
name: "SLOWFAST"
format: "mp4"
num_classes: 400
num_frames: 32
sampling_rate: 2
target_fps: 30
alpha: 8
beta: 8
crop_size: 224
image_mean: [0.45, 0.45, 0.45]
image_std: [0.225, 0.225, 0.225]
TRAIN:
epoch: 196
target_size: 224
min_size: 256
max_size: 320
batch_size: 8
use_gpu: True
num_gpus: 1
filelist: "./data/train.csv"
base_lr: 0.1
warmup_epochs: 34
warmup_start_lr: 0.01
l2_weight_decay: 1e-4
momentum: 0.9
VALID:
use_preciseBN: True
preciseBN_interval: 10
num_batches_preciseBN: 200
target_size: 224
min_size: 256
max_size: 320
batch_size: 8
filelist: "./data/val.csv"
TEST:
target_size: 256
batch_size: 8
filelist: "./data/val.csv"
num_ensemble_views: 10
num_spatial_crops: 3
INFER:
target_size: 256
batch_size: 8
filelist: "./data/infer.csv"
num_ensemble_views: 10
num_spatial_crops: 3
save_path: "./data/results"
MODEL:
name: "SLOWFAST"
format: "mp4"
num_classes: 400
num_frames: 32
sampling_rate: 2
target_fps: 30
alpha: 8
beta: 8
crop_size: 224
image_mean: [0.45, 0.45, 0.45]
image_std: [0.225, 0.225, 0.225]
TRAIN:
epoch: 196
target_size: 224
min_size: 256
max_size: 320
batch_size: 64
use_gpu: True
num_gpus: 8
filelist: "./data/train.csv"
base_lr: 0.1
warmup_epochs: 34
warmup_start_lr: 0.01
l2_weight_decay: 1e-4
momentum: 0.9
VALID:
use_preciseBN: True
preciseBN_interval: 10
num_batches_preciseBN: 200
target_size: 224
min_size: 256
max_size: 320
batch_size: 64
filelist: "./data/val.csv"
TEST:
target_size: 256
batch_size: 64
filelist: "./data/val.csv"
weights: "checkpoints/slowfast_epoch195"
num_ensemble_views: 10
num_spatial_crops: 3
INFER:
target_size: 256
batch_size: 32
filelist: "./data/infer.csv"
weights: "checkpoints/slowfast_epoch195"
num_ensemble_views: 10
num_spatial_crops: 3
save_path: "./data/results"
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import time
import argparse
import ast
import logging
import itertools
import numpy as np
import random
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.io import DataLoader, Dataset
from paddle.incubate.hapi.distributed import DistributedBatchSampler
from model import *
from config_utils import *
from lr_policy import get_epoch_lr
from kinetics_dataset import KineticsDataset
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser("SlowFast train")
parser.add_argument(
'--config',
type=str,
default='slowfast.yaml',
help='path to config file of model')
parser.add_argument(
'--use_visualdl',
type=ast.literal_eval,
default=False,
help='whether to use visual dl.')
parser.add_argument(
'--vd_logdir',
type=str,
default='./vdlog',
help='default save visualdl_log in ./vdlog.')
parser.add_argument(
'--use_data_parallel',
type=ast.literal_eval,
default=True,
help='default use data parallel.')
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='default use gpu.')
parser.add_argument(
'--epoch',
type=int,
default=None,
help='epoch number, None to read from config file')
parser.add_argument(
'--batch_size',
type=int,
default=None,
help='training batch size. None to use config file setting.')
parser.add_argument(
'--save_dir',
type=str,
default='./checkpoints',
help='default model save in ./checkpoints.')
parser.add_argument(
'--resume',
type=ast.literal_eval,
default=False,
help='whether to resume training')
parser.add_argument(
'--resume_epoch',
type=int,
default=100000,
help='epoch to resume training based on latest saved checkpoints. ')
parser.add_argument(
'--valid_interval',
type=int,
default=1,
help='validation epoch interval, 0 for no validation.')
parser.add_argument(
'--log_interval',
type=int,
default=10,
help='mini-batch interval to log.')
args = parser.parse_args()
return args
def val(epoch, model, valid_loader, use_visualdl):
val_iter_num = len(valid_loader)
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0
for batch_id, data in enumerate(valid_loader):
y_data = data[2]
labels = to_variable(y_data)
labels.stop_gradient = True
model_inputs = [data[0], data[1]]
preds = model(model_inputs, training=False)
loss_out = fluid.layers.softmax_with_cross_entropy(
logits=preds, label=labels)
avg_loss = fluid.layers.mean(loss_out)
acc_top1 = fluid.layers.accuracy(input=preds, label=labels, k=1)
acc_top5 = fluid.layers.accuracy(input=preds, label=labels, k=5)
total_loss += avg_loss.numpy()[0]
total_acc1 += acc_top1.numpy()[0]
total_acc5 += acc_top5.numpy()[0]
total_sample += 1
if use_visualdl:
vdl_writer.add_scalar(
tag="val/loss",
step=epoch * val_iter_num + batch_id,
value=avg_loss.numpy())
vdl_writer.add_scalar(
tag="val/err1",
step=epoch * val_iter_num + batch_id,
value=1.0 - acc_top1.numpy())
vdl_writer.add_scalar(
tag="val/err5",
step=epoch * val_iter_num + batch_id,
value=1.0 - acc_top5.numpy())
print( "[Test Epoch %d, batch %d] loss %.5f, err1 %.5f, err5 %.5f" % \
(epoch, batch_id, avg_loss.numpy(), 1.0 - acc_top1.numpy(), 1. - acc_top5.numpy()))
print( '[TEST Epoch %d end] avg_loss %.5f, avg_err1 %.5f, avg_err5= %.5f' % \
(epoch, total_loss / total_sample, 1. - total_acc1 / total_sample, 1. - total_acc5 / total_sample))
if use_visualdl:
vdl_writer.add_scalar(
tag="val_epoch/loss", step=epoch, value=total_loss / total_sample)
vdl_writer.add_scalar(
tag="val_epoch/err1",
step=epoch,
value=1.0 - total_acc1 / total_sample)
vdl_writer.add_scalar(
tag="val_epoch/err5",
step=epoch,
value=1.0 - total_acc5 / total_sample)
def create_optimizer(cfg, data_size, params):
l2_weight_decay = cfg.l2_weight_decay
momentum = cfg.momentum
lr_list = []
bd_list = []
cur_bd = 1
for cur_epoch in range(cfg.epoch):
for cur_iter in range(data_size):
cur_lr = get_epoch_lr(cur_epoch + float(cur_iter) / data_size, cfg)
lr_list.append(cur_lr)
bd_list.append(cur_bd)
cur_bd += 1
bd_list.pop()
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd_list, values=lr_list),
momentum=momentum,
regularization=fluid.regularizer.L2Decay(l2_weight_decay),
use_nesterov=True,
parameter_list=params)
return optimizer
def precise_BN(model, data_loader, num_iters=200):
"""
Recompute and update the batch norm stats to make them more precise. During
training both BN stats and the weight are changing after every iteration, so
the running average can not precisely reflect the actual stats of the
current model.
In this function, the BN stats are recomputed with fixed weights, to make
the running average more precise. Specifically, it computes the true average
of per-batch mean/variance instead of the running average.
This is useful to improve validation accuracy.
:param model: the model whose bn stats will be recomputed
:param data_loader: an iterator. Produce data as input to the model
:param num_iters: number of iterations to compute the stats.
:return: the model with precise mean and variance in bn layers.
"""
bn_layers_list = [
m for m in model.sublayers()
if isinstance(m, paddle.fluid.dygraph.nn.BatchNorm) and not m._is_test
]
if len(bn_layers_list) == 0:
return
# moving_mean=moving_mean*momentum+batch_mean*(1.−momentum)
# we set momentum=0. to get the true mean and variance during forward
momentum_actual = [bn._momentum for bn in bn_layers_list]
for bn in bn_layers_list:
bn._momentum = 0.
running_mean = [
fluid.layers.zeros_like(bn._mean) for bn in bn_layers_list
] #pre-ignore
running_var = [
fluid.layers.zeros_like(bn._variance) for bn in bn_layers_list
]
ind = -1
for ind, data in enumerate(itertools.islice(data_loader, num_iters)):
model_inputs = [data[0], data[1]]
model(model_inputs, training=True)
for i, bn in enumerate(bn_layers_list):
# Accumulates the bn stats.
running_mean[i] += (bn._mean - running_mean[i]) / (ind + 1)
running_var[i] += (bn._variance - running_var[i]) / (ind + 1)
assert ind == num_iters - 1, (
"update_bn_stats is meant to run for {} iterations, "
"but the dataloader stops at {} iterations.".format(num_iters, ind))
# Sets the precise bn stats.
for i, bn in enumerate(bn_layers_list):
bn._mean.set_value(running_mean[i])
bn._variance.set_value(running_var[i])
bn._momentum = momentum_actual[i]
def train(args):
config = parse_config(args.config)
train_config = merge_configs(config, 'train', vars(args))
valid_config = merge_configs(config, 'valid', vars(args))
print_configs(train_config, 'Train')
# visual dl to visualize trianing process
local_rank = fluid.dygraph.parallel.Env().local_rank
if args.use_visualdl:
try:
from visualdl import LogWriter
vdl_writer = LogWriter(args.vd_logdir + '/' + str(local_rank))
except:
print(
"visualdl is not installed, please install visualdl if you want to use"
)
if not args.use_gpu:
place = fluid.CPUPlace()
elif not args.use_data_parallel:
place = fluid.CUDAPlace(0)
else:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
random.seed(0)
np.random.seed(0)
paddle.framework.manual_seed(0)
with fluid.dygraph.guard(place):
# 1. init net
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
video_model = SlowFast(cfg=train_config, num_classes=400)
if args.use_data_parallel:
video_model = fluid.dygraph.parallel.DataParallel(video_model,
strategy)
bs_denominator = 1
if args.use_gpu:
gpus = os.getenv("CUDA_VISIBLE_DEVICES", "")
if gpus == "":
pass
else:
gpus = gpus.split(",")
num_gpus = len(gpus)
assert num_gpus == train_config.TRAIN.num_gpus, \
"num_gpus({}) set by CUDA_VISIBLE_DEVICES" \
"shoud be the same as that" \
"set in {}({})".format(
num_gpus, args.config, train_config.TRAIN.num_gpus)
bs_denominator = train_config.TRAIN.num_gpus
# 2. reader and optimizer
bs_train_single = int(train_config.TRAIN.batch_size / bs_denominator)
bs_val_single = int(train_config.VALID.batch_size / bs_denominator)
train_data = KineticsDataset(mode="train", cfg=train_config)
valid_data = KineticsDataset(mode="valid", cfg=valid_config)
train_sampler = DistributedBatchSampler(
train_data,
batch_size=bs_train_single,
shuffle=True,
drop_last=True)
train_loader = DataLoader(
train_data,
batch_sampler=train_sampler,
places=place,
feed_list=None,
num_workers=8,
return_list=True)
valid_sampler = DistributedBatchSampler(
valid_data,
batch_size=bs_val_single,
shuffle=False,
drop_last=False)
valid_loader = DataLoader(
valid_data,
batch_sampler=valid_sampler,
places=place,
feed_list=None,
num_workers=8,
return_list=True)
train_iter_num = len(train_loader)
optimizer = create_optimizer(train_config.TRAIN, train_iter_num,
video_model.parameters())
#3. load checkpoint
if args.resume:
saved_path = "slowfast_epoch" #default
model_path = saved_path + args.resume_epoch
assert os.path.exists(model_path + ".pdparams"), \
"Given dir {}.pdparams not exist.".format(model_path)
assert os.path.exists(model_path + ".pdopt"), \
"Given dir {}.pdopt not exist.".format(model_path)
para_dict, opti_dict = fluid.dygraph.load_dygraph(model_path)
video_model.set_dict(para_dict)
optimizer.set_dict(opti_dict)
if args.use_visualdl:
# change log path otherwise log history will be overwritten
vdl_writer = LogWriter(args.vd_logdir + args.resume_epoch + '/'
+ str(local_rank))
# 4. train loop
for epoch in range(train_config.TRAIN.epoch):
epoch_start = time.time()
if args.resume and epoch <= args.resume_epoch:
print("epoch:{}<=args.resume_epoch:{}, pass".format(
epoch, args.resume_epoch))
continue
video_model.train()
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0
print('start for, Epoch {}/{} '.format(epoch,
train_config.TRAIN.epoch))
batch_start = time.time()
for batch_id, data in enumerate(train_loader):
batch_reader_end = time.time()
y_data = data[2]
labels = to_variable(y_data)
labels.stop_gradient = True
model_inputs = [data[0], data[1]]
# 4.1.1 call net()
preds = video_model(model_inputs, training=True)
loss_out = fluid.layers.softmax_with_cross_entropy(
logits=preds, label=labels)
avg_loss = fluid.layers.mean(loss_out)
acc_top1 = fluid.layers.accuracy(input=preds, label=labels, k=1)
acc_top5 = fluid.layers.accuracy(input=preds, label=labels, k=5)
# 4.1.2 call backward()
if args.use_data_parallel:
avg_loss = video_model.scale_loss(avg_loss)
avg_loss.backward()
video_model.apply_collective_grads()
else:
avg_loss.backward()
# 4.1.3 call minimize()
optimizer.minimize(avg_loss)
video_model.clear_gradients()
total_loss += avg_loss.numpy()[0]
total_acc1 += acc_top1.numpy()[0]
total_acc5 += acc_top5.numpy()[0]
total_sample += 1
if args.use_visualdl:
vdl_writer.add_scalar(
tag="train/loss",
step=epoch * train_iter_num + batch_id,
value=avg_loss.numpy())
vdl_writer.add_scalar(
tag="train/err1",
step=epoch * train_iter_num + batch_id,
value=1.0 - acc_top1.numpy())
vdl_writer.add_scalar(
tag="train/err5",
step=epoch * train_iter_num + batch_id,
value=1.0 - acc_top5.numpy())
train_batch_cost = time.time() - batch_start
train_reader_cost = batch_reader_end - batch_start
batch_start = time.time()
if batch_id % args.log_interval == 0:
print( "[Epoch %d, batch %d] loss %.5f, err1 %.5f, err5 %.5f, batch_cost: %.5f s, reader_cost: %.5f s" % \
(epoch, batch_id, avg_loss.numpy(), 1.0 - acc_top1.numpy(), 1. - acc_top5.numpy(), train_batch_cost, train_reader_cost))
train_epoch_cost = time.time() - epoch_start
print( '[Epoch %d end] avg_loss %.5f, avg_err1 %.5f, avg_err5= %.5f, epoch_cost: %.5f s' % \
(epoch, total_loss / total_sample, 1. - total_acc1 / total_sample, 1. - total_acc5 / total_sample, train_epoch_cost))
if args.use_visualdl:
vdl_writer.add_scalar(
tag="train_epoch/loss",
step=epoch,
value=total_loss / total_sample)
vdl_writer.add_scalar(
tag="train_epoch/err1",
step=epoch,
value=1. - total_acc1 / total_sample)
vdl_writer.add_scalar(
tag="train_epoch/err5",
step=epoch,
value=1. - total_acc5 / total_sample)
# 4.3 do preciseBN
if valid_config.VALID.use_preciseBN and epoch % valid_config.VALID.preciseBN_interval == 0:
print("do precise BN in epoch {} ...".format(epoch))
precise_BN(video_model, train_loader,
min(valid_config.VALID.num_batches_preciseBN,
len(train_loader)))
# 4.3 save checkpoint
if local_rank == 0:
if not os.path.isdir(args.save_dir):
os.makedirs(args.save_dir)
model_path = os.path.join(args.save_dir,
"slowfast_epoch{}".format(epoch))
fluid.dygraph.save_dygraph(video_model.state_dict(), model_path)
fluid.dygraph.save_dygraph(optimizer.state_dict(), model_path)
print('save_dygraph End, Epoch {}/{} '.format(
epoch, train_config.TRAIN.epoch))
# 4.4 validation
video_model.eval()
val(epoch, video_model, valid_loader, args.use_visualdl)
logger.info('[TRAIN] training finished')
if __name__ == "__main__":
args = parse_args()
logger.info(args)
train(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册