data_reader.py 2.6 KB
Newer Older
X
xiaoting 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from PIL import Image, ImageOps
import numpy as np

###A_LIST_FILE = "./train_data/trainA.txt"
###B_LIST_FILE = "./train_data/trainB.txt"
###A_TEST_LIST_FILE = "./train_data/testA.txt"
###B_TEST_LIST_FILE = "./train_data/testB.txt"
###IMAGES_ROOT = "./train_data/"

A_LIST_FILE = "./data/cityscapes/trainA.txt"
B_LIST_FILE = "./data/cityscapes/trainB.txt"
A_TEST_LIST_FILE = "./data/cityscapes/testA.txt"
B_TEST_LIST_FILE = "./data/cityscapes/testB.txt"
IMAGES_ROOT = "./data/cityscapes/"

def image_shape():
    return [3, 256, 256]


def max_images_num():
    return 2974


def reader_creater(list_file, cycle=True, shuffle=True, return_name=False):
    images = [IMAGES_ROOT + line for line in open(list_file, 'r').readlines()]

    def reader():
        while True:
            if shuffle:
                np.random.shuffle(images)
            for file in images:
                file = file.strip("\n\r\t ")
                image = Image.open(file)
                ## Resize
                image = image.resize((286, 286), Image.BICUBIC)
                ## RandomCrop
                i = np.random.randint(0, 30)
                j = np.random.randint(0, 30)
                image = image.crop((i, j , i+256, j+256))
                # RandomHorizontalFlip
                sed = np.random.rand()
                if sed > 0.5:
                    image = ImageOps.mirror(image)
                # ToTensor
                image = np.array(image).transpose([2, 0, 1]).astype('float32')
                image = image / 255.0
                # Normalize, mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]
                image = (image - 0.5) / 0.5
                
                if return_name:
                    yield image[np.newaxis, :], os.path.basename(file)
                else:
                    yield image
            if not cycle:
                break

    return reader


def a_reader(shuffle=True):
    """
    Reader of images with A style for training.
    """
    return reader_creater(A_LIST_FILE, shuffle=shuffle)


def b_reader(shuffle=True):
    """
    Reader of images with B style for training.
    """
    return reader_creater(B_LIST_FILE, shuffle=shuffle)


def a_test_reader():
    """
    Reader of images with A style for test.
    """
    return reader_creater(A_TEST_LIST_FILE, cycle=False, return_name=True)


def b_test_reader():
    """
    Reader of images with B style for test.
    """
    return reader_creater(B_TEST_LIST_FILE, cycle=False, return_name=True)