提交 1d2e74d9 编写于 作者: F FlyingQianMM

randomcrop and randomexpand for mask

上级 cbf72fca
......@@ -122,56 +122,42 @@ paddlex.det.transforms.MixupImage(alpha=1.5, beta=1.5, mixup_epoch=-1)
## RandomExpand类
```python
paddlex.det.transforms.RandomExpand(max_ratio=4., prob=0.5, mean=[127.5, 127.5, 127.5])
paddlex.det.transforms.RandomExpand(ratio=4., prob=0.5, fill_value=[123.675, 116.28, 103.53])
```
随机扩张图像,模型训练时的数据增强操作,模型训练时的数据增强操作。
随机扩张图像,模型训练时的数据增强操作
1. 随机选取扩张比例(扩张比例大于1时才进行扩张)。
2. 计算扩张后图像大小。
3. 初始化像素值为数据集均值的图像,并将原图像随机粘贴于该图像上。
3. 初始化像素值为输入填充值的图像,并将原图像随机粘贴于该图像上。
4. 根据原图像粘贴位置换算出扩张后真实标注框的位置坐标。
5. 根据原图像粘贴位置换算出扩张后真实分割区域的位置坐标。
### 参数
* **max_ratio** (float): 图像扩张的最大比例。默认为4.0。
* **ratio** (float): 图像扩张的最大比例。默认为4.0。
* **prob** (float): 随机扩张的概率。默认为0.5。
* **mean** (list): 图像数据集的均值(0-255)。默认为[127.5, 127.5, 127.5]。
* **fill_value** (list): 扩张图像的初始填充值(0-255)。默认为[123.675, 116.28, 103.53]。
## RandomCrop类
```python
paddlex.det.transforms.RandomCrop(batch_sampler=None, satisfy_all=False, avoid_no_bbox=True)
paddlex.det.transforms.RandomCrop(aspect_ratio=[.5, 2.], thresholds=[.0, .1, .3, .5, .7, .9], scaling=[.3, 1.], num_attempts=50, allow_no_crop=True, cover_all_box=False)
```
随机裁剪图像,模型训练时的数据增强操作。
1. 根据batch_sampler计算获取裁剪候选区域的位置。
(1) 根据min scale、max scale、min aspect ratio、max aspect ratio计算随机剪裁的高、宽。
(2) 根据随机剪裁的高、宽随机选取剪裁的起始点。
(3) 筛选出裁剪候选区域:
* 当satisfy_all为True时,需所有真实标注框与裁剪候选区域的重叠度满足需求时,该裁剪候选区域才可保留。
* 当satisfy_all为False时,当有一个真实标注框与裁剪候选区域的重叠度满足需求时,该裁剪候选区域就可保留。
2. 遍历所有裁剪候选区域:
(1) 若真实标注框与候选裁剪区域不重叠,或其中心点不在候选裁剪区域,则将该真实标注框去除。
(2) 计算相对于该候选裁剪区域,真实标注框的位置,并筛选出对应的类别、混合得分。
(3) 若avoid_no_bbox为False,返回当前裁剪后的信息即可;反之,要找到一个裁剪区域中真实标注框个数不为0的区域,才返回裁剪后的信息
1. 若allow_no_crop为True,则在thresholds加入’no_crop’
2. 随机打乱thresholds
3. 遍历thresholds中各元素:
(1) 如果当前thresh为’no_crop’,则返回原始图像和标注信息
(2) 随机取出aspect_ratio和scaling中的值并由此计算出候选裁剪区域的高、宽、起始点。
(3) 计算真实标注框与候选裁剪区域IoU,若全部真实标注框的IoU都小于thresh,则继续第3步
(4) 如果cover_all_box为True且存在真实标注框的IoU小于thresh,则继续第3步
(5) 筛选出位于候选裁剪区域内的真实标注框,若有效框的个数为0,则继续第3步,否则进行第4步。
4. 换算有效真值标注框相对候选裁剪区域的位置坐标。
5. 换算有效分割区域相对候选裁剪区域的位置坐标
### 参数
* **batch_sampler** (list): 随机裁剪参数的多种组合,每种组合包含8个值,如下:
- max sample (int):满足当前组合的裁剪区域的个数上限。
- max trial (int): 查找满足当前组合的次数。
- min scale (float): 裁剪面积相对原面积,每条边缩短比例的最小限制。
- max scale (float): 裁剪面积相对原面积,每条边缩短比例的最大限制。
- min aspect ratio (float): 裁剪后短边缩放比例的最小限制。
- max aspect ratio (float): 裁剪后短边缩放比例的最大限制。
- min overlap (float): 真实标注框与裁剪图像重叠面积的最小限制。
- max overlap (float): 真实标注框与裁剪图像重叠面积的最大限制。
默认值为None,当为None时采用如下设置:
[[1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]]
* **satisfy_all** (bool): 是否需要所有标注框满足条件,裁剪候选区域才保留。默认为False。
* **avoid_no_bbox** (bool): 是否对裁剪图像不存在标注框的图像进行保留。默认为True。
* **aspect_ratio** (list): 裁剪后短边缩放比例的取值范围,以[min, max]形式表示。默认值为[.5, 2.]。
* **thresholds** (list): 判断裁剪候选区域是否有效所需的IoU阈值取值列表。默认值为[.0, .1, .3, .5, .7, .9]。
* **scaling** (list): 裁剪面积相对原面积的取值范围,以[min, max]形式表示。默认值为[.3, 1.]。
* **num_attempts** (int): 在放弃寻找有效裁剪区域前尝试的次数。默认值为50。
* **allow_no_crop** (bool): 是否允许未进行裁剪。默认值为True。
* **cover_all_box** (bool): 是否要求所有的真实标注框都必须在裁剪区域内。默认值为False。
......@@ -4,10 +4,10 @@
在服务端部署的模型需要首先将模型导出为inference格式模型,导出的模型将包括`__model__``__params__``model.yml`三个文名,分别为模型的网络结构,模型权重和模型的配置文件(包括数据预处理参数等等)。在安装完PaddleX后,在命令行终端使用如下命令导出模型到当前目录`inferece_model`下。
> 可直接下载垃圾检测模型测试本文档的流程[garbage_epoch_12.tar.gz](https://bj.bcebos.com/paddlex/models/garbage_epoch_12.tar.gz)
> 可直接下载垃圾检测模型测试本文档的流程[xiaoduxiong_epoch_12.tar.gz](https://bj.bcebos.com/paddlex/models/xiaoduxiong_epoch_12.tar.gz)
```
paddlex --export_inference --model_dir=./garbage_epoch_12 --save_dir=./inference_model
paddlex --export_inference --model_dir=./xiaoduxiong_epoch_12 --save_dir=./inference_model
```
## 模型C++和Python部署方案预计一周内推出...
......@@ -19,25 +19,6 @@ import cv2
import scipy
def meet_emit_constraint(src_bbox, sample_bbox):
center_x = (src_bbox[2] + src_bbox[0]) / 2
center_y = (src_bbox[3] + src_bbox[1]) / 2
if center_x >= sample_bbox[0] and \
center_x <= sample_bbox[2] and \
center_y >= sample_bbox[1] and \
center_y <= sample_bbox[3]:
return True
return False
def clip_bbox(src_bbox):
src_bbox[0] = max(min(src_bbox[0], 1.0), 0.0)
src_bbox[1] = max(min(src_bbox[1], 1.0), 0.0)
src_bbox[2] = max(min(src_bbox[2], 1.0), 0.0)
src_bbox[3] = max(min(src_bbox[3], 1.0), 0.0)
return src_bbox
def bbox_area(src_bbox):
if src_bbox[2] < src_bbox[0] or src_bbox[3] < src_bbox[1]:
return 0.
......@@ -47,189 +28,6 @@ def bbox_area(src_bbox):
return width * height
def is_overlap(object_bbox, sample_bbox):
if object_bbox[0] >= sample_bbox[2] or \
object_bbox[2] <= sample_bbox[0] or \
object_bbox[1] >= sample_bbox[3] or \
object_bbox[3] <= sample_bbox[1]:
return False
else:
return True
def filter_and_process(sample_bbox, bboxes, labels, scores=None):
new_bboxes = []
new_labels = []
new_scores = []
for i in range(len(bboxes)):
new_bbox = [0, 0, 0, 0]
obj_bbox = [bboxes[i][0], bboxes[i][1], bboxes[i][2], bboxes[i][3]]
if not meet_emit_constraint(obj_bbox, sample_bbox):
continue
if not is_overlap(obj_bbox, sample_bbox):
continue
sample_width = sample_bbox[2] - sample_bbox[0]
sample_height = sample_bbox[3] - sample_bbox[1]
new_bbox[0] = (obj_bbox[0] - sample_bbox[0]) / sample_width
new_bbox[1] = (obj_bbox[1] - sample_bbox[1]) / sample_height
new_bbox[2] = (obj_bbox[2] - sample_bbox[0]) / sample_width
new_bbox[3] = (obj_bbox[3] - sample_bbox[1]) / sample_height
new_bbox = clip_bbox(new_bbox)
if bbox_area(new_bbox) > 0:
new_bboxes.append(new_bbox)
new_labels.append([labels[i][0]])
if scores is not None:
new_scores.append([scores[i][0]])
bboxes = np.array(new_bboxes)
labels = np.array(new_labels)
scores = np.array(new_scores)
return bboxes, labels, scores
def bbox_area_sampling(bboxes, labels, scores, target_size, min_size):
new_bboxes = []
new_labels = []
new_scores = []
for i, bbox in enumerate(bboxes):
w = float((bbox[2] - bbox[0]) * target_size)
h = float((bbox[3] - bbox[1]) * target_size)
if w * h < float(min_size * min_size):
continue
else:
new_bboxes.append(bbox)
new_labels.append(labels[i])
if scores is not None and scores.size != 0:
new_scores.append(scores[i])
bboxes = np.array(new_bboxes)
labels = np.array(new_labels)
scores = np.array(new_scores)
return bboxes, labels, scores
def generate_sample_bbox(sampler):
scale = np.random.uniform(sampler[2], sampler[3])
aspect_ratio = np.random.uniform(sampler[4], sampler[5])
aspect_ratio = max(aspect_ratio, (scale**2.0))
aspect_ratio = min(aspect_ratio, 1 / (scale**2.0))
bbox_width = scale * (aspect_ratio**0.5)
bbox_height = scale / (aspect_ratio**0.5)
xmin_bound = 1 - bbox_width
ymin_bound = 1 - bbox_height
xmin = np.random.uniform(0, xmin_bound)
ymin = np.random.uniform(0, ymin_bound)
xmax = xmin + bbox_width
ymax = ymin + bbox_height
sampled_bbox = [xmin, ymin, xmax, ymax]
return sampled_bbox
def generate_sample_bbox_square(sampler, image_width, image_height):
scale = np.random.uniform(sampler[2], sampler[3])
aspect_ratio = np.random.uniform(sampler[4], sampler[5])
aspect_ratio = max(aspect_ratio, (scale**2.0))
aspect_ratio = min(aspect_ratio, 1 / (scale**2.0))
bbox_width = scale * (aspect_ratio**0.5)
bbox_height = scale / (aspect_ratio**0.5)
if image_height < image_width:
bbox_width = bbox_height * image_height / image_width
else:
bbox_height = bbox_width * image_width / image_height
xmin_bound = 1 - bbox_width
ymin_bound = 1 - bbox_height
xmin = np.random.uniform(0, xmin_bound)
ymin = np.random.uniform(0, ymin_bound)
xmax = xmin + bbox_width
ymax = ymin + bbox_height
sampled_bbox = [xmin, ymin, xmax, ymax]
return sampled_bbox
def data_anchor_sampling(bbox_labels, image_width, image_height, scale_array,
resize_width):
num_gt = len(bbox_labels)
# np.random.randint range: [low, high)
rand_idx = np.random.randint(0, num_gt) if num_gt != 0 else 0
if num_gt != 0:
norm_xmin = bbox_labels[rand_idx][0]
norm_ymin = bbox_labels[rand_idx][1]
norm_xmax = bbox_labels[rand_idx][2]
norm_ymax = bbox_labels[rand_idx][3]
xmin = norm_xmin * image_width
ymin = norm_ymin * image_height
wid = image_width * (norm_xmax - norm_xmin)
hei = image_height * (norm_ymax - norm_ymin)
range_size = 0
area = wid * hei
for scale_ind in range(0, len(scale_array) - 1):
if area > scale_array[scale_ind] ** 2 and area < \
scale_array[scale_ind + 1] ** 2:
range_size = scale_ind + 1
break
if area > scale_array[len(scale_array) - 2]**2:
range_size = len(scale_array) - 2
scale_choose = 0.0
if range_size == 0:
rand_idx_size = 0
else:
# np.random.randint range: [low, high)
rng_rand_size = np.random.randint(0, range_size + 1)
rand_idx_size = rng_rand_size % (range_size + 1)
if rand_idx_size == range_size:
min_resize_val = scale_array[rand_idx_size] / 2.0
max_resize_val = min(2.0 * scale_array[rand_idx_size],
2 * math.sqrt(wid * hei))
scale_choose = random.uniform(min_resize_val, max_resize_val)
else:
min_resize_val = scale_array[rand_idx_size] / 2.0
max_resize_val = 2.0 * scale_array[rand_idx_size]
scale_choose = random.uniform(min_resize_val, max_resize_val)
sample_bbox_size = wid * resize_width / scale_choose
w_off_orig = 0.0
h_off_orig = 0.0
if sample_bbox_size < max(image_height, image_width):
if wid <= sample_bbox_size:
w_off_orig = np.random.uniform(xmin + wid - sample_bbox_size,
xmin)
else:
w_off_orig = np.random.uniform(xmin,
xmin + wid - sample_bbox_size)
if hei <= sample_bbox_size:
h_off_orig = np.random.uniform(ymin + hei - sample_bbox_size,
ymin)
else:
h_off_orig = np.random.uniform(ymin,
ymin + hei - sample_bbox_size)
else:
w_off_orig = np.random.uniform(image_width - sample_bbox_size, 0.0)
h_off_orig = np.random.uniform(image_height - sample_bbox_size,
0.0)
w_off_orig = math.floor(w_off_orig)
h_off_orig = math.floor(h_off_orig)
# Figure out top left coordinates.
w_off = float(w_off_orig / image_width)
h_off = float(h_off_orig / image_height)
sampled_bbox = [
w_off, h_off, w_off + float(sample_bbox_size / image_width),
h_off + float(sample_bbox_size / image_height)
]
return sampled_bbox
else:
return 0
def jaccard_overlap(sample_bbox, object_bbox):
if sample_bbox[0] >= object_bbox[2] or \
sample_bbox[2] <= object_bbox[0] or \
......@@ -249,143 +47,143 @@ def jaccard_overlap(sample_bbox, object_bbox):
return overlap
def intersect_bbox(bbox1, bbox2):
if bbox2[0] > bbox1[2] or bbox2[2] < bbox1[0] or \
bbox2[1] > bbox1[3] or bbox2[3] < bbox1[1]:
intersection_box = [0.0, 0.0, 0.0, 0.0]
else:
intersection_box = [
max(bbox1[0], bbox2[0]),
max(bbox1[1], bbox2[1]),
min(bbox1[2], bbox2[2]),
min(bbox1[3], bbox2[3])
]
return intersection_box
def bbox_coverage(bbox1, bbox2):
inter_box = intersect_bbox(bbox1, bbox2)
intersect_size = bbox_area(inter_box)
if intersect_size > 0:
bbox1_size = bbox_area(bbox1)
return intersect_size / bbox1_size
else:
return 0.
def iou_matrix(a, b):
tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
area_o = (area_a[:, np.newaxis] + area_b - area_i)
return area_i / (area_o + 1e-10)
def satisfy_sample_constraint(sampler,
sample_bbox,
gt_bboxes,
satisfy_all=False):
if sampler[6] == 0 and sampler[7] == 0:
return True
satisfied = []
for i in range(len(gt_bboxes)):
object_bbox = [
gt_bboxes[i][0], gt_bboxes[i][1], gt_bboxes[i][2], gt_bboxes[i][3]
]
overlap = jaccard_overlap(sample_bbox, object_bbox)
if sampler[6] != 0 and \
overlap < sampler[6]:
satisfied.append(False)
def crop_box_with_center_constraint(box, crop):
cropped_box = box.copy()
cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2])
cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:])
cropped_box[:, :2] -= crop[:2]
cropped_box[:, 2:] -= crop[:2]
centers = (box[:, :2] + box[:, 2:]) / 2
valid = np.logical_and(crop[:2] <= centers, centers < crop[2:]).all(axis=1)
valid = np.logical_and(
valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
return cropped_box, np.where(valid)[0]
def is_poly(segm):
if not isinstance(segm, (list, dict)):
raise Exception("Invalid segm type: {}".format(type(segm)))
return isinstance(segm, list)
def crop_image(img, crop):
x1, y1, x2, y2 = crop
return img[y1:y2, x1:x2, :]
def crop_segms(segms, valid_ids, crop, height, width):
def _crop_poly(segm, crop):
xmin, ymin, xmax, ymax = crop
crop_coord = [xmin, ymin, xmin, ymax, xmax, ymax, xmax, ymin]
crop_p = np.array(crop_coord).reshape(4, 2)
crop_p = Polygon(crop_p)
crop_segm = list()
for poly in segm:
poly = np.array(poly).reshape(len(poly) // 2, 2)
polygon = Polygon(poly)
if not polygon.is_valid:
exterior = polygon.exterior
multi_lines = exterior.intersection(exterior)
polygons = shapely.ops.polygonize(multi_lines)
polygon = MultiPolygon(polygons)
multi_polygon = list()
if isinstance(polygon, MultiPolygon):
multi_polygon = copy.deepcopy(polygon)
else:
multi_polygon.append(copy.deepcopy(polygon))
for per_polygon in multi_polygon:
inter = per_polygon.intersection(crop_p)
if not inter:
continue
if sampler[7] != 0 and \
overlap > sampler[7]:
satisfied.append(False)
if isinstance(inter, (MultiPolygon, GeometryCollection)):
for part in inter:
if not isinstance(part, Polygon):
continue
satisfied.append(True)
if not satisfy_all:
return True
if satisfy_all:
return np.all(satisfied)
part = np.squeeze(
np.array(part.exterior.coords[:-1]).reshape(1, -1))
part[0::2] -= xmin
part[1::2] -= ymin
crop_segm.append(part.tolist())
elif isinstance(inter, Polygon):
crop_poly = np.squeeze(
np.array(inter.exterior.coords[:-1]).reshape(1, -1))
crop_poly[0::2] -= xmin
crop_poly[1::2] -= ymin
crop_segm.append(crop_poly.tolist())
else:
return False
continue
return crop_segm
def _crop_rle(rle, crop, height, width):
if 'counts' in rle and type(rle['counts']) == list:
rle = mask_util.frPyObjects(rle, height, width)
mask = mask_util.decode(rle)
mask = mask[crop[1]:crop[3], crop[0]:crop[2]]
rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
return rle
def satisfy_sample_constraint_coverage(sampler, sample_bbox, gt_bboxes):
if sampler[6] == 0 and sampler[7] == 0:
has_jaccard_overlap = False
crop_segms = []
for id in valid_ids:
segm = segms[id]
if is_poly(segm):
import copy
import shapely.ops
import logging
from shapely.geometry import Polygon, MultiPolygon, GeometryCollection
logging.getLogger("shapely").setLevel(logging.WARNING)
# Polygon format
crop_segms.append(_crop_poly(segm, crop))
else:
has_jaccard_overlap = True
if sampler[8] == 0 and sampler[9] == 0:
has_object_coverage = False
# RLE format
import pycocotools.mask as mask_util
crop_segms.append(_crop_rle(segm, crop, height, width))
return crop_segms
def expand_segms(segms, x, y, height, width, ratio):
def _expand_poly(poly, x, y):
expanded_poly = np.array(poly)
expanded_poly[0::2] += x
expanded_poly[1::2] += y
return expanded_poly.tolist()
def _expand_rle(rle, x, y, height, width, ratio):
if 'counts' in rle and type(rle['counts']) == list:
rle = mask_util.frPyObjects(rle, height, width)
mask = mask_util.decode(rle)
expanded_mask = np.full((int(height * ratio), int(width * ratio)),
0).astype(mask.dtype)
expanded_mask[y:y + height, x:x + width] = mask
rle = mask_util.encode(
np.array(expanded_mask, order='F', dtype=np.uint8))
return rle
expanded_segms = []
for segm in segms:
if is_poly(segm):
# Polygon format
expanded_segms.append([_expand_poly(poly, x, y) for poly in segm])
else:
has_object_coverage = True
if not has_jaccard_overlap and not has_object_coverage:
return True
found = False
for i in range(len(gt_bboxes)):
object_bbox = [
gt_bboxes[i][0], gt_bboxes[i][1], gt_bboxes[i][2], gt_bboxes[i][3]
]
if has_jaccard_overlap:
overlap = jaccard_overlap(sample_bbox, object_bbox)
if sampler[6] != 0 and \
overlap < sampler[6]:
continue
if sampler[7] != 0 and \
overlap > sampler[7]:
continue
found = True
if has_object_coverage:
object_coverage = bbox_coverage(object_bbox, sample_bbox)
if sampler[8] != 0 and \
object_coverage < sampler[8]:
continue
if sampler[9] != 0 and \
object_coverage > sampler[9]:
continue
found = True
if found:
return True
return found
def crop_image_sampling(img, sample_bbox, image_width, image_height,
target_size):
# no clipping here
xmin = int(sample_bbox[0] * image_width)
xmax = int(sample_bbox[2] * image_width)
ymin = int(sample_bbox[1] * image_height)
ymax = int(sample_bbox[3] * image_height)
w_off = xmin
h_off = ymin
width = xmax - xmin
height = ymax - ymin
cross_xmin = max(0.0, float(w_off))
cross_ymin = max(0.0, float(h_off))
cross_xmax = min(float(w_off + width - 1.0), float(image_width))
cross_ymax = min(float(h_off + height - 1.0), float(image_height))
cross_width = cross_xmax - cross_xmin
cross_height = cross_ymax - cross_ymin
roi_xmin = 0 if w_off >= 0 else abs(w_off)
roi_ymin = 0 if h_off >= 0 else abs(h_off)
roi_width = cross_width
roi_height = cross_height
roi_y1 = int(roi_ymin)
roi_y2 = int(roi_ymin + roi_height)
roi_x1 = int(roi_xmin)
roi_x2 = int(roi_xmin + roi_width)
cross_y1 = int(cross_ymin)
cross_y2 = int(cross_ymin + cross_height)
cross_x1 = int(cross_xmin)
cross_x2 = int(cross_xmin + cross_width)
sample_img = np.zeros((height, width, 3))
sample_img[roi_y1: roi_y2, roi_x1: roi_x2] = \
img[cross_y1: cross_y2, cross_x1: cross_x2]
sample_img = cv2.resize(
sample_img, (target_size, target_size), interpolation=cv2.INTER_AREA)
return sample_img
# RLE format
import pycocotools.mask as mask_util
expanded_segms.append(
_expand_rle(segm, x, y, height, width, ratio))
return expanded_segms
def box_horizontal_flip(bboxes, width):
......@@ -409,15 +207,10 @@ def segms_horizontal_flip(segms, height, width):
if 'counts' in rle and type(rle['counts']) == list:
rle = mask_util.frPyObjects([rle], height, width)
mask = mask_util.decode(rle)
mask = mask[:, ::-1, :]
mask = mask[:, ::-1]
rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
return rle
def is_poly(segm):
if not isinstance(segm, (list, dict)):
raise Exception("Invalid segm type: {}".format(type(segm)))
return isinstance(segm, list)
flipped_segms = []
for segm in segms:
if is_poly(segm):
......
......@@ -12,13 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .ops import *
from .box_utils import *
try:
from collections.abc import Sequence
except Exception:
from collections import Sequence
from numbers import Number
import random
import os.path as osp
import numpy as np
from PIL import Image, ImageEnhance
import cv2
from PIL import Image, ImageEnhance
from .ops import *
from .box_utils import *
class Compose:
......@@ -658,9 +667,17 @@ class MixupImage:
gt_score2 = im_info['mixup'][2]['gt_score']
gt_score = np.concatenate(
(gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
if 'gt_poly' in label_info:
gt_poly1 = label_info['gt_poly']
gt_poly2 = im_info['mixup'][2]['gt_poly']
label_info['gt_poly'] = gt_poly1 + gt_poly2
is_crowd1 = label_info['is_crowd']
is_crowd2 = im_info['mixup'][2]['is_crowd']
is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0)
label_info['gt_bbox'] = gt_bbox
label_info['gt_score'] = gt_score
label_info['gt_class'] = gt_class
label_info['is_crowd'] = is_crowd
im_info['augment_shape'] = np.array([im.shape[0],
im.shape[1]]).astype('int32')
im_info.pop('mixup')
......@@ -672,23 +689,32 @@ class MixupImage:
class RandomExpand:
"""随机扩张图像,模型训练时的数据增强操作。
1. 随机选取扩张比例(扩张比例大于1时才进行扩张)。
2. 计算扩张后图像大小。
3. 初始化像素值为数据集均值的图像,并将原图像随机粘贴于该图像上。
3. 初始化像素值为输入填充值的图像,并将原图像随机粘贴于该图像上。
4. 根据原图像粘贴位置换算出扩张后真实标注框的位置坐标。
5. 根据原图像粘贴位置换算出扩张后真实分割区域的位置坐标。
Args:
max_ratio (float): 图像扩张的最大比例。默认为4.0。
ratio (float): 图像扩张的最大比例。默认为4.0。
prob (float): 随机扩张的概率。默认为0.5。
mean (list): 图像数据集的均值(0-255)。默认为[127.5, 127.5, 127.5]。
fill_value (list): 扩张图像的初始填充值(0-255)。默认为[123.675, 116.28, 103.53]。
"""
def __init__(self, max_ratio=4., prob=0.5, mean=[127.5, 127.5, 127.5]):
self.max_ratio = max_ratio
self.mean = mean
def __init__(self,
ratio=4.,
prob=0.5,
fill_value=[123.675, 116.28, 103.53]):
super(RandomExpand, self).__init__()
assert ratio > 1.01, "expand ratio must be larger than 1.01"
self.ratio = ratio
self.prob = prob
assert isinstance(fill_value, (Number, Sequence)), \
"fill value must be either float or sequence"
if isinstance(fill_value, Number):
fill_value = (fill_value, ) * 3
if not isinstance(fill_value, tuple):
fill_value = tuple(fill_value)
self.fill_value = fill_value
def __call__(self, im, im_info=None, label_info=None):
"""
......@@ -696,7 +722,6 @@ class RandomExpand:
im (np.ndarray): 图像np.ndarray数据。
im_info (dict, 可选): 存储与图像相关的信息。
label_info (dict, 可选): 存储与标注框相关的信息。
Returns:
tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
......@@ -708,7 +733,6 @@ class RandomExpand:
其中n代表真实标注框的个数。
- gt_class (np.ndarray): 随机扩张后每个真实标注框对应的类别序号,形状为(n, 1),
其中n代表真实标注框的个数。
Raises:
TypeError: 形参数据类型不满足需求。
"""
......@@ -723,108 +747,68 @@ class RandomExpand:
'gt_class' not in label_info:
raise TypeError('Cannot do RandomExpand! ' + \
'Becasuse gt_bbox/gt_class is not in label_info!')
prob = np.random.uniform(0, 1)
if np.random.uniform(0., 1.) < self.prob:
return (im, im_info, label_info)
augment_shape = im_info['augment_shape']
im_width = augment_shape[1]
im_height = augment_shape[0]
gt_bbox = label_info['gt_bbox']
gt_class = label_info['gt_class']
height = int(augment_shape[0])
width = int(augment_shape[1])
if prob < self.prob:
if self.max_ratio - 1 >= 0.01:
expand_ratio = np.random.uniform(1, self.max_ratio)
height = int(im_height * expand_ratio)
width = int(im_width * expand_ratio)
h_off = math.floor(np.random.uniform(0, height - im_height))
w_off = math.floor(np.random.uniform(0, width - im_width))
expand_bbox = [
-w_off / im_width, -h_off / im_height,
(width - w_off) / im_width, (height - h_off) / im_height
]
expand_im = np.ones((height, width, 3))
expand_im = np.uint8(expand_im * np.squeeze(self.mean))
expand_im = Image.fromarray(expand_im)
im = im.astype('uint8')
im = Image.fromarray(im)
expand_im.paste(im, (int(w_off), int(h_off)))
expand_im = np.asarray(expand_im)
for i in range(gt_bbox.shape[0]):
gt_bbox[i][0] = gt_bbox[i][0] / im_width
gt_bbox[i][1] = gt_bbox[i][1] / im_height
gt_bbox[i][2] = gt_bbox[i][2] / im_width
gt_bbox[i][3] = gt_bbox[i][3] / im_height
gt_bbox, gt_class, _ = filter_and_process(
expand_bbox, gt_bbox, gt_class)
for i in range(gt_bbox.shape[0]):
gt_bbox[i][0] = gt_bbox[i][0] * width
gt_bbox[i][1] = gt_bbox[i][1] * height
gt_bbox[i][2] = gt_bbox[i][2] * width
gt_bbox[i][3] = gt_bbox[i][3] * height
im = expand_im.astype('float32')
label_info['gt_bbox'] = gt_bbox
label_info['gt_class'] = gt_class
im_info['augment_shape'] = np.array([height,
width]).astype('int32')
if label_info is None:
return (im, im_info)
else:
expand_ratio = np.random.uniform(1., self.ratio)
h = int(height * expand_ratio)
w = int(width * expand_ratio)
if not h > height or not w > width:
return (im, im_info, label_info)
y = np.random.randint(0, h - height)
x = np.random.randint(0, w - width)
canvas = np.ones((h, w, 3), dtype=np.uint8)
canvas *= np.array(self.fill_value, dtype=np.uint8)
canvas[y:y + height, x:x + width, :] = im.astype(np.uint8)
im_info['augment_shape'] = np.array([h, w]).astype('int32')
if 'gt_bbox' in label_info and len(label_info['gt_bbox']) > 0:
label_info['gt_bbox'] += np.array([x, y] * 2, dtype=np.float32)
if 'gt_poly' in label_info and len(label_info['gt_poly']) > 0:
label_info['gt_poly'] = expand_segms(label_info['gt_poly'], x, y,
height, width, expand_ratio)
return (canvas, im_info, label_info)
class RandomCrop:
"""随机裁剪图像。
1. 根据batch_sampler计算获取裁剪候选区域的位置。
(1) 根据min scale、max scale、min aspect ratio、max aspect ratio计算随机剪裁的高、宽。
(2) 根据随机剪裁的高、宽随机选取剪裁的起始点。
(3) 筛选出裁剪候选区域:
- 当satisfy_all为True时,需所有真实标注框与裁剪候选区域的重叠度满足需求时,该裁剪候选区域才可保留。
- 当satisfy_all为False时,当有一个真实标注框与裁剪候选区域的重叠度满足需求时,该裁剪候选区域就可保留。
2. 遍历所有裁剪候选区域:
(1) 若真实标注框与候选裁剪区域不重叠,或其中心点不在候选裁剪区域,
则将该真实标注框去除。
(2) 计算相对于该候选裁剪区域,真实标注框的位置,并筛选出对应的类别、混合得分。
(3) 若avoid_no_bbox为False,返回当前裁剪后的信息即可;
反之,要找到一个裁剪区域中真实标注框个数不为0的区域,才返回裁剪后的信息。
1. 若allow_no_crop为True,则在thresholds加入’no_crop’
2. 随机打乱thresholds
3. 遍历thresholds中各元素:
(1) 如果当前thresh为’no_crop’,则返回原始图像和标注信息
(2) 随机取出aspect_ratio和scaling中的值并由此计算出候选裁剪区域的高、宽、起始点。
(3) 计算真实标注框与候选裁剪区域IoU,若全部真实标注框的IoU都小于thresh,则继续第3步
(4) 如果cover_all_box为True且存在真实标注框的IoU小于thresh,则继续第3步
(5) 筛选出位于候选裁剪区域内的真实标注框,若有效框的个数为0,则继续第3步,否则进行第4步。
4. 换算有效真值标注框相对候选裁剪区域的位置坐标。
5. 换算有效分割区域相对候选裁剪区域的位置坐标。
Args:
batch_sampler (list): 随机裁剪参数的多种组合,每种组合包含8个值,如下:
- max sample (int):满足当前组合的裁剪区域的个数上限。
- max trial (int): 查找满足当前组合的次数。
- min scale (float): 裁剪面积相对原面积,每条边缩短比例的最小限制。
- max scale (float): 裁剪面积相对原面积,每条边缩短比例的最大限制。
- min aspect ratio (float): 裁剪后短边缩放比例的最小限制。
- max aspect ratio (float): 裁剪后短边缩放比例的最大限制。
- min overlap (float): 真实标注框与裁剪图像重叠面积的最小限制。
- max overlap (float): 真实标注框与裁剪图像重叠面积的最大限制。
默认值为None,当为None时采用如下设置:
[[1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]]
satisfy_all (bool): 是否需要所有标注框满足条件,裁剪候选区域才保留。默认为False。
avoid_no_bbox (bool): 是否对裁剪图像不存在标注框的图像进行保留。默认为True。
aspect_ratio (list): 裁剪后短边缩放比例的取值范围,以[min, max]形式表示。默认值为[.5, 2.]。
thresholds (list): 判断裁剪候选区域是否有效所需的IoU阈值取值列表。默认值为[.0, .1, .3, .5, .7, .9]。
scaling (list): 裁剪面积相对原面积的取值范围,以[min, max]形式表示。默认值为[.3, 1.]。
num_attempts (int): 在放弃寻找有效裁剪区域前尝试的次数。默认值为50。
allow_no_crop (bool): 是否允许未进行裁剪。默认值为True。
cover_all_box (bool): 是否要求所有的真实标注框都必须在裁剪区域内。默认值为False。
"""
def __init__(self,
batch_sampler=None,
satisfy_all=False,
avoid_no_bbox=True):
if batch_sampler is None:
batch_sampler = [[1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]]
self.batch_sampler = batch_sampler
self.satisfy_all = satisfy_all
self.avoid_no_bbox = avoid_no_bbox
aspect_ratio=[.5, 2.],
thresholds=[.0, .1, .3, .5, .7, .9],
scaling=[.3, 1.],
num_attempts=50,
allow_no_crop=True,
cover_all_box=False):
self.aspect_ratio = aspect_ratio
self.thresholds = thresholds
self.scaling = scaling
self.num_attempts = num_attempts
self.allow_no_crop = allow_no_crop
self.cover_all_box = cover_all_box
def __call__(self, im, im_info=None, label_info=None):
"""
......@@ -859,65 +843,83 @@ class RandomCrop:
'gt_class' not in label_info:
raise TypeError('Cannot do RandomCrop! ' + \
'Becasuse gt_bbox/gt_class is not in label_info!')
if len(label_info['gt_bbox']) == 0:
return (im, im_info, label_info)
augment_shape = im_info['augment_shape']
im_width = augment_shape[1]
im_height = augment_shape[0]
w = augment_shape[1]
h = augment_shape[0]
gt_bbox = label_info['gt_bbox']
gt_bbox_tmp = gt_bbox.copy()
for i in range(gt_bbox_tmp.shape[0]):
gt_bbox_tmp[i][0] = gt_bbox[i][0] / im_width
gt_bbox_tmp[i][1] = gt_bbox[i][1] / im_height
gt_bbox_tmp[i][2] = gt_bbox[i][2] / im_width
gt_bbox_tmp[i][3] = gt_bbox[i][3] / im_height
gt_class = label_info['gt_class']
thresholds = list(self.thresholds)
if self.allow_no_crop:
thresholds.append('no_crop')
np.random.shuffle(thresholds)
gt_score = None
if 'gt_score' in label_info:
gt_score = label_info['gt_score']
sampled_bbox = []
gt_bbox_tmp = gt_bbox_tmp.tolist()
for sampler in self.batch_sampler:
found = 0
for i in range(sampler[1]):
if found >= sampler[0]:
break
sample_bbox = generate_sample_bbox(sampler)
if satisfy_sample_constraint(sampler, sample_bbox, gt_bbox_tmp,
self.satisfy_all):
sampled_bbox.append(sample_bbox)
found = found + 1
im = np.array(im)
while sampled_bbox:
idx = int(np.random.uniform(0, len(sampled_bbox)))
sample_bbox = sampled_bbox.pop(idx)
sample_bbox = clip_bbox(sample_bbox)
crop_bbox, crop_class, crop_score = \
filter_and_process(sample_bbox, gt_bbox_tmp, gt_class, gt_score)
if self.avoid_no_bbox:
if len(crop_bbox) < 1:
for thresh in thresholds:
if thresh == 'no_crop':
return (im, im_info, label_info)
found = False
for i in range(self.num_attempts):
scale = np.random.uniform(*self.scaling)
min_ar, max_ar = self.aspect_ratio
aspect_ratio = np.random.uniform(
max(min_ar, scale**2), min(max_ar, scale**-2))
crop_h = int(h * scale / np.sqrt(aspect_ratio))
crop_w = int(w * scale * np.sqrt(aspect_ratio))
crop_y = np.random.randint(0, h - crop_h)
crop_x = np.random.randint(0, w - crop_w)
crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
iou = iou_matrix(gt_bbox, np.array([crop_box],
dtype=np.float32))
if iou.max() < thresh:
continue
xmin = int(sample_bbox[0] * im_width)
xmax = int(sample_bbox[2] * im_width)
ymin = int(sample_bbox[1] * im_height)
ymax = int(sample_bbox[3] * im_height)
im = im[ymin:ymax, xmin:xmax]
for i in range(crop_bbox.shape[0]):
crop_bbox[i][0] = crop_bbox[i][0] * (xmax - xmin)
crop_bbox[i][1] = crop_bbox[i][1] * (ymax - ymin)
crop_bbox[i][2] = crop_bbox[i][2] * (xmax - xmin)
crop_bbox[i][3] = crop_bbox[i][3] * (ymax - ymin)
label_info['gt_bbox'] = crop_bbox
label_info['gt_class'] = crop_class
label_info['gt_score'] = crop_score
im_info['augment_shape'] = np.array([ymax - ymin,
xmax - xmin]).astype('int32')
if label_info is None:
return (im, im_info)
if self.cover_all_box and iou.min() < thresh:
continue
cropped_box, valid_ids = crop_box_with_center_constraint(
gt_bbox, np.array(crop_box, dtype=np.float32))
if valid_ids.size > 0:
found = True
break
if found:
if 'gt_poly' in label_info and len(label_info['gt_poly']) > 0:
crop_polys = crop_segms(label_info['gt_poly'], valid_ids,
np.array(crop_box, dtype=np.int64),
h, w)
if [] in crop_polys:
delete_id = list()
valid_polys = list()
for id, crop_poly in enumerate(crop_polys):
if crop_poly == []:
delete_id.append(id)
else:
valid_polys.append(crop_poly)
valid_ids = np.delete(valid_ids, delete_id)
if len(valid_polys) == 0:
return (im, im_info, label_info)
if label_info is None:
return (im, im_info)
label_info['gt_poly'] = valid_polys
else:
label_info['gt_poly'] = crop_polys
im = crop_image(im, crop_box)
label_info['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0)
label_info['gt_class'] = np.take(
label_info['gt_class'], valid_ids, axis=0)
im_info['augment_shape'] = np.array(
[crop_box[3] - crop_box[1],
crop_box[2] - crop_box[0]]).astype('int32')
if 'gt_score' in label_info:
label_info['gt_score'] = np.take(
label_info['gt_score'], valid_ids, axis=0)
if 'is_crowd' in label_info:
label_info['is_crowd'] = np.take(
label_info['is_crowd'], valid_ids, axis=0)
return (im, im_info, label_info)
return (im, im_info, label_info)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册