You need to sign in or sign up before continuing.
提交 fcd73d26 编写于 作者: 文幕地方's avatar 文幕地方

fix bug in copypaste when point num > 8

上级 6cb99060
...@@ -32,6 +32,7 @@ class CopyPaste(object): ...@@ -32,6 +32,7 @@ class CopyPaste(object):
self.aug = IaaAugment(augmenter_args) self.aug = IaaAugment(augmenter_args)
def __call__(self, data): def __call__(self, data):
point_num = data['polys'].shape[1]
src_img = data['image'] src_img = data['image']
src_polys = data['polys'].tolist() src_polys = data['polys'].tolist()
src_ignores = data['ignore_tags'].tolist() src_ignores = data['ignore_tags'].tolist()
...@@ -57,6 +58,9 @@ class CopyPaste(object): ...@@ -57,6 +58,9 @@ class CopyPaste(object):
src_img, box = self.paste_img(src_img, box_img, src_polys) src_img, box = self.paste_img(src_img, box_img, src_polys)
if box is not None: if box is not None:
box = box.tolist()
for _ in range(len(box), point_num):
box.append(box[-1])
src_polys.append(box) src_polys.append(box)
src_ignores.append(tag) src_ignores.append(tag)
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR) src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import numpy as np import numpy as np
import os import os
import random import random
import traceback
from paddle.io import Dataset from paddle.io import Dataset
from .imaug import transform, create_operators from .imaug import transform, create_operators
...@@ -93,7 +94,8 @@ class SimpleDataSet(Dataset): ...@@ -93,7 +94,8 @@ class SimpleDataSet(Dataset):
img = f.read() img = f.read()
data['image'] = img data['image'] = img
data = transform(data, load_data_ops) data = transform(data, load_data_ops)
if data is None:
if data is None or data['polys'].shape[1]!=4:
continue continue
ext_data.append(data) ext_data.append(data)
return ext_data return ext_data
...@@ -115,10 +117,10 @@ class SimpleDataSet(Dataset): ...@@ -115,10 +117,10 @@ class SimpleDataSet(Dataset):
data['image'] = img data['image'] = img
data['ext_data'] = self.get_ext_data() data['ext_data'] = self.get_ext_data()
outs = transform(data, self.ops) outs = transform(data, self.ops)
except Exception as e: except:
self.logger.error( self.logger.error(
"When parsing line {}, error happened with msg: {}".format( "When parsing line {}, error happened with msg: {}".format(
data_line, e)) data_line, traceback.format_exc()))
outs = None outs = None
if outs is None: if outs is None:
# during evaluation, we should fix the idx to get same results for many times of evaluation. # during evaluation, we should fix the idx to get same results for many times of evaluation.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册