提交 91c4fda1 编写于 作者: W wuzewu 提交者: zhangxuefei

update dataset to add label list

上级 fff81191
......@@ -28,7 +28,9 @@ class ImageClassificationDataset(object):
self.train_list_file = None
self.test_list_file = None
self.validate_list_file = None
self.label_list_file = None
self.num_labels = 0
self.label_list = []
def _download_dataset(self, dataset_path, url):
if not os.path.exists(dataset_path):
......@@ -70,6 +72,12 @@ class ImageClassificationDataset(object):
return _base_reader()
def label_dict(self):
if not self.label_list:
with open(self.label_list_file, "r") as file:
self.label_list = file.read().split("\n")
return {index: key for index, key in enumerate(self.label_list)}
def train_data(self, shuffle=True):
train_data_path = os.path.join(self.base_path, self.train_list_file)
return self._parse_data(train_data_path, shuffle)
......
......@@ -32,4 +32,5 @@ class DogCatDataset(ImageClassificationDataset):
self.train_list_file = "train_list.txt"
self.test_list_file = "test_list.txt"
self.validate_list_file = "validate_list.txt"
self.label_list_file = "label_list.txt"
self.num_labels = 2
......@@ -32,4 +32,5 @@ class FlowersDataset(ImageClassificationDataset):
self.train_list_file = "train_list.txt"
self.test_list_file = "test_list.txt"
self.validate_list_file = "validate_list.txt"
self.label_list_file = "label_list.txt"
self.num_labels = 5
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册