preprocess_img.py 5.4 KB
Newer Older
1
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os
import random
import numpy as np
import PIL.Image as Image
M
minqiyang 已提交
20 21 22
from six.moves import cStringIO as StringIO
from . import preprocess_util
from .image_util import crop_img
Z
zhangjinchao01 已提交
23 24 25 26 27 28 29 30


def resize_image(img, target_size):
    """
    Resize an image so that the shorter edge has length target_size.
    img: the input image to be resized.
    target_size: the target resized image size.
    """
Q
qijun 已提交
31
    percent = (target_size / float(min(img.size[0], img.size[1])))
Z
zhangjinchao01 已提交
32 33 34 35 36
    resized_size = int(round(img.size[0] * percent)),\
                   int(round(img.size[1] * percent))
    img = img.resize(resized_size, Image.ANTIALIAS)
    return img

Q
qijun 已提交
37

Z
zhangjinchao01 已提交
38 39 40 41
class DiskImage:
    """
    A class of image data on disk.
    """
Q
qijun 已提交
42

Z
zhangjinchao01 已提交
43 44 45 46 47 48 49 50 51 52 53 54
    def __init__(self, path, target_size):
        """
        path: path of the image.
        target_size: target resize size.
        """
        self.path = path
        self.target_size = target_size
        self.img = None
        pass

    def read_image(self):
        if self.img is None:
M
minqiyang 已提交
55
            print("reading: " + self.path)
Z
zhangjinchao01 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
            image = resize_image(Image.open(self.path), self.target_size)
            self.img = image

    def convert_to_array(self):
        self.read_image()
        np_array = np.array(self.img)
        if len(np_array.shape) == 3:
            np_array = np.swapaxes(np_array, 1, 2)
            np_array = np.swapaxes(np_array, 1, 0)
        return np_array

    def convert_to_paddle_format(self):
        """
        convert the image into the paddle batch format.
        """
        self.read_image()
M
minqiyang 已提交
72
        output = StringIO()
Z
zhangjinchao01 已提交
73 74 75 76 77 78 79 80 81
        self.img.save(output, "jpeg")
        contents = output.getvalue()
        return contents


class ImageClassificationDatasetCreater(preprocess_util.DatasetCreater):
    """
    A class to process data for image classification.
    """
Q
qijun 已提交
82

Z
zhangjinchao01 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    def __init__(self, data_path, target_size, color=True):
        """
        data_path: the path to store the training data and batches.
        target_size: processed image size in a batch.
        color: whether to use color images.
        """
        preprocess_util.DatasetCreater.__init__(self, data_path)
        self.target_size = target_size
        self.color = color
        self.keys = ["images", "labels"]
        self.permute_key = "labels"

    def create_meta_file(self, data):
        """
        Create a meta file for image classification.
        The meta file contains the meam image, as well as some configs.
        data: the training Dataaet.
        """
Q
qijun 已提交
101
        output_path = os.path.join(self.data_path, self.batch_dir_name,
Z
zhangjinchao01 已提交
102 103 104 105 106 107 108 109 110 111 112
                                   self.meta_filename)
        if self.color:
            mean_img = np.zeros((3, self.target_size, self.target_size))
        else:
            mean_img = np.zeros((self.target_size, self.target_size))
        for d in data.data:
            img = d[0].convert_to_array()
            cropped_img = crop_img(img, self.target_size, self.color)
            mean_img += cropped_img
        mean_img /= len(data.data)
        mean_img = mean_img.astype('int32').flatten()
Q
qijun 已提交
113 114 115 116 117 118 119
        preprocess_util.save_file({
            "data_mean": mean_img,
            "image_size": self.target_size,
            "mean_image_size": self.target_size,
            "num_classes": self.num_classes,
            "color": self.color
        }, output_path)
Z
zhangjinchao01 已提交
120 121 122 123 124
        pass

    def create_dataset_from_list(self, path):
        data = []
        label_set = []
L
lujun 已提交
125
        for line in open(path):
Z
zhangjinchao01 已提交
126 127 128 129
            items = line.rstrip.split()
            image_path = items[0]
            label_name = items[1]
            if not label_name in label_set:
M
minqiyang 已提交
130
                label_set[label_name] = len(list(label_set.keys()))
Q
qijun 已提交
131 132 133
            img = DiskImage(path=image_path, target_size=self.target_size)
            label = preprocess_util.Lablel(
                label=label_set[label_name], name=label_name)
Z
zhangjinchao01 已提交
134 135 136 137
        return preprocess_util.Dataset(data, self.keys), label_set

    def create_dataset_from_dir(self, path):
        """
T
tianshuo78520a 已提交
138
        Create a Dataset object for image classification.
Z
zhangjinchao01 已提交
139 140 141 142 143
        Each folder in the path directory corresponds to a set of images of
        this label, and the name of the folder is the name of the
        path: the path of the image dataset.
        """
        if self.from_list:
L
lujun 已提交
144
            return self.create_dataset_from_list(path)
Z
zhangjinchao01 已提交
145 146
        label_set = preprocess_util.get_label_set_from_dir(path)
        data = []
M
minqiyang 已提交
147
        for l_name in list(label_set.keys()):
Q
qijun 已提交
148 149
            image_paths = preprocess_util.list_images(
                os.path.join(path, l_name))
Z
zhangjinchao01 已提交
150
            for p in image_paths:
Q
qijun 已提交
151 152 153
                img = DiskImage(path=p, target_size=self.target_size)
                label = preprocess_util.Label(
                    label=label_set[l_name], name=l_name)
Z
zhangjinchao01 已提交
154 155 156
                data.append((img, label))
        random.shuffle(data)
        return preprocess_util.Dataset(data, self.keys), label_set