From 2ee6dd972c9254da3e83494cd8c47acd85069a45 Mon Sep 17 00:00:00 2001 From: Yuan Gao Date: Mon, 14 Oct 2019 15:49:05 +0800 Subject: [PATCH] add class aware sampling strategy (#3104) * add class aware sampling strategy * remove redundancy code --- ppdet/data/data_feed.py | 20 ++- ppdet/data/source/__init__.py | 10 +- .../class_aware_sampling_roidb_source.py | 132 ++++++++++++++++++ ppdet/data/source/coco_loader.py | 3 +- ppdet/utils/coco_eval.py | 2 +- 5 files changed, 159 insertions(+), 8 deletions(-) create mode 100644 ppdet/data/source/class_aware_sampling_roidb_source.py diff --git a/ppdet/data/data_feed.py b/ppdet/data/data_feed.py index c79c5e949..7fab79154 100644 --- a/ppdet/data/data_feed.py +++ b/ppdet/data/data_feed.py @@ -70,6 +70,10 @@ def _prepare_data_config(feed, args_path): 'TYPE': type(feed.dataset).__source__ } + if feed.mode == 'TRAIN': + data_config['CLASS_AWARE_SAMPLING'] = getattr( + feed, 'class_aware_sampling', False) + if len(getattr(feed.dataset, 'images', [])) > 0: data_config['IMAGES'] = feed.dataset.images @@ -301,7 +305,8 @@ class DataFeed(object): bufsize=10, use_process=False, memsize=None, - use_padded_im_info=False): + use_padded_im_info=False, + class_aware_sampling=False): super(DataFeed, self).__init__() self.fields = fields self.image_shape = image_shape @@ -318,6 +323,7 @@ class DataFeed(object): self.memsize = memsize self.dataset = dataset self.use_padded_im_info = use_padded_im_info + self.class_aware_sampling = class_aware_sampling if isinstance(dataset, dict): self.dataset = DataSet(**dataset) @@ -447,7 +453,8 @@ class FasterRCNNTrainFeed(DataFeed): bufsize=10, num_workers=2, use_process=False, - memsize=None): + memsize=None, + class_aware_sampling=False): # XXX this should be handled by the data loader, since `fields` is # given, just collect them sample_transforms.append(ArrangeRCNN()) @@ -464,7 +471,8 @@ class FasterRCNNTrainFeed(DataFeed): bufsize=bufsize, num_workers=num_workers, use_process=use_process, - memsize=memsize) + memsize=memsize, + class_aware_sampling=class_aware_sampling) # XXX these modes should be unified self.mode = 'TRAIN' @@ -891,7 +899,8 @@ class YoloTrainFeed(DataFeed): use_process=True, memsize=None, num_max_boxes=50, - mixup_epoch=250): + mixup_epoch=250, + class_aware_sampling=False): sample_transforms.append(ArrangeYOLO()) super(YoloTrainFeed, self).__init__( dataset, @@ -907,7 +916,8 @@ class YoloTrainFeed(DataFeed): num_workers=num_workers, bufsize=bufsize, use_process=use_process, - memsize=memsize) + memsize=memsize, + class_aware_sampling=class_aware_sampling) self.num_max_boxes = num_max_boxes self.mixup_epoch = mixup_epoch self.mode = 'TRAIN' diff --git a/ppdet/data/source/__init__.py b/ppdet/data/source/__init__.py index ca0d5c833..e55df6962 100644 --- a/ppdet/data/source/__init__.py +++ b/ppdet/data/source/__init__.py @@ -21,6 +21,7 @@ import copy from .roidb_source import RoiDbSource from .simple_source import SimpleSource from .iterator_source import IteratorSource +from .class_aware_sampling_roidb_source import ClassAwareSamplingRoiDbSource def build_source(config): @@ -53,7 +54,12 @@ def build_source(config): source_type = 'RoiDbSource' if 'type' in data_cf: if data_cf['type'] in ['VOCSource', 'COCOSource', 'RoiDbSource']: - source_type = 'RoiDbSource' + if 'class_aware_sampling' in args and args['class_aware_sampling']: + source_type = 'ClassAwareSamplingRoiDbSource' + else: + source_type = 'RoiDbSource' + if 'class_aware_sampling' in args: + del args['class_aware_sampling'] else: source_type = data_cf['type'] del args['type'] @@ -61,5 +67,7 @@ def build_source(config): return RoiDbSource(**args) elif source_type == 'SimpleSource': return SimpleSource(**args) + elif source_type == 'ClassAwareSamplingRoiDbSource': + return ClassAwareSamplingRoiDbSource(**args) else: raise ValueError('source type not supported: ' + source_type) diff --git a/ppdet/data/source/class_aware_sampling_roidb_source.py b/ppdet/data/source/class_aware_sampling_roidb_source.py new file mode 100644 index 000000000..0175037c3 --- /dev/null +++ b/ppdet/data/source/class_aware_sampling_roidb_source.py @@ -0,0 +1,132 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#function: +# interface to load data from local files and parse it for samples, +# eg: roidb data in pickled files + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import os +import random + +import copy +import collections +import pickle as pkl +import numpy as np +from .roidb_source import RoiDbSource + + +class ClassAwareSamplingRoiDbSource(RoiDbSource): + """ interface to load class aware sampling roidb data from files + """ + + def __init__(self, + anno_file, + image_dir=None, + samples=-1, + is_shuffle=True, + load_img=False, + cname2cid=None, + use_default_label=None, + mixup_epoch=-1, + with_background=True): + """ Init + + Args: + fname (str): label file path + image_dir (str): root dir for images + samples (int): samples to load, -1 means all + is_shuffle (bool): whether to shuffle samples + load_img (bool): whether load data in this class + cname2cid (dict): the label name to id dictionary + use_default_label (bool):whether use the default mapping of label to id + mixup_epoch (int): parse mixup in first n epoch + with_background (bool): whether load background + as a class + """ + super(ClassAwareSamplingRoiDbSource, self).__init__( + anno_file=anno_file, + image_dir=image_dir, + samples=samples, + is_shuffle=is_shuffle, + load_img=load_img, + cname2cid=cname2cid, + use_default_label=use_default_label, + mixup_epoch=mixup_epoch, + with_background=with_background) + self._img_weights = None + + def __str__(self): + return 'ClassAwareSamplingRoidbSource(fname:%s,epoch:%d,size:%d)' \ + % (self._fname, self._epoch, self.size()) + + def next(self): + """ load next sample + """ + if self._epoch < 0: + self.reset() + + _pos = np.random.choice( + self._samples, 1, replace=False, p=self._img_weights)[0] + sample = copy.deepcopy(self._roidb[_pos]) + + if self._load_img: + sample['image'] = self._load_image(sample['im_file']) + else: + sample['im_file'] = os.path.join(self._image_dir, sample['im_file']) + + return sample + + def _calc_img_weights(self): + """ calculate the probabilities of each sample + """ + imgs_cls = [] + num_per_cls = {} + img_weights = [] + for i, roidb in enumerate(self._roidb): + img_cls = set( + [k for cls in self._roidb[i]['gt_class'] for k in cls]) + imgs_cls.append(img_cls) + for c in img_cls: + if c not in num_per_cls: + num_per_cls[c] = 1 + else: + num_per_cls[c] += 1 + + for i in range(len(self._roidb)): + weights = 0 + for c in imgs_cls[i]: + weights += 1 / num_per_cls[c] + img_weights.append(weights) + # Probabilities sum to 1 + img_weights = img_weights / np.sum(img_weights) + return img_weights + + def reset(self): + """ implementation of Dataset.reset + """ + if self._roidb is None: + self._roidb = self._load() + + if self._img_weights is None: + self._img_weights = self._calc_img_weights() + + self._samples = len(self._roidb) + + if self._epoch < 0: + self._epoch = 0 diff --git a/ppdet/data/source/coco_loader.py b/ppdet/data/source/coco_loader.py index ad62d8290..db1849890 100644 --- a/ppdet/data/source/coco_loader.py +++ b/ppdet/data/source/coco_loader.py @@ -101,7 +101,8 @@ def load(anno_path, sample_num=-1, with_background=True): gt_class[i][0] = catid2clsid[catid] gt_bbox[i, :] = box['clean_bbox'] is_crowd[i][0] = box['iscrowd'] - gt_poly[i] = box['segmentation'] + if 'segmentation' in box: + gt_poly[i] = box['segmentation'] coco_rec = { 'im_file': im_fname, diff --git a/ppdet/utils/coco_eval.py b/ppdet/utils/coco_eval.py index 7640145dd..70a128e18 100644 --- a/ppdet/utils/coco_eval.py +++ b/ppdet/utils/coco_eval.py @@ -213,7 +213,7 @@ def bbox2out(results, clsid2catid, is_bbox_normalized=False): for j in range(num): dt = bboxes[k] clsid, score, xmin, ymin, xmax, ymax = dt.tolist() - catid = clsid2catid[clsid] + catid = (clsid2catid[int(clsid)]) if is_bbox_normalized: xmin, ymin, xmax, ymax = \ -- GitLab