reader.py 5.1 KB
Newer Older
1
import os
2
import sys
3 4 5
import math
import random
import functools
6 7 8 9 10 11
try:
    import cPickle as pickle
    from cStringIO import StringIO
except ImportError:
    import pickle
    from io import BytesIO
12
import numpy as np
W
whs 已提交
13
import paddle
14 15 16 17 18 19 20 21
from PIL import Image, ImageEnhance

random.seed(0)

THREAD = 8
BUF_SIZE = 1024

TRAIN_LIST = 'data/train.list'
C
CrossLee1 已提交
22 23
TEST_LIST = 'data/test.list'
INFER_LIST = 'data/test.list'
24 25 26 27

img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))

28
python_ver = sys.version_info
29 30 31 32

def imageloader(buf):
    if isinstance(buf, str):
        img = Image.open(StringIO(buf))
33 34
    else:
        img = Image.open(BytesIO(buf))
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99

    return img.convert('RGB')


def group_scale(imgs, target_size):
    resized_imgs = []
    for i in range(len(imgs)):
        img = imgs[i]
        w, h = img.size
        if (w <= h and w == target_size) or (h <= w and h == target_size):
            resized_imgs.append(img)
            continue

        if w < h:
            ow = target_size
            oh = int(target_size * 4.0 / 3.0)
            resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
        else:
            oh = target_size
            ow = int(target_size * 4.0 / 3.0)
            resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))

    return resized_imgs


def group_random_crop(img_group, target_size):
    w, h = img_group[0].size
    th, tw = target_size, target_size

    out_images = []
    x1 = random.randint(0, w - tw)
    y1 = random.randint(0, h - th)

    for img in img_group:
        if w == tw and h == th:
            out_images.append(img)
        else:
            out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))

    return out_images


def group_random_flip(img_group):
    v = random.random()
    if v < 0.5:
        ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
        return ret
    else:
        return img_group


def group_center_crop(img_group, target_size):
    img_crop = []
    for img in img_group:
        w, h = img.size
        th, tw = target_size, target_size
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
        img_crop.append(img.crop((x1, y1, x1 + tw, y1 + th)))

    return img_crop


def video_loader(frames, nsample, mode):
    videolen = len(frames)
100
    average_dur = videolen // nsample
101 102 103 104 105 106 107 108 109 110 111 112

    imgs = []
    for i in range(nsample):
        idx = 0
        if mode == 'train':
            if average_dur >= 1:
                idx = random.randint(0, average_dur - 1)
                idx += i * average_dur
            else:
                idx = i
        else:
            if average_dur >= 1:
113
                idx = (average_dur - 1) // 2
114 115 116 117
                idx += i * average_dur
            else:
                idx = i

118
        imgbuf = frames[int(idx % videolen)]
119 120 121 122 123 124 125 126
        img = imageloader(imgbuf)
        imgs.append(img)

    return imgs


def decode_pickle(sample, mode, seg_num, short_size, target_size):
    pickle_path = sample[0]
127 128 129 130
    if python_ver < (3, 0):
        data_loaded = pickle.load(open(pickle_path, 'rb'))
    else:
        data_loaded = pickle.load(open(pickle_path, 'rb'), encoding='bytes')
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
    vid, label, frames = data_loaded

    imgs = video_loader(frames, seg_num, mode)
    imgs = group_scale(imgs, short_size)

    if mode == 'train':
        imgs = group_random_crop(imgs, target_size)
        imgs = group_random_flip(imgs)
    else:
        imgs = group_center_crop(imgs, target_size)

    np_imgs = (np.array(imgs[0]).astype('float32').transpose(
        (2, 0, 1))).reshape(1, 3, 224, 224) / 255
    for i in range(len(imgs) - 1):
        img = (np.array(imgs[i + 1]).astype('float32').transpose(
            (2, 0, 1))).reshape(1, 3, 224, 224) / 255
        np_imgs = np.concatenate((np_imgs, img))
    imgs = np_imgs
    imgs -= img_mean
    imgs /= img_std

    if mode == 'train' or mode == 'test':
        return imgs, label
    elif mode == 'infer':
        return imgs, vid


def _reader_creator(pickle_list,
                    mode,
                    seg_num,
                    short_size,
                    target_size,
                    shuffle=False):
    def reader():
        with open(pickle_list) as flist:
            lines = [line.strip() for line in flist]
            if shuffle:
                random.shuffle(lines)
            for line in lines:
                pickle_path = line.strip()
                yield [pickle_path]

    mapper = functools.partial(
        decode_pickle,
        mode=mode,
        seg_num=seg_num,
        short_size=short_size,
        target_size=target_size)

    return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)


def train(seg_num):
    return _reader_creator(
        TRAIN_LIST,
        'train',
        shuffle=True,
        seg_num=seg_num,
        short_size=256,
        target_size=224)


def test(seg_num):
    return _reader_creator(
        TEST_LIST,
        'test',
        shuffle=False,
        seg_num=seg_num,
        short_size=256,
        target_size=224)


def infer(seg_num):
    return _reader_creator(
        INFER_LIST,
        'infer',
        shuffle=False,
        seg_num=seg_num,
        short_size=256,
        target_size=224)