You need to sign in or sign up before continuing.
未验证 提交 9a3bb366 编写于 作者: W wangguanzhong 提交者: GitHub

support no_labeling training on voc (#3668)

上级 ba7e185c
......@@ -98,3 +98,7 @@ TestDataset:
**Q:** 如何打印网络FLOPs?
**A:**`configs/runtime.yml`中设置`print_flops: true`,同时需要安装PaddleSlim(比如:pip install paddleslim),即可打印模型的FLOPs。
**Q:** 如何使用无标注框进行训练?
**A:**`configs/dataset/coco.py` 或者`configs/dataset/voc.py`中的TrainDataset下设置`allow_empty: true`, 此时允许数据集加载无标注框进行训练。该功能支持coco,voc数据格式,RCNN系列和YOLO系列模型验证能够正常训练。另外,如果无标注框数据过多,会影响模型收敛,在TrainDataset下可以设置`empty_ratio: 0.1`对无标注框数据进行随机采样,控制无标注框的数据量占总数据量的比例,默认值为1.,即使用全部无标注框
......@@ -38,7 +38,7 @@ class COCODataSet(DetDataset):
allow_empty (bool): whether to load empty entry. False as default
empty_ratio (float): the ratio of empty record number to total
record's, if empty_ratio is out of [0. ,1.), do not sample the
records. 1. as default
records and use all the empty entries. 1. as default
"""
def __init__(self,
......@@ -63,7 +63,8 @@ class COCODataSet(DetDataset):
if self.empty_ratio < 0. or self.empty_ratio >= 1.:
return records
import random
sample_num = int(num * self.empty_ratio / (1 - self.empty_ratio))
sample_num = min(
int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
records = random.sample(records, sample_num)
return records
......
......@@ -42,6 +42,10 @@ class VOCDataSet(DetDataset):
sample_num (int): number of samples to load, -1 means all.
label_list (str): if use_default_label is False, will load
mapping between category and class index.
allow_empty (bool): whether to load empty entry. False as default
empty_ratio (float): the ratio of empty record number to total
record's, if empty_ratio is out of [0. ,1.), do not sample the
records and use all the empty entries. 1. as default
"""
def __init__(self,
......@@ -50,7 +54,9 @@ class VOCDataSet(DetDataset):
anno_path=None,
data_fields=['image'],
sample_num=-1,
label_list=None):
label_list=None,
allow_empty=False,
empty_ratio=1.):
super(VOCDataSet, self).__init__(
dataset_dir=dataset_dir,
image_dir=image_dir,
......@@ -58,6 +64,18 @@ class VOCDataSet(DetDataset):
data_fields=data_fields,
sample_num=sample_num)
self.label_list = label_list
self.allow_empty = allow_empty
self.empty_ratio = empty_ratio
def _sample_empty(self, records, num):
# if empty_ratio is out of [0. ,1.), do not sample the records
if self.empty_ratio < 0. or self.empty_ratio >= 1.:
return records
import random
sample_num = min(
int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
records = random.sample(records, sample_num)
return records
def parse_dataset(self, ):
anno_path = os.path.join(self.dataset_dir, self.anno_path)
......@@ -66,6 +84,7 @@ class VOCDataSet(DetDataset):
# mapping category name to class id
# first_class:0, second_class:1, ...
records = []
empty_records = []
ct = 0
cname2cid = {}
if self.label_list:
......@@ -164,15 +183,19 @@ class VOCDataSet(DetDataset):
if k in self.data_fields:
voc_rec[k] = v
if len(objs) != 0:
if len(objs) == 0:
empty_records.append(voc_rec)
else:
records.append(voc_rec)
ct += 1
if self.sample_num > 0 and ct >= self.sample_num:
break
assert len(records) > 0, 'not found any voc record in %s' % (
self.anno_path)
assert ct > 0, 'not found any voc record in %s' % (self.anno_path)
logger.debug('{} samples in file {}'.format(ct, anno_path))
if len(empty_records) > 0:
empty_records = self._sample_empty(empty_records, len(records))
records += empty_records
self.roidbs, self.cname2cid = records, cname2cid
def get_label_list(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册