diff --git a/python/paddle/v2/image.py b/python/paddle/v2/image.py index a7bb22a35519b87e196b014056649f3a1bfa504a..e5000e440cc8d822dbd38dce3978d2722d32ebe4 100644 --- a/python/paddle/v2/image.py +++ b/python/paddle/v2/image.py @@ -176,7 +176,6 @@ def resize_short(im, size): :param size: the shorter edge size of image after resizing. :type size: int """ - assert im.shape[-1] == 1 or im.shape[-1] == 3 h, w = im.shape[:2] h_new, w_new = size, size if h > w: @@ -267,7 +266,7 @@ def random_crop(im, size, is_color=True): return im -def left_right_flip(im): +def left_right_flip(im, is_color=True): """ Flip an image along the horizontal direction. Return the flipped image. @@ -278,13 +277,15 @@ def left_right_flip(im): im = left_right_flip(im) - :paam im: input image with HWC layout + :param im: input image with HWC layout or HW layout for gray image :type im: ndarray + :param is_color: whether input image is color or not + :type is_color: bool """ - if len(im.shape) == 3: + if len(im.shape) == 3 and is_color: return im[:, ::-1, :] else: - return im[:, ::-1, :] + return im[:, ::-1] def simple_transform(im, @@ -321,8 +322,9 @@ def simple_transform(im, if is_train: im = random_crop(im, crop_size, is_color=is_color) if np.random.randint(2) == 0: - im = left_right_flip(im) + im = left_right_flip(im, is_color) else: + im = center_crop(im, crop_size, is_color) im = center_crop(im, crop_size, is_color=is_color) if len(im.shape) == 3: im = to_chw(im) @@ -331,8 +333,10 @@ def simple_transform(im, if mean is not None: mean = np.array(mean, dtype=np.float32) # mean value, may be one value per channel - if mean.ndim == 1: + if mean.ndim == 1 and is_color: mean = mean[:, np.newaxis, np.newaxis] + elif mean.ndim == 1: + mean = mean else: # elementwise mean assert len(mean.shape) == len(im) @@ -372,6 +376,6 @@ def load_and_transform(filename, mean values per channel. :type mean: numpy array | list """ - im = load_image(filename) + im = load_image(filename, is_color) im = simple_transform(im, resize_size, crop_size, is_train, is_color, mean) return im