提交 b73e4832 编写于 作者: S SunGaofeng

Merge branch 'video_classification' of https://github.com/SunGaofeng/models...

Merge branch 'video_classification' of https://github.com/SunGaofeng/models into video_classification
...@@ -37,21 +37,21 @@ class FeatureReader(DataReader): ...@@ -37,21 +37,21 @@ class FeatureReader(DataReader):
NextVlad only: eigen_file NextVlad only: eigen_file
""" """
def __init__(self, name, phase, cfg): def __init__(self, name, mode, cfg):
self.name = name self.name = name
self.phase = phase self.mode = mode
self.num_classes = cfg.MODEL.num_classes self.num_classes = cfg.MODEL.num_classes
# set batch size and file list # set batch size and file list
self.batch_size = cfg[phase.upper()]['batch_size'] self.batch_size = cfg[mode.upper()]['batch_size']
self.filelist = cfg[phase.upper()]['filelist'] self.filelist = cfg[mode.upper()]['filelist']
self.eigen_file = cfg.MODEL.get('eigen_file', None) self.eigen_file = cfg.MODEL.get('eigen_file', None)
self.seg_num = cfg.MODEL.get('seg_num', None) self.seg_num = cfg.MODEL.get('seg_num', None)
def create_reader(self): def create_reader(self):
fl = open(self.filelist).readlines() fl = open(self.filelist).readlines()
fl = [line.strip() for line in fl if line.strip() != ''] fl = [line.strip() for line in fl if line.strip() != '']
if self.phase == 'train': if self.mode == 'train':
random.shuffle(fl) random.shuffle(fl)
def reader(): def reader():
...@@ -62,14 +62,14 @@ class FeatureReader(DataReader): ...@@ -62,14 +62,14 @@ class FeatureReader(DataReader):
else: else:
data = pickle.load(open(filepath, 'rb'), encoding='bytes') data = pickle.load(open(filepath, 'rb'), encoding='bytes')
indexes = list(range(len(data))) indexes = list(range(len(data)))
if self.phase == 'train': if self.mode == 'train':
random.shuffle(indexes) random.shuffle(indexes)
for i in indexes: for i in indexes:
record = data[i] record = data[i]
nframes = record[b'nframes'] nframes = record[b'nframes']
rgb = record[b'feature'].astype(float) rgb = record[b'feature'].astype(float)
audio = record[b'audio'].astype(float) audio = record[b'audio'].astype(float)
if self.phase != 'infer': if self.mode != 'infer':
label = record[b'label'] label = record[b'label']
one_hot_label = make_one_hot(label, self.num_classes) one_hot_label = make_one_hot(label, self.num_classes)
video = record[b'video'] video = record[b'video']
...@@ -94,7 +94,7 @@ class FeatureReader(DataReader): ...@@ -94,7 +94,7 @@ class FeatureReader(DataReader):
self.seg_num) self.seg_num)
rgb = rgb[sample_inds] rgb = rgb[sample_inds]
audio = audio[sample_inds] audio = audio[sample_inds]
if self.phase != 'infer': if self.mode != 'infer':
batch_out.append((rgb, audio, one_hot_label)) batch_out.append((rgb, audio, one_hot_label))
else: else:
batch_out.append((rgb, audio, video)) batch_out.append((rgb, audio, video))
......
...@@ -53,31 +53,31 @@ class KineticsReader(DataReader): ...@@ -53,31 +53,31 @@ class KineticsReader(DataReader):
list list
""" """
def __init__(self, name, phase, cfg): def __init__(self, name, mode, cfg):
self.name = name self.name = name
self.phase = phase self.mode = mode
self.format = cfg.MODEL.format self.format = cfg.MODEL.format
self.num_classes = cfg.MODEL.num_classes self.num_classes = cfg.MODEL.num_classes
self.seg_num = cfg.MODEL.seg_num self.seg_num = cfg.MODEL.seg_num
self.seglen = cfg.MODEL.seglen self.seglen = cfg.MODEL.seglen
self.short_size = cfg[phase.upper()]['short_size'] self.short_size = cfg[mode.upper()]['short_size']
self.target_size = cfg[phase.upper()]['target_size'] self.target_size = cfg[mode.upper()]['target_size']
self.num_reader_threads = cfg[phase.upper()]['num_reader_threads'] self.num_reader_threads = cfg[mode.upper()]['num_reader_threads']
self.buf_size = cfg[phase.upper()]['buf_size'] self.buf_size = cfg[mode.upper()]['buf_size']
self.img_mean = np.array(cfg.MODEL.image_mean).reshape( self.img_mean = np.array(cfg.MODEL.image_mean).reshape(
[3, 1, 1]).astype(np.float32) [3, 1, 1]).astype(np.float32)
self.img_std = np.array(cfg.MODEL.image_std).reshape( self.img_std = np.array(cfg.MODEL.image_std).reshape(
[3, 1, 1]).astype(np.float32) [3, 1, 1]).astype(np.float32)
# set batch size and file list # set batch size and file list
self.batch_size = cfg[phase.upper()]['batch_size'] self.batch_size = cfg[mode.upper()]['batch_size']
self.filelist = cfg[phase.upper()]['filelist'] self.filelist = cfg[mode.upper()]['filelist']
def create_reader(self): def create_reader(self):
_reader = _reader_creator(self.filelist, self.phase, seg_num=self.seg_num, seglen = self.seglen, \ _reader = _reader_creator(self.filelist, self.mode, seg_num=self.seg_num, seglen = self.seglen, \
short_size = self.short_size, target_size = self.target_size, \ short_size = self.short_size, target_size = self.target_size, \
img_mean = self.img_mean, img_std = self.img_std, \ img_mean = self.img_mean, img_std = self.img_std, \
shuffle = (self.phase == 'train'), \ shuffle = (self.mode == 'train'), \
num_threads = self.num_reader_threads, \ num_threads = self.num_reader_threads, \
buf_size = self.buf_size, format = self.format) buf_size = self.buf_size, format = self.format)
...@@ -95,7 +95,7 @@ class KineticsReader(DataReader): ...@@ -95,7 +95,7 @@ class KineticsReader(DataReader):
def _reader_creator(pickle_list, def _reader_creator(pickle_list,
phase, mode,
seg_num, seg_num,
seglen, seglen,
short_size, short_size,
...@@ -124,7 +124,7 @@ def _reader_creator(pickle_list, ...@@ -124,7 +124,7 @@ def _reader_creator(pickle_list,
mapper = functools.partial( mapper = functools.partial(
decode_func, decode_func,
phase=phase, mode=mode,
seg_num=seg_num, seg_num=seg_num,
seglen=seglen, seglen=seglen,
short_size=short_size, short_size=short_size,
...@@ -135,16 +135,16 @@ def _reader_creator(pickle_list, ...@@ -135,16 +135,16 @@ def _reader_creator(pickle_list,
return paddle.reader.xmap_readers(mapper, reader, num_threads, buf_size) return paddle.reader.xmap_readers(mapper, reader, num_threads, buf_size)
def decode_mp4(sample, phase, seg_num, seglen, short_size, target_size, def decode_mp4(sample, mode, seg_num, seglen, short_size, target_size,
img_mean, img_std): img_mean, img_std):
sample = sample[0].split(' ') sample = sample[0].split(' ')
mp4_path = sample[0] mp4_path = sample[0]
# when infer, we store vid as label # when infer, we store vid as label
label = int(sample[1]) label = int(sample[1])
imgs = mp4_loader(mp4_path, seg_num, seglen, phase) imgs = mp4_loader(mp4_path, seg_num, seglen, mode)
imgs = group_scale(imgs, short_size) imgs = group_scale(imgs, short_size)
if phase == 'train': if mode == 'train':
imgs = group_random_crop(imgs, target_size) imgs = group_random_crop(imgs, target_size)
imgs = group_random_flip(imgs) imgs = group_random_flip(imgs)
else: else:
...@@ -164,7 +164,7 @@ def decode_mp4(sample, phase, seg_num, seglen, short_size, target_size, ...@@ -164,7 +164,7 @@ def decode_mp4(sample, phase, seg_num, seglen, short_size, target_size,
return imgs, label return imgs, label
def decode_pickle(sample, phase, seg_num, seglen, short_size, target_size, def decode_pickle(sample, mode, seg_num, seglen, short_size, target_size,
img_mean, img_std): img_mean, img_std):
pickle_path = sample[0] pickle_path = sample[0]
try: try:
...@@ -182,10 +182,10 @@ def decode_pickle(sample, phase, seg_num, seglen, short_size, target_size, ...@@ -182,10 +182,10 @@ def decode_pickle(sample, phase, seg_num, seglen, short_size, target_size,
logger.info('Error when loading {}'.format(pickle_path)) logger.info('Error when loading {}'.format(pickle_path))
return None, None return None, None
imgs = video_loader(frames, seg_num, seglen, phase) imgs = video_loader(frames, seg_num, seglen, mode)
imgs = group_scale(imgs, short_size) imgs = group_scale(imgs, short_size)
if phase == 'train': if mode == 'train':
imgs = group_random_crop(imgs, target_size) imgs = group_random_crop(imgs, target_size)
imgs = group_random_flip(imgs) imgs = group_random_flip(imgs)
else: else:
...@@ -202,9 +202,9 @@ def decode_pickle(sample, phase, seg_num, seglen, short_size, target_size, ...@@ -202,9 +202,9 @@ def decode_pickle(sample, phase, seg_num, seglen, short_size, target_size,
imgs /= img_std imgs /= img_std
imgs = np.reshape(imgs, (seg_num, seglen * 3, target_size, target_size)) imgs = np.reshape(imgs, (seg_num, seglen * 3, target_size, target_size))
if phase == 'train' or phase == 'valid' or phase == 'test': if mode == 'train' or mode == 'valid' or mode == 'test':
return imgs, label return imgs, label
elif phase == 'infer': elif mode == 'infer':
return imgs, vid return imgs, vid
...@@ -281,14 +281,14 @@ def imageloader(buf): ...@@ -281,14 +281,14 @@ def imageloader(buf):
return img.convert('RGB') return img.convert('RGB')
def video_loader(frames, nsample, seglen, phase): def video_loader(frames, nsample, seglen, mode):
videolen = len(frames) videolen = len(frames)
average_dur = int(videolen / nsample) average_dur = int(videolen / nsample)
imgs = [] imgs = []
for i in range(nsample): for i in range(nsample):
idx = 0 idx = 0
if phase == 'train': if mode == 'train':
if average_dur >= seglen: if average_dur >= seglen:
idx = random.randint(0, average_dur - seglen) idx = random.randint(0, average_dur - seglen)
idx += i * average_dur idx += i * average_dur
...@@ -313,7 +313,7 @@ def video_loader(frames, nsample, seglen, phase): ...@@ -313,7 +313,7 @@ def video_loader(frames, nsample, seglen, phase):
return imgs return imgs
def mp4_loader(filepath, nsample, seglen, phase): def mp4_loader(filepath, nsample, seglen, mode):
cap = cv2.VideoCapture(filepath) cap = cv2.VideoCapture(filepath)
videolen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) videolen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
average_dur = int(videolen / nsample) average_dur = int(videolen / nsample)
...@@ -329,7 +329,7 @@ def mp4_loader(filepath, nsample, seglen, phase): ...@@ -329,7 +329,7 @@ def mp4_loader(filepath, nsample, seglen, phase):
imgs = [] imgs = []
for i in range(nsample): for i in range(nsample):
idx = 0 idx = 0
if phase == 'train': if mode == 'train':
if average_dur >= seglen: if average_dur >= seglen:
idx = random.randint(0, average_dur - seglen) idx = random.randint(0, average_dur - seglen)
idx += i * average_dur idx += i * average_dur
......
...@@ -43,15 +43,15 @@ class NonlocalReader(DataReader): ...@@ -43,15 +43,15 @@ class NonlocalReader(DataReader):
use_multi_crop use_multi_crop
""" """
def __init__(self, name, phase, cfg): def __init__(self, name, mode, cfg):
self.name = name self.name = name
self.phase = phase self.mode = mode
self.cfg = cfg self.cfg = cfg
def create_reader(self): def create_reader(self):
cfg = self.cfg cfg = self.cfg
phase = self.phase mode = self.mode
num_reader_threads = cfg[phase.upper()]['num_reader_threads'] num_reader_threads = cfg[mode.upper()]['num_reader_threads']
assert num_reader_threads >=1, \ assert num_reader_threads >=1, \
"number of reader threads({}) should be a positive integer".format(num_reader_threads) "number of reader threads({}) should be a positive integer".format(num_reader_threads)
if num_reader_threads == 1: if num_reader_threads == 1:
...@@ -62,24 +62,24 @@ class NonlocalReader(DataReader): ...@@ -62,24 +62,24 @@ class NonlocalReader(DataReader):
dataset_args = {} dataset_args = {}
dataset_args['image_mean'] = cfg.MODEL.image_mean dataset_args['image_mean'] = cfg.MODEL.image_mean
dataset_args['image_std'] = cfg.MODEL.image_std dataset_args['image_std'] = cfg.MODEL.image_std
dataset_args['crop_size'] = cfg[phase.upper()]['crop_size'] dataset_args['crop_size'] = cfg[mode.upper()]['crop_size']
dataset_args['sample_rate'] = cfg[phase.upper()]['sample_rate'] dataset_args['sample_rate'] = cfg[mode.upper()]['sample_rate']
dataset_args['video_length'] = cfg[phase.upper()]['video_length'] dataset_args['video_length'] = cfg[mode.upper()]['video_length']
dataset_args['min_size'] = cfg[phase.upper()]['jitter_scales'][0] dataset_args['min_size'] = cfg[mode.upper()]['jitter_scales'][0]
dataset_args['max_size'] = cfg[phase.upper()]['jitter_scales'][1] dataset_args['max_size'] = cfg[mode.upper()]['jitter_scales'][1]
dataset_args['num_reader_threads'] = num_reader_threads dataset_args['num_reader_threads'] = num_reader_threads
filelist = cfg[phase.upper()]['list'] filelist = cfg[mode.upper()]['list']
batch_size = cfg[phase.upper()]['batch_size'] batch_size = cfg[mode.upper()]['batch_size']
if self.phase == 'train': if self.mode == 'train':
sample_times = 1 sample_times = 1
return reader_func(filelist, batch_size, sample_times, True, True, return reader_func(filelist, batch_size, sample_times, True, True,
**dataset_args) **dataset_args)
elif self.phase == 'valid': elif self.mode == 'valid':
sample_times = 1 sample_times = 1
return reader_func(filelist, batch_size, sample_times, False, False, return reader_func(filelist, batch_size, sample_times, False, False,
**dataset_args) **dataset_args)
elif self.phase == 'test': elif self.mode == 'test':
sample_times = cfg['TEST']['num_test_clips'] sample_times = cfg['TEST']['num_test_clips']
if cfg['TEST']['use_multi_crop'] == 1: if cfg['TEST']['use_multi_crop'] == 1:
sample_times = int(sample_times / 3) sample_times = int(sample_times / 3)
......
...@@ -37,7 +37,7 @@ class ReaderNotFoundError(Exception): ...@@ -37,7 +37,7 @@ class ReaderNotFoundError(Exception):
class DataReader(object): class DataReader(object):
"""data reader for video input""" """data reader for video input"""
def __init__(self, model_name, phase, cfg): def __init__(self, model_name, mode, cfg):
"""Not implemented""" """Not implemented"""
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册