提交 e845a1d5 编写于 作者: Q qingqing01

Enhance image.py for gray image.

上级 57528920
...@@ -176,7 +176,6 @@ def resize_short(im, size): ...@@ -176,7 +176,6 @@ def resize_short(im, size):
:param size: the shorter edge size of image after resizing. :param size: the shorter edge size of image after resizing.
:type size: int :type size: int
""" """
assert im.shape[-1] == 1 or im.shape[-1] == 3
h, w = im.shape[:2] h, w = im.shape[:2]
h_new, w_new = size, size h_new, w_new = size, size
if h > w: if h > w:
...@@ -267,7 +266,7 @@ def random_crop(im, size, is_color=True): ...@@ -267,7 +266,7 @@ def random_crop(im, size, is_color=True):
return im return im
def left_right_flip(im): def left_right_flip(im, is_color=True):
""" """
Flip an image along the horizontal direction. Flip an image along the horizontal direction.
Return the flipped image. Return the flipped image.
...@@ -278,13 +277,16 @@ def left_right_flip(im): ...@@ -278,13 +277,16 @@ def left_right_flip(im):
im = left_right_flip(im) im = left_right_flip(im)
:paam im: input image with HWC layout :paam im: input image with HWC layout or HW layout for gray image
:type im: ndarray :type im: ndarray
:paam is_color: whether color input image or not
:type is_color: bool
""" """
if len(im.shape) == 3: if len(im.shape) == 3 and is_color:
return im[:, ::-1, :] return im[:, ::-1, :]
else: else:
return im[:, ::-1, :] return im[:, ::-1]
def simple_transform(im, def simple_transform(im,
...@@ -319,11 +321,11 @@ def simple_transform(im, ...@@ -319,11 +321,11 @@ def simple_transform(im,
""" """
im = resize_short(im, resize_size) im = resize_short(im, resize_size)
if is_train: if is_train:
im = random_crop(im, crop_size) im = random_crop(im, crop_size, is_color)
if np.random.randint(2) == 0: if np.random.randint(2) == 0:
im = left_right_flip(im) im = left_right_flip(im, is_color)
else: else:
im = center_crop(im, crop_size) im = center_crop(im, crop_size, is_color)
if len(im.shape) == 3: if len(im.shape) == 3:
im = to_chw(im) im = to_chw(im)
...@@ -331,8 +333,10 @@ def simple_transform(im, ...@@ -331,8 +333,10 @@ def simple_transform(im,
if mean is not None: if mean is not None:
mean = np.array(mean, dtype=np.float32) mean = np.array(mean, dtype=np.float32)
# mean value, may be one value per channel # mean value, may be one value per channel
if mean.ndim == 1: if mean.ndim == 1 and is_color:
mean = mean[:, np.newaxis, np.newaxis] mean = mean[:, np.newaxis, np.newaxis]
elif mean.ndim == 1:
mean = mean
else: else:
# elementwise mean # elementwise mean
assert len(mean.shape) == len(im) assert len(mean.shape) == len(im)
...@@ -372,6 +376,6 @@ def load_and_transform(filename, ...@@ -372,6 +376,6 @@ def load_and_transform(filename,
mean values per channel. mean values per channel.
:type mean: numpy array | list :type mean: numpy array | list
""" """
im = load_image(filename) im = load_image(filename, is_color)
im = simple_transform(im, resize_size, crop_size, is_train, is_color, mean) im = simple_transform(im, resize_size, crop_size, is_train, is_color, mean)
return im return im
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册