未验证 提交 3203066d 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #20 from heavengate/add_tsm

Add tsm
...@@ -18,6 +18,7 @@ from . import mobilenetv1 ...@@ -18,6 +18,7 @@ from . import mobilenetv1
from . import mobilenetv2 from . import mobilenetv2
from . import darknet from . import darknet
from . import yolov3 from . import yolov3
from . import tsm
from .resnet import * from .resnet import *
from .mobilenetv1 import * from .mobilenetv1 import *
...@@ -25,10 +26,12 @@ from .mobilenetv2 import * ...@@ -25,10 +26,12 @@ from .mobilenetv2 import *
from .vgg import * from .vgg import *
from .darknet import * from .darknet import *
from .yolov3 import * from .yolov3 import *
from .tsm import *
__all__ = resnet.__all__ \ __all__ = resnet.__all__ \
+ vgg.__all__ \ + vgg.__all__ \
+ mobilenetv1.__all__ \ + mobilenetv1.__all__ \
+ mobilenetv2.__all__ \ + mobilenetv2.__all__ \
+ darknet.__all__ \ + darknet.__all__ \
+ yolov3.__all__ + yolov3.__all__ \
+ tsm.__all__
# TSM 视频分类模型
---
## 内容
- [模型简介](#模型简介)
- [快速开始](#快速开始)
- [参考论文](#参考论文)
## 模型简介
Temporal Shift Module是由MIT和IBM Watson AI Lab的Ji Lin,Chuang Gan和Song Han等人提出的通过时间位移来提高网络视频理解能力的模块,其位移操作原理如下图所示。
<p align="center">
<img src="./images/temporal_shift.png" height=250 width=800 hspace='10'/> <br />
Temporal shift module
</p>
上图中矩阵表示特征图中的temporal和channel维度,通过将一部分的channel在temporal维度上向前位移一步,一部分的channel在temporal维度上向后位移一步,位移后的空缺补零。通过这种方式在特征图中引入temporal维度上的上下文交互,提高在时间维度上的视频理解能力。
TSM模型是将Temporal Shift Module插入到ResNet网络中构建的视频分类模型,本模型库实现版本为以ResNet-50作为主干网络的TSM模型。
详细内容请参考论文[Temporal Shift Module for Efficient Video Understanding](https://arxiv.org/abs/1811.08383v1)
## 快速开始
### 安装说明
#### paddle安装
本项目依赖于 PaddlePaddle 1.7及以上版本或适当的develop版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
#### 代码下载及环境变量设置
克隆代码库到本地,并设置`PYTHONPATH`环境变量
```bash
git clone https://github.com/PaddlePaddle/hapi
cd hapi
export PYTHONPATH=$PYTHONPATH:`pwd`
cd tsm
```
### 数据准备
TSM的训练数据采用由DeepMind公布的Kinetics-400动作识别数据集。数据下载及准备请参考[数据说明](./dataset/README.md)
#### 小数据集验证
为了便于快速迭代,我们采用了较小的数据集进行动态图训练验证,从Kinetics-400数据集中选取分类标签(label)分别为 0, 2, 3, 4, 6, 7, 9, 12, 14, 15的即前10类数据验证模型精度。
### 模型训练
数据准备完毕后,可使用`main.py`脚本启动训练和评估,如下脚本会自动每epoch交替进行训练和模型评估,并将checkpoint默认保存在`tsm_checkpoint`目录下。
`main.py`脚本参数可通过如下命令查询
```bash
python main.py --help
```
#### 静态图训练
使用如下方式进行单卡训练:
```bash
export CUDA_VISIBLE_DEVICES=0
python main.py --data=<path/to/dataset> --batch_size=16
```
使用如下方式进行多卡训练:
```bash
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch main.py --data=<path/to/dataset> --batch_size=8
```
#### 动态图训练
动态图训练只需要在运行脚本时添加`-d`参数即可。
使用如下方式进行单卡训练:
```bash
export CUDA_VISIBLE_DEVICES=0
python main.py --data=<path/to/dataset> --batch_size=16 -d
```
使用如下方式进行多卡训练:
```bash
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch main.py --data=<path/to/dataset> --batch_size=8 -d
```
**注意:** 对于静态图和动态图,多卡训练中`--batch_size`为每卡上的batch_size,即总batch_size为`--batch_size`乘以卡数
### 模型评估
可通过如下两种方式进行模型评估。
1. 自动下载Paddle发布的[TSM-ResNet50](https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams)权重评估
```bash
python main.py --data=<path/to/dataset> --eval_only
```
2. 加载checkpoint进行精度评估
```bash
python main.py --data=<path/to/dataset> --eval_only --weights=tsm_checkpoint/final
```
#### 评估精度
在10类小数据集下训练模型权重见[TSM-ResNet50](https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams),评估精度如下:
|Top-1|Top-5|
|:-:|:-:|
|76%|98%|
### 模型推断
可通过如下两种方式进行模型推断。
1. 自动下载Paddle发布的[TSM-ResNet50](https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams)权重推断
```bash
python infer.py --data=<path/to/dataset> --label_list=<path/to/label_list> --infer_file=<path/to/pickle>
```
2. 加载checkpoint进行精度推断
```bash
python infer.py --data=<path/to/dataset> --label_list=<path/to/label_list> --infer_file=<path/to/pickle> --weights=tsm_checkpoint/final
```
模型推断结果会以如下日志形式输出
```text
2020-04-03 07:37:16,321-INFO: Sample ./kineteics/val_10/data_batch_10-042_6 predict label: 6, ground truth label: 6
```
## 参考论文
- [Temporal Shift Module for Efficient Video Understanding](https://arxiv.org/abs/1811.08383v1), Ji Lin, Chuang Gan, Song Han
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import paddle.fluid as fluid
import logging
logger = logging.getLogger(__name__)
__all__ = ['check_gpu', 'check_version']
def check_gpu(use_gpu):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"
try:
if use_gpu and not fluid.is_compiled_with_cuda():
logger.error(err)
sys.exit(1)
except Exception as e:
pass
def check_version(version='1.7.0'):
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version {} or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
.format(version)
try:
fluid.require_version(version)
except Exception as e:
logger.error(err)
sys.exit(1)
# 数据使用说明
## Kinetics数据集
Kinetics数据集是DeepMind公开的大规模视频动作识别数据集,有Kinetics400与Kinetics600两个版本。这里使用Kinetics400数据集,具体的数据预处理过程如下。
### mp4视频下载
在Code\_Root目录下创建文件夹
cd $Code_Root/data/dataset && mkdir kinetics
cd kinetics && mkdir data_k400 && cd data_k400
mkdir train_mp4 && mkdir val_mp4
ActivityNet官方提供了Kinetics的下载工具,具体参考其[官方repo ](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics)即可下载Kinetics400的mp4视频集合。将kinetics400的训练与验证集合分别下载到data/dataset/kinetics/data\_k400/train\_mp4与data/dataset/kinetics/data\_k400/val\_mp4。
### mp4文件预处理
为提高数据读取速度,提前将mp4文件解帧并打pickle包,dataloader从视频的pkl文件中读取数据(该方法耗费更多存储空间)。pkl文件里打包的内容为(video-id, label, [frame1, frame2,...,frameN])。
在 data/dataset/kinetics/data\_k400目录下创建目录train\_pkl和val\_pkl
cd $Code_Root/data/dataset/kinetics/data_k400
mkdir train_pkl && mkdir val_pkl
进入$Code\_Root/data/dataset/kinetics目录,使用video2pkl.py脚本进行数据转化。首先需要下载[train](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics/data/kinetics-400_train.csv)[validation](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics/data/kinetics-400_val.csv)数据集的文件列表。
首先生成预处理需要的数据集标签文件
python generate_label.py kinetics-400_train.csv kinetics400_label.txt
然后执行如下程序:
python video2pkl.py kinetics-400_train.csv $Source_dir $Target_dir 8 #以8个进程为例
- 该脚本依赖`ffmpeg`库,请预先安装`ffmpeg`
对于train数据,
Source_dir = $Code_Root/data/dataset/kinetics/data_k400/train_mp4
Target_dir = $Code_Root/data/dataset/kinetics/data_k400/train_pkl
对于val数据,
Source_dir = $Code_Root/data/dataset/kinetics/data_k400/val_mp4
Target_dir = $Code_Root/data/dataset/kinetics/data_k400/val_pkl
这样即可将mp4文件解码并保存为pkl文件。
### 生成训练和验证集list
··
cd $Code_Root/data/dataset/kinetics
ls $Code_Root/data/dataset/kinetics/data_k400/train_pkl/* > train.list
ls $Code_Root/data/dataset/kinetics/data_k400/val_pkl/* > val.list
ls $Code_Root/data/dataset/kinetics/data_k400/val_pkl/* > test.list
ls $Code_Root/data/dataset/kinetics/data_k400/val_pkl/* > infer.list
即可生成相应的文件列表,train.list和val.list的每一行表示一个pkl文件的绝对路径,示例如下:
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/train_pkl/data_batch_100-097
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/train_pkl/data_batch_100-114
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/train_pkl/data_batch_100-118
...
或者
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/val_pkl/data_batch_102-085
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/val_pkl/data_batch_102-086
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/val_pkl/data_batch_102-090
...
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
# kinetics-400_train.csv should be down loaded first and set as sys.argv[1]
# sys.argv[2] can be set as kinetics400_label.txt
# python generate_label.py kinetics-400_train.csv kinetics400_label.txt
num_classes = 400
fname = sys.argv[1]
outname = sys.argv[2]
fl = open(fname).readlines()
fl = fl[1:]
outf = open(outname, 'w')
label_list = []
for line in fl:
label = line.strip().split(',')[0].strip('"')
if label in label_list:
continue
else:
label_list.append(label)
assert len(label_list
) == num_classes, "there should be {} labels in list, but ".format(
num_classes, len(label_list))
label_list.sort()
for i in range(num_classes):
outf.write('{} {}'.format(label_list[i], i) + '\n')
outf.close()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import glob
try:
import cPickle as pickle
except:
import pickle
from multiprocessing import Pool
# example command line: python generate_k400_pkl.py kinetics-400_train.csv 8
#
# kinetics-400_train.csv is the training set file of K400 official release
# each line contains laebl,youtube_id,time_start,time_end,split,is_cc
assert (len(sys.argv) == 5)
f = open(sys.argv[1])
source_dir = sys.argv[2]
target_dir = sys.argv[3]
num_threads = sys.argv[4]
all_video_entries = [x.strip().split(',') for x in f.readlines()]
all_video_entries = all_video_entries[1:]
f.close()
category_label_map = {}
f = open('kinetics400_label.txt')
for line in f:
ens = line.strip().split(' ')
category = " ".join(ens[0:-1])
label = int(ens[-1])
category_label_map[category] = label
f.close()
def generate_pkl(entry):
mode = entry[4]
category = entry[0].strip('"')
category_dir = category
video_path = os.path.join(
'./',
entry[1] + "_%06d" % int(entry[2]) + "_%06d" % int(entry[3]) + ".mp4")
video_path = os.path.join(source_dir, category_dir, video_path)
label = category_label_map[category]
vid = './' + video_path.split('/')[-1].split('.')[0]
if os.path.exists(video_path):
if not os.path.exists(vid):
os.makedirs(vid)
os.system('ffmpeg -i ' + video_path + ' -q 0 ' + vid + '/%06d.jpg')
else:
print("File not exists {}".format(video_path))
return
images = sorted(glob.glob(vid + '/*.jpg'))
ims = []
for img in images:
f = open(img, 'rb')
ims.append(f.read())
f.close()
output_pkl = vid + ".pkl"
output_pkl = os.path.join(target_dir, output_pkl)
f = open(output_pkl, 'wb')
pickle.dump((vid, label, ims), f, protocol=2)
f.close()
os.system('rm -rf %s' % vid)
pool = Pool(processes=int(sys.argv[4]))
pool.map(generate_pkl, all_video_entries)
pool.close()
pool.join()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import os
import argparse
import numpy as np
from model import Input, set_device
from models import tsm_resnet50
from check import check_gpu, check_version
from kinetics_dataset import KineticsDataset
from transforms import *
import logging
logger = logging.getLogger(__name__)
def main():
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if FLAGS.dynamic else None
transform = Compose([GroupScale(),
GroupCenterCrop(),
NormalizeImage()])
dataset = KineticsDataset(
pickle_file=FLAGS.infer_file,
label_list=FLAGS.label_list,
mode='test',
transform=transform)
labels = dataset.label_list
model = tsm_resnet50(num_classes=len(labels),
pretrained=FLAGS.weights is None)
inputs = [Input([None, 8, 3, 224, 224], 'float32', name='image')]
model.prepare(inputs=inputs, device=FLAGS.device)
if FLAGS.weights is not None:
model.load(FLAGS.weights, reset_optimizer=True)
imgs, label = dataset[0]
pred = model.test([imgs[np.newaxis, :]])
pred = labels[np.argmax(pred)]
logger.info("Sample {} predict label: {}, ground truth label: {}" \
.format(FLAGS.infer_file, pred, labels[int(label)]))
if __name__ == '__main__':
parser = argparse.ArgumentParser("CNN training on TSM")
parser.add_argument(
"--data", type=str, default='dataset/kinetics',
help="path to dataset root directory")
parser.add_argument(
"--device", type=str, default='gpu',
help="device to use, gpu or cpu")
parser.add_argument(
"-d", "--dynamic", action='store_true',
help="enable dygraph mode")
parser.add_argument(
"--label_list", type=str, default=None,
help="path to category index label list file")
parser.add_argument(
"--infer_file", type=str, default=None,
help="path to pickle file for inference")
parser.add_argument(
"-w",
"--weights",
default=None,
type=str,
help="weights path for evaluation")
FLAGS = parser.parse_args()
check_gpu(str.lower(FLAGS.device) == 'gpu')
check_version()
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import six
import sys
import random
import numpy as np
from PIL import Image, ImageEnhance
try:
import cPickle as pickle
from cStringIO import StringIO
except ImportError:
import pickle
from io import BytesIO
from paddle.fluid.io import Dataset
import logging
logger = logging.getLogger(__name__)
__all__ = ['KineticsDataset']
KINETICS_CLASS_NUM = 400
class KineticsDataset(Dataset):
"""
Kinetics dataset
Args:
file_list (str): path to file list
pickle_dir (str): path to pickle file directory
label_list (str): path to label_list file, if set None, the
default class number 400 of kinetics dataset will be
used. Default None
mode (str): 'train' or 'val' mode, segmentation methods will
be different in these 2 modes. Default 'train'
seg_num (int): segment number to sample from each video.
Default 8
seg_len (int): frame number of each segment. Default 1
transform (callable): transforms to perform on video samples,
None for no transforms. Default None.
"""
def __init__(self,
file_list=None,
pickle_dir=None,
pickle_file=None,
label_list=None,
mode='train',
seg_num=8,
seg_len=1,
transform=None):
assert str.lower(mode) in ['train', 'val', 'test'], \
"mode can only be 'train' 'val' or 'test'"
self.mode = str.lower(mode)
if self.mode in ['train', 'val']:
assert os.path.isfile(file_list), \
"file_list {} not a file".format(file_list)
with open(file_list) as f:
self.pickle_paths = [l.strip() for l in f]
assert os.path.isdir(pickle_dir), \
"pickle_dir {} not a directory".format(pickle_dir)
self.pickle_dir = pickle_dir
else:
assert os.path.isfile(pickle_file), \
"pickle_file {} not a file".format(pickle_file)
self.pickle_dir = ''
self.pickle_paths = [pickle_file]
self.label_list = label_list
if self.label_list is not None:
assert os.path.isfile(self.label_list), \
"label_list {} not a file".format(self.label_list)
with open(self.label_list) as f:
self.label_list = [int(l.strip()) for l in f]
self.seg_num = seg_num
self.seg_len = seg_len
self.transform = transform
def __len__(self):
return len(self.pickle_paths)
def __getitem__(self, idx):
pickle_path = os.path.join(self.pickle_dir, self.pickle_paths[idx])
try:
if six.PY2:
data = pickle.load(open(pickle_path, 'rb'))
else:
data = pickle.load(open(pickle_path, 'rb'), encoding='bytes')
vid, label, frames = data
if len(frames) < 1:
logger.error("{} contains no frame".format(pickle_path))
sys.exit(-1)
except Exception as e:
logger.error("Load {} failed: {}".format(pickle_path, e))
sys.exit(-1)
if self.label_list is not None:
label = self.label_list.index(label)
imgs = self._video_loader(frames)
if self.transform:
imgs, label = self.transform(imgs, label)
return imgs, np.array([label])
@property
def num_classes(self):
return KINETICS_CLASS_NUM if self.label_list is None \
else len(self.label_list)
def _video_loader(self, frames):
videolen = len(frames)
average_dur = int(videolen / self.seg_num)
imgs = []
for i in range(self.seg_num):
idx = 0
if self.mode == 'train':
if average_dur >= self.seg_len:
idx = random.randint(0, average_dur - self.seg_len)
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
else:
if average_dur >= self.seg_len:
idx = (average_dur - self.seg_len) // 2
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
for jj in range(idx, idx + self.seg_len):
imgbuf = frames[int(jj % videolen)]
img = self._imageloader(imgbuf)
imgs.append(img)
return imgs
def _imageloader(self, buf):
if isinstance(buf, str):
img = Image.open(StringIO(buf))
else:
img = Image.open(BytesIO(buf))
return img.convert('RGB')
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import os
import argparse
import numpy as np
from paddle import fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from model import Model, CrossEntropy, Input, set_device
from metrics import Accuracy
from models import tsm_resnet50
from check import check_gpu, check_version
from kinetics_dataset import KineticsDataset
from transforms import *
def make_optimizer(step_per_epoch, parameter_list=None):
boundaries = [e * step_per_epoch for e in [40, 60]]
values = [FLAGS.lr * (0.1 ** i) for i in range(len(boundaries) + 1)]
learning_rate = fluid.layers.piecewise_decay(
boundaries=boundaries,
values=values)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
regularization=fluid.regularizer.L2Decay(1e-4),
momentum=0.9,
parameter_list=parameter_list)
return optimizer
def main():
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if FLAGS.dynamic else None
train_transform = Compose([GroupScale(),
GroupMultiScaleCrop(),
GroupRandomCrop(),
GroupRandomFlip(),
NormalizeImage()])
train_dataset = KineticsDataset(
file_list=os.path.join(FLAGS.data, 'train_10.list'),
pickle_dir=os.path.join(FLAGS.data, 'train_10'),
label_list=os.path.join(FLAGS.data, 'label_list'),
transform=train_transform)
val_transform = Compose([GroupScale(),
GroupCenterCrop(),
NormalizeImage()])
val_dataset = KineticsDataset(
file_list=os.path.join(FLAGS.data, 'val_10.list'),
pickle_dir=os.path.join(FLAGS.data, 'val_10'),
label_list=os.path.join(FLAGS.data, 'label_list'),
mode='val',
transform=val_transform)
pretrained = FLAGS.eval_only and FLAGS.weights is None
model = tsm_resnet50(num_classes=train_dataset.num_classes,
pretrained=pretrained)
step_per_epoch = int(len(train_dataset) / FLAGS.batch_size \
/ ParallelEnv().nranks)
optim = make_optimizer(step_per_epoch, model.parameters())
inputs = [Input([None, 8, 3, 224, 224], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
model.prepare(
optim,
CrossEntropy(),
metrics=Accuracy(topk=(1, 5)),
inputs=inputs,
labels=labels,
device=FLAGS.device)
if FLAGS.eval_only:
if FLAGS.weights is not None:
model.load(FLAGS.weights, reset_optimizer=True)
model.evaluate(
val_dataset,
batch_size=FLAGS.batch_size,
num_workers=FLAGS.num_workers)
return
if FLAGS.resume is not None:
model.load(FLAGS.resume)
model.fit(train_data=train_dataset,
eval_data=val_dataset,
epochs=FLAGS.epoch,
batch_size=FLAGS.batch_size,
save_dir='tsm_checkpoint',
num_workers=FLAGS.num_workers,
drop_last=True,
shuffle=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser("CNN training on TSM")
parser.add_argument(
"--data", type=str, default='dataset/kinetics',
help="path to dataset root directory")
parser.add_argument(
"--device", type=str, default='gpu', help="device to use, gpu or cpu")
parser.add_argument(
"-d", "--dynamic", action='store_true', help="enable dygraph mode")
parser.add_argument(
"--eval_only", action='store_true', help="run evaluation only")
parser.add_argument(
"-e", "--epoch", default=70, type=int, help="number of epoch")
parser.add_argument(
"-j", "--num_workers", default=4, type=int, help="read worker number")
parser.add_argument(
'--lr',
'--learning-rate',
default=1e-2,
type=float,
metavar='LR',
help='initial learning rate')
parser.add_argument(
"-b", "--batch_size", default=16, type=int, help="batch size")
parser.add_argument(
"-r",
"--resume",
default=None,
type=str,
help="checkpoint path to resume")
parser.add_argument(
"-w",
"--weights",
default=None,
type=str,
help="weights path for evaluation")
FLAGS = parser.parse_args()
check_gpu(str.lower(FLAGS.device) == 'gpu')
check_version()
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import traceback
import numpy as np
from PIL import Image
import logging
logger = logging.getLogger(__name__)
__all__ = ['GroupScale', 'GroupMultiScaleCrop', 'GroupRandomCrop',
'GroupRandomFlip', 'GroupCenterCrop', 'NormalizeImage',
'Compose']
class Compose(object):
def __init__(self, transforms=[]):
self.transforms = transforms
def __call__(self, *data):
for f in self.transforms:
try:
data = f(*data)
except Exception as e:
stack_info = traceback.format_exc()
logger.info("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
return data
class GroupScale(object):
"""
Group scale image
Args:
target_size (int): image resize target size
"""
def __init__(self, target_size=224):
self.target_size = target_size
def __call__(self, imgs, label):
resized_imgs = []
for i in range(len(imgs)):
img = imgs[i]
w, h = img.size
if (w <= h and w == self.target_size) or \
(h <= w and h == self.target_size):
resized_imgs.append(img)
continue
if w < h:
ow = self.target_size
oh = int(self.target_size * 4.0 / 3.0)
resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
else:
oh = self.target_size
ow = int(self.target_size * 4.0 / 3.0)
resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
return resized_imgs, label
class GroupMultiScaleCrop(object):
"""
FIXME: add comments
"""
def __init__(self,
short_size=256,
scales=None,
max_distort=1,
fix_crop=True,
more_fix_crop=True):
self.short_size = short_size
self.scales = scales if scales is not None \
else [1, .875, .75, .66]
self.max_distort = max_distort
self.fix_crop = fix_crop
self.more_fix_crop = more_fix_crop
def __call__(self, imgs, label):
input_size = [self.short_size, self.short_size]
im_size = imgs[0].size
# get random crop offset
def _sample_crop_size(im_size):
image_w, image_h = im_size[0], im_size[1]
base_size = min(image_w, image_h)
crop_sizes = [int(base_size * x) for x in self.scales]
crop_h = [
input_size[1] if abs(x - input_size[1]) < 3 else x
for x in crop_sizes
]
crop_w = [
input_size[0] if abs(x - input_size[0]) < 3 else x
for x in crop_sizes
]
pairs = []
for i, h in enumerate(crop_h):
for j, w in enumerate(crop_w):
if abs(i - j) <= self.max_distort:
pairs.append((w, h))
crop_pair = random.choice(pairs)
if not self.fix_crop:
w_offset = np.random.randint(0, image_w - crop_pair[0])
h_offset = np.random.randint(0, image_h - crop_pair[1])
else:
w_step = (image_w - crop_pair[0]) / 4
h_step = (image_h - crop_pair[1]) / 4
ret = list()
ret.append((0, 0)) # upper left
if w_step != 0:
ret.append((4 * w_step, 0)) # upper right
if h_step != 0:
ret.append((0, 4 * h_step)) # lower left
if h_step != 0 and w_step != 0:
ret.append((4 * w_step, 4 * h_step)) # lower right
if h_step != 0 or w_step != 0:
ret.append((2 * w_step, 2 * h_step)) # center
if self.more_fix_crop:
ret.append((0, 2 * h_step)) # center left
ret.append((4 * w_step, 2 * h_step)) # center right
ret.append((2 * w_step, 4 * h_step)) # lower center
ret.append((2 * w_step, 0 * h_step)) # upper center
ret.append((1 * w_step, 1 * h_step)) # upper left quarter
ret.append((3 * w_step, 1 * h_step)) # upper right quarter
ret.append((1 * w_step, 3 * h_step)) # lower left quarter
ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
w_offset, h_offset = random.choice(ret)
return crop_pair[0], crop_pair[1], w_offset, h_offset
crop_w, crop_h, offset_w, offset_h = _sample_crop_size(im_size)
crop_imgs = [
img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h))
for img in imgs
]
ret_imgs = [
img.resize((input_size[0], input_size[1]), Image.BILINEAR)
for img in crop_imgs
]
return ret_imgs, label
class GroupRandomCrop(object):
def __init__(self, target_size=224):
self.target_size = target_size
def __call__(self, imgs, label):
w, h = imgs[0].size
th, tw = self.target_size, self.target_size
assert (w >= self.target_size) and (h >= self.target_size), \
"image width({}) and height({}) should be larger than " \
"crop size".format(w, h, self.target_size)
out_images = []
x1 = np.random.randint(0, w - tw)
y1 = np.random.randint(0, h - th)
for img in imgs:
if w == tw and h == th:
out_images.append(img)
else:
out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
return out_images, label
class GroupRandomFlip(object):
def __call__(self, imgs, label):
v = np.random.random()
if v < 0.5:
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in imgs]
return ret, label
else:
return imgs, label
class GroupCenterCrop(object):
def __init__(self, target_size=224):
self.target_size = target_size
def __call__(self, imgs, label):
crop_imgs = []
for img in imgs:
w, h = img.size
th, tw = self.target_size, self.target_size
assert (w >= self.target_size) and (h >= self.target_size), \
"image width({}) and height({}) should be larger " \
"than crop size".format(w, h, self.target_size)
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
crop_imgs.append(img.crop((x1, y1, x1 + tw, y1 + th)))
return crop_imgs, label
class NormalizeImage(object):
def __init__(self,
target_size=224,
img_mean=[0.485, 0.456, 0.406],
img_std=[0.229, 0.224, 0.225],
seg_num=8,
seg_len=1):
self.target_size = target_size
self.img_mean = np.array(img_mean).reshape((3, 1, 1)).astype('float32')
self.img_std = np.array(img_std).reshape((3, 1, 1)).astype('float32')
self.seg_num = seg_num
self.seg_len = seg_len
def __call__(self, imgs, label):
np_imgs = (np.array(imgs[0]).astype('float32').transpose(
(2, 0, 1))).reshape(1, 3, self.target_size,
self.target_size) / 255
for i in range(len(imgs) - 1):
img = (np.array(imgs[i + 1]).astype('float32').transpose(
(2, 0, 1))).reshape(1, 3, self.target_size,
self.target_size) / 255
np_imgs = np.concatenate((np_imgs, img))
np_imgs -= self.img_mean
np_imgs /= self.img_std
np_imgs = np.reshape(np_imgs, (self.seg_num, self.seg_len * 3,
self.target_size, self.target_size))
return np_imgs, label
...@@ -117,7 +117,7 @@ python main.py --help ...@@ -117,7 +117,7 @@ python main.py --help
使用如下方式进行多卡训练: 使用如下方式进行多卡训练:
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --data=<path/to/dataset> --batch_size=16 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --data=<path/to/dataset> --batch_size=16
``` ```
#### 动态图训练 #### 动态图训练
...@@ -127,7 +127,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --data=<path/to/dataset> --batch_siz ...@@ -127,7 +127,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --data=<path/to/dataset> --batch_siz
使用如下方式进行多卡训练: 使用如下方式进行多卡训练:
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --data=<path/to/dataset> --batch_size=16 -d CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py -m paddle.distributed.launch --data=<path/to/dataset> --batch_size=16 -d
``` ```
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册