提交 de91f9a0 编写于 作者: M MikoyChinese

Fix copy_paste no texts augment.

上级 473a8116
...@@ -35,10 +35,12 @@ class CopyPaste(object): ...@@ -35,10 +35,12 @@ class CopyPaste(object):
point_num = data['polys'].shape[1] point_num = data['polys'].shape[1]
src_img = data['image'] src_img = data['image']
src_polys = data['polys'].tolist() src_polys = data['polys'].tolist()
src_texts = data['texts']
src_ignores = data['ignore_tags'].tolist() src_ignores = data['ignore_tags'].tolist()
ext_data = data['ext_data'][0] ext_data = data['ext_data'][0]
ext_image = ext_data['image'] ext_image = ext_data['image']
ext_polys = ext_data['polys'] ext_polys = ext_data['polys']
ext_texts = ext_data['texts']
ext_ignores = ext_data['ignore_tags'] ext_ignores = ext_data['ignore_tags']
indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]] indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
...@@ -53,7 +55,7 @@ class CopyPaste(object): ...@@ -53,7 +55,7 @@ class CopyPaste(object):
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB) ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
src_img = Image.fromarray(src_img).convert('RGBA') src_img = Image.fromarray(src_img).convert('RGBA')
for poly, tag in zip(select_polys, select_ignores): for idx, poly, tag in zip(select_idxs, select_polys, select_ignores):
box_img = get_rotate_crop_image(ext_image, poly) box_img = get_rotate_crop_image(ext_image, poly)
src_img, box = self.paste_img(src_img, box_img, src_polys) src_img, box = self.paste_img(src_img, box_img, src_polys)
...@@ -62,6 +64,7 @@ class CopyPaste(object): ...@@ -62,6 +64,7 @@ class CopyPaste(object):
for _ in range(len(box), point_num): for _ in range(len(box), point_num):
box.append(box[-1]) box.append(box[-1])
src_polys.append(box) src_polys.append(box)
src_texts.append(ext_texts[idx])
src_ignores.append(tag) src_ignores.append(tag)
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR) src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
h, w = src_img.shape[:2] h, w = src_img.shape[:2]
...@@ -70,6 +73,7 @@ class CopyPaste(object): ...@@ -70,6 +73,7 @@ class CopyPaste(object):
src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h) src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
data['image'] = src_img data['image'] = src_img
data['polys'] = src_polys data['polys'] = src_polys
data['texts'] = src_texts
data['ignore_tags'] = np.array(src_ignores) data['ignore_tags'] = np.array(src_ignores)
return data return data
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册