diff --git a/paddle_hub/dataset/base_cv_dataset.py b/paddle_hub/dataset/base_cv_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d6afab474706ff401e1fe926cea60d014fc07ca5 --- /dev/null +++ b/paddle_hub/dataset/base_cv_dataset.py @@ -0,0 +1,78 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import paddle_hub as hub +from paddle_hub.tools.downloader import default_downloader + + +class ImageClassificationDataset: + def __init__(self): + self.base_path = None + self.train_list_file = None + self.test_list_file = None + self.validate_list_file = None + self.num_labels = 0 + + def _download_dataset(self, dataset_path, url): + if not os.path.exists(dataset_path): + result, tips, dataset_path = default_downloader.download_file_and_uncompress( + url=url, + save_path=hub.dir.DATA_HOME, + print_progress=True, + replace=True) + if not result: + print(tips) + exit() + return dataset_path + + def _parse_data(self, data_path, shuffle=False): + def _base_reader(): + data = [] + with open(data_path, "r") as file: + while True: + line = file.readline() + if not line: + break + line = line.strip() + items = line.split(" ") + image_path = os.path.join(self.base_path, items[0]) + label = items[1] + data.append((image_path, items[1])) + + if shuffle: + np.random.shuffle(data) + + for item in data: + yield item + + return _base_reader() + + def train_data(self, shuffle=True): + train_data_path = os.path.join(self.base_path, self.train_list_file) + return self._parse_data(train_data_path, shuffle) + + def test_data(self, shuffle=False): + test_data_path = os.path.join(self.base_path, self.test_list_file) + return self._parse_data(test_data_path, shuffle) + + def validate_data(self, shuffle=False): + validate_data_path = os.path.join(self.base_path, + self.validate_list_file) + return self._parse_data(validate_data_path, shuffle) diff --git a/paddle_hub/dataset/cv_reader.py b/paddle_hub/dataset/cv_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..a801ec43e8458d24c3c66ef08f57cda6a210e120 --- /dev/null +++ b/paddle_hub/dataset/cv_reader.py @@ -0,0 +1,84 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from PIL import Image + +import paddle_hub.io.augmentation as image_augmentation + +color_mode_dict = { + "RGB": [0, 1, 2], + "RBG": [0, 2, 1], + "GBR": [1, 2, 0], + "GRB": [1, 0, 2], + "BGR": [2, 1, 0], + "BRG": [2, 0, 1] +} + + +class ImageClassificationReader: + def __init__(self, + image_width, + image_height, + dataset, + color_mode="RGB", + data_augmentation=False): + self.image_width = image_width + self.image_height = image_height + self.color_mode = color_mode + self.dataset = dataset + self.data_augmentation = data_augmentation + if self.color_mode not in color_mode_dict: + raise ValueError( + "Color_mode should in %s." % color_mode_dict.keys()) + + if self.image_width <= 0 or self.image_height <= 0: + raise ValueError("Image width and height should not be negative.") + + def data_generator(self, phase, shuffle=False): + if phase == "train": + data = self.dataset.train_data(shuffle) + elif phase == "test": + shuffle = False + data = self.dataset.test_data(shuffle) + elif phase == "validate": + shuffle = False + data = self.dataset.validate_data(shuffle) + + def _data_reader(): + for image_path, label in data: + image = Image.open(image_path) + image = image_augmentation.image_resize(image, self.image_width, + self.image_height) + if self.data_augmentation: + image = image_augmentation.image_random_process( + image, enable_resize=False) + + # only support RGB + image = image.convert('RGB') + + # HWC to CHW + image = np.array(image) + if len(image.shape) == 3: + image = np.swapaxes(image, 1, 2) + image = np.swapaxes(image, 1, 0) + + image = image[color_mode_dict[self.color_mode], :, :] + yield ((image, label)) + + return _data_reader diff --git a/paddle_hub/dataset/dogcat.py b/paddle_hub/dataset/dogcat.py new file mode 100644 index 0000000000000000000000000000000000000000..37a1291ca8c91e178a0c3cd1e9e313a4ee6b3b86 --- /dev/null +++ b/paddle_hub/dataset/dogcat.py @@ -0,0 +1,35 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import paddle_hub as hub +from paddle_hub.dataset.base_cv_dataset import ImageClassificationDataset + + +class DogCatDataset(ImageClassificationDataset): + def __init__(self): + super(DogCatDataset, self).__init__() + dataset_path = os.path.join(hub.dir.DATA_HOME, "dog-cat") + self.base_path = self._download_dataset( + dataset_path=dataset_path, + url="https://paddlehub-dataset.bj.bcebos.com/dog-cat.tar.gz") + self.train_list_file = "train_list.txt" + self.test_list_file = "test_list.txt" + self.validate_list_file = "validate_list.txt" + self.num_labels = 2 diff --git a/paddle_hub/dataset/flowers.py b/paddle_hub/dataset/flowers.py new file mode 100644 index 0000000000000000000000000000000000000000..1c06b86573c0a8f5d2cbe5ee727bcb281da4453a --- /dev/null +++ b/paddle_hub/dataset/flowers.py @@ -0,0 +1,35 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import paddle_hub as hub +from paddle_hub.dataset.base_cv_dataset import ImageClassificationDataset + + +class FlowersDataset(ImageClassificationDataset): + def __init__(self): + super(FlowersDataset, self).__init__() + dataset_path = os.path.join(hub.dir.DATA_HOME, "flower_photos") + self.base_path = self._download_dataset( + dataset_path=dataset_path, + url="https://paddlehub-dataset.bj.bcebos.com/flower_photos.tar.gz") + self.train_list_file = "train_list.txt" + self.test_list_file = "test_list.txt" + self.validate_list_file = "validate_list.txt" + self.num_labels = 5