diff --git a/fluid/video_classification/reader.py b/fluid/video_classification/reader.py index 11cfaa5b3ddc949d20f7f33d15f957cba5225919..9c4fd812a63cd1e2b581dadaf6eaf70373e76b3c 100644 --- a/fluid/video_classification/reader.py +++ b/fluid/video_classification/reader.py @@ -1,17 +1,20 @@ import os +import sys import math import random import functools -import cPickle -from cStringIO import StringIO +try: + import cPickle as pickle + from cStringIO import StringIO +except ImportError: + import pickle + from io import BytesIO import numpy as np import paddle from PIL import Image, ImageEnhance random.seed(0) -DATA_DIM = 224 - THREAD = 8 BUF_SIZE = 1024 @@ -22,17 +25,13 @@ INFER_LIST = 'data/test.list' img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) +python_ver = sys.version_info def imageloader(buf): if isinstance(buf, str): - tempbuff = StringIO() - tempbuff.write(buf) - tempbuff.seek(0) - img = Image.open(tempbuff) - elif isinstance(buf, collections.Sequence): - img = Image.open(StringIO(buf[-1])) - else: img = Image.open(StringIO(buf)) + else: + img = Image.open(BytesIO(buf)) return img.convert('RGB') @@ -98,7 +97,7 @@ def group_center_crop(img_group, target_size): def video_loader(frames, nsample, mode): videolen = len(frames) - average_dur = videolen / nsample + average_dur = videolen // nsample imgs = [] for i in range(nsample): @@ -111,12 +110,12 @@ def video_loader(frames, nsample, mode): idx = i else: if average_dur >= 1: - idx = (average_dur - 1) / 2 + idx = (average_dur - 1) // 2 idx += i * average_dur else: idx = i - imgbuf = frames[idx % videolen] + imgbuf = frames[int(idx % videolen)] img = imageloader(imgbuf) imgs.append(img) @@ -125,7 +124,10 @@ def video_loader(frames, nsample, mode): def decode_pickle(sample, mode, seg_num, short_size, target_size): pickle_path = sample[0] - data_loaded = cPickle.load(open(pickle_path)) + if python_ver < (3, 0): + data_loaded = pickle.load(open(pickle_path, 'rb')) + else: + data_loaded = pickle.load(open(pickle_path, 'rb'), encoding='bytes') vid, label, frames = data_loaded imgs = video_loader(frames, seg_num, mode) diff --git a/fluid/video_classification/resnet.py b/fluid/video_classification/resnet.py index a6eeb5890f4d7624b776cb81f60dc2bc601cfd0b..494235469a37939d67f4239d2b47f8d9461264f4 100644 --- a/fluid/video_classification/resnet.py +++ b/fluid/video_classification/resnet.py @@ -22,7 +22,7 @@ class TSN_ResNet(): num_filters=num_filters, filter_size=filter_size, stride=stride, - padding=(filter_size - 1) / 2, + padding=(filter_size - 1) // 2, groups=groups, act=None, bias_attr=False) diff --git a/fluid/video_classification/train.py b/fluid/video_classification/train.py index c879bf688233dce5d1ce839af76ca41164e3a571..e873cdb608ccfd83a8600e77b4837e2e52872549 100644 --- a/fluid/video_classification/train.py +++ b/fluid/video_classification/train.py @@ -91,7 +91,7 @@ def train(args): fluid.io.load_vars(exe, pretrained_model, vars=vars) # reader - train_reader = paddle.batch(reader.train(seg_num), batch_size=batch_size) + train_reader = paddle.batch(reader.train(seg_num), batch_size=batch_size, drop_last=True) # test in single GPU test_reader = paddle.batch(reader.test(seg_num), batch_size=batch_size / 16) feeder = fluid.DataFeeder(place=place, feed_list=[image, label]) diff --git a/fluid/video_classification/utility.py b/fluid/video_classification/utility.py index 6e5237341fdedaa940d5d6d1231d38bb461f0590..20b4141f7fb24b1617a1ef0f1d4a3c2536213b14 100644 --- a/fluid/video_classification/utility.py +++ b/fluid/video_classification/utility.py @@ -17,6 +17,7 @@ from __future__ import division from __future__ import print_function import distutils.util import numpy as np +import six from paddle.fluid import core @@ -36,7 +37,7 @@ def print_arguments(args): :type args: argparse.Namespace """ print("----------- Configuration Arguments -----------") - for arg, value in sorted(vars(args).iteritems()): + for arg, value in sorted(six.iteritems(vars(args))): print("%s: %s" % (arg, value)) print("------------------------------------------------")