data.py 3.6 KB
Newer Older
Q
qingqing01 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Q
qingqing01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import numpy as np
from PIL import Image, ImageOps

DATASET = "cityscapes"
A_LIST_FILE = "./data/" + DATASET + "/trainA.txt"
B_LIST_FILE = "./data/" + DATASET + "/trainB.txt"
A_TEST_LIST_FILE = "./data/" + DATASET + "/testA.txt"
B_TEST_LIST_FILE = "./data/" + DATASET + "/testB.txt"
IMAGES_ROOT = "./data/" + DATASET + "/"

import paddle.fluid as fluid


class Cityscapes(fluid.io.Dataset):
    def __init__(self, root_path, file_path, mode='train', return_name=False):
        self.root_path = root_path
        self.file_path = file_path
        self.mode = mode
        self.return_name = return_name
        self.images = [root_path + l for l in open(file_path, 'r').readlines()]

    def _train(self, image):
        ## Resize
        image = image.resize((286, 286), Image.BICUBIC)
        ## RandomCrop
        i = np.random.randint(0, 30)
        j = np.random.randint(0, 30)
        image = image.crop((i, j, i + 256, j + 256))
        # RandomHorizontalFlip
        if np.random.rand() > 0.5:
            image = ImageOps.mirror(image)
        return image

    def __getitem__(self, idx):
        f = self.images[idx].strip("\n\r\t ")
        image = Image.open(f)
        if self.mode == 'train':
            image = self._train(image)
        else:
            image = image.resize((256, 256), Image.BICUBIC)
        # ToTensor
        image = np.array(image).transpose([2, 0, 1]).astype('float32')
        image = image / 255.0
        # Normalize, mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]
        image = (image - 0.5) / 0.5
        if self.return_name:
            return [image], os.path.basename(f)
        else:
            return [image]

    def __len__(self):
        return len(self.images)


def DataA(root=IMAGES_ROOT, fpath=A_LIST_FILE):
    """
    Reader of images with A style for training.
    """
    return Cityscapes(root, fpath)


def DataB(root=IMAGES_ROOT, fpath=B_LIST_FILE):
    """
    Reader of images with B style for training.
    """
    return Cityscapes(root, fpath)


def TestDataA(root=IMAGES_ROOT, fpath=A_TEST_LIST_FILE):
    """
    Reader of images with A style for training.
    """
    return Cityscapes(root, fpath, mode='test', return_name=True)


def TestDataB(root=IMAGES_ROOT, fpath=B_TEST_LIST_FILE):
    """
    Reader of images with B style for training.
    """
    return Cityscapes(root, fpath, mode='test', return_name=True)


class ImagePool(object):
    def __init__(self, pool_size=50):
        self.pool = []
        self.count = 0
        self.pool_size = pool_size

    def get(self, image):
        if self.count < self.pool_size:
            self.pool.append(image)
            self.count += 1
            return image
        else:
            p = random.random()
            if p > 0.5:
                random_id = random.randint(0, self.pool_size - 1)
                temp = self.pool[random_id]
                self.pool[random_id] = image
                return temp
            else:
                return image