提交 d81baf51 编写于 作者: W wuzewu

image-classification reader add standardization operations

上级 c8404395
...@@ -10,7 +10,11 @@ def train(): ...@@ -10,7 +10,11 @@ def train():
sign_name="feature_map", trainable=True) sign_name="feature_map", trainable=True)
dataset = hub.dataset.Flowers() dataset = hub.dataset.Flowers()
data_reader = hub.reader.ImageClassificationReader( data_reader = hub.reader.ImageClassificationReader(
image_width=224, image_height=224, dataset=dataset) image_width=resnet_module.get_excepted_image_width(),
image_height=resnet_module.get_excepted_image_height(),
images_mean=resnet_module.get_pretrained_images_mean(),
images_std=resnet_module.get_pretrained_images_std(),
dataset=dataset)
with fluid.program_guard(program): with fluid.program_guard(program):
label = fluid.layers.data(name="label", dtype="int64", shape=[1]) label = fluid.layers.data(name="label", dtype="int64", shape=[1])
img = input_dict[0] img = input_dict[0]
...@@ -20,6 +24,7 @@ def train(): ...@@ -20,6 +24,7 @@ def train():
use_cuda=True, use_cuda=True,
num_epoch=10, num_epoch=10,
batch_size=32, batch_size=32,
enable_memory_optim=False,
strategy=hub.finetune.strategy.DefaultFinetuneStrategy()) strategy=hub.finetune.strategy.DefaultFinetuneStrategy())
feed_list = [img.name, label.name] feed_list = [img.name, label.name]
......
...@@ -22,7 +22,7 @@ from PIL import Image ...@@ -22,7 +22,7 @@ from PIL import Image
import paddlehub.io.augmentation as image_augmentation import paddlehub.io.augmentation as image_augmentation
color_mode_dict = { channel_order_dict = {
"RGB": [0, 1, 2], "RGB": [0, 1, 2],
"RBG": [0, 2, 1], "RBG": [0, 2, 1],
"GBR": [1, 2, 0], "GBR": [1, 2, 0],
...@@ -37,16 +37,35 @@ class ImageClassificationReader(object): ...@@ -37,16 +37,35 @@ class ImageClassificationReader(object):
image_width, image_width,
image_height, image_height,
dataset, dataset,
color_mode="RGB", channel_order="RGB",
images_mean=None,
images_std=None,
data_augmentation=False): data_augmentation=False):
self.image_width = image_width self.image_width = image_width
self.image_height = image_height self.image_height = image_height
self.color_mode = color_mode self.channel_order = channel_order
self.dataset = dataset self.dataset = dataset
self.data_augmentation = data_augmentation self.data_augmentation = data_augmentation
if self.color_mode not in color_mode_dict: self.images_std = images_std
self.images_mean = images_mean
if self.images_mean is None:
try:
self.images_mean = self.dataset.images_mean
except:
self.images_mean = [0, 0, 0]
self.images_mean = np.array(self.images_mean).reshape(3, 1, 1)
if self.images_std is None:
try:
self.images_std = self.dataset.images_std
except:
self.images_std = [1, 1, 1]
self.images_std = np.array(self.images_std).reshape(3, 1, 1)
if self.channel_order not in channel_order_dict:
raise ValueError( raise ValueError(
"Color_mode should in %s." % color_mode_dict.keys()) "The channel_order should in %s." % channel_order_dict.keys())
if self.image_width <= 0 or self.image_height <= 0: if self.image_width <= 0 or self.image_height <= 0:
raise ValueError("Image width and height should not be negative.") raise ValueError("Image width and height should not be negative.")
...@@ -74,12 +93,16 @@ class ImageClassificationReader(object): ...@@ -74,12 +93,16 @@ class ImageClassificationReader(object):
image = image.convert('RGB') image = image.convert('RGB')
# HWC to CHW # HWC to CHW
image = np.array(image) image = np.array(image).astype('float32')
if len(image.shape) == 3: if len(image.shape) == 3:
image = np.swapaxes(image, 1, 2) image = np.swapaxes(image, 1, 2)
image = np.swapaxes(image, 1, 0) image = np.swapaxes(image, 1, 0)
image = image[color_mode_dict[self.color_mode], :, :] # standardization
image /= 255
image -= self.images_mean
image /= self.images_std
image = image[channel_order_dict[self.channel_order], :, :]
yield ((image, label)) yield ((image, label))
return paddle.batch(_data_reader, batch_size=batch_size) return paddle.batch(_data_reader, batch_size=batch_size)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册