提交 3db718af 编写于 作者: S sunyanfang01

fix the voc dataset

上级 02b1a99e
...@@ -16,6 +16,7 @@ from __future__ import absolute_import ...@@ -16,6 +16,7 @@ from __future__ import absolute_import
import copy import copy
import os.path as osp import os.path as osp
import random import random
import re
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
...@@ -103,23 +104,46 @@ class VOCDetection(Dataset): ...@@ -103,23 +104,46 @@ class VOCDetection(Dataset):
else: else:
ct = int(tree.find('id').text) ct = int(tree.find('id').text)
im_id = np.array([int(tree.find('id').text)]) im_id = np.array([int(tree.find('id').text)])
pattern = re.compile('<object>', re.IGNORECASE)
objs = tree.findall('object') obj_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:-1]
im_w = float(tree.find('size').find('width').text) objs = tree.findall(obj_tag)
im_h = float(tree.find('size').find('height').text) pattern = re.compile('<size>', re.IGNORECASE)
size_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:-1]
size_element = tree.find(size_tag)
pattern = re.compile('<width>', re.IGNORECASE)
width_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1]
im_w = float(size_element.find(width_tag).text)
pattern = re.compile('<height>', re.IGNORECASE)
height_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1]
im_h = float(size_element.find(height_tag).text)
gt_bbox = np.zeros((len(objs), 4), dtype=np.float32) gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
gt_class = np.zeros((len(objs), 1), dtype=np.int32) gt_class = np.zeros((len(objs), 1), dtype=np.int32)
gt_score = np.ones((len(objs), 1), dtype=np.float32) gt_score = np.ones((len(objs), 1), dtype=np.float32)
is_crowd = np.zeros((len(objs), 1), dtype=np.int32) is_crowd = np.zeros((len(objs), 1), dtype=np.int32)
difficult = np.zeros((len(objs), 1), dtype=np.int32) difficult = np.zeros((len(objs), 1), dtype=np.int32)
for i, obj in enumerate(objs): for i, obj in enumerate(objs):
cname = obj.find('name').text pattern = re.compile('<name>', re.IGNORECASE)
name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
cname = obj.find(name_tag).text
gt_class[i][0] = cname2cid[cname] gt_class[i][0] = cname2cid[cname]
_difficult = int(obj.find('difficult').text) pattern = re.compile('<difficult>', re.IGNORECASE)
x1 = float(obj.find('bndbox').find('xmin').text) diff_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
y1 = float(obj.find('bndbox').find('ymin').text) _difficult = int(obj.find(diff_tag).text)
x2 = float(obj.find('bndbox').find('xmax').text) pattern = re.compile('<bndbox>', re.IGNORECASE)
y2 = float(obj.find('bndbox').find('ymax').text) box_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
box_element = obj.find(box_tag)
pattern = re.compile('<xmin>', re.IGNORECASE)
xmin_element = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
x1 = float(box_element.find(xmin_element).text)
pattern = re.compile('<ymin>', re.IGNORECASE)
ymin_element = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
y1 = float(box_element.find(ymin_element).text)
pattern = re.compile('<xmax>', re.IGNORECASE)
xmax_element = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
x2 = float(box_element.find(xmax_element).text)
pattern = re.compile('<ymax>', re.IGNORECASE)
ymax_element = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
y2 = float(box_element.find(ymax_element).text)
x1 = max(0, x1) x1 = max(0, x1)
y1 = max(0, y1) y1 = max(0, y1)
if im_w > 0.5 and im_h > 0.5: if im_w > 0.5 and im_h > 0.5:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册