未验证 提交 9f9817d3 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #1331 from CrossLee1/develop

support python3 in video classification
import os import os
import sys
import math import math
import random import random
import functools import functools
import cPickle try:
from cStringIO import StringIO import cPickle as pickle
from cStringIO import StringIO
except ImportError:
import pickle
from io import BytesIO
import numpy as np import numpy as np
import paddle import paddle
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
random.seed(0) random.seed(0)
DATA_DIM = 224
THREAD = 8 THREAD = 8
BUF_SIZE = 1024 BUF_SIZE = 1024
...@@ -22,17 +25,13 @@ INFER_LIST = 'data/test.list' ...@@ -22,17 +25,13 @@ INFER_LIST = 'data/test.list'
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
python_ver = sys.version_info
def imageloader(buf): def imageloader(buf):
if isinstance(buf, str): if isinstance(buf, str):
tempbuff = StringIO()
tempbuff.write(buf)
tempbuff.seek(0)
img = Image.open(tempbuff)
elif isinstance(buf, collections.Sequence):
img = Image.open(StringIO(buf[-1]))
else:
img = Image.open(StringIO(buf)) img = Image.open(StringIO(buf))
else:
img = Image.open(BytesIO(buf))
return img.convert('RGB') return img.convert('RGB')
...@@ -98,7 +97,7 @@ def group_center_crop(img_group, target_size): ...@@ -98,7 +97,7 @@ def group_center_crop(img_group, target_size):
def video_loader(frames, nsample, mode): def video_loader(frames, nsample, mode):
videolen = len(frames) videolen = len(frames)
average_dur = videolen / nsample average_dur = videolen // nsample
imgs = [] imgs = []
for i in range(nsample): for i in range(nsample):
...@@ -111,12 +110,12 @@ def video_loader(frames, nsample, mode): ...@@ -111,12 +110,12 @@ def video_loader(frames, nsample, mode):
idx = i idx = i
else: else:
if average_dur >= 1: if average_dur >= 1:
idx = (average_dur - 1) / 2 idx = (average_dur - 1) // 2
idx += i * average_dur idx += i * average_dur
else: else:
idx = i idx = i
imgbuf = frames[idx % videolen] imgbuf = frames[int(idx % videolen)]
img = imageloader(imgbuf) img = imageloader(imgbuf)
imgs.append(img) imgs.append(img)
...@@ -125,7 +124,10 @@ def video_loader(frames, nsample, mode): ...@@ -125,7 +124,10 @@ def video_loader(frames, nsample, mode):
def decode_pickle(sample, mode, seg_num, short_size, target_size): def decode_pickle(sample, mode, seg_num, short_size, target_size):
pickle_path = sample[0] pickle_path = sample[0]
data_loaded = cPickle.load(open(pickle_path)) if python_ver < (3, 0):
data_loaded = pickle.load(open(pickle_path, 'rb'))
else:
data_loaded = pickle.load(open(pickle_path, 'rb'), encoding='bytes')
vid, label, frames = data_loaded vid, label, frames = data_loaded
imgs = video_loader(frames, seg_num, mode) imgs = video_loader(frames, seg_num, mode)
......
...@@ -22,7 +22,7 @@ class TSN_ResNet(): ...@@ -22,7 +22,7 @@ class TSN_ResNet():
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=stride, stride=stride,
padding=(filter_size - 1) / 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
act=None, act=None,
bias_attr=False) bias_attr=False)
......
...@@ -91,7 +91,7 @@ def train(args): ...@@ -91,7 +91,7 @@ def train(args):
fluid.io.load_vars(exe, pretrained_model, vars=vars) fluid.io.load_vars(exe, pretrained_model, vars=vars)
# reader # reader
train_reader = paddle.batch(reader.train(seg_num), batch_size=batch_size) train_reader = paddle.batch(reader.train(seg_num), batch_size=batch_size, drop_last=True)
# test in single GPU # test in single GPU
test_reader = paddle.batch(reader.test(seg_num), batch_size=batch_size / 16) test_reader = paddle.batch(reader.test(seg_num), batch_size=batch_size / 16)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label]) feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
......
...@@ -17,6 +17,7 @@ from __future__ import division ...@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import distutils.util import distutils.util
import numpy as np import numpy as np
import six
from paddle.fluid import core from paddle.fluid import core
...@@ -36,7 +37,7 @@ def print_arguments(args): ...@@ -36,7 +37,7 @@ def print_arguments(args):
:type args: argparse.Namespace :type args: argparse.Namespace
""" """
print("----------- Configuration Arguments -----------") print("----------- Configuration Arguments -----------")
for arg, value in sorted(vars(args).iteritems()): for arg, value in sorted(six.iteritems(vars(args))):
print("%s: %s" % (arg, value)) print("%s: %s" % (arg, value))
print("------------------------------------------------") print("------------------------------------------------")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册