diff --git a/PaddleCV/video/datareader/kinetics_reader.py b/PaddleCV/video/datareader/kinetics_reader.py index c7bbf17241383ffc32778330db9ac78308683b46..ed1de044fdaef39f5ccda6c4684733c8dcd8c8ef 100644 --- a/PaddleCV/video/datareader/kinetics_reader.py +++ b/PaddleCV/video/datareader/kinetics_reader.py @@ -74,7 +74,7 @@ class KineticsReader(DataReader): self.filelist = cfg[mode.upper()]['filelist'] def create_reader(self): - _reader = _reader_creator(self.filelist, self.mode, seg_num=self.seg_num, seglen = self.seglen, \ + _reader = self._reader_creator(self.filelist, self.mode, seg_num=self.seg_num, seglen = self.seglen, \ short_size = self.short_size, target_size = self.target_size, \ img_mean = self.img_mean, img_std = self.img_std, \ shuffle = (self.mode == 'train'), \ @@ -94,117 +94,183 @@ class KineticsReader(DataReader): return _batch_reader -def _reader_creator(pickle_list, - mode, - seg_num, - seglen, - short_size, - target_size, - img_mean, - img_std, - shuffle=False, - num_threads=1, - buf_size=1024, - format='pkl'): - def reader(): - with open(pickle_list) as flist: - lines = [line.strip() for line in flist] - if shuffle: - random.shuffle(lines) - for line in lines: - pickle_path = line.strip() - yield [pickle_path] - - if format == 'pkl': - decode_func = decode_pickle - elif format == 'mp4': - decode_func = decode_mp4 - else: - raise "Not implemented format {}".format(format) - - mapper = functools.partial( - decode_func, - mode=mode, - seg_num=seg_num, - seglen=seglen, - short_size=short_size, - target_size=target_size, - img_mean=img_mean, - img_std=img_std) - - return paddle.reader.xmap_readers(mapper, reader, num_threads, buf_size) - - -def decode_mp4(sample, mode, seg_num, seglen, short_size, target_size, img_mean, - img_std): - sample = sample[0].split(' ') - mp4_path = sample[0] - # when infer, we store vid as label - label = int(sample[1]) - try: - imgs = mp4_loader(mp4_path, seg_num, seglen, mode) - if len(imgs) < 1: - logger.error('{} frame length {} less than 1.'.format(mp4_path, - len(imgs))) - return None, None - except: - logger.error('Error when loading {}'.format(mp4_path)) - return None, None - - return imgs_transform(imgs, label, mode, seg_num, seglen, \ - short_size, target_size, img_mean, img_std) - - -def decode_pickle(sample, mode, seg_num, seglen, short_size, target_size, - img_mean, img_std): - pickle_path = sample[0] - try: - if python_ver < (3, 0): - data_loaded = pickle.load(open(pickle_path, 'rb')) + def _reader_creator(self, + pickle_list, + mode, + seg_num, + seglen, + short_size, + target_size, + img_mean, + img_std, + shuffle=False, + num_threads=1, + buf_size=1024, + format='pkl'): + def decode_mp4(sample, mode, seg_num, seglen, short_size, target_size, img_mean, + img_std): + sample = sample[0].split(' ') + mp4_path = sample[0] + # when infer, we store vid as label + label = int(sample[1]) + try: + imgs = mp4_loader(mp4_path, seg_num, seglen, mode) + if len(imgs) < 1: + logger.error('{} frame length {} less than 1.'.format(mp4_path, + len(imgs))) + return None, None + except: + logger.error('Error when loading {}'.format(mp4_path)) + return None, None + + return imgs_transform(imgs, label, mode, seg_num, seglen, \ + short_size, target_size, img_mean, img_std) + + + def decode_pickle(sample, mode, seg_num, seglen, short_size, target_size, + img_mean, img_std): + pickle_path = sample[0] + try: + if python_ver < (3, 0): + data_loaded = pickle.load(open(pickle_path, 'rb')) + else: + data_loaded = pickle.load(open(pickle_path, 'rb'), encoding='bytes') + + vid, label, frames = data_loaded + if len(frames) < 1: + logger.error('{} frame length {} less than 1.'.format(pickle_path, + len(frames))) + return None, None + except: + logger.info('Error when loading {}'.format(pickle_path)) + return None, None + + if mode == 'train' or mode == 'valid' or mode == 'test': + ret_label = label + elif mode == 'infer': + ret_label = vid + + imgs = video_loader(frames, seg_num, seglen, mode) + return imgs_transform(imgs, ret_label, mode, seg_num, seglen, \ + short_size, target_size, img_mean, img_std) + + + def imgs_transform(imgs, label, mode, seg_num, seglen, short_size, target_size, + img_mean, img_std): + imgs = group_scale(imgs, short_size) + + if mode == 'train': + if self.name == "TSM": + imgs = group_multi_scale_crop(imgs, short_size) + imgs = group_random_crop(imgs, target_size) + imgs = group_random_flip(imgs) + else: + imgs = group_center_crop(imgs, target_size) + + np_imgs = (np.array(imgs[0]).astype('float32').transpose( + (2, 0, 1))).reshape(1, 3, target_size, target_size) / 255 + for i in range(len(imgs) - 1): + img = (np.array(imgs[i + 1]).astype('float32').transpose( + (2, 0, 1))).reshape(1, 3, target_size, target_size) / 255 + np_imgs = np.concatenate((np_imgs, img)) + imgs = np_imgs + imgs -= img_mean + imgs /= img_std + imgs = np.reshape(imgs, (seg_num, seglen * 3, target_size, target_size)) + + return imgs, label + + + def reader(): + with open(pickle_list) as flist: + lines = [line.strip() for line in flist] + if shuffle: + random.shuffle(lines) + for line in lines: + pickle_path = line.strip() + yield [pickle_path] + + if format == 'pkl': + decode_func = decode_pickle + elif format == 'mp4': + decode_func = decode_mp4 else: - data_loaded = pickle.load(open(pickle_path, 'rb'), encoding='bytes') - - vid, label, frames = data_loaded - if len(frames) < 1: - logger.error('{} frame length {} less than 1.'.format(pickle_path, - len(frames))) - return None, None - except: - logger.info('Error when loading {}'.format(pickle_path)) - return None, None - - if mode == 'train' or mode == 'valid' or mode == 'test': - ret_label = label - elif mode == 'infer': - ret_label = vid - - imgs = video_loader(frames, seg_num, seglen, mode) - return imgs_transform(imgs, ret_label, mode, seg_num, seglen, \ - short_size, target_size, img_mean, img_std) - - -def imgs_transform(imgs, label, mode, seg_num, seglen, short_size, target_size, - img_mean, img_std): - imgs = group_scale(imgs, short_size) - - if mode == 'train': - imgs = group_random_crop(imgs, target_size) - imgs = group_random_flip(imgs) - else: - imgs = group_center_crop(imgs, target_size) - - np_imgs = (np.array(imgs[0]).astype('float32').transpose( - (2, 0, 1))).reshape(1, 3, target_size, target_size) / 255 - for i in range(len(imgs) - 1): - img = (np.array(imgs[i + 1]).astype('float32').transpose( - (2, 0, 1))).reshape(1, 3, target_size, target_size) / 255 - np_imgs = np.concatenate((np_imgs, img)) - imgs = np_imgs - imgs -= img_mean - imgs /= img_std - imgs = np.reshape(imgs, (seg_num, seglen * 3, target_size, target_size)) - - return imgs, label + raise "Not implemented format {}".format(format) + + mapper = functools.partial( + decode_func, + mode=mode, + seg_num=seg_num, + seglen=seglen, + short_size=short_size, + target_size=target_size, + img_mean=img_mean, + img_std=img_std) + + return paddle.reader.xmap_readers(mapper, reader, num_threads, buf_size) + + +def group_multi_scale_crop(img_group, target_size, scales=None, \ + max_distort=1, fix_crop=True, more_fix_crop=True): + scales = scales if scales is not None else [1, .875, .75, .66] + input_size = [target_size, target_size] + + im_size = img_group[0].size + + # get random crop offset + def _sample_crop_size(im_size): + image_w, image_h = im_size[0], im_size[1] + + base_size = min(image_w, image_h) + crop_sizes = [int(base_size * x) for x in scales] + crop_h = [input_size[1] if abs(x - input_size[1]) < 3 else x for x in crop_sizes] + crop_w = [input_size[0] if abs(x - input_size[0]) < 3 else x for x in crop_sizes] + + pairs = [] + for i, h in enumerate(crop_h): + for j, w in enumerate(crop_w): + if abs(i - j) <= max_distort: + pairs.append((w, h)) + + crop_pair = random.choice(pairs) + if not fix_crop: + w_offset = random.randint(0, image_w - crop_pair[0]) + h_offset = random.randint(0, image_h - crop_pair[1]) + else: + w_step = (image_w - crop_pair[0]) / 4 + h_step = (image_h - crop_pair[1]) / 4 + + ret = list() + ret.append((0, 0)) # upper left + if w_step != 0: + ret.append((4 * w_step, 0)) # upper right + if h_step != 0: + ret.append((0, 4 * h_step)) # lower left + if h_step != 0 and w_step != 0: + ret.append((4 * w_step, 4 * h_step)) # lower right + if h_step != 0 or w_step != 0: + ret.append((2 * w_step, 2 * h_step)) # center + + if more_fix_crop: + ret.append((0, 2 * h_step)) # center left + ret.append((4 * w_step, 2 * h_step)) # center right + ret.append((2 * w_step, 4 * h_step)) # lower center + ret.append((2 * w_step, 0 * h_step)) # upper center + + ret.append((1 * w_step, 1 * h_step)) # upper left quarter + ret.append((3 * w_step, 1 * h_step)) # upper right quarter + ret.append((1 * w_step, 3 * h_step)) # lower left quarter + ret.append((3 * w_step, 3 * h_step)) # lower righ quarter + + w_offset, h_offset = random.choice(ret) + + return crop_pair[0], crop_pair[1], w_offset, h_offset + + crop_w, crop_h, offset_w, offset_h = _sample_crop_size(im_size) + crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] + ret_img_group = [img.resize((input_size[0], input_size[1]), Image.BILINEAR) for img in crop_img_group] + + return ret_img_group def group_random_crop(img_group, target_size):