From 4106e54c50057161cba15ec273268d4a96349c3c Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Fri, 28 Aug 2020 19:59:23 +0800 Subject: [PATCH] Fix hapi transform bug (#26738) * fix bug --- .../incubate/hapi/tests/test_transforms.py | 7 +++++- .../hapi/vision/transforms/functional.py | 22 +++++++++++++++++++ .../hapi/vision/transforms/transforms.py | 2 +- 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/python/paddle/incubate/hapi/tests/test_transforms.py b/python/paddle/incubate/hapi/tests/test_transforms.py index 087f2d1615f..84208fda1e9 100644 --- a/python/paddle/incubate/hapi/tests/test_transforms.py +++ b/python/paddle/incubate/hapi/tests/test_transforms.py @@ -64,6 +64,11 @@ class TestTransforms(unittest.TestCase): self.do_transform(trans) + def test_normalize(self): + normalize = transforms.Normalize(mean=0.5, std=0.5) + trans = transforms.Compose([transforms.Permute(mode='CHW'), normalize]) + self.do_transform(trans) + def test_trans_resize(self): trans = transforms.Compose([ transforms.Resize(300, [0, 1]), @@ -165,7 +170,7 @@ class TestTransforms(unittest.TestCase): fake_img = np.random.rand(500, 400, 3).astype('float32') fake_img_gray = trans_gray(fake_img) - np.testing.assert_equal(len(fake_img_gray.shape), 2) + np.testing.assert_equal(len(fake_img_gray.shape), 3) np.testing.assert_equal(fake_img_gray.shape[0], 500) np.testing.assert_equal(fake_img_gray.shape[1], 400) diff --git a/python/paddle/incubate/hapi/vision/transforms/functional.py b/python/paddle/incubate/hapi/vision/transforms/functional.py index f76aa6be8b4..b118ee3fc75 100644 --- a/python/paddle/incubate/hapi/vision/transforms/functional.py +++ b/python/paddle/incubate/hapi/vision/transforms/functional.py @@ -16,6 +16,7 @@ import sys import collections import random import math +import functools import cv2 import numbers @@ -31,6 +32,23 @@ else: __all__ = ['flip', 'resize', 'pad', 'rotate', 'to_grayscale'] +def keepdims(func): + """Keep the dimension of input images unchanged""" + + @functools.wraps(func) + def wrapper(image, *args, **kwargs): + if len(image.shape) != 3: + raise ValueError("Expect image have 3 dims, but got {} dims".format( + len(image.shape))) + ret = func(image, *args, **kwargs) + if len(ret.shape) == 2: + ret = ret[:, :, np.newaxis] + return ret + + return wrapper + + +@keepdims def flip(image, code): """ Accordding to the code (the type of flip), flip the input image @@ -62,6 +80,7 @@ def flip(image, code): return cv2.flip(image, flipCode=code) +@keepdims def resize(img, size, interpolation=cv2.INTER_LINEAR): """ resize the input data to given size @@ -103,6 +122,7 @@ def resize(img, size, interpolation=cv2.INTER_LINEAR): return cv2.resize(img, size[::-1], interpolation=interpolation) +@keepdims def pad(img, padding, fill=(0, 0, 0), padding_mode='constant'): """Pads the given CV Image on all sides with speficified padding mode and fill value. @@ -193,6 +213,7 @@ def pad(img, padding, fill=(0, 0, 0), padding_mode='constant'): return img +@keepdims def rotate(img, angle, interpolation=cv2.INTER_LINEAR, @@ -266,6 +287,7 @@ def rotate(img, return dst.astype(dtype) +@keepdims def to_grayscale(img, num_output_channels=1): """Converts image to grayscale version of image. diff --git a/python/paddle/incubate/hapi/vision/transforms/transforms.py b/python/paddle/incubate/hapi/vision/transforms/transforms.py index 90c6e279959..d46faa0685a 100644 --- a/python/paddle/incubate/hapi/vision/transforms/transforms.py +++ b/python/paddle/incubate/hapi/vision/transforms/transforms.py @@ -505,7 +505,7 @@ class Normalize(object): mean = [mean, mean, mean] if isinstance(std, numbers.Number): - mean = [std, std, std] + std = [std, std, std] self.mean = np.array(mean, dtype=np.float32).reshape(len(mean), 1, 1) self.std = np.array(std, dtype=np.float32).reshape(len(std), 1, 1) -- GitLab