提交 9503f211 编写于 作者: Z Zeyu Chen

migrade from paddle_hub.data to paddle_hub.io

上级 c492de97
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from PIL import Image, ImageEnhance
from paddle_hub.tools import utils
import numpy as np
def _check_range_0_1(value):
value = value if value <= 1 else 1
value = value if value >= 0 else 0
return value
def _check_bound(low, high):
low = _check_range_0_1(low)
high = _check_range_0_1(high)
high = high if high >= low else low
return low, high
def _check_img(img):
if isinstance(img, str):
utils.check_path(img)
img = Image.open(img)
return img
def _check_img_and_size(img, width, height):
img = _check_img(img)
img_width, img_height = img.size
height = height if img_height > height else img_height
height = img_height if height <= 0 else height
width = width if img_width > width else img_width
width = img_width if width <= 0 else width
return img, width, height
def image_crop_from_position(img, width, height, w_start, h_start):
img, width, height = _check_img_and_size(img, width, height)
w_end = w_start + width
h_end = h_start + height
return img.crop((w_start, h_start, w_end, h_end))
def image_crop_from_TL(img, width, height):
w_start = h_start = 0
return image_crop_from_position(img, width, height, w_start, h_start)
def image_crop_from_TR(img, width, height):
img, width, height = _check_img_and_size(img, width, height)
w_start = img.size[0] - width
h_start = 0
return image_crop_from_position(img, width, height, w_start, h_start)
def image_crop_from_BL(img, width, height):
img, width, height = _check_img_and_size(img, width, height)
w_start = 0
h_start = img.size[1] - height
return image_crop_from_position(img, width, height, w_start, h_start)
def image_crop_from_BR(img, width, height):
img, width, height = _check_img_and_size(img, width, height)
w_start = img.size[0] - width
h_start = img.size[1] - height
return image_crop_from_position(img, width, height, w_start, h_start)
def image_crop_from_centor(img, width, height):
img = _check_img(img)
w_start = (img.size[0] - width) / 2
h_start = (img.size[1] - height) / 2
return image_crop_from_position(img, width, height, w_start, h_start)
def image_crop_random(img, width=0, height=0):
img = _check_img(img)
width = width if width else np.random.randint(
int(img.size[0] / 10), img.size[0])
height = height if height else np.random.randint(
int(img.size[1] / 10), img.size[1])
w_start = np.random.randint(0, img.size[0] - width)
h_start = np.random.randint(0, img.size[1] - height)
return image_crop_from_position(img, width, height, w_start, h_start)
def image_resize(img, width, height, interpolation_method=Image.LANCZOS):
img = _check_img(img)
return img.resize((width, height), interpolation_method)
def image_resize_random(img,
width=0,
height=0,
interpolation_method=Image.LANCZOS):
img = _check_img(img)
width = width if width else np.random.randint(
int(img.size[0] / 10), img.size[0])
height = height if height else np.random.randint(
int(img.size[1] / 10), img.size[1])
return image_resize(img, width, height, interpolation_method)
def image_rotate(img, angle, expand=False):
img = _check_img(img)
return img.rotate(angle, expand=expand)
def image_rotate_random(img, low=0, high=360, expand=False):
angle = np.random.randint(low, high)
return image_rotate(img, angle, expand)
def image_brightness_adjust(img, delta):
delta = _check_range_0_1(delta)
img = _check_img(img)
return ImageEnhance.Brightness(img).enhance(delta)
def image_brightness_adjust_random(img, low=0, high=1):
low, high = _check_bound(low, high)
delta = np.random.uniform(low, high)
return image_brightness_adjust(img, delta)
def image_contrast_adjust(img, delta):
delta = _check_range_0_1(delta)
img = _check_img(img)
return ImageEnhance.Contrast(img).enhance(delta)
def image_contrast_adjust_random(img, low=0, high=1):
low, high = _check_bound(low, high)
delta = np.random.uniform(low, high)
return image_contrast_adjust(img, delta)
def image_saturation_adjust(img, delta):
delta = _check_range_0_1(delta)
img = _check_img(img)
return ImageEnhance.Color(img).enhance(delta)
def image_saturation_adjust_random(img, low=0, high=1):
low, high = _check_bound(low, high)
delta = np.random.uniform(low, high)
return image_saturation_adjust(img, delta)
def image_flip_top_bottom(img):
img = _check_img(img)
return img.transpose(Image.FLIP_TOP_BOTTOM)
def image_flip_left_right(img):
img = _check_img(img)
return img.transpose(Image.FLIP_LEFT_RIGHT)
def image_flip_random(img):
img = _check_img(img)
flag = np.random.randint(0, 1)
if flag:
return image_flip_top_bottom(img)
else:
return image_flip_left_right(img)
def image_random_process(img,
enable_resize=True,
enable_crop=True,
enable_rotate=True,
enable_brightness_adjust=True,
enable_contrast_adjust=True,
enable_saturation_adjust=True,
enable_flip=True):
operator_list = []
if enable_resize:
operator_list.append(image_resize_random)
if enable_crop:
operator_list.append(image_crop_random)
if enable_rotate:
operator_list.append(image_rotate_random)
if enable_brightness_adjust:
operator_list.append(image_brightness_adjust_random)
if enable_contrast_adjust:
operator_list.append(image_contrast_adjust_random)
if enable_saturation_adjust:
operator_list.append(image_saturation_adjust_random)
if enable_flip:
operator_list.append(image_flip_random)
if not operator_list:
return img
random_op_index = np.random.randint(0, len(operator_list) - 1)
random_op = operator_list[random_op_index]
return random_op(img)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
class CSVReader:
def __init__(self):
pass
def _check(self):
pass
def read(self, csv_file):
with open(csv_file, "r") as file:
content = file.read()
content = content.split('\n')
self.title = content[0].split(',')
self.content = {}
for key in self.title:
self.content[key] = []
for text in content[1:]:
if (text == ""):
continue
for index, item in enumerate(text.split(',')):
title = self.title[index]
self.content[title].append(item)
return self.content
class YAMLReader:
def __init__(self):
pass
def _check(self):
pass
def read(self, yaml_file):
with open(yaml_file, "r") as file:
content = file.read()
return yaml.load(content)
yaml_reader = YAMLReader()
csv_reader = CSVReader()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from PIL import Image
from paddle_hub.tools.logger import logger
from paddle_hub.tools import utils
class DataType(Enum):
IMAGE = 0
TEXT = 1
AUDIO = 2
VIDEO = 3
INT = 4
FLOAT = 5
@classmethod
def type(cls, data_type):
if data_type in DataType:
return data_type
data_type = data_type.upper()
if data_type in DataType.__dict__:
return DataType.__dict__[data_type]
return None
@classmethod
def str(cls, data_type):
if data_type == DataType.IMAGE:
return "IMAGE"
elif data_type == DataType.TEXT:
return "TEXT"
elif data_type == DataType.AUDIO:
return "AUDIO"
elif data_type == DataType.VIDEO:
return "VIDEO"
elif data_type == DataType.INT:
return "INT"
elif data_type == DataType.FLOAT:
return "FLOAT"
return None
@classmethod
def is_valid_type(cls, data_type):
data_type = DataType.type(data_type)
return data_type in DataType
@classmethod
def type_reader(cls, data_type):
data_type = DataType.type(data_type)
if not DataType.is_valid_type(data_type):
logger.critical("invalid data type %s" % data_type)
exit(1)
if data_type == DataType.IMAGE:
return ImageReader
elif data_type == DataType.TEXT:
return TextReader
else:
type_str = DataType.str(data_type)
logger.critical(
"data type %s not supported for the time being" % type_str)
exit(1)
class ImageReader:
@classmethod
def read(cls, path):
utils.check_path(path)
image = Image.open(path)
return image
class TextReader:
@classmethod
def read(cls, text):
return text
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册