提交 a8d23e8e 编写于 作者: S sunyanfang01

modify imgaug support

上级 48598f27
...@@ -95,8 +95,8 @@ class VOCDetection(Dataset): ...@@ -95,8 +95,8 @@ class VOCDetection(Dataset):
if not osp.isfile(xml_file): if not osp.isfile(xml_file):
continue continue
if not osp.exists(img_file): if not osp.exists(img_file):
raise IOError( raise IOError('The image file {} is not exist!'.format(
'The image file {} is not exist!'.format(img_file)) img_file))
tree = ET.parse(xml_file) tree = ET.parse(xml_file)
if tree.find('id') is None: if tree.find('id') is None:
im_id = np.array([ct]) im_id = np.array([ct])
...@@ -122,25 +122,20 @@ class VOCDetection(Dataset): ...@@ -122,25 +122,20 @@ class VOCDetection(Dataset):
y2 = float(obj.find('bndbox').find('ymax').text) y2 = float(obj.find('bndbox').find('ymax').text)
x1 = max(0, x1) x1 = max(0, x1)
y1 = max(0, y1) y1 = max(0, y1)
if im_w > 0.5 and im_h > 0.5:
x2 = min(im_w - 1, x2) x2 = min(im_w - 1, x2)
y2 = min(im_h - 1, y2) y2 = min(im_h - 1, y2)
gt_bbox[i] = [x1, y1, x2, y2] gt_bbox[i] = [x1, y1, x2, y2]
is_crowd[i][0] = 0 is_crowd[i][0] = 0
difficult[i][0] = _difficult difficult[i][0] = _difficult
annotations['annotations'].append({ annotations['annotations'].append({
'iscrowd': 'iscrowd': 0,
0, 'image_id': int(im_id[0]),
'image_id':
int(im_id[0]),
'bbox': [x1, y1, x2 - x1 + 1, y2 - y1 + 1], 'bbox': [x1, y1, x2 - x1 + 1, y2 - y1 + 1],
'area': 'area': float((x2 - x1 + 1) * (y2 - y1 + 1)),
float((x2 - x1 + 1) * (y2 - y1 + 1)), 'category_id': cname2cid[cname],
'category_id': 'id': ann_ct,
cname2cid[cname], 'difficult': _difficult
'id':
ann_ct,
'difficult':
_difficult
}) })
ann_ct += 1 ann_ct += 1
...@@ -160,14 +155,10 @@ class VOCDetection(Dataset): ...@@ -160,14 +155,10 @@ class VOCDetection(Dataset):
self.file_list.append([img_file, voc_rec]) self.file_list.append([img_file, voc_rec])
ct += 1 ct += 1
annotations['images'].append({ annotations['images'].append({
'height': 'height': im_h,
im_h, 'width': im_w,
'width': 'id': int(im_id[0]),
im_w, 'file_name': osp.split(img_file)[1]
'id':
int(im_id[0]),
'file_name':
osp.split(img_file)[1]
}) })
if not len(self.file_list) > 0: if not len(self.file_list) > 0:
...@@ -198,8 +189,7 @@ class VOCDetection(Dataset): ...@@ -198,8 +189,7 @@ class VOCDetection(Dataset):
else: else:
mix_pos = 0 mix_pos = 0
im_info['mixup'] = [ im_info['mixup'] = [
files[mix_pos][0], files[mix_pos][0], copy.deepcopy(files[mix_pos][1][0]),
copy.deepcopy(files[mix_pos][1][0]),
copy.deepcopy(files[mix_pos][1][1]) copy.deepcopy(files[mix_pos][1][1])
] ]
self._pos += 1 self._pos += 1
......
...@@ -111,8 +111,8 @@ class Compose(DetTransform): ...@@ -111,8 +111,8 @@ class Compose(DetTransform):
try: try:
im = cv2.imread(im_file).astype('float32') im = cv2.imread(im_file).astype('float32')
except: except:
raise TypeError( raise TypeError('Can\'t read The image file {}!'.format(
'Can\'t read The image file {}!'.format(im_file)) im_file))
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
# make default im_info with [h, w, 1] # make default im_info with [h, w, 1]
im_info['im_resize_info'] = np.array( im_info['im_resize_info'] = np.array(
...@@ -145,19 +145,10 @@ class Compose(DetTransform): ...@@ -145,19 +145,10 @@ class Compose(DetTransform):
outputs = op(im, im_info, label_info) outputs = op(im, im_info, label_info)
im = outputs[0] im = outputs[0]
else: else:
im = execute_imgaug(op, im)
if label_info is not None: if label_info is not None:
gt_poly = label_info.get('gt_poly', None)
gt_bbox = label_info['gt_bbox']
if gt_poly is None:
im, aug_bbox = execute_imgaug(op, im, bboxes=gt_bbox)
else:
im, aug_bbox, aug_poly = execute_imgaug(
op, im, bboxes=gt_bbox, polygons=gt_poly)
label_info['gt_poly'] = aug_poly
label_info['gt_bbox'] = aug_bbox
outputs = (im, im_info, label_info) outputs = (im, im_info, label_info)
else: else:
im, = execute_imgaug(op, im)
outputs = (im, im_info) outputs = (im, im_info)
return outputs return outputs
...@@ -218,8 +209,8 @@ class ResizeByShort(DetTransform): ...@@ -218,8 +209,8 @@ class ResizeByShort(DetTransform):
im_short_size = min(im.shape[0], im.shape[1]) im_short_size = min(im.shape[0], im.shape[1])
im_long_size = max(im.shape[0], im.shape[1]) im_long_size = max(im.shape[0], im.shape[1])
scale = float(self.short_size) / im_short_size scale = float(self.short_size) / im_short_size
if self.max_size > 0 and np.round( if self.max_size > 0 and np.round(scale *
scale * im_long_size) > self.max_size: im_long_size) > self.max_size:
scale = float(self.max_size) / float(im_long_size) scale = float(self.max_size) / float(im_long_size)
resized_width = int(round(im.shape[1] * scale)) resized_width = int(round(im.shape[1] * scale))
resized_height = int(round(im.shape[0] * scale)) resized_height = int(round(im.shape[0] * scale))
...@@ -302,8 +293,8 @@ class Padding(DetTransform): ...@@ -302,8 +293,8 @@ class Padding(DetTransform):
if isinstance(self.target_size, int): if isinstance(self.target_size, int):
padding_im_h = self.target_size padding_im_h = self.target_size
padding_im_w = self.target_size padding_im_w = self.target_size
elif isinstance(self.target_size, list) or isinstance( elif isinstance(self.target_size, list) or isinstance(self.target_size,
self.target_size, tuple): tuple):
padding_im_w = self.target_size[0] padding_im_w = self.target_size[0]
padding_im_h = self.target_size[1] padding_im_h = self.target_size[1]
elif self.coarsest_stride > 0: elif self.coarsest_stride > 0:
...@@ -321,8 +312,8 @@ class Padding(DetTransform): ...@@ -321,8 +312,8 @@ class Padding(DetTransform):
raise ValueError( raise ValueError(
'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})' 'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})'
.format(im_w, im_h, padding_im_w, padding_im_h)) .format(im_w, im_h, padding_im_w, padding_im_h))
padding_im = np.zeros((padding_im_h, padding_im_w, im_c), padding_im = np.zeros(
dtype=np.float32) (padding_im_h, padding_im_w, im_c), dtype=np.float32)
padding_im[:im_h, :im_w, :] = im padding_im[:im_h, :im_w, :] = im
if label_info is None: if label_info is None:
return (padding_im, im_info) return (padding_im, im_info)
...@@ -932,8 +923,9 @@ class RandomCrop(DetTransform): ...@@ -932,8 +923,9 @@ class RandomCrop(DetTransform):
crop_y = np.random.randint(0, h - crop_h) crop_y = np.random.randint(0, h - crop_h)
crop_x = np.random.randint(0, w - crop_w) crop_x = np.random.randint(0, w - crop_w)
crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h] crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
iou = iou_matrix(gt_bbox, np.array([crop_box], iou = iou_matrix(
dtype=np.float32)) gt_bbox, np.array(
[crop_box], dtype=np.float32))
if iou.max() < thresh: if iou.max() < thresh:
continue continue
...@@ -941,16 +933,21 @@ class RandomCrop(DetTransform): ...@@ -941,16 +933,21 @@ class RandomCrop(DetTransform):
continue continue
cropped_box, valid_ids = crop_box_with_center_constraint( cropped_box, valid_ids = crop_box_with_center_constraint(
gt_bbox, np.array(crop_box, dtype=np.float32)) gt_bbox, np.array(
crop_box, dtype=np.float32))
if valid_ids.size > 0: if valid_ids.size > 0:
found = True found = True
break break
if found: if found:
if 'gt_poly' in label_info and len(label_info['gt_poly']) > 0: if 'gt_poly' in label_info and len(label_info['gt_poly']) > 0:
crop_polys = crop_segms(label_info['gt_poly'], valid_ids, crop_polys = crop_segms(
np.array(crop_box, dtype=np.int64), label_info['gt_poly'],
h, w) valid_ids,
np.array(
crop_box, dtype=np.int64),
h,
w)
if [] in crop_polys: if [] in crop_polys:
delete_id = list() delete_id = list()
valid_polys = list() valid_polys = list()
......
...@@ -13,36 +13,41 @@ ...@@ -13,36 +13,41 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
import copy
def execute_imgaug(augmenter, im, bboxes=None, polygons=None, def execute_imgaug(augmenter, im, bboxes=None, polygons=None,
segment_map=None): segment_map=None):
# 预处理,将bboxes, polygons转换成imgaug格式 # 预处理,将bboxes, polygons转换成imgaug格式
import imgaug.augmentables.polys as polys import imgaug.augmentables.kps as kps
import imgaug.augmentables.bbs as bbs import imgaug.augmentables.bbs as bbs
aug_im = im.astype('uint8') aug_im = im.astype('uint8')
aug_im = augmenter.augment(image=aug_im)
return aug_im
# TODO imgaug的标注处理逻辑与paddlex已存的transform存在部分差异
# 目前仅支持对原图进行处理,因此只能使用pixlevel的imgaug增强操作
# 以下代码暂不会执行
aug_bboxes = None aug_bboxes = None
if bboxes is not None: if bboxes is not None:
aug_bboxes = list() aug_bboxes = list()
for i in range(len(bboxes)): for i in range(len(bboxes)):
x1 = bboxes[i, 0] - 1 x1 = bboxes[i, 0]
y1 = bboxes[i, 1] y1 = bboxes[i, 1]
x2 = bboxes[i, 2] x2 = bboxes[i, 2]
y2 = bboxes[i, 3] y2 = bboxes[i, 3]
aug_bboxes.append(bbs.BoundingBox(x1, y1, x2, y2)) aug_bboxes.append(bbs.BoundingBox(x1, y1, x2, y2))
aug_polygons = None aug_points = None
lod_info = list()
if polygons is not None: if polygons is not None:
aug_polygons = list() aug_points = list()
for i in range(len(polygons)): for i in range(len(polygons)):
num = len(polygons[i]) num = len(polygons[i])
lod_info.append(num)
for j in range(num): for j in range(num):
points = np.reshape(polygons[i][j], (-1, 2)) tmp = np.reshape(polygons[i][j], (-1, 2))
aug_polygons.append(polys.Polygon(points)) for k in range(len(tmp)):
aug_points.append(kps.Keypoint(tmp[k, 0], tmp[k, 1]))
aug_segment_map = None aug_segment_map = None
if segment_map is not None: if segment_map is not None:
...@@ -56,72 +61,47 @@ def execute_imgaug(augmenter, im, bboxes=None, polygons=None, ...@@ -56,72 +61,47 @@ def execute_imgaug(augmenter, im, bboxes=None, polygons=None,
raise Exception( raise Exception(
"Only support 2-dimensions for 3-dimensions for segment_map") "Only support 2-dimensions for 3-dimensions for segment_map")
aug_im, aug_bboxes, aug_polygons, aug_seg_map = augmenter.augment( unnormalized_batch = augmenter.augment(
image=aug_im, image=aug_im,
bounding_boxes=aug_bboxes, bounding_boxes=aug_bboxes,
polygons=aug_polygons, keypoints=aug_points,
segmentation_maps=aug_segment_map) segmentation_maps=aug_segment_map,
return_batch=True)
aug_im = unnormalized_batch.images_aug[0]
aug_bboxes = unnormalized_batch.bounding_boxes_aug
aug_points = unnormalized_batch.keypoints_aug
aug_seg_map = unnormalized_batch.segmentation_maps_aug
aug_im = aug_im.astype('float32') aug_im = aug_im.astype('float32')
if aug_polygons is not None:
assert len(aug_bboxes) == len(
lod_info
), "Number of aug_bboxes should be equal to number of aug_polygons"
if aug_bboxes is not None: if aug_bboxes is not None:
# 裁剪掉在图像之外的bbox和polygon
for i in range(len(aug_bboxes)):
aug_bboxes[i] = aug_bboxes[i].clip_out_of_image(aug_im)
if aug_polygons is not None:
for i in range(len(aug_polygons)):
aug_polygons[i] = aug_polygons[i].clip_out_of_image(aug_im)
# 过滤掉无效的bbox和polygon,并转换为训练数据格式
converted_bboxes = list() converted_bboxes = list()
converted_polygons = list()
poly_index = 0
for i in range(len(aug_bboxes)): for i in range(len(aug_bboxes)):
# 过滤width或height不足1像素的框
if aug_bboxes[i].width < 1 or aug_bboxes[i].height < 1:
continue
if aug_polygons is None:
converted_bboxes.append([ converted_bboxes.append([
aug_bboxes[i].x1, aug_bboxes[i].y1, aug_bboxes[i].x2, aug_bboxes[i].x1, aug_bboxes[i].y1, aug_bboxes[i].x2,
aug_bboxes[i].y2 aug_bboxes[i].y2
]) ])
continue aug_bboxes = converted_bboxes
# 如若有polygons,将会继续执行下面代码
polygons_this_box = list()
for ps in aug_polygons[poly_index:poly_index + lod_info[i]]:
if len(ps) == 0:
continue
for p in ps:
# 没有3个point的polygon被过滤
if len(p.exterior) < 3:
continue
polygons_this_box.append(p.exterior.flatten().tolist())
poly_index += lod_info[i]
if len(polygons_this_box) == 0: aug_polygons = None
continue if aug_points is not None:
converted_bboxes.append([ aug_polygons = copy.deepcopy(polygons)
aug_bboxes[i].x1, aug_bboxes[i].y1, aug_bboxes[i].x2, idx = 0
aug_bboxes[i].y2 for i in range(len(aug_polygons)):
]) num = len(aug_polygons[i])
converted_polygons.append(polygons_this_box) for j in range(num):
if len(converted_bboxes) == 0: num_points = len(aug_polygons[i][j]) // 2
aug_im = im for k in range(num_points):
converted_bboxes = bboxes aug_polygons[i][j][k * 2] = aug_points[idx].x
converted_polygons = polygons aug_polygons[i][j][k * 2 + 1] = aug_points[idx].y
idx += 1
result = [aug_im] result = [aug_im]
if bboxes is not None: if aug_bboxes is not None:
result.append(np.array(converted_bboxes)) result.append(np.array(aug_bboxes))
if polygons is not None: if aug_polygons is not None:
result.append(converted_polygons) result.append(aug_polygons)
if segment_map is not None: if aug_seg_map is not None:
n, h, w, c = aug_seg_map.shape n, h, w, c = aug_seg_map.shape
if len(segment_map.shape) == 2: if len(segment_map.shape) == 2:
aug_seg_map = np.reshape(aug_seg_map, (h, w)) aug_seg_map = np.reshape(aug_seg_map, (h, w))
......
...@@ -101,11 +101,10 @@ class Compose(SegTransform): ...@@ -101,11 +101,10 @@ class Compose(SegTransform):
if len(outputs) == 3: if len(outputs) == 3:
label = outputs[2] label = outputs[2]
else: else:
im = execute_imgaug(op, im)
if label is not None: if label is not None:
im, label = execute_imgaug(op, im, segment_map=label)
outputs = (im, im_info, label) outputs = (im, im_info, label)
else: else:
im, = execute_imgaug(op, im)
outputs = (im, im_info) outputs = (im, im_info)
return outputs return outputs
...@@ -391,8 +390,8 @@ class ResizeByShort(SegTransform): ...@@ -391,8 +390,8 @@ class ResizeByShort(SegTransform):
im_short_size = min(im.shape[0], im.shape[1]) im_short_size = min(im.shape[0], im.shape[1])
im_long_size = max(im.shape[0], im.shape[1]) im_long_size = max(im.shape[0], im.shape[1])
scale = float(self.short_size) / im_short_size scale = float(self.short_size) / im_short_size
if self.max_size > 0 and np.round( if self.max_size > 0 and np.round(scale *
scale * im_long_size) > self.max_size: im_long_size) > self.max_size:
scale = float(self.max_size) / float(im_long_size) scale = float(self.max_size) / float(im_long_size)
resized_width = int(round(im.shape[1] * scale)) resized_width = int(round(im.shape[1] * scale))
resized_height = int(round(im.shape[0] * scale)) resized_height = int(round(im.shape[0] * scale))
...@@ -423,8 +422,8 @@ class ResizeRangeScaling(SegTransform): ...@@ -423,8 +422,8 @@ class ResizeRangeScaling(SegTransform):
def __init__(self, min_value=400, max_value=600): def __init__(self, min_value=400, max_value=600):
if min_value > max_value: if min_value > max_value:
raise ValueError('min_value must be less than max_value, ' raise ValueError('min_value must be less than max_value, '
'but they are {} and {}.'.format( 'but they are {} and {}.'.format(min_value,
min_value, max_value)) max_value))
self.min_value = min_value self.min_value = min_value
self.max_value = max_value self.max_value = max_value
...@@ -761,8 +760,8 @@ class RandomPaddingCrop(SegTransform): ...@@ -761,8 +760,8 @@ class RandomPaddingCrop(SegTransform):
h_off = np.random.randint(img_height - crop_height + 1) h_off = np.random.randint(img_height - crop_height + 1)
w_off = np.random.randint(img_width - crop_width + 1) w_off = np.random.randint(img_width - crop_width + 1)
im = im[h_off:(crop_height + h_off), w_off:( im = im[h_off:(crop_height + h_off), w_off:(w_off + crop_width
w_off + crop_width), :] ), :]
if label is not None: if label is not None:
label = label[h_off:(crop_height + h_off), w_off:( label = label[h_off:(crop_height + h_off), w_off:(
w_off + crop_width)] w_off + crop_width)]
......
...@@ -27,7 +27,7 @@ setuptools.setup( ...@@ -27,7 +27,7 @@ setuptools.setup(
long_description_content_type="text/plain", long_description_content_type="text/plain",
url="https://github.com/PaddlePaddle/PaddleX", url="https://github.com/PaddlePaddle/PaddleX",
packages=setuptools.find_packages(), packages=setuptools.find_packages(),
setup_requires=['cython', 'numpy', 'sklearn'], setup_requires=['cython', 'numpy'],
install_requires=[ install_requires=[
"pycocotools;platform_system!='Windows'", 'pyyaml', 'colorama', 'tqdm', "pycocotools;platform_system!='Windows'", 'pyyaml', 'colorama', 'tqdm',
'visualdl==1.3.0', 'paddleslim==1.0.1', 'visualdl==2.0.0a2' 'visualdl==1.3.0', 'paddleslim==1.0.1', 'visualdl==2.0.0a2'
...@@ -38,6 +38,4 @@ setuptools.setup( ...@@ -38,6 +38,4 @@ setuptools.setup(
"Operating System :: OS Independent", "Operating System :: OS Independent",
], ],
license='Apache 2.0', license='Apache 2.0',
entry_points={'console_scripts': [ entry_points={'console_scripts': ['paddlex=paddlex.command:main', ]})
'paddlex=paddlex.command:main',
]})
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册