diff --git a/python/paddle/v2/dataset/flowers.py b/python/paddle/v2/dataset/flowers.py index ef92fec75f3d2b143c665605b162cc318b80a4a0..fb9062fbb4f20f08f049427c3958ed3d2f861ba3 100644 --- a/python/paddle/v2/dataset/flowers.py +++ b/python/paddle/v2/dataset/flowers.py @@ -61,7 +61,8 @@ def default_mapper(is_train, sample): ''' img, label = sample img = load_image_bytes(img) - img = simple_transform(img, 256, 224, is_train) + img = simple_transform( + img, 256, 224, is_train, mean=[103.94, 116.78, 123, 68]) return img.flatten().astype('float32'), label diff --git a/python/paddle/v2/image.py b/python/paddle/v2/image.py index 0d648e9ae697ff0373c6cdc166608d395a8d8086..965d965335a56a97448bd8c738b03eceaee550e2 100644 --- a/python/paddle/v2/image.py +++ b/python/paddle/v2/image.py @@ -262,7 +262,12 @@ def left_right_flip(im): return im[:, ::-1, :] -def simple_transform(im, resize_size, crop_size, is_train, is_color=True): +def simple_transform(im, + resize_size, + crop_size, + is_train, + is_color=True, + mean=None): """ Simply data argumentation for training. These operations include resizing, croping and flipping. @@ -288,7 +293,19 @@ def simple_transform(im, resize_size, crop_size, is_train, is_color=True): im = left_right_flip(im) else: im = center_crop(im, crop_size) - im = to_chw(im) + if len(im.shape) == 3: + im = to_chw(im) + + im = im.astype('float32') + if mean is not None: + mean = np.array(mean, dtype=np.float32) + # mean value, may be one value per channel + if mean.ndim == 1: + mean = mean[:, np.newaxis, np.newaxis] + else: + # elementwise mean + assert len(mean.shape) == len(im) + im -= mean return im @@ -297,7 +314,8 @@ def load_and_transform(filename, resize_size, crop_size, is_train, - is_color=True): + is_color=True, + mean=None): """ Load image from the input file `filename` and transform image for data argumentation. Please refer to the `simple_transform` interface @@ -318,5 +336,5 @@ def load_and_transform(filename, :type is_train: bool """ im = load_image(filename) - im = simple_transform(im, resize_size, crop_size, is_train, is_color) + im = simple_transform(im, resize_size, crop_size, is_train, is_color, mean) return im