From 91c4fda1ec3d4dfc0132aab9d8d49aecb2bb7f22 Mon Sep 17 00:00:00 2001 From: wuzewu Date: Thu, 18 Apr 2019 16:11:03 +0800 Subject: [PATCH] update dataset to add label list --- paddlehub/dataset/base_cv_dataset.py | 8 ++++++++ paddlehub/dataset/dogcat.py | 1 + paddlehub/dataset/flowers.py | 1 + 3 files changed, 10 insertions(+) diff --git a/paddlehub/dataset/base_cv_dataset.py b/paddlehub/dataset/base_cv_dataset.py index 2da36772..5725c4cd 100644 --- a/paddlehub/dataset/base_cv_dataset.py +++ b/paddlehub/dataset/base_cv_dataset.py @@ -28,7 +28,9 @@ class ImageClassificationDataset(object): self.train_list_file = None self.test_list_file = None self.validate_list_file = None + self.label_list_file = None self.num_labels = 0 + self.label_list = [] def _download_dataset(self, dataset_path, url): if not os.path.exists(dataset_path): @@ -70,6 +72,12 @@ class ImageClassificationDataset(object): return _base_reader() + def label_dict(self): + if not self.label_list: + with open(self.label_list_file, "r") as file: + self.label_list = file.read().split("\n") + return {index: key for index, key in enumerate(self.label_list)} + def train_data(self, shuffle=True): train_data_path = os.path.join(self.base_path, self.train_list_file) return self._parse_data(train_data_path, shuffle) diff --git a/paddlehub/dataset/dogcat.py b/paddlehub/dataset/dogcat.py index 3bdca560..70600b34 100644 --- a/paddlehub/dataset/dogcat.py +++ b/paddlehub/dataset/dogcat.py @@ -32,4 +32,5 @@ class DogCatDataset(ImageClassificationDataset): self.train_list_file = "train_list.txt" self.test_list_file = "test_list.txt" self.validate_list_file = "validate_list.txt" + self.label_list_file = "label_list.txt" self.num_labels = 2 diff --git a/paddlehub/dataset/flowers.py b/paddlehub/dataset/flowers.py index 8f7a02c0..adef50ae 100644 --- a/paddlehub/dataset/flowers.py +++ b/paddlehub/dataset/flowers.py @@ -32,4 +32,5 @@ class FlowersDataset(ImageClassificationDataset): self.train_list_file = "train_list.txt" self.test_list_file = "test_list.txt" self.validate_list_file = "validate_list.txt" + self.label_list_file = "label_list.txt" self.num_labels = 5 -- GitLab