提交 67fb7c0d 编写于 作者: S SunGaofeng

remove redundant code in kinetics_reader in decode_mp4 and decode_pickle

上级 c39c6e22
......@@ -135,33 +135,24 @@ def _reader_creator(pickle_list,
return paddle.reader.xmap_readers(mapper, reader, num_threads, buf_size)
def decode_mp4(sample, mode, seg_num, seglen, short_size, target_size,
img_mean, img_std):
def decode_mp4(sample, mode, seg_num, seglen, short_size, target_size, img_mean,
img_std):
sample = sample[0].split(' ')
mp4_path = sample[0]
# when infer, we store vid as label
label = int(sample[1])
try:
imgs = mp4_loader(mp4_path, seg_num, seglen, mode)
imgs = group_scale(imgs, short_size)
if mode == 'train':
imgs = group_random_crop(imgs, target_size)
imgs = group_random_flip(imgs)
else:
imgs = group_center_crop(imgs, target_size)
np_imgs = (np.array(imgs[0]).astype('float32').transpose(
(2, 0, 1))).reshape(1, 3, target_size, target_size) / 255
for i in range(len(imgs) - 1):
img = (np.array(imgs[i + 1]).astype('float32').transpose(
(2, 0, 1))).reshape(1, 3, target_size, target_size) / 255
np_imgs = np.concatenate((np_imgs, img))
imgs = np_imgs
imgs -= img_mean
imgs /= img_std
imgs = np.reshape(imgs, (seg_num, seglen * 3, target_size, target_size))
if len(imgs) < 1:
logger.error('{} frame length {} less than 1.'.format(mp4_path,
len(imgs)))
return None, None
except:
logger.error('Error when loading {}'.format(mp4_path))
return None, None
return imgs, label
return imgs_transform(imgs, label, mode, seg_num, seglen, \
short_size, target_size, img_mean, img_std)
def decode_pickle(sample, mode, seg_num, seglen, short_size, target_size,
......@@ -175,14 +166,25 @@ def decode_pickle(sample, mode, seg_num, seglen, short_size, target_size,
vid, label, frames = data_loaded
if len(frames) < 1:
logger.info('{} frame length {} less than 1.'.format(pickle_path,
logger.error('{} frame length {} less than 1.'.format(pickle_path,
len(frames)))
raise
return None, None
except:
logger.info('Error when loading {}'.format(pickle_path))
return None, None
if mode == 'train' or mode == 'valid' or mode == 'test':
ret_label = label
elif mode == 'infer':
ret_label = vid
imgs = video_loader(frames, seg_num, seglen, mode)
return imgs_transform(imgs, ret_label, mode, seg_num, seglen, \
short_size, target_size, img_mean, img_std)
def imgs_transform(imgs, label, mode, seg_num, seglen, short_size, target_size,
img_mean, img_std):
imgs = group_scale(imgs, short_size)
if mode == 'train':
......@@ -202,10 +204,7 @@ def decode_pickle(sample, mode, seg_num, seglen, short_size, target_size,
imgs /= img_std
imgs = np.reshape(imgs, (seg_num, seglen * 3, target_size, target_size))
if mode == 'train' or mode == 'valid' or mode == 'test':
return imgs, label
elif mode == 'infer':
return imgs, vid
def group_random_crop(img_group, target_size):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册