diff --git a/python/paddle/v2/image.py b/python/paddle/v2/image.py index 1429d6b1e08fe4ab2d1c5a0f19f1cedbcbc85abd..de4c8abbc3b3c1f37bbb36f509e1c0f2dbe9a44e 100644 --- a/python/paddle/v2/image.py +++ b/python/paddle/v2/image.py @@ -176,7 +176,6 @@ def resize_short(im, size): :param size: the shorter edge size of image after resizing. :type size: int """ - assert im.shape[-1] == 1 or im.shape[-1] == 3 h, w = im.shape[:2] h_new, w_new = size, size if h > w: @@ -267,7 +266,7 @@ def random_crop(im, size, is_color=True): return im -def left_right_flip(im): +def left_right_flip(im, is_color=True): """ Flip an image along the horizontal direction. Return the flipped image. @@ -278,13 +277,16 @@ def left_right_flip(im): im = left_right_flip(im) - :paam im: input image with HWC layout + :paam im: input image with HWC layout or HW layout for gray image + :type im: ndarray + :paam is_color: whether color input image or not + :type is_color: bool """ - if len(im.shape) == 3: + if len(im.shape) == 3 and is_color: return im[:, ::-1, :] else: - return im[:, ::-1, :] + return im[:, ::-1] def simple_transform(im, @@ -319,11 +321,11 @@ def simple_transform(im, """ im = resize_short(im, resize_size) if is_train: - im = random_crop(im, crop_size) + im = random_crop(im, crop_size, is_color) if np.random.randint(2) == 0: - im = left_right_flip(im) + im = left_right_flip(im, is_color) else: - im = center_crop(im, crop_size) + im = center_crop(im, crop_size, is_color) if len(im.shape) == 3: im = to_chw(im) @@ -331,8 +333,10 @@ def simple_transform(im, if mean is not None: mean = np.array(mean, dtype=np.float32) # mean value, may be one value per channel - if mean.ndim == 1: + if mean.ndim == 1 and is_color: mean = mean[:, np.newaxis, np.newaxis] + elif mean.ndim == 1: + mean = mean else: # elementwise mean assert len(mean.shape) == len(im) @@ -372,6 +376,6 @@ def load_and_transform(filename, mean values per channel. :type mean: numpy array | list """ - im = load_image(filename) + im = load_image(filename, is_color) im = simple_transform(im, resize_size, crop_size, is_train, is_color, mean) return im