未验证 提交 c2cdf8b2 编写于 作者: J Jason 提交者: GitHub

Merge pull request #96 from SunAhong1993/syf_docs

fix the voc dataset
...@@ -17,6 +17,7 @@ import copy ...@@ -17,6 +17,7 @@ import copy
import os import os
import os.path as osp import os.path as osp
import random import random
import re
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
...@@ -104,23 +105,49 @@ class VOCDetection(Dataset): ...@@ -104,23 +105,49 @@ class VOCDetection(Dataset):
else: else:
ct = int(tree.find('id').text) ct = int(tree.find('id').text)
im_id = np.array([int(tree.find('id').text)]) im_id = np.array([int(tree.find('id').text)])
pattern = re.compile('<object>', re.IGNORECASE)
objs = tree.findall('object') obj_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:-1]
im_w = float(tree.find('size').find('width').text) objs = tree.findall(obj_tag)
im_h = float(tree.find('size').find('height').text) pattern = re.compile('<size>', re.IGNORECASE)
size_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:-1]
size_element = tree.find(size_tag)
pattern = re.compile('<width>', re.IGNORECASE)
width_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1]
im_w = float(size_element.find(width_tag).text)
pattern = re.compile('<height>', re.IGNORECASE)
height_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1]
im_h = float(size_element.find(height_tag).text)
gt_bbox = np.zeros((len(objs), 4), dtype=np.float32) gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
gt_class = np.zeros((len(objs), 1), dtype=np.int32) gt_class = np.zeros((len(objs), 1), dtype=np.int32)
gt_score = np.ones((len(objs), 1), dtype=np.float32) gt_score = np.ones((len(objs), 1), dtype=np.float32)
is_crowd = np.zeros((len(objs), 1), dtype=np.int32) is_crowd = np.zeros((len(objs), 1), dtype=np.int32)
difficult = np.zeros((len(objs), 1), dtype=np.int32) difficult = np.zeros((len(objs), 1), dtype=np.int32)
for i, obj in enumerate(objs): for i, obj in enumerate(objs):
cname = obj.find('name').text.strip() pattern = re.compile('<name>', re.IGNORECASE)
name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
cname = obj.find(name_tag).text.strip()
gt_class[i][0] = cname2cid[cname] gt_class[i][0] = cname2cid[cname]
_difficult = int(obj.find('difficult').text) pattern = re.compile('<difficult>', re.IGNORECASE)
x1 = float(obj.find('bndbox').find('xmin').text) diff_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
y1 = float(obj.find('bndbox').find('ymin').text) try:
x2 = float(obj.find('bndbox').find('xmax').text) _difficult = int(obj.find(diff_tag).text)
y2 = float(obj.find('bndbox').find('ymax').text) except Exception:
_difficult = 0
pattern = re.compile('<bndbox>', re.IGNORECASE)
box_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
box_element = obj.find(box_tag)
pattern = re.compile('<xmin>', re.IGNORECASE)
xmin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
x1 = float(box_element.find(xmin_tag).text)
pattern = re.compile('<ymin>', re.IGNORECASE)
ymin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
y1 = float(box_element.find(ymin_tag).text)
pattern = re.compile('<xmax>', re.IGNORECASE)
xmax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
x2 = float(box_element.find(xmax_tag).text)
pattern = re.compile('<ymax>', re.IGNORECASE)
ymax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
y2 = float(box_element.find(ymax_tag).text)
x1 = max(0, x1) x1 = max(0, x1)
y1 = max(0, y1) y1 = max(0, y1)
if im_w > 0.5 and im_h > 0.5: if im_w > 0.5 and im_h > 0.5:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册