提交 6fda0a8c 编写于 作者: F FlyingQianMM

fix bug in coco loader

上级 be77e22f
......@@ -100,7 +100,7 @@ class CocoDetection(VOCDetection):
gt_score = np.ones((num_bbox, 1), dtype=np.float32)
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
difficult = np.zeros((num_bbox, 1), dtype=np.int32)
gt_poly = None
gt_poly = [None] * num_bbox
for i, box in enumerate(bboxes):
catid = box['category_id']
......@@ -108,8 +108,6 @@ class CocoDetection(VOCDetection):
gt_bbox[i, :] = box['clean_bbox']
is_crowd[i][0] = box['iscrowd']
if 'segmentation' in box:
if gt_poly is None:
gt_poly = [None] * num_bbox
gt_poly[i] = box['segmentation']
im_info = {
......@@ -121,10 +119,9 @@ class CocoDetection(VOCDetection):
'gt_class': gt_class,
'gt_bbox': gt_bbox,
'gt_score': gt_score,
'gt_poly': gt_poly,
'difficult': difficult
}
if gt_poly is not None:
label_info['gt_poly'] = gt_poly
coco_rec = (im_info, label_info)
self.file_list.append([im_fname, coco_rec])
......
......@@ -106,16 +106,20 @@ class VOCDetection(Dataset):
ct = int(tree.find('id').text)
im_id = np.array([int(tree.find('id').text)])
pattern = re.compile('<object>', re.IGNORECASE)
obj_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:-1]
obj_tag = pattern.findall(
str(ET.tostringlist(tree.getroot())))[0][1:-1]
objs = tree.findall(obj_tag)
pattern = re.compile('<size>', re.IGNORECASE)
size_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:-1]
size_tag = pattern.findall(
str(ET.tostringlist(tree.getroot())))[0][1:-1]
size_element = tree.find(size_tag)
pattern = re.compile('<width>', re.IGNORECASE)
width_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1]
width_tag = pattern.findall(
str(ET.tostringlist(size_element)))[0][1:-1]
im_w = float(size_element.find(width_tag).text)
pattern = re.compile('<height>', re.IGNORECASE)
height_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1]
height_tag = pattern.findall(
str(ET.tostringlist(size_element)))[0][1:-1]
im_h = float(size_element.find(height_tag).text)
gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
gt_class = np.zeros((len(objs), 1), dtype=np.int32)
......@@ -124,29 +128,36 @@ class VOCDetection(Dataset):
difficult = np.zeros((len(objs), 1), dtype=np.int32)
for i, obj in enumerate(objs):
pattern = re.compile('<name>', re.IGNORECASE)
name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][
1:-1]
cname = obj.find(name_tag).text.strip()
gt_class[i][0] = cname2cid[cname]
pattern = re.compile('<difficult>', re.IGNORECASE)
diff_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
diff_tag = pattern.findall(str(ET.tostringlist(obj)))[0][
1:-1]
try:
_difficult = int(obj.find(diff_tag).text)
except Exception:
_difficult = 0
pattern = re.compile('<bndbox>', re.IGNORECASE)
box_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
box_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:
-1]
box_element = obj.find(box_tag)
pattern = re.compile('<xmin>', re.IGNORECASE)
xmin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
xmin_tag = pattern.findall(
str(ET.tostringlist(box_element)))[0][1:-1]
x1 = float(box_element.find(xmin_tag).text)
pattern = re.compile('<ymin>', re.IGNORECASE)
ymin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
ymin_tag = pattern.findall(
str(ET.tostringlist(box_element)))[0][1:-1]
y1 = float(box_element.find(ymin_tag).text)
pattern = re.compile('<xmax>', re.IGNORECASE)
xmax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
xmax_tag = pattern.findall(
str(ET.tostringlist(box_element)))[0][1:-1]
x2 = float(box_element.find(xmax_tag).text)
pattern = re.compile('<ymax>', re.IGNORECASE)
ymax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
ymax_tag = pattern.findall(
str(ET.tostringlist(box_element)))[0][1:-1]
y2 = float(box_element.find(ymax_tag).text)
x1 = max(0, x1)
y1 = max(0, y1)
......@@ -176,6 +187,7 @@ class VOCDetection(Dataset):
'gt_class': gt_class,
'gt_bbox': gt_bbox,
'gt_score': gt_score,
'gt_poly': [],
'difficult': difficult
}
voc_rec = (im_info, label_info)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册