未验证 提交 4106e54c 编写于 作者: L LielinJiang 提交者: GitHub

Fix hapi transform bug (#26738)

* fix bug
上级 64df9b99
...@@ -64,6 +64,11 @@ class TestTransforms(unittest.TestCase): ...@@ -64,6 +64,11 @@ class TestTransforms(unittest.TestCase):
self.do_transform(trans) self.do_transform(trans)
def test_normalize(self):
normalize = transforms.Normalize(mean=0.5, std=0.5)
trans = transforms.Compose([transforms.Permute(mode='CHW'), normalize])
self.do_transform(trans)
def test_trans_resize(self): def test_trans_resize(self):
trans = transforms.Compose([ trans = transforms.Compose([
transforms.Resize(300, [0, 1]), transforms.Resize(300, [0, 1]),
...@@ -165,7 +170,7 @@ class TestTransforms(unittest.TestCase): ...@@ -165,7 +170,7 @@ class TestTransforms(unittest.TestCase):
fake_img = np.random.rand(500, 400, 3).astype('float32') fake_img = np.random.rand(500, 400, 3).astype('float32')
fake_img_gray = trans_gray(fake_img) fake_img_gray = trans_gray(fake_img)
np.testing.assert_equal(len(fake_img_gray.shape), 2) np.testing.assert_equal(len(fake_img_gray.shape), 3)
np.testing.assert_equal(fake_img_gray.shape[0], 500) np.testing.assert_equal(fake_img_gray.shape[0], 500)
np.testing.assert_equal(fake_img_gray.shape[1], 400) np.testing.assert_equal(fake_img_gray.shape[1], 400)
......
...@@ -16,6 +16,7 @@ import sys ...@@ -16,6 +16,7 @@ import sys
import collections import collections
import random import random
import math import math
import functools
import cv2 import cv2
import numbers import numbers
...@@ -31,6 +32,23 @@ else: ...@@ -31,6 +32,23 @@ else:
__all__ = ['flip', 'resize', 'pad', 'rotate', 'to_grayscale'] __all__ = ['flip', 'resize', 'pad', 'rotate', 'to_grayscale']
def keepdims(func):
"""Keep the dimension of input images unchanged"""
@functools.wraps(func)
def wrapper(image, *args, **kwargs):
if len(image.shape) != 3:
raise ValueError("Expect image have 3 dims, but got {} dims".format(
len(image.shape)))
ret = func(image, *args, **kwargs)
if len(ret.shape) == 2:
ret = ret[:, :, np.newaxis]
return ret
return wrapper
@keepdims
def flip(image, code): def flip(image, code):
""" """
Accordding to the code (the type of flip), flip the input image Accordding to the code (the type of flip), flip the input image
...@@ -62,6 +80,7 @@ def flip(image, code): ...@@ -62,6 +80,7 @@ def flip(image, code):
return cv2.flip(image, flipCode=code) return cv2.flip(image, flipCode=code)
@keepdims
def resize(img, size, interpolation=cv2.INTER_LINEAR): def resize(img, size, interpolation=cv2.INTER_LINEAR):
""" """
resize the input data to given size resize the input data to given size
...@@ -103,6 +122,7 @@ def resize(img, size, interpolation=cv2.INTER_LINEAR): ...@@ -103,6 +122,7 @@ def resize(img, size, interpolation=cv2.INTER_LINEAR):
return cv2.resize(img, size[::-1], interpolation=interpolation) return cv2.resize(img, size[::-1], interpolation=interpolation)
@keepdims
def pad(img, padding, fill=(0, 0, 0), padding_mode='constant'): def pad(img, padding, fill=(0, 0, 0), padding_mode='constant'):
"""Pads the given CV Image on all sides with speficified padding mode and fill value. """Pads the given CV Image on all sides with speficified padding mode and fill value.
...@@ -193,6 +213,7 @@ def pad(img, padding, fill=(0, 0, 0), padding_mode='constant'): ...@@ -193,6 +213,7 @@ def pad(img, padding, fill=(0, 0, 0), padding_mode='constant'):
return img return img
@keepdims
def rotate(img, def rotate(img,
angle, angle,
interpolation=cv2.INTER_LINEAR, interpolation=cv2.INTER_LINEAR,
...@@ -266,6 +287,7 @@ def rotate(img, ...@@ -266,6 +287,7 @@ def rotate(img,
return dst.astype(dtype) return dst.astype(dtype)
@keepdims
def to_grayscale(img, num_output_channels=1): def to_grayscale(img, num_output_channels=1):
"""Converts image to grayscale version of image. """Converts image to grayscale version of image.
......
...@@ -505,7 +505,7 @@ class Normalize(object): ...@@ -505,7 +505,7 @@ class Normalize(object):
mean = [mean, mean, mean] mean = [mean, mean, mean]
if isinstance(std, numbers.Number): if isinstance(std, numbers.Number):
mean = [std, std, std] std = [std, std, std]
self.mean = np.array(mean, dtype=np.float32).reshape(len(mean), 1, 1) self.mean = np.array(mean, dtype=np.float32).reshape(len(mean), 1, 1)
self.std = np.array(std, dtype=np.float32).reshape(len(std), 1, 1) self.std = np.array(std, dtype=np.float32).reshape(len(std), 1, 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册