未验证 提交 91bc8c5b 编写于 作者: K Kaipeng Deng 提交者: GitHub

add param with_background in dataset voc (#2579)

上级 c5a2c0d2
...@@ -123,10 +123,12 @@ def load(fname, ...@@ -123,10 +123,12 @@ def load(fname,
elif os.path.isfile(fname): elif os.path.isfile(fname):
from . import voc_loader from . import voc_loader
if use_default_label is None or cname2cid is not None: if use_default_label is None or cname2cid is not None:
records, cname2cid = voc_loader.get_roidb(fname, samples, cname2cid) records, cname2cid = voc_loader.get_roidb(fname, samples, cname2cid,
with_background=with_background)
else: else:
records, cname2cid = voc_loader.load(fname, samples, records, cname2cid = voc_loader.load(fname, samples,
use_default_label) use_default_label,
with_background=with_background)
else: else:
raise ValueError('invalid file type when load data from file[%s]' % raise ValueError('invalid file type when load data from file[%s]' %
(fname)) (fname))
......
...@@ -18,7 +18,10 @@ import numpy as np ...@@ -18,7 +18,10 @@ import numpy as np
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
def get_roidb(anno_path, sample_num=-1, cname2cid=None): def get_roidb(anno_path,
sample_num=-1,
cname2cid=None,
with_background=True):
""" """
Load VOC records with annotations in xml directory 'anno_path' Load VOC records with annotations in xml directory 'anno_path'
...@@ -30,6 +33,9 @@ def get_roidb(anno_path, sample_num=-1, cname2cid=None): ...@@ -30,6 +33,9 @@ def get_roidb(anno_path, sample_num=-1, cname2cid=None):
anno_path (str): root directory for voc annotation data anno_path (str): root directory for voc annotation data
sample_num (int): number of samples to load, -1 means all sample_num (int): number of samples to load, -1 means all
cname2cid (dict): the label name to id dictionary cname2cid (dict): the label name to id dictionary
with_background (bool): whether load background as a class.
if True, total class number will
be 81. default True
Returns: Returns:
(records, catname2clsid) (records, catname2clsid)
...@@ -89,7 +95,7 @@ def get_roidb(anno_path, sample_num=-1, cname2cid=None): ...@@ -89,7 +95,7 @@ def get_roidb(anno_path, sample_num=-1, cname2cid=None):
cname = obj.find('name').text cname = obj.find('name').text
if not existence and cname not in cname2cid: if not existence and cname not in cname2cid:
# the background's id is 0, so need to add 1. # the background's id is 0, so need to add 1.
cname2cid[cname] = len(cname2cid) + 1 cname2cid[cname] = len(cname2cid) + int(with_background)
elif existence and cname not in cname2cid: elif existence and cname not in cname2cid:
raise KeyError( raise KeyError(
'Not found cname[%s] in cname2cid when map it to cid.' % 'Not found cname[%s] in cname2cid when map it to cid.' %
...@@ -129,7 +135,10 @@ def get_roidb(anno_path, sample_num=-1, cname2cid=None): ...@@ -129,7 +135,10 @@ def get_roidb(anno_path, sample_num=-1, cname2cid=None):
return [records, cname2cid] return [records, cname2cid]
def load(anno_path, sample_num=-1, use_default_label=True): def load(anno_path,
sample_num=-1,
use_default_label=True,
with_background=True):
""" """
Load VOC records with annotations in Load VOC records with annotations in
xml directory 'anno_path' xml directory 'anno_path'
...@@ -142,6 +151,9 @@ def load(anno_path, sample_num=-1, use_default_label=True): ...@@ -142,6 +151,9 @@ def load(anno_path, sample_num=-1, use_default_label=True):
@anno_path (str): root directory for voc annotation data @anno_path (str): root directory for voc annotation data
@sample_num (int): number of samples to load, -1 means all @sample_num (int): number of samples to load, -1 means all
@use_default_label (bool): whether use the default mapping of label to id @use_default_label (bool): whether use the default mapping of label to id
@with_background (bool): whether load background as a class.
if True, total class number will
be 81. default True
Returns: Returns:
(records, catname2clsid) (records, catname2clsid)
...@@ -165,21 +177,24 @@ def load(anno_path, sample_num=-1, use_default_label=True): ...@@ -165,21 +177,24 @@ def load(anno_path, sample_num=-1, use_default_label=True):
assert os.path.isfile(txt_file) and \ assert os.path.isfile(txt_file) and \
os.path.isdir(xml_path), 'invalid xml path' os.path.isdir(xml_path), 'invalid xml path'
# mapping category name to class id
# if with_background is True:
# background:0, first_class:1, second_class:2, ...
# if with_background is False:
# first_class:0, second_class:1, ...
records = [] records = []
ct = 0 ct = 0
cname2cid = {} cname2cid = {}
if not use_default_label: if not use_default_label:
label_path = os.path.join(part[0], 'ImageSets/Main/label_list.txt') label_path = os.path.join(part[0], 'ImageSets/Main/label_list.txt')
with open(label_path, 'r') as fr: with open(label_path, 'r') as fr:
label_id = 1 label_id = int(with_background)
for line in fr.readlines(): for line in fr.readlines():
cname2cid[line.strip()] = label_id cname2cid[line.strip()] = label_id
label_id += 1 label_id += 1
else: else:
cname2cid = pascalvoc_label() cname2cid = pascalvoc_label(with_background)
# mapping category name to class id
# background:0, first_class:1, second_class:2, ...
with open(txt_file, 'r') as fr: with open(txt_file, 'r') as fr:
while True: while True:
line = fr.readline() line = fr.readline()
...@@ -241,7 +256,7 @@ def load(anno_path, sample_num=-1, use_default_label=True): ...@@ -241,7 +256,7 @@ def load(anno_path, sample_num=-1, use_default_label=True):
return [records, cname2cid] return [records, cname2cid]
def pascalvoc_label(): def pascalvoc_label(with_background=True):
labels_map = { labels_map = {
'aeroplane': 1, 'aeroplane': 1,
'bicycle': 2, 'bicycle': 2,
...@@ -264,4 +279,6 @@ def pascalvoc_label(): ...@@ -264,4 +279,6 @@ def pascalvoc_label():
'train': 19, 'train': 19,
'tvmonitor': 20 'tvmonitor': 20
} }
if not with_background:
labels_map = {k: v - 1 for k, v in labels_map.items()}
return labels_map return labels_map
...@@ -80,7 +80,7 @@ def vocall_category_info(with_background=True): ...@@ -80,7 +80,7 @@ def vocall_category_info(with_background=True):
with_background (bool, default True): with_background (bool, default True):
whether load background as class 0. whether load background as class 0.
""" """
label_map = pascalvoc_label() label_map = pascalvoc_label(with_background)
label_map = sorted(label_map.items(), key=lambda x: x[1]) label_map = sorted(label_map.items(), key=lambda x: x[1])
cats = [l[0] for l in label_map] cats = [l[0] for l in label_map]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册