diff --git a/dygraph/slowfast/README.md b/dygraph/slowfast/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3a448d3cb104a2a7456f6a67980a9937d9941061 --- /dev/null +++ b/dygraph/slowfast/README.md @@ -0,0 +1,170 @@ +# SlowFast 视频分类模型动态图实现 + +--- +## 内容 + +- [模型简介](#模型简介) +- [代码结构](#代码结构) +- [安装说明](#安装说明) +- [数据准备](#数据准备) +- [模型训练](#模型训练) +- [模型评估](#模型评估) +- [模型预测](#模型预测) +- [参考论文](#参考论文) + + +## 模型简介 + +SlowFast是视频分类领域的高精度模型,使用slow和fast两个分支。slow分支以稀疏采样得到的帧作为输入,捕捉视频中的表观信息。fast分支以高频采样得到的帧作为输入,捕获视频中的运动信息,最终将两个分支的特征拼接得到预测结果。 + +

+
+SlowFast Overview +

+ +详细内容请参考ICCV 2019论文[SlowFast Networks for Video Recognition](https://arxiv.org/abs/1812.03982) + + +## 代码结构 +``` +├── slowfast.yaml # 多卡配置文件,用户可方便的配置超参数 +├── slowfast-single.yaml # 单卡评估预测配置文件,用户可方便的配置超参数 +├── run_train_multi.sh # 多卡训练脚本 +├── run_train_single.sh # 单卡训练脚本 +├── run_eval_multi.sh # 多卡评估脚本 +├── run_infer_multi.sh # 多卡预测脚本 +├── run_eval_single.sh # 单卡评估脚本 +├── run_infer_single.sh # 单卡预测脚本 +├── train.py # 训练代码 +├── eval.py # 评估代码,评估网络性能 +├── predict.py # 预测代码 +├── model.py # 网络结构 +├── model_utils.py # 网络结构细节相关 +├── lr_policy.py # 学习率调整方式 +├── kinetics_dataset.py # kinetics-400数据预处理代码 +└── config_utils.py # 配置细节相关代码 +``` + +## 安装说明 + +请使用```python3.7```运行样例代码 + +### 环境依赖: + +``` + CUDA >= 9.0 + cudnn >= 7.5 +``` + +### 依赖安装: + +- PaddlePaddle版本>= 2.0.0-alpha0: +``` pip3.7 install paddlepaddle-gpu==2.0.0a0 -i https://mirror.baidu.com/pypi/simple ``` +- 安装opencv 4.2: +``` pip3.7 install opencv-python==4.2.0.34``` +- 如果想可视化训练曲线,请安装VisualDL: +``` pip3.7 install visualdl -i https://mirror.baidu.com/pypi/simple``` + + +## 数据准备 + +SlowFast模型的训练数据采用Kinetics400数据集,直接输入mp4文件进行训练,数据准备方式如下: + +### mp4视频下载 + +在Code\_Root目录下创建文件夹 + +``` + cd $Code_Root/ && mkdir data + cd data && mkdir data_k400 && cd data_k400 + mkdir train_mp4 && mkdir val_mp4 +``` + +ActivityNet官方提供了Kinetics的下载工具,具体参考其[官方repo](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics)即可下载Kinetics400的mp4视频集合。将kinetics400的训练与验证集合分别下载到data/data\_k400/train\_mp4与data/data\_k400/val\_mp4。 + +### 生成训练和验证集list + +验证(Validation)和评估(Evaluation)使用同一份数据集,预测(Predict)可以从验证数据集中挑选少量样本进行预测 + +``` + cd $Code_Root/data/ + ls $Code_Root/data/data_k400/train_mp4/* > train.csv + ls $Code_Root/data/data_k400/val_mp4/* > val.csv + ls $Code_Root/data/data_k400/val_mp4/* > test.csv + ls $Code_Root/data/data_k400/val_mp4/* > infer.csv +``` + +## 模型训练 + +数据准备完成后,可通过如下两种方式启动训练: + +默认使用8卡训练,启动方式如下: + + bash run_train_multi.sh + +若使用单卡训练,启动方式如下: + + bash run_train_single.sh + +- 建议使用多卡训练方式,单卡由于batch\_size减小,精度可能会有损失。 + +- 从头开始训练,使用上述启动命令行或者脚本程序即可启动训练,不需要用到预训练模型。 + +- Visual DL可以用来对训练过程进行可视化,具体使用方法请参考[VisualDL](https://github.com/PaddlePaddle/VisualDL) + +**训练资源要求:** + +* 8卡V100,总batch\_size=64,单卡batch\_size=8,单卡显存占用约9G。 +* Kinetics400训练集较大(约23万个样本),SlowFast模型迭代epoch数较多(196个),因此模型训练耗时较长,约200个小时。 +* 训练加速工作进行中,敬请期待。 + + +## 模型评估 + +训练完成后,可通过如下方式进行模型评估: + +多卡评估方式如下: + + bash run_eval_multi.sh + +若使用单卡评估,启动方式如下: + + bash run_eval_single.sh + +- 进行评估时,可修改脚本中的`weights`参数指定用到的权重文件,如果不设置,将使用默认参数文件checkpoints/slowfast_epoch195.pdparams。 + +- 使用```multi_crop```的方式进行评估,因此评估有一定耗时,建议使用多卡评估,加快评估速度。若使用默认方式进行多卡评估,耗时约4小时。 + +- 模型最终的评估精度会打印在日志文件中。 + + +在Kinetics400数据集下评估精度如下: + +| Acc1 | Acc5 | +| :---: | :---: | +| 74.35 | 91.33 | + +- 由于Kinetics400数据集部分源文件已缺失,无法下载,我们使用的数据集比官方数据少~5%,因此精度相比于论文公布的结果有一定损失。 + +## 模型预测 + +训练完成后,可通过如下方式进行模型预测: + +多卡预测方式如下: + + bash run_infer_multi.sh + +若使用单卡预测,启动方式如下: + + bash run_infer_single.sh + +- 进行预测时,可修改脚本中的`weights`参数指定用到的权重文件,如果不设置,将使用默认参数文件checkpoints/slowfast_epoch195.pdparams。 + +- 使用```multi_crop```的方式进行评估,因此单个视频文件预测也有一定耗时。 + +- 预测结果保存在./data/results/result.json文件中。 + + +## 参考论文 + +- [SlowFast Networks for Video Recognition](https://arxiv.org/abs/1812.03982) diff --git a/dygraph/slowfast/SLOWFAST.png b/dygraph/slowfast/SLOWFAST.png new file mode 100644 index 0000000000000000000000000000000000000000..8de85e26aabb1468da97a3503b6cb87317ffdda3 Binary files /dev/null and b/dygraph/slowfast/SLOWFAST.png differ diff --git a/dygraph/slowfast/config_utils.py b/dygraph/slowfast/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4fedd1b246b27f6e3ddfd8d12dfcec51e7737e5b --- /dev/null +++ b/dygraph/slowfast/config_utils.py @@ -0,0 +1,85 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import yaml +import logging +logger = logging.getLogger(__name__) + +CONFIG_SECS = [ + 'train', + 'valid', + 'test', + 'infer', +] + + +class AttrDict(dict): + def __getattr__(self, key): + return self[key] + + def __setattr__(self, key, value): + if key in self.__dict__: + self.__dict__[key] = value + else: + self[key] = value + + +def parse_config(cfg_file): + """Load a config file into AttrDict""" + import yaml + with open(cfg_file, 'r') as fopen: + yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.Loader)) + create_attr_dict(yaml_config) + return yaml_config + + +def create_attr_dict(yaml_config): + from ast import literal_eval + for key, value in yaml_config.items(): + if type(value) is dict: + yaml_config[key] = value = AttrDict(value) + if isinstance(value, str): + try: + value = literal_eval(value) + except BaseException: + pass + if isinstance(value, AttrDict): + create_attr_dict(yaml_config[key]) + else: + yaml_config[key] = value + return + + +def merge_configs(cfg, sec, args_dict): + assert sec in CONFIG_SECS, "invalid config section {}".format(sec) + sec_dict = getattr(cfg, sec.upper()) + for k, v in args_dict.items(): + if v is None: + continue + try: + if hasattr(sec_dict, k): + setattr(sec_dict, k, v) + except: + pass + return cfg + + +def print_configs(cfg, mode): + logger.info("---------------- {:>5} Arguments ----------------".format( + mode)) + for sec, sec_items in cfg.items(): + logger.info("{}:".format(sec)) + for k, v in sec_items.items(): + logger.info(" {}:{}".format(k, v)) + logger.info("-------------------------------------------------") diff --git a/dygraph/slowfast/eval.py b/dygraph/slowfast/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..985a8eca07501c08d95ab674ed72cfa84fe1af20 --- /dev/null +++ b/dygraph/slowfast/eval.py @@ -0,0 +1,191 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import os +import sys +import argparse +import ast +import logging +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.dygraph.base import to_variable +from paddle.io import DataLoader, Dataset +from paddle.incubate.hapi.distributed import DistributedBatchSampler, _all_gather +from paddle.fluid.dygraph.parallel import ParallelEnv + +from model import * +from config_utils import * +from kinetics_dataset import KineticsDataset + +logging.root.handlers = [] +FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser( + "SLOWFAST test for performance evaluation.") + parser.add_argument( + '--config_file', + type=str, + default='slowfast.yaml', + help='path to config file of model') + parser.add_argument( + '--batch_size', + type=int, + default=None, + help='total eval batch size of all gpus.') + parser.add_argument( + '--use_gpu', + type=ast.literal_eval, + default=True, + help='default use gpu.') + parser.add_argument( + '--use_data_parallel', + type=ast.literal_eval, + default=True, + help='default use data parallel.') + parser.add_argument( + '--weights', + type=str, + default=None, + help='Weight path, None to use config setting.') + parser.add_argument( + '--log_interval', + type=int, + default=1, + help='mini-batch interval to log.') + args = parser.parse_args() + return args + + +# Performance Evaluation +def test_slowfast(args): + config = parse_config(args.config_file) + test_config = merge_configs(config, 'test', vars(args)) + print_configs(test_config, "Test") + + if not args.use_gpu: + place = fluid.CPUPlace() + elif not args.use_data_parallel: + place = fluid.CUDAPlace(0) + else: + place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) + + _nranks = ParallelEnv().nranks # num gpu + bs_single = int(test_config.TEST.batch_size / + _nranks) # batch_size of each gpu + + with fluid.dygraph.guard(place): + #build model + slowfast = SlowFast(cfg=test_config, num_classes=400) + if args.weights: + assert os.path.exists(args.weights + '.pdparams'),\ + "Given weight dir {} not exist.".format(args.weights) + + logger.info('load test weights from {}'.format(args.weights)) + model_dict, _ = fluid.load_dygraph(args.weights) + slowfast.set_dict(model_dict) + + if args.use_data_parallel: + strategy = fluid.dygraph.parallel.prepare_context() + slowfast = fluid.dygraph.parallel.DataParallel(slowfast, strategy) + + #create reader + test_data = KineticsDataset(mode="test", cfg=test_config) + test_sampler = DistributedBatchSampler( + test_data, batch_size=bs_single, shuffle=False, drop_last=False) + test_loader = DataLoader( + test_data, + batch_sampler=test_sampler, + places=place, + feed_list=None, + num_workers=8, + return_list=True) + + # start eval + num_ensemble_views = test_config.TEST.num_ensemble_views + num_spatial_crops = test_config.TEST.num_spatial_crops + num_cls = test_config.MODEL.num_classes + num_clips = num_ensemble_views * num_spatial_crops + num_videos = len(test_data) // num_clips + video_preds = np.zeros((num_videos, num_cls)) + video_labels = np.zeros((num_videos, 1), dtype="int64") + clip_count = {} + + print( + "[EVAL] eval start, number of videos {}, total number of clips {}". + format(num_videos, num_clips * num_videos)) + slowfast.eval() + for batch_id, data in enumerate(test_loader): + # call net + model_inputs = [data[0], data[1]] + preds = slowfast(model_inputs, training=False) + labels = data[2] + clip_ids = data[3] + + # gather mulit card, results of following process in each card is the same. + if _nranks > 1: + preds = _all_gather(preds, _nranks) + labels = _all_gather(labels, _nranks) + clip_ids = _all_gather(clip_ids, _nranks) + + # to numpy + preds = preds.numpy() + labels = labels.numpy().astype("int64") + clip_ids = clip_ids.numpy() + + # preds ensemble + for ind in range(preds.shape[0]): + vid_id = int(clip_ids[ind]) // num_clips + ts_idx = int(clip_ids[ind]) % num_clips + if vid_id not in clip_count: + clip_count[vid_id] = [] + if ts_idx in clip_count[vid_id]: + print( + "[EVAL] Passed!! read video {} clip index {} / {} repeatedly.". + format(vid_id, ts_idx, clip_ids[ind])) + else: + clip_count[vid_id].append(ts_idx) + video_preds[vid_id] += preds[ind] # ensemble method: sum + if video_labels[vid_id].sum() > 0: + assert video_labels[vid_id] == labels[ind] + video_labels[vid_id] = labels[ind] + if batch_id % args.log_interval == 0: + print("[EVAL] Processing batch {}/{} ...".format( + batch_id, len(test_data) // test_config.TEST.batch_size)) + + # check clip index of each video + for key in clip_count.keys(): + if len(clip_count[key]) != num_clips or sum(clip_count[ + key]) != num_clips * (num_clips - 1) / 2: + print( + "[EVAL] Warning!! video [{}] clip count [{}] not match number clips {}". + format(key, clip_count[key], num_clips)) + + video_preds = to_variable(video_preds) + video_labels = to_variable(video_labels) + acc_top1 = fluid.layers.accuracy( + input=video_preds, label=video_labels, k=1) + acc_top5 = fluid.layers.accuracy( + input=video_preds, label=video_labels, k=5) + print('[EVAL] eval finished, avg_acc1= {}, avg_acc5= {} '.format( + acc_top1.numpy(), acc_top5.numpy())) + + +if __name__ == "__main__": + args = parse_args() + logger.info(args) + test_slowfast(args) diff --git a/dygraph/slowfast/kinetics_dataset.py b/dygraph/slowfast/kinetics_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f5689708f7b5fdcfc7334bf7ae91bc8445dc18b1 --- /dev/null +++ b/dygraph/slowfast/kinetics_dataset.py @@ -0,0 +1,315 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import math +import random +import numpy as np +from PIL import Image, ImageEnhance +import logging +from paddle.io import Dataset + +logger = logging.getLogger(__name__) + +__all__ = ['KineticsDataset'] + + +class KineticsDataset(Dataset): + def __init__(self, mode, cfg): + self.mode = mode + self.format = cfg.MODEL.format + self.num_frames = cfg.MODEL.num_frames + self.sampling_rate = cfg.MODEL.sampling_rate + self.target_fps = cfg.MODEL.target_fps + self.slowfast_alpha = cfg.MODEL.alpha + + self.target_size = cfg[mode.upper()]['target_size'] + self.img_mean = cfg.MODEL.image_mean + self.img_std = cfg.MODEL.image_std + self.filelist = cfg[mode.upper()]['filelist'] + + if self.mode in ["train", "valid"]: + self.min_size = cfg[mode.upper()]['min_size'] + self.max_size = cfg[mode.upper()]['max_size'] + self.num_ensemble_views = 1 + self.num_spatial_crops = 1 + self._num_clips = 1 + elif self.mode in ['test', 'infer']: + self.min_size = self.max_size = self.target_size + self.num_ensemble_views = cfg.TEST.num_ensemble_views + self.num_spatial_crops = cfg.TEST.num_spatial_crops + self._num_clips = (self.num_ensemble_views * self.num_spatial_crops) + + self._construct_loader() + + def _construct_loader(self): + """ + Construct the video loader. + """ + self._num_retries = 5 + self._path_to_videos = [] + self._labels = [] + self._spatial_temporal_idx = [] + with open(self.filelist, "r") as f: + for clip_idx, path_label in enumerate(f.read().splitlines()): + if self.mode == 'infer': + path = path_label + label = 0 # without label when infer actually + else: + path, label = path_label.split() + for idx in range(self._num_clips): + self._path_to_videos.append(path) + self._labels.append(int(label)) + self._spatial_temporal_idx.append(idx) + + def __len__(self): + return len(self._path_to_videos) + + def __getitem__(self, idx): + if self.mode in ["train", "valid"]: + temporal_sample_index = -1 + spatial_sample_index = -1 + elif self.mode in ["test", 'infer']: + temporal_sample_index = (self._spatial_temporal_idx[idx] // + self.num_spatial_crops) + spatial_sample_index = (self._spatial_temporal_idx[idx] % + self.num_spatial_crops) + + for ir in range(self._num_retries): + mp4_path = self._path_to_videos[idx] + try: + pathways = self.mp4_loader( + mp4_path, + temporal_sample_index, + spatial_sample_index, + temporal_num_clips=self.num_ensemble_views, + spatial_num_clips=self.num_spatial_crops, + num_frames=self.num_frames, + sampling_rate=self.sampling_rate, + target_fps=self.target_fps, + target_size=self.target_size, + img_mean=self.img_mean, + img_std=self.img_std, + slowfast_alpha=self.slowfast_alpha, + min_size=self.min_size, + max_size=self.max_size) + except: + if ir < self._num_retries - 1: + logger.error( + 'Error when loading {}, have {} trys, will try again'. + format(mp4_path, ir)) + idx = random.randint(0, len(self._path_to_videos) - 1) + continue + else: + logger.error( + 'Error when loading {}, have {} trys, will not try again'. + format(mp4_path, ir)) + return None, None + label = self._labels[idx] + return pathways[0], pathways[1], np.array([label]), np.array([idx]) + + def mp4_loader(self, filepath, temporal_sample_index, spatial_sample_index, + temporal_num_clips, spatial_num_clips, num_frames, + sampling_rate, target_fps, target_size, img_mean, img_std, + slowfast_alpha, min_size, max_size): + frames_sample, clip_size = self.decode_sampling( + filepath, temporal_sample_index, temporal_num_clips, num_frames, + sampling_rate, target_fps) + frames_select = self.temporal_sampling( + frames_sample, clip_size, num_frames, filepath, + temporal_sample_index, temporal_num_clips) + frames_resize = self.scale(frames_select, min_size, max_size) + frames_crop = self.crop(frames_resize, target_size, + spatial_sample_index, spatial_num_clips) + frames_flip = self.flip(frames_crop, spatial_sample_index) + + #list to nparray + npframes = (np.stack(frames_flip)).astype('float32') + npframes_norm = self.color_norm(npframes, img_mean, img_std) + frames_list = self.pack_output(npframes_norm, slowfast_alpha) + + return frames_list + + def get_start_end_idx(self, video_size, clip_size, clip_idx, + temporal_num_clips): + delta = max(video_size - clip_size, 0) + if clip_idx == -1: # when test, temporal_num_clips is not used + # Random temporal sampling. + start_idx = random.uniform(0, delta) + else: + # Uniformly sample the clip with the given index. + start_idx = delta * clip_idx / temporal_num_clips + end_idx = start_idx + clip_size - 1 + return start_idx, end_idx + + def decode_sampling(self, filepath, temporal_sample_index, + temporal_num_clips, num_frames, sampling_rate, + target_fps): + cap = cv2.VideoCapture(filepath) + videolen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + (major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.') + if int(major_ver) < 3: + fps = cap.get(cv2.cv.CV_CAP_PROP_FPS) + else: + fps = cap.get(cv2.CAP_PROP_FPS) + + clip_size = num_frames * sampling_rate * fps / target_fps + + if filepath[-3:] != 'mp4': + start_idx, end_idx = 0, math.inf + else: + start_idx, end_idx = self.get_start_end_idx( + videolen, clip_size, temporal_sample_index, temporal_num_clips) + #print("filepath:",filepath,"start_idx:",start_idx,"end_idx:",end_idx) + + frames_sample = [] #start randomly, decode clip size + start_idx = math.ceil(start_idx) + cap.set(cv2.CAP_PROP_POS_FRAMES, start_idx) + for i in range(videolen): + if i < start_idx: + continue + ret, frame = cap.read() + if ret == False: + continue + if i <= end_idx + 1: #buffer + img = frame[:, :, ::-1] #BGR -> RGB + frames_sample.append(img) + else: + break + return frames_sample, clip_size + + def temporal_sampling(self, frames_sample, clip_size, num_frames, filepath, + temporal_sample_index, temporal_num_clips): + """ sample num_frames from clip_size """ + fs_len = len(frames_sample) + + if filepath[-3:] != 'mp4': + start_idx, end_idx = self.get_start_end_idx( + fs_len, clip_size, temporal_sample_index, temporal_num_clips) + else: + start_idx, end_idx = self.get_start_end_idx(fs_len, clip_size, 0, 1) + + index = np.linspace(start_idx, end_idx, num_frames).astype("int64") + index = np.clip(index, 0, fs_len - 1) + frames_select = [] + for i in range(index.shape[0]): + idx = index[i] + imgbuf = frames_sample[idx] + img = Image.fromarray(imgbuf, mode='RGB') + frames_select.append(img) + + return frames_select + + def scale(self, frames_select, min_size, max_size): + size = int(round(np.random.uniform(min_size, max_size))) + assert (len(frames_select) >= 1) , \ + "len(frames_select):{} should be larger than 1".format(len(frames_select)) + width, height = frames_select[0].size + if (width <= height and width == size) or (height <= width and + height == size): + return frames_select + + new_width = size + new_height = size + if width < height: + new_height = int(math.floor((float(height) / width) * size)) + else: + new_width = int(math.floor((float(width) / height) * size)) + + frames_resize = [] + for j in range(len(frames_select)): + img = frames_select[j] + scale_img = img.resize((new_width, new_height), Image.BILINEAR) + frames_resize.append(scale_img) + + return frames_resize + + def crop(self, frames_resize, target_size, spatial_sample_index, + spatial_num_clips): + w, h = frames_resize[0].size + if w == target_size and h == target_size: + return frames_resize + + assert (w >= target_size) and (h >= target_size), \ + "image width({}) and height({}) should be larger than crop size({},{})".format(w, h, target_size, target_size) + frames_crop = [] + if spatial_sample_index == -1: + x_offset = random.randint(0, w - target_size) + y_offset = random.randint(0, h - target_size) + else: + x_gap = int(math.ceil((w - target_size) / (spatial_num_clips - 1))) + y_gap = int(math.ceil((h - target_size) / (spatial_num_clips - 1))) + if h > w: + x_offset = int(math.ceil((w - target_size) / 2)) + if spatial_sample_index == 0: + y_offset = 0 + elif spatial_sample_index == spatial_num_clips - 1: + y_offset = h - target_size + else: + y_offset = y_gap * spatial_sample_index + else: + y_offset = int(math.ceil((h - target_size) / 2)) + if spatial_sample_index == 0: + x_offset = 0 + elif spatial_sample_index == spatial_num_clips - 1: + x_offset = w - target_size + else: + x_offset = x_gap * spatial_sample_index + + for img in frames_resize: + nimg = img.crop((x_offset, y_offset, x_offset + target_size, + y_offset + target_size)) + frames_crop.append(nimg) + return frames_crop + + def flip(self, frames_crop, spatial_sample_index): + # without flip when test + if spatial_sample_index != -1: + return frames_crop + + frames_flip = [] + if np.random.uniform() < 0.5: + for img in frames_crop: + nimg = img.transpose(Image.FLIP_LEFT_RIGHT) + frames_flip.append(nimg) + else: + frames_flip = frames_crop + return frames_flip + + def color_norm(self, npframes_norm, c_mean, c_std): + npframes_norm /= 255.0 + npframes_norm -= np.array(c_mean).reshape( + [1, 1, 1, 3]).astype(np.float32) + npframes_norm /= np.array(c_std).reshape( + [1, 1, 1, 3]).astype(np.float32) + return npframes_norm + + def pack_output(self, npframes_norm, slowfast_alpha): + fast_pathway = npframes_norm + + # sample num points between start and end + slow_idx_start = 0 + slow_idx_end = fast_pathway.shape[0] - 1 + slow_idx_num = fast_pathway.shape[0] // slowfast_alpha + slow_idxs_select = np.linspace(slow_idx_start, slow_idx_end, + slow_idx_num).astype("int64") + slow_pathway = fast_pathway[slow_idxs_select] + + # T H W C -> C T H W. + slow_pathway = slow_pathway.transpose(3, 0, 1, 2) + fast_pathway = fast_pathway.transpose(3, 0, 1, 2) + + # slow + fast + frames_list = [slow_pathway, fast_pathway] + return frames_list diff --git a/dygraph/slowfast/lr_policy.py b/dygraph/slowfast/lr_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..b834d5a1aa4d2b70094dea2be1e19b5b2c701190 --- /dev/null +++ b/dygraph/slowfast/lr_policy.py @@ -0,0 +1,41 @@ +"""Learning rate policy.""" + +import math + + +def get_epoch_lr(cur_epoch, cfg): + """ + Retrieve the learning rate of the current epoch with the option to perform + warm up in the beginning of the training stage. + Args: + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + cur_epoch (float): the number of epoch of the current training stage. + #""" + warmup_epochs = cfg.warmup_epochs #34 + warmup_start_lr = cfg.warmup_start_lr #0.01 + lr = lr_func_cosine(cur_epoch, cfg) + + # Perform warm up. + if cur_epoch < warmup_epochs: + lr_start = warmup_start_lr + lr_end = lr_func_cosine(warmup_epochs, cfg) + alpha = (lr_end - lr_start) / warmup_epochs + lr = cur_epoch * alpha + lr_start + return lr + + +def lr_func_cosine(cur_epoch, cfg): + """ + Retrieve the learning rate to specified values at specified epoch with the + cosine learning rate schedule. Details can be found in: + Ilya Loshchilov, and Frank Hutter + SGDR: Stochastic Gradient Descent With Warm Restarts. + Args: + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + cur_epoch (float): the number of epoch of the current training stage. + """ + base_lr = cfg.base_lr #0.1 + max_epoch = cfg.epoch #196 + return base_lr * (math.cos(math.pi * cur_epoch / max_epoch) + 1.0) * 0.5 diff --git a/dygraph/slowfast/model.py b/dygraph/slowfast/model.py new file mode 100644 index 0000000000000000000000000000000000000000..d1dabab2e2412d18ee506c22e6abb8640c84b611 --- /dev/null +++ b/dygraph/slowfast/model.py @@ -0,0 +1,417 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle.fluid as fluid +from paddle.fluid.dygraph.base import to_variable + +from model_utils import * + + +class ResNetBasicStem(fluid.dygraph.Layer): + """ + ResNe(X)t 3D stem module. + Performs spatiotemporal Convolution, BN, and Relu following by a + spatiotemporal pooling. + """ + + def __init__( + self, + dim_in, + dim_out, + kernel, + stride, + padding, + eps=1e-5, ): + super(ResNetBasicStem, self).__init__() + self.kernel = kernel + self.stride = stride + self.padding = padding + self.eps = eps + self._construct_stem(dim_in, dim_out) + + def _construct_stem(self, dim_in, dim_out): + fan = (dim_out) * (self.kernel[0] * self.kernel[1] * self.kernel[2]) + initializer_tmp = get_conv_init(fan) + batchnorm_weight = 1.0 + + self._conv = fluid.dygraph.nn.Conv3D( + num_channels=dim_in, + num_filters=dim_out, + filter_size=self.kernel, + stride=self.stride, + padding=self.padding, + param_attr=fluid.ParamAttr(initializer=initializer_tmp), + bias_attr=False) + self._bn = fluid.dygraph.BatchNorm( + num_channels=dim_out, + epsilon=self.eps, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(batchnorm_weight), + regularizer=fluid.regularizer.L2Decay( + regularization_coeff=0.0)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(0.0), + regularizer=fluid.regularizer.L2Decay( + regularization_coeff=0.0))) + + def forward(self, x): + x = self._conv(x) + x = self._bn(x) + x = fluid.layers.relu(x) + x = fluid.layers.pool3d( + input=x, + pool_type="max", + pool_size=[1, 3, 3], + pool_stride=[1, 2, 2], + pool_padding=[0, 1, 1], + data_format="NCDHW") + return x + + +class VideoModelStem(fluid.dygraph.Layer): + """ + Video 3D stem module. Provides stem operations of Conv, BN, ReLU, MaxPool + on input data tensor for slow and fast pathways. + """ + + def __init__( + self, + dim_in, + dim_out, + kernel, + stride, + padding, + eps=1e-5, ): + """ + Args: + dim_in (list): the list of channel dimensions of the inputs. + dim_out (list): the output dimension of the convolution in the stem + layer. + kernel (list): the kernels' size of the convolutions in the stem + layers. Temporal kernel size, height kernel size, width kernel + size in order. + stride (list): the stride sizes of the convolutions in the stem + layer. Temporal kernel stride, height kernel size, width kernel + size in order. + padding (list): the paddings' sizes of the convolutions in the stem + layer. Temporal padding size, height padding size, width padding + size in order. + eps (float): epsilon for batch norm. + """ + super(VideoModelStem, self).__init__() + + assert (len({ + len(dim_in), + len(dim_out), + len(kernel), + len(stride), + len(padding), + }) == 1), "Input pathway dimensions are not consistent." + self.num_pathways = len(dim_in) + self.kernel = kernel + self.stride = stride + self.padding = padding + self.eps = eps + self._construct_stem(dim_in, dim_out) + + def _construct_stem(self, dim_in, dim_out): + for pathway in range(len(dim_in)): + stem = ResNetBasicStem( + dim_in[pathway], + dim_out[pathway], + self.kernel[pathway], + self.stride[pathway], + self.padding[pathway], + self.eps, ) + self.add_sublayer("pathway{}_stem".format(pathway), stem) + + def forward(self, x): + assert ( + len(x) == self.num_pathways + ), "Input tensor does not contain {} pathway".format(self.num_pathways) + + for pathway in range(len(x)): + m = getattr(self, "pathway{}_stem".format(pathway)) + x[pathway] = m(to_variable(x[pathway])) + + return x + + +class FuseFastToSlow(fluid.dygraph.Layer): + """ + Fuses the information from the Fast pathway to the Slow pathway. Given the + tensors from Slow pathway and Fast pathway, fuse information from Fast to + Slow, then return the fused tensors from Slow and Fast pathway in order. + """ + + def __init__( + self, + dim_in, + fusion_conv_channel_ratio, + fusion_kernel, + alpha, + eps=1e-5, ): + """ + Args: + dim_in (int): the channel dimension of the input. + fusion_conv_channel_ratio (int): channel ratio for the convolution + used to fuse from Fast pathway to Slow pathway. + fusion_kernel (int): kernel size of the convolution used to fuse + from Fast pathway to Slow pathway. + alpha (int): the frame rate ratio between the Fast and Slow pathway. + eps (float): epsilon for batch norm. + """ + super(FuseFastToSlow, self).__init__() + fan = (dim_in * fusion_conv_channel_ratio) * (fusion_kernel * 1 * 1) + initializer_tmp = get_conv_init(fan) + batchnorm_weight = 1.0 + + self._conv_f2s = fluid.dygraph.nn.Conv3D( + num_channels=dim_in, + num_filters=dim_in * fusion_conv_channel_ratio, + filter_size=[fusion_kernel, 1, 1], + stride=[alpha, 1, 1], + padding=[fusion_kernel // 2, 0, 0], + param_attr=fluid.ParamAttr(initializer=initializer_tmp), + bias_attr=False) + self._bn = fluid.dygraph.BatchNorm( + num_channels=dim_in * fusion_conv_channel_ratio, + epsilon=eps, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(batchnorm_weight), + regularizer=fluid.regularizer.L2Decay( + regularization_coeff=0.0)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(0.0), + regularizer=fluid.regularizer.L2Decay( + regularization_coeff=0.0))) + + def forward(self, x): + x_s = x[0] + x_f = x[1] + fuse = self._conv_f2s(x_f) + fuse = self._bn(fuse) + fuse = fluid.layers.relu(fuse) + x_s_fuse = fluid.layers.concat(input=[x_s, fuse], axis=1, name=None) + + return [x_s_fuse, x_f] + + +class SlowFast(fluid.dygraph.Layer): + """ + SlowFast model builder for SlowFast network. + + Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. + "Slowfast networks for video recognition." + https://arxiv.org/pdf/1812.03982.pdf + """ + + def __init__(self, cfg, num_classes): + """ + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + super(SlowFast, self).__init__() + self.num_classes = num_classes + self.num_frames = cfg.MODEL.num_frames #32 + self.alpha = cfg.MODEL.alpha #8 + self.beta = cfg.MODEL.beta #8 + self.crop_size = cfg.MODEL.crop_size #224 + self.num_pathways = 2 + self.res_depth = 50 + self.num_groups = 1 + self.input_channel_num = [3, 3] + self.width_per_group = 64 + self.fusion_conv_channel_ratio = 2 + self.fusion_kernel_sz = 5 + self.dropout_rate = 0.5 + self._construct_network(cfg) + + def _construct_network(self, cfg): + """ + Builds a SlowFast model. + The first pathway is the Slow pathway + and the second pathway is the Fast pathway. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + temp_kernel = [ + [[1], [5]], # conv1 temporal kernel for slow and fast pathway. + [[1], [3]], # res2 temporal kernel for slow and fast pathway. + [[1], [3]], # res3 temporal kernel for slow and fast pathway. + [[3], [3]], # res4 temporal kernel for slow and fast pathway. + [[3], [3]], + ] # res5 temporal kernel for slow and fast pathway. + + self.s1 = VideoModelStem( + dim_in=self.input_channel_num, + dim_out=[self.width_per_group, self.width_per_group // self.beta], + kernel=[temp_kernel[0][0] + [7, 7], temp_kernel[0][1] + [7, 7]], + stride=[[1, 2, 2]] * 2, + padding=[ + [temp_kernel[0][0][0] // 2, 3, 3], + [temp_kernel[0][1][0] // 2, 3, 3], + ], ) + self.s1_fuse = FuseFastToSlow( + dim_in=self.width_per_group // self.beta, + fusion_conv_channel_ratio=self.fusion_conv_channel_ratio, + fusion_kernel=self.fusion_kernel_sz, + alpha=self.alpha, ) + + # ResNet backbone + MODEL_STAGE_DEPTH = {50: (3, 4, 6, 3)} + (d2, d3, d4, d5) = MODEL_STAGE_DEPTH[self.res_depth] + + num_block_temp_kernel = [[3, 3], [4, 4], [6, 6], [3, 3]] + spatial_dilations = [[1, 1], [1, 1], [1, 1], [1, 1]] + spatial_strides = [[1, 1], [2, 2], [2, 2], [2, 2]] + + out_dim_ratio = self.beta // self.fusion_conv_channel_ratio #4 + dim_inner = self.width_per_group * self.num_groups #64 + + self.s2 = ResStage( + dim_in=[ + self.width_per_group + self.width_per_group // out_dim_ratio, + self.width_per_group // self.beta, + ], + dim_out=[ + self.width_per_group * 4, + self.width_per_group * 4 // self.beta, + ], + dim_inner=[dim_inner, dim_inner // self.beta], + temp_kernel_sizes=temp_kernel[1], + stride=spatial_strides[0], + num_blocks=[d2] * 2, + num_groups=[self.num_groups] * 2, + num_block_temp_kernel=num_block_temp_kernel[0], + dilation=spatial_dilations[0], ) + + self.s2_fuse = FuseFastToSlow( + dim_in=self.width_per_group * 4 // self.beta, + fusion_conv_channel_ratio=self.fusion_conv_channel_ratio, + fusion_kernel=self.fusion_kernel_sz, + alpha=self.alpha, ) + + self.s3 = ResStage( + dim_in=[ + self.width_per_group * 4 + self.width_per_group * 4 // + out_dim_ratio, + self.width_per_group * 4 // self.beta, + ], + dim_out=[ + self.width_per_group * 8, + self.width_per_group * 8 // self.beta, + ], + dim_inner=[dim_inner * 2, dim_inner * 2 // self.beta], + temp_kernel_sizes=temp_kernel[2], + stride=spatial_strides[1], + num_blocks=[d3] * 2, + num_groups=[self.num_groups] * 2, + num_block_temp_kernel=num_block_temp_kernel[1], + dilation=spatial_dilations[1], ) + + self.s3_fuse = FuseFastToSlow( + dim_in=self.width_per_group * 8 // self.beta, + fusion_conv_channel_ratio=self.fusion_conv_channel_ratio, + fusion_kernel=self.fusion_kernel_sz, + alpha=self.alpha, ) + + self.s4 = ResStage( + dim_in=[ + self.width_per_group * 8 + self.width_per_group * 8 // + out_dim_ratio, + self.width_per_group * 8 // self.beta, + ], + dim_out=[ + self.width_per_group * 16, + self.width_per_group * 16 // self.beta, + ], + dim_inner=[dim_inner * 4, dim_inner * 4 // self.beta], + temp_kernel_sizes=temp_kernel[3], + stride=spatial_strides[2], + num_blocks=[d4] * 2, + num_groups=[self.num_groups] * 2, + num_block_temp_kernel=num_block_temp_kernel[2], + dilation=spatial_dilations[2], ) + + self.s4_fuse = FuseFastToSlow( + dim_in=self.width_per_group * 16 // self.beta, + fusion_conv_channel_ratio=self.fusion_conv_channel_ratio, + fusion_kernel=self.fusion_kernel_sz, + alpha=self.alpha, ) + + self.s5 = ResStage( + dim_in=[ + self.width_per_group * 16 + self.width_per_group * 16 // + out_dim_ratio, + self.width_per_group * 16 // self.beta, + ], + dim_out=[ + self.width_per_group * 32, + self.width_per_group * 32 // self.beta, + ], + dim_inner=[dim_inner * 8, dim_inner * 8 // self.beta], + temp_kernel_sizes=temp_kernel[4], + stride=spatial_strides[3], + num_blocks=[d5] * 2, + num_groups=[self.num_groups] * 2, + num_block_temp_kernel=num_block_temp_kernel[3], + dilation=spatial_dilations[3], ) + + self.pool_size = [[1, 1, 1], [1, 1, 1]] + self.head = ResNetBasicHead( + dim_in=[ + self.width_per_group * 32, + self.width_per_group * 32 // self.beta, + ], + num_classes=self.num_classes, + pool_size=[ + [ + self.num_frames // self.alpha // self.pool_size[0][0], + self.crop_size // 32 // self.pool_size[0][1], + self.crop_size // 32 // self.pool_size[0][2], + ], + [ + self.num_frames // self.pool_size[1][0], + self.crop_size // 32 // self.pool_size[1][1], + self.crop_size // 32 // self.pool_size[1][2], + ], + ], + dropout_rate=self.dropout_rate, ) + + def forward(self, x, training): + x = self.s1(x) #VideoModelStem + x = self.s1_fuse(x) #FuseFastToSlow + x = self.s2(x) #ResStage + x = self.s2_fuse(x) + + for pathway in range(self.num_pathways): + x[pathway] = fluid.layers.pool3d( + input=x[pathway], + pool_type="max", + pool_size=self.pool_size[pathway], + pool_stride=self.pool_size[pathway], + pool_padding=[0, 0, 0], + data_format="NCDHW") + + x = self.s3(x) + x = self.s3_fuse(x) + x = self.s4(x) + x = self.s4_fuse(x) + x = self.s5(x) + x = self.head(x, training) #ResNetBasicHead + return x diff --git a/dygraph/slowfast/model_utils.py b/dygraph/slowfast/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..777a4163bc69ca32de42d9ab7ad87dba856bfb6d --- /dev/null +++ b/dygraph/slowfast/model_utils.py @@ -0,0 +1,484 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from paddle.fluid.initializer import MSRA +import paddle.fluid as fluid + + +# get init parameters for conv layer +def get_conv_init(fan_out): + return MSRA(uniform=False, fan_in=fan_out) + + +"""Video models.""" + + +class BottleneckTransform(fluid.dygraph.Layer): + """ + Bottleneck transformation: Tx1x1, 1x3x3, 1x1x1, where T is the size of + temporal kernel. + """ + + def __init__( + self, + dim_in, + dim_out, + temp_kernel_size, + stride, + dim_inner, + num_groups, + stride_1x1=False, + inplace_relu=True, + eps=1e-5, + dilation=1, ): + """ + Args: + dim_in (int): the channel dimensions of the input. + dim_out (int): the channel dimension of the output. + temp_kernel_size (int): the temporal kernel sizes of the middle + convolution in the bottleneck. + stride (int): the stride of the bottleneck. + dim_inner (int): the inner dimension of the block. + num_groups (int): number of groups for the convolution. num_groups=1 + is for standard ResNet like networks, and num_groups>1 is for + ResNeXt like networks. + stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise + apply stride to the 3x3 conv. + inplace_relu (bool): if True, calculate the relu on the original + input without allocating new memory. + eps (float): epsilon for batch norm. + dilation (int): size of dilation. + """ + super(BottleneckTransform, self).__init__() + self.temp_kernel_size = temp_kernel_size + self._inplace_relu = inplace_relu + self._eps = eps + self._stride_1x1 = stride_1x1 + self._construct(dim_in, dim_out, stride, dim_inner, num_groups, + dilation) + + def _construct(self, dim_in, dim_out, stride, dim_inner, num_groups, + dilation): + str1x1, str3x3 = (stride, 1) if self._stride_1x1 else (1, stride) + + fan = (dim_inner) * (self.temp_kernel_size * 1 * 1) + initializer_tmp = get_conv_init(fan) + batchnorm_weight = 1.0 + + self.a = fluid.dygraph.nn.Conv3D( + num_channels=dim_in, + num_filters=dim_inner, + filter_size=[self.temp_kernel_size, 1, 1], + stride=[1, str1x1, str1x1], + padding=[int(self.temp_kernel_size // 2), 0, 0], + param_attr=fluid.ParamAttr(initializer=initializer_tmp), + bias_attr=False) + self.a_bn = fluid.dygraph.BatchNorm( + num_channels=dim_inner, + epsilon=self._eps, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(batchnorm_weight), + regularizer=fluid.regularizer.L2Decay( + regularization_coeff=0.0)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(0.0), + regularizer=fluid.regularizer.L2Decay( + regularization_coeff=0.0))) + + # 1x3x3, BN, ReLU. + fan = (dim_inner) * (1 * 3 * 3) + initializer_tmp = get_conv_init(fan) + batchnorm_weight = 1.0 + + self.b = fluid.dygraph.nn.Conv3D( + num_channels=dim_inner, + num_filters=dim_inner, + filter_size=[1, 3, 3], + stride=[1, str3x3, str3x3], + padding=[0, dilation, dilation], + groups=num_groups, + dilation=[1, dilation, dilation], + param_attr=fluid.ParamAttr(initializer=initializer_tmp), + bias_attr=False) + self.b_bn = fluid.dygraph.BatchNorm( + num_channels=dim_inner, + epsilon=self._eps, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(batchnorm_weight), + regularizer=fluid.regularizer.L2Decay( + regularization_coeff=0.0)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(0.0), + regularizer=fluid.regularizer.L2Decay( + regularization_coeff=0.0))) + + # 1x1x1, BN. + fan = (dim_out) * (1 * 1 * 1) + initializer_tmp = get_conv_init(fan) + batchnorm_weight = 0.0 + + self.c = fluid.dygraph.nn.Conv3D( + num_channels=dim_inner, + num_filters=dim_out, + filter_size=[1, 1, 1], + stride=[1, 1, 1], + padding=[0, 0, 0], + param_attr=fluid.ParamAttr(initializer=initializer_tmp), + bias_attr=False) + self.c_bn = fluid.dygraph.BatchNorm( + num_channels=dim_out, + epsilon=self._eps, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(batchnorm_weight), + regularizer=fluid.regularizer.L2Decay( + regularization_coeff=0.0)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(0.0), + regularizer=fluid.regularizer.L2Decay( + regularization_coeff=0.0))) + + def forward(self, x): + # Branch2a. + x = self.a(x) + x = self.a_bn(x) + x = fluid.layers.relu(x) + + # Branch2b. + x = self.b(x) + x = self.b_bn(x) + x = fluid.layers.relu(x) + + # Branch2c + x = self.c(x) + x = self.c_bn(x) + return x + + +class ResBlock(fluid.dygraph.Layer): + """ + Residual block. + """ + + def __init__( + self, + dim_in, + dim_out, + temp_kernel_size, + stride, + dim_inner, + num_groups=1, + stride_1x1=False, + inplace_relu=True, + eps=1e-5, + dilation=1, ): + """ + ResBlock class constructs redisual blocks. More details can be found in: + Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. + "Deep residual learning for image recognition." + https://arxiv.org/abs/1512.03385 + Args: + dim_in (int): the channel dimensions of the input. + dim_out (int): the channel dimension of the output. + temp_kernel_size (int): the temporal kernel sizes of the middle + convolution in the bottleneck. + stride (int): the stride of the bottleneck. + trans_func (string): transform function to be used to construct the + bottleneck. + dim_inner (int): the inner dimension of the block. + num_groups (int): number of groups for the convolution. num_groups=1 + is for standard ResNet like networks, and num_groups>1 is for + ResNeXt like networks. + stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise + apply stride to the 3x3 conv. + inplace_relu (bool): calculate the relu on the original input + without allocating new memory. + eps (float): epsilon for batch norm. + dilation (int): size of dilation. + """ + super(ResBlock, self).__init__() + self._inplace_relu = inplace_relu + self._eps = eps + self._construct( + dim_in, + dim_out, + temp_kernel_size, + stride, + dim_inner, + num_groups, + stride_1x1, + inplace_relu, + dilation, ) + + def _construct( + self, + dim_in, + dim_out, + temp_kernel_size, + stride, + dim_inner, + num_groups, + stride_1x1, + inplace_relu, + dilation, ): + # Use skip connection with projection if dim or res change. + if (dim_in != dim_out) or (stride != 1): + fan = (dim_out) * (1 * 1 * 1) + initializer_tmp = get_conv_init(fan) + batchnorm_weight = 1.0 + self.branch1 = fluid.dygraph.nn.Conv3D( + num_channels=dim_in, + num_filters=dim_out, + filter_size=1, + stride=[1, stride, stride], + padding=0, + param_attr=fluid.ParamAttr(initializer=initializer_tmp), + bias_attr=False, + dilation=1) + self.branch1_bn = fluid.dygraph.BatchNorm( + num_channels=dim_out, + epsilon=self._eps, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(batchnorm_weight), + regularizer=fluid.regularizer.L2Decay( + regularization_coeff=0.0)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(0.0), + regularizer=fluid.regularizer.L2Decay( + regularization_coeff=0.0))) + + self.branch2 = BottleneckTransform( + dim_in, + dim_out, + temp_kernel_size, + stride, + dim_inner, + num_groups, + stride_1x1=stride_1x1, + inplace_relu=inplace_relu, + dilation=dilation, ) + + def forward(self, x): + if hasattr(self, "branch1"): + x1 = self.branch1(x) + x1 = self.branch1_bn(x1) + x2 = self.branch2(x) + x = fluid.layers.elementwise_add(x=x1, y=x2) + else: + x2 = self.branch2(x) + x = fluid.layers.elementwise_add(x=x, y=x2) + + x = fluid.layers.relu(x) + return x + + +class ResStage(fluid.dygraph.Layer): + """ + Stage of 3D ResNet. It expects to have one or more tensors as input for + multi-pathway (SlowFast) cases. More details can be found here: + + Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. + "Slowfast networks for video recognition." + https://arxiv.org/pdf/1812.03982.pdf + """ + + def __init__( + self, + dim_in, + dim_out, + stride, + temp_kernel_sizes, + num_blocks, + dim_inner, + num_groups, + num_block_temp_kernel, + dilation, + stride_1x1=False, + inplace_relu=True, ): + """ + The `__init__` method of any subclass should also contain these arguments. + ResStage builds p streams, where p can be greater or equal to one. + Args: + dim_in (list): list of p the channel dimensions of the input. + Different channel dimensions control the input dimension of + different pathways. + dim_out (list): list of p the channel dimensions of the output. + Different channel dimensions control the input dimension of + different pathways. + temp_kernel_sizes (list): list of the p temporal kernel sizes of the + convolution in the bottleneck. Different temp_kernel_sizes + control different pathway. + stride (list): list of the p strides of the bottleneck. Different + stride control different pathway. + num_blocks (list): list of p numbers of blocks for each of the + pathway. + dim_inner (list): list of the p inner channel dimensions of the + input. Different channel dimensions control the input dimension + of different pathways. + num_groups (list): list of number of p groups for the convolution. + num_groups=1 is for standard ResNet like networks, and + num_groups>1 is for ResNeXt like networks. + num_block_temp_kernel (list): extent the temp_kernel_sizes to + num_block_temp_kernel blocks, then fill temporal kernel size + of 1 for the rest of the layers. + dilation (list): size of dilation for each pathway. + """ + super(ResStage, self).__init__() + assert all((num_block_temp_kernel[i] <= num_blocks[i] + for i in range(len(temp_kernel_sizes)))) + self.num_blocks = num_blocks + self.temp_kernel_sizes = [(temp_kernel_sizes[i] * num_blocks[i] + )[:num_block_temp_kernel[i]] + [1] * + (num_blocks[i] - num_block_temp_kernel[i]) + for i in range(len(temp_kernel_sizes))] + assert (len({ + len(dim_in), + len(dim_out), + len(temp_kernel_sizes), + len(stride), + len(num_blocks), + len(dim_inner), + len(num_groups), + len(num_block_temp_kernel), + }) == 1) + self.num_pathways = len(self.num_blocks) + self._construct( + dim_in, + dim_out, + stride, + dim_inner, + num_groups, + stride_1x1, + inplace_relu, + dilation, ) + + def _construct( + self, + dim_in, + dim_out, + stride, + dim_inner, + num_groups, + stride_1x1, + inplace_relu, + dilation, ): + + for pathway in range(self.num_pathways): + for i in range(self.num_blocks[pathway]): + res_block = ResBlock( + dim_in[pathway] if i == 0 else dim_out[pathway], + dim_out[pathway], + self.temp_kernel_sizes[pathway][i], + stride[pathway] if i == 0 else 1, + dim_inner[pathway], + num_groups[pathway], + stride_1x1=stride_1x1, + inplace_relu=inplace_relu, + dilation=dilation[pathway], ) + self.add_sublayer("pathway{}_res{}".format(pathway, i), + res_block) + + def forward(self, inputs): + output = [] + for pathway in range(self.num_pathways): + x = inputs[pathway] + + for i in range(self.num_blocks[pathway]): + m = getattr(self, "pathway{}_res{}".format(pathway, i)) + x = m(x) + output.append(x) + + return output + + +"""ResNe(X)t Head helper.""" + + +class ResNetBasicHead(fluid.dygraph.Layer): + """ + ResNe(X)t 3D head. + This layer performs a fully-connected projection during training, when the + input size is 1x1x1. It performs a convolutional projection during testing + when the input size is larger than 1x1x1. If the inputs are from multiple + different pathways, the inputs will be concatenated after pooling. + """ + + def __init__( + self, + dim_in, + num_classes, + pool_size, + dropout_rate=0.0, ): + """ + ResNetBasicHead takes p pathways as input where p in [1, infty]. + + Args: + dim_in (list): the list of channel dimensions of the p inputs to the + ResNetHead. + num_classes (int): the channel dimensions of the p outputs to the + ResNetHead. + pool_size (list): the list of kernel sizes of p spatial temporal + poolings, temporal pool kernel size, spatial pool kernel size, + spatial pool kernel size in order. + dropout_rate (float): dropout rate. If equal to 0.0, perform no + dropout. + """ + super(ResNetBasicHead, self).__init__() + assert (len({len(pool_size), len(dim_in)}) == 1 + ), "pathway dimensions are not consistent." + self.num_pathways = len(pool_size) + self.pool_size = pool_size + self.dropout_rate = dropout_rate + fc_init_std = 0.01 + initializer_tmp = fluid.initializer.NormalInitializer( + loc=0.0, scale=fc_init_std) + self.projection = fluid.dygraph.Linear( + input_dim=sum(dim_in), + output_dim=num_classes, + param_attr=fluid.ParamAttr(initializer=initializer_tmp), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(0.0)), ) + + def forward(self, inputs, training): + assert ( + len(inputs) == self.num_pathways + ), "Input tensor does not contain {} pathway".format(self.num_pathways) + pool_out = [] + for pathway in range(self.num_pathways): + tmp_out = fluid.layers.pool3d( + input=inputs[pathway], + pool_type="avg", + pool_size=self.pool_size[pathway], + pool_stride=1, + data_format="NCDHW") + pool_out.append(tmp_out) + + x = fluid.layers.concat(input=pool_out, axis=1, name=None) + x = fluid.layers.transpose(x=x, perm=(0, 2, 3, 4, 1)) + + # Perform dropout. + if self.dropout_rate > 0.0: + x = fluid.layers.dropout( + x, + dropout_prob=self.dropout_rate, + dropout_implementation='upscale_in_train') + + x = self.projection(x) + + # Performs fully convlutional inference. + if not training: + x = fluid.layers.softmax(x, axis=4) + x = fluid.layers.reduce_mean(x, dim=[1, 2, 3]) + + x = fluid.layers.reshape(x, shape=(x.shape[0], -1)) + return x diff --git a/dygraph/slowfast/predict.py b/dygraph/slowfast/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..04ad7f0cee781c944beab13d172b66350e587ae7 --- /dev/null +++ b/dygraph/slowfast/predict.py @@ -0,0 +1,221 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import os +import sys +import argparse +import ast +import logging +import numpy as np +import json +import paddle.fluid as fluid +from paddle.fluid.dygraph.base import to_variable +from paddle.io import DataLoader, Dataset +from paddle.incubate.hapi.distributed import DistributedBatchSampler, _all_gather +from paddle.fluid.dygraph.parallel import ParallelEnv + +from model import * +from config_utils import * +from kinetics_dataset import KineticsDataset + +logging.root.handlers = [] +FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser( + "SLOWFAST test for performance evaluation.") + parser.add_argument( + '--config_file', + type=str, + default='slowfast.yaml', + help='path to config file of model') + parser.add_argument( + '--batch_size', + type=int, + default=None, + help='total eval batch size of all gpus.') + parser.add_argument( + '--use_gpu', + type=ast.literal_eval, + default=True, + help='default use gpu.') + parser.add_argument( + '--use_data_parallel', + type=ast.literal_eval, + default=True, + help='default use data parallel.') + parser.add_argument( + '--weights', + type=str, + default=None, + help='weight path, None to use config setting.') + parser.add_argument( + '--log_interval', + type=int, + default=1, + help='mini-batch interval to log.') + parser.add_argument( + '--save_path', + type=str, + default=None, + help='save path, None to use config setting.') + args = parser.parse_args() + return args + + +# Prediction +def infer_slowfast(args): + config = parse_config(args.config_file) + infer_config = merge_configs(config, 'infer', vars(args)) + print_configs(infer_config, "Infer") + + if not os.path.isdir(infer_config.INFER.save_path): + os.makedirs(infer_config.INFER.save_path) + + if not args.use_gpu: + place = fluid.CPUPlace() + elif not args.use_data_parallel: + place = fluid.CUDAPlace(0) + else: + place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) + + _nranks = ParallelEnv().nranks # num gpu + bs_single = int(infer_config.INFER.batch_size / + _nranks) # batch_size of each gpu + + with fluid.dygraph.guard(place): + #build model + slowfast = SlowFast(cfg=infer_config, num_classes=400) + if args.weights: + assert os.path.exists(args.weights + '.pdparams'),\ + "Given weight dir {} not exist.".format(args.weights) + + logger.info('load test weights from {}'.format(args.weights)) + model_dict, _ = fluid.load_dygraph(args.weights) + slowfast.set_dict(model_dict) + + if args.use_data_parallel: + strategy = fluid.dygraph.parallel.prepare_context() + slowfast = fluid.dygraph.parallel.DataParallel(slowfast, strategy) + + #create reader + infer_data = KineticsDataset(mode="infer", cfg=infer_config) + infer_sampler = DistributedBatchSampler( + infer_data, batch_size=bs_single, shuffle=False, drop_last=False) + infer_loader = DataLoader( + infer_data, + batch_sampler=infer_sampler, + places=place, + feed_list=None, + num_workers=0, + return_list=True) + + # start infer + num_ensemble_views = infer_config.INFER.num_ensemble_views + num_spatial_crops = infer_config.INFER.num_spatial_crops + num_cls = infer_config.MODEL.num_classes + num_clips = num_ensemble_views * num_spatial_crops + num_videos = len(infer_data) // num_clips + video_preds = np.zeros((num_videos, num_cls)) + clip_count = {} + + video_paths = [] + with open(infer_config.INFER.filelist, "r") as f: + for path in f.read().splitlines(): + video_paths.append(path) + + print( + "[INFER] infer start, number of videos {}, number of clips {}, total number of clips {}". + format(num_videos, num_clips, num_clips * num_videos)) + slowfast.eval() + for batch_id, data in enumerate(infer_loader): + # call net + model_inputs = [data[0], data[1]] + preds = slowfast(model_inputs, training=False) + clip_ids = data[3] + + # gather mulit card, results of following process in each card is the same. + if _nranks > 1: + preds = _all_gather(preds, _nranks) + clip_ids = _all_gather(clip_ids, _nranks) + + # to numpy + preds = preds.numpy() + clip_ids = clip_ids.numpy() + + # preds ensemble + for ind in range(preds.shape[0]): + vid_id = int(clip_ids[ind]) // num_clips + ts_idx = int(clip_ids[ind]) % num_clips + if vid_id not in clip_count: + clip_count[vid_id] = [] + if ts_idx in clip_count[vid_id]: + print( + "[INFER] Passed!! read video {} clip index {} / {} repeatedly.". + format(vid_id, ts_idx, clip_ids[ind])) + else: + clip_count[vid_id].append(ts_idx) + video_preds[vid_id] += preds[ind] # ensemble method: sum + if batch_id % args.log_interval == 0: + print("[INFER] Processing batch {}/{} ...".format( + batch_id, len(infer_data) // infer_config.INFER.batch_size)) + + # check clip index of each video + for key in clip_count.keys(): + if len(clip_count[key]) != num_clips or sum(clip_count[ + key]) != num_clips * (num_clips - 1) / 2: + print( + "[INFER] Warning!! video [{}] clip count [{}] not match number clips {}". + format(key, clip_count[key], num_clips)) + + res_list = [] + for j in range(video_preds.shape[0]): + pred = to_variable(video_preds[j] / num_clips) #mean prob + video_path = video_paths[j] + pred = to_variable(pred) + top1_values, top1_indices = fluid.layers.topk(pred, k=1) + top5_values, top5_indices = fluid.layers.topk(pred, k=5) + top1_values = top1_values.numpy().astype("float64")[0] + top1_indices = int(top1_indices.numpy()[0]) + top5_values = list(top5_values.numpy().astype("float64")) + top5_indices = [int(item) for item in top5_indices.numpy() + ] #np.int is not JSON serializable + print("[INFER] video id [{}], top1 value {}, top1 indices {}". + format(video_path, top1_values, top1_indices)) + print("[INFER] video id [{}], top5 value {}, top5 indices {}". + format(video_path, top5_values, top5_indices)) + save_dict = { + 'video_id': video_path, + 'top1_values': top1_values, + 'top1_indices': top1_indices, + 'top5_values': top5_values, + 'top5_indices': top5_indices + } + res_list.append(save_dict) + + with open( + os.path.join(infer_config.INFER.save_path, 'result' + '.json'), + 'w') as f: + json.dump(res_list, f) + print('[INFER] infer finished, results saved in {}'.format( + infer_config.INFER.save_path)) + + +if __name__ == "__main__": + args = parse_args() + logger.info(args) + infer_slowfast(args) diff --git a/dygraph/slowfast/run_eval_multi.sh b/dygraph/slowfast/run_eval_multi.sh new file mode 100644 index 0000000000000000000000000000000000000000..b53e83c06347ad5c5f11ebdd31d9da54cbfb8f04 --- /dev/null +++ b/dygraph/slowfast/run_eval_multi.sh @@ -0,0 +1,7 @@ +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python3.7 -m paddle.distributed.launch \ + eval.py \ + --config=slowfast.yaml \ + --use_gpu=True \ + --use_data_parallel=1 \ + --weights=checkpoints/slowfast_epoch195 diff --git a/dygraph/slowfast/run_eval_single.sh b/dygraph/slowfast/run_eval_single.sh new file mode 100644 index 0000000000000000000000000000000000000000..c5d5e900fd9134587c37992fe295dc125ea65c3b --- /dev/null +++ b/dygraph/slowfast/run_eval_single.sh @@ -0,0 +1,6 @@ +export CUDA_VISIBLE_DEVICES=0 +python3.7 eval.py \ + --config=slowfast-single.yaml \ + --use_gpu=True \ + --use_data_parallel=0 \ + --weights=checkpoints/slowfast_epoch195 diff --git a/dygraph/slowfast/run_infer_multi.sh b/dygraph/slowfast/run_infer_multi.sh new file mode 100644 index 0000000000000000000000000000000000000000..422e1cb58d6e93b742e24460d89955e6a291af19 --- /dev/null +++ b/dygraph/slowfast/run_infer_multi.sh @@ -0,0 +1,7 @@ +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python3.7 -m paddle.distributed.launch \ + predict.py \ + --config=slowfast.yaml \ + --use_gpu=True \ + --use_data_parallel=1 \ + --weights=checkpoints/slowfast_epoch195 diff --git a/dygraph/slowfast/run_infer_single.sh b/dygraph/slowfast/run_infer_single.sh new file mode 100644 index 0000000000000000000000000000000000000000..419bf0fe7070c04627b47cc0241e883126a09b0f --- /dev/null +++ b/dygraph/slowfast/run_infer_single.sh @@ -0,0 +1,6 @@ +export CUDA_VISIBLE_DEVICES=0 +python3.7 predict.py \ + --config=slowfast-single.yaml \ + --use_gpu=True \ + --use_data_parallel=0 \ + --weights=checkpoints/slowfast_epoch195 diff --git a/dygraph/slowfast/run_train_multi.sh b/dygraph/slowfast/run_train_multi.sh new file mode 100644 index 0000000000000000000000000000000000000000..501bbd18751e1cb0a5a37953067c05365b82b7ed --- /dev/null +++ b/dygraph/slowfast/run_train_multi.sh @@ -0,0 +1,13 @@ +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +start_time=$(date +%s) + +python3.7 -m paddle.distributed.launch --log_dir=logs \ + train.py \ + --config=slowfast.yaml \ + --use_gpu=True \ + --use_data_parallel=1 \ + +end_time=$(date +%s) +cost_time=$[ $end_time-$start_time ] +echo "8 card bs=64, 196 epoch 34 warmup epoch, 400 class, preciseBN 200 iter build kernel time is $(($cost_time/60))min $(($cost_time%60))s" diff --git a/dygraph/slowfast/run_train_single.sh b/dygraph/slowfast/run_train_single.sh new file mode 100644 index 0000000000000000000000000000000000000000..35fae3b9d787f89430083bd1d0fffd8baad3a242 --- /dev/null +++ b/dygraph/slowfast/run_train_single.sh @@ -0,0 +1,9 @@ +export CUDA_VISIBLE_DEVICES=0 + +start_time=$(date +%s) + +python3.7 train.py --config=slowfast-single.yaml --use_gpu=True --use_data_parallel=0 + +end_time=$(date +%s) +cost_time=$[ $end_time-$start_time ] +echo "1 card bs=8, 196 epoch 34 warmup epoch, 400 class, preciseBN 200 iter build kernel time is $(($cost_time/60))min $(($cost_time%60))s" diff --git a/dygraph/slowfast/slowfast-single.yaml b/dygraph/slowfast/slowfast-single.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f674bb73ec69857eecb6c3eaed85bb862f1b44d4 --- /dev/null +++ b/dygraph/slowfast/slowfast-single.yaml @@ -0,0 +1,53 @@ +MODEL: + name: "SLOWFAST" + format: "mp4" + num_classes: 400 + num_frames: 32 + sampling_rate: 2 + target_fps: 30 + alpha: 8 + beta: 8 + crop_size: 224 + image_mean: [0.45, 0.45, 0.45] + image_std: [0.225, 0.225, 0.225] + +TRAIN: + epoch: 196 + target_size: 224 + min_size: 256 + max_size: 320 + batch_size: 8 + use_gpu: True + num_gpus: 1 + filelist: "./data/train.csv" + base_lr: 0.1 + warmup_epochs: 34 + warmup_start_lr: 0.01 + l2_weight_decay: 1e-4 + momentum: 0.9 + +VALID: + use_preciseBN: True + preciseBN_interval: 10 + num_batches_preciseBN: 200 + target_size: 224 + min_size: 256 + max_size: 320 + batch_size: 8 + filelist: "./data/val.csv" + +TEST: + target_size: 256 + batch_size: 8 + filelist: "./data/val.csv" + num_ensemble_views: 10 + num_spatial_crops: 3 + +INFER: + target_size: 256 + batch_size: 8 + filelist: "./data/infer.csv" + num_ensemble_views: 10 + num_spatial_crops: 3 + save_path: "./data/results" + diff --git a/dygraph/slowfast/slowfast.yaml b/dygraph/slowfast/slowfast.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2e762f71853e2cdbe7f7053c29c63b34bdb931bf --- /dev/null +++ b/dygraph/slowfast/slowfast.yaml @@ -0,0 +1,55 @@ +MODEL: + name: "SLOWFAST" + format: "mp4" + num_classes: 400 + num_frames: 32 + sampling_rate: 2 + target_fps: 30 + alpha: 8 + beta: 8 + crop_size: 224 + image_mean: [0.45, 0.45, 0.45] + image_std: [0.225, 0.225, 0.225] + +TRAIN: + epoch: 196 + target_size: 224 + min_size: 256 + max_size: 320 + batch_size: 64 + use_gpu: True + num_gpus: 8 + filelist: "./data/train.csv" + base_lr: 0.1 + warmup_epochs: 34 + warmup_start_lr: 0.01 + l2_weight_decay: 1e-4 + momentum: 0.9 + +VALID: + use_preciseBN: True + preciseBN_interval: 10 + num_batches_preciseBN: 200 + target_size: 224 + min_size: 256 + max_size: 320 + batch_size: 64 + filelist: "./data/val.csv" + +TEST: + target_size: 256 + batch_size: 64 + filelist: "./data/val.csv" + weights: "checkpoints/slowfast_epoch195" + num_ensemble_views: 10 + num_spatial_crops: 3 + +INFER: + target_size: 256 + batch_size: 32 + filelist: "./data/infer.csv" + weights: "checkpoints/slowfast_epoch195" + num_ensemble_views: 10 + num_spatial_crops: 3 + save_path: "./data/results" + diff --git a/dygraph/slowfast/train.py b/dygraph/slowfast/train.py new file mode 100644 index 0000000000000000000000000000000000000000..0a8176590f5e82cea73a19ba09a32e4b93c7c869 --- /dev/null +++ b/dygraph/slowfast/train.py @@ -0,0 +1,461 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import os +import sys +import time +import argparse +import ast +import logging +import itertools +import numpy as np +import random +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph.base import to_variable +from paddle.io import DataLoader, Dataset +from paddle.incubate.hapi.distributed import DistributedBatchSampler + +from model import * +from config_utils import * +from lr_policy import get_epoch_lr +from kinetics_dataset import KineticsDataset + +logging.root.handlers = [] +FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser("SlowFast train") + parser.add_argument( + '--config', + type=str, + default='slowfast.yaml', + help='path to config file of model') + parser.add_argument( + '--use_visualdl', + type=ast.literal_eval, + default=False, + help='whether to use visual dl.') + parser.add_argument( + '--vd_logdir', + type=str, + default='./vdlog', + help='default save visualdl_log in ./vdlog.') + parser.add_argument( + '--use_data_parallel', + type=ast.literal_eval, + default=True, + help='default use data parallel.') + parser.add_argument( + '--use_gpu', + type=ast.literal_eval, + default=True, + help='default use gpu.') + parser.add_argument( + '--epoch', + type=int, + default=None, + help='epoch number, None to read from config file') + parser.add_argument( + '--batch_size', + type=int, + default=None, + help='training batch size. None to use config file setting.') + parser.add_argument( + '--save_dir', + type=str, + default='./checkpoints', + help='default model save in ./checkpoints.') + parser.add_argument( + '--resume', + type=ast.literal_eval, + default=False, + help='whether to resume training') + parser.add_argument( + '--resume_epoch', + type=int, + default=100000, + help='epoch to resume training based on latest saved checkpoints. ') + parser.add_argument( + '--valid_interval', + type=int, + default=1, + help='validation epoch interval, 0 for no validation.') + parser.add_argument( + '--log_interval', + type=int, + default=10, + help='mini-batch interval to log.') + + args = parser.parse_args() + return args + + +def val(epoch, model, valid_loader, use_visualdl): + val_iter_num = len(valid_loader) + total_loss = 0.0 + total_acc1 = 0.0 + total_acc5 = 0.0 + total_sample = 0 + + for batch_id, data in enumerate(valid_loader): + y_data = data[2] + labels = to_variable(y_data) + labels.stop_gradient = True + model_inputs = [data[0], data[1]] + + preds = model(model_inputs, training=False) + + loss_out = fluid.layers.softmax_with_cross_entropy( + logits=preds, label=labels) + avg_loss = fluid.layers.mean(loss_out) + acc_top1 = fluid.layers.accuracy(input=preds, label=labels, k=1) + acc_top5 = fluid.layers.accuracy(input=preds, label=labels, k=5) + + total_loss += avg_loss.numpy()[0] + total_acc1 += acc_top1.numpy()[0] + total_acc5 += acc_top5.numpy()[0] + total_sample += 1 + if use_visualdl: + vdl_writer.add_scalar( + tag="val/loss", + step=epoch * val_iter_num + batch_id, + value=avg_loss.numpy()) + vdl_writer.add_scalar( + tag="val/err1", + step=epoch * val_iter_num + batch_id, + value=1.0 - acc_top1.numpy()) + vdl_writer.add_scalar( + tag="val/err5", + step=epoch * val_iter_num + batch_id, + value=1.0 - acc_top5.numpy()) + print( "[Test Epoch %d, batch %d] loss %.5f, err1 %.5f, err5 %.5f" % \ + (epoch, batch_id, avg_loss.numpy(), 1.0 - acc_top1.numpy(), 1. - acc_top5.numpy())) + print( '[TEST Epoch %d end] avg_loss %.5f, avg_err1 %.5f, avg_err5= %.5f' % \ + (epoch, total_loss / total_sample, 1. - total_acc1 / total_sample, 1. - total_acc5 / total_sample)) + + if use_visualdl: + vdl_writer.add_scalar( + tag="val_epoch/loss", step=epoch, value=total_loss / total_sample) + vdl_writer.add_scalar( + tag="val_epoch/err1", + step=epoch, + value=1.0 - total_acc1 / total_sample) + vdl_writer.add_scalar( + tag="val_epoch/err5", + step=epoch, + value=1.0 - total_acc5 / total_sample) + + +def create_optimizer(cfg, data_size, params): + l2_weight_decay = cfg.l2_weight_decay + momentum = cfg.momentum + + lr_list = [] + bd_list = [] + cur_bd = 1 + for cur_epoch in range(cfg.epoch): + for cur_iter in range(data_size): + cur_lr = get_epoch_lr(cur_epoch + float(cur_iter) / data_size, cfg) + lr_list.append(cur_lr) + bd_list.append(cur_bd) + cur_bd += 1 + bd_list.pop() + + optimizer = fluid.optimizer.Momentum( + learning_rate=fluid.layers.piecewise_decay( + boundaries=bd_list, values=lr_list), + momentum=momentum, + regularization=fluid.regularizer.L2Decay(l2_weight_decay), + use_nesterov=True, + parameter_list=params) + + return optimizer + + +def precise_BN(model, data_loader, num_iters=200): + """ + Recompute and update the batch norm stats to make them more precise. During + training both BN stats and the weight are changing after every iteration, so + the running average can not precisely reflect the actual stats of the + current model. + In this function, the BN stats are recomputed with fixed weights, to make + the running average more precise. Specifically, it computes the true average + of per-batch mean/variance instead of the running average. + This is useful to improve validation accuracy. + :param model: the model whose bn stats will be recomputed + :param data_loader: an iterator. Produce data as input to the model + :param num_iters: number of iterations to compute the stats. + :return: the model with precise mean and variance in bn layers. + """ + bn_layers_list = [ + m for m in model.sublayers() + if isinstance(m, paddle.fluid.dygraph.nn.BatchNorm) and not m._is_test + ] + if len(bn_layers_list) == 0: + return + + # moving_mean=moving_mean*momentum+batch_mean*(1.−momentum) + # we set momentum=0. to get the true mean and variance during forward + momentum_actual = [bn._momentum for bn in bn_layers_list] + for bn in bn_layers_list: + bn._momentum = 0. + + running_mean = [ + fluid.layers.zeros_like(bn._mean) for bn in bn_layers_list + ] #pre-ignore + running_var = [ + fluid.layers.zeros_like(bn._variance) for bn in bn_layers_list + ] + + ind = -1 + for ind, data in enumerate(itertools.islice(data_loader, num_iters)): + model_inputs = [data[0], data[1]] + model(model_inputs, training=True) + + for i, bn in enumerate(bn_layers_list): + # Accumulates the bn stats. + running_mean[i] += (bn._mean - running_mean[i]) / (ind + 1) + running_var[i] += (bn._variance - running_var[i]) / (ind + 1) + + assert ind == num_iters - 1, ( + "update_bn_stats is meant to run for {} iterations, " + "but the dataloader stops at {} iterations.".format(num_iters, ind)) + + # Sets the precise bn stats. + for i, bn in enumerate(bn_layers_list): + bn._mean.set_value(running_mean[i]) + bn._variance.set_value(running_var[i]) + bn._momentum = momentum_actual[i] + + +def train(args): + config = parse_config(args.config) + train_config = merge_configs(config, 'train', vars(args)) + valid_config = merge_configs(config, 'valid', vars(args)) + print_configs(train_config, 'Train') + + # visual dl to visualize trianing process + local_rank = fluid.dygraph.parallel.Env().local_rank + if args.use_visualdl: + try: + from visualdl import LogWriter + vdl_writer = LogWriter(args.vd_logdir + '/' + str(local_rank)) + except: + print( + "visualdl is not installed, please install visualdl if you want to use" + ) + + if not args.use_gpu: + place = fluid.CPUPlace() + elif not args.use_data_parallel: + place = fluid.CUDAPlace(0) + else: + place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) + + random.seed(0) + np.random.seed(0) + paddle.framework.manual_seed(0) + with fluid.dygraph.guard(place): + # 1. init net + if args.use_data_parallel: + strategy = fluid.dygraph.parallel.prepare_context() + + video_model = SlowFast(cfg=train_config, num_classes=400) + if args.use_data_parallel: + video_model = fluid.dygraph.parallel.DataParallel(video_model, + strategy) + + bs_denominator = 1 + if args.use_gpu: + gpus = os.getenv("CUDA_VISIBLE_DEVICES", "") + if gpus == "": + pass + else: + gpus = gpus.split(",") + num_gpus = len(gpus) + assert num_gpus == train_config.TRAIN.num_gpus, \ + "num_gpus({}) set by CUDA_VISIBLE_DEVICES" \ + "shoud be the same as that" \ + "set in {}({})".format( + num_gpus, args.config, train_config.TRAIN.num_gpus) + bs_denominator = train_config.TRAIN.num_gpus + + # 2. reader and optimizer + bs_train_single = int(train_config.TRAIN.batch_size / bs_denominator) + bs_val_single = int(train_config.VALID.batch_size / bs_denominator) + train_data = KineticsDataset(mode="train", cfg=train_config) + valid_data = KineticsDataset(mode="valid", cfg=valid_config) + train_sampler = DistributedBatchSampler( + train_data, + batch_size=bs_train_single, + shuffle=True, + drop_last=True) + train_loader = DataLoader( + train_data, + batch_sampler=train_sampler, + places=place, + feed_list=None, + num_workers=8, + return_list=True) + valid_sampler = DistributedBatchSampler( + valid_data, + batch_size=bs_val_single, + shuffle=False, + drop_last=False) + valid_loader = DataLoader( + valid_data, + batch_sampler=valid_sampler, + places=place, + feed_list=None, + num_workers=8, + return_list=True) + + train_iter_num = len(train_loader) + optimizer = create_optimizer(train_config.TRAIN, train_iter_num, + video_model.parameters()) + + #3. load checkpoint + if args.resume: + saved_path = "slowfast_epoch" #default + model_path = saved_path + args.resume_epoch + assert os.path.exists(model_path + ".pdparams"), \ + "Given dir {}.pdparams not exist.".format(model_path) + assert os.path.exists(model_path + ".pdopt"), \ + "Given dir {}.pdopt not exist.".format(model_path) + para_dict, opti_dict = fluid.dygraph.load_dygraph(model_path) + video_model.set_dict(para_dict) + optimizer.set_dict(opti_dict) + if args.use_visualdl: + # change log path otherwise log history will be overwritten + vdl_writer = LogWriter(args.vd_logdir + args.resume_epoch + '/' + + str(local_rank)) + + # 4. train loop + for epoch in range(train_config.TRAIN.epoch): + epoch_start = time.time() + if args.resume and epoch <= args.resume_epoch: + print("epoch:{}<=args.resume_epoch:{}, pass".format( + epoch, args.resume_epoch)) + continue + video_model.train() + total_loss = 0.0 + total_acc1 = 0.0 + total_acc5 = 0.0 + total_sample = 0 + + print('start for, Epoch {}/{} '.format(epoch, + train_config.TRAIN.epoch)) + batch_start = time.time() + for batch_id, data in enumerate(train_loader): + batch_reader_end = time.time() + y_data = data[2] + labels = to_variable(y_data) + labels.stop_gradient = True + model_inputs = [data[0], data[1]] + + # 4.1.1 call net() + preds = video_model(model_inputs, training=True) + loss_out = fluid.layers.softmax_with_cross_entropy( + logits=preds, label=labels) + avg_loss = fluid.layers.mean(loss_out) + acc_top1 = fluid.layers.accuracy(input=preds, label=labels, k=1) + acc_top5 = fluid.layers.accuracy(input=preds, label=labels, k=5) + + # 4.1.2 call backward() + if args.use_data_parallel: + avg_loss = video_model.scale_loss(avg_loss) + avg_loss.backward() + video_model.apply_collective_grads() + else: + avg_loss.backward() + + # 4.1.3 call minimize() + optimizer.minimize(avg_loss) + video_model.clear_gradients() + + total_loss += avg_loss.numpy()[0] + total_acc1 += acc_top1.numpy()[0] + total_acc5 += acc_top5.numpy()[0] + total_sample += 1 + if args.use_visualdl: + vdl_writer.add_scalar( + tag="train/loss", + step=epoch * train_iter_num + batch_id, + value=avg_loss.numpy()) + vdl_writer.add_scalar( + tag="train/err1", + step=epoch * train_iter_num + batch_id, + value=1.0 - acc_top1.numpy()) + vdl_writer.add_scalar( + tag="train/err5", + step=epoch * train_iter_num + batch_id, + value=1.0 - acc_top5.numpy()) + + train_batch_cost = time.time() - batch_start + train_reader_cost = batch_reader_end - batch_start + batch_start = time.time() + if batch_id % args.log_interval == 0: + print( "[Epoch %d, batch %d] loss %.5f, err1 %.5f, err5 %.5f, batch_cost: %.5f s, reader_cost: %.5f s" % \ + (epoch, batch_id, avg_loss.numpy(), 1.0 - acc_top1.numpy(), 1. - acc_top5.numpy(), train_batch_cost, train_reader_cost)) + + train_epoch_cost = time.time() - epoch_start + print( '[Epoch %d end] avg_loss %.5f, avg_err1 %.5f, avg_err5= %.5f, epoch_cost: %.5f s' % \ + (epoch, total_loss / total_sample, 1. - total_acc1 / total_sample, 1. - total_acc5 / total_sample, train_epoch_cost)) + if args.use_visualdl: + vdl_writer.add_scalar( + tag="train_epoch/loss", + step=epoch, + value=total_loss / total_sample) + vdl_writer.add_scalar( + tag="train_epoch/err1", + step=epoch, + value=1. - total_acc1 / total_sample) + vdl_writer.add_scalar( + tag="train_epoch/err5", + step=epoch, + value=1. - total_acc5 / total_sample) + + # 4.3 do preciseBN + if valid_config.VALID.use_preciseBN and epoch % valid_config.VALID.preciseBN_interval == 0: + print("do precise BN in epoch {} ...".format(epoch)) + precise_BN(video_model, train_loader, + min(valid_config.VALID.num_batches_preciseBN, + len(train_loader))) + + # 4.3 save checkpoint + if local_rank == 0: + if not os.path.isdir(args.save_dir): + os.makedirs(args.save_dir) + model_path = os.path.join(args.save_dir, + "slowfast_epoch{}".format(epoch)) + fluid.dygraph.save_dygraph(video_model.state_dict(), model_path) + fluid.dygraph.save_dygraph(optimizer.state_dict(), model_path) + print('save_dygraph End, Epoch {}/{} '.format( + epoch, train_config.TRAIN.epoch)) + + # 4.4 validation + video_model.eval() + val(epoch, video_model, valid_loader, args.use_visualdl) + + logger.info('[TRAIN] training finished') + + +if __name__ == "__main__": + args = parse_args() + logger.info(args) + train(args)