diff --git a/examples/tsm/kinetics_dataset.py b/examples/tsm/kinetics_dataset.py index 123d89814a8c631569cd0503750cafac631cca22..6cfd0d1c6be615c8cd862a59ca3254a47f336c99 100644 --- a/examples/tsm/kinetics_dataset.py +++ b/examples/tsm/kinetics_dataset.py @@ -113,7 +113,7 @@ class KineticsDataset(Dataset): if self.transform: imgs, label = self.transform(imgs, label) - return imgs, np.array([label]) + return imgs, np.array([label]).astype('int64') @property def num_classes(self): diff --git a/hapi/datasets/flowers.py b/hapi/datasets/flowers.py index 1f4f707888d460260d598826ba15ca3c69455f7b..c360e8fc287dd97fc5747ae9f65c668e3b7a1cf1 100644 --- a/hapi/datasets/flowers.py +++ b/hapi/datasets/flowers.py @@ -123,7 +123,7 @@ class Flowers(Dataset): if self.transform is not None: image, label = self.transform(image, label) - return image, label + return image, label.astype('int64') def __len__(self): return len(self.indexes) diff --git a/hapi/datasets/mnist.py b/hapi/datasets/mnist.py index 18c62901edb95fd573334a4f3fe2201be7447711..11b5f310ffc6baf2df85a9bcae716c54715097fe 100644 --- a/hapi/datasets/mnist.py +++ b/hapi/datasets/mnist.py @@ -144,7 +144,7 @@ class MNIST(Dataset): for i in range(buffer_size): self.images.append(images[i, :]) - self.labels.append(np.array([labels[i]])) + self.labels.append(np.array([labels[i]]).astype('int64')) def __getitem__(self, idx): image, label = self.images[idx], self.labels[idx]