提交 b25dfb58 编写于 作者: W wuzewu

add cv reader and dataset

上级 08ef63b6
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle_hub as hub
from paddle_hub.tools.downloader import default_downloader
class ImageClassificationDataset:
def __init__(self):
self.base_path = None
self.train_list_file = None
self.test_list_file = None
self.validate_list_file = None
self.num_labels = 0
def _download_dataset(self, dataset_path, url):
if not os.path.exists(dataset_path):
result, tips, dataset_path = default_downloader.download_file_and_uncompress(
url=url,
save_path=hub.dir.DATA_HOME,
print_progress=True,
replace=True)
if not result:
print(tips)
exit()
return dataset_path
def _parse_data(self, data_path, shuffle=False):
def _base_reader():
data = []
with open(data_path, "r") as file:
while True:
line = file.readline()
if not line:
break
line = line.strip()
items = line.split(" ")
image_path = os.path.join(self.base_path, items[0])
label = items[1]
data.append((image_path, items[1]))
if shuffle:
np.random.shuffle(data)
for item in data:
yield item
return _base_reader()
def train_data(self, shuffle=True):
train_data_path = os.path.join(self.base_path, self.train_list_file)
return self._parse_data(train_data_path, shuffle)
def test_data(self, shuffle=False):
test_data_path = os.path.join(self.base_path, self.test_list_file)
return self._parse_data(test_data_path, shuffle)
def validate_data(self, shuffle=False):
validate_data_path = os.path.join(self.base_path,
self.validate_list_file)
return self._parse_data(validate_data_path, shuffle)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from PIL import Image
import paddle_hub.io.augmentation as image_augmentation
color_mode_dict = {
"RGB": [0, 1, 2],
"RBG": [0, 2, 1],
"GBR": [1, 2, 0],
"GRB": [1, 0, 2],
"BGR": [2, 1, 0],
"BRG": [2, 0, 1]
}
class ImageClassificationReader:
def __init__(self,
image_width,
image_height,
dataset,
color_mode="RGB",
data_augmentation=False):
self.image_width = image_width
self.image_height = image_height
self.color_mode = color_mode
self.dataset = dataset
self.data_augmentation = data_augmentation
if self.color_mode not in color_mode_dict:
raise ValueError(
"Color_mode should in %s." % color_mode_dict.keys())
if self.image_width <= 0 or self.image_height <= 0:
raise ValueError("Image width and height should not be negative.")
def data_generator(self, phase, shuffle=False):
if phase == "train":
data = self.dataset.train_data(shuffle)
elif phase == "test":
shuffle = False
data = self.dataset.test_data(shuffle)
elif phase == "validate":
shuffle = False
data = self.dataset.validate_data(shuffle)
def _data_reader():
for image_path, label in data:
image = Image.open(image_path)
image = image_augmentation.image_resize(image, self.image_width,
self.image_height)
if self.data_augmentation:
image = image_augmentation.image_random_process(
image, enable_resize=False)
# only support RGB
image = image.convert('RGB')
# HWC to CHW
image = np.array(image)
if len(image.shape) == 3:
image = np.swapaxes(image, 1, 2)
image = np.swapaxes(image, 1, 0)
image = image[color_mode_dict[self.color_mode], :, :]
yield ((image, label))
return _data_reader
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle_hub as hub
from paddle_hub.dataset.base_cv_dataset import ImageClassificationDataset
class DogCatDataset(ImageClassificationDataset):
def __init__(self):
super(DogCatDataset, self).__init__()
dataset_path = os.path.join(hub.dir.DATA_HOME, "dog-cat")
self.base_path = self._download_dataset(
dataset_path=dataset_path,
url="https://paddlehub-dataset.bj.bcebos.com/dog-cat.tar.gz")
self.train_list_file = "train_list.txt"
self.test_list_file = "test_list.txt"
self.validate_list_file = "validate_list.txt"
self.num_labels = 2
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle_hub as hub
from paddle_hub.dataset.base_cv_dataset import ImageClassificationDataset
class FlowersDataset(ImageClassificationDataset):
def __init__(self):
super(FlowersDataset, self).__init__()
dataset_path = os.path.join(hub.dir.DATA_HOME, "flower_photos")
self.base_path = self._download_dataset(
dataset_path=dataset_path,
url="https://paddlehub-dataset.bj.bcebos.com/flower_photos.tar.gz")
self.train_list_file = "train_list.txt"
self.test_list_file = "test_list.txt"
self.validate_list_file = "validate_list.txt"
self.num_labels = 5
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册