preprocess_img.py 5.4 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os
import random
import numpy as np
import PIL.Image as Image
import StringIO
import preprocess_util
from image_util import crop_img


def resize_image(img, target_size):
    """
    Resize an image so that the shorter edge has length target_size.
    img: the input image to be resized.
    target_size: the target resized image size.
    """
Q
qijun 已提交
31
    percent = (target_size / float(min(img.size[0], img.size[1])))
Z
zhangjinchao01 已提交
32 33 34 35 36
    resized_size = int(round(img.size[0] * percent)),\
                   int(round(img.size[1] * percent))
    img = img.resize(resized_size, Image.ANTIALIAS)
    return img

Q
qijun 已提交
37

Z
zhangjinchao01 已提交
38 39 40 41
class DiskImage:
    """
    A class of image data on disk.
    """
Q
qijun 已提交
42

Z
zhangjinchao01 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
    def __init__(self, path, target_size):
        """
        path: path of the image.
        target_size: target resize size.
        """
        self.path = path
        self.target_size = target_size
        self.img = None
        pass

    def read_image(self):
        if self.img is None:
            print "reading: " + self.path
            image = resize_image(Image.open(self.path), self.target_size)
            self.img = image

    def convert_to_array(self):
        self.read_image()
        np_array = np.array(self.img)
        if len(np_array.shape) == 3:
            np_array = np.swapaxes(np_array, 1, 2)
            np_array = np.swapaxes(np_array, 1, 0)
        return np_array

    def convert_to_paddle_format(self):
        """
        convert the image into the paddle batch format.
        """
        self.read_image()
        output = StringIO.StringIO()
        self.img.save(output, "jpeg")
        contents = output.getvalue()
        return contents


class ImageClassificationDatasetCreater(preprocess_util.DatasetCreater):
    """
    A class to process data for image classification.
    """
Q
qijun 已提交
82

Z
zhangjinchao01 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    def __init__(self, data_path, target_size, color=True):
        """
        data_path: the path to store the training data and batches.
        target_size: processed image size in a batch.
        color: whether to use color images.
        """
        preprocess_util.DatasetCreater.__init__(self, data_path)
        self.target_size = target_size
        self.color = color
        self.keys = ["images", "labels"]
        self.permute_key = "labels"

    def create_meta_file(self, data):
        """
        Create a meta file for image classification.
        The meta file contains the meam image, as well as some configs.
        data: the training Dataaet.
        """
Q
qijun 已提交
101
        output_path = os.path.join(self.data_path, self.batch_dir_name,
Z
zhangjinchao01 已提交
102 103 104 105 106 107 108 109 110 111 112
                                   self.meta_filename)
        if self.color:
            mean_img = np.zeros((3, self.target_size, self.target_size))
        else:
            mean_img = np.zeros((self.target_size, self.target_size))
        for d in data.data:
            img = d[0].convert_to_array()
            cropped_img = crop_img(img, self.target_size, self.color)
            mean_img += cropped_img
        mean_img /= len(data.data)
        mean_img = mean_img.astype('int32').flatten()
Q
qijun 已提交
113 114 115 116 117 118 119
        preprocess_util.save_file({
            "data_mean": mean_img,
            "image_size": self.target_size,
            "mean_image_size": self.target_size,
            "num_classes": self.num_classes,
            "color": self.color
        }, output_path)
Z
zhangjinchao01 已提交
120 121 122 123 124 125 126 127 128 129 130
        pass

    def create_dataset_from_list(self, path):
        data = []
        label_set = []
        for line in open(file_list):
            items = line.rstrip.split()
            image_path = items[0]
            label_name = items[1]
            if not label_name in label_set:
                label_set[label_name] = len(label_set.keys())
Q
qijun 已提交
131 132 133
            img = DiskImage(path=image_path, target_size=self.target_size)
            label = preprocess_util.Lablel(
                label=label_set[label_name], name=label_name)
Z
zhangjinchao01 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147
        return preprocess_util.Dataset(data, self.keys), label_set

    def create_dataset_from_dir(self, path):
        """
        Create a Dataset object for image classfication.
        Each folder in the path directory corresponds to a set of images of
        this label, and the name of the folder is the name of the
        path: the path of the image dataset.
        """
        if self.from_list:
            return create_dataset_from_list(path)
        label_set = preprocess_util.get_label_set_from_dir(path)
        data = []
        for l_name in label_set.keys():
Q
qijun 已提交
148 149
            image_paths = preprocess_util.list_images(
                os.path.join(path, l_name))
Z
zhangjinchao01 已提交
150
            for p in image_paths:
Q
qijun 已提交
151 152 153
                img = DiskImage(path=p, target_size=self.target_size)
                label = preprocess_util.Label(
                    label=label_set[l_name], name=l_name)
Z
zhangjinchao01 已提交
154 155 156
                data.append((img, label))
        random.shuffle(data)
        return preprocess_util.Dataset(data, self.keys), label_set