From 1d2e74d9540d1e2fa7fbc25fc595b1cc8a4a122b Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Wed, 6 May 2020 20:40:33 +0800 Subject: [PATCH] randomcrop and randomexpand for mask --- docs/apis/transforms/det_transforms.md | 64 ++-- docs/deploy.md | 4 +- paddlex/cv/transforms/box_utils.py | 475 +++++++----------------- paddlex/cv/transforms/det_transforms.py | 326 ++++++++-------- 4 files changed, 325 insertions(+), 544 deletions(-) diff --git a/docs/apis/transforms/det_transforms.md b/docs/apis/transforms/det_transforms.md index 7d059f9..9be5a7b 100644 --- a/docs/apis/transforms/det_transforms.md +++ b/docs/apis/transforms/det_transforms.md @@ -122,56 +122,42 @@ paddlex.det.transforms.MixupImage(alpha=1.5, beta=1.5, mixup_epoch=-1) ## RandomExpand类 ```python -paddlex.det.transforms.RandomExpand(max_ratio=4., prob=0.5, mean=[127.5, 127.5, 127.5]) +paddlex.det.transforms.RandomExpand(ratio=4., prob=0.5, fill_value=[123.675, 116.28, 103.53]) ``` -随机扩张图像,模型训练时的数据增强操作,模型训练时的数据增强操作。 -1. 随机选取扩张比例(扩张比例大于1时才进行扩张)。 -2. 计算扩张后图像大小。 -3. 初始化像素值为数据集均值的图像,并将原图像随机粘贴于该图像上。 +随机扩张图像,模型训练时的数据增强操作。 +1. 随机选取扩张比例(扩张比例大于1时才进行扩张)。 +2. 计算扩张后图像大小。 +3. 初始化像素值为输入填充值的图像,并将原图像随机粘贴于该图像上。 4. 根据原图像粘贴位置换算出扩张后真实标注框的位置坐标。 +5. 根据原图像粘贴位置换算出扩张后真实分割区域的位置坐标。 ### 参数 -* **max_ratio** (float): 图像扩张的最大比例。默认为4.0。 +* **ratio** (float): 图像扩张的最大比例。默认为4.0。 * **prob** (float): 随机扩张的概率。默认为0.5。 -* **mean** (list): 图像数据集的均值(0-255)。默认为[127.5, 127.5, 127.5]。 +* **fill_value** (list): 扩张图像的初始填充值(0-255)。默认为[123.675, 116.28, 103.53]。 ## RandomCrop类 ```python -paddlex.det.transforms.RandomCrop(batch_sampler=None, satisfy_all=False, avoid_no_bbox=True) +paddlex.det.transforms.RandomCrop(aspect_ratio=[.5, 2.], thresholds=[.0, .1, .3, .5, .7, .9], scaling=[.3, 1.], num_attempts=50, allow_no_crop=True, cover_all_box=False) ``` 随机裁剪图像,模型训练时的数据增强操作。 -1. 根据batch_sampler计算获取裁剪候选区域的位置。 - (1) 根据min scale、max scale、min aspect ratio、max aspect ratio计算随机剪裁的高、宽。 - (2) 根据随机剪裁的高、宽随机选取剪裁的起始点。 - (3) 筛选出裁剪候选区域: - * 当satisfy_all为True时,需所有真实标注框与裁剪候选区域的重叠度满足需求时,该裁剪候选区域才可保留。 - * 当satisfy_all为False时,当有一个真实标注框与裁剪候选区域的重叠度满足需求时,该裁剪候选区域就可保留。 -2. 遍历所有裁剪候选区域: - (1) 若真实标注框与候选裁剪区域不重叠,或其中心点不在候选裁剪区域,则将该真实标注框去除。 - (2) 计算相对于该候选裁剪区域,真实标注框的位置,并筛选出对应的类别、混合得分。 - (3) 若avoid_no_bbox为False,返回当前裁剪后的信息即可;反之,要找到一个裁剪区域中真实标注框个数不为0的区域,才返回裁剪后的信息。 +1. 若allow_no_crop为True,则在thresholds加入’no_crop’ +2. 随机打乱thresholds +3. 遍历thresholds中各元素: + (1) 如果当前thresh为’no_crop’,则返回原始图像和标注信息 + (2) 随机取出aspect_ratio和scaling中的值并由此计算出候选裁剪区域的高、宽、起始点。 + (3) 计算真实标注框与候选裁剪区域IoU,若全部真实标注框的IoU都小于thresh,则继续第3步 + (4) 如果cover_all_box为True且存在真实标注框的IoU小于thresh,则继续第3步 + (5) 筛选出位于候选裁剪区域内的真实标注框,若有效框的个数为0,则继续第3步,否则进行第4步。 +4. 换算有效真值标注框相对候选裁剪区域的位置坐标。 +5. 换算有效分割区域相对候选裁剪区域的位置坐标。 ### 参数 -* **batch_sampler** (list): 随机裁剪参数的多种组合,每种组合包含8个值,如下: - - max sample (int):满足当前组合的裁剪区域的个数上限。 - - max trial (int): 查找满足当前组合的次数。 - - min scale (float): 裁剪面积相对原面积,每条边缩短比例的最小限制。 - - max scale (float): 裁剪面积相对原面积,每条边缩短比例的最大限制。 - - min aspect ratio (float): 裁剪后短边缩放比例的最小限制。 - - max aspect ratio (float): 裁剪后短边缩放比例的最大限制。 - - min overlap (float): 真实标注框与裁剪图像重叠面积的最小限制。 - - max overlap (float): 真实标注框与裁剪图像重叠面积的最大限制。 - - 默认值为None,当为None时采用如下设置: - - [[1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0], - [1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 1.0], - [1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 1.0], - [1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 1.0], - [1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 1.0], - [1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 1.0], - [1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]] -* **satisfy_all** (bool): 是否需要所有标注框满足条件,裁剪候选区域才保留。默认为False。 -* **avoid_no_bbox** (bool): 是否对裁剪图像不存在标注框的图像进行保留。默认为True。 +* **aspect_ratio** (list): 裁剪后短边缩放比例的取值范围,以[min, max]形式表示。默认值为[.5, 2.]。 +* **thresholds** (list): 判断裁剪候选区域是否有效所需的IoU阈值取值列表。默认值为[.0, .1, .3, .5, .7, .9]。 +* **scaling** (list): 裁剪面积相对原面积的取值范围,以[min, max]形式表示。默认值为[.3, 1.]。 +* **num_attempts** (int): 在放弃寻找有效裁剪区域前尝试的次数。默认值为50。 +* **allow_no_crop** (bool): 是否允许未进行裁剪。默认值为True。 +* **cover_all_box** (bool): 是否要求所有的真实标注框都必须在裁剪区域内。默认值为False。 diff --git a/docs/deploy.md b/docs/deploy.md index 21bd9be..198f4c4 100644 --- a/docs/deploy.md +++ b/docs/deploy.md @@ -4,10 +4,10 @@ 在服务端部署的模型需要首先将模型导出为inference格式模型,导出的模型将包括`__model__`、`__params__`和`model.yml`三个文名,分别为模型的网络结构,模型权重和模型的配置文件(包括数据预处理参数等等)。在安装完PaddleX后,在命令行终端使用如下命令导出模型到当前目录`inferece_model`下。 -> 可直接下载垃圾检测模型测试本文档的流程[garbage_epoch_12.tar.gz](https://bj.bcebos.com/paddlex/models/garbage_epoch_12.tar.gz) +> 可直接下载垃圾检测模型测试本文档的流程[xiaoduxiong_epoch_12.tar.gz](https://bj.bcebos.com/paddlex/models/xiaoduxiong_epoch_12.tar.gz) ``` -paddlex --export_inference --model_dir=./garbage_epoch_12 --save_dir=./inference_model +paddlex --export_inference --model_dir=./xiaoduxiong_epoch_12 --save_dir=./inference_model ``` ## 模型C++和Python部署方案预计一周内推出... diff --git a/paddlex/cv/transforms/box_utils.py b/paddlex/cv/transforms/box_utils.py index a52631c..02f3c4d 100644 --- a/paddlex/cv/transforms/box_utils.py +++ b/paddlex/cv/transforms/box_utils.py @@ -19,25 +19,6 @@ import cv2 import scipy -def meet_emit_constraint(src_bbox, sample_bbox): - center_x = (src_bbox[2] + src_bbox[0]) / 2 - center_y = (src_bbox[3] + src_bbox[1]) / 2 - if center_x >= sample_bbox[0] and \ - center_x <= sample_bbox[2] and \ - center_y >= sample_bbox[1] and \ - center_y <= sample_bbox[3]: - return True - return False - - -def clip_bbox(src_bbox): - src_bbox[0] = max(min(src_bbox[0], 1.0), 0.0) - src_bbox[1] = max(min(src_bbox[1], 1.0), 0.0) - src_bbox[2] = max(min(src_bbox[2], 1.0), 0.0) - src_bbox[3] = max(min(src_bbox[3], 1.0), 0.0) - return src_bbox - - def bbox_area(src_bbox): if src_bbox[2] < src_bbox[0] or src_bbox[3] < src_bbox[1]: return 0. @@ -47,189 +28,6 @@ def bbox_area(src_bbox): return width * height -def is_overlap(object_bbox, sample_bbox): - if object_bbox[0] >= sample_bbox[2] or \ - object_bbox[2] <= sample_bbox[0] or \ - object_bbox[1] >= sample_bbox[3] or \ - object_bbox[3] <= sample_bbox[1]: - return False - else: - return True - - -def filter_and_process(sample_bbox, bboxes, labels, scores=None): - new_bboxes = [] - new_labels = [] - new_scores = [] - for i in range(len(bboxes)): - new_bbox = [0, 0, 0, 0] - obj_bbox = [bboxes[i][0], bboxes[i][1], bboxes[i][2], bboxes[i][3]] - if not meet_emit_constraint(obj_bbox, sample_bbox): - continue - if not is_overlap(obj_bbox, sample_bbox): - continue - sample_width = sample_bbox[2] - sample_bbox[0] - sample_height = sample_bbox[3] - sample_bbox[1] - new_bbox[0] = (obj_bbox[0] - sample_bbox[0]) / sample_width - new_bbox[1] = (obj_bbox[1] - sample_bbox[1]) / sample_height - new_bbox[2] = (obj_bbox[2] - sample_bbox[0]) / sample_width - new_bbox[3] = (obj_bbox[3] - sample_bbox[1]) / sample_height - new_bbox = clip_bbox(new_bbox) - if bbox_area(new_bbox) > 0: - new_bboxes.append(new_bbox) - new_labels.append([labels[i][0]]) - if scores is not None: - new_scores.append([scores[i][0]]) - bboxes = np.array(new_bboxes) - labels = np.array(new_labels) - scores = np.array(new_scores) - return bboxes, labels, scores - - -def bbox_area_sampling(bboxes, labels, scores, target_size, min_size): - new_bboxes = [] - new_labels = [] - new_scores = [] - for i, bbox in enumerate(bboxes): - w = float((bbox[2] - bbox[0]) * target_size) - h = float((bbox[3] - bbox[1]) * target_size) - if w * h < float(min_size * min_size): - continue - else: - new_bboxes.append(bbox) - new_labels.append(labels[i]) - if scores is not None and scores.size != 0: - new_scores.append(scores[i]) - bboxes = np.array(new_bboxes) - labels = np.array(new_labels) - scores = np.array(new_scores) - return bboxes, labels, scores - - -def generate_sample_bbox(sampler): - scale = np.random.uniform(sampler[2], sampler[3]) - aspect_ratio = np.random.uniform(sampler[4], sampler[5]) - aspect_ratio = max(aspect_ratio, (scale**2.0)) - aspect_ratio = min(aspect_ratio, 1 / (scale**2.0)) - bbox_width = scale * (aspect_ratio**0.5) - bbox_height = scale / (aspect_ratio**0.5) - xmin_bound = 1 - bbox_width - ymin_bound = 1 - bbox_height - xmin = np.random.uniform(0, xmin_bound) - ymin = np.random.uniform(0, ymin_bound) - xmax = xmin + bbox_width - ymax = ymin + bbox_height - sampled_bbox = [xmin, ymin, xmax, ymax] - return sampled_bbox - - -def generate_sample_bbox_square(sampler, image_width, image_height): - scale = np.random.uniform(sampler[2], sampler[3]) - aspect_ratio = np.random.uniform(sampler[4], sampler[5]) - aspect_ratio = max(aspect_ratio, (scale**2.0)) - aspect_ratio = min(aspect_ratio, 1 / (scale**2.0)) - bbox_width = scale * (aspect_ratio**0.5) - bbox_height = scale / (aspect_ratio**0.5) - if image_height < image_width: - bbox_width = bbox_height * image_height / image_width - else: - bbox_height = bbox_width * image_width / image_height - xmin_bound = 1 - bbox_width - ymin_bound = 1 - bbox_height - xmin = np.random.uniform(0, xmin_bound) - ymin = np.random.uniform(0, ymin_bound) - xmax = xmin + bbox_width - ymax = ymin + bbox_height - sampled_bbox = [xmin, ymin, xmax, ymax] - return sampled_bbox - - -def data_anchor_sampling(bbox_labels, image_width, image_height, scale_array, - resize_width): - num_gt = len(bbox_labels) - # np.random.randint range: [low, high) - rand_idx = np.random.randint(0, num_gt) if num_gt != 0 else 0 - - if num_gt != 0: - norm_xmin = bbox_labels[rand_idx][0] - norm_ymin = bbox_labels[rand_idx][1] - norm_xmax = bbox_labels[rand_idx][2] - norm_ymax = bbox_labels[rand_idx][3] - - xmin = norm_xmin * image_width - ymin = norm_ymin * image_height - wid = image_width * (norm_xmax - norm_xmin) - hei = image_height * (norm_ymax - norm_ymin) - range_size = 0 - - area = wid * hei - for scale_ind in range(0, len(scale_array) - 1): - if area > scale_array[scale_ind] ** 2 and area < \ - scale_array[scale_ind + 1] ** 2: - range_size = scale_ind + 1 - break - - if area > scale_array[len(scale_array) - 2]**2: - range_size = len(scale_array) - 2 - - scale_choose = 0.0 - if range_size == 0: - rand_idx_size = 0 - else: - # np.random.randint range: [low, high) - rng_rand_size = np.random.randint(0, range_size + 1) - rand_idx_size = rng_rand_size % (range_size + 1) - - if rand_idx_size == range_size: - min_resize_val = scale_array[rand_idx_size] / 2.0 - max_resize_val = min(2.0 * scale_array[rand_idx_size], - 2 * math.sqrt(wid * hei)) - scale_choose = random.uniform(min_resize_val, max_resize_val) - else: - min_resize_val = scale_array[rand_idx_size] / 2.0 - max_resize_val = 2.0 * scale_array[rand_idx_size] - scale_choose = random.uniform(min_resize_val, max_resize_val) - - sample_bbox_size = wid * resize_width / scale_choose - - w_off_orig = 0.0 - h_off_orig = 0.0 - if sample_bbox_size < max(image_height, image_width): - if wid <= sample_bbox_size: - w_off_orig = np.random.uniform(xmin + wid - sample_bbox_size, - xmin) - else: - w_off_orig = np.random.uniform(xmin, - xmin + wid - sample_bbox_size) - - if hei <= sample_bbox_size: - h_off_orig = np.random.uniform(ymin + hei - sample_bbox_size, - ymin) - else: - h_off_orig = np.random.uniform(ymin, - ymin + hei - sample_bbox_size) - - else: - w_off_orig = np.random.uniform(image_width - sample_bbox_size, 0.0) - h_off_orig = np.random.uniform(image_height - sample_bbox_size, - 0.0) - - w_off_orig = math.floor(w_off_orig) - h_off_orig = math.floor(h_off_orig) - - # Figure out top left coordinates. - w_off = float(w_off_orig / image_width) - h_off = float(h_off_orig / image_height) - - sampled_bbox = [ - w_off, h_off, w_off + float(sample_bbox_size / image_width), - h_off + float(sample_bbox_size / image_height) - ] - return sampled_bbox - else: - return 0 - - def jaccard_overlap(sample_bbox, object_bbox): if sample_bbox[0] >= object_bbox[2] or \ sample_bbox[2] <= object_bbox[0] or \ @@ -249,143 +47,143 @@ def jaccard_overlap(sample_bbox, object_bbox): return overlap -def intersect_bbox(bbox1, bbox2): - if bbox2[0] > bbox1[2] or bbox2[2] < bbox1[0] or \ - bbox2[1] > bbox1[3] or bbox2[3] < bbox1[1]: - intersection_box = [0.0, 0.0, 0.0, 0.0] - else: - intersection_box = [ - max(bbox1[0], bbox2[0]), - max(bbox1[1], bbox2[1]), - min(bbox1[2], bbox2[2]), - min(bbox1[3], bbox2[3]) - ] - return intersection_box - - -def bbox_coverage(bbox1, bbox2): - inter_box = intersect_bbox(bbox1, bbox2) - intersect_size = bbox_area(inter_box) - - if intersect_size > 0: - bbox1_size = bbox_area(bbox1) - return intersect_size / bbox1_size - else: - return 0. +def iou_matrix(a, b): + tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + area_o = (area_a[:, np.newaxis] + area_b - area_i) + return area_i / (area_o + 1e-10) -def satisfy_sample_constraint(sampler, - sample_bbox, - gt_bboxes, - satisfy_all=False): - if sampler[6] == 0 and sampler[7] == 0: - return True - satisfied = [] - for i in range(len(gt_bboxes)): - object_bbox = [ - gt_bboxes[i][0], gt_bboxes[i][1], gt_bboxes[i][2], gt_bboxes[i][3] - ] - overlap = jaccard_overlap(sample_bbox, object_bbox) - if sampler[6] != 0 and \ - overlap < sampler[6]: - satisfied.append(False) - continue - if sampler[7] != 0 and \ - overlap > sampler[7]: - satisfied.append(False) - continue - satisfied.append(True) - if not satisfy_all: - return True - - if satisfy_all: - return np.all(satisfied) - else: - return False +def crop_box_with_center_constraint(box, crop): + cropped_box = box.copy() + + cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2]) + cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:]) + cropped_box[:, :2] -= crop[:2] + cropped_box[:, 2:] -= crop[:2] + + centers = (box[:, :2] + box[:, 2:]) / 2 + valid = np.logical_and(crop[:2] <= centers, centers < crop[2:]).all(axis=1) + valid = np.logical_and( + valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1)) + + return cropped_box, np.where(valid)[0] + + +def is_poly(segm): + if not isinstance(segm, (list, dict)): + raise Exception("Invalid segm type: {}".format(type(segm))) + return isinstance(segm, list) -def satisfy_sample_constraint_coverage(sampler, sample_bbox, gt_bboxes): - if sampler[6] == 0 and sampler[7] == 0: - has_jaccard_overlap = False - else: - has_jaccard_overlap = True - if sampler[8] == 0 and sampler[9] == 0: - has_object_coverage = False - else: - has_object_coverage = True - - if not has_jaccard_overlap and not has_object_coverage: - return True - found = False - for i in range(len(gt_bboxes)): - object_bbox = [ - gt_bboxes[i][0], gt_bboxes[i][1], gt_bboxes[i][2], gt_bboxes[i][3] - ] - if has_jaccard_overlap: - overlap = jaccard_overlap(sample_bbox, object_bbox) - if sampler[6] != 0 and \ - overlap < sampler[6]: - continue - if sampler[7] != 0 and \ - overlap > sampler[7]: - continue - found = True - if has_object_coverage: - object_coverage = bbox_coverage(object_bbox, sample_bbox) - if sampler[8] != 0 and \ - object_coverage < sampler[8]: - continue - if sampler[9] != 0 and \ - object_coverage > sampler[9]: - continue - found = True - if found: - return True - return found - - -def crop_image_sampling(img, sample_bbox, image_width, image_height, - target_size): - # no clipping here - xmin = int(sample_bbox[0] * image_width) - xmax = int(sample_bbox[2] * image_width) - ymin = int(sample_bbox[1] * image_height) - ymax = int(sample_bbox[3] * image_height) - - w_off = xmin - h_off = ymin - width = xmax - xmin - height = ymax - ymin - cross_xmin = max(0.0, float(w_off)) - cross_ymin = max(0.0, float(h_off)) - cross_xmax = min(float(w_off + width - 1.0), float(image_width)) - cross_ymax = min(float(h_off + height - 1.0), float(image_height)) - cross_width = cross_xmax - cross_xmin - cross_height = cross_ymax - cross_ymin - - roi_xmin = 0 if w_off >= 0 else abs(w_off) - roi_ymin = 0 if h_off >= 0 else abs(h_off) - roi_width = cross_width - roi_height = cross_height - - roi_y1 = int(roi_ymin) - roi_y2 = int(roi_ymin + roi_height) - roi_x1 = int(roi_xmin) - roi_x2 = int(roi_xmin + roi_width) - - cross_y1 = int(cross_ymin) - cross_y2 = int(cross_ymin + cross_height) - cross_x1 = int(cross_xmin) - cross_x2 = int(cross_xmin + cross_width) - - sample_img = np.zeros((height, width, 3)) - sample_img[roi_y1: roi_y2, roi_x1: roi_x2] = \ - img[cross_y1: cross_y2, cross_x1: cross_x2] - - sample_img = cv2.resize( - sample_img, (target_size, target_size), interpolation=cv2.INTER_AREA) - - return sample_img + +def crop_image(img, crop): + x1, y1, x2, y2 = crop + return img[y1:y2, x1:x2, :] + + +def crop_segms(segms, valid_ids, crop, height, width): + def _crop_poly(segm, crop): + xmin, ymin, xmax, ymax = crop + crop_coord = [xmin, ymin, xmin, ymax, xmax, ymax, xmax, ymin] + crop_p = np.array(crop_coord).reshape(4, 2) + crop_p = Polygon(crop_p) + + crop_segm = list() + for poly in segm: + poly = np.array(poly).reshape(len(poly) // 2, 2) + polygon = Polygon(poly) + if not polygon.is_valid: + exterior = polygon.exterior + multi_lines = exterior.intersection(exterior) + polygons = shapely.ops.polygonize(multi_lines) + polygon = MultiPolygon(polygons) + multi_polygon = list() + if isinstance(polygon, MultiPolygon): + multi_polygon = copy.deepcopy(polygon) + else: + multi_polygon.append(copy.deepcopy(polygon)) + for per_polygon in multi_polygon: + inter = per_polygon.intersection(crop_p) + if not inter: + continue + if isinstance(inter, (MultiPolygon, GeometryCollection)): + for part in inter: + if not isinstance(part, Polygon): + continue + part = np.squeeze( + np.array(part.exterior.coords[:-1]).reshape(1, -1)) + part[0::2] -= xmin + part[1::2] -= ymin + crop_segm.append(part.tolist()) + elif isinstance(inter, Polygon): + crop_poly = np.squeeze( + np.array(inter.exterior.coords[:-1]).reshape(1, -1)) + crop_poly[0::2] -= xmin + crop_poly[1::2] -= ymin + crop_segm.append(crop_poly.tolist()) + else: + continue + return crop_segm + + def _crop_rle(rle, crop, height, width): + if 'counts' in rle and type(rle['counts']) == list: + rle = mask_util.frPyObjects(rle, height, width) + mask = mask_util.decode(rle) + mask = mask[crop[1]:crop[3], crop[0]:crop[2]] + rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8)) + return rle + + crop_segms = [] + for id in valid_ids: + segm = segms[id] + if is_poly(segm): + import copy + import shapely.ops + import logging + from shapely.geometry import Polygon, MultiPolygon, GeometryCollection + logging.getLogger("shapely").setLevel(logging.WARNING) + # Polygon format + crop_segms.append(_crop_poly(segm, crop)) + else: + # RLE format + import pycocotools.mask as mask_util + crop_segms.append(_crop_rle(segm, crop, height, width)) + return crop_segms + + +def expand_segms(segms, x, y, height, width, ratio): + def _expand_poly(poly, x, y): + expanded_poly = np.array(poly) + expanded_poly[0::2] += x + expanded_poly[1::2] += y + return expanded_poly.tolist() + + def _expand_rle(rle, x, y, height, width, ratio): + if 'counts' in rle and type(rle['counts']) == list: + rle = mask_util.frPyObjects(rle, height, width) + mask = mask_util.decode(rle) + expanded_mask = np.full((int(height * ratio), int(width * ratio)), + 0).astype(mask.dtype) + expanded_mask[y:y + height, x:x + width] = mask + rle = mask_util.encode( + np.array(expanded_mask, order='F', dtype=np.uint8)) + return rle + + expanded_segms = [] + for segm in segms: + if is_poly(segm): + # Polygon format + expanded_segms.append([_expand_poly(poly, x, y) for poly in segm]) + else: + # RLE format + import pycocotools.mask as mask_util + expanded_segms.append( + _expand_rle(segm, x, y, height, width, ratio)) + return expanded_segms def box_horizontal_flip(bboxes, width): @@ -409,15 +207,10 @@ def segms_horizontal_flip(segms, height, width): if 'counts' in rle and type(rle['counts']) == list: rle = mask_util.frPyObjects([rle], height, width) mask = mask_util.decode(rle) - mask = mask[:, ::-1, :] + mask = mask[:, ::-1] rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8)) return rle - def is_poly(segm): - if not isinstance(segm, (list, dict)): - raise Exception("Invalid segm type: {}".format(type(segm))) - return isinstance(segm, list) - flipped_segms = [] for segm in segms: if is_poly(segm): diff --git a/paddlex/cv/transforms/det_transforms.py b/paddlex/cv/transforms/det_transforms.py index 5d58007..47d5929 100644 --- a/paddlex/cv/transforms/det_transforms.py +++ b/paddlex/cv/transforms/det_transforms.py @@ -12,13 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .ops import * -from .box_utils import * +try: + from collections.abc import Sequence +except Exception: + from collections import Sequence + +from numbers import Number + import random import os.path as osp import numpy as np -from PIL import Image, ImageEnhance + import cv2 +from PIL import Image, ImageEnhance + +from .ops import * +from .box_utils import * class Compose: @@ -81,7 +90,7 @@ class Compose: im = cv2.imread(im_file).astype('float32') except: raise TypeError( - 'Can\'t read The image file {}!'.format(im_file)) + 'Can\'t read The image file {}!'.format(im_file)) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) # make default im_info with [h, w, 1] im_info['im_resize_info'] = np.array( @@ -658,9 +667,17 @@ class MixupImage: gt_score2 = im_info['mixup'][2]['gt_score'] gt_score = np.concatenate( (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0) + if 'gt_poly' in label_info: + gt_poly1 = label_info['gt_poly'] + gt_poly2 = im_info['mixup'][2]['gt_poly'] + label_info['gt_poly'] = gt_poly1 + gt_poly2 + is_crowd1 = label_info['is_crowd'] + is_crowd2 = im_info['mixup'][2]['is_crowd'] + is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0) label_info['gt_bbox'] = gt_bbox label_info['gt_score'] = gt_score label_info['gt_class'] = gt_class + label_info['is_crowd'] = is_crowd im_info['augment_shape'] = np.array([im.shape[0], im.shape[1]]).astype('int32') im_info.pop('mixup') @@ -672,23 +689,32 @@ class MixupImage: class RandomExpand: """随机扩张图像,模型训练时的数据增强操作。 - 1. 随机选取扩张比例(扩张比例大于1时才进行扩张)。 2. 计算扩张后图像大小。 - 3. 初始化像素值为数据集均值的图像,并将原图像随机粘贴于该图像上。 + 3. 初始化像素值为输入填充值的图像,并将原图像随机粘贴于该图像上。 4. 根据原图像粘贴位置换算出扩张后真实标注框的位置坐标。 - + 5. 根据原图像粘贴位置换算出扩张后真实分割区域的位置坐标。 Args: - max_ratio (float): 图像扩张的最大比例。默认为4.0。 + ratio (float): 图像扩张的最大比例。默认为4.0。 prob (float): 随机扩张的概率。默认为0.5。 - mean (list): 图像数据集的均值(0-255)。默认为[127.5, 127.5, 127.5]。 - + fill_value (list): 扩张图像的初始填充值(0-255)。默认为[123.675, 116.28, 103.53]。 """ - def __init__(self, max_ratio=4., prob=0.5, mean=[127.5, 127.5, 127.5]): - self.max_ratio = max_ratio - self.mean = mean + def __init__(self, + ratio=4., + prob=0.5, + fill_value=[123.675, 116.28, 103.53]): + super(RandomExpand, self).__init__() + assert ratio > 1.01, "expand ratio must be larger than 1.01" + self.ratio = ratio self.prob = prob + assert isinstance(fill_value, (Number, Sequence)), \ + "fill value must be either float or sequence" + if isinstance(fill_value, Number): + fill_value = (fill_value, ) * 3 + if not isinstance(fill_value, tuple): + fill_value = tuple(fill_value) + self.fill_value = fill_value def __call__(self, im, im_info=None, label_info=None): """ @@ -696,7 +722,6 @@ class RandomExpand: im (np.ndarray): 图像np.ndarray数据。 im_info (dict, 可选): 存储与图像相关的信息。 label_info (dict, 可选): 存储与标注框相关的信息。 - Returns: tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; 当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、 @@ -708,7 +733,6 @@ class RandomExpand: 其中n代表真实标注框的个数。 - gt_class (np.ndarray): 随机扩张后每个真实标注框对应的类别序号,形状为(n, 1), 其中n代表真实标注框的个数。 - Raises: TypeError: 形参数据类型不满足需求。 """ @@ -723,108 +747,68 @@ class RandomExpand: 'gt_class' not in label_info: raise TypeError('Cannot do RandomExpand! ' + \ 'Becasuse gt_bbox/gt_class is not in label_info!') - prob = np.random.uniform(0, 1) + if np.random.uniform(0., 1.) < self.prob: + return (im, im_info, label_info) + augment_shape = im_info['augment_shape'] - im_width = augment_shape[1] - im_height = augment_shape[0] - gt_bbox = label_info['gt_bbox'] - gt_class = label_info['gt_class'] - - if prob < self.prob: - if self.max_ratio - 1 >= 0.01: - expand_ratio = np.random.uniform(1, self.max_ratio) - height = int(im_height * expand_ratio) - width = int(im_width * expand_ratio) - h_off = math.floor(np.random.uniform(0, height - im_height)) - w_off = math.floor(np.random.uniform(0, width - im_width)) - expand_bbox = [ - -w_off / im_width, -h_off / im_height, - (width - w_off) / im_width, (height - h_off) / im_height - ] - expand_im = np.ones((height, width, 3)) - expand_im = np.uint8(expand_im * np.squeeze(self.mean)) - expand_im = Image.fromarray(expand_im) - im = im.astype('uint8') - im = Image.fromarray(im) - expand_im.paste(im, (int(w_off), int(h_off))) - expand_im = np.asarray(expand_im) - for i in range(gt_bbox.shape[0]): - gt_bbox[i][0] = gt_bbox[i][0] / im_width - gt_bbox[i][1] = gt_bbox[i][1] / im_height - gt_bbox[i][2] = gt_bbox[i][2] / im_width - gt_bbox[i][3] = gt_bbox[i][3] / im_height - gt_bbox, gt_class, _ = filter_and_process( - expand_bbox, gt_bbox, gt_class) - for i in range(gt_bbox.shape[0]): - gt_bbox[i][0] = gt_bbox[i][0] * width - gt_bbox[i][1] = gt_bbox[i][1] * height - gt_bbox[i][2] = gt_bbox[i][2] * width - gt_bbox[i][3] = gt_bbox[i][3] * height - im = expand_im.astype('float32') - label_info['gt_bbox'] = gt_bbox - label_info['gt_class'] = gt_class - im_info['augment_shape'] = np.array([height, - width]).astype('int32') - if label_info is None: - return (im, im_info) - else: + height = int(augment_shape[0]) + width = int(augment_shape[1]) + + expand_ratio = np.random.uniform(1., self.ratio) + h = int(height * expand_ratio) + w = int(width * expand_ratio) + if not h > height or not w > width: return (im, im_info, label_info) + y = np.random.randint(0, h - height) + x = np.random.randint(0, w - width) + canvas = np.ones((h, w, 3), dtype=np.uint8) + canvas *= np.array(self.fill_value, dtype=np.uint8) + canvas[y:y + height, x:x + width, :] = im.astype(np.uint8) + + im_info['augment_shape'] = np.array([h, w]).astype('int32') + if 'gt_bbox' in label_info and len(label_info['gt_bbox']) > 0: + label_info['gt_bbox'] += np.array([x, y] * 2, dtype=np.float32) + if 'gt_poly' in label_info and len(label_info['gt_poly']) > 0: + label_info['gt_poly'] = expand_segms(label_info['gt_poly'], x, y, + height, width, expand_ratio) + return (canvas, im_info, label_info) class RandomCrop: """随机裁剪图像。 - - 1. 根据batch_sampler计算获取裁剪候选区域的位置。 - (1) 根据min scale、max scale、min aspect ratio、max aspect ratio计算随机剪裁的高、宽。 - (2) 根据随机剪裁的高、宽随机选取剪裁的起始点。 - (3) 筛选出裁剪候选区域: - - 当satisfy_all为True时,需所有真实标注框与裁剪候选区域的重叠度满足需求时,该裁剪候选区域才可保留。 - - 当satisfy_all为False时,当有一个真实标注框与裁剪候选区域的重叠度满足需求时,该裁剪候选区域就可保留。 - 2. 遍历所有裁剪候选区域: - (1) 若真实标注框与候选裁剪区域不重叠,或其中心点不在候选裁剪区域, - 则将该真实标注框去除。 - (2) 计算相对于该候选裁剪区域,真实标注框的位置,并筛选出对应的类别、混合得分。 - (3) 若avoid_no_bbox为False,返回当前裁剪后的信息即可; - 反之,要找到一个裁剪区域中真实标注框个数不为0的区域,才返回裁剪后的信息。 + 1. 若allow_no_crop为True,则在thresholds加入’no_crop’ + 2. 随机打乱thresholds + 3. 遍历thresholds中各元素: + (1) 如果当前thresh为’no_crop’,则返回原始图像和标注信息 + (2) 随机取出aspect_ratio和scaling中的值并由此计算出候选裁剪区域的高、宽、起始点。 + (3) 计算真实标注框与候选裁剪区域IoU,若全部真实标注框的IoU都小于thresh,则继续第3步 + (4) 如果cover_all_box为True且存在真实标注框的IoU小于thresh,则继续第3步 + (5) 筛选出位于候选裁剪区域内的真实标注框,若有效框的个数为0,则继续第3步,否则进行第4步。 + 4. 换算有效真值标注框相对候选裁剪区域的位置坐标。 + 5. 换算有效分割区域相对候选裁剪区域的位置坐标。 Args: - batch_sampler (list): 随机裁剪参数的多种组合,每种组合包含8个值,如下: - - max sample (int):满足当前组合的裁剪区域的个数上限。 - - max trial (int): 查找满足当前组合的次数。 - - min scale (float): 裁剪面积相对原面积,每条边缩短比例的最小限制。 - - max scale (float): 裁剪面积相对原面积,每条边缩短比例的最大限制。 - - min aspect ratio (float): 裁剪后短边缩放比例的最小限制。 - - max aspect ratio (float): 裁剪后短边缩放比例的最大限制。 - - min overlap (float): 真实标注框与裁剪图像重叠面积的最小限制。 - - max overlap (float): 真实标注框与裁剪图像重叠面积的最大限制。 - 默认值为None,当为None时采用如下设置: - [[1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0], - [1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 1.0], - [1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 1.0], - [1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 1.0], - [1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 1.0], - [1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 1.0], - [1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]] - satisfy_all (bool): 是否需要所有标注框满足条件,裁剪候选区域才保留。默认为False。 - avoid_no_bbox (bool): 是否对裁剪图像不存在标注框的图像进行保留。默认为True。 - + aspect_ratio (list): 裁剪后短边缩放比例的取值范围,以[min, max]形式表示。默认值为[.5, 2.]。 + thresholds (list): 判断裁剪候选区域是否有效所需的IoU阈值取值列表。默认值为[.0, .1, .3, .5, .7, .9]。 + scaling (list): 裁剪面积相对原面积的取值范围,以[min, max]形式表示。默认值为[.3, 1.]。 + num_attempts (int): 在放弃寻找有效裁剪区域前尝试的次数。默认值为50。 + allow_no_crop (bool): 是否允许未进行裁剪。默认值为True。 + cover_all_box (bool): 是否要求所有的真实标注框都必须在裁剪区域内。默认值为False。 """ def __init__(self, - batch_sampler=None, - satisfy_all=False, - avoid_no_bbox=True): - if batch_sampler is None: - batch_sampler = [[1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], - [1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 1.0], - [1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 1.0], - [1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 1.0], - [1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 1.0], - [1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 1.0], - [1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]] - self.batch_sampler = batch_sampler - self.satisfy_all = satisfy_all - self.avoid_no_bbox = avoid_no_bbox + aspect_ratio=[.5, 2.], + thresholds=[.0, .1, .3, .5, .7, .9], + scaling=[.3, 1.], + num_attempts=50, + allow_no_crop=True, + cover_all_box=False): + self.aspect_ratio = aspect_ratio + self.thresholds = thresholds + self.scaling = scaling + self.num_attempts = num_attempts + self.allow_no_crop = allow_no_crop + self.cover_all_box = cover_all_box def __call__(self, im, im_info=None, label_info=None): """ @@ -859,66 +843,84 @@ class RandomCrop: 'gt_class' not in label_info: raise TypeError('Cannot do RandomCrop! ' + \ 'Becasuse gt_bbox/gt_class is not in label_info!') + + if len(label_info['gt_bbox']) == 0: + return (im, im_info, label_info) + augment_shape = im_info['augment_shape'] - im_width = augment_shape[1] - im_height = augment_shape[0] + w = augment_shape[1] + h = augment_shape[0] gt_bbox = label_info['gt_bbox'] - gt_bbox_tmp = gt_bbox.copy() - for i in range(gt_bbox_tmp.shape[0]): - gt_bbox_tmp[i][0] = gt_bbox[i][0] / im_width - gt_bbox_tmp[i][1] = gt_bbox[i][1] / im_height - gt_bbox_tmp[i][2] = gt_bbox[i][2] / im_width - gt_bbox_tmp[i][3] = gt_bbox[i][3] / im_height - gt_class = label_info['gt_class'] - - gt_score = None - if 'gt_score' in label_info: - gt_score = label_info['gt_score'] - sampled_bbox = [] - gt_bbox_tmp = gt_bbox_tmp.tolist() - for sampler in self.batch_sampler: - found = 0 - for i in range(sampler[1]): - if found >= sampler[0]: - break - sample_bbox = generate_sample_bbox(sampler) - if satisfy_sample_constraint(sampler, sample_bbox, gt_bbox_tmp, - self.satisfy_all): - sampled_bbox.append(sample_bbox) - found = found + 1 - im = np.array(im) - while sampled_bbox: - idx = int(np.random.uniform(0, len(sampled_bbox))) - sample_bbox = sampled_bbox.pop(idx) - sample_bbox = clip_bbox(sample_bbox) - crop_bbox, crop_class, crop_score = \ - filter_and_process(sample_bbox, gt_bbox_tmp, gt_class, gt_score) - if self.avoid_no_bbox: - if len(crop_bbox) < 1: + thresholds = list(self.thresholds) + if self.allow_no_crop: + thresholds.append('no_crop') + np.random.shuffle(thresholds) + + for thresh in thresholds: + if thresh == 'no_crop': + return (im, im_info, label_info) + + found = False + for i in range(self.num_attempts): + scale = np.random.uniform(*self.scaling) + min_ar, max_ar = self.aspect_ratio + aspect_ratio = np.random.uniform( + max(min_ar, scale**2), min(max_ar, scale**-2)) + crop_h = int(h * scale / np.sqrt(aspect_ratio)) + crop_w = int(w * scale * np.sqrt(aspect_ratio)) + crop_y = np.random.randint(0, h - crop_h) + crop_x = np.random.randint(0, w - crop_w) + crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h] + iou = iou_matrix(gt_bbox, np.array([crop_box], + dtype=np.float32)) + if iou.max() < thresh: continue - xmin = int(sample_bbox[0] * im_width) - xmax = int(sample_bbox[2] * im_width) - ymin = int(sample_bbox[1] * im_height) - ymax = int(sample_bbox[3] * im_height) - im = im[ymin:ymax, xmin:xmax] - for i in range(crop_bbox.shape[0]): - crop_bbox[i][0] = crop_bbox[i][0] * (xmax - xmin) - crop_bbox[i][1] = crop_bbox[i][1] * (ymax - ymin) - crop_bbox[i][2] = crop_bbox[i][2] * (xmax - xmin) - crop_bbox[i][3] = crop_bbox[i][3] * (ymax - ymin) - label_info['gt_bbox'] = crop_bbox - label_info['gt_class'] = crop_class - label_info['gt_score'] = crop_score - im_info['augment_shape'] = np.array([ymax - ymin, - xmax - xmin]).astype('int32') - if label_info is None: - return (im, im_info) - else: + + if self.cover_all_box and iou.min() < thresh: + continue + + cropped_box, valid_ids = crop_box_with_center_constraint( + gt_bbox, np.array(crop_box, dtype=np.float32)) + if valid_ids.size > 0: + found = True + break + + if found: + if 'gt_poly' in label_info and len(label_info['gt_poly']) > 0: + crop_polys = crop_segms(label_info['gt_poly'], valid_ids, + np.array(crop_box, dtype=np.int64), + h, w) + if [] in crop_polys: + delete_id = list() + valid_polys = list() + for id, crop_poly in enumerate(crop_polys): + if crop_poly == []: + delete_id.append(id) + else: + valid_polys.append(crop_poly) + valid_ids = np.delete(valid_ids, delete_id) + if len(valid_polys) == 0: + return (im, im_info, label_info) + label_info['gt_poly'] = valid_polys + else: + label_info['gt_poly'] = crop_polys + im = crop_image(im, crop_box) + label_info['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0) + label_info['gt_class'] = np.take( + label_info['gt_class'], valid_ids, axis=0) + im_info['augment_shape'] = np.array( + [crop_box[3] - crop_box[1], + crop_box[2] - crop_box[0]]).astype('int32') + if 'gt_score' in label_info: + label_info['gt_score'] = np.take( + label_info['gt_score'], valid_ids, axis=0) + + if 'is_crowd' in label_info: + label_info['is_crowd'] = np.take( + label_info['is_crowd'], valid_ids, axis=0) return (im, im_info, label_info) - if label_info is None: - return (im, im_info) - else: - return (im, im_info, label_info) + + return (im, im_info, label_info) class ArrangeFasterRCNN: -- GitLab