提交 b73e4832 编写于 作者: S SunGaofeng

Merge branch 'video_classification' of https://github.com/SunGaofeng/models...

Merge branch 'video_classification' of https://github.com/SunGaofeng/models into video_classification
......@@ -37,21 +37,21 @@ class FeatureReader(DataReader):
NextVlad only: eigen_file
"""
def __init__(self, name, phase, cfg):
def __init__(self, name, mode, cfg):
self.name = name
self.phase = phase
self.mode = mode
self.num_classes = cfg.MODEL.num_classes
# set batch size and file list
self.batch_size = cfg[phase.upper()]['batch_size']
self.filelist = cfg[phase.upper()]['filelist']
self.batch_size = cfg[mode.upper()]['batch_size']
self.filelist = cfg[mode.upper()]['filelist']
self.eigen_file = cfg.MODEL.get('eigen_file', None)
self.seg_num = cfg.MODEL.get('seg_num', None)
def create_reader(self):
fl = open(self.filelist).readlines()
fl = [line.strip() for line in fl if line.strip() != '']
if self.phase == 'train':
if self.mode == 'train':
random.shuffle(fl)
def reader():
......@@ -62,14 +62,14 @@ class FeatureReader(DataReader):
else:
data = pickle.load(open(filepath, 'rb'), encoding='bytes')
indexes = list(range(len(data)))
if self.phase == 'train':
if self.mode == 'train':
random.shuffle(indexes)
for i in indexes:
record = data[i]
nframes = record[b'nframes']
rgb = record[b'feature'].astype(float)
audio = record[b'audio'].astype(float)
if self.phase != 'infer':
if self.mode != 'infer':
label = record[b'label']
one_hot_label = make_one_hot(label, self.num_classes)
video = record[b'video']
......@@ -94,7 +94,7 @@ class FeatureReader(DataReader):
self.seg_num)
rgb = rgb[sample_inds]
audio = audio[sample_inds]
if self.phase != 'infer':
if self.mode != 'infer':
batch_out.append((rgb, audio, one_hot_label))
else:
batch_out.append((rgb, audio, video))
......
......@@ -53,31 +53,31 @@ class KineticsReader(DataReader):
list
"""
def __init__(self, name, phase, cfg):
def __init__(self, name, mode, cfg):
self.name = name
self.phase = phase
self.mode = mode
self.format = cfg.MODEL.format
self.num_classes = cfg.MODEL.num_classes
self.seg_num = cfg.MODEL.seg_num
self.seglen = cfg.MODEL.seglen
self.short_size = cfg[phase.upper()]['short_size']
self.target_size = cfg[phase.upper()]['target_size']
self.num_reader_threads = cfg[phase.upper()]['num_reader_threads']
self.buf_size = cfg[phase.upper()]['buf_size']
self.short_size = cfg[mode.upper()]['short_size']
self.target_size = cfg[mode.upper()]['target_size']
self.num_reader_threads = cfg[mode.upper()]['num_reader_threads']
self.buf_size = cfg[mode.upper()]['buf_size']
self.img_mean = np.array(cfg.MODEL.image_mean).reshape(
[3, 1, 1]).astype(np.float32)
self.img_std = np.array(cfg.MODEL.image_std).reshape(
[3, 1, 1]).astype(np.float32)
# set batch size and file list
self.batch_size = cfg[phase.upper()]['batch_size']
self.filelist = cfg[phase.upper()]['filelist']
self.batch_size = cfg[mode.upper()]['batch_size']
self.filelist = cfg[mode.upper()]['filelist']
def create_reader(self):
_reader = _reader_creator(self.filelist, self.phase, seg_num=self.seg_num, seglen = self.seglen, \
_reader = _reader_creator(self.filelist, self.mode, seg_num=self.seg_num, seglen = self.seglen, \
short_size = self.short_size, target_size = self.target_size, \
img_mean = self.img_mean, img_std = self.img_std, \
shuffle = (self.phase == 'train'), \
shuffle = (self.mode == 'train'), \
num_threads = self.num_reader_threads, \
buf_size = self.buf_size, format = self.format)
......@@ -95,7 +95,7 @@ class KineticsReader(DataReader):
def _reader_creator(pickle_list,
phase,
mode,
seg_num,
seglen,
short_size,
......@@ -124,7 +124,7 @@ def _reader_creator(pickle_list,
mapper = functools.partial(
decode_func,
phase=phase,
mode=mode,
seg_num=seg_num,
seglen=seglen,
short_size=short_size,
......@@ -135,16 +135,16 @@ def _reader_creator(pickle_list,
return paddle.reader.xmap_readers(mapper, reader, num_threads, buf_size)
def decode_mp4(sample, phase, seg_num, seglen, short_size, target_size,
def decode_mp4(sample, mode, seg_num, seglen, short_size, target_size,
img_mean, img_std):
sample = sample[0].split(' ')
mp4_path = sample[0]
# when infer, we store vid as label
label = int(sample[1])
imgs = mp4_loader(mp4_path, seg_num, seglen, phase)
imgs = mp4_loader(mp4_path, seg_num, seglen, mode)
imgs = group_scale(imgs, short_size)
if phase == 'train':
if mode == 'train':
imgs = group_random_crop(imgs, target_size)
imgs = group_random_flip(imgs)
else:
......@@ -164,7 +164,7 @@ def decode_mp4(sample, phase, seg_num, seglen, short_size, target_size,
return imgs, label
def decode_pickle(sample, phase, seg_num, seglen, short_size, target_size,
def decode_pickle(sample, mode, seg_num, seglen, short_size, target_size,
img_mean, img_std):
pickle_path = sample[0]
try:
......@@ -182,10 +182,10 @@ def decode_pickle(sample, phase, seg_num, seglen, short_size, target_size,
logger.info('Error when loading {}'.format(pickle_path))
return None, None
imgs = video_loader(frames, seg_num, seglen, phase)
imgs = video_loader(frames, seg_num, seglen, mode)
imgs = group_scale(imgs, short_size)
if phase == 'train':
if mode == 'train':
imgs = group_random_crop(imgs, target_size)
imgs = group_random_flip(imgs)
else:
......@@ -202,9 +202,9 @@ def decode_pickle(sample, phase, seg_num, seglen, short_size, target_size,
imgs /= img_std
imgs = np.reshape(imgs, (seg_num, seglen * 3, target_size, target_size))
if phase == 'train' or phase == 'valid' or phase == 'test':
if mode == 'train' or mode == 'valid' or mode == 'test':
return imgs, label
elif phase == 'infer':
elif mode == 'infer':
return imgs, vid
......@@ -281,14 +281,14 @@ def imageloader(buf):
return img.convert('RGB')
def video_loader(frames, nsample, seglen, phase):
def video_loader(frames, nsample, seglen, mode):
videolen = len(frames)
average_dur = int(videolen / nsample)
imgs = []
for i in range(nsample):
idx = 0
if phase == 'train':
if mode == 'train':
if average_dur >= seglen:
idx = random.randint(0, average_dur - seglen)
idx += i * average_dur
......@@ -313,7 +313,7 @@ def video_loader(frames, nsample, seglen, phase):
return imgs
def mp4_loader(filepath, nsample, seglen, phase):
def mp4_loader(filepath, nsample, seglen, mode):
cap = cv2.VideoCapture(filepath)
videolen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
average_dur = int(videolen / nsample)
......@@ -329,7 +329,7 @@ def mp4_loader(filepath, nsample, seglen, phase):
imgs = []
for i in range(nsample):
idx = 0
if phase == 'train':
if mode == 'train':
if average_dur >= seglen:
idx = random.randint(0, average_dur - seglen)
idx += i * average_dur
......
......@@ -43,15 +43,15 @@ class NonlocalReader(DataReader):
use_multi_crop
"""
def __init__(self, name, phase, cfg):
def __init__(self, name, mode, cfg):
self.name = name
self.phase = phase
self.mode = mode
self.cfg = cfg
def create_reader(self):
cfg = self.cfg
phase = self.phase
num_reader_threads = cfg[phase.upper()]['num_reader_threads']
mode = self.mode
num_reader_threads = cfg[mode.upper()]['num_reader_threads']
assert num_reader_threads >=1, \
"number of reader threads({}) should be a positive integer".format(num_reader_threads)
if num_reader_threads == 1:
......@@ -62,24 +62,24 @@ class NonlocalReader(DataReader):
dataset_args = {}
dataset_args['image_mean'] = cfg.MODEL.image_mean
dataset_args['image_std'] = cfg.MODEL.image_std
dataset_args['crop_size'] = cfg[phase.upper()]['crop_size']
dataset_args['sample_rate'] = cfg[phase.upper()]['sample_rate']
dataset_args['video_length'] = cfg[phase.upper()]['video_length']
dataset_args['min_size'] = cfg[phase.upper()]['jitter_scales'][0]
dataset_args['max_size'] = cfg[phase.upper()]['jitter_scales'][1]
dataset_args['crop_size'] = cfg[mode.upper()]['crop_size']
dataset_args['sample_rate'] = cfg[mode.upper()]['sample_rate']
dataset_args['video_length'] = cfg[mode.upper()]['video_length']
dataset_args['min_size'] = cfg[mode.upper()]['jitter_scales'][0]
dataset_args['max_size'] = cfg[mode.upper()]['jitter_scales'][1]
dataset_args['num_reader_threads'] = num_reader_threads
filelist = cfg[phase.upper()]['list']
batch_size = cfg[phase.upper()]['batch_size']
filelist = cfg[mode.upper()]['list']
batch_size = cfg[mode.upper()]['batch_size']
if self.phase == 'train':
if self.mode == 'train':
sample_times = 1
return reader_func(filelist, batch_size, sample_times, True, True,
**dataset_args)
elif self.phase == 'valid':
elif self.mode == 'valid':
sample_times = 1
return reader_func(filelist, batch_size, sample_times, False, False,
**dataset_args)
elif self.phase == 'test':
elif self.mode == 'test':
sample_times = cfg['TEST']['num_test_clips']
if cfg['TEST']['use_multi_crop'] == 1:
sample_times = int(sample_times / 3)
......
......@@ -37,7 +37,7 @@ class ReaderNotFoundError(Exception):
class DataReader(object):
"""data reader for video input"""
def __init__(self, model_name, phase, cfg):
def __init__(self, model_name, mode, cfg):
"""Not implemented"""
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册