提交 3a238aa0 编写于 作者: S syyxsxx

fix python raspberry

上级 2af6a45f
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import random
import math
import cv2
import scipy
def bbox_area(src_bbox):
if src_bbox[2] < src_bbox[0] or src_bbox[3] < src_bbox[1]:
return 0.
else:
width = src_bbox[2] - src_bbox[0]
height = src_bbox[3] - src_bbox[1]
return width * height
def jaccard_overlap(sample_bbox, object_bbox):
if sample_bbox[0] >= object_bbox[2] or \
sample_bbox[2] <= object_bbox[0] or \
sample_bbox[1] >= object_bbox[3] or \
sample_bbox[3] <= object_bbox[1]:
return 0
intersect_xmin = max(sample_bbox[0], object_bbox[0])
intersect_ymin = max(sample_bbox[1], object_bbox[1])
intersect_xmax = min(sample_bbox[2], object_bbox[2])
intersect_ymax = min(sample_bbox[3], object_bbox[3])
intersect_size = (intersect_xmax - intersect_xmin) * (
intersect_ymax - intersect_ymin)
sample_bbox_size = bbox_area(sample_bbox)
object_bbox_size = bbox_area(object_bbox)
overlap = intersect_size / (
sample_bbox_size + object_bbox_size - intersect_size)
return overlap
def iou_matrix(a, b):
tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
area_o = (area_a[:, np.newaxis] + area_b - area_i)
return area_i / (area_o + 1e-10)
def crop_box_with_center_constraint(box, crop):
cropped_box = box.copy()
cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2])
cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:])
cropped_box[:, :2] -= crop[:2]
cropped_box[:, 2:] -= crop[:2]
centers = (box[:, :2] + box[:, 2:]) / 2
valid = np.logical_and(crop[:2] <= centers, centers < crop[2:]).all(axis=1)
valid = np.logical_and(
valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
return cropped_box, np.where(valid)[0]
def is_poly(segm):
if not isinstance(segm, (list, dict)):
raise Exception("Invalid segm type: {}".format(type(segm)))
return isinstance(segm, list)
def crop_image(img, crop):
x1, y1, x2, y2 = crop
return img[y1:y2, x1:x2, :]
def crop_segms(segms, valid_ids, crop, height, width):
def _crop_poly(segm, crop):
xmin, ymin, xmax, ymax = crop
crop_coord = [xmin, ymin, xmin, ymax, xmax, ymax, xmax, ymin]
crop_p = np.array(crop_coord).reshape(4, 2)
crop_p = Polygon(crop_p)
crop_segm = list()
for poly in segm:
poly = np.array(poly).reshape(len(poly) // 2, 2)
polygon = Polygon(poly)
if not polygon.is_valid:
exterior = polygon.exterior
multi_lines = exterior.intersection(exterior)
polygons = shapely.ops.polygonize(multi_lines)
polygon = MultiPolygon(polygons)
multi_polygon = list()
if isinstance(polygon, MultiPolygon):
multi_polygon = copy.deepcopy(polygon)
else:
multi_polygon.append(copy.deepcopy(polygon))
for per_polygon in multi_polygon:
inter = per_polygon.intersection(crop_p)
if not inter:
continue
if isinstance(inter, (MultiPolygon, GeometryCollection)):
for part in inter:
if not isinstance(part, Polygon):
continue
part = np.squeeze(
np.array(part.exterior.coords[:-1]).reshape(1, -1))
part[0::2] -= xmin
part[1::2] -= ymin
crop_segm.append(part.tolist())
elif isinstance(inter, Polygon):
crop_poly = np.squeeze(
np.array(inter.exterior.coords[:-1]).reshape(1, -1))
crop_poly[0::2] -= xmin
crop_poly[1::2] -= ymin
crop_segm.append(crop_poly.tolist())
else:
continue
return crop_segm
def _crop_rle(rle, crop, height, width):
if 'counts' in rle and type(rle['counts']) == list:
rle = mask_util.frPyObjects(rle, height, width)
mask = mask_util.decode(rle)
mask = mask[crop[1]:crop[3], crop[0]:crop[2]]
rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
return rle
crop_segms = []
for id in valid_ids:
segm = segms[id]
if is_poly(segm):
import copy
import shapely.ops
import logging
from shapely.geometry import Polygon, MultiPolygon, GeometryCollection
logging.getLogger("shapely").setLevel(logging.WARNING)
# Polygon format
crop_segms.append(_crop_poly(segm, crop))
else:
# RLE format
import pycocotools.mask as mask_util
crop_segms.append(_crop_rle(segm, crop, height, width))
return crop_segms
def expand_segms(segms, x, y, height, width, ratio):
def _expand_poly(poly, x, y):
expanded_poly = np.array(poly)
expanded_poly[0::2] += x
expanded_poly[1::2] += y
return expanded_poly.tolist()
def _expand_rle(rle, x, y, height, width, ratio):
if 'counts' in rle and type(rle['counts']) == list:
rle = mask_util.frPyObjects(rle, height, width)
mask = mask_util.decode(rle)
expanded_mask = np.full((int(height * ratio), int(width * ratio)),
0).astype(mask.dtype)
expanded_mask[y:y + height, x:x + width] = mask
rle = mask_util.encode(
np.array(expanded_mask, order='F', dtype=np.uint8))
return rle
expanded_segms = []
for segm in segms:
if is_poly(segm):
# Polygon format
expanded_segms.append([_expand_poly(poly, x, y) for poly in segm])
else:
# RLE format
import pycocotools.mask as mask_util
expanded_segms.append(
_expand_rle(segm, x, y, height, width, ratio))
return expanded_segms
def box_horizontal_flip(bboxes, width):
oldx1 = bboxes[:, 0].copy()
oldx2 = bboxes[:, 2].copy()
bboxes[:, 0] = width - oldx2 - 1
bboxes[:, 2] = width - oldx1 - 1
if bboxes.shape[0] != 0 and (bboxes[:, 2] < bboxes[:, 0]).all():
raise ValueError(
"RandomHorizontalFlip: invalid box, x2 should be greater than x1")
return bboxes
def segms_horizontal_flip(segms, height, width):
def _flip_poly(poly, width):
flipped_poly = np.array(poly)
flipped_poly[0::2] = width - np.array(poly[0::2]) - 1
return flipped_poly.tolist()
def _flip_rle(rle, height, width):
if 'counts' in rle and type(rle['counts']) == list:
rle = mask_util.frPyObjects([rle], height, width)
mask = mask_util.decode(rle)
mask = mask[:, ::-1]
rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
return rle
flipped_segms = []
for segm in segms:
if is_poly(segm):
# Polygon format
flipped_segms.append([_flip_poly(poly, width) for poly in segm])
else:
# RLE format
import pycocotools.mask as mask_util
flipped_segms.append(_flip_rle(segm, height, width))
return flipped_segms
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -18,7 +18,6 @@ import random ...@@ -18,7 +18,6 @@ import random
import os.path as osp import os.path as osp
import numpy as np import numpy as np
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
import utils.logging as logging
class ClsTransform: class ClsTransform:
...@@ -49,14 +48,7 @@ class Compose(ClsTransform): ...@@ -49,14 +48,7 @@ class Compose(ClsTransform):
'must be equal or larger than 1!') 'must be equal or larger than 1!')
self.transforms = transforms self.transforms = transforms
# 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
for op in self.transforms:
if not isinstance(op, ClsTransform):
import imgaug.augmenters as iaa
if not isinstance(op, iaa.Augmenter):
raise Exception(
"Elements in transforms should be defined in 'paddlex.cls.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
)
def __call__(self, im, label=None): def __call__(self, im, label=None):
""" """
...@@ -79,18 +71,10 @@ class Compose(ClsTransform): ...@@ -79,18 +71,10 @@ class Compose(ClsTransform):
raise TypeError('Can\'t read The image file {}!'.format(im)) raise TypeError('Can\'t read The image file {}!'.format(im))
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
for op in self.transforms: for op in self.transforms:
if isinstance(op, ClsTransform): outputs = op(im, label)
outputs = op(im, label) im = outputs[0]
im = outputs[0] if len(outputs) == 2:
if len(outputs) == 2: label = outputs[1]
label = outputs[1]
else:
import imgaug.augmenters as iaa
if isinstance(op, iaa.Augmenter):
im = execute_imgaug(op, im)
outputs = (im, )
if label is not None:
outputs = (im, label)
return outputs return outputs
def add_augmenters(self, augmenters): def add_augmenters(self, augmenters):
...@@ -100,109 +84,10 @@ class Compose(ClsTransform): ...@@ -100,109 +84,10 @@ class Compose(ClsTransform):
transform_names = [type(x).__name__ for x in self.transforms] transform_names = [type(x).__name__ for x in self.transforms]
for aug in augmenters: for aug in augmenters:
if type(aug).__name__ in transform_names: if type(aug).__name__ in transform_names:
logging.error("{} is already in ComposedTransforms, need to remove it from add_augmenters().".format(type(aug).__name__)) print("{} is already in ComposedTransforms, need to remove it from add_augmenters().".format(type(aug).__name__))
self.transforms = augmenters + self.transforms self.transforms = augmenters + self.transforms
class RandomCrop(ClsTransform):
"""对图像进行随机剪裁,模型训练时的数据增强操作。
1. 根据lower_scale、lower_ratio、upper_ratio计算随机剪裁的高、宽。
2. 根据随机剪裁的高、宽随机选取剪裁的起始点。
3. 剪裁图像。
4. 调整剪裁后的图像的大小到crop_size*crop_size。
Args:
crop_size (int): 随机裁剪后重新调整的目标边长。默认为224。
lower_scale (float): 裁剪面积相对原面积比例的最小限制。默认为0.08。
lower_ratio (float): 宽变换比例的最小限制。默认为3. / 4。
upper_ratio (float): 宽变换比例的最大限制。默认为4. / 3。
"""
def __init__(self,
crop_size=224,
lower_scale=0.08,
lower_ratio=3. / 4,
upper_ratio=4. / 3):
self.crop_size = crop_size
self.lower_scale = lower_scale
self.lower_ratio = lower_ratio
self.upper_ratio = upper_ratio
def __call__(self, im, label=None):
"""
Args:
im (np.ndarray): 图像np.ndarray数据。
label (int): 每张图像所对应的类别序号。
Returns:
tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
"""
im = random_crop(im, self.crop_size, self.lower_scale,
self.lower_ratio, self.upper_ratio)
if label is None:
return (im, )
else:
return (im, label)
class RandomHorizontalFlip(ClsTransform):
"""以一定的概率对图像进行随机水平翻转,模型训练时的数据增强操作。
Args:
prob (float): 随机水平翻转的概率。默认为0.5。
"""
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, im, label=None):
"""
Args:
im (np.ndarray): 图像np.ndarray数据。
label (int): 每张图像所对应的类别序号。
Returns:
tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
"""
if random.random() < self.prob:
im = horizontal_flip(im)
if label is None:
return (im, )
else:
return (im, label)
class RandomVerticalFlip(ClsTransform):
"""以一定的概率对图像进行随机垂直翻转,模型训练时的数据增强操作。
Args:
prob (float): 随机垂直翻转的概率。默认为0.5。
"""
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, im, label=None):
"""
Args:
im (np.ndarray): 图像np.ndarray数据。
label (int): 每张图像所对应的类别序号。
Returns:
tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
"""
if random.random() < self.prob:
im = vertical_flip(im)
if label is None:
return (im, )
else:
return (im, label)
class Normalize(ClsTransform): class Normalize(ClsTransform):
"""对图像进行标准化。 """对图像进行标准化。
...@@ -315,131 +200,6 @@ class CenterCrop(ClsTransform): ...@@ -315,131 +200,6 @@ class CenterCrop(ClsTransform):
return (im, label) return (im, label)
class RandomRotate(ClsTransform):
def __init__(self, rotate_range=30, prob=0.5):
"""以一定的概率对图像在[-rotate_range, rotaterange]角度范围内进行旋转,模型训练时的数据增强操作。
Args:
rotate_range (int): 旋转度数的范围。默认为30。
prob (float): 随机旋转的概率。默认为0.5。
"""
self.rotate_range = rotate_range
self.prob = prob
def __call__(self, im, label=None):
"""
Args:
im (np.ndarray): 图像np.ndarray数据。
label (int): 每张图像所对应的类别序号。
Returns:
tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
"""
rotate_lower = -self.rotate_range
rotate_upper = self.rotate_range
im = im.astype('uint8')
im = Image.fromarray(im)
if np.random.uniform(0, 1) < self.prob:
im = rotate(im, rotate_lower, rotate_upper)
im = np.asarray(im).astype('float32')
if label is None:
return (im, )
else:
return (im, label)
class RandomDistort(ClsTransform):
"""以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。
1. 对变换的操作顺序进行随机化操作。
2. 按照1中的顺序以一定的概率对图像在范围[-range, range]内进行随机像素内容变换。
Args:
brightness_range (float): 明亮度因子的范围。默认为0.9。
brightness_prob (float): 随机调整明亮度的概率。默认为0.5。
contrast_range (float): 对比度因子的范围。默认为0.9。
contrast_prob (float): 随机调整对比度的概率。默认为0.5。
saturation_range (float): 饱和度因子的范围。默认为0.9。
saturation_prob (float): 随机调整饱和度的概率。默认为0.5。
hue_range (int): 色调因子的范围。默认为18。
hue_prob (float): 随机调整色调的概率。默认为0.5。
"""
def __init__(self,
brightness_range=0.9,
brightness_prob=0.5,
contrast_range=0.9,
contrast_prob=0.5,
saturation_range=0.9,
saturation_prob=0.5,
hue_range=18,
hue_prob=0.5):
self.brightness_range = brightness_range
self.brightness_prob = brightness_prob
self.contrast_range = contrast_range
self.contrast_prob = contrast_prob
self.saturation_range = saturation_range
self.saturation_prob = saturation_prob
self.hue_range = hue_range
self.hue_prob = hue_prob
def __call__(self, im, label=None):
"""
Args:
im (np.ndarray): 图像np.ndarray数据。
label (int): 每张图像所对应的类别序号。
Returns:
tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
"""
brightness_lower = 1 - self.brightness_range
brightness_upper = 1 + self.brightness_range
contrast_lower = 1 - self.contrast_range
contrast_upper = 1 + self.contrast_range
saturation_lower = 1 - self.saturation_range
saturation_upper = 1 + self.saturation_range
hue_lower = -self.hue_range
hue_upper = self.hue_range
ops = [brightness, contrast, saturation, hue]
random.shuffle(ops)
params_dict = {
'brightness': {
'brightness_lower': brightness_lower,
'brightness_upper': brightness_upper
},
'contrast': {
'contrast_lower': contrast_lower,
'contrast_upper': contrast_upper
},
'saturation': {
'saturation_lower': saturation_lower,
'saturation_upper': saturation_upper
},
'hue': {
'hue_lower': hue_lower,
'hue_upper': hue_upper
}
}
prob_dict = {
'brightness': self.brightness_prob,
'contrast': self.contrast_prob,
'saturation': self.saturation_prob,
'hue': self.hue_prob,
}
for id in range(len(ops)):
params = params_dict[ops[id].__name__]
prob = prob_dict[ops[id].__name__]
params['im'] = im
if np.random.uniform(0, 1) < prob:
im = ops[id](**params)
if label is None:
return (im, )
else:
return (im, label)
class ArrangeClassifier(ClsTransform): class ArrangeClassifier(ClsTransform):
"""获取训练/验证/预测所需信息。注意:此操作不需用户自己显示调用 """获取训练/验证/预测所需信息。注意:此操作不需用户自己显示调用
...@@ -510,12 +270,7 @@ class ComposedClsTransforms(Compose): ...@@ -510,12 +270,7 @@ class ComposedClsTransforms(Compose):
) )
if mode == 'train': if mode == 'train':
# 训练时的transforms,包含数据增强 pass
transforms = [
RandomCrop(crop_size=width), RandomHorizontalFlip(prob=0.5),
Normalize(
mean=mean, std=std)
]
else: else:
# 验证/预测时的transforms # 验证/预测时的transforms
transforms = [ transforms = [
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import copy
def execute_imgaug(augmenter, im, bboxes=None, polygons=None,
segment_map=None):
# 预处理,将bboxes, polygons转换成imgaug格式
import imgaug.augmentables.kps as kps
import imgaug.augmentables.bbs as bbs
aug_im = im.astype('uint8')
aug_im = augmenter.augment(image=aug_im).astype('float32')
return aug_im
# TODO imgaug的标注处理逻辑与paddlex已存的transform存在部分差异
# 目前仅支持对原图进行处理,因此只能使用pixlevel的imgaug增强操作
# 以下代码暂不会执行
aug_bboxes = None
if bboxes is not None:
aug_bboxes = list()
for i in range(len(bboxes)):
x1 = bboxes[i, 0]
y1 = bboxes[i, 1]
x2 = bboxes[i, 2]
y2 = bboxes[i, 3]
aug_bboxes.append(bbs.BoundingBox(x1, y1, x2, y2))
aug_points = None
if polygons is not None:
aug_points = list()
for i in range(len(polygons)):
num = len(polygons[i])
for j in range(num):
tmp = np.reshape(polygons[i][j], (-1, 2))
for k in range(len(tmp)):
aug_points.append(kps.Keypoint(tmp[k, 0], tmp[k, 1]))
aug_segment_map = None
if segment_map is not None:
if len(segment_map.shape) == 2:
h, w = segment_map.shape
aug_segment_map = np.reshape(segment_map, (1, h, w, 1))
elif len(segment_map.shape) == 3:
h, w, c = segment_map.shape
aug_segment_map = np.reshape(segment_map, (1, h, w, c))
else:
raise Exception(
"Only support 2-dimensions for 3-dimensions for segment_map")
unnormalized_batch = augmenter.augment(
image=aug_im,
bounding_boxes=aug_bboxes,
keypoints=aug_points,
segmentation_maps=aug_segment_map,
return_batch=True)
aug_im = unnormalized_batch.images_aug[0]
aug_bboxes = unnormalized_batch.bounding_boxes_aug
aug_points = unnormalized_batch.keypoints_aug
aug_seg_map = unnormalized_batch.segmentation_maps_aug
aug_im = aug_im.astype('float32')
if aug_bboxes is not None:
converted_bboxes = list()
for i in range(len(aug_bboxes)):
converted_bboxes.append([
aug_bboxes[i].x1, aug_bboxes[i].y1, aug_bboxes[i].x2,
aug_bboxes[i].y2
])
aug_bboxes = converted_bboxes
aug_polygons = None
if aug_points is not None:
aug_polygons = copy.deepcopy(polygons)
idx = 0
for i in range(len(aug_polygons)):
num = len(aug_polygons[i])
for j in range(num):
num_points = len(aug_polygons[i][j]) // 2
for k in range(num_points):
aug_polygons[i][j][k * 2] = aug_points[idx].x
aug_polygons[i][j][k * 2 + 1] = aug_points[idx].y
idx += 1
result = [aug_im]
if aug_bboxes is not None:
result.append(np.array(aug_bboxes))
if aug_polygons is not None:
result.append(aug_polygons)
if aug_seg_map is not None:
n, h, w, c = aug_seg_map.shape
if len(segment_map.shape) == 2:
aug_seg_map = np.reshape(aug_seg_map, (h, w))
elif len(segment_map.shape) == 3:
aug_seg_map = np.reshape(aug_seg_map, (h, w, c))
result.append(aug_seg_map)
return result
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -14,14 +14,12 @@ ...@@ -14,14 +14,12 @@
# limitations under the License. # limitations under the License.
from .ops import * from .ops import *
from .imgaug_support import execute_imgaug
import random import random
import os.path as osp import os.path as osp
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import cv2 import cv2
from collections import OrderedDict from collections import OrderedDict
import utils.logging as logging
class SegTransform: class SegTransform:
...@@ -53,14 +51,7 @@ class Compose(SegTransform): ...@@ -53,14 +51,7 @@ class Compose(SegTransform):
'must be equal or larger than 1!') 'must be equal or larger than 1!')
self.transforms = transforms self.transforms = transforms
self.to_rgb = False self.to_rgb = False
# 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
for op in self.transforms:
if not isinstance(op, SegTransform):
import imgaug.augmenters as iaa
if not isinstance(op, iaa.Augmenter):
raise Exception(
"Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
)
def __call__(self, im, im_info=None, label=None): def __call__(self, im, im_info=None, label=None):
""" """
...@@ -116,7 +107,7 @@ class Compose(SegTransform): ...@@ -116,7 +107,7 @@ class Compose(SegTransform):
transform_names = [type(x).__name__ for x in self.transforms] transform_names = [type(x).__name__ for x in self.transforms]
for aug in augmenters: for aug in augmenters:
if type(aug).__name__ in transform_names: if type(aug).__name__ in transform_names:
logging.error("{} is already in ComposedTransforms, need to remove it from add_augmenters().".format(type(aug).__name__)) print("{} is already in ComposedTransforms, need to remove it from add_augmenters().".format(type(aug).__name__))
self.transforms = augmenters + self.transforms self.transforms = augmenters + self.transforms
...@@ -828,76 +819,6 @@ class RandomBlur(SegTransform): ...@@ -828,76 +819,6 @@ class RandomBlur(SegTransform):
return (im, im_info, label) return (im, im_info, label)
class RandomRotate(SegTransform):
"""对图像进行随机旋转, 模型训练时的数据增强操作。
在旋转区间[-rotate_range, rotate_range]内,对图像进行随机旋转,当存在标注图像时,同步进行,
并对旋转后的图像和标注图像进行相应的padding。
Args:
rotate_range (float): 最大旋转角度。默认为15度。
im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
label_padding_value (int): 标注图像padding的值。默认为255。
"""
def __init__(self,
rotate_range=15,
im_padding_value=[127.5, 127.5, 127.5],
label_padding_value=255):
self.rotate_range = rotate_range
self.im_padding_value = im_padding_value
self.label_padding_value = label_padding_value
def __call__(self, im, im_info=None, label=None):
"""
Args:
im (np.ndarray): 图像np.ndarray数据。
im_info (list): 存储图像reisze或padding前的shape信息,如
[('resize', [200, 300]), ('padding', [400, 600])]表示
图像在过resize前shape为(200, 300), 过padding前shape为
(400, 600)
label (np.ndarray): 标注图像np.ndarray数据。
Returns:
tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
存储与图像相关信息的字典和标注图像np.ndarray数据。
"""
if self.rotate_range > 0:
(h, w) = im.shape[:2]
do_rotation = np.random.uniform(-self.rotate_range,
self.rotate_range)
pc = (w // 2, h // 2)
r = cv2.getRotationMatrix2D(pc, do_rotation, 1.0)
cos = np.abs(r[0, 0])
sin = np.abs(r[0, 1])
nw = int((h * sin) + (w * cos))
nh = int((h * cos) + (w * sin))
(cx, cy) = pc
r[0, 2] += (nw / 2) - cx
r[1, 2] += (nh / 2) - cy
dsize = (nw, nh)
im = cv2.warpAffine(
im,
r,
dsize=dsize,
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=self.im_padding_value)
label = cv2.warpAffine(
label,
r,
dsize=dsize,
flags=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
borderValue=self.label_padding_value)
if label is None:
return (im, im_info)
else:
return (im, im_info, label)
class RandomScaleAspect(SegTransform): class RandomScaleAspect(SegTransform):
...@@ -1125,11 +1046,7 @@ class ComposedSegTransforms(Compose): ...@@ -1125,11 +1046,7 @@ class ComposedSegTransforms(Compose):
std=[0.5, 0.5, 0.5]): std=[0.5, 0.5, 0.5]):
if mode == 'train': if mode == 'train':
# 训练时的transforms,包含数据增强 # 训练时的transforms,包含数据增强
transforms = [ pass
RandomHorizontalFlip(prob=0.5), ResizeStepScaling(),
RandomPaddingCrop(crop_size=train_crop_size), Normalize(
mean=mean, std=std)
]
else: else:
# 验证/预测时的transforms # 验证/预测时的transforms
transforms = [Normalize(mean=mean, std=std)] transforms = [Normalize(mean=mean, std=std)]
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import os
import sys
import colorama
from colorama import init
init(autoreset=True)
levels = {0: 'ERROR', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'}
log_level = 2
def log(level=2, message="", use_color=False):
current_time = time.time()
time_array = time.localtime(current_time)
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
if log_level >= level:
if use_color:
print("\033[1;31;40m{} [{}]\t{}\033[0m".format(current_time, levels[
level], message).encode("utf-8").decode("latin1"))
else:
print("{} [{}]\t{}".format(current_time, levels[level], message)
.encode("utf-8").decode("latin1"))
sys.stdout.flush()
def debug(message="", use_color=False):
log(level=3, message=message, use_color=use_color)
def info(message="", use_color=False):
log(level=2, message=message, use_color=use_color)
def warning(message="", use_color=True):
log(level=1, message=message, use_color=use_color)
def error(message="", use_color=True, exit=True):
log(level=0, message=message, use_color=use_color)
if exit:
sys.exit(-1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册