diff --git a/examples/image_classification/imagenet_dataset.py b/examples/image_classification/imagenet_dataset.py index 3d730de6e33f0f12474ded283630bb5736c3d8a7..25dcc338e20e53e75f3637ea8f8e3d492a1240e1 100644 --- a/examples/image_classification/imagenet_dataset.py +++ b/examples/image_classification/imagenet_dataset.py @@ -51,7 +51,7 @@ class ImageNetDataset(DatasetFolder): img_path, label = self.samples[idx] img = cv2.imread(img_path).astype(np.float32) label = np.array([label]) - return self.transform(img, label) + return self.transform(img), label def __len__(self): return len(self.samples) diff --git a/hapi/datasets/folder.py b/hapi/datasets/folder.py index 23f2c9592915e3e83d596c9cc3679eca306a4bd5..c13710ea033dd62b665d60967d3acc91cb84c4ef 100644 --- a/hapi/datasets/folder.py +++ b/hapi/datasets/folder.py @@ -150,7 +150,7 @@ class DatasetFolder(Dataset): path, target = self.samples[index] sample = self.loader(path) if self.transform is not None: - sample, target = self.transform(sample, target) + sample, target = self.transform(sample) return sample, target diff --git a/hapi/vision/transforms/transforms.py b/hapi/vision/transforms/transforms.py index 3d974171ce0d6f5a80f2af6a272a4250d771fb4d..87e49862489c1d9284b9d3e6d018e0de2f183bcb 100644 --- a/hapi/vision/transforms/transforms.py +++ b/hapi/vision/transforms/transforms.py @@ -64,10 +64,10 @@ class Compose(object): def __init__(self, transforms): self.transforms = transforms - def __call__(self, *data): + def __call__(self, data): for f in self.transforms: try: - data = f(*data) + data = f(data) except Exception as e: stack_info = traceback.format_exc() print("fail to perform transform [{}] with error: " @@ -130,8 +130,8 @@ class Resize(object): self.size = size self.interpolation = interpolation - def __call__(self, img, lbl): - return F.resize(img, self.size, self.interpolation), lbl + def __call__(self, img): + return F.resize(img, self.size, self.interpolation) class RandomResizedCrop(object): @@ -193,10 +193,10 @@ class RandomResizedCrop(object): y = (height - h) // 2 return x, y, w, h - def __call__(self, img, lbl): + def __call__(self, img): x, y, w, h = self._get_params(img) cropped_img = img[y:y + h, x:x + w] - return F.resize(cropped_img, self.output_size, self.interpolation), lbl + return F.resize(cropped_img, self.output_size, self.interpolation) class CenterCropResize(object): @@ -224,10 +224,10 @@ class CenterCropResize(object): y = (w + 1 - c) // 2 return c, x, y - def __call__(self, img, lbl): + def __call__(self, img): c, x, y = self._get_params(img) cropped_img = img[x:x + c, y:y + c, :] - return F.resize(cropped_img, self.size, self.interpolation), lbl + return F.resize(cropped_img, self.size, self.interpolation) class CenterCrop(object): @@ -251,10 +251,10 @@ class CenterCrop(object): y = int(round((h - th) / 2.0)) return x, y - def __call__(self, img, lbl): + def __call__(self, img): x, y = self._get_params(img) th, tw = self.output_size - return img[y:y + th, x:x + tw], lbl + return img[y:y + th, x:x + tw] class RandomHorizontalFlip(object): @@ -267,10 +267,10 @@ class RandomHorizontalFlip(object): def __init__(self, prob=0.5): self.prob = prob - def __call__(self, img, lbl): + def __call__(self, img): if np.random.random() < self.prob: - return F.flip(img, code=1), lbl - return img, lbl + return F.flip(img, code=1) + return img class RandomVerticalFlip(object): @@ -283,10 +283,10 @@ class RandomVerticalFlip(object): def __init__(self, prob=0.5): self.prob = prob - def __call__(self, img, lbl): + def __call__(self, img): if np.random.random() < self.prob: - return F.flip(img, code=0), lbl - return img, lbl + return F.flip(img, code=0) + return img class Normalize(object): @@ -311,8 +311,8 @@ class Normalize(object): self.mean = np.array(mean, dtype=np.float32).reshape(len(mean), 1, 1) self.std = np.array(std, dtype=np.float32).reshape(len(std), 1, 1) - def __call__(self, img, lbl): - return (img - self.mean) / self.std, lbl + def __call__(self, img): + return (img - self.mean) / self.std class Permute(object): @@ -333,12 +333,12 @@ class Permute(object): self.mode = mode self.to_rgb = to_rgb - def __call__(self, img, lbl): + def __call__(self, img): if self.to_rgb: img = img[..., ::-1] if self.mode == "CHW": - return img.transpose((2, 0, 1)), lbl - return img, lbl + return img.transpose((2, 0, 1)) + return img class GaussianNoise(object): @@ -354,11 +354,11 @@ class GaussianNoise(object): self.mean = np.array(mean, dtype=np.float32) self.std = np.array(std, dtype=np.float32) - def __call__(self, img, lbl): + def __call__(self, img): dtype = img.dtype noise = np.random.normal(self.mean, self.std, img.shape) * 255 img = img + noise.astype(np.float32) - return np.clip(img, 0, 255).astype(dtype), lbl + return np.clip(img, 0, 255).astype(dtype) class BrightnessTransform(object): @@ -374,15 +374,15 @@ class BrightnessTransform(object): raise ValueError("brightness value should be non-negative") self.value = value - def __call__(self, img, lbl): + def __call__(self, img): if self.value == 0: - return img, lbl + return img dtype = img.dtype img = img.astype(np.float32) alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value) img = img * alpha - return img.clip(0, 255).astype(dtype), lbl + return img.clip(0, 255).astype(dtype) class ContrastTransform(object): @@ -398,16 +398,16 @@ class ContrastTransform(object): raise ValueError("contrast value should be non-negative") self.value = value - def __call__(self, img, lbl): + def __call__(self, img): if self.value == 0: - return img, lbl + return img dtype = img.dtype img = img.astype(np.float32) alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value) img = img * alpha + cv2.cvtColor(img, cv2.COLOR_BGR2GRAY).mean() * ( 1 - alpha) - return img.clip(0, 255).astype(dtype), lbl + return img.clip(0, 255).astype(dtype) class SaturationTransform(object): @@ -423,9 +423,9 @@ class SaturationTransform(object): raise ValueError("saturation value should be non-negative") self.value = value - def __call__(self, img, lbl): + def __call__(self, img): if self.value == 0: - return img, lbl + return img dtype = img.dtype img = img.astype(np.float32) @@ -433,7 +433,7 @@ class SaturationTransform(object): gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) gray_img = gray_img[..., np.newaxis] img = img * alpha + gray_img * (1 - alpha) - return img.clip(0, 255).astype(dtype), lbl + return img.clip(0, 255).astype(dtype) class HueTransform(object): @@ -449,9 +449,9 @@ class HueTransform(object): raise ValueError("hue value should be in [0.0, 0.5]") self.value = value - def __call__(self, img, lbl): + def __call__(self, img): if self.value == 0: - return img, lbl + return img dtype = img.dtype img = img.astype(np.uint8) @@ -464,7 +464,7 @@ class HueTransform(object): with np.errstate(over="ignore"): h += np.uint8(alpha * 255) hsv_img = cv2.merge([h, s, v]) - return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR_FULL).astype(dtype), lbl + return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR_FULL).astype(dtype) class ColorJitter(object): @@ -499,5 +499,5 @@ class ColorJitter(object): random.shuffle(transforms) self.transforms = Compose(transforms) - def __call__(self, img, lbl): - return self.transforms(img, lbl) + def __call__(self, img): + return self.transforms(img)