提交 d3a1a1fc 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: fix type of label to int64

上级 fbe88807
...@@ -32,5 +32,5 @@ class ICartoonDataset(CommonDataset): ...@@ -32,5 +32,5 @@ class ICartoonDataset(CommonDataset):
for l in lines: for l in lines:
l = l.strip().split("\t") l = l.strip().split("\t")
self.images.append(os.path.join(self._img_root, l[0])) self.images.append(os.path.join(self._img_root, l[0]))
self.labels.append(int(l[1])) self.labels.append(np.int64(l[1]))
assert os.path.exists(self.images[-1]) assert os.path.exists(self.images[-1])
...@@ -34,5 +34,5 @@ class ImageNetDataset(CommonDataset): ...@@ -34,5 +34,5 @@ class ImageNetDataset(CommonDataset):
for l in lines: for l in lines:
l = l.strip().split(" ") l = l.strip().split(" ")
self.images.append(os.path.join(self._img_root, l[0])) self.images.append(os.path.join(self._img_root, l[0]))
self.labels.append(int(l[1])) self.labels.append(np.int64(l[1]))
assert os.path.exists(self.images[-1]) assert os.path.exists(self.images[-1])
...@@ -28,6 +28,7 @@ import random ...@@ -28,6 +28,7 @@ import random
from .common_dataset import CommonDataset from .common_dataset import CommonDataset
class LogoDataset(CommonDataset): class LogoDataset(CommonDataset):
def _load_anno(self): def _load_anno(self):
assert os.path.exists(self._cls_path) assert os.path.exists(self._cls_path)
...@@ -41,7 +42,5 @@ class LogoDataset(CommonDataset): ...@@ -41,7 +42,5 @@ class LogoDataset(CommonDataset):
if l[0] == 'image_id': if l[0] == 'image_id':
continue continue
self.images.append(os.path.join(self._img_root, l[3])) self.images.append(os.path.join(self._img_root, l[3]))
self.labels.append(int(l[1])-1) self.labels.append(np.int64(l[1]) - 1)
assert os.path.exists(self.images[-1]) assert os.path.exists(self.images[-1])
...@@ -37,7 +37,7 @@ class MultiLabelDataset(CommonDataset): ...@@ -37,7 +37,7 @@ class MultiLabelDataset(CommonDataset):
self.images.append(os.path.join(self._img_root, l[0])) self.images.append(os.path.join(self._img_root, l[0]))
labels = l[1].split(',') labels = l[1].split(',')
labels = [int(i) for i in labels] labels = [np.int64(i) for i in labels]
self.labels.append(labels) self.labels.append(labels)
assert os.path.exists(self.images[-1]) assert os.path.exists(self.images[-1])
......
...@@ -112,8 +112,8 @@ class VeriWild(Dataset): ...@@ -112,8 +112,8 @@ class VeriWild(Dataset):
for l in lines: for l in lines:
l = l.strip().split() l = l.strip().split()
self.images.append(os.path.join(self._img_root, l[0])) self.images.append(os.path.join(self._img_root, l[0]))
self.labels.append(int(l[1])) self.labels.append(np.int64(l[1]))
self.cameras.append(int(l[2])) self.cameras.append(np.int64(l[2]))
assert os.path.exists(self.images[-1]) assert os.path.exists(self.images[-1])
def __getitem__(self, idx): def __getitem__(self, idx):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册