reader.py 5.4 KB
Newer Older
1
import os
2
import sys
3 4 5
import math
import random
import functools
6 7 8 9 10 11
try:
    import cPickle as pickle
    from cStringIO import StringIO
except ImportError:
    import pickle
    from io import BytesIO
12
import numpy as np
W
whs 已提交
13
import paddle
14 15 16
from PIL import Image, ImageEnhance

random.seed(0)
Z
add ce  
zhengya01 已提交
17
np.random.seed(0)
18 19 20 21 22

THREAD = 8
BUF_SIZE = 1024

TRAIN_LIST = 'data/train.list'
C
CrossLee1 已提交
23 24
TEST_LIST = 'data/test.list'
INFER_LIST = 'data/test.list'
25 26 27 28

img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))

29
python_ver = sys.version_info
30

Z
add ce  
zhengya01 已提交
31

32 33 34
def imageloader(buf):
    if isinstance(buf, str):
        img = Image.open(StringIO(buf))
35 36
    else:
        img = Image.open(BytesIO(buf))
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101

    return img.convert('RGB')


def group_scale(imgs, target_size):
    resized_imgs = []
    for i in range(len(imgs)):
        img = imgs[i]
        w, h = img.size
        if (w <= h and w == target_size) or (h <= w and h == target_size):
            resized_imgs.append(img)
            continue

        if w < h:
            ow = target_size
            oh = int(target_size * 4.0 / 3.0)
            resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
        else:
            oh = target_size
            ow = int(target_size * 4.0 / 3.0)
            resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))

    return resized_imgs


def group_random_crop(img_group, target_size):
    w, h = img_group[0].size
    th, tw = target_size, target_size

    out_images = []
    x1 = random.randint(0, w - tw)
    y1 = random.randint(0, h - th)

    for img in img_group:
        if w == tw and h == th:
            out_images.append(img)
        else:
            out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))

    return out_images


def group_random_flip(img_group):
    v = random.random()
    if v < 0.5:
        ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
        return ret
    else:
        return img_group


def group_center_crop(img_group, target_size):
    img_crop = []
    for img in img_group:
        w, h = img.size
        th, tw = target_size, target_size
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
        img_crop.append(img.crop((x1, y1, x1 + tw, y1 + th)))

    return img_crop


def video_loader(frames, nsample, mode):
    videolen = len(frames)
102
    average_dur = videolen // nsample
103 104 105 106 107 108 109 110 111 112 113 114

    imgs = []
    for i in range(nsample):
        idx = 0
        if mode == 'train':
            if average_dur >= 1:
                idx = random.randint(0, average_dur - 1)
                idx += i * average_dur
            else:
                idx = i
        else:
            if average_dur >= 1:
115
                idx = (average_dur - 1) // 2
116 117 118 119
                idx += i * average_dur
            else:
                idx = i

120
        imgbuf = frames[int(idx % videolen)]
121 122 123 124 125 126 127 128
        img = imageloader(imgbuf)
        imgs.append(img)

    return imgs


def decode_pickle(sample, mode, seg_num, short_size, target_size):
    pickle_path = sample[0]
129 130 131 132
    if python_ver < (3, 0):
        data_loaded = pickle.load(open(pickle_path, 'rb'))
    else:
        data_loaded = pickle.load(open(pickle_path, 'rb'), encoding='bytes')
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
    vid, label, frames = data_loaded

    imgs = video_loader(frames, seg_num, mode)
    imgs = group_scale(imgs, short_size)

    if mode == 'train':
        imgs = group_random_crop(imgs, target_size)
        imgs = group_random_flip(imgs)
    else:
        imgs = group_center_crop(imgs, target_size)

    np_imgs = (np.array(imgs[0]).astype('float32').transpose(
        (2, 0, 1))).reshape(1, 3, 224, 224) / 255
    for i in range(len(imgs) - 1):
        img = (np.array(imgs[i + 1]).astype('float32').transpose(
            (2, 0, 1))).reshape(1, 3, 224, 224) / 255
        np_imgs = np.concatenate((np_imgs, img))
    imgs = np_imgs
    imgs -= img_mean
    imgs /= img_std

Z
add ce  
zhengya01 已提交
154
    if mode == 'train' or mode == 'test' or mode == 'train_ce':
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
        return imgs, label
    elif mode == 'infer':
        return imgs, vid


def _reader_creator(pickle_list,
                    mode,
                    seg_num,
                    short_size,
                    target_size,
                    shuffle=False):
    def reader():
        with open(pickle_list) as flist:
            lines = [line.strip() for line in flist]
            if shuffle:
                random.shuffle(lines)
            for line in lines:
                pickle_path = line.strip()
                yield [pickle_path]

    mapper = functools.partial(
        decode_pickle,
        mode=mode,
        seg_num=seg_num,
        short_size=short_size,
        target_size=target_size)

    return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)


def train(seg_num):
    return _reader_creator(
        TRAIN_LIST,
        'train',
        shuffle=True,
        seg_num=seg_num,
        short_size=256,
        target_size=224)


def test(seg_num):
    return _reader_creator(
        TEST_LIST,
        'test',
        shuffle=False,
        seg_num=seg_num,
        short_size=256,
        target_size=224)


def infer(seg_num):
    return _reader_creator(
        INFER_LIST,
        'infer',
        shuffle=False,
        seg_num=seg_num,
        short_size=256,
        target_size=224)
Z
add ce  
zhengya01 已提交
213 214 215 216 217 218 219 220 221 222


def train_ce(seg_num):
    return _reader_creator(
        TRAIN_LIST,
        'train_ce',
        shuffle=False,
        seg_num=seg_num,
        short_size=256,
        target_size=224)