reader.py 5.7 KB
Newer Older
W
wangmeng28 已提交
1
import os
W
wangmeng28 已提交
2
import math
W
wangmeng28 已提交
3 4 5
import random
import functools
import numpy as np
6
import paddle
W
wangmeng28 已提交
7 8 9
from PIL import Image, ImageEnhance

random.seed(0)
10
np.random.seed(0)
W
wangmeng28 已提交
11 12 13 14

DATA_DIM = 224

THREAD = 8
15
BUF_SIZE = 102400
W
wangmeng28 已提交
16

17
DATA_DIR = 'data/ILSVRC2012'
W
wangmeng28 已提交
18

W
wangmeng28 已提交
19 20
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
W
wangmeng28 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37


def resize_short(img, target_size):
    percent = float(target_size) / min(img.size[0], img.size[1])
    resized_width = int(round(img.size[0] * percent))
    resized_height = int(round(img.size[1] * percent))
    img = img.resize((resized_width, resized_height), Image.LANCZOS)
    return img


def crop_image(img, target_size, center):
    width, height = img.size
    size = target_size
    if center == True:
        w_start = (width - size) / 2
        h_start = (height - size) / 2
    else:
M
minqiyang 已提交
38 39
        w_start = np.random.randint(0, width - size + 1)
        h_start = np.random.randint(0, height - size + 1)
W
wangmeng28 已提交
40 41 42 43 44 45
    w_end = w_start + size
    h_end = h_start + size
    img = img.crop((w_start, h_start, w_end, h_end))
    return img


W
wangmeng28 已提交
46
def random_crop(img, size, scale=[0.08, 1.0], ratio=[3. / 4., 4. / 3.]):
47
    aspect_ratio = math.sqrt(np.random.uniform(*ratio))
W
wangmeng28 已提交
48 49 50 51 52 53 54 55
    w = 1. * aspect_ratio
    h = 1. / aspect_ratio

    bound = min((float(img.size[0]) / img.size[1]) / (w**2),
                (float(img.size[1]) / img.size[0]) / (h**2))
    scale_max = min(scale[1], bound)
    scale_min = min(scale[0], bound)

56
    target_area = img.size[0] * img.size[1] * np.random.uniform(scale_min,
W
wangmeng28 已提交
57 58 59 60 61
                                                             scale_max)
    target_size = math.sqrt(target_area)
    w = int(target_size * w)
    h = int(target_size * h)

M
minqiyang 已提交
62 63
    i = np.random.randint(0, img.size[0] - w + 1)
    j = np.random.randint(0, img.size[1] - h + 1)
W
wangmeng28 已提交
64 65 66 67 68 69

    img = img.crop((i, j, i + w, j + h))
    img = img.resize((size, size), Image.LANCZOS)
    return img


70
def rotate_image(img):
M
minqiyang 已提交
71
    angle = np.random.randint(-10, 11)
72 73 74 75
    img = img.rotate(angle)
    return img


W
wangmeng28 已提交
76 77
def distort_color(img):
    def random_brightness(img, lower=0.5, upper=1.5):
78
        e = np.random.uniform(lower, upper)
W
wangmeng28 已提交
79 80 81
        return ImageEnhance.Brightness(img).enhance(e)

    def random_contrast(img, lower=0.5, upper=1.5):
82
        e = np.random.uniform(lower, upper)
W
wangmeng28 已提交
83 84 85
        return ImageEnhance.Contrast(img).enhance(e)

    def random_color(img, lower=0.5, upper=1.5):
86
        e = np.random.uniform(lower, upper)
W
wangmeng28 已提交
87 88 89
        return ImageEnhance.Color(img).enhance(e)

    ops = [random_brightness, random_contrast, random_color]
90
    np.random.shuffle(ops)
W
wangmeng28 已提交
91 92 93 94 95 96 97 98

    img = ops[0](img)
    img = ops[1](img)
    img = ops[2](img)

    return img


W
wangmeng28 已提交
99
def process_image(sample, mode, color_jitter, rotate):
W
wangmeng28 已提交
100 101 102 103
    img_path = sample[0]

    img = Image.open(img_path)
    if mode == 'train':
W
wangmeng28 已提交
104 105
        if rotate: img = rotate_image(img)
        img = random_crop(img, DATA_DIM)
W
wangmeng28 已提交
106
    else:
107
        img = resize_short(img, target_size=256)
W
wangmeng28 已提交
108
        img = crop_image(img, target_size=DATA_DIM, center=True)
W
wangmeng28 已提交
109
    if mode == 'train':
W
wangmeng28 已提交
110 111
        if color_jitter:
            img = distort_color(img)
M
minqiyang 已提交
112
        if np.random.randint(0, 2) == 1:
W
wangmeng28 已提交
113 114 115 116 117
            img = img.transpose(Image.FLIP_LEFT_RIGHT)

    if img.mode != 'RGB':
        img = img.convert('RGB')

W
wangmeng28 已提交
118
    img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255
W
wangmeng28 已提交
119
    img -= img_mean
W
wangmeng28 已提交
120
    img /= img_std
W
wangmeng28 已提交
121

122
    if mode == 'train' or mode == 'val':
W
wangmeng28 已提交
123
        return img, sample[1]
124
    elif mode == 'test':
125
        return [img]
W
wangmeng28 已提交
126 127


W
wangmeng28 已提交
128 129 130 131
def _reader_creator(file_list,
                    mode,
                    shuffle=False,
                    color_jitter=False,
Y
Yancey1989 已提交
132 133
                    rotate=False,
                    data_dir=DATA_DIR):
W
wangmeng28 已提交
134 135
    def reader():
        with open(file_list) as flist:
Y
Yancey1989 已提交
136
            full_lines = [line.strip() for line in flist]
W
wangmeng28 已提交
137
            if shuffle:
Y
Yancey1989 已提交
138 139 140 141 142
                np.random.shuffle(full_lines)
            if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'):
                # distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
                trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
                trainer_count = int(os.getenv("PADDLE_TRAINERS", "1"))
M
minqiyang 已提交
143
                per_node_lines = len(full_lines) // trainer_count
Y
Yancey1989 已提交
144 145 146 147 148 149 150 151 152
                lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1)
                                   * per_node_lines]
                print(
                    "read images from %d, length: %d, lines length: %d, total: %d"
                    % (trainer_id * per_node_lines, per_node_lines, len(lines),
                       len(full_lines)))
            else:
                lines = full_lines

W
wangmeng28 已提交
153
            for line in lines:
154
                if mode == 'train' or mode == 'val':
W
wangmeng28 已提交
155
                    img_path, label = line.split()
Y
Yancey1989 已提交
156 157
                    img_path = img_path.replace("JPEG", "jpeg")
                    img_path = os.path.join(data_dir, img_path)
W
wangmeng28 已提交
158
                    yield img_path, int(label)
159
                elif mode == 'test':
Y
Yancey1989 已提交
160
                    img_path = os.path.join(data_dir, line)
W
wangmeng28 已提交
161 162
                    yield [img_path]

W
wangmeng28 已提交
163 164
    mapper = functools.partial(
        process_image, mode=mode, color_jitter=color_jitter, rotate=rotate)
W
wangmeng28 已提交
165 166 167 168

    return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)


Y
Yancey1989 已提交
169 170
def train(data_dir=DATA_DIR):
    file_list = os.path.join(data_dir, 'train_list.txt')
W
wangmeng28 已提交
171
    return _reader_creator(
Y
Yancey1989 已提交
172
        file_list, 'train', shuffle=True, color_jitter=False, rotate=False, data_dir=data_dir)
W
wangmeng28 已提交
173 174


Y
Yancey1989 已提交
175 176 177
def val(data_dir=DATA_DIR):
    file_list = os.path.join(data_dir, 'val_list.txt')
    return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir)
W
wangmeng28 已提交
178 179


Y
Yancey1989 已提交
180 181 182
def test(data_dir=DATA_DIR):
    file_list = os.path.join(data_dir, 'val_list.txt')
    return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir)