提交 353c3260 编写于 作者: L LielinJiang

update transform

上级 f2212a71
...@@ -51,7 +51,7 @@ class ImageNetDataset(DatasetFolder): ...@@ -51,7 +51,7 @@ class ImageNetDataset(DatasetFolder):
img_path, label = self.samples[idx] img_path, label = self.samples[idx]
img = cv2.imread(img_path).astype(np.float32) img = cv2.imread(img_path).astype(np.float32)
label = np.array([label]) label = np.array([label])
return self.transform(img, label) return self.transform(img), label
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
...@@ -150,7 +150,7 @@ class DatasetFolder(Dataset): ...@@ -150,7 +150,7 @@ class DatasetFolder(Dataset):
path, target = self.samples[index] path, target = self.samples[index]
sample = self.loader(path) sample = self.loader(path)
if self.transform is not None: if self.transform is not None:
sample, target = self.transform(sample, target) sample, target = self.transform(sample)
return sample, target return sample, target
......
...@@ -64,10 +64,10 @@ class Compose(object): ...@@ -64,10 +64,10 @@ class Compose(object):
def __init__(self, transforms): def __init__(self, transforms):
self.transforms = transforms self.transforms = transforms
def __call__(self, *data): def __call__(self, data):
for f in self.transforms: for f in self.transforms:
try: try:
data = f(*data) data = f(data)
except Exception as e: except Exception as e:
stack_info = traceback.format_exc() stack_info = traceback.format_exc()
print("fail to perform transform [{}] with error: " print("fail to perform transform [{}] with error: "
...@@ -130,8 +130,8 @@ class Resize(object): ...@@ -130,8 +130,8 @@ class Resize(object):
self.size = size self.size = size
self.interpolation = interpolation self.interpolation = interpolation
def __call__(self, img, lbl): def __call__(self, img):
return F.resize(img, self.size, self.interpolation), lbl return F.resize(img, self.size, self.interpolation)
class RandomResizedCrop(object): class RandomResizedCrop(object):
...@@ -193,10 +193,10 @@ class RandomResizedCrop(object): ...@@ -193,10 +193,10 @@ class RandomResizedCrop(object):
y = (height - h) // 2 y = (height - h) // 2
return x, y, w, h return x, y, w, h
def __call__(self, img, lbl): def __call__(self, img):
x, y, w, h = self._get_params(img) x, y, w, h = self._get_params(img)
cropped_img = img[y:y + h, x:x + w] cropped_img = img[y:y + h, x:x + w]
return F.resize(cropped_img, self.output_size, self.interpolation), lbl return F.resize(cropped_img, self.output_size, self.interpolation)
class CenterCropResize(object): class CenterCropResize(object):
...@@ -224,10 +224,10 @@ class CenterCropResize(object): ...@@ -224,10 +224,10 @@ class CenterCropResize(object):
y = (w + 1 - c) // 2 y = (w + 1 - c) // 2
return c, x, y return c, x, y
def __call__(self, img, lbl): def __call__(self, img):
c, x, y = self._get_params(img) c, x, y = self._get_params(img)
cropped_img = img[x:x + c, y:y + c, :] cropped_img = img[x:x + c, y:y + c, :]
return F.resize(cropped_img, self.size, self.interpolation), lbl return F.resize(cropped_img, self.size, self.interpolation)
class CenterCrop(object): class CenterCrop(object):
...@@ -251,10 +251,10 @@ class CenterCrop(object): ...@@ -251,10 +251,10 @@ class CenterCrop(object):
y = int(round((h - th) / 2.0)) y = int(round((h - th) / 2.0))
return x, y return x, y
def __call__(self, img, lbl): def __call__(self, img):
x, y = self._get_params(img) x, y = self._get_params(img)
th, tw = self.output_size th, tw = self.output_size
return img[y:y + th, x:x + tw], lbl return img[y:y + th, x:x + tw]
class RandomHorizontalFlip(object): class RandomHorizontalFlip(object):
...@@ -267,10 +267,10 @@ class RandomHorizontalFlip(object): ...@@ -267,10 +267,10 @@ class RandomHorizontalFlip(object):
def __init__(self, prob=0.5): def __init__(self, prob=0.5):
self.prob = prob self.prob = prob
def __call__(self, img, lbl): def __call__(self, img):
if np.random.random() < self.prob: if np.random.random() < self.prob:
return F.flip(img, code=1), lbl return F.flip(img, code=1)
return img, lbl return img
class RandomVerticalFlip(object): class RandomVerticalFlip(object):
...@@ -283,10 +283,10 @@ class RandomVerticalFlip(object): ...@@ -283,10 +283,10 @@ class RandomVerticalFlip(object):
def __init__(self, prob=0.5): def __init__(self, prob=0.5):
self.prob = prob self.prob = prob
def __call__(self, img, lbl): def __call__(self, img):
if np.random.random() < self.prob: if np.random.random() < self.prob:
return F.flip(img, code=0), lbl return F.flip(img, code=0)
return img, lbl return img
class Normalize(object): class Normalize(object):
...@@ -311,8 +311,8 @@ class Normalize(object): ...@@ -311,8 +311,8 @@ class Normalize(object):
self.mean = np.array(mean, dtype=np.float32).reshape(len(mean), 1, 1) self.mean = np.array(mean, dtype=np.float32).reshape(len(mean), 1, 1)
self.std = np.array(std, dtype=np.float32).reshape(len(std), 1, 1) self.std = np.array(std, dtype=np.float32).reshape(len(std), 1, 1)
def __call__(self, img, lbl): def __call__(self, img):
return (img - self.mean) / self.std, lbl return (img - self.mean) / self.std
class Permute(object): class Permute(object):
...@@ -333,12 +333,12 @@ class Permute(object): ...@@ -333,12 +333,12 @@ class Permute(object):
self.mode = mode self.mode = mode
self.to_rgb = to_rgb self.to_rgb = to_rgb
def __call__(self, img, lbl): def __call__(self, img):
if self.to_rgb: if self.to_rgb:
img = img[..., ::-1] img = img[..., ::-1]
if self.mode == "CHW": if self.mode == "CHW":
return img.transpose((2, 0, 1)), lbl return img.transpose((2, 0, 1))
return img, lbl return img
class GaussianNoise(object): class GaussianNoise(object):
...@@ -354,11 +354,11 @@ class GaussianNoise(object): ...@@ -354,11 +354,11 @@ class GaussianNoise(object):
self.mean = np.array(mean, dtype=np.float32) self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32) self.std = np.array(std, dtype=np.float32)
def __call__(self, img, lbl): def __call__(self, img):
dtype = img.dtype dtype = img.dtype
noise = np.random.normal(self.mean, self.std, img.shape) * 255 noise = np.random.normal(self.mean, self.std, img.shape) * 255
img = img + noise.astype(np.float32) img = img + noise.astype(np.float32)
return np.clip(img, 0, 255).astype(dtype), lbl return np.clip(img, 0, 255).astype(dtype)
class BrightnessTransform(object): class BrightnessTransform(object):
...@@ -374,15 +374,15 @@ class BrightnessTransform(object): ...@@ -374,15 +374,15 @@ class BrightnessTransform(object):
raise ValueError("brightness value should be non-negative") raise ValueError("brightness value should be non-negative")
self.value = value self.value = value
def __call__(self, img, lbl): def __call__(self, img):
if self.value == 0: if self.value == 0:
return img, lbl return img
dtype = img.dtype dtype = img.dtype
img = img.astype(np.float32) img = img.astype(np.float32)
alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value) alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
img = img * alpha img = img * alpha
return img.clip(0, 255).astype(dtype), lbl return img.clip(0, 255).astype(dtype)
class ContrastTransform(object): class ContrastTransform(object):
...@@ -398,16 +398,16 @@ class ContrastTransform(object): ...@@ -398,16 +398,16 @@ class ContrastTransform(object):
raise ValueError("contrast value should be non-negative") raise ValueError("contrast value should be non-negative")
self.value = value self.value = value
def __call__(self, img, lbl): def __call__(self, img):
if self.value == 0: if self.value == 0:
return img, lbl return img
dtype = img.dtype dtype = img.dtype
img = img.astype(np.float32) img = img.astype(np.float32)
alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value) alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
img = img * alpha + cv2.cvtColor(img, cv2.COLOR_BGR2GRAY).mean() * ( img = img * alpha + cv2.cvtColor(img, cv2.COLOR_BGR2GRAY).mean() * (
1 - alpha) 1 - alpha)
return img.clip(0, 255).astype(dtype), lbl return img.clip(0, 255).astype(dtype)
class SaturationTransform(object): class SaturationTransform(object):
...@@ -423,9 +423,9 @@ class SaturationTransform(object): ...@@ -423,9 +423,9 @@ class SaturationTransform(object):
raise ValueError("saturation value should be non-negative") raise ValueError("saturation value should be non-negative")
self.value = value self.value = value
def __call__(self, img, lbl): def __call__(self, img):
if self.value == 0: if self.value == 0:
return img, lbl return img
dtype = img.dtype dtype = img.dtype
img = img.astype(np.float32) img = img.astype(np.float32)
...@@ -433,7 +433,7 @@ class SaturationTransform(object): ...@@ -433,7 +433,7 @@ class SaturationTransform(object):
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray_img = gray_img[..., np.newaxis] gray_img = gray_img[..., np.newaxis]
img = img * alpha + gray_img * (1 - alpha) img = img * alpha + gray_img * (1 - alpha)
return img.clip(0, 255).astype(dtype), lbl return img.clip(0, 255).astype(dtype)
class HueTransform(object): class HueTransform(object):
...@@ -449,9 +449,9 @@ class HueTransform(object): ...@@ -449,9 +449,9 @@ class HueTransform(object):
raise ValueError("hue value should be in [0.0, 0.5]") raise ValueError("hue value should be in [0.0, 0.5]")
self.value = value self.value = value
def __call__(self, img, lbl): def __call__(self, img):
if self.value == 0: if self.value == 0:
return img, lbl return img
dtype = img.dtype dtype = img.dtype
img = img.astype(np.uint8) img = img.astype(np.uint8)
...@@ -464,7 +464,7 @@ class HueTransform(object): ...@@ -464,7 +464,7 @@ class HueTransform(object):
with np.errstate(over="ignore"): with np.errstate(over="ignore"):
h += np.uint8(alpha * 255) h += np.uint8(alpha * 255)
hsv_img = cv2.merge([h, s, v]) hsv_img = cv2.merge([h, s, v])
return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR_FULL).astype(dtype), lbl return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR_FULL).astype(dtype)
class ColorJitter(object): class ColorJitter(object):
...@@ -499,5 +499,5 @@ class ColorJitter(object): ...@@ -499,5 +499,5 @@ class ColorJitter(object):
random.shuffle(transforms) random.shuffle(transforms)
self.transforms = Compose(transforms) self.transforms = Compose(transforms)
def __call__(self, img, lbl): def __call__(self, img):
return self.transforms(img, lbl) return self.transforms(img)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册