diff --git a/fluid/video_classification/README.md b/fluid/video_classification/README.md new file mode 100644 index 0000000000000000000000000000000000000000..34c361f2ab685b683bfbd73b5b41aa65c69e4f98 --- /dev/null +++ b/fluid/video_classification/README.md @@ -0,0 +1,141 @@ +# Video Classification Based on Temporal Segment Network + +Video classification has drawn a significant amount of attentions in the past few years. This page introduces how to perform video classification with PaddlePaddle Fluid, on the public UCF-101 dataset, based on the state-of-the-art Temporal Segment Network (TSN) method. + +______________________________________________________________________________ + +## Table of Contents +
data/download.sh
.
+
+#### decode video into frame
+To avoid the process of decoding videos in network training, we offline decode them into frames and save it in the pickle
format, easily readable for python.
+
+Users can refer to the script data/video_decode.py
for video decoding.
+
+#### split data into train and test
+We follow the split 1 of UCF-101 dataset. After data splitting, users can get 9537 videos for training and 3783 videos for validation. The reference script is data/split_data.py
.
+
+#### save pickle for training
+As stated above, we save all data as pickle
format for training. All information in each video is saved into one pickle, includes video id, frames binary and label. Please refer to the script data/generate_train_data.py
.
+After this operation, one can get two directories containing training and testing data in pickle
format, and two files train.list and test.list, with each line seperated by SPACE.
+
+### Training
+After data preparation, users can start the PaddlePaddle Fluid training by:
+```
+python train.py \
+ --batch_size=128 \
+ --total_videos=9537 \
+ --class_dim=101 \
+ --num_epochs=60 \
+ --image_shape=3,224,224 \
+ --model_save_dir=output/ \
+ --with_mem_opt=True \
+ --lr_init=0.01 \
+ --num_layers=50 \
+ --seg_num=7 \
+ --pretrained_model={path_to_pretrained_model}
+```
+
+parameter introduction:
+reader.py
. Note that we use group operation for all frames in one video.
+
+
+training:
+The training log is like:
+```
+[TRAIN] Pass: 0 trainbatch: 0 loss: 4.630959 acc1: 0.0 acc5: 0.0390625 time: 3.09 sec
+[TRAIN] Pass: 0 trainbatch: 10 loss: 4.559069 acc1: 0.0546875 acc5: 0.1171875 time: 3.91 sec
+[TRAIN] Pass: 0 trainbatch: 20 loss: 4.040092 acc1: 0.09375 acc5: 0.3515625 time: 3.88 sec
+[TRAIN] Pass: 0 trainbatch: 30 loss: 3.478214 acc1: 0.3203125 acc5: 0.5546875 time: 3.32 sec
+[TRAIN] Pass: 0 trainbatch: 40 loss: 3.005404 acc1: 0.3515625 acc5: 0.6796875 time: 3.33 sec
+[TRAIN] Pass: 0 trainbatch: 50 loss: 2.585245 acc1: 0.4609375 acc5: 0.7265625 time: 3.13 sec
+[TRAIN] Pass: 0 trainbatch: 60 loss: 2.151489 acc1: 0.4921875 acc5: 0.8203125 time: 3.35 sec
+[TRAIN] Pass: 0 trainbatch: 70 loss: 1.981680 acc1: 0.578125 acc5: 0.8359375 time: 3.30 sec
+```
+
+### Evaluation
+Evaluation is to evaluate the performance of a trained model. One can download pretrained models and set its path to path_to_pretrain_model. Then top1/top5 accuracy can be obtained by running the following command:
+```
+python eval.py \
+ --batch_size=128 \
+ --class_dim=101 \
+ --image_shape=3,224,224 \
+ --with_mem_opt=True \
+ --num_layers=50 \
+ --seg_num=7 \
+ --test_model={path_to_pretrained_model}
+```
+
+According to the congfiguration of evaluation, the output log is like:
+```
+[TEST] Pass: 0 testbatch: 0 loss: 0.011551 acc1: 1.0 acc5: 1.0 time: 0.48 sec
+[TEST] Pass: 0 testbatch: 10 loss: 0.710330 acc1: 0.75 acc5: 1.0 time: 0.49 sec
+[TEST] Pass: 0 testbatch: 20 loss: 0.000547 acc1: 1.0 acc5: 1.0 time: 0.48 sec
+[TEST] Pass: 0 testbatch: 30 loss: 0.036623 acc1: 1.0 acc5: 1.0 time: 0.48 sec
+[TEST] Pass: 0 testbatch: 40 loss: 0.138705 acc1: 1.0 acc5: 1.0 time: 0.48 sec
+[TEST] Pass: 0 testbatch: 50 loss: 0.056909 acc1: 1.0 acc5: 1.0 time: 0.49 sec
+[TEST] Pass: 0 testbatch: 60 loss: 0.742937 acc1: 0.75 acc5: 1.0 time: 0.49 sec
+[TEST] Pass: 0 testbatch: 70 loss: 1.720186 acc1: 0.5 acc5: 0.875 time: 0.48 sec
+[TEST] Pass: 0 testbatch: 80 loss: 0.199669 acc1: 0.875 acc5: 1.0 time: 0.48 sec
+[TEST] Pass: 0 testbatch: 90 loss: 0.195510 acc1: 1.0 acc5: 1.0 time: 0.48 sec
+```
+
+### Inference
+Inference is used to get prediction score or video features based on trained models.
+```
+python infer.py \
+ --batch_size=128 \
+ --class_dim=101 \
+ --image_shape=3,224,224 \
+ --with_mem_opt=True \
+ --num_layers=50 \
+ --seg_num=7 \
+ --test_model={path_to_pretrained_model}
+```
+
+The output contains predication results, including maximum score (before softmax) and corresponding predicted label.
+```
+Test sample: PlayingGuitar_g01_c03, score: [21.418629], class [62]
+Test sample: SalsaSpin_g05_c06, score: [13.238657], class [76]
+Test sample: TrampolineJumping_g04_c01, score: [21.722862], class [93]
+Test sample: JavelinThrow_g01_c04, score: [16.27892], class [44]
+Test sample: PlayingTabla_g01_c01, score: [15.366951], class [65]
+Test sample: ParallelBars_g04_c07, score: [18.42596], class [56]
+Test sample: PlayingCello_g05_c05, score: [18.795723], class [58]
+Test sample: LongJump_g03_c04, score: [7.100088], class [50]
+Test sample: SkyDiving_g06_c03, score: [15.144707], class [82]
+Test sample: UnevenBars_g07_c04, score: [22.114838], class [95]
+```
+
+### Performance
+Configuration | Top-1 acc
+------------- | ---------------:
+seg=7, size=224 | 0.859
+seg=10, size=224 | 0.863
diff --git a/fluid/video_classification/data/download.sh b/fluid/video_classification/data/download.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f7a8045d19907824249575b8538692f59325aa28
--- /dev/null
+++ b/fluid/video_classification/data/download.sh
@@ -0,0 +1,9 @@
+# Download the dataset
+echo "Downloading..."
+wget http://crcv.ucf.edu/data/UCF101/UCF101.rar
+wget http://crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip
+
+# Extract the data.
+echo "Extracting..."
+unrar x UCF101.rar
+unzip UCF101TrainTestSplits-RecognitionTask.zip
diff --git a/fluid/video_classification/data/generate_train_data.py b/fluid/video_classification/data/generate_train_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..547ebbaa746ab87987c5207784b2a9c1b212315d
--- /dev/null
+++ b/fluid/video_classification/data/generate_train_data.py
@@ -0,0 +1,35 @@
+import os
+import cPickle
+
+# read class file
+dd = {}
+f = open('ucfTrainTestlist/classInd.txt')
+for line in f.readlines():
+ label, name = line.split()
+ dd[name.lower()] = int(label) - 1
+f.close()
+
+# generate pkl
+path = 'train/'
+savepath = 'train_pkl/'
+if not os.path.exists(savepath):
+ os.makedirs(savepath)
+
+fw = open('train.list', 'w')
+for folder in os.listdir(path):
+ vidid = folder.split('_', 1)[1]
+ this_label = dd[folder.split('_')[1].lower()]
+ this_feat = []
+ for img in sorted(os.listdir(path + folder)):
+ fout = open(path + folder + '/' + img, 'rb')
+ this_feat.append(fout.read())
+ fout.close()
+
+ res = [vidid, this_label, this_feat]
+
+ outp = open(savepath + vidid + '.pkl', 'wb')
+ cPickle.dump(res, outp, protocol=cPickle.HIGHEST_PROTOCOL)
+ outp.close()
+
+ fw.write('data/train_pkl/%s.pkl\n' % vidid)
+fw.close()
diff --git a/fluid/video_classification/data/split_data.py b/fluid/video_classification/data/split_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..79eb3b4340b4b6fa36c09adf0653f2fbaa76fbf5
--- /dev/null
+++ b/fluid/video_classification/data/split_data.py
@@ -0,0 +1,29 @@
+import os
+import shutil
+
+# set path
+train_path = 'train/'
+if not os.path.exists(train_path):
+ os.makedirs(train_path)
+
+test_path = 'test/'
+if not os.path.exists(test_path):
+ os.makedirs(test_path)
+
+# move data
+frame_dir = 'frame/'
+f = open('ucfTrainTestlist/trainlist01.txt')
+for line in f.readlines():
+ folder = line.split('.')[0]
+ vidid = folder.split('/')[-1]
+
+ shutil.move(frame_dir + folder, train_path + vidid)
+f.close()
+
+f = open('ucfTrainTestlist/testlist01.txt')
+for line in f.readlines():
+ folder = line.split('.')[0]
+ vidid = folder.split('/')[-1]
+
+ shutil.move(frame_dir + folder, test_path + vidid)
+f.close()
diff --git a/fluid/video_classification/data/video_decode.py b/fluid/video_classification/data/video_decode.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4892530cd4ecd88271cf9ccf89c558dc105c65b
--- /dev/null
+++ b/fluid/video_classification/data/video_decode.py
@@ -0,0 +1,20 @@
+import os, sys
+import shutil
+
+
+def decode():
+ path = './UCF-101/'
+ for folder in os.listdir(path):
+ for vid in os.listdir(path + folder):
+ print vid
+ video_path = path + folder + '/' + vid
+ image_folder = './frame/' + folder + '/' + vid.split('.')[0] + '/'
+ if not os.path.exists(image_folder):
+ os.makedirs(image_folder)
+
+ os.system('./ffmpeg -i ' + video_path + ' -q 0 ' + image_folder +
+ '/%06d.jpg')
+
+
+if __name__ == '__main__':
+ decode()
diff --git a/fluid/video_classification/eval.py b/fluid/video_classification/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..91b8d445978f78b9b0883ac2717831904eca09bb
--- /dev/null
+++ b/fluid/video_classification/eval.py
@@ -0,0 +1,121 @@
+import os
+import numpy as np
+import time
+import sys
+import paddle.v2 as paddle
+import paddle.fluid as fluid
+from resnet import TSN_ResNet
+import reader
+
+import argparse
+import functools
+from paddle.fluid.framework import Parameter
+from utility import add_arguments, print_arguments
+
+parser = argparse.ArgumentParser(description=__doc__)
+add_arg = functools.partial(add_arguments, argparser=parser)
+# yapf: disable
+add_arg('batch_size', int, 128, "Minibatch size.")
+add_arg('num_layers', int, 50, "How many layers for ResNet model.")
+add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
+add_arg('class_dim', int, 101, "Number of class.")
+add_arg('seg_num', int, 7, "Number of segments.")
+add_arg('image_shape', str, "3,224,224", "Input image size.")
+add_arg('test_model', str, None, "Test model path.")
+# yapf: enable
+
+
+def eval(args):
+ # parameters from arguments
+ seg_num = args.seg_num
+ class_dim = args.class_dim
+ num_layers = args.num_layers
+ batch_size = args.batch_size
+ test_model = args.test_model
+
+ if test_model == None:
+ print('Please specify the test model ...')
+ return
+
+ image_shape = [int(m) for m in args.image_shape.split(",")]
+ image_shape = [seg_num] + image_shape
+
+ # model definition
+ model = TSN_ResNet(layers=num_layers, seg_num=seg_num)
+ image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
+ label = fluid.layers.data(name='label', shape=[1], dtype='int64')
+
+ out = model.net(input=image, class_dim=class_dim)
+ cost = fluid.layers.cross_entropy(input=out, label=label)
+
+ avg_cost = fluid.layers.mean(x=cost)
+ acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
+ acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
+
+ # for test
+ inference_program = fluid.default_main_program().clone(for_test=True)
+
+ if args.with_mem_opt:
+ fluid.memory_optimize(fluid.default_main_program())
+
+ place = fluid.CUDAPlace(0)
+ exe = fluid.Executor(place)
+ exe.run(fluid.default_startup_program())
+
+ def is_parameter(var):
+ if isinstance(var, Parameter):
+ return isinstance(var, Parameter)
+
+ if test_model is not None:
+ vars = filter(is_parameter, inference_program.list_vars())
+ fluid.io.load_vars(exe, test_model, vars=vars)
+
+ # reader
+ test_reader = paddle.batch(reader.test(seg_num), batch_size=batch_size / 16)
+ feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
+
+ fetch_list = [avg_cost.name, acc_top1.name, acc_top5.name]
+
+ # test
+ cnt = 0
+ pass_id = 0
+ test_info = [[], [], []]
+ for batch_id, data in enumerate(test_reader()):
+ t1 = time.time()
+ loss, acc1, acc5 = exe.run(inference_program,
+ fetch_list=fetch_list,
+ feed=feeder.feed(data))
+ t2 = time.time()
+ period = t2 - t1
+ loss = np.mean(loss)
+ acc1 = np.mean(acc1)
+ acc5 = np.mean(acc5)
+ test_info[0].append(loss * len(data))
+ test_info[1].append(acc1 * len(data))
+ test_info[2].append(acc5 * len(data))
+ cnt += len(data)
+ if batch_id % 10 == 0:
+ print(
+ "[TEST] Pass: {0}\ttestbatch: {1}\tloss: {2}\tacc1: {3}\tacc5: {4}\ttime: {5}"
+ .format(pass_id, batch_id, '%.6f' % loss, acc1, acc5,
+ "%2.2f sec" % period))
+ sys.stdout.flush()
+
+ test_loss = np.sum(test_info[0]) / cnt
+ test_acc1 = np.sum(test_info[1]) / cnt
+ test_acc5 = np.sum(test_info[2]) / cnt
+
+ print("+ End pass: {0}, test_loss: {1}, test_acc1: {2}, test_acc5: {3}"
+ .format(pass_id, '%.3f' % test_loss, '%.3f' % test_acc1, '%.3f' %
+ test_acc5))
+ sys.stdout.flush()
+
+
+def main():
+ args = parser.parse_args()
+ print_arguments(args)
+ eval(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/fluid/video_classification/infer.py b/fluid/video_classification/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f3048b1a89f1f218248d6e5760d08683000343d
--- /dev/null
+++ b/fluid/video_classification/infer.py
@@ -0,0 +1,93 @@
+import os
+import numpy as np
+import time
+import sys
+import paddle.v2 as paddle
+import paddle.fluid as fluid
+from resnet import TSN_ResNet
+import reader
+
+import argparse
+import functools
+from paddle.fluid.framework import Parameter
+from utility import add_arguments, print_arguments
+
+parser = argparse.ArgumentParser(description=__doc__)
+add_arg = functools.partial(add_arguments, argparser=parser)
+# yapf: disable
+add_arg('num_layers', int, 50, "How many layers for ResNet model.")
+add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
+add_arg('class_dim', int, 101, "Number of class.")
+add_arg('seg_num', int, 7, "Number of segments.")
+add_arg('image_shape', str, "3,224,224", "Input image size.")
+add_arg('test_model', str, None, "Test model path.")
+# yapf: enable
+
+
+def infer(args):
+ # parameters from arguments
+ seg_num = args.seg_num
+ class_dim = args.class_dim
+ num_layers = args.num_layers
+ test_model = args.test_model
+
+ if test_model == None:
+ print('Please specify the test model ...')
+ return
+
+ image_shape = [int(m) for m in args.image_shape.split(",")]
+ image_shape = [seg_num] + image_shape
+
+ # model definition
+ model = TSN_ResNet(layers=num_layers, seg_num=seg_num)
+ image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
+
+ out = model.net(input=image, class_dim=class_dim)
+
+ # for test
+ inference_program = fluid.default_main_program().clone(for_test=True)
+
+ if args.with_mem_opt:
+ fluid.memory_optimize(fluid.default_main_program())
+
+ place = fluid.CUDAPlace(0)
+ exe = fluid.Executor(place)
+ exe.run(fluid.default_startup_program())
+
+ def is_parameter(var):
+ if isinstance(var, Parameter):
+ return isinstance(var, Parameter)
+
+ if test_model is not None:
+ vars = filter(is_parameter, inference_program.list_vars())
+ fluid.io.load_vars(exe, test_model, vars=vars)
+
+ # reader
+ test_reader = paddle.batch(reader.infer(seg_num), batch_size=1)
+ feeder = fluid.DataFeeder(place=place, feed_list=[image])
+
+ fetch_list = [out.name]
+
+ # test
+ TOPK = 1
+ for batch_id, data in enumerate(test_reader()):
+ data, vid = data[0]
+ data = [[data]]
+ result = exe.run(inference_program,
+ fetch_list=fetch_list,
+ feed=feeder.feed(data))
+ result = result[0][0]
+ pred_label = np.argsort(result)[::-1][:TOPK]
+ print("Test sample: {0}, score: {1}, class {2}".format(vid, result[
+ pred_label], pred_label))
+ sys.stdout.flush()
+
+
+def main():
+ args = parser.parse_args()
+ print_arguments(args)
+ infer(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/fluid/video_classification/reader.py b/fluid/video_classification/reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..530383f9cf1443c96bc8651fbd5d2e4beba2efd0
--- /dev/null
+++ b/fluid/video_classification/reader.py
@@ -0,0 +1,208 @@
+import os
+import math
+import random
+import functools
+import cPickle
+from cStringIO import StringIO
+import numpy as np
+import paddle.v2 as paddle
+from PIL import Image, ImageEnhance
+
+random.seed(0)
+
+DATA_DIM = 224
+
+THREAD = 8
+BUF_SIZE = 1024
+
+TRAIN_LIST = 'data/train.list'
+TEST_LIST = 'data/val.list'
+INFER_LIST = 'data/val.list'
+
+img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
+img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
+
+
+def imageloader(buf):
+ if isinstance(buf, str):
+ tempbuff = StringIO()
+ tempbuff.write(buf)
+ tempbuff.seek(0)
+ img = Image.open(tempbuff)
+ elif isinstance(buf, collections.Sequence):
+ img = Image.open(StringIO(buf[-1]))
+ else:
+ img = Image.open(StringIO(buf))
+
+ return img.convert('RGB')
+
+
+def group_scale(imgs, target_size):
+ resized_imgs = []
+ for i in range(len(imgs)):
+ img = imgs[i]
+ w, h = img.size
+ if (w <= h and w == target_size) or (h <= w and h == target_size):
+ resized_imgs.append(img)
+ continue
+
+ if w < h:
+ ow = target_size
+ oh = int(target_size * 4.0 / 3.0)
+ resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
+ else:
+ oh = target_size
+ ow = int(target_size * 4.0 / 3.0)
+ resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
+
+ return resized_imgs
+
+
+def group_random_crop(img_group, target_size):
+ w, h = img_group[0].size
+ th, tw = target_size, target_size
+
+ out_images = []
+ x1 = random.randint(0, w - tw)
+ y1 = random.randint(0, h - th)
+
+ for img in img_group:
+ if w == tw and h == th:
+ out_images.append(img)
+ else:
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
+
+ return out_images
+
+
+def group_random_flip(img_group):
+ v = random.random()
+ if v < 0.5:
+ ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
+ return ret
+ else:
+ return img_group
+
+
+def group_center_crop(img_group, target_size):
+ img_crop = []
+ for img in img_group:
+ w, h = img.size
+ th, tw = target_size, target_size
+ x1 = int(round((w - tw) / 2.))
+ y1 = int(round((h - th) / 2.))
+ img_crop.append(img.crop((x1, y1, x1 + tw, y1 + th)))
+
+ return img_crop
+
+
+def video_loader(frames, nsample, mode):
+ videolen = len(frames)
+ average_dur = videolen / nsample
+
+ imgs = []
+ for i in range(nsample):
+ idx = 0
+ if mode == 'train':
+ if average_dur >= 1:
+ idx = random.randint(0, average_dur - 1)
+ idx += i * average_dur
+ else:
+ idx = i
+ else:
+ if average_dur >= 1:
+ idx = (average_dur - 1) / 2
+ idx += i * average_dur
+ else:
+ idx = i
+
+ imgbuf = frames[idx % videolen]
+ img = imageloader(imgbuf)
+ imgs.append(img)
+
+ return imgs
+
+
+def decode_pickle(sample, mode, seg_num, short_size, target_size):
+ pickle_path = sample[0]
+ data_loaded = cPickle.load(open(pickle_path))
+ vid, label, frames = data_loaded
+
+ imgs = video_loader(frames, seg_num, mode)
+ imgs = group_scale(imgs, short_size)
+
+ if mode == 'train':
+ imgs = group_random_crop(imgs, target_size)
+ imgs = group_random_flip(imgs)
+ else:
+ imgs = group_center_crop(imgs, target_size)
+
+ np_imgs = (np.array(imgs[0]).astype('float32').transpose(
+ (2, 0, 1))).reshape(1, 3, 224, 224) / 255
+ for i in range(len(imgs) - 1):
+ img = (np.array(imgs[i + 1]).astype('float32').transpose(
+ (2, 0, 1))).reshape(1, 3, 224, 224) / 255
+ np_imgs = np.concatenate((np_imgs, img))
+ imgs = np_imgs
+ imgs -= img_mean
+ imgs /= img_std
+
+ if mode == 'train' or mode == 'test':
+ return imgs, label
+ elif mode == 'infer':
+ return imgs, vid
+
+
+def _reader_creator(pickle_list,
+ mode,
+ seg_num,
+ short_size,
+ target_size,
+ shuffle=False):
+ def reader():
+ with open(pickle_list) as flist:
+ lines = [line.strip() for line in flist]
+ if shuffle:
+ random.shuffle(lines)
+ for line in lines:
+ pickle_path = line.strip()
+ yield [pickle_path]
+
+ mapper = functools.partial(
+ decode_pickle,
+ mode=mode,
+ seg_num=seg_num,
+ short_size=short_size,
+ target_size=target_size)
+
+ return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
+
+
+def train(seg_num):
+ return _reader_creator(
+ TRAIN_LIST,
+ 'train',
+ shuffle=True,
+ seg_num=seg_num,
+ short_size=256,
+ target_size=224)
+
+
+def test(seg_num):
+ return _reader_creator(
+ TEST_LIST,
+ 'test',
+ shuffle=False,
+ seg_num=seg_num,
+ short_size=256,
+ target_size=224)
+
+
+def infer(seg_num):
+ return _reader_creator(
+ INFER_LIST,
+ 'infer',
+ shuffle=False,
+ seg_num=seg_num,
+ short_size=256,
+ target_size=224)
diff --git a/fluid/video_classification/resnet.py b/fluid/video_classification/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6eeb5890f4d7624b776cb81f60dc2bc601cfd0b
--- /dev/null
+++ b/fluid/video_classification/resnet.py
@@ -0,0 +1,106 @@
+import os
+import time
+import sys
+import paddle.fluid as fluid
+import math
+
+
+class TSN_ResNet():
+ def __init__(self, layers=50, seg_num=7):
+ self.layers = layers
+ self.seg_num = seg_num
+
+ def conv_bn_layer(self,
+ input,
+ num_filters,
+ filter_size,
+ stride=1,
+ groups=1,
+ act=None):
+ conv = fluid.layers.conv2d(
+ input=input,
+ num_filters=num_filters,
+ filter_size=filter_size,
+ stride=stride,
+ padding=(filter_size - 1) / 2,
+ groups=groups,
+ act=None,
+ bias_attr=False)
+ return fluid.layers.batch_norm(input=conv, act=act)
+
+ def shortcut(self, input, ch_out, stride):
+ ch_in = input.shape[1]
+ if ch_in != ch_out or stride != 1:
+ return self.conv_bn_layer(input, ch_out, 1, stride)
+ else:
+ return input
+
+ def bottleneck_block(self, input, num_filters, stride):
+ conv0 = self.conv_bn_layer(
+ input=input, num_filters=num_filters, filter_size=1, act='relu')
+ conv1 = self.conv_bn_layer(
+ input=conv0,
+ num_filters=num_filters,
+ filter_size=3,
+ stride=stride,
+ act='relu')
+ conv2 = self.conv_bn_layer(
+ input=conv1, num_filters=num_filters * 4, filter_size=1, act=None)
+
+ short = self.shortcut(input, num_filters * 4, stride)
+
+ return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
+
+ def net(self, input, class_dim=101):
+ layers = self.layers
+ seg_num = self.seg_num
+ supported_layers = [50, 101, 152]
+ if layers not in supported_layers:
+ print("supported layers are", supported_layers, \
+ "but input layer is ", layers)
+ exit()
+
+ # reshape input
+ channels = input.shape[2]
+ short_size = input.shape[3]
+ input = fluid.layers.reshape(
+ x=input, shape=[-1, channels, short_size, short_size])
+
+ if layers == 50:
+ depth = [3, 4, 6, 3]
+ elif layers == 101:
+ depth = [3, 4, 23, 3]
+ elif layers == 152:
+ depth = [3, 8, 36, 3]
+ num_filters = [64, 128, 256, 512]
+
+ conv = self.conv_bn_layer(
+ input=input, num_filters=64, filter_size=7, stride=2, act='relu')
+ conv = fluid.layers.pool2d(
+ input=conv,
+ pool_size=3,
+ pool_stride=2,
+ pool_padding=1,
+ pool_type='max')
+
+ for block in range(len(depth)):
+ for i in range(depth[block]):
+ conv = self.bottleneck_block(
+ input=conv,
+ num_filters=num_filters[block],
+ stride=2 if i == 0 and block != 0 else 1)
+ pool = fluid.layers.pool2d(
+ input=conv, pool_size=7, pool_type='avg', global_pooling=True)
+
+ feature = fluid.layers.reshape(
+ x=pool, shape=[-1, seg_num, pool.shape[1]])
+ out = fluid.layers.reduce_mean(feature, dim=1)
+
+ stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
+ out = fluid.layers.fc(input=out,
+ size=class_dim,
+ act='softmax',
+ param_attr=fluid.param_attr.ParamAttr(
+ initializer=fluid.initializer.Uniform(-stdv,
+ stdv)))
+ return out
diff --git a/fluid/video_classification/train.py b/fluid/video_classification/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e8552d4514440e0a3b9f7eab8781dd0be54632f
--- /dev/null
+++ b/fluid/video_classification/train.py
@@ -0,0 +1,179 @@
+import os
+import numpy as np
+import time
+import sys
+import paddle.fluid as fluid
+from resnet import TSN_ResNet
+import reader
+
+import argparse
+import functools
+from paddle.fluid.framework import Parameter
+from utility import add_arguments, print_arguments
+
+parser = argparse.ArgumentParser(description=__doc__)
+add_arg = functools.partial(add_arguments, argparser=parser)
+# yapf: disable
+add_arg('batch_size', int, 128, "Minibatch size.")
+add_arg('num_layers', int, 50, "How many layers for ResNet model.")
+add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
+add_arg('num_epochs', int, 60, "Number of epochs.")
+add_arg('class_dim', int, 101, "Number of class.")
+add_arg('seg_num', int, 7, "Number of segments.")
+add_arg('image_shape', str, "3,224,224", "Input image size.")
+add_arg('model_save_dir', str, "output", "Model save directory.")
+add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
+add_arg('total_videos', int, 9537, "Training video number.")
+add_arg('lr_init', float, 0.01, "Set initial learning rate.")
+# yapf: enable
+
+
+def train(args):
+ # parameters from arguments
+ seg_num = args.seg_num
+ class_dim = args.class_dim
+ num_layers = args.num_layers
+ num_epochs = args.num_epochs
+ batch_size = args.batch_size
+ pretrained_model = args.pretrained_model
+ model_save_dir = args.model_save_dir
+
+ image_shape = [int(m) for m in args.image_shape.split(",")]
+ image_shape = [seg_num] + image_shape
+
+ # model definition
+ model = TSN_ResNet(layers=num_layers, seg_num=seg_num)
+
+ image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
+ label = fluid.layers.data(name='label', shape=[1], dtype='int64')
+
+ out = model.net(input=image, class_dim=class_dim)
+ cost = fluid.layers.cross_entropy(input=out, label=label)
+
+ avg_cost = fluid.layers.mean(x=cost)
+ acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
+ acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
+
+ # for test
+ inference_program = fluid.default_main_program().clone(for_test=True)
+
+ # learning rate strategy
+ epoch_points = [num_epochs / 3, num_epochs * 2 / 3]
+ total_videos = args.total_videos
+ step = int(total_videos / batch_size + 1)
+ bd = [e * step for e in epoch_points]
+
+ lr_init = args.lr_init
+ lr = [lr_init, lr_init / 10, lr_init / 100]
+
+ # initialize optimizer
+ optimizer = fluid.optimizer.Momentum(
+ learning_rate=fluid.layers.piecewise_decay(
+ boundaries=bd, values=lr),
+ momentum=0.9,
+ regularization=fluid.regularizer.L2Decay(1e-4))
+
+ opts = optimizer.minimize(avg_cost)
+ if args.with_mem_opt:
+ fluid.memory_optimize(fluid.default_main_program())
+
+ place = fluid.CUDAPlace(0)
+ exe = fluid.Executor(place)
+ exe.run(fluid.default_startup_program())
+
+ def is_parameter(var):
+ if isinstance(var, Parameter):
+ return isinstance(var, Parameter) and (not ("fc_0" in var.name))
+
+ if pretrained_model is not None:
+ vars = filter(is_parameter, inference_program.list_vars())
+ fluid.io.load_vars(exe, pretrained_model, vars=vars)
+
+ # reader
+ train_reader = paddle.batch(reader.train(seg_num), batch_size=batch_size)
+ # test in single GPU
+ test_reader = paddle.batch(reader.test(seg_num), batch_size=batch_size / 16)
+ feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
+
+ train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name)
+
+ fetch_list = [avg_cost.name, acc_top1.name, acc_top5.name]
+
+ # train
+ for pass_id in range(num_epochs):
+ train_info = [[], [], []]
+ test_info = [[], [], []]
+ for batch_id, data in enumerate(train_reader()):
+ t1 = time.time()
+ loss, acc1, acc5 = train_exe.run(fetch_list, feed=feeder.feed(data))
+ t2 = time.time()
+ period = t2 - t1
+ loss = np.mean(np.array(loss))
+ acc1 = np.mean(np.array(acc1))
+ acc5 = np.mean(np.array(acc5))
+ train_info[0].append(loss)
+ train_info[1].append(acc1)
+ train_info[2].append(acc5)
+
+ if batch_id % 10 == 0:
+ print(
+ "[TRAIN] Pass: {0}\ttrainbatch: {1}\tloss: {2}\tacc1: {3}\tacc5: {4}\ttime: {5}"
+ .format(pass_id, batch_id, '%.6f' % loss, acc1, acc5,
+ "%2.2f sec" % period))
+ sys.stdout.flush()
+
+ train_loss = np.array(train_info[0]).mean()
+ train_acc1 = np.array(train_info[1]).mean()
+ train_acc5 = np.array(train_info[2]).mean()
+
+ # test
+ cnt = 0
+ for batch_id, data in enumerate(test_reader()):
+ t1 = time.time()
+ loss, acc1, acc5 = exe.run(inference_program,
+ fetch_list=fetch_list,
+ feed=feeder.feed(data))
+ t2 = time.time()
+ period = t2 - t1
+ loss = np.mean(loss)
+ acc1 = np.mean(acc1)
+ acc5 = np.mean(acc5)
+ test_info[0].append(loss * len(data))
+ test_info[1].append(acc1 * len(data))
+ test_info[2].append(acc5 * len(data))
+ cnt += len(data)
+ if batch_id % 10 == 0:
+ print(
+ "[TEST] Pass: {0}\ttestbatch: {1}\tloss: {2}\tacc1: {3}\tacc5: {4}\ttime: {5}"
+ .format(pass_id, batch_id, '%.6f' % loss, acc1, acc5,
+ "%2.2f sec" % period))
+ sys.stdout.flush()
+
+ test_loss = np.sum(test_info[0]) / cnt
+ test_acc1 = np.sum(test_info[1]) / cnt
+ test_acc5 = np.sum(test_info[2]) / cnt
+
+ print(
+ "+ End pass: {0}, train_loss: {1}, train_acc1: {2}, train_acc5: {3}"
+ .format(pass_id, '%.3f' % train_loss, '%.3f' % train_acc1, '%.3f' %
+ train_acc5))
+ print("+ End pass: {0}, test_loss: {1}, test_acc1: {2}, test_acc5: {3}"
+ .format(pass_id, '%.3f' % test_loss, '%.3f' % test_acc1, '%.3f' %
+ test_acc5))
+ sys.stdout.flush()
+
+ # save model
+ model_path = os.path.join(model_save_dir, str(pass_id))
+ if not os.path.isdir(model_path):
+ os.makedirs(model_path)
+ fluid.io.save_persistables(exe, model_path)
+
+
+def main():
+ args = parser.parse_args()
+ print_arguments(args)
+ train(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/fluid/video_classification/utility.py b/fluid/video_classification/utility.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e5237341fdedaa940d5d6d1231d38bb461f0590
--- /dev/null
+++ b/fluid/video_classification/utility.py
@@ -0,0 +1,61 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+"""Contains common utility functions."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import distutils.util
+import numpy as np
+from paddle.fluid import core
+
+
+def print_arguments(args):
+ """Print argparse's arguments.
+
+ Usage:
+
+ .. code-block:: python
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("name", default="Jonh", type=str, help="User name.")
+ args = parser.parse_args()
+ print_arguments(args)
+
+ :param args: Input argparse.Namespace for printing.
+ :type args: argparse.Namespace
+ """
+ print("----------- Configuration Arguments -----------")
+ for arg, value in sorted(vars(args).iteritems()):
+ print("%s: %s" % (arg, value))
+ print("------------------------------------------------")
+
+
+def add_arguments(argname, type, default, help, argparser, **kwargs):
+ """Add argparse's argument.
+
+ Usage:
+
+ .. code-block:: python
+
+ parser = argparse.ArgumentParser()
+ add_argument("name", str, "Jonh", "User name.", parser)
+ args = parser.parse_args()
+ """
+ type = distutils.util.strtobool if type == bool else type
+ argparser.add_argument(
+ "--" + argname,
+ default=default,
+ type=type,
+ help=help + ' Default: %(default)s.',
+ **kwargs)