提交 353c3260 编写于 作者: L LielinJiang

update transform

上级 f2212a71
......@@ -51,7 +51,7 @@ class ImageNetDataset(DatasetFolder):
img_path, label = self.samples[idx]
img = cv2.imread(img_path).astype(np.float32)
label = np.array([label])
return self.transform(img, label)
return self.transform(img), label
def __len__(self):
return len(self.samples)
......@@ -150,7 +150,7 @@ class DatasetFolder(Dataset):
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample, target = self.transform(sample, target)
sample, target = self.transform(sample)
return sample, target
......
......@@ -64,10 +64,10 @@ class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, *data):
def __call__(self, data):
for f in self.transforms:
try:
data = f(*data)
data = f(data)
except Exception as e:
stack_info = traceback.format_exc()
print("fail to perform transform [{}] with error: "
......@@ -130,8 +130,8 @@ class Resize(object):
self.size = size
self.interpolation = interpolation
def __call__(self, img, lbl):
return F.resize(img, self.size, self.interpolation), lbl
def __call__(self, img):
return F.resize(img, self.size, self.interpolation)
class RandomResizedCrop(object):
......@@ -193,10 +193,10 @@ class RandomResizedCrop(object):
y = (height - h) // 2
return x, y, w, h
def __call__(self, img, lbl):
def __call__(self, img):
x, y, w, h = self._get_params(img)
cropped_img = img[y:y + h, x:x + w]
return F.resize(cropped_img, self.output_size, self.interpolation), lbl
return F.resize(cropped_img, self.output_size, self.interpolation)
class CenterCropResize(object):
......@@ -224,10 +224,10 @@ class CenterCropResize(object):
y = (w + 1 - c) // 2
return c, x, y
def __call__(self, img, lbl):
def __call__(self, img):
c, x, y = self._get_params(img)
cropped_img = img[x:x + c, y:y + c, :]
return F.resize(cropped_img, self.size, self.interpolation), lbl
return F.resize(cropped_img, self.size, self.interpolation)
class CenterCrop(object):
......@@ -251,10 +251,10 @@ class CenterCrop(object):
y = int(round((h - th) / 2.0))
return x, y
def __call__(self, img, lbl):
def __call__(self, img):
x, y = self._get_params(img)
th, tw = self.output_size
return img[y:y + th, x:x + tw], lbl
return img[y:y + th, x:x + tw]
class RandomHorizontalFlip(object):
......@@ -267,10 +267,10 @@ class RandomHorizontalFlip(object):
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, img, lbl):
def __call__(self, img):
if np.random.random() < self.prob:
return F.flip(img, code=1), lbl
return img, lbl
return F.flip(img, code=1)
return img
class RandomVerticalFlip(object):
......@@ -283,10 +283,10 @@ class RandomVerticalFlip(object):
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, img, lbl):
def __call__(self, img):
if np.random.random() < self.prob:
return F.flip(img, code=0), lbl
return img, lbl
return F.flip(img, code=0)
return img
class Normalize(object):
......@@ -311,8 +311,8 @@ class Normalize(object):
self.mean = np.array(mean, dtype=np.float32).reshape(len(mean), 1, 1)
self.std = np.array(std, dtype=np.float32).reshape(len(std), 1, 1)
def __call__(self, img, lbl):
return (img - self.mean) / self.std, lbl
def __call__(self, img):
return (img - self.mean) / self.std
class Permute(object):
......@@ -333,12 +333,12 @@ class Permute(object):
self.mode = mode
self.to_rgb = to_rgb
def __call__(self, img, lbl):
def __call__(self, img):
if self.to_rgb:
img = img[..., ::-1]
if self.mode == "CHW":
return img.transpose((2, 0, 1)), lbl
return img, lbl
return img.transpose((2, 0, 1))
return img
class GaussianNoise(object):
......@@ -354,11 +354,11 @@ class GaussianNoise(object):
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
def __call__(self, img, lbl):
def __call__(self, img):
dtype = img.dtype
noise = np.random.normal(self.mean, self.std, img.shape) * 255
img = img + noise.astype(np.float32)
return np.clip(img, 0, 255).astype(dtype), lbl
return np.clip(img, 0, 255).astype(dtype)
class BrightnessTransform(object):
......@@ -374,15 +374,15 @@ class BrightnessTransform(object):
raise ValueError("brightness value should be non-negative")
self.value = value
def __call__(self, img, lbl):
def __call__(self, img):
if self.value == 0:
return img, lbl
return img
dtype = img.dtype
img = img.astype(np.float32)
alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
img = img * alpha
return img.clip(0, 255).astype(dtype), lbl
return img.clip(0, 255).astype(dtype)
class ContrastTransform(object):
......@@ -398,16 +398,16 @@ class ContrastTransform(object):
raise ValueError("contrast value should be non-negative")
self.value = value
def __call__(self, img, lbl):
def __call__(self, img):
if self.value == 0:
return img, lbl
return img
dtype = img.dtype
img = img.astype(np.float32)
alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
img = img * alpha + cv2.cvtColor(img, cv2.COLOR_BGR2GRAY).mean() * (
1 - alpha)
return img.clip(0, 255).astype(dtype), lbl
return img.clip(0, 255).astype(dtype)
class SaturationTransform(object):
......@@ -423,9 +423,9 @@ class SaturationTransform(object):
raise ValueError("saturation value should be non-negative")
self.value = value
def __call__(self, img, lbl):
def __call__(self, img):
if self.value == 0:
return img, lbl
return img
dtype = img.dtype
img = img.astype(np.float32)
......@@ -433,7 +433,7 @@ class SaturationTransform(object):
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray_img = gray_img[..., np.newaxis]
img = img * alpha + gray_img * (1 - alpha)
return img.clip(0, 255).astype(dtype), lbl
return img.clip(0, 255).astype(dtype)
class HueTransform(object):
......@@ -449,9 +449,9 @@ class HueTransform(object):
raise ValueError("hue value should be in [0.0, 0.5]")
self.value = value
def __call__(self, img, lbl):
def __call__(self, img):
if self.value == 0:
return img, lbl
return img
dtype = img.dtype
img = img.astype(np.uint8)
......@@ -464,7 +464,7 @@ class HueTransform(object):
with np.errstate(over="ignore"):
h += np.uint8(alpha * 255)
hsv_img = cv2.merge([h, s, v])
return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR_FULL).astype(dtype), lbl
return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR_FULL).astype(dtype)
class ColorJitter(object):
......@@ -499,5 +499,5 @@ class ColorJitter(object):
random.shuffle(transforms)
self.transforms = Compose(transforms)
def __call__(self, img, lbl):
return self.transforms(img, lbl)
def __call__(self, img):
return self.transforms(img)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册