未验证 提交 c2cdf8b2 编写于 作者: J Jason 提交者: GitHub

Merge pull request #96 from SunAhong1993/syf_docs

fix the voc dataset
......@@ -17,6 +17,7 @@ import copy
import os
import os.path as osp
import random
import re
import numpy as np
from collections import OrderedDict
import xml.etree.ElementTree as ET
......@@ -104,23 +105,49 @@ class VOCDetection(Dataset):
else:
ct = int(tree.find('id').text)
im_id = np.array([int(tree.find('id').text)])
objs = tree.findall('object')
im_w = float(tree.find('size').find('width').text)
im_h = float(tree.find('size').find('height').text)
pattern = re.compile('<object>', re.IGNORECASE)
obj_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:-1]
objs = tree.findall(obj_tag)
pattern = re.compile('<size>', re.IGNORECASE)
size_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:-1]
size_element = tree.find(size_tag)
pattern = re.compile('<width>', re.IGNORECASE)
width_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1]
im_w = float(size_element.find(width_tag).text)
pattern = re.compile('<height>', re.IGNORECASE)
height_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1]
im_h = float(size_element.find(height_tag).text)
gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
gt_class = np.zeros((len(objs), 1), dtype=np.int32)
gt_score = np.ones((len(objs), 1), dtype=np.float32)
is_crowd = np.zeros((len(objs), 1), dtype=np.int32)
difficult = np.zeros((len(objs), 1), dtype=np.int32)
for i, obj in enumerate(objs):
cname = obj.find('name').text.strip()
pattern = re.compile('<name>', re.IGNORECASE)
name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
cname = obj.find(name_tag).text.strip()
gt_class[i][0] = cname2cid[cname]
_difficult = int(obj.find('difficult').text)
x1 = float(obj.find('bndbox').find('xmin').text)
y1 = float(obj.find('bndbox').find('ymin').text)
x2 = float(obj.find('bndbox').find('xmax').text)
y2 = float(obj.find('bndbox').find('ymax').text)
pattern = re.compile('<difficult>', re.IGNORECASE)
diff_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
try:
_difficult = int(obj.find(diff_tag).text)
except Exception:
_difficult = 0
pattern = re.compile('<bndbox>', re.IGNORECASE)
box_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
box_element = obj.find(box_tag)
pattern = re.compile('<xmin>', re.IGNORECASE)
xmin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
x1 = float(box_element.find(xmin_tag).text)
pattern = re.compile('<ymin>', re.IGNORECASE)
ymin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
y1 = float(box_element.find(ymin_tag).text)
pattern = re.compile('<xmax>', re.IGNORECASE)
xmax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
x2 = float(box_element.find(xmax_tag).text)
pattern = re.compile('<ymax>', re.IGNORECASE)
ymax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
y2 = float(box_element.find(ymax_tag).text)
x1 = max(0, x1)
y1 = max(0, y1)
if im_w > 0.5 and im_h > 0.5:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册