提交 6fda0a8c 编写于 作者: F FlyingQianMM

fix bug in coco loader

上级 be77e22f
...@@ -100,7 +100,7 @@ class CocoDetection(VOCDetection): ...@@ -100,7 +100,7 @@ class CocoDetection(VOCDetection):
gt_score = np.ones((num_bbox, 1), dtype=np.float32) gt_score = np.ones((num_bbox, 1), dtype=np.float32)
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32) is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
difficult = np.zeros((num_bbox, 1), dtype=np.int32) difficult = np.zeros((num_bbox, 1), dtype=np.int32)
gt_poly = None gt_poly = [None] * num_bbox
for i, box in enumerate(bboxes): for i, box in enumerate(bboxes):
catid = box['category_id'] catid = box['category_id']
...@@ -108,8 +108,6 @@ class CocoDetection(VOCDetection): ...@@ -108,8 +108,6 @@ class CocoDetection(VOCDetection):
gt_bbox[i, :] = box['clean_bbox'] gt_bbox[i, :] = box['clean_bbox']
is_crowd[i][0] = box['iscrowd'] is_crowd[i][0] = box['iscrowd']
if 'segmentation' in box: if 'segmentation' in box:
if gt_poly is None:
gt_poly = [None] * num_bbox
gt_poly[i] = box['segmentation'] gt_poly[i] = box['segmentation']
im_info = { im_info = {
...@@ -121,10 +119,9 @@ class CocoDetection(VOCDetection): ...@@ -121,10 +119,9 @@ class CocoDetection(VOCDetection):
'gt_class': gt_class, 'gt_class': gt_class,
'gt_bbox': gt_bbox, 'gt_bbox': gt_bbox,
'gt_score': gt_score, 'gt_score': gt_score,
'gt_poly': gt_poly,
'difficult': difficult 'difficult': difficult
} }
if gt_poly is not None:
label_info['gt_poly'] = gt_poly
coco_rec = (im_info, label_info) coco_rec = (im_info, label_info)
self.file_list.append([im_fname, coco_rec]) self.file_list.append([im_fname, coco_rec])
......
...@@ -106,16 +106,20 @@ class VOCDetection(Dataset): ...@@ -106,16 +106,20 @@ class VOCDetection(Dataset):
ct = int(tree.find('id').text) ct = int(tree.find('id').text)
im_id = np.array([int(tree.find('id').text)]) im_id = np.array([int(tree.find('id').text)])
pattern = re.compile('<object>', re.IGNORECASE) pattern = re.compile('<object>', re.IGNORECASE)
obj_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:-1] obj_tag = pattern.findall(
str(ET.tostringlist(tree.getroot())))[0][1:-1]
objs = tree.findall(obj_tag) objs = tree.findall(obj_tag)
pattern = re.compile('<size>', re.IGNORECASE) pattern = re.compile('<size>', re.IGNORECASE)
size_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:-1] size_tag = pattern.findall(
str(ET.tostringlist(tree.getroot())))[0][1:-1]
size_element = tree.find(size_tag) size_element = tree.find(size_tag)
pattern = re.compile('<width>', re.IGNORECASE) pattern = re.compile('<width>', re.IGNORECASE)
width_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1] width_tag = pattern.findall(
str(ET.tostringlist(size_element)))[0][1:-1]
im_w = float(size_element.find(width_tag).text) im_w = float(size_element.find(width_tag).text)
pattern = re.compile('<height>', re.IGNORECASE) pattern = re.compile('<height>', re.IGNORECASE)
height_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1] height_tag = pattern.findall(
str(ET.tostringlist(size_element)))[0][1:-1]
im_h = float(size_element.find(height_tag).text) im_h = float(size_element.find(height_tag).text)
gt_bbox = np.zeros((len(objs), 4), dtype=np.float32) gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
gt_class = np.zeros((len(objs), 1), dtype=np.int32) gt_class = np.zeros((len(objs), 1), dtype=np.int32)
...@@ -124,29 +128,36 @@ class VOCDetection(Dataset): ...@@ -124,29 +128,36 @@ class VOCDetection(Dataset):
difficult = np.zeros((len(objs), 1), dtype=np.int32) difficult = np.zeros((len(objs), 1), dtype=np.int32)
for i, obj in enumerate(objs): for i, obj in enumerate(objs):
pattern = re.compile('<name>', re.IGNORECASE) pattern = re.compile('<name>', re.IGNORECASE)
name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1] name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][
1:-1]
cname = obj.find(name_tag).text.strip() cname = obj.find(name_tag).text.strip()
gt_class[i][0] = cname2cid[cname] gt_class[i][0] = cname2cid[cname]
pattern = re.compile('<difficult>', re.IGNORECASE) pattern = re.compile('<difficult>', re.IGNORECASE)
diff_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1] diff_tag = pattern.findall(str(ET.tostringlist(obj)))[0][
1:-1]
try: try:
_difficult = int(obj.find(diff_tag).text) _difficult = int(obj.find(diff_tag).text)
except Exception: except Exception:
_difficult = 0 _difficult = 0
pattern = re.compile('<bndbox>', re.IGNORECASE) pattern = re.compile('<bndbox>', re.IGNORECASE)
box_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1] box_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:
-1]
box_element = obj.find(box_tag) box_element = obj.find(box_tag)
pattern = re.compile('<xmin>', re.IGNORECASE) pattern = re.compile('<xmin>', re.IGNORECASE)
xmin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1] xmin_tag = pattern.findall(
str(ET.tostringlist(box_element)))[0][1:-1]
x1 = float(box_element.find(xmin_tag).text) x1 = float(box_element.find(xmin_tag).text)
pattern = re.compile('<ymin>', re.IGNORECASE) pattern = re.compile('<ymin>', re.IGNORECASE)
ymin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1] ymin_tag = pattern.findall(
str(ET.tostringlist(box_element)))[0][1:-1]
y1 = float(box_element.find(ymin_tag).text) y1 = float(box_element.find(ymin_tag).text)
pattern = re.compile('<xmax>', re.IGNORECASE) pattern = re.compile('<xmax>', re.IGNORECASE)
xmax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1] xmax_tag = pattern.findall(
str(ET.tostringlist(box_element)))[0][1:-1]
x2 = float(box_element.find(xmax_tag).text) x2 = float(box_element.find(xmax_tag).text)
pattern = re.compile('<ymax>', re.IGNORECASE) pattern = re.compile('<ymax>', re.IGNORECASE)
ymax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1] ymax_tag = pattern.findall(
str(ET.tostringlist(box_element)))[0][1:-1]
y2 = float(box_element.find(ymax_tag).text) y2 = float(box_element.find(ymax_tag).text)
x1 = max(0, x1) x1 = max(0, x1)
y1 = max(0, y1) y1 = max(0, y1)
...@@ -176,6 +187,7 @@ class VOCDetection(Dataset): ...@@ -176,6 +187,7 @@ class VOCDetection(Dataset):
'gt_class': gt_class, 'gt_class': gt_class,
'gt_bbox': gt_bbox, 'gt_bbox': gt_bbox,
'gt_score': gt_score, 'gt_score': gt_score,
'gt_poly': [],
'difficult': difficult 'difficult': difficult
} }
voc_rec = (im_info, label_info) voc_rec = (im_info, label_info)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册