提交 45165914 编写于 作者: K kinghuin

accelerate video reader

上级 16710e75
......@@ -88,12 +88,9 @@ class VideoTag(hub.Module):
extractor_model.load_test_weights(exe, args.extractor_weights,
extractor_main_prog)
# get reader and metrics
extractor_reader = get_reader("TSN", 'infer',
extractor_infer_config)
extractor_feeder = fluid.DataFeeder(
place=place, feed_list=extractor_feeds)
return extractor_reader, extractor_main_prog, extractor_fetch_list, extractor_feeder, extractor_scope
return extractor_main_prog, extractor_fetch_list, extractor_feeder, extractor_scope
def _predictor(self, args, exe, place):
predictor_scope = fluid.Scope()
......@@ -169,15 +166,20 @@ class VideoTag(hub.Module):
self.place = fluid.CUDAPlace(
0) if args.use_gpu else fluid.CPUPlace()
self.exe = fluid.Executor(self.place)
self.extractor_reader, self.extractor_main_prog, self.extractor_fetch_list, self.extractor_feeder, self.extractor_scope = self._extractor(
self.extractor_main_prog, self.extractor_fetch_list, self.extractor_feeder, self.extractor_scope = self._extractor(
args, self.exe, self.place)
self.predictor_main_prog, self.predictor_fetch_list, self.predictor_feeder, self.predictor_scope = self._predictor(
args, self.exe, self.place)
self._has_load = True
extractor_config = parse_config(args.extractor_config)
extractor_infer_config = merge_configs(extractor_config, 'infer',
vars(args))
extractor_reader = get_reader("TSN", 'infer', extractor_infer_config)
feature_list = []
file_list = []
for idx, data in enumerate(self.extractor_reader()):
for idx, data in enumerate(extractor_reader()):
file_id = [item[-1] for item in data]
feed_data = [item[:-1] for item in data]
feature_out = self.exe.run(
......
......@@ -87,38 +87,15 @@ class KineticsReader(DataReader):
def _batch_reader():
batch_out = []
for imgs, label in _reader():
#for imgs in _reader():
if imgs is None:
continue
batch_out.append((imgs, label))
#batch_out.append((imgs,))
if len(batch_out) == self.batch_size:
yield batch_out
batch_out = []
return _batch_reader
def _inference_reader_creator(self, video_path, mode, seg_num, seglen,
short_size, target_size, img_mean, img_std):
def reader():
try:
imgs = mp4_loader(video_path, seg_num, seglen, mode)
if len(imgs) < 1:
logger.error('{} frame length {} less than 1.'.format(
video_path, len(imgs)))
yield None, None
except:
logger.error('Error when loading {}'.format(video_path))
yield None, None
imgs_ret = imgs_transform(imgs, mode, seg_num, seglen, short_size,
target_size, img_mean, img_std)
label_ret = video_path
yield imgs_ret, label_ret
return reader
def _reader_creator(self,
pickle_list,
mode,
......@@ -149,37 +126,7 @@ class KineticsReader(DataReader):
return imgs_transform(imgs, mode, seg_num, seglen, \
short_size, target_size, img_mean, img_std, name = self.name), mp4_path
def decode_pickle(sample, mode, seg_num, seglen, short_size,
target_size, img_mean, img_std):
pickle_path = sample[0]
try:
if python_ver < (3, 0):
data_loaded = pickle.load(open(pickle_path, 'rb'))
else:
data_loaded = pickle.load(
open(pickle_path, 'rb'), encoding='bytes')
vid, label, frames = data_loaded
if len(frames) < 1:
logger.error('{} frame length {} less than 1.'.format(
pickle_path, len(frames)))
return None, None
except:
logger.info('Error when loading {}'.format(pickle_path))
return None, None
if mode == 'train' or mode == 'valid' or mode == 'test':
ret_label = label
elif mode == 'infer':
ret_label = vid
imgs = video_loader(frames, seg_num, seglen, mode)
return imgs_transform(imgs, mode, seg_num, seglen, \
short_size, target_size, img_mean, img_std, name = self.name), ret_label
def reader():
# with open(pickle_list) as flist:
# lines = [line.strip() for line in flist]
lines = [line.strip() for line in pickle_list]
if shuffle:
random.shuffle(lines)
......@@ -187,15 +134,8 @@ class KineticsReader(DataReader):
pickle_path = line.strip()
yield [pickle_path]
if format == 'pkl':
decode_func = decode_pickle
elif format == 'mp4':
decode_func = decode_mp4
else:
raise "Not implemented format {}".format(format)
mapper = functools.partial(
decode_func,
decode_mp4,
mode=mode,
seg_num=seg_num,
seglen=seglen,
......@@ -218,142 +158,26 @@ def imgs_transform(imgs,
name=''):
imgs = group_scale(imgs, short_size)
if mode == 'train':
if name == "TSM":
imgs = group_multi_scale_crop(imgs, short_size)
imgs = group_random_crop(imgs, target_size)
imgs = group_random_flip(imgs)
else:
imgs = group_center_crop(imgs, target_size)
np_imgs = (np.array(imgs[0]).astype('float32').transpose(
(2, 0, 1))).reshape(1, 3, target_size, target_size) / 255
for i in range(len(imgs) - 1):
img = (np.array(imgs[i + 1]).astype('float32').transpose(
(2, 0, 1))).reshape(1, 3, target_size, target_size) / 255
np_imgs = np.concatenate((np_imgs, img))
imgs = np_imgs
imgs -= img_mean
imgs /= img_std
imgs = np.reshape(imgs, (seg_num, seglen * 3, target_size, target_size))
return imgs
def group_multi_scale_crop(img_group, target_size, scales=None, \
max_distort=1, fix_crop=True, more_fix_crop=True):
scales = scales if scales is not None else [1, .875, .75, .66]
input_size = [target_size, target_size]
im_size = img_group[0].size
# get random crop offset
def _sample_crop_size(im_size):
image_w, image_h = im_size[0], im_size[1]
base_size = min(image_w, image_h)
crop_sizes = [int(base_size * x) for x in scales]
crop_h = [
input_size[1] if abs(x - input_size[1]) < 3 else x
for x in crop_sizes
]
crop_w = [
input_size[0] if abs(x - input_size[0]) < 3 else x
for x in crop_sizes
]
pairs = []
for i, h in enumerate(crop_h):
for j, w in enumerate(crop_w):
if abs(i - j) <= max_distort:
pairs.append((w, h))
crop_pair = random.choice(pairs)
if not fix_crop:
w_offset = random.randint(0, image_w - crop_pair[0])
h_offset = random.randint(0, image_h - crop_pair[1])
else:
w_step = (image_w - crop_pair[0]) / 4
h_step = (image_h - crop_pair[1]) / 4
ret = list()
ret.append((0, 0)) # upper left
if w_step != 0:
ret.append((4 * w_step, 0)) # upper right
if h_step != 0:
ret.append((0, 4 * h_step)) # lower left
if h_step != 0 and w_step != 0:
ret.append((4 * w_step, 4 * h_step)) # lower right
if h_step != 0 or w_step != 0:
ret.append((2 * w_step, 2 * h_step)) # center
if more_fix_crop:
ret.append((0, 2 * h_step)) # center left
ret.append((4 * w_step, 2 * h_step)) # center right
ret.append((2 * w_step, 4 * h_step)) # lower center
ret.append((2 * w_step, 0 * h_step)) # upper center
ret.append((1 * w_step, 1 * h_step)) # upper left quarter
ret.append((3 * w_step, 1 * h_step)) # upper right quarter
ret.append((1 * w_step, 3 * h_step)) # lower left quarter
ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
w_offset, h_offset = random.choice(ret)
np_imgs = np.array([np.array(img).astype('float32') for img in imgs]) #dhwc
np_imgs = group_center_crop(np_imgs, target_size)
np_imgs = np_imgs.transpose(0, 3, 1, 2) / 255 #dchw
np_imgs -= img_mean
np_imgs /= img_std
return crop_pair[0], crop_pair[1], w_offset, h_offset
return np_imgs
crop_w, crop_h, offset_w, offset_h = _sample_crop_size(im_size)
crop_img_group = [
img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h))
for img in img_group
]
ret_img_group = [
img.resize((input_size[0], input_size[1]), Image.BILINEAR)
for img in crop_img_group
]
return ret_img_group
def group_random_crop(img_group, target_size):
w, h = img_group[0].size
def group_center_crop(np_imgs, target_size):
d, h, w, c = np_imgs.shape
th, tw = target_size, target_size
assert (w >= target_size) and (h >= target_size), \
"image width({}) and height({}) should be larger than crop size".format(w, h, target_size)
out_images = []
x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
"image width({}) and height({}) should be larger than crop size".format(w, h, target_size)
for img in img_group:
if w == tw and h == th:
out_images.append(img)
else:
out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
return out_images
def group_random_flip(img_group):
v = random.random()
if v < 0.5:
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
return ret
else:
return img_group
def group_center_crop(img_group, target_size):
img_crop = []
for img in img_group:
w, h = img.size
th, tw = target_size, target_size
assert (w >= target_size) and (h >= target_size), \
"image width({}) and height({}) should be larger than crop size".format(w, h, target_size)
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
img_crop.append(img.crop((x1, y1, x1 + tw, y1 + th)))
h_off = int(round((h - th) / 2.))
w_off = int(round((w - tw) / 2.))
img_crop = np_imgs[:, h_off:h_off + target_size, w_off:w_off +
target_size, :]
return img_crop
......@@ -378,47 +202,6 @@ def group_scale(imgs, target_size):
return resized_imgs
def imageloader(buf):
if isinstance(buf, str):
img = Image.open(StringIO(buf))
else:
img = Image.open(BytesIO(buf))
return img.convert('RGB')
def video_loader(frames, nsample, seglen, mode):
videolen = len(frames)
average_dur = int(videolen / nsample)
imgs = []
for i in range(nsample):
idx = 0
if mode == 'train':
if average_dur >= seglen:
idx = random.randint(0, average_dur - seglen)
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
else:
if average_dur >= seglen:
idx = (average_dur - seglen) // 2
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
for jj in range(idx, idx + seglen):
imgbuf = frames[int(jj % videolen)]
img = imageloader(imgbuf)
imgs.append(img)
return imgs
def mp4_loader(filepath, nsample, seglen, mode):
cap = cv2.VideoCapture(filepath)
videolen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
......@@ -434,26 +217,16 @@ def mp4_loader(filepath, nsample, seglen, mode):
imgs = []
for i in range(nsample):
idx = 0
if mode == 'train':
if average_dur >= seglen:
idx = random.randint(0, average_dur - seglen)
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
if average_dur >= seglen:
idx = (average_dur - 1) // 2
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
if average_dur >= seglen:
idx = (average_dur - 1) // 2
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
idx = i
for jj in range(idx, idx + seglen):
imgbuf = sampledFrames[int(jj % len(sampledFrames))]
img = Image.fromarray(imgbuf, mode='RGB')
imgs.append(img)
return imgs
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册