提交 41464e18 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge/data): voc dataset supports detection

GitOrigin-RevId: f78bef3cd2ba895c2ec0ca1e0132737cf5a5cbd5
上级 26738d99
......@@ -29,43 +29,47 @@ class PascalVOC(VisionDataset):
supported_order = (
"image",
# "boxes",
# "boxes_category",
"boxes",
"boxes_category",
"mask",
"info",
)
def __init__(self, root, image_set, *, order=None):
if ("boxes" in order or "boxes_category" in order) and "mask" in order:
raise ValueError("PascalVOC only supports boxes & boxes_category or mask, not both.")
super().__init__(root, order=order, supported_order=self.supported_order)
voc_root = self.root
if not os.path.isdir(voc_root):
if not os.path.isdir(self.root):
raise RuntimeError("Dataset not found or corrupted.")
self.image_set = image_set
image_dir = os.path.join(voc_root, "JPEGImages")
# for segmentation
if "aug" in image_set:
mask_dir = os.path.join(voc_root, "SegmentationClass_aug")
image_dir = os.path.join(self.root, "JPEGImages")
if "boxes" in order or "boxes_category" in order:
annotation_dir = os.path.join(self.root, 'Annotations')
splitdet_dir = os.path.join(self.root, "ImageSets/Main")
split_f = os.path.join(splitdet_dir, image_set.rstrip("\n") + ".txt")
with open(os.path.join(split_f), "r") as f:
self.file_names = [x.strip() for x in f.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names]
self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in self.file_names]
assert len(self.images) == len(self.annotations)
elif "mask" in order:
if "aug" in image_set:
mask_dir = os.path.join(self.root, "SegmentationClass_aug")
else:
mask_dir = os.path.join(self.root, "SegmentationClass")
splitmask_dir = os.path.join(self.root, "ImageSets/Segmentation")
split_f = os.path.join(splitmask_dir, image_set.rstrip("\n") + ".txt")
with open(os.path.join(split_f), "r") as f:
self.file_names = [x.strip() for x in f.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names]
self.masks = [os.path.join(mask_dir, x + ".png") for x in self.file_names]
assert len(self.images) == len(self.masks)
else:
mask_dir = os.path.join(voc_root, "SegmentationClass")
splitmask_dir = os.path.join(voc_root, "ImageSets/Segmentation")
split_f = os.path.join(splitmask_dir, image_set.rstrip("\n") + ".txt")
with open(os.path.join(split_f), "r") as f:
self.file_names = [x.strip() for x in f.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names]
self.masks = [os.path.join(mask_dir, x + ".png") for x in self.file_names]
# TODO: for detection
# splitdet_dir = os.path.join(voc_root, "ImageSets/Main")
# split_f = os.path.join(splitdet_dir, image_set.rstrip("\n") + ".txt")
# with open(os.path.join(split_f), "r") as f:
# self.file_names = [x.strip() for x in f.readlines()]
# self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
# self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in self.file_names]
# assert (len(self.images) == len(self.masks)) and (len(self.images) == len(self.annotations))
raise NotImplementedError
def __getitem__(self, index):
target = []
......@@ -73,6 +77,19 @@ class PascalVOC(VisionDataset):
if k == "image":
image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
target.append(image)
elif k == "boxes":
anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
boxes = [obj["bndbox"] for obj in anno["annotation"]["object"]]
# boxes type xyxy
boxes = [(bb['xmin'], bb['ymin'], bb['xmax'], bb['ymax']) for bb in boxes]
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
target.append(boxes)
elif k == "boxes_category":
anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
boxes_category = [obj["name"] for obj in anno["annotation"]["object"]]
boxes_category = [self.class_names.index(bc)-1 for bc in boxes_category]
boxes_category = np.array(boxes_category, dtype=np.int32)
target.append(boxes_category)
elif k == "mask":
if "aug" in self.image_set:
mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE)
......@@ -81,9 +98,6 @@ class PascalVOC(VisionDataset):
mask = self._trans_mask(mask)
mask = mask[:, :, np.newaxis]
target.append(mask)
elif k == "boxes":
boxes = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
target.append(boxes)
elif k == "info":
if image is None:
image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
......@@ -128,7 +142,6 @@ class PascalVOC(VisionDataset):
return voc_dict
class_names = (
"background",
"aeroplane",
"bicycle",
"bird",
......@@ -151,7 +164,6 @@ class PascalVOC(VisionDataset):
"tvmonitor",
)
class_colors = [
[0, 0, 0],
[0, 0, 128],
[0, 128, 0],
[0, 128, 128],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册