diff --git a/hapi/datasets/flowers.py b/hapi/datasets/flowers.py index 1f4f707888d460260d598826ba15ca3c69455f7b..aaae6107163d12b23a62366587f10c43d3579fab 100644 --- a/hapi/datasets/flowers.py +++ b/hapi/datasets/flowers.py @@ -121,7 +121,7 @@ class Flowers(Dataset): image = np.array(Image.open(io.BytesIO(image))) if self.transform is not None: - image, label = self.transform(image, label) + image = self.transform(image) return image, label diff --git a/hapi/datasets/mnist.py b/hapi/datasets/mnist.py index 18c62901edb95fd573334a4f3fe2201be7447711..e45aea711d8c4b8db2df5a558668939740bbc6bf 100644 --- a/hapi/datasets/mnist.py +++ b/hapi/datasets/mnist.py @@ -149,7 +149,7 @@ class MNIST(Dataset): def __getitem__(self, idx): image, label = self.images[idx], self.labels[idx] if self.transform is not None: - image, label = self.transform(image, label) + image = self.transform(image) return image, label def __len__(self): diff --git a/hapi/vision/transforms/transforms.py b/hapi/vision/transforms/transforms.py index 14bcd00221caf0f33d645d89478a9cf55d81c4ca..90b43b8d410aecd9fedbae3434c5d16c9351a411 100644 --- a/hapi/vision/transforms/transforms.py +++ b/hapi/vision/transforms/transforms.py @@ -61,6 +61,24 @@ class Compose(object): Args: transforms (list of ``Transform`` objects): list of transforms to compose. + Returns: + A compose object which is callable, __call__ for this Compose + object will call each given :attr:`transforms` sequencely. + + Examples: + + .. code-block:: python + + from hapi.datasets import Flowers + from hapi.vision.transforms import Compose, ColorJitter, Resize + + transform = Compose([ColorJitter(), Resize(size=608)]) + flowers = Flowers(mode='test', transform=transform) + + for i in range(10): + sample = flowers[i] + print(sample[0].shape, sample[1]) + """ def __init__(self, transforms):