preprocess_img.py 5.5 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os
import random
import numpy as np
import PIL.Image as Image
import StringIO
import preprocess_util
from image_util import crop_img


def resize_image(img, target_size):
    """
    Resize an image so that the shorter edge has length target_size.
    img: the input image to be resized.
    target_size: the target resized image size.
    """
    percent = (target_size/float(min(img.size[0], img.size[1])))
    resized_size = int(round(img.size[0] * percent)),\
                   int(round(img.size[1] * percent))
    img = img.resize(resized_size, Image.ANTIALIAS)
    return img

class DiskImage:
    """
    A class of image data on disk.
    """
    def __init__(self, path, target_size):
        """
        path: path of the image.
        target_size: target resize size.
        """
        self.path = path
        self.target_size = target_size
        self.img = None
        pass

    def read_image(self):
        if self.img is None:
            print "reading: " + self.path
            image = resize_image(Image.open(self.path), self.target_size)
            self.img = image

    def convert_to_array(self):
        self.read_image()
        np_array = np.array(self.img)
        if len(np_array.shape) == 3:
            np_array = np.swapaxes(np_array, 1, 2)
            np_array = np.swapaxes(np_array, 1, 0)
        return np_array

    def convert_to_paddle_format(self):
        """
        convert the image into the paddle batch format.
        """
        self.read_image()
        output = StringIO.StringIO()
        self.img.save(output, "jpeg")
        contents = output.getvalue()
        return contents


class ImageClassificationDatasetCreater(preprocess_util.DatasetCreater):
    """
    A class to process data for image classification.
    """
    def __init__(self, data_path, target_size, color=True):
        """
        data_path: the path to store the training data and batches.
        target_size: processed image size in a batch.
        color: whether to use color images.
        """
        preprocess_util.DatasetCreater.__init__(self, data_path)
        self.target_size = target_size
        self.color = color
        self.keys = ["images", "labels"]
        self.permute_key = "labels"

    def create_meta_file(self, data):
        """
        Create a meta file for image classification.
        The meta file contains the meam image, as well as some configs.
        data: the training Dataaet.
        """
        output_path = os.path.join(self.data_path,
                                   self.batch_dir_name,
                                   self.meta_filename)
        if self.color:
            mean_img = np.zeros((3, self.target_size, self.target_size))
        else:
            mean_img = np.zeros((self.target_size, self.target_size))
        for d in data.data:
            img = d[0].convert_to_array()
            cropped_img = crop_img(img, self.target_size, self.color)
            mean_img += cropped_img
        mean_img /= len(data.data)
        mean_img = mean_img.astype('int32').flatten()
        preprocess_util.save_file({"data_mean": mean_img,
                                   "image_size": self.target_size,
                                   "mean_image_size": self.target_size,
                                   "num_classes": self.num_classes,
                                   "color": self.color},
                                  output_path)
        pass

    def create_dataset_from_list(self, path):
        data = []
        label_set = []
        for line in open(file_list):
            items = line.rstrip.split()
            image_path = items[0]
            label_name = items[1]
            if not label_name in label_set:
                label_set[label_name] = len(label_set.keys())
            img = DiskImage(path = image_path, target_size = self.target_size)
            label = preprocess_util.Lablel(label = label_set[label_name],
                                           name=label_name)
        return preprocess_util.Dataset(data, self.keys), label_set


    def create_dataset_from_dir(self, path):
        """
        Create a Dataset object for image classfication.
        Each folder in the path directory corresponds to a set of images of
        this label, and the name of the folder is the name of the
        path: the path of the image dataset.
        """
        if self.from_list:
            return create_dataset_from_list(path)
        label_set = preprocess_util.get_label_set_from_dir(path)
        data = []
        for l_name in label_set.keys():
            image_paths = preprocess_util.list_images(os.path.join(path, l_name))
            for p in image_paths:
                img = DiskImage(path = p, target_size = self.target_size)
                label = preprocess_util.Label(label = label_set[l_name],
                                              name = l_name)
                data.append((img, label))
        random.shuffle(data)
        return preprocess_util.Dataset(data, self.keys), label_set