From e66d91b39ebffb6f26ed6023c213879092f5bec6 Mon Sep 17 00:00:00 2001 From: JYChen Date: Fri, 29 Apr 2022 16:37:24 +0800 Subject: [PATCH] add Tensor support colorjitter (#42382) * add Tensor support for sub-functions of colorjitter * add UT --- python/paddle/tests/test_transforms.py | 57 ++++++ python/paddle/vision/transforms/functional.py | 52 +++-- .../vision/transforms/functional_tensor.py | 186 ++++++++++++++++++ 3 files changed, 275 insertions(+), 20 deletions(-) diff --git a/python/paddle/tests/test_transforms.py b/python/paddle/tests/test_transforms.py index 974943a99d8..119b1037278 100644 --- a/python/paddle/tests/test_transforms.py +++ b/python/paddle/tests/test_transforms.py @@ -355,6 +355,10 @@ class TestTransformsTensor(TestTransformsCV2): trans = transforms.Compose([normalize]) self.do_transform(trans) + def test_color_jitter(self): + trans = transforms.Compose([transforms.ColorJitter(1.1, 2.2, 0.8, 0.1)]) + self.do_transform(trans) + def test_pad(self): trans = transforms.Compose([transforms.Pad(2)]) self.do_transform(trans) @@ -562,6 +566,59 @@ class TestFunctional(unittest.TestCase): tensor_cropped_img.numpy().transpose((1, 2, 0)), decimal=4) + def test_color_jitter_sub_function(self): + np.random.seed(555) + np_img = (np.random.rand(28, 28, 3) * 255).astype('uint8') + pil_img = Image.fromarray(np_img) + tensor_img = F.to_tensor(np_img) + np_img = pil_img + + np_img_gray = (np.random.rand(28, 28, 1) * 255).astype('uint8') + tensor_img_gray = F.to_tensor(np_img_gray) + + places = ['cpu'] + if paddle.device.is_compiled_with_cuda(): + places.append('gpu') + + def test_adjust_brightness(np_img, tensor_img): + result_cv2 = np.array(F.adjust_brightness(np_img, 1.2)) + result_tensor = F.adjust_brightness(tensor_img, 1.2).numpy() + result_tensor = np.transpose(result_tensor * 255, + (1, 2, 0)).astype('uint8') + np.testing.assert_equal(result_cv2, result_tensor) + + # For adjust_contrast / adjust_saturation / adjust_hue the implement is kind + # of different between PIL and Tensor. So the results can not equal exactly. + + def test_adjust_contrast(np_img, tensor_img): + result_pil = np.array(F.adjust_contrast(np_img, 0.36)) + result_tensor = F.adjust_contrast(tensor_img, 0.36).numpy() + result_tensor = np.transpose(result_tensor * 255, (1, 2, 0)) + diff = np.max(np.abs(result_tensor - result_pil)) + self.assertTrue(diff < 1.1) + + def test_adjust_saturation(np_img, tensor_img): + result_pil = np.array(F.adjust_saturation(np_img, 1.0)) + result_tensor = F.adjust_saturation(tensor_img, 1.0).numpy() + result_tensor = np.transpose(result_tensor * 255., (1, 2, 0)) + diff = np.max(np.abs(result_tensor - result_pil)) + self.assertTrue(diff < 1.1) + + def test_adjust_hue(np_img, tensor_img): + result_pil = np.array(F.adjust_hue(np_img, 0.45)) + result_tensor = F.adjust_hue(tensor_img, 0.45).numpy() + result_tensor = np.transpose(result_tensor * 255, (1, 2, 0)) + diff = np.max(np.abs(result_tensor - result_pil)) + self.assertTrue(diff <= 16.0) + + for place in places: + paddle.set_device(place) + + test_adjust_brightness(np_img, tensor_img) + test_adjust_contrast(np_img, tensor_img) + test_adjust_saturation(np_img, tensor_img) + test_adjust_hue(np_img, tensor_img) + def test_pad(self): np_img = (np.random.rand(28, 24, 3) * 255).astype('uint8') pil_img = Image.fromarray(np_img) diff --git a/python/paddle/vision/transforms/functional.py b/python/paddle/vision/transforms/functional.py index 8caab964bf8..1afac6e48be 100644 --- a/python/paddle/vision/transforms/functional.py +++ b/python/paddle/vision/transforms/functional.py @@ -370,13 +370,13 @@ def adjust_brightness(img, brightness_factor): """Adjusts brightness of an Image. Args: - img (PIL.Image|np.array): Image to be adjusted. + img (PIL.Image|np.array|paddle.Tensor): Image to be adjusted. brightness_factor (float): How much to adjust the brightness. Can be any non negative number. 0 gives a black image, 1 gives the original image while 2 increases the brightness by a factor of 2. Returns: - PIL.Image or np.array: Brightness adjusted image. + PIL.Image|np.array|paddle.Tensor: Brightness adjusted image. Examples: .. code-block:: python @@ -392,28 +392,31 @@ def adjust_brightness(img, brightness_factor): converted_img = F.adjust_brightness(fake_img, 0.4) print(converted_img.size) """ - if not (_is_pil_image(img) or _is_numpy_image(img)): + if not (_is_pil_image(img) or _is_numpy_image(img) or + _is_tensor_image(img)): raise TypeError( - 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + 'img should be PIL Image or Tensor Image or ndarray with dim=[2 or 3]. Got {}'. format(type(img))) if _is_pil_image(img): return F_pil.adjust_brightness(img, brightness_factor) - else: + elif _is_numpy_image(img): return F_cv2.adjust_brightness(img, brightness_factor) + else: + return F_t.adjust_brightness(img, brightness_factor) def adjust_contrast(img, contrast_factor): """Adjusts contrast of an Image. Args: - img (PIL.Image|np.array): Image to be adjusted. + img (PIL.Image|np.array|paddle.Tensor): Image to be adjusted. contrast_factor (float): How much to adjust the contrast. Can be any non negative number. 0 gives a solid gray image, 1 gives the original image while 2 increases the contrast by a factor of 2. Returns: - PIL.Image or np.array: Contrast adjusted image. + PIL.Image|np.array|paddle.Tensor: Contrast adjusted image. Examples: .. code-block:: python @@ -429,28 +432,31 @@ def adjust_contrast(img, contrast_factor): converted_img = F.adjust_contrast(fake_img, 0.4) print(converted_img.size) """ - if not (_is_pil_image(img) or _is_numpy_image(img)): + if not (_is_pil_image(img) or _is_numpy_image(img) or + _is_tensor_image(img)): raise TypeError( - 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + 'img should be PIL Image or Tensor Image or ndarray with dim=[2 or 3]. Got {}'. format(type(img))) if _is_pil_image(img): return F_pil.adjust_contrast(img, contrast_factor) - else: + elif _is_numpy_image(img): return F_cv2.adjust_contrast(img, contrast_factor) + else: + return F_t.adjust_contrast(img, contrast_factor) def adjust_saturation(img, saturation_factor): """Adjusts color saturation of an image. Args: - img (PIL.Image|np.array): Image to be adjusted. + img (PIL.Image|np.array|paddle.Tensor): Image to be adjusted. saturation_factor (float): How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2. Returns: - PIL.Image or np.array: Saturation adjusted image. + PIL.Image|np.array|paddle.Tensor: Saturation adjusted image. Examples: .. code-block:: python @@ -467,15 +473,18 @@ def adjust_saturation(img, saturation_factor): print(converted_img.size) """ - if not (_is_pil_image(img) or _is_numpy_image(img)): + if not (_is_pil_image(img) or _is_numpy_image(img) or + _is_tensor_image(img)): raise TypeError( - 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + 'img should be PIL Image or Tensor Image or ndarray with dim=[2 or 3]. Got {}'. format(type(img))) if _is_pil_image(img): return F_pil.adjust_saturation(img, saturation_factor) - else: + elif _is_numpy_image(img): return F_cv2.adjust_saturation(img, saturation_factor) + else: + return F_t.adjust_saturation(img, saturation_factor) def adjust_hue(img, hue_factor): @@ -489,7 +498,7 @@ def adjust_hue(img, hue_factor): interval `[-0.5, 0.5]`. Args: - img (PIL.Image|np.array): Image to be adjusted. + img (PIL.Image|np.array|paddle.Tensor): Image to be adjusted. hue_factor (float): How much to shift the hue channel. Should be in [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in HSV space in positive and negative direction respectively. @@ -497,7 +506,7 @@ def adjust_hue(img, hue_factor): with complementary colors while 0 gives the original image. Returns: - PIL.Image or np.array: Hue adjusted image. + PIL.Image|np.array|paddle.Tensor: Hue adjusted image. Examples: .. code-block:: python @@ -514,15 +523,18 @@ def adjust_hue(img, hue_factor): print(converted_img.size) """ - if not (_is_pil_image(img) or _is_numpy_image(img)): + if not (_is_pil_image(img) or _is_numpy_image(img) or + _is_tensor_image(img)): raise TypeError( - 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + 'img should be PIL Image or Tensor Image or ndarray with dim=[2 or 3]. Got {}'. format(type(img))) if _is_pil_image(img): return F_pil.adjust_hue(img, hue_factor) - else: + elif _is_numpy_image(img): return F_cv2.adjust_hue(img, hue_factor) + else: + return F_t.adjust_hue(img, hue_factor) def rotate(img, diff --git a/python/paddle/vision/transforms/functional_tensor.py b/python/paddle/vision/transforms/functional_tensor.py index 5e5cf465425..2d6dc125d42 100644 --- a/python/paddle/vision/transforms/functional_tensor.py +++ b/python/paddle/vision/transforms/functional_tensor.py @@ -86,6 +86,68 @@ def _get_image_size(img, data_format): _get_image_h_axis(data_format)] +def _rgb_to_hsv(img): + """Convert a image Tensor from RGB to HSV. This implementation is based on Pillow ( + https://github.com/python-pillow/Pillow/blob/main/src/libImaging/Convert.c) + """ + maxc = img.max(axis=-3) + minc = img.min(axis=-3) + + is_equal = paddle.equal(maxc, minc) + one_divisor = paddle.ones_like(maxc) + c_delta = maxc - minc + # s is 0 when maxc == minc, set the divisor to 1 to avoid zero divide. + s = c_delta / paddle.where(is_equal, one_divisor, maxc) + + r, g, b = img.unbind(axis=-3) + c_delta_divisor = paddle.where(is_equal, one_divisor, c_delta) + # when maxc == minc, there is r == g == b, set the divisor to 1 to avoid zero divide. + rc = (maxc - r) / c_delta_divisor + gc = (maxc - g) / c_delta_divisor + bc = (maxc - b) / c_delta_divisor + + hr = (maxc == r).astype(maxc.dtype) * (bc - gc) + hg = ((maxc == g) & (maxc != r)).astype(maxc.dtype) * (rc - bc + 2.0) + hb = ((maxc != r) & (maxc != g)).astype(maxc.dtype) * (gc - rc + 4.0) + h = (hr + hg + hb) / 6.0 + 1.0 + h = h - h.trunc() + return paddle.stack([h, s, maxc], axis=-3) + + +def _hsv_to_rgb(img): + """Convert a image Tensor from HSV to RGB. + """ + h, s, v = img.unbind(axis=-3) + f = h * 6.0 + i = paddle.floor(f) + f = f - i + i = i.astype(paddle.int32) % 6 + + p = paddle.clip(v * (1.0 - s), 0.0, 1.0) + q = paddle.clip(v * (1.0 - s * f), 0.0, 1.0) + t = paddle.clip(v * (1.0 - s * (1.0 - f)), 0.0, 1.0) + + mask = paddle.equal( + i.unsqueeze(axis=-3), + paddle.arange( + 6, dtype=i.dtype).reshape((-1, 1, 1))).astype(img.dtype) + matrix = paddle.stack( + [ + paddle.stack( + [v, q, p, p, t, v], axis=-3), paddle.stack( + [t, v, v, q, p, p], axis=-3), paddle.stack( + [p, p, t, v, v, q], axis=-3) + ], + axis=-4) + return paddle.einsum("...ijk, ...xijk -> ...xjk", mask, matrix) + + +def _blend_images(img1, img2, ratio): + max_value = 1.0 if paddle.is_floating_point(img1) else 255.0 + return paddle.lerp(img2, img1, float(ratio)).clip( + 0, max_value).astype(img1.dtype) + + def normalize(img, mean, std, data_format='CHW'): """Normalizes a tensor image given mean and standard deviation. @@ -514,3 +576,127 @@ def resize(img, size, interpolation='bilinear', data_format='CHW'): data_format='N' + data_format.upper()) return img.squeeze(0) + + +def adjust_brightness(img, brightness_factor): + """Adjusts brightness of an Image. + + Args: + img (paddle.Tensor): Image to be adjusted. + brightness_factor (float): How much to adjust the brightness. Can be + any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + + Returns: + paddle.Tensor: Brightness adjusted image. + + """ + _assert_image_tensor(img, 'CHW') + assert brightness_factor >= 0, "brightness_factor should be non-negative." + assert _get_image_num_channels( + img, 'CHW') in [1, 3], "channels of input should be either 1 or 3." + + extreme_target = paddle.zeros_like(img, img.dtype) + return _blend_images(img, extreme_target, brightness_factor) + + +def adjust_contrast(img, contrast_factor): + """Adjusts contrast of an image. + + Args: + img (paddle.Tensor): Image to be adjusted. + contrast_factor (float): How much to adjust the contrast. Can be any + non negative number. 0 gives a solid gray image, 1 gives the + original image while 2 increases the contrast by a factor of 2. + + Returns: + paddle.Tensor: Contrast adjusted image. + + """ + _assert_image_tensor(img, 'chw') + assert contrast_factor >= 0, "contrast_factor should be non-negative." + + channels = _get_image_num_channels(img, 'CHW') + dtype = img.dtype if paddle.is_floating_point(img) else paddle.float32 + if channels == 1: + extreme_target = paddle.mean( + img.astype(dtype), axis=(-3, -2, -1), keepdim=True) + elif channels == 3: + extreme_target = paddle.mean( + to_grayscale(img).astype(dtype), axis=(-3, -2, -1), keepdim=True) + else: + raise ValueError("channels of input should be either 1 or 3.") + + return _blend_images(img, extreme_target, contrast_factor) + + +def adjust_saturation(img, saturation_factor): + """Adjusts color saturation of an image. + + Args: + img (paddle.Tensor): Image to be adjusted. + saturation_factor (float): How much to adjust the saturation. 0 will + give a black and white image, 1 will give the original image while + 2 will enhance the saturation by a factor of 2. + + Returns: + paddle.Tensor: Saturation adjusted image. + + """ + _assert_image_tensor(img, 'CHW') + assert saturation_factor >= 0, "saturation_factor should be non-negative." + channels = _get_image_num_channels(img, 'CHW') + if channels == 1: + return img + elif channels == 3: + extreme_target = to_grayscale(img) + else: + raise ValueError("channels of input should be either 1 or 3.") + + return _blend_images(img, extreme_target, saturation_factor) + + +def adjust_hue(img, hue_factor): + """Adjusts hue of an image. + + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + + Args: + img (paddle.Tensor): Image to be adjusted. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + + Returns: + paddle.Tensor: Hue adjusted image. + + """ + _assert_image_tensor(img, 'CHW') + assert hue_factor >= -0.5 and hue_factor <= 0.5, "hue_factor should be in range [-0.5, 0.5]" + channels = _get_image_num_channels(img, 'CHW') + if channels == 1: + return img + elif channels == 3: + dtype = img.dtype + if dtype == paddle.uint8: + img = img.astype(paddle.float32) / 255.0 + + img_hsv = _rgb_to_hsv(img) + h, s, v = img_hsv.unbind(axis=-3) + h = (h + hue_factor) + h = h - h.floor() + img_adjusted = _hsv_to_rgb(paddle.stack([h, s, v], axis=-3)) + + if dtype == paddle.uint8: + img_adjusted = (img_adjusted * 255.0).astype(dtype) + else: + raise ValueError("channels of input should be either 1 or 3.") + + return img_adjusted -- GitLab