base_cv_dataset.py 3.4 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

W
wuzewu 已提交
21 22
import paddlehub as hub
from paddlehub.common.downloader import default_downloader
W
wuzewu 已提交
23 24


W
wuzewu 已提交
25
class ImageClassificationDataset(object):
W
wuzewu 已提交
26 27 28 29 30
    def __init__(self):
        self.base_path = None
        self.train_list_file = None
        self.test_list_file = None
        self.validate_list_file = None
W
wuzewu 已提交
31
        self.label_list_file = None
W
wuzewu 已提交
32
        self.num_labels = 0
W
wuzewu 已提交
33
        self.label_list = []
W
wuzewu 已提交
34 35 36 37 38

    def _download_dataset(self, dataset_path, url):
        if not os.path.exists(dataset_path):
            result, tips, dataset_path = default_downloader.download_file_and_uncompress(
                url=url,
W
wuzewu 已提交
39
                save_path=hub.common.dir.DATA_HOME,
W
wuzewu 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
                print_progress=True,
                replace=True)
            if not result:
                print(tips)
                exit()
        return dataset_path

    def _parse_data(self, data_path, shuffle=False):
        def _base_reader():
            data = []
            with open(data_path, "r") as file:
                while True:
                    line = file.readline()
                    if not line:
                        break
                    line = line.strip()
                    items = line.split(" ")
W
wuzewu 已提交
57 58
                    if len(items) > 2:
                        image_path = " ".join(items[0:-1])
W
wuzewu 已提交
59
                    else:
W
wuzewu 已提交
60 61 62 63 64 65 66
                        image_path = items[0]
                    if not os.path.isabs(image_path):
                        if self.base_path is not None:
                            image_path = os.path.join(self.base_path,
                                                      image_path)
                    label = items[-1]
                    data.append((image_path, items[-1]))
W
wuzewu 已提交
67 68 69 70 71 72 73 74 75

            if shuffle:
                np.random.shuffle(data)

            for item in data:
                yield item

        return _base_reader()

W
wuzewu 已提交
76 77
    def label_dict(self):
        if not self.label_list:
W
wuzewu 已提交
78 79
            with open(os.path.join(self.base_path, self.label_list_file),
                      "r") as file:
W
wuzewu 已提交
80 81 82
                self.label_list = file.read().split("\n")
        return {index: key for index, key in enumerate(self.label_list)}

W
wuzewu 已提交
83 84 85 86 87 88 89 90 91 92 93 94
    def train_data(self, shuffle=True):
        train_data_path = os.path.join(self.base_path, self.train_list_file)
        return self._parse_data(train_data_path, shuffle)

    def test_data(self, shuffle=False):
        test_data_path = os.path.join(self.base_path, self.test_list_file)
        return self._parse_data(test_data_path, shuffle)

    def validate_data(self, shuffle=False):
        validate_data_path = os.path.join(self.base_path,
                                          self.validate_list_file)
        return self._parse_data(validate_data_path, shuffle)