diff --git a/paddlehub/dataset/base_cv_dataset.py b/paddlehub/dataset/base_cv_dataset.py index 2da367722a8ec1ef9c88d37946f470cb62d77268..5725c4cd2c3e2f84fccf022fccbd6bbd88f46596 100644 --- a/paddlehub/dataset/base_cv_dataset.py +++ b/paddlehub/dataset/base_cv_dataset.py @@ -28,7 +28,9 @@ class ImageClassificationDataset(object): self.train_list_file = None self.test_list_file = None self.validate_list_file = None + self.label_list_file = None self.num_labels = 0 + self.label_list = [] def _download_dataset(self, dataset_path, url): if not os.path.exists(dataset_path): @@ -70,6 +72,12 @@ class ImageClassificationDataset(object): return _base_reader() + def label_dict(self): + if not self.label_list: + with open(self.label_list_file, "r") as file: + self.label_list = file.read().split("\n") + return {index: key for index, key in enumerate(self.label_list)} + def train_data(self, shuffle=True): train_data_path = os.path.join(self.base_path, self.train_list_file) return self._parse_data(train_data_path, shuffle) diff --git a/paddlehub/dataset/dogcat.py b/paddlehub/dataset/dogcat.py index 3bdca560c19142eac6f93679b1fe635e693f7e1d..70600b34e18640bf8b7b10502e7233524ae330bb 100644 --- a/paddlehub/dataset/dogcat.py +++ b/paddlehub/dataset/dogcat.py @@ -32,4 +32,5 @@ class DogCatDataset(ImageClassificationDataset): self.train_list_file = "train_list.txt" self.test_list_file = "test_list.txt" self.validate_list_file = "validate_list.txt" + self.label_list_file = "label_list.txt" self.num_labels = 2 diff --git a/paddlehub/dataset/flowers.py b/paddlehub/dataset/flowers.py index 8f7a02c09c4c4c5ad99dff446bbd4064e62e3dc8..adef50aea3f7d248668342b8de5cf495d80f911f 100644 --- a/paddlehub/dataset/flowers.py +++ b/paddlehub/dataset/flowers.py @@ -32,4 +32,5 @@ class FlowersDataset(ImageClassificationDataset): self.train_list_file = "train_list.txt" self.test_list_file = "test_list.txt" self.validate_list_file = "validate_list.txt" + self.label_list_file = "label_list.txt" self.num_labels = 5