# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import os import random import numpy as np import PIL.Image as Image import StringIO import preprocess_util from image_util import crop_img def resize_image(img, target_size): """ Resize an image so that the shorter edge has length target_size. img: the input image to be resized. target_size: the target resized image size. """ percent = (target_size/float(min(img.size[0], img.size[1]))) resized_size = int(round(img.size[0] * percent)),\ int(round(img.size[1] * percent)) img = img.resize(resized_size, Image.ANTIALIAS) return img class DiskImage: """ A class of image data on disk. """ def __init__(self, path, target_size): """ path: path of the image. target_size: target resize size. """ self.path = path self.target_size = target_size self.img = None pass def read_image(self): if self.img is None: print "reading: " + self.path image = resize_image(Image.open(self.path), self.target_size) self.img = image def convert_to_array(self): self.read_image() np_array = np.array(self.img) if len(np_array.shape) == 3: np_array = np.swapaxes(np_array, 1, 2) np_array = np.swapaxes(np_array, 1, 0) return np_array def convert_to_paddle_format(self): """ convert the image into the paddle batch format. """ self.read_image() output = StringIO.StringIO() self.img.save(output, "jpeg") contents = output.getvalue() return contents class ImageClassificationDatasetCreater(preprocess_util.DatasetCreater): """ A class to process data for image classification. """ def __init__(self, data_path, target_size, color=True): """ data_path: the path to store the training data and batches. target_size: processed image size in a batch. color: whether to use color images. """ preprocess_util.DatasetCreater.__init__(self, data_path) self.target_size = target_size self.color = color self.keys = ["images", "labels"] self.permute_key = "labels" def create_meta_file(self, data): """ Create a meta file for image classification. The meta file contains the meam image, as well as some configs. data: the training Dataaet. """ output_path = os.path.join(self.data_path, self.batch_dir_name, self.meta_filename) if self.color: mean_img = np.zeros((3, self.target_size, self.target_size)) else: mean_img = np.zeros((self.target_size, self.target_size)) for d in data.data: img = d[0].convert_to_array() cropped_img = crop_img(img, self.target_size, self.color) mean_img += cropped_img mean_img /= len(data.data) mean_img = mean_img.astype('int32').flatten() preprocess_util.save_file({"data_mean": mean_img, "image_size": self.target_size, "mean_image_size": self.target_size, "num_classes": self.num_classes, "color": self.color}, output_path) pass def create_dataset_from_list(self, path): data = [] label_set = [] for line in open(file_list): items = line.rstrip.split() image_path = items[0] label_name = items[1] if not label_name in label_set: label_set[label_name] = len(label_set.keys()) img = DiskImage(path = image_path, target_size = self.target_size) label = preprocess_util.Lablel(label = label_set[label_name], name=label_name) return preprocess_util.Dataset(data, self.keys), label_set def create_dataset_from_dir(self, path): """ Create a Dataset object for image classfication. Each folder in the path directory corresponds to a set of images of this label, and the name of the folder is the name of the path: the path of the image dataset. """ if self.from_list: return create_dataset_from_list(path) label_set = preprocess_util.get_label_set_from_dir(path) data = [] for l_name in label_set.keys(): image_paths = preprocess_util.list_images(os.path.join(path, l_name)) for p in image_paths: img = DiskImage(path = p, target_size = self.target_size) label = preprocess_util.Label(label = label_set[l_name], name = l_name) data.append((img, label)) random.shuffle(data) return preprocess_util.Dataset(data, self.keys), label_set