提交 91c4fda1 编写于 作者: W wuzewu 提交者: zhangxuefei

update dataset to add label list

上级 fff81191
...@@ -28,7 +28,9 @@ class ImageClassificationDataset(object): ...@@ -28,7 +28,9 @@ class ImageClassificationDataset(object):
self.train_list_file = None self.train_list_file = None
self.test_list_file = None self.test_list_file = None
self.validate_list_file = None self.validate_list_file = None
self.label_list_file = None
self.num_labels = 0 self.num_labels = 0
self.label_list = []
def _download_dataset(self, dataset_path, url): def _download_dataset(self, dataset_path, url):
if not os.path.exists(dataset_path): if not os.path.exists(dataset_path):
...@@ -70,6 +72,12 @@ class ImageClassificationDataset(object): ...@@ -70,6 +72,12 @@ class ImageClassificationDataset(object):
return _base_reader() return _base_reader()
def label_dict(self):
if not self.label_list:
with open(self.label_list_file, "r") as file:
self.label_list = file.read().split("\n")
return {index: key for index, key in enumerate(self.label_list)}
def train_data(self, shuffle=True): def train_data(self, shuffle=True):
train_data_path = os.path.join(self.base_path, self.train_list_file) train_data_path = os.path.join(self.base_path, self.train_list_file)
return self._parse_data(train_data_path, shuffle) return self._parse_data(train_data_path, shuffle)
......
...@@ -32,4 +32,5 @@ class DogCatDataset(ImageClassificationDataset): ...@@ -32,4 +32,5 @@ class DogCatDataset(ImageClassificationDataset):
self.train_list_file = "train_list.txt" self.train_list_file = "train_list.txt"
self.test_list_file = "test_list.txt" self.test_list_file = "test_list.txt"
self.validate_list_file = "validate_list.txt" self.validate_list_file = "validate_list.txt"
self.label_list_file = "label_list.txt"
self.num_labels = 2 self.num_labels = 2
...@@ -32,4 +32,5 @@ class FlowersDataset(ImageClassificationDataset): ...@@ -32,4 +32,5 @@ class FlowersDataset(ImageClassificationDataset):
self.train_list_file = "train_list.txt" self.train_list_file = "train_list.txt"
self.test_list_file = "test_list.txt" self.test_list_file = "test_list.txt"
self.validate_list_file = "validate_list.txt" self.validate_list_file = "validate_list.txt"
self.label_list_file = "label_list.txt"
self.num_labels = 5 self.num_labels = 5
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册