cifar100_reader.py 4.1 KB
Newer Older
B
Bai Yifan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from PIL import Image
from PIL import ImageOps
import os
import math
import random
import tarfile
import functools
import numpy as np
from PIL import Image, ImageEnhance
import paddle
# for python2/python3 compatiablity
try:
    import cPickle
except:
    import _pickle as cPickle

IMAGE_SIZE = 32
IMAGE_DEPTH = 3
CIFAR_MEAN = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
CIFAR_STD = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]

URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/'
CIFAR100_URL = URL_PREFIX + 'cifar-100-python.tar.gz'
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
paddle.dataset.common.DATA_HOME = "dataset/"


def preprocess(sample, is_training):
    image_array = sample.reshape(IMAGE_DEPTH, IMAGE_SIZE, IMAGE_SIZE)
    rgb_array = np.transpose(image_array, (1, 2, 0))
    img = Image.fromarray(rgb_array, 'RGB')

    if is_training:
        # pad, ramdom crop, random_flip_left_right, random_rotation
        img = ImageOps.expand(img, (4, 4, 4, 4), fill=0)
        left_top = np.random.randint(8, size=2)
        img = img.crop((left_top[1], left_top[0], left_top[1] + IMAGE_SIZE,
                        left_top[0] + IMAGE_SIZE))
        if np.random.randint(2):
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
        random_angle = np.random.randint(-15, 15)
        img = img.rotate(random_angle, Image.NEAREST)
    img = np.array(img).astype(np.float32)

    img_float = img / 255.0
    img = (img_float - CIFAR_MEAN) / CIFAR_STD

    img = np.transpose(img, (2, 0, 1))
    return img


def reader_generator(datasets, batch_size, is_training, is_shuffle):
    def read_batch(datasets):
        if is_shuffle:
            random.shuffle(datasets)
        for im, label in datasets:
            im = preprocess(im, is_training)
            yield im, [int(label)]

    def reader():
        batch_data = []
        batch_label = []
        for data in read_batch(datasets):
            batch_data.append(data[0])
            batch_label.append(data[1])
            if len(batch_data) == batch_size:
                batch_data = np.array(batch_data, dtype='float32')
                batch_label = np.array(batch_label, dtype='int64')
                batch_out = [batch_data, batch_label]
                yield batch_out
                batch_data = []
                batch_label = []

    return reader


def cifar100_reader(file_name, data_name, is_shuffle):
    with tarfile.open(file_name, mode='r') as f:
        names = [
            each_item.name for each_item in f if data_name in each_item.name
        ]
        names.sort()
        datasets = []
        for name in names:
            print("Reading file " + name)
            try:
105
                batch = cPickle.load(f.extractfile(name), encoding='iso-8859-1')
B
Bai Yifan 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
            except:
                batch = cPickle.load(f.extractfile(name))
            data = batch['data']
            labels = batch.get('labels', batch.get('fine_labels', None))
            assert labels is not None
            dataset = zip(data, labels)
            datasets.extend(dataset)
        if is_shuffle:
            random.shuffle(datasets)
    return datasets


def train_valid(batch_size, is_train, is_shuffle):
    name = 'train' if is_train else 'test'
    datasets = cifar100_reader(
        paddle.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
        name, is_shuffle)
    reader = reader_generator(datasets, batch_size, is_train, is_shuffle)
    return reader