未验证 提交 4106e54c 编写于 作者: L LielinJiang 提交者: GitHub

Fix hapi transform bug (#26738)

* fix bug
上级 64df9b99
......@@ -64,6 +64,11 @@ class TestTransforms(unittest.TestCase):
self.do_transform(trans)
def test_normalize(self):
normalize = transforms.Normalize(mean=0.5, std=0.5)
trans = transforms.Compose([transforms.Permute(mode='CHW'), normalize])
self.do_transform(trans)
def test_trans_resize(self):
trans = transforms.Compose([
transforms.Resize(300, [0, 1]),
......@@ -165,7 +170,7 @@ class TestTransforms(unittest.TestCase):
fake_img = np.random.rand(500, 400, 3).astype('float32')
fake_img_gray = trans_gray(fake_img)
np.testing.assert_equal(len(fake_img_gray.shape), 2)
np.testing.assert_equal(len(fake_img_gray.shape), 3)
np.testing.assert_equal(fake_img_gray.shape[0], 500)
np.testing.assert_equal(fake_img_gray.shape[1], 400)
......
......@@ -16,6 +16,7 @@ import sys
import collections
import random
import math
import functools
import cv2
import numbers
......@@ -31,6 +32,23 @@ else:
__all__ = ['flip', 'resize', 'pad', 'rotate', 'to_grayscale']
def keepdims(func):
"""Keep the dimension of input images unchanged"""
@functools.wraps(func)
def wrapper(image, *args, **kwargs):
if len(image.shape) != 3:
raise ValueError("Expect image have 3 dims, but got {} dims".format(
len(image.shape)))
ret = func(image, *args, **kwargs)
if len(ret.shape) == 2:
ret = ret[:, :, np.newaxis]
return ret
return wrapper
@keepdims
def flip(image, code):
"""
Accordding to the code (the type of flip), flip the input image
......@@ -62,6 +80,7 @@ def flip(image, code):
return cv2.flip(image, flipCode=code)
@keepdims
def resize(img, size, interpolation=cv2.INTER_LINEAR):
"""
resize the input data to given size
......@@ -103,6 +122,7 @@ def resize(img, size, interpolation=cv2.INTER_LINEAR):
return cv2.resize(img, size[::-1], interpolation=interpolation)
@keepdims
def pad(img, padding, fill=(0, 0, 0), padding_mode='constant'):
"""Pads the given CV Image on all sides with speficified padding mode and fill value.
......@@ -193,6 +213,7 @@ def pad(img, padding, fill=(0, 0, 0), padding_mode='constant'):
return img
@keepdims
def rotate(img,
angle,
interpolation=cv2.INTER_LINEAR,
......@@ -266,6 +287,7 @@ def rotate(img,
return dst.astype(dtype)
@keepdims
def to_grayscale(img, num_output_channels=1):
"""Converts image to grayscale version of image.
......
......@@ -505,7 +505,7 @@ class Normalize(object):
mean = [mean, mean, mean]
if isinstance(std, numbers.Number):
mean = [std, std, std]
std = [std, std, std]
self.mean = np.array(mean, dtype=np.float32).reshape(len(mean), 1, 1)
self.std = np.array(std, dtype=np.float32).reshape(len(std), 1, 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册