You need to sign in or sign up before continuing.
提交 2ee6dd97 编写于 作者: Y Yuan Gao 提交者: qingqing01

add class aware sampling strategy (#3104)

* add class aware sampling strategy
* remove redundancy code
上级 e968c137
......@@ -70,6 +70,10 @@ def _prepare_data_config(feed, args_path):
'TYPE': type(feed.dataset).__source__
}
if feed.mode == 'TRAIN':
data_config['CLASS_AWARE_SAMPLING'] = getattr(
feed, 'class_aware_sampling', False)
if len(getattr(feed.dataset, 'images', [])) > 0:
data_config['IMAGES'] = feed.dataset.images
......@@ -301,7 +305,8 @@ class DataFeed(object):
bufsize=10,
use_process=False,
memsize=None,
use_padded_im_info=False):
use_padded_im_info=False,
class_aware_sampling=False):
super(DataFeed, self).__init__()
self.fields = fields
self.image_shape = image_shape
......@@ -318,6 +323,7 @@ class DataFeed(object):
self.memsize = memsize
self.dataset = dataset
self.use_padded_im_info = use_padded_im_info
self.class_aware_sampling = class_aware_sampling
if isinstance(dataset, dict):
self.dataset = DataSet(**dataset)
......@@ -447,7 +453,8 @@ class FasterRCNNTrainFeed(DataFeed):
bufsize=10,
num_workers=2,
use_process=False,
memsize=None):
memsize=None,
class_aware_sampling=False):
# XXX this should be handled by the data loader, since `fields` is
# given, just collect them
sample_transforms.append(ArrangeRCNN())
......@@ -464,7 +471,8 @@ class FasterRCNNTrainFeed(DataFeed):
bufsize=bufsize,
num_workers=num_workers,
use_process=use_process,
memsize=memsize)
memsize=memsize,
class_aware_sampling=class_aware_sampling)
# XXX these modes should be unified
self.mode = 'TRAIN'
......@@ -891,7 +899,8 @@ class YoloTrainFeed(DataFeed):
use_process=True,
memsize=None,
num_max_boxes=50,
mixup_epoch=250):
mixup_epoch=250,
class_aware_sampling=False):
sample_transforms.append(ArrangeYOLO())
super(YoloTrainFeed, self).__init__(
dataset,
......@@ -907,7 +916,8 @@ class YoloTrainFeed(DataFeed):
num_workers=num_workers,
bufsize=bufsize,
use_process=use_process,
memsize=memsize)
memsize=memsize,
class_aware_sampling=class_aware_sampling)
self.num_max_boxes = num_max_boxes
self.mixup_epoch = mixup_epoch
self.mode = 'TRAIN'
......
......@@ -21,6 +21,7 @@ import copy
from .roidb_source import RoiDbSource
from .simple_source import SimpleSource
from .iterator_source import IteratorSource
from .class_aware_sampling_roidb_source import ClassAwareSamplingRoiDbSource
def build_source(config):
......@@ -53,7 +54,12 @@ def build_source(config):
source_type = 'RoiDbSource'
if 'type' in data_cf:
if data_cf['type'] in ['VOCSource', 'COCOSource', 'RoiDbSource']:
source_type = 'RoiDbSource'
if 'class_aware_sampling' in args and args['class_aware_sampling']:
source_type = 'ClassAwareSamplingRoiDbSource'
else:
source_type = 'RoiDbSource'
if 'class_aware_sampling' in args:
del args['class_aware_sampling']
else:
source_type = data_cf['type']
del args['type']
......@@ -61,5 +67,7 @@ def build_source(config):
return RoiDbSource(**args)
elif source_type == 'SimpleSource':
return SimpleSource(**args)
elif source_type == 'ClassAwareSamplingRoiDbSource':
return ClassAwareSamplingRoiDbSource(**args)
else:
raise ValueError('source type not supported: ' + source_type)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#function:
# interface to load data from local files and parse it for samples,
# eg: roidb data in pickled files
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import random
import copy
import collections
import pickle as pkl
import numpy as np
from .roidb_source import RoiDbSource
class ClassAwareSamplingRoiDbSource(RoiDbSource):
""" interface to load class aware sampling roidb data from files
"""
def __init__(self,
anno_file,
image_dir=None,
samples=-1,
is_shuffle=True,
load_img=False,
cname2cid=None,
use_default_label=None,
mixup_epoch=-1,
with_background=True):
""" Init
Args:
fname (str): label file path
image_dir (str): root dir for images
samples (int): samples to load, -1 means all
is_shuffle (bool): whether to shuffle samples
load_img (bool): whether load data in this class
cname2cid (dict): the label name to id dictionary
use_default_label (bool):whether use the default mapping of label to id
mixup_epoch (int): parse mixup in first n epoch
with_background (bool): whether load background
as a class
"""
super(ClassAwareSamplingRoiDbSource, self).__init__(
anno_file=anno_file,
image_dir=image_dir,
samples=samples,
is_shuffle=is_shuffle,
load_img=load_img,
cname2cid=cname2cid,
use_default_label=use_default_label,
mixup_epoch=mixup_epoch,
with_background=with_background)
self._img_weights = None
def __str__(self):
return 'ClassAwareSamplingRoidbSource(fname:%s,epoch:%d,size:%d)' \
% (self._fname, self._epoch, self.size())
def next(self):
""" load next sample
"""
if self._epoch < 0:
self.reset()
_pos = np.random.choice(
self._samples, 1, replace=False, p=self._img_weights)[0]
sample = copy.deepcopy(self._roidb[_pos])
if self._load_img:
sample['image'] = self._load_image(sample['im_file'])
else:
sample['im_file'] = os.path.join(self._image_dir, sample['im_file'])
return sample
def _calc_img_weights(self):
""" calculate the probabilities of each sample
"""
imgs_cls = []
num_per_cls = {}
img_weights = []
for i, roidb in enumerate(self._roidb):
img_cls = set(
[k for cls in self._roidb[i]['gt_class'] for k in cls])
imgs_cls.append(img_cls)
for c in img_cls:
if c not in num_per_cls:
num_per_cls[c] = 1
else:
num_per_cls[c] += 1
for i in range(len(self._roidb)):
weights = 0
for c in imgs_cls[i]:
weights += 1 / num_per_cls[c]
img_weights.append(weights)
# Probabilities sum to 1
img_weights = img_weights / np.sum(img_weights)
return img_weights
def reset(self):
""" implementation of Dataset.reset
"""
if self._roidb is None:
self._roidb = self._load()
if self._img_weights is None:
self._img_weights = self._calc_img_weights()
self._samples = len(self._roidb)
if self._epoch < 0:
self._epoch = 0
......@@ -101,7 +101,8 @@ def load(anno_path, sample_num=-1, with_background=True):
gt_class[i][0] = catid2clsid[catid]
gt_bbox[i, :] = box['clean_bbox']
is_crowd[i][0] = box['iscrowd']
gt_poly[i] = box['segmentation']
if 'segmentation' in box:
gt_poly[i] = box['segmentation']
coco_rec = {
'im_file': im_fname,
......
......@@ -213,7 +213,7 @@ def bbox2out(results, clsid2catid, is_bbox_normalized=False):
for j in range(num):
dt = bboxes[k]
clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
catid = clsid2catid[clsid]
catid = (clsid2catid[int(clsid)])
if is_bbox_normalized:
xmin, ymin, xmax, ymax = \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册