kinetics_reader.py 28.3 KB
Newer Older
D
dengkaipeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#  Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

15 16
import os
import sys
17
import cv2
18 19 20 21 22 23 24 25 26 27 28
import math
import random
import functools
try:
    import cPickle as pickle
    from cStringIO import StringIO
except ImportError:
    import pickle
    from io import BytesIO
import numpy as np
import paddle
H
huangjun12 已提交
29
import paddle.fluid as fluid
30 31 32 33 34 35 36
try:
    from nvidia.dali.pipeline import Pipeline
    import nvidia.dali.ops as ops
    import nvidia.dali.types as types
    import tempfile
    from nvidia.dali.plugin.paddle import DALIGenericIterator
except:
H
huangjun12 已提交
37
    Pipeline = object
38 39
    print("DALI is not installed, you can improve performance if use DALI")

40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
from PIL import Image, ImageEnhance
import logging

from .reader_utils import DataReader

logger = logging.getLogger(__name__)
python_ver = sys.version_info


class KineticsReader(DataReader):
    """
    Data reader for kinetics dataset of two format mp4 and pkl.
    1. mp4, the original format of kinetics400
    2. pkl, the mp4 was decoded previously and stored as pkl
    In both case, load the data, and then get the frame data in the form of numpy and label as an integer.
     dataset cfg: format
                  num_classes
                  seg_num
                  short_size
                  target_size
                  num_reader_threads
                  buf_size
                  image_mean
                  image_std
                  batch_size
                  list
    """

68
    def __init__(self, name, mode, cfg):
69
        super(KineticsReader, self).__init__(name, mode, cfg)
70
        self.format = cfg.MODEL.format
71 72 73 74 75 76 77
        self.num_classes = self.get_config_from_sec('model', 'num_classes')
        self.seg_num = self.get_config_from_sec('model', 'seg_num')
        self.seglen = self.get_config_from_sec('model', 'seglen')

        self.seg_num = self.get_config_from_sec(mode, 'seg_num', self.seg_num)
        self.short_size = self.get_config_from_sec(mode, 'short_size')
        self.target_size = self.get_config_from_sec(mode, 'target_size')
78 79
        self.num_reader_threads = self.get_config_from_sec(mode,
                                                           'num_reader_threads')
80
        self.buf_size = self.get_config_from_sec(mode, 'buf_size')
81
        self.fix_random_seed = self.get_config_from_sec(mode, 'fix_random_seed')
82

83
        self.img_mean = np.array(cfg.MODEL.image_mean).reshape(
84
            [3, 1, 1]).astype(np.float32)
85
        self.img_std = np.array(cfg.MODEL.image_std).reshape(
86 87
            [3, 1, 1]).astype(np.float32)
        # set batch size and file list
88 89
        self.batch_size = cfg[mode.upper()]['batch_size']
        self.filelist = cfg[mode.upper()]['filelist']
90 91 92 93 94 95 96
        # set num_trainers and trainer_id when distributed training is implemented
        self.num_trainers = self.get_config_from_sec(mode, 'num_trainers', 1)
        self.trainer_id = self.get_config_from_sec(mode, 'trainer_id', 0)
        self.use_dali = self.get_config_from_sec(mode, 'use_dali', False)
        self.dali_mean = cfg.MODEL.image_mean * (self.seg_num * self.seglen)
        self.dali_std = cfg.MODEL.image_std * (self.seg_num * self.seglen)

97 98 99 100 101
        if self.mode == 'infer':
            self.video_path = cfg[mode.upper()]['video_path']
        else:
            self.video_path = ''
        if self.fix_random_seed:
X
xiegegege 已提交
102 103
            random.seed(0)
            np.random.seed(0)
104
            self.num_reader_threads = 1
105 106

    def create_reader(self):
107 108 109 110
        # if use_dali to improve performance
        if self.use_dali:
            return self.build_dali_reader()

111 112 113 114 115 116 117 118 119 120 121 122 123
        # if set video_path for inference mode, just load this single video
        if (self.mode == 'infer') and (self.video_path != ''):
            # load video from file stored at video_path
            _reader = self._inference_reader_creator(
                self.video_path,
                self.mode,
                seg_num=self.seg_num,
                seglen=self.seglen,
                short_size=self.short_size,
                target_size=self.target_size,
                img_mean=self.img_mean,
                img_std=self.img_std)
        else:
124 125
            assert os.path.exists(self.filelist), \
                        '{} not exist, please check the data list'.format(self.filelist)
126
            _reader = self._reader_creator(self.filelist, self.mode, seg_num=self.seg_num, seglen = self.seglen, \
127 128
                             short_size = self.short_size, target_size = self.target_size, \
                             img_mean = self.img_mean, img_std = self.img_std, \
129
                             shuffle = (self.mode == 'train'), \
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
                             num_threads = self.num_reader_threads, \
                             buf_size = self.buf_size, format = self.format)

        def _batch_reader():
            batch_out = []
            for imgs, label in _reader():
                if imgs is None:
                    continue
                batch_out.append((imgs, label))
                if len(batch_out) == self.batch_size:
                    yield batch_out
                    batch_out = []

        return _batch_reader

145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
    def _inference_reader_creator(self, video_path, mode, seg_num, seglen,
                                  short_size, target_size, img_mean, img_std):
        def reader():
            try:
                imgs = mp4_loader(video_path, seg_num, seglen, mode)
                if len(imgs) < 1:
                    logger.error('{} frame length {} less than 1.'.format(
                        video_path, len(imgs)))
                    yield None, None
            except:
                logger.error('Error when loading {}'.format(mp4_path))
                yield None, None

            imgs_ret = imgs_transform(imgs, mode, seg_num, seglen, short_size,
                                      target_size, img_mean, img_std)
            label_ret = video_path

            yield imgs_ret, label_ret

        return reader

D
dengkaipeng 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178
    def _reader_creator(self,
                        pickle_list,
                        mode,
                        seg_num,
                        seglen,
                        short_size,
                        target_size,
                        img_mean,
                        img_std,
                        shuffle=False,
                        num_threads=1,
                        buf_size=1024,
                        format='pkl'):
179 180
        def decode_mp4(sample, mode, seg_num, seglen, short_size, target_size,
                       img_mean, img_std):
D
dengkaipeng 已提交
181 182 183 184 185 186 187
            sample = sample[0].split(' ')
            mp4_path = sample[0]
            # when infer, we store vid as label
            label = int(sample[1])
            try:
                imgs = mp4_loader(mp4_path, seg_num, seglen, mode)
                if len(imgs) < 1:
188 189
                    logger.error('{} frame length {} less than 1.'.format(
                        mp4_path, len(imgs)))
D
dengkaipeng 已提交
190 191 192 193 194
                    return None, None
            except:
                logger.error('Error when loading {}'.format(mp4_path))
                return None, None

195 196
            return imgs_transform(imgs, mode, seg_num, seglen, \
                         short_size, target_size, img_mean, img_std, name = self.name), label
D
dengkaipeng 已提交
197

198 199
        def decode_pickle(sample, mode, seg_num, seglen, short_size,
                          target_size, img_mean, img_std):
D
dengkaipeng 已提交
200 201 202 203 204
            pickle_path = sample[0]
            try:
                if python_ver < (3, 0):
                    data_loaded = pickle.load(open(pickle_path, 'rb'))
                else:
205 206
                    data_loaded = pickle.load(
                        open(pickle_path, 'rb'), encoding='bytes')
D
dengkaipeng 已提交
207 208 209

                vid, label, frames = data_loaded
                if len(frames) < 1:
210 211
                    logger.error('{} frame length {} less than 1.'.format(
                        pickle_path, len(frames)))
D
dengkaipeng 已提交
212 213 214 215 216 217 218 219 220 221 222
                    return None, None
            except:
                logger.info('Error when loading {}'.format(pickle_path))
                return None, None

            if mode == 'train' or mode == 'valid' or mode == 'test':
                ret_label = label
            elif mode == 'infer':
                ret_label = vid

            imgs = video_loader(frames, seg_num, seglen, mode)
223 224
            return imgs_transform(imgs, mode, seg_num, seglen, \
                         short_size, target_size, img_mean, img_std, name = self.name), ret_label
D
dengkaipeng 已提交
225

226
        def reader_():
D
dengkaipeng 已提交
227
            with open(pickle_list) as flist:
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
                full_lines = [line.strip() for line in flist]
                if self.mode == 'train':
                    if (not hasattr(reader_, 'seed')):
                        reader_.seed = 0
                    random.Random(reader_.seed).shuffle(full_lines)
                    print("reader shuffle seed", reader_.seed)
                    if reader_.seed is not None:
                        reader_.seed += 1

                per_node_lines = int(
                    math.ceil(len(full_lines) * 1.0 / self.num_trainers))
                total_lines = per_node_lines * self.num_trainers

                # aligned full_lines so that it can evenly divisible
                full_lines += full_lines[:(total_lines - len(full_lines))]
                assert len(full_lines) == total_lines

                # trainer get own sample
                lines = full_lines[self.trainer_id:total_lines:
                                   self.num_trainers]
                logger.info("trainerid %d, trainer_count %d" %
                            (self.trainer_id, self.num_trainers))
                logger.info(
                    "read images from %d, length: %d, lines length: %d, total: %d"
                    % (self.trainer_id * per_node_lines, per_node_lines,
                       len(lines), len(full_lines)))
                assert len(lines) == per_node_lines
D
dengkaipeng 已提交
255 256 257 258 259 260 261 262
                for line in lines:
                    pickle_path = line.strip()
                    yield [pickle_path]

        if format == 'pkl':
            decode_func = decode_pickle
        elif format == 'mp4':
            decode_func = decode_mp4
263
        else:
D
dengkaipeng 已提交
264 265 266 267 268 269 270 271 272 273 274 275
            raise "Not implemented format {}".format(format)

        mapper = functools.partial(
            decode_func,
            mode=mode,
            seg_num=seg_num,
            seglen=seglen,
            short_size=short_size,
            target_size=target_size,
            img_mean=img_mean,
            img_std=img_std)

H
huangjun12 已提交
276
        return fluid.io.xmap_readers(mapper, reader_, num_threads, buf_size)
277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519

    def build_dali_reader(self):
        """
        build dali training reader
        """

        def reader_():
            with open(self.filelist) as flist:
                full_lines = [line for line in flist]
                if self.mode == 'train':
                    if (not hasattr(reader_, 'seed')):
                        reader_.seed = 0
                    random.Random(reader_.seed).shuffle(full_lines)
                    print("reader shuffle seed", reader_.seed)
                    if reader_.seed is not None:
                        reader_.seed += 1

                per_node_lines = int(
                    math.ceil(len(full_lines) * 1.0 / self.num_trainers))
                total_lines = per_node_lines * self.num_trainers

                # aligned full_lines so that it can evenly divisible
                full_lines += full_lines[:(total_lines - len(full_lines))]
                assert len(full_lines) == total_lines

                # trainer get own sample
                lines = full_lines[self.trainer_id:total_lines:
                                   self.num_trainers]
                assert len(lines) == per_node_lines

                logger.info("trainerid %d, trainer_count %d" %
                            (self.trainer_id, self.num_trainers))
                logger.info(
                    "read images from %d, length: %d, lines length: %d, total: %d"
                    % (self.trainer_id * per_node_lines, per_node_lines,
                       len(lines), len(full_lines)))

            video_files = ''
            for item in lines:
                video_files += item
            tf = tempfile.NamedTemporaryFile()
            tf.write(str.encode(video_files))
            tf.flush()
            video_files = tf.name

            device_id = int(os.getenv('FLAGS_selected_gpus', 0))
            print('---------- device id -----------', device_id)

            if self.mode == 'train':
                pipe = VideoPipe(
                    batch_size=self.batch_size,
                    num_threads=1,
                    device_id=device_id,
                    file_list=video_files,
                    sequence_length=self.seg_num * self.seglen,
                    seg_num=self.seg_num,
                    seg_length=self.seglen,
                    resize_shorter_scale=self.short_size,
                    crop_target_size=self.target_size,
                    is_training=(self.mode == 'train'),
                    dali_mean=self.dali_mean,
                    dali_std=self.dali_std)
            else:
                pipe = VideoTestPipe(
                    batch_size=self.batch_size,
                    num_threads=1,
                    device_id=device_id,
                    file_list=video_files,
                    sequence_length=self.seg_num * self.seglen,
                    seg_num=self.seg_num,
                    seg_length=self.seglen,
                    resize_shorter_scale=self.short_size,
                    crop_target_size=self.target_size,
                    is_training=(self.mode == 'train'),
                    dali_mean=self.dali_mean,
                    dali_std=self.dali_std)
            logger.info(
                'initializing dataset, it will take several minutes if it is too large .... '
            )
            video_loader = DALIGenericIterator(
                [pipe], ['image', 'label'],
                len(lines),
                dynamic_shape=True,
                auto_reset=True)

            return video_loader

        dali_reader = reader_()

        def ret_reader():
            for data in dali_reader:
                yield data[0]['image'], data[0]['label']

        return ret_reader


class VideoPipe(Pipeline):
    def __init__(self,
                 batch_size,
                 num_threads,
                 device_id,
                 file_list,
                 sequence_length,
                 seg_num,
                 seg_length,
                 resize_shorter_scale,
                 crop_target_size,
                 is_training=False,
                 initial_prefetch_size=10,
                 num_shards=1,
                 shard_id=0,
                 dali_mean=0.,
                 dali_std=1.0):
        super(VideoPipe, self).__init__(batch_size, num_threads, device_id)
        self.input = ops.VideoReader(
            device="gpu",
            file_list=file_list,
            sequence_length=sequence_length,
            seg_num=seg_num,
            seg_length=seg_length,
            is_training=is_training,
            num_shards=num_shards,
            shard_id=shard_id,
            random_shuffle=is_training,
            initial_fill=initial_prefetch_size)
        # the sequece data read by ops.VideoReader is of shape [F, H, W, C]
        # Because the ops.Resize does not support sequence data, 
        # it will be transposed into [H, W, F, C], 
        # then reshaped to [H, W, FC], and then resized like a 2-D image.
        self.transpose = ops.Transpose(device="gpu", perm=[1, 2, 0, 3])
        self.reshape = ops.Reshape(
            device="gpu", rel_shape=[1.0, 1.0, -1], layout='HWC')
        self.resize = ops.Resize(
            device="gpu", resize_shorter=resize_shorter_scale)
        # crops and mirror are applied by ops.CropMirrorNormalize.
        # Normalization will be implemented in paddle due to the difficulty of dimension broadcast,
        # It is not sure whether dimension broadcast can be implemented correctly by dali, just take the Paddle Op instead.
        self.pos_rng_x = ops.Uniform(range=(0.0, 1.0))
        self.pos_rng_y = ops.Uniform(range=(0.0, 1.0))
        self.mirror_generator = ops.Uniform(range=(0.0, 1.0))
        self.cast_mirror = ops.Cast(dtype=types.DALIDataType.INT32)
        self.crop_mirror_norm = ops.CropMirrorNormalize(
            device="gpu",
            crop=[crop_target_size, crop_target_size],
            mean=dali_mean,
            std=dali_std)
        self.reshape_back = ops.Reshape(
            device="gpu",
            shape=[
                seg_num, seg_length * 3, crop_target_size, crop_target_size
            ],
            layout='FCHW')
        self.cast_label = ops.Cast(device="gpu", dtype=types.DALIDataType.INT64)

    def define_graph(self):
        output, label = self.input(name="Reader")
        output = self.transpose(output)
        output = self.reshape(output)

        output = self.resize(output)
        output = output / 255.
        pos_x = self.pos_rng_x()
        pos_y = self.pos_rng_y()
        mirror_flag = self.mirror_generator()
        mirror_flag = (mirror_flag > 0.5)
        mirror_flag = self.cast_mirror(mirror_flag)
        #output = self.crop(output, crop_pos_x=pos_x, crop_pos_y=pos_y)
        output = self.crop_mirror_norm(
            output, crop_pos_x=pos_x, crop_pos_y=pos_y, mirror=mirror_flag)
        output = self.reshape_back(output)
        label = self.cast_label(label)
        return output, label


class VideoTestPipe(Pipeline):
    def __init__(self,
                 batch_size,
                 num_threads,
                 device_id,
                 file_list,
                 sequence_length,
                 seg_num,
                 seg_length,
                 resize_shorter_scale,
                 crop_target_size,
                 is_training=False,
                 initial_prefetch_size=10,
                 num_shards=1,
                 shard_id=0,
                 dali_mean=0.,
                 dali_std=1.0):
        super(VideoTestPipe, self).__init__(batch_size, num_threads, device_id)
        self.input = ops.VideoReader(
            device="gpu",
            file_list=file_list,
            sequence_length=sequence_length,
            seg_num=seg_num,
            seg_length=seg_length,
            is_training=is_training,
            num_shards=num_shards,
            shard_id=shard_id,
            random_shuffle=is_training,
            initial_fill=initial_prefetch_size)
        # the sequece data read by ops.VideoReader is of shape [F, H, W, C]
        # Because the ops.Resize does not support sequence data, 
        # it will be transposed into [H, W, F, C], 
        # then reshaped to [H, W, FC], and then resized like a 2-D image.
        self.transpose = ops.Transpose(device="gpu", perm=[1, 2, 0, 3])
        self.reshape = ops.Reshape(
            device="gpu", rel_shape=[1.0, 1.0, -1], layout='HWC')
        self.resize = ops.Resize(
            device="gpu", resize_shorter=resize_shorter_scale)
        # crops and mirror are applied by ops.CropMirrorNormalize.
        # Normalization will be implemented in paddle due to the difficulty of dimension broadcast,
        # It is not sure whether dimension broadcast can be implemented correctly by dali, just take the Paddle Op instead.
        self.crop_mirror_norm = ops.CropMirrorNormalize(
            device="gpu",
            crop=[crop_target_size, crop_target_size],
            crop_pos_x=0.5,
            crop_pos_y=0.5,
            mirror=0,
            mean=dali_mean,
            std=dali_std)
        self.reshape_back = ops.Reshape(
            device="gpu",
            shape=[
                seg_num, seg_length * 3, crop_target_size, crop_target_size
            ],
            layout='FCHW')
        self.cast_label = ops.Cast(device="gpu", dtype=types.DALIDataType.INT64)

    def define_graph(self):
        output, label = self.input(name="Reader")
        output = self.transpose(output)
        output = self.reshape(output)

        output = self.resize(output)
        output = output / 255.
        #output = self.crop(output, crop_pos_x=pos_x, crop_pos_y=pos_y)
        output = self.crop_mirror_norm(output)
        output = self.reshape_back(output)
        label = self.cast_label(label)
        return output, label
D
dengkaipeng 已提交
520 521


522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553
def imgs_transform(imgs,
                   mode,
                   seg_num,
                   seglen,
                   short_size,
                   target_size,
                   img_mean,
                   img_std,
                   name=''):
    imgs = group_scale(imgs, short_size)

    if mode == 'train':
        if name == "TSM":
            imgs = group_multi_scale_crop(imgs, short_size)
        imgs = group_random_crop(imgs, target_size)
        imgs = group_random_flip(imgs)
    else:
        imgs = group_center_crop(imgs, target_size)

    np_imgs = (np.array(imgs[0]).astype('float32').transpose(
        (2, 0, 1))).reshape(1, 3, target_size, target_size) / 255
    for i in range(len(imgs) - 1):
        img = (np.array(imgs[i + 1]).astype('float32').transpose(
            (2, 0, 1))).reshape(1, 3, target_size, target_size) / 255
        np_imgs = np.concatenate((np_imgs, img))
    imgs = np_imgs
    imgs -= img_mean
    imgs /= img_std
    imgs = np.reshape(imgs, (seg_num, seglen * 3, target_size, target_size))

    return imgs

D
dengkaipeng 已提交
554 555 556 557 558 559 560 561 562 563 564 565 566
def group_multi_scale_crop(img_group, target_size, scales=None, \
        max_distort=1, fix_crop=True, more_fix_crop=True):
    scales = scales if scales is not None else [1, .875, .75, .66]
    input_size = [target_size, target_size]

    im_size = img_group[0].size

    # get random crop offset
    def _sample_crop_size(im_size):
        image_w, image_h = im_size[0], im_size[1]

        base_size = min(image_w, image_h)
        crop_sizes = [int(base_size * x) for x in scales]
567 568 569 570 571 572 573 574
        crop_h = [
            input_size[1] if abs(x - input_size[1]) < 3 else x
            for x in crop_sizes
        ]
        crop_w = [
            input_size[0] if abs(x - input_size[0]) < 3 else x
            for x in crop_sizes
        ]
D
dengkaipeng 已提交
575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616

        pairs = []
        for i, h in enumerate(crop_h):
            for j, w in enumerate(crop_w):
                if abs(i - j) <= max_distort:
                    pairs.append((w, h))

        crop_pair = random.choice(pairs)
        if not fix_crop:
            w_offset = random.randint(0, image_w - crop_pair[0])
            h_offset = random.randint(0, image_h - crop_pair[1])
        else:
            w_step = (image_w - crop_pair[0]) / 4
            h_step = (image_h - crop_pair[1]) / 4

            ret = list()
            ret.append((0, 0))  # upper left
            if w_step != 0:
                ret.append((4 * w_step, 0))  # upper right
            if h_step != 0:
                ret.append((0, 4 * h_step))  # lower left
            if h_step != 0 and w_step != 0:
                ret.append((4 * w_step, 4 * h_step))  # lower right
            if h_step != 0 or w_step != 0:
                ret.append((2 * w_step, 2 * h_step))  # center

            if more_fix_crop:
                ret.append((0, 2 * h_step))  # center left
                ret.append((4 * w_step, 2 * h_step))  # center right
                ret.append((2 * w_step, 4 * h_step))  # lower center
                ret.append((2 * w_step, 0 * h_step))  # upper center

                ret.append((1 * w_step, 1 * h_step))  # upper left quarter
                ret.append((3 * w_step, 1 * h_step))  # upper right quarter
                ret.append((1 * w_step, 3 * h_step))  # lower left quarter
                ret.append((3 * w_step, 3 * h_step))  # lower righ quarter

            w_offset, h_offset = random.choice(ret)

        return crop_pair[0], crop_pair[1], w_offset, h_offset

    crop_w, crop_h, offset_w, offset_h = _sample_crop_size(im_size)
617 618 619 620 621 622 623 624
    crop_img_group = [
        img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h))
        for img in img_group
    ]
    ret_img_group = [
        img.resize((input_size[0], input_size[1]), Image.BILINEAR)
        for img in crop_img_group
    ]
D
dengkaipeng 已提交
625 626

    return ret_img_group
627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701


def group_random_crop(img_group, target_size):
    w, h = img_group[0].size
    th, tw = target_size, target_size

    assert (w >= target_size) and (h >= target_size), \
          "image width({}) and height({}) should be larger than crop size".format(w, h, target_size)

    out_images = []
    x1 = random.randint(0, w - tw)
    y1 = random.randint(0, h - th)

    for img in img_group:
        if w == tw and h == th:
            out_images.append(img)
        else:
            out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))

    return out_images


def group_random_flip(img_group):
    v = random.random()
    if v < 0.5:
        ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
        return ret
    else:
        return img_group


def group_center_crop(img_group, target_size):
    img_crop = []
    for img in img_group:
        w, h = img.size
        th, tw = target_size, target_size
        assert (w >= target_size) and (h >= target_size), \
             "image width({}) and height({}) should be larger than crop size".format(w, h, target_size)
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
        img_crop.append(img.crop((x1, y1, x1 + tw, y1 + th)))

    return img_crop


def group_scale(imgs, target_size):
    resized_imgs = []
    for i in range(len(imgs)):
        img = imgs[i]
        w, h = img.size
        if (w <= h and w == target_size) or (h <= w and h == target_size):
            resized_imgs.append(img)
            continue

        if w < h:
            ow = target_size
            oh = int(target_size * 4.0 / 3.0)
            resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
        else:
            oh = target_size
            ow = int(target_size * 4.0 / 3.0)
            resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))

    return resized_imgs


def imageloader(buf):
    if isinstance(buf, str):
        img = Image.open(StringIO(buf))
    else:
        img = Image.open(BytesIO(buf))

    return img.convert('RGB')


702
def video_loader(frames, nsample, seglen, mode):
703 704 705 706 707 708
    videolen = len(frames)
    average_dur = int(videolen / nsample)

    imgs = []
    for i in range(nsample):
        idx = 0
709
        if mode == 'train':
710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733
            if average_dur >= seglen:
                idx = random.randint(0, average_dur - seglen)
                idx += i * average_dur
            elif average_dur >= 1:
                idx += i * average_dur
            else:
                idx = i
        else:
            if average_dur >= seglen:
                idx = (average_dur - seglen) // 2
                idx += i * average_dur
            elif average_dur >= 1:
                idx += i * average_dur
            else:
                idx = i

        for jj in range(idx, idx + seglen):
            imgbuf = frames[int(jj % videolen)]
            img = imageloader(imgbuf)
            imgs.append(img)

    return imgs


734
def mp4_loader(filepath, nsample, seglen, mode):
735 736 737 738 739 740 741 742 743 744
    cap = cv2.VideoCapture(filepath)
    videolen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    sampledFrames = []
    for i in range(videolen):
        ret, frame = cap.read()
        # maybe first frame is empty
        if ret == False:
            continue
        img = frame[:, :, ::-1]
        sampledFrames.append(img)
745
    average_dur = int(len(sampledFrames) / nsample)
746 747 748
    imgs = []
    for i in range(nsample):
        idx = 0
749
        if mode == 'train':
750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766
            if average_dur >= seglen:
                idx = random.randint(0, average_dur - seglen)
                idx += i * average_dur
            elif average_dur >= 1:
                idx += i * average_dur
            else:
                idx = i
        else:
            if average_dur >= seglen:
                idx = (average_dur - 1) // 2
                idx += i * average_dur
            elif average_dur >= 1:
                idx += i * average_dur
            else:
                idx = i

        for jj in range(idx, idx + seglen):
767
            imgbuf = sampledFrames[int(jj % len(sampledFrames))]
768 769 770 771
            img = Image.fromarray(imgbuf, mode='RGB')
            imgs.append(img)

    return imgs