From d81baf511af4fe901320a4f11cb87e7620af59b0 Mon Sep 17 00:00:00 2001 From: wuzewu Date: Fri, 12 Apr 2019 12:02:05 +0800 Subject: [PATCH] image-classification reader add standardization operations --- demo/image-classification/retrain.py | 7 +++++- paddlehub/reader/cv_reader.py | 37 ++++++++++++++++++++++------ 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/demo/image-classification/retrain.py b/demo/image-classification/retrain.py index 1e6e47ce..52a0ea19 100644 --- a/demo/image-classification/retrain.py +++ b/demo/image-classification/retrain.py @@ -10,7 +10,11 @@ def train(): sign_name="feature_map", trainable=True) dataset = hub.dataset.Flowers() data_reader = hub.reader.ImageClassificationReader( - image_width=224, image_height=224, dataset=dataset) + image_width=resnet_module.get_excepted_image_width(), + image_height=resnet_module.get_excepted_image_height(), + images_mean=resnet_module.get_pretrained_images_mean(), + images_std=resnet_module.get_pretrained_images_std(), + dataset=dataset) with fluid.program_guard(program): label = fluid.layers.data(name="label", dtype="int64", shape=[1]) img = input_dict[0] @@ -20,6 +24,7 @@ def train(): use_cuda=True, num_epoch=10, batch_size=32, + enable_memory_optim=False, strategy=hub.finetune.strategy.DefaultFinetuneStrategy()) feed_list = [img.name, label.name] diff --git a/paddlehub/reader/cv_reader.py b/paddlehub/reader/cv_reader.py index aeeec007..9048df02 100644 --- a/paddlehub/reader/cv_reader.py +++ b/paddlehub/reader/cv_reader.py @@ -22,7 +22,7 @@ from PIL import Image import paddlehub.io.augmentation as image_augmentation -color_mode_dict = { +channel_order_dict = { "RGB": [0, 1, 2], "RBG": [0, 2, 1], "GBR": [1, 2, 0], @@ -37,16 +37,35 @@ class ImageClassificationReader(object): image_width, image_height, dataset, - color_mode="RGB", + channel_order="RGB", + images_mean=None, + images_std=None, data_augmentation=False): self.image_width = image_width self.image_height = image_height - self.color_mode = color_mode + self.channel_order = channel_order self.dataset = dataset self.data_augmentation = data_augmentation - if self.color_mode not in color_mode_dict: + self.images_std = images_std + self.images_mean = images_mean + + if self.images_mean is None: + try: + self.images_mean = self.dataset.images_mean + except: + self.images_mean = [0, 0, 0] + self.images_mean = np.array(self.images_mean).reshape(3, 1, 1) + + if self.images_std is None: + try: + self.images_std = self.dataset.images_std + except: + self.images_std = [1, 1, 1] + self.images_std = np.array(self.images_std).reshape(3, 1, 1) + + if self.channel_order not in channel_order_dict: raise ValueError( - "Color_mode should in %s." % color_mode_dict.keys()) + "The channel_order should in %s." % channel_order_dict.keys()) if self.image_width <= 0 or self.image_height <= 0: raise ValueError("Image width and height should not be negative.") @@ -74,12 +93,16 @@ class ImageClassificationReader(object): image = image.convert('RGB') # HWC to CHW - image = np.array(image) + image = np.array(image).astype('float32') if len(image.shape) == 3: image = np.swapaxes(image, 1, 2) image = np.swapaxes(image, 1, 0) - image = image[color_mode_dict[self.color_mode], :, :] + # standardization + image /= 255 + image -= self.images_mean + image /= self.images_std + image = image[channel_order_dict[self.channel_order], :, :] yield ((image, label)) return paddle.batch(_data_reader, batch_size=batch_size) -- GitLab