提交 e8726492 编写于 作者: D dengkaipeng

add multi-scale crop for TSM.

上级 99dcaf65
...@@ -74,7 +74,7 @@ class KineticsReader(DataReader): ...@@ -74,7 +74,7 @@ class KineticsReader(DataReader):
self.filelist = cfg[mode.upper()]['filelist'] self.filelist = cfg[mode.upper()]['filelist']
def create_reader(self): def create_reader(self):
_reader = _reader_creator(self.filelist, self.mode, seg_num=self.seg_num, seglen = self.seglen, \ _reader = self._reader_creator(self.filelist, self.mode, seg_num=self.seg_num, seglen = self.seglen, \
short_size = self.short_size, target_size = self.target_size, \ short_size = self.short_size, target_size = self.target_size, \
img_mean = self.img_mean, img_std = self.img_std, \ img_mean = self.img_mean, img_std = self.img_std, \
shuffle = (self.mode == 'train'), \ shuffle = (self.mode == 'train'), \
...@@ -94,117 +94,183 @@ class KineticsReader(DataReader): ...@@ -94,117 +94,183 @@ class KineticsReader(DataReader):
return _batch_reader return _batch_reader
def _reader_creator(pickle_list, def _reader_creator(self,
mode, pickle_list,
seg_num, mode,
seglen, seg_num,
short_size, seglen,
target_size, short_size,
img_mean, target_size,
img_std, img_mean,
shuffle=False, img_std,
num_threads=1, shuffle=False,
buf_size=1024, num_threads=1,
format='pkl'): buf_size=1024,
def reader(): format='pkl'):
with open(pickle_list) as flist: def decode_mp4(sample, mode, seg_num, seglen, short_size, target_size, img_mean,
lines = [line.strip() for line in flist] img_std):
if shuffle: sample = sample[0].split(' ')
random.shuffle(lines) mp4_path = sample[0]
for line in lines: # when infer, we store vid as label
pickle_path = line.strip() label = int(sample[1])
yield [pickle_path] try:
imgs = mp4_loader(mp4_path, seg_num, seglen, mode)
if format == 'pkl': if len(imgs) < 1:
decode_func = decode_pickle logger.error('{} frame length {} less than 1.'.format(mp4_path,
elif format == 'mp4': len(imgs)))
decode_func = decode_mp4 return None, None
else: except:
raise "Not implemented format {}".format(format) logger.error('Error when loading {}'.format(mp4_path))
return None, None
mapper = functools.partial(
decode_func, return imgs_transform(imgs, label, mode, seg_num, seglen, \
mode=mode, short_size, target_size, img_mean, img_std)
seg_num=seg_num,
seglen=seglen,
short_size=short_size, def decode_pickle(sample, mode, seg_num, seglen, short_size, target_size,
target_size=target_size, img_mean, img_std):
img_mean=img_mean, pickle_path = sample[0]
img_std=img_std) try:
if python_ver < (3, 0):
return paddle.reader.xmap_readers(mapper, reader, num_threads, buf_size) data_loaded = pickle.load(open(pickle_path, 'rb'))
else:
data_loaded = pickle.load(open(pickle_path, 'rb'), encoding='bytes')
def decode_mp4(sample, mode, seg_num, seglen, short_size, target_size, img_mean,
img_std): vid, label, frames = data_loaded
sample = sample[0].split(' ') if len(frames) < 1:
mp4_path = sample[0] logger.error('{} frame length {} less than 1.'.format(pickle_path,
# when infer, we store vid as label len(frames)))
label = int(sample[1]) return None, None
try: except:
imgs = mp4_loader(mp4_path, seg_num, seglen, mode) logger.info('Error when loading {}'.format(pickle_path))
if len(imgs) < 1: return None, None
logger.error('{} frame length {} less than 1.'.format(mp4_path,
len(imgs))) if mode == 'train' or mode == 'valid' or mode == 'test':
return None, None ret_label = label
except: elif mode == 'infer':
logger.error('Error when loading {}'.format(mp4_path)) ret_label = vid
return None, None
imgs = video_loader(frames, seg_num, seglen, mode)
return imgs_transform(imgs, label, mode, seg_num, seglen, \ return imgs_transform(imgs, ret_label, mode, seg_num, seglen, \
short_size, target_size, img_mean, img_std) short_size, target_size, img_mean, img_std)
def decode_pickle(sample, mode, seg_num, seglen, short_size, target_size, def imgs_transform(imgs, label, mode, seg_num, seglen, short_size, target_size,
img_mean, img_std): img_mean, img_std):
pickle_path = sample[0] imgs = group_scale(imgs, short_size)
try:
if python_ver < (3, 0): if mode == 'train':
data_loaded = pickle.load(open(pickle_path, 'rb')) if self.name == "TSM":
imgs = group_multi_scale_crop(imgs, short_size)
imgs = group_random_crop(imgs, target_size)
imgs = group_random_flip(imgs)
else:
imgs = group_center_crop(imgs, target_size)
np_imgs = (np.array(imgs[0]).astype('float32').transpose(
(2, 0, 1))).reshape(1, 3, target_size, target_size) / 255
for i in range(len(imgs) - 1):
img = (np.array(imgs[i + 1]).astype('float32').transpose(
(2, 0, 1))).reshape(1, 3, target_size, target_size) / 255
np_imgs = np.concatenate((np_imgs, img))
imgs = np_imgs
imgs -= img_mean
imgs /= img_std
imgs = np.reshape(imgs, (seg_num, seglen * 3, target_size, target_size))
return imgs, label
def reader():
with open(pickle_list) as flist:
lines = [line.strip() for line in flist]
if shuffle:
random.shuffle(lines)
for line in lines:
pickle_path = line.strip()
yield [pickle_path]
if format == 'pkl':
decode_func = decode_pickle
elif format == 'mp4':
decode_func = decode_mp4
else: else:
data_loaded = pickle.load(open(pickle_path, 'rb'), encoding='bytes') raise "Not implemented format {}".format(format)
vid, label, frames = data_loaded mapper = functools.partial(
if len(frames) < 1: decode_func,
logger.error('{} frame length {} less than 1.'.format(pickle_path, mode=mode,
len(frames))) seg_num=seg_num,
return None, None seglen=seglen,
except: short_size=short_size,
logger.info('Error when loading {}'.format(pickle_path)) target_size=target_size,
return None, None img_mean=img_mean,
img_std=img_std)
if mode == 'train' or mode == 'valid' or mode == 'test':
ret_label = label return paddle.reader.xmap_readers(mapper, reader, num_threads, buf_size)
elif mode == 'infer':
ret_label = vid
def group_multi_scale_crop(img_group, target_size, scales=None, \
imgs = video_loader(frames, seg_num, seglen, mode) max_distort=1, fix_crop=True, more_fix_crop=True):
return imgs_transform(imgs, ret_label, mode, seg_num, seglen, \ scales = scales if scales is not None else [1, .875, .75, .66]
short_size, target_size, img_mean, img_std) input_size = [target_size, target_size]
im_size = img_group[0].size
def imgs_transform(imgs, label, mode, seg_num, seglen, short_size, target_size,
img_mean, img_std): # get random crop offset
imgs = group_scale(imgs, short_size) def _sample_crop_size(im_size):
image_w, image_h = im_size[0], im_size[1]
if mode == 'train':
imgs = group_random_crop(imgs, target_size) base_size = min(image_w, image_h)
imgs = group_random_flip(imgs) crop_sizes = [int(base_size * x) for x in scales]
else: crop_h = [input_size[1] if abs(x - input_size[1]) < 3 else x for x in crop_sizes]
imgs = group_center_crop(imgs, target_size) crop_w = [input_size[0] if abs(x - input_size[0]) < 3 else x for x in crop_sizes]
np_imgs = (np.array(imgs[0]).astype('float32').transpose( pairs = []
(2, 0, 1))).reshape(1, 3, target_size, target_size) / 255 for i, h in enumerate(crop_h):
for i in range(len(imgs) - 1): for j, w in enumerate(crop_w):
img = (np.array(imgs[i + 1]).astype('float32').transpose( if abs(i - j) <= max_distort:
(2, 0, 1))).reshape(1, 3, target_size, target_size) / 255 pairs.append((w, h))
np_imgs = np.concatenate((np_imgs, img))
imgs = np_imgs crop_pair = random.choice(pairs)
imgs -= img_mean if not fix_crop:
imgs /= img_std w_offset = random.randint(0, image_w - crop_pair[0])
imgs = np.reshape(imgs, (seg_num, seglen * 3, target_size, target_size)) h_offset = random.randint(0, image_h - crop_pair[1])
else:
return imgs, label w_step = (image_w - crop_pair[0]) / 4
h_step = (image_h - crop_pair[1]) / 4
ret = list()
ret.append((0, 0)) # upper left
if w_step != 0:
ret.append((4 * w_step, 0)) # upper right
if h_step != 0:
ret.append((0, 4 * h_step)) # lower left
if h_step != 0 and w_step != 0:
ret.append((4 * w_step, 4 * h_step)) # lower right
if h_step != 0 or w_step != 0:
ret.append((2 * w_step, 2 * h_step)) # center
if more_fix_crop:
ret.append((0, 2 * h_step)) # center left
ret.append((4 * w_step, 2 * h_step)) # center right
ret.append((2 * w_step, 4 * h_step)) # lower center
ret.append((2 * w_step, 0 * h_step)) # upper center
ret.append((1 * w_step, 1 * h_step)) # upper left quarter
ret.append((3 * w_step, 1 * h_step)) # upper right quarter
ret.append((1 * w_step, 3 * h_step)) # lower left quarter
ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
w_offset, h_offset = random.choice(ret)
return crop_pair[0], crop_pair[1], w_offset, h_offset
crop_w, crop_h, offset_w, offset_h = _sample_crop_size(im_size)
crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
ret_img_group = [img.resize((input_size[0], input_size[1]), Image.BILINEAR) for img in crop_img_group]
return ret_img_group
def group_random_crop(img_group, target_size): def group_random_crop(img_group, target_size):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册