未验证 提交 9c2f0fe4 编写于 作者: J Jason 提交者: GitHub

Merge pull request #44 from PaddlePaddle/develop_imgaug

add imgaug support
...@@ -100,7 +100,7 @@ class CocoDetection(VOCDetection): ...@@ -100,7 +100,7 @@ class CocoDetection(VOCDetection):
gt_score = np.ones((num_bbox, 1), dtype=np.float32) gt_score = np.ones((num_bbox, 1), dtype=np.float32)
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32) is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
difficult = np.zeros((num_bbox, 1), dtype=np.int32) difficult = np.zeros((num_bbox, 1), dtype=np.int32)
gt_poly = [None] * num_bbox gt_poly = None
for i, box in enumerate(bboxes): for i, box in enumerate(bboxes):
catid = box['category_id'] catid = box['category_id']
...@@ -108,6 +108,8 @@ class CocoDetection(VOCDetection): ...@@ -108,6 +108,8 @@ class CocoDetection(VOCDetection):
gt_bbox[i, :] = box['clean_bbox'] gt_bbox[i, :] = box['clean_bbox']
is_crowd[i][0] = box['iscrowd'] is_crowd[i][0] = box['iscrowd']
if 'segmentation' in box: if 'segmentation' in box:
if gt_poly is None:
gt_poly = [None] * num_bbox
gt_poly[i] = box['segmentation'] gt_poly[i] = box['segmentation']
im_info = { im_info = {
...@@ -119,9 +121,11 @@ class CocoDetection(VOCDetection): ...@@ -119,9 +121,11 @@ class CocoDetection(VOCDetection):
'gt_class': gt_class, 'gt_class': gt_class,
'gt_bbox': gt_bbox, 'gt_bbox': gt_bbox,
'gt_score': gt_score, 'gt_score': gt_score,
'gt_poly': gt_poly,
'difficult': difficult 'difficult': difficult
} }
if gt_poly is not None:
label_info['gt_poly'] = gt_poly
coco_rec = (im_info, label_info) coco_rec = (im_info, label_info)
self.file_list.append([im_fname, coco_rec]) self.file_list.append([im_fname, coco_rec])
......
...@@ -153,7 +153,6 @@ class VOCDetection(Dataset): ...@@ -153,7 +153,6 @@ class VOCDetection(Dataset):
'gt_class': gt_class, 'gt_class': gt_class,
'gt_bbox': gt_bbox, 'gt_bbox': gt_bbox,
'gt_score': gt_score, 'gt_score': gt_score,
'gt_poly': [],
'difficult': difficult 'difficult': difficult
} }
voc_rec = (im_info, label_info) voc_rec = (im_info, label_info)
......
...@@ -16,6 +16,7 @@ import os ...@@ -16,6 +16,7 @@ import os
import cv2 import cv2
import colorsys import colorsys
import numpy as np import numpy as np
import time
import paddlex.utils.logging as logging import paddlex.utils.logging as logging
from .detection_eval import fixed_linspace, backup_linspace, loadRes from .detection_eval import fixed_linspace, backup_linspace, loadRes
...@@ -25,8 +26,12 @@ def visualize_detection(image, result, threshold=0.5, save_dir='./'): ...@@ -25,8 +26,12 @@ def visualize_detection(image, result, threshold=0.5, save_dir='./'):
Visualize bbox and mask results Visualize bbox and mask results
""" """
image_name = os.path.split(image)[-1] if isinstance(image, np.ndarray):
image = cv2.imread(image) image_name = str(int(time.time())) + '.jpg'
else:
image = cv2.imread(image)
image_name = os.path.split(image)[-1]
image = draw_bbox_mask(image, result, threshold=threshold) image = draw_bbox_mask(image, result, threshold=threshold)
if save_dir is not None: if save_dir is not None:
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
...@@ -56,13 +61,18 @@ def visualize_segmentation(image, result, weight=0.6, save_dir='./'): ...@@ -56,13 +61,18 @@ def visualize_segmentation(image, result, weight=0.6, save_dir='./'):
c3 = cv2.LUT(label_map, color_map[:, 2]) c3 = cv2.LUT(label_map, color_map[:, 2])
pseudo_img = np.dstack((c1, c2, c3)) pseudo_img = np.dstack((c1, c2, c3))
im = cv2.imread(image) if isinstance(image, np.ndarray):
im = image
image_name = str(int(time.time())) + '.jpg'
else:
image = cv2.imread(image)
image_name = os.path.split(image)[-1]
vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0) vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
if save_dir is not None: if save_dir is not None:
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
image_name = os.path.split(image)[-1]
out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name)) out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
cv2.imwrite(out_path, vis_result) cv2.imwrite(out_path, vis_result)
logging.info('The visualized result is saved as {}'.format(out_path)) logging.info('The visualized result is saved as {}'.format(out_path))
......
...@@ -13,13 +13,22 @@ ...@@ -13,13 +13,22 @@
# limitations under the License. # limitations under the License.
from .ops import * from .ops import *
from .imgaug_support import execute_imgaug
import random import random
import os.path as osp import os.path as osp
import numpy as np import numpy as np
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
class Compose: class ClsTransform:
"""分类Transform的基类
"""
def __init__(self):
pass
class Compose(ClsTransform):
"""根据数据预处理/增强算子对输入数据进行操作。 """根据数据预处理/增强算子对输入数据进行操作。
所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。 所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
...@@ -39,6 +48,15 @@ class Compose: ...@@ -39,6 +48,15 @@ class Compose:
'must be equal or larger than 1!') 'must be equal or larger than 1!')
self.transforms = transforms self.transforms = transforms
# 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
for op in self.transforms:
if not isinstance(op, ClsTransform):
import imgaug.augmenters as iaa
if not isinstance(op, iaa.Augmenter):
raise Exception(
"Elements in transforms should be defined in 'paddlex.cls.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
)
def __call__(self, im, label=None): def __call__(self, im, label=None):
""" """
Args: Args:
...@@ -48,20 +66,34 @@ class Compose: ...@@ -48,20 +66,34 @@ class Compose:
tuple: 根据网络所需字段所组成的tuple; tuple: 根据网络所需字段所组成的tuple;
字段由transforms中的最后一个数据预处理操作决定。 字段由transforms中的最后一个数据预处理操作决定。
""" """
try: if isinstance(im, np.ndarray):
im = cv2.imread(im).astype('float32') if len(im.shape) != 3:
except: raise Exception(
raise TypeError('Can\'t read The image file {}!'.format(im)) "im should be 3-dimension, but now is {}-dimensions".
format(len(im.shape)))
else:
try:
im = cv2.imread(im).astype('float32')
except:
raise TypeError('Can\'t read The image file {}!'.format(im))
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
for op in self.transforms: for op in self.transforms:
outputs = op(im, label) if isinstance(op, ClsTransform):
im = outputs[0] outputs = op(im, label)
if len(outputs) == 2: im = outputs[0]
label = outputs[1] if len(outputs) == 2:
label = outputs[1]
else:
import imgaug.augmenters as iaa
if isinstance(op, iaa.Augmenter):
im, = execute_imgaug(op, im)
output = (im, )
if label is not None:
output = (im, label)
return outputs return outputs
class RandomCrop: class RandomCrop(ClsTransform):
"""对图像进行随机剪裁,模型训练时的数据增强操作。 """对图像进行随机剪裁,模型训练时的数据增强操作。
1. 根据lower_scale、lower_ratio、upper_ratio计算随机剪裁的高、宽。 1. 根据lower_scale、lower_ratio、upper_ratio计算随机剪裁的高、宽。
...@@ -104,7 +136,7 @@ class RandomCrop: ...@@ -104,7 +136,7 @@ class RandomCrop:
return (im, label) return (im, label)
class RandomHorizontalFlip: class RandomHorizontalFlip(ClsTransform):
"""以一定的概率对图像进行随机水平翻转,模型训练时的数据增强操作。 """以一定的概率对图像进行随机水平翻转,模型训练时的数据增强操作。
Args: Args:
...@@ -132,7 +164,7 @@ class RandomHorizontalFlip: ...@@ -132,7 +164,7 @@ class RandomHorizontalFlip:
return (im, label) return (im, label)
class RandomVerticalFlip: class RandomVerticalFlip(ClsTransform):
"""以一定的概率对图像进行随机垂直翻转,模型训练时的数据增强操作。 """以一定的概率对图像进行随机垂直翻转,模型训练时的数据增强操作。
Args: Args:
...@@ -160,7 +192,7 @@ class RandomVerticalFlip: ...@@ -160,7 +192,7 @@ class RandomVerticalFlip:
return (im, label) return (im, label)
class Normalize: class Normalize(ClsTransform):
"""对图像进行标准化。 """对图像进行标准化。
1. 对图像进行归一化到区间[0.0, 1.0]。 1. 对图像进行归一化到区间[0.0, 1.0]。
...@@ -195,7 +227,7 @@ class Normalize: ...@@ -195,7 +227,7 @@ class Normalize:
return (im, label) return (im, label)
class ResizeByShort: class ResizeByShort(ClsTransform):
"""根据图像短边对图像重新调整大小(resize)。 """根据图像短边对图像重新调整大小(resize)。
1. 获取图像的长边和短边长度。 1. 获取图像的长边和短边长度。
...@@ -242,7 +274,7 @@ class ResizeByShort: ...@@ -242,7 +274,7 @@ class ResizeByShort:
return (im, label) return (im, label)
class CenterCrop: class CenterCrop(ClsTransform):
"""以图像中心点扩散裁剪长宽为`crop_size`的正方形 """以图像中心点扩散裁剪长宽为`crop_size`的正方形
1. 计算剪裁的起始点。 1. 计算剪裁的起始点。
...@@ -272,7 +304,7 @@ class CenterCrop: ...@@ -272,7 +304,7 @@ class CenterCrop:
return (im, label) return (im, label)
class RandomRotate: class RandomRotate(ClsTransform):
def __init__(self, rotate_range=30, prob=0.5): def __init__(self, rotate_range=30, prob=0.5):
"""以一定的概率对图像在[-rotate_range, rotaterange]角度范围内进行旋转,模型训练时的数据增强操作。 """以一定的概率对图像在[-rotate_range, rotaterange]角度范围内进行旋转,模型训练时的数据增强操作。
...@@ -306,7 +338,7 @@ class RandomRotate: ...@@ -306,7 +338,7 @@ class RandomRotate:
return (im, label) return (im, label)
class RandomDistort: class RandomDistort(ClsTransform):
"""以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。 """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。
1. 对变换的操作顺序进行随机化操作。 1. 对变换的操作顺序进行随机化操作。
...@@ -397,7 +429,7 @@ class RandomDistort: ...@@ -397,7 +429,7 @@ class RandomDistort:
return (im, label) return (im, label)
class ArrangeClassifier: class ArrangeClassifier(ClsTransform):
"""获取训练/验证/预测所需信息。注意:此操作不需用户自己显示调用 """获取训练/验证/预测所需信息。注意:此操作不需用户自己显示调用
Args: Args:
......
...@@ -24,11 +24,20 @@ import numpy as np ...@@ -24,11 +24,20 @@ import numpy as np
import cv2 import cv2
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
from .imgaug_support import execute_imgaug
from .ops import * from .ops import *
from .box_utils import * from .box_utils import *
class Compose: class DetTransform:
"""检测数据处理基类
"""
def __init__(self):
pass
class Compose(DetTransform):
"""根据数据预处理/增强列表对输入数据进行操作。 """根据数据预处理/增强列表对输入数据进行操作。
所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。 所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
...@@ -49,8 +58,16 @@ class Compose: ...@@ -49,8 +58,16 @@ class Compose:
self.transforms = transforms self.transforms = transforms
self.use_mixup = False self.use_mixup = False
for t in self.transforms: for t in self.transforms:
if t.__class__.__name__ == 'MixupImage': if type(t).__name__ == 'MixupImage':
self.use_mixup = True self.use_mixup = True
# 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
for op in self.transforms:
if not isinstance(op, DetTransform):
import imgaug.augmenters as iaa
if not isinstance(op, iaa.Augmenter):
raise Exception(
"Elements in transforms should be defined in 'paddlex.det.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
)
def __call__(self, im, im_info=None, label_info=None): def __call__(self, im, im_info=None, label_info=None):
""" """
...@@ -84,11 +101,18 @@ class Compose: ...@@ -84,11 +101,18 @@ class Compose:
def decode_image(im_file, im_info, label_info): def decode_image(im_file, im_info, label_info):
if im_info is None: if im_info is None:
im_info = dict() im_info = dict()
try: if isinstance(im_file, np.ndarray):
im = cv2.imread(im_file).astype('float32') if len(im_file.shape) != 3:
except: raise Exception(
raise TypeError( "im should be 3-dimensions, but now is {}-dimensions".
'Can\'t read The image file {}!'.format(im_file)) format(len(im_file.shape)))
im = im_file
else:
try:
im = cv2.imread(im_file).astype('float32')
except:
raise TypeError(
'Can\'t read The image file {}!'.format(im_file))
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
# make default im_info with [h, w, 1] # make default im_info with [h, w, 1]
im_info['im_resize_info'] = np.array( im_info['im_resize_info'] = np.array(
...@@ -117,12 +141,28 @@ class Compose: ...@@ -117,12 +141,28 @@ class Compose:
for op in self.transforms: for op in self.transforms:
if im is None: if im is None:
return None return None
outputs = op(im, im_info, label_info) if isinstance(op, DetTransform):
im = outputs[0] outputs = op(im, im_info, label_info)
im = outputs[0]
else:
if label_info is not None:
gt_poly = label_info.get('gt_poly', None)
gt_bbox = label_info['gt_bbox']
if gt_poly is None:
im, aug_bbox = execute_imgaug(op, im, bboxes=gt_bbox)
else:
im, aug_bbox, aug_poly = execute_imgaug(
op, im, bboxes=gt_bbox, polygons=gt_poly)
label_info['gt_poly'] = aug_poly
label_info['gt_bbox'] = aug_bbox
outputs = (im, im_info, label_info)
else:
im, = execute_imgaug(op, im)
outputs = (im, im_info)
return outputs return outputs
class ResizeByShort: class ResizeByShort(DetTransform):
"""根据图像的短边调整图像大小(resize)。 """根据图像的短边调整图像大小(resize)。
1. 获取图像的长边和短边长度。 1. 获取图像的长边和短边长度。
...@@ -194,7 +234,7 @@ class ResizeByShort: ...@@ -194,7 +234,7 @@ class ResizeByShort:
return (im, im_info, label_info) return (im, im_info, label_info)
class Padding: class Padding(DetTransform):
"""1.将图像的长和宽padding至coarsest_stride的倍数。如输入图像为[300, 640], """1.将图像的长和宽padding至coarsest_stride的倍数。如输入图像为[300, 640],
`coarest_stride`为32,则由于300不为32的倍数,因此在图像最右和最下使用0值 `coarest_stride`为32,则由于300不为32的倍数,因此在图像最右和最下使用0值
进行padding,最终输出图像为[320, 640]。 进行padding,最终输出图像为[320, 640]。
...@@ -290,7 +330,7 @@ class Padding: ...@@ -290,7 +330,7 @@ class Padding:
return (padding_im, im_info, label_info) return (padding_im, im_info, label_info)
class Resize: class Resize(DetTransform):
"""调整图像大小(resize)。 """调整图像大小(resize)。
- 当目标大小(target_size)类型为int时,根据插值方式, - 当目标大小(target_size)类型为int时,根据插值方式,
...@@ -369,7 +409,7 @@ class Resize: ...@@ -369,7 +409,7 @@ class Resize:
return (im, im_info, label_info) return (im, im_info, label_info)
class RandomHorizontalFlip: class RandomHorizontalFlip(DetTransform):
"""随机翻转图像、标注框、分割信息,模型训练时的数据增强操作。 """随机翻转图像、标注框、分割信息,模型训练时的数据增强操作。
1. 随机采样一个0-1之间的小数,当小数小于水平翻转概率时, 1. 随机采样一个0-1之间的小数,当小数小于水平翻转概率时,
...@@ -447,7 +487,7 @@ class RandomHorizontalFlip: ...@@ -447,7 +487,7 @@ class RandomHorizontalFlip:
return (im, im_info, label_info) return (im, im_info, label_info)
class Normalize: class Normalize(DetTransform):
"""对图像进行标准化。 """对图像进行标准化。
1. 归一化图像到到区间[0.0, 1.0]。 1. 归一化图像到到区间[0.0, 1.0]。
...@@ -491,7 +531,7 @@ class Normalize: ...@@ -491,7 +531,7 @@ class Normalize:
return (im, im_info, label_info) return (im, im_info, label_info)
class RandomDistort: class RandomDistort(DetTransform):
"""以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作 """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作
1. 对变换的操作顺序进行随机化操作。 1. 对变换的操作顺序进行随机化操作。
...@@ -585,7 +625,7 @@ class RandomDistort: ...@@ -585,7 +625,7 @@ class RandomDistort:
return (im, im_info, label_info) return (im, im_info, label_info)
class MixupImage: class MixupImage(DetTransform):
"""对图像进行mixup操作,模型训练时的数据增强操作,目前仅YOLOv3模型支持该transform。 """对图像进行mixup操作,模型训练时的数据增强操作,目前仅YOLOv3模型支持该transform。
当label_info中不存在mixup字段时,直接返回,否则进行下述操作: 当label_info中不存在mixup字段时,直接返回,否则进行下述操作:
...@@ -714,7 +754,7 @@ class MixupImage: ...@@ -714,7 +754,7 @@ class MixupImage:
return (im, im_info, label_info) return (im, im_info, label_info)
class RandomExpand: class RandomExpand(DetTransform):
"""随机扩张图像,模型训练时的数据增强操作。 """随机扩张图像,模型训练时的数据增强操作。
1. 随机选取扩张比例(扩张比例大于1时才进行扩张)。 1. 随机选取扩张比例(扩张比例大于1时才进行扩张)。
2. 计算扩张后图像大小。 2. 计算扩张后图像大小。
...@@ -796,7 +836,7 @@ class RandomExpand: ...@@ -796,7 +836,7 @@ class RandomExpand:
return (canvas, im_info, label_info) return (canvas, im_info, label_info)
class RandomCrop: class RandomCrop(DetTransform):
"""随机裁剪图像。 """随机裁剪图像。
1. 若allow_no_crop为True,则在thresholds加入’no_crop’。 1. 若allow_no_crop为True,则在thresholds加入’no_crop’。
2. 随机打乱thresholds。 2. 随机打乱thresholds。
...@@ -944,7 +984,7 @@ class RandomCrop: ...@@ -944,7 +984,7 @@ class RandomCrop:
return (im, im_info, label_info) return (im, im_info, label_info)
class ArrangeFasterRCNN: class ArrangeFasterRCNN(DetTransform):
"""获取FasterRCNN模型训练/验证/预测所需信息。 """获取FasterRCNN模型训练/验证/预测所需信息。
Args: Args:
...@@ -1019,7 +1059,7 @@ class ArrangeFasterRCNN: ...@@ -1019,7 +1059,7 @@ class ArrangeFasterRCNN:
return outputs return outputs
class ArrangeMaskRCNN: class ArrangeMaskRCNN(DetTransform):
"""获取MaskRCNN模型训练/验证/预测所需信息。 """获取MaskRCNN模型训练/验证/预测所需信息。
Args: Args:
...@@ -1103,7 +1143,7 @@ class ArrangeMaskRCNN: ...@@ -1103,7 +1143,7 @@ class ArrangeMaskRCNN:
return outputs return outputs
class ArrangeYOLOv3: class ArrangeYOLOv3(DetTransform):
"""获取YOLOv3模型训练/验证/预测所需信息。 """获取YOLOv3模型训练/验证/预测所需信息。
Args: Args:
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def execute_imgaug(augmenter, im, bboxes=None, polygons=None,
segment_map=None):
# 预处理,将bboxes, polygons转换成imgaug格式
import imgaug.augmentables.polys as polys
import imgaug.augmentables.bbs as bbs
aug_im = im.astype('uint8')
aug_bboxes = None
if bboxes is not None:
aug_bboxes = list()
for i in range(len(bboxes)):
x1 = bboxes[i, 0] - 1
y1 = bboxes[i, 1]
x2 = bboxes[i, 2]
y2 = bboxes[i, 3]
aug_bboxes.append(bbs.BoundingBox(x1, y1, x2, y2))
aug_polygons = None
lod_info = list()
if polygons is not None:
aug_polygons = list()
for i in range(len(polygons)):
num = len(polygons[i])
lod_info.append(num)
for j in range(num):
points = np.reshape(polygons[i][j], (-1, 2))
aug_polygons.append(polys.Polygon(points))
aug_segment_map = None
if segment_map is not None:
if len(segment_map.shape) == 2:
h, w = segment_map.shape
aug_segment_map = np.reshape(segment_map, (1, h, w, 1))
elif len(segment_map.shape) == 3:
h, w, c = segment_map.shape
aug_segment_map = np.reshape(segment_map, (1, h, w, c))
else:
raise Exception(
"Only support 2-dimensions for 3-dimensions for segment_map")
aug_im, aug_bboxes, aug_polygons, aug_seg_map = augmenter.augment(
image=aug_im,
bounding_boxes=aug_bboxes,
polygons=aug_polygons,
segmentation_maps=aug_segment_map)
aug_im = aug_im.astype('float32')
if aug_polygons is not None:
assert len(aug_bboxes) == len(
lod_info
), "Number of aug_bboxes should be equal to number of aug_polygons"
if aug_bboxes is not None:
# 裁剪掉在图像之外的bbox和polygon
for i in range(len(aug_bboxes)):
aug_bboxes[i] = aug_bboxes[i].clip_out_of_image(aug_im)
if aug_polygons is not None:
for i in range(len(aug_polygons)):
aug_polygons[i] = aug_polygons[i].clip_out_of_image(aug_im)
# 过滤掉无效的bbox和polygon,并转换为训练数据格式
converted_bboxes = list()
converted_polygons = list()
poly_index = 0
for i in range(len(aug_bboxes)):
# 过滤width或height不足1像素的框
if aug_bboxes[i].width < 1 or aug_bboxes[i].height < 1:
continue
if aug_polygons is None:
converted_bboxes.append([
aug_bboxes[i].x1, aug_bboxes[i].y1, aug_bboxes[i].x2,
aug_bboxes[i].y2
])
continue
# 如若有polygons,将会继续执行下面代码
polygons_this_box = list()
for ps in aug_polygons[poly_index:poly_index + lod_info[i]]:
if len(ps) == 0:
continue
for p in ps:
# 没有3个point的polygon被过滤
if len(p.exterior) < 3:
continue
polygons_this_box.append(p.exterior.flatten().tolist())
poly_index += lod_info[i]
if len(polygons_this_box) == 0:
continue
converted_bboxes.append([
aug_bboxes[i].x1, aug_bboxes[i].y1, aug_bboxes[i].x2,
aug_bboxes[i].y2
])
converted_polygons.append(polygons_this_box)
if len(converted_bboxes) == 0:
aug_im = im
converted_bboxes = bboxes
converted_polygons = polygons
result = [aug_im]
if bboxes is not None:
result.append(np.array(converted_bboxes))
if polygons is not None:
result.append(converted_polygons)
if segment_map is not None:
n, h, w, c = aug_seg_map.shape
if len(segment_map.shape) == 2:
aug_seg_map = np.reshape(aug_seg_map, (h, w))
elif len(segment_map.shape) == 3:
aug_seg_map = np.reshape(aug_seg_map, (h, w, c))
result.append(aug_seg_map)
return result
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from .ops import * from .ops import *
from .imgaug_support import execute_imgaug
import random import random
import os.path as osp import os.path as osp
import numpy as np import numpy as np
...@@ -22,7 +23,15 @@ import cv2 ...@@ -22,7 +23,15 @@ import cv2
from collections import OrderedDict from collections import OrderedDict
class Compose: class SegTransform:
""" 分割transform基类
"""
def __init__(self):
pass
class Compose(SegTransform):
"""根据数据预处理/增强算子对输入数据进行操作。 """根据数据预处理/增强算子对输入数据进行操作。
所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。 所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
...@@ -43,6 +52,14 @@ class Compose: ...@@ -43,6 +52,14 @@ class Compose:
'must be equal or larger than 1!') 'must be equal or larger than 1!')
self.transforms = transforms self.transforms = transforms
self.to_rgb = False self.to_rgb = False
# 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
for op in self.transforms:
if not isinstance(op, SegTransform):
import imgaug.augmenters as iaa
if not isinstance(op, iaa.Augmenter):
raise Exception(
"Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
)
def __call__(self, im, im_info=None, label=None): def __call__(self, im, im_info=None, label=None):
""" """
...@@ -60,26 +77,40 @@ class Compose: ...@@ -60,26 +77,40 @@ class Compose:
if im_info is None: if im_info is None:
im_info = list() im_info = list()
try: if isinstance(im, np.ndarray):
im = cv2.imread(im).astype('float32') if len(im.shape) != 3:
except: raise Exception(
raise ValueError('Can\'t read The image file {}!'.format(im)) "im should be 3-dimensions, but now is {}-dimensions".
format(len(im.shape)))
else:
try:
im = cv2.imread(im).astype('float32')
except:
raise ValueError('Can\'t read The image file {}!'.format(im))
if self.to_rgb: if self.to_rgb:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
if label is not None: if label is not None:
if not isinstance(label, np.ndarray): if not isinstance(label, np.ndarray):
label = np.asarray(Image.open(label)) label = np.asarray(Image.open(label))
for op in self.transforms: for op in self.transforms:
outputs = op(im, im_info, label) if isinstance(op, SegTransform):
im = outputs[0] outputs = op(im, im_info, label)
if len(outputs) >= 2: im = outputs[0]
im_info = outputs[1] if len(outputs) >= 2:
if len(outputs) == 3: im_info = outputs[1]
label = outputs[2] if len(outputs) == 3:
label = outputs[2]
else:
if label is not None:
im, label = execute_imgaug(op, im, segment_map=label)
outputs = (im, im_info, label)
else:
im, = execute_imgaug(op, im)
outputs = (im, im_info)
return outputs return outputs
class RandomHorizontalFlip: class RandomHorizontalFlip(SegTransform):
"""以一定的概率对图像进行水平翻转。当存在标注图像时,则同步进行翻转。 """以一定的概率对图像进行水平翻转。当存在标注图像时,则同步进行翻转。
Args: Args:
...@@ -115,7 +146,7 @@ class RandomHorizontalFlip: ...@@ -115,7 +146,7 @@ class RandomHorizontalFlip:
return (im, im_info, label) return (im, im_info, label)
class RandomVerticalFlip: class RandomVerticalFlip(SegTransform):
"""以一定的概率对图像进行垂直翻转。当存在标注图像时,则同步进行翻转。 """以一定的概率对图像进行垂直翻转。当存在标注图像时,则同步进行翻转。
Args: Args:
...@@ -150,7 +181,7 @@ class RandomVerticalFlip: ...@@ -150,7 +181,7 @@ class RandomVerticalFlip:
return (im, im_info, label) return (im, im_info, label)
class Resize: class Resize(SegTransform):
"""调整图像大小(resize),当存在标注图像时,则同步进行处理。 """调整图像大小(resize),当存在标注图像时,则同步进行处理。
- 当目标大小(target_size)类型为int时,根据插值方式, - 当目标大小(target_size)类型为int时,根据插值方式,
...@@ -260,7 +291,7 @@ class Resize: ...@@ -260,7 +291,7 @@ class Resize:
return (im, im_info, label) return (im, im_info, label)
class ResizeByLong: class ResizeByLong(SegTransform):
"""对图像长边resize到固定值,短边按比例进行缩放。当存在标注图像时,则同步进行处理。 """对图像长边resize到固定值,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
Args: Args:
...@@ -301,7 +332,7 @@ class ResizeByLong: ...@@ -301,7 +332,7 @@ class ResizeByLong:
return (im, im_info, label) return (im, im_info, label)
class ResizeByShort: class ResizeByShort(SegTransform):
"""根据图像的短边调整图像大小(resize)。 """根据图像的短边调整图像大小(resize)。
1. 获取图像的长边和短边长度。 1. 获取图像的长边和短边长度。
...@@ -378,7 +409,7 @@ class ResizeByShort: ...@@ -378,7 +409,7 @@ class ResizeByShort:
return (im, im_info, label) return (im, im_info, label)
class ResizeRangeScaling: class ResizeRangeScaling(SegTransform):
"""对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。 """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
Args: Args:
...@@ -427,7 +458,7 @@ class ResizeRangeScaling: ...@@ -427,7 +458,7 @@ class ResizeRangeScaling:
return (im, im_info, label) return (im, im_info, label)
class ResizeStepScaling: class ResizeStepScaling(SegTransform):
"""对图像按照某一个比例resize,这个比例以scale_step_size为步长 """对图像按照某一个比例resize,这个比例以scale_step_size为步长
在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。 在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。
...@@ -502,7 +533,7 @@ class ResizeStepScaling: ...@@ -502,7 +533,7 @@ class ResizeStepScaling:
return (im, im_info, label) return (im, im_info, label)
class Normalize: class Normalize(SegTransform):
"""对图像进行标准化。 """对图像进行标准化。
1.尺度缩放到 [0,1]。 1.尺度缩放到 [0,1]。
2.对图像进行减均值除以标准差操作。 2.对图像进行减均值除以标准差操作。
...@@ -550,7 +581,7 @@ class Normalize: ...@@ -550,7 +581,7 @@ class Normalize:
return (im, im_info, label) return (im, im_info, label)
class Padding: class Padding(SegTransform):
"""对图像或标注图像进行padding,padding方向为右和下。 """对图像或标注图像进行padding,padding方向为右和下。
根据提供的值对图像或标注图像进行padding操作。 根据提供的值对图像或标注图像进行padding操作。
...@@ -642,7 +673,7 @@ class Padding: ...@@ -642,7 +673,7 @@ class Padding:
return (im, im_info, label) return (im, im_info, label)
class RandomPaddingCrop: class RandomPaddingCrop(SegTransform):
"""对图像和标注图进行随机裁剪,当所需要的裁剪尺寸大于原图时,则进行padding操作。 """对图像和标注图进行随机裁剪,当所需要的裁剪尺寸大于原图时,则进行padding操作。
Args: Args:
...@@ -741,7 +772,7 @@ class RandomPaddingCrop: ...@@ -741,7 +772,7 @@ class RandomPaddingCrop:
return (im, im_info, label) return (im, im_info, label)
class RandomBlur: class RandomBlur(SegTransform):
"""以一定的概率对图像进行高斯模糊。 """以一定的概率对图像进行高斯模糊。
Args: Args:
...@@ -787,7 +818,7 @@ class RandomBlur: ...@@ -787,7 +818,7 @@ class RandomBlur:
return (im, im_info, label) return (im, im_info, label)
class RandomRotate: class RandomRotate(SegTransform):
"""对图像进行随机旋转, 模型训练时的数据增强操作。 """对图像进行随机旋转, 模型训练时的数据增强操作。
在旋转区间[-rotate_range, rotate_range]内,对图像进行随机旋转,当存在标注图像时,同步进行, 在旋转区间[-rotate_range, rotate_range]内,对图像进行随机旋转,当存在标注图像时,同步进行,
并对旋转后的图像和标注图像进行相应的padding。 并对旋转后的图像和标注图像进行相应的padding。
...@@ -859,7 +890,7 @@ class RandomRotate: ...@@ -859,7 +890,7 @@ class RandomRotate:
return (im, im_info, label) return (im, im_info, label)
class RandomScaleAspect: class RandomScaleAspect(SegTransform):
"""裁剪并resize回原始尺寸的图像和标注图像。 """裁剪并resize回原始尺寸的图像和标注图像。
按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。 按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。
...@@ -922,7 +953,7 @@ class RandomScaleAspect: ...@@ -922,7 +953,7 @@ class RandomScaleAspect:
return (im, im_info, label) return (im, im_info, label)
class RandomDistort: class RandomDistort(SegTransform):
"""对图像进行随机失真。 """对图像进行随机失真。
1. 对变换的操作顺序进行随机化操作。 1. 对变换的操作顺序进行随机化操作。
...@@ -1018,7 +1049,7 @@ class RandomDistort: ...@@ -1018,7 +1049,7 @@ class RandomDistort:
return (im, im_info, label) return (im, im_info, label)
class ArrangeSegmenter: class ArrangeSegmenter(SegTransform):
"""获取训练/验证/预测所需的信息。 """获取训练/验证/预测所需的信息。
Args: Args:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册