From d3a1a1fc457dea96d55f895c719ae41cc9840ceb Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Thu, 28 Oct 2021 07:09:41 +0000 Subject: [PATCH] fix: fix type of label to int64 --- ppcls/data/dataloader/icartoon_dataset.py | 2 +- ppcls/data/dataloader/imagenet_dataset.py | 2 +- ppcls/data/dataloader/logo_dataset.py | 5 ++--- ppcls/data/dataloader/multilabel_dataset.py | 2 +- ppcls/data/dataloader/vehicle_dataset.py | 4 ++-- 5 files changed, 7 insertions(+), 8 deletions(-) diff --git a/ppcls/data/dataloader/icartoon_dataset.py b/ppcls/data/dataloader/icartoon_dataset.py index 21234148..18e3b4b7 100644 --- a/ppcls/data/dataloader/icartoon_dataset.py +++ b/ppcls/data/dataloader/icartoon_dataset.py @@ -32,5 +32,5 @@ class ICartoonDataset(CommonDataset): for l in lines: l = l.strip().split("\t") self.images.append(os.path.join(self._img_root, l[0])) - self.labels.append(int(l[1])) + self.labels.append(np.int64(l[1])) assert os.path.exists(self.images[-1]) diff --git a/ppcls/data/dataloader/imagenet_dataset.py b/ppcls/data/dataloader/imagenet_dataset.py index e084bb74..1166ab38 100644 --- a/ppcls/data/dataloader/imagenet_dataset.py +++ b/ppcls/data/dataloader/imagenet_dataset.py @@ -34,5 +34,5 @@ class ImageNetDataset(CommonDataset): for l in lines: l = l.strip().split(" ") self.images.append(os.path.join(self._img_root, l[0])) - self.labels.append(int(l[1])) + self.labels.append(np.int64(l[1])) assert os.path.exists(self.images[-1]) diff --git a/ppcls/data/dataloader/logo_dataset.py b/ppcls/data/dataloader/logo_dataset.py index 3e05e7fe..132ead98 100644 --- a/ppcls/data/dataloader/logo_dataset.py +++ b/ppcls/data/dataloader/logo_dataset.py @@ -28,6 +28,7 @@ import random from .common_dataset import CommonDataset + class LogoDataset(CommonDataset): def _load_anno(self): assert os.path.exists(self._cls_path) @@ -41,7 +42,5 @@ class LogoDataset(CommonDataset): if l[0] == 'image_id': continue self.images.append(os.path.join(self._img_root, l[3])) - self.labels.append(int(l[1])-1) + self.labels.append(np.int64(l[1]) - 1) assert os.path.exists(self.images[-1]) - - diff --git a/ppcls/data/dataloader/multilabel_dataset.py b/ppcls/data/dataloader/multilabel_dataset.py index 08d2ba15..2c1ed770 100644 --- a/ppcls/data/dataloader/multilabel_dataset.py +++ b/ppcls/data/dataloader/multilabel_dataset.py @@ -37,7 +37,7 @@ class MultiLabelDataset(CommonDataset): self.images.append(os.path.join(self._img_root, l[0])) labels = l[1].split(',') - labels = [int(i) for i in labels] + labels = [np.int64(i) for i in labels] self.labels.append(labels) assert os.path.exists(self.images[-1]) diff --git a/ppcls/data/dataloader/vehicle_dataset.py b/ppcls/data/dataloader/vehicle_dataset.py index 80fc6bb0..2981a57a 100644 --- a/ppcls/data/dataloader/vehicle_dataset.py +++ b/ppcls/data/dataloader/vehicle_dataset.py @@ -112,8 +112,8 @@ class VeriWild(Dataset): for l in lines: l = l.strip().split() self.images.append(os.path.join(self._img_root, l[0])) - self.labels.append(int(l[1])) - self.cameras.append(int(l[2])) + self.labels.append(np.int64(l[1])) + self.cameras.append(np.int64(l[2])) assert os.path.exists(self.images[-1]) def __getitem__(self, idx): -- GitLab