未验证 提交 88d1f7dc 编写于 作者: J jayhenry 提交者: GitHub

Add Detection Finetune Task (#510)

* add object detection dataset, reader, task and demo
上级 13fc90ba
#coding:utf-8
import argparse
import os
import ast
import paddle.fluid as fluid
import paddlehub as hub
import numpy as np
from paddlehub.reader.cv_reader import ObjectDetectionReader
from paddlehub.dataset.base_cv_dataset import ObjectDetectionDataset
from paddlehub.contrib.ppdet.utils.coco_eval import bbox2out
from paddlehub.common.detection_config import get_model_type, get_feed_list, get_mid_feature
from paddlehub.common import detection_config as dconf
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for predict.")
parser.add_argument("--checkpoint_dir", type=str, default="paddlehub_finetune_ckpt", help="Path to save log data.")
parser.add_argument("--batch_size", type=int, default=2, help="Total examples' number in batch for training.")
parser.add_argument("--module", type=str, default="ssd", help="Module used as a feature extractor.")
parser.add_argument("--dataset", type=str, default="coco10", help="Dataset to finetune.")
# yapf: enable.
module_map = {
"yolov3": "yolov3_darknet53_coco2017",
"ssd": "ssd_vgg16_512_coco2017",
"faster_rcnn": "faster_rcnn_resnet50_coco2017",
}
def predict(args):
module_name = args.module # 'yolov3_darknet53_coco2017'
model_type = get_model_type(module_name) # 'yolo'
# define data
ds = hub.dataset.Coco10(model_type)
print("ds.num_labels", ds.num_labels)
data_reader = ObjectDetectionReader(dataset=ds, model_type=model_type)
# define model(program)
module = hub.Module(name=module_name)
if model_type == 'rcnn':
input_dict, output_dict, program = module.context(trainable=True, phase='train')
input_dict_pred, output_dict_pred, program_pred = module.context(trainable=False)
else:
input_dict, output_dict, program = module.context(trainable=True)
input_dict_pred = output_dict_pred = None
feed_list, pred_feed_list = get_feed_list(module_name, input_dict, input_dict_pred)
feature, pred_feature = get_mid_feature(module_name, output_dict, output_dict_pred)
config = hub.RunConfig(
use_data_parallel=False,
use_pyreader=True,
use_cuda=args.use_gpu,
batch_size=args.batch_size,
enable_memory_optim=False,
checkpoint_dir=args.checkpoint_dir,
strategy=hub.finetune.strategy.DefaultFinetuneStrategy())
task = hub.DetectionTask(
data_reader=data_reader,
num_classes=ds.num_labels,
feed_list=feed_list,
feature=feature,
predict_feed_list=pred_feed_list,
predict_feature=pred_feature,
model_type=model_type,
config=config)
data = ["./test/test_img_bird.jpg", "./test/test_img_cat.jpg",]
label_map = ds.label_dict()
run_states = task.predict(data=data, accelerate_mode=False)
results = [run_state.run_results for run_state in run_states]
for outs in results:
keys = ['im_shape', 'im_id', 'bbox']
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(keys, outs)
}
print("im_id", res['im_id'])
is_bbox_normalized = dconf.conf[model_type]['is_bbox_normalized']
clsid2catid = {}
for k in label_map:
clsid2catid[k] = k
bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized)
print(bbox_results)
if __name__ == "__main__":
args = parser.parse_args()
if not args.module in module_map:
hub.logger.error("module should in %s" % module_map.keys())
exit(1)
args.module = module_map[args.module]
predict(args)
# -*- coding:utf8 -*-
import argparse
import os
import ast
import paddle.fluid as fluid
import paddlehub as hub
from paddlehub.reader.cv_reader import ObjectDetectionReader
from paddlehub.dataset.base_cv_dataset import ObjectDetectionDataset
import numpy as np
from paddlehub.common.detection_config import get_model_type, get_feed_list, get_mid_feature
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch", type=int, default=50, help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for fine-tuning.")
parser.add_argument("--checkpoint_dir", type=str, default="paddlehub_finetune_ckpt", help="Path to save log data.")
parser.add_argument("--batch_size", type=int, default=8, help="Total examples' number in batch for training.")
parser.add_argument("--module", type=str, default="ssd", help="Module used as feature extractor.")
parser.add_argument("--dataset", type=str, default="coco_10", help="Dataset to finetune.")
parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.")
# yapf: enable.
module_map = {
"yolov3": "yolov3_darknet53_coco2017",
"ssd": "ssd_vgg16_512_coco2017",
"faster_rcnn": "faster_rcnn_resnet50_coco2017",
}
def finetune(args):
module_name = args.module # 'yolov3_darknet53_coco2017'
model_type = get_model_type(module_name) # 'yolo'
# define dataset
ds = hub.dataset.Coco10(model_type)
# base_path = '/home/local3/zhaopenghao/data/detect/paddle-job-84942-0'
# train_dir = 'train_data/images'
# train_list = 'train_data/coco/instances_coco.json'
# val_dir = 'eval_data/images'
# val_list = 'eval_data/coco/instances_coco.json'
# ds = ObjectDetectionDataset(base_path, train_dir, train_list, val_dir, val_list, val_dir, val_list, model_type=model_type)
# print(ds.label_dict())
print("ds.num_labels", ds.num_labels)
# define batch reader
data_reader = ObjectDetectionReader(dataset=ds, model_type=model_type)
# define model(program)
module = hub.Module(name=module_name)
if model_type == 'rcnn':
input_dict, output_dict, program = module.context(trainable=True, phase='train')
input_dict_pred, output_dict_pred, program_pred = module.context(trainable=False)
else:
input_dict, output_dict, program = module.context(trainable=True)
input_dict_pred = output_dict_pred = None
print("input_dict keys", input_dict.keys())
print("output_dict keys", output_dict.keys())
feed_list, pred_feed_list = get_feed_list(module_name, input_dict, input_dict_pred)
print("output_dict length:", len(output_dict))
print(output_dict.keys())
if output_dict_pred is not None:
print(output_dict_pred.keys())
feature, pred_feature = get_mid_feature(module_name, output_dict, output_dict_pred)
config = hub.RunConfig(
log_interval=10,
eval_interval=100,
use_data_parallel=args.use_data_parallel,
use_pyreader=True,
use_cuda=args.use_gpu,
num_epoch=args.num_epoch,
batch_size=args.batch_size,
enable_memory_optim=False,
checkpoint_dir=args.checkpoint_dir,
strategy=hub.finetune.strategy.DefaultFinetuneStrategy(learning_rate=0.00025, optimizer_name="adam"))
task = hub.DetectionTask(
data_reader=data_reader,
num_classes=ds.num_labels,
feed_list=feed_list,
feature=feature,
predict_feed_list=pred_feed_list,
predict_feature=pred_feature,
model_type=model_type,
config=config)
task.finetune_and_eval()
if __name__ == "__main__":
args = parser.parse_args()
if not args.module in module_map:
hub.logger.error("module should in %s" % module_map.keys())
exit(1)
args.module = module_map[args.module]
finetune(args)
......@@ -48,6 +48,7 @@ from .io.type import DataType
from .finetune.task import BaseTask
from .finetune.task import ClassifierTask
from .finetune.task import DetectionTask
from .finetune.task import TextClassifierTask
from .finetune.task import ImageClassifierTask
from .finetune.task import SequenceLabelTask
......
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
conf = {
"ssd": {
"with_background": True,
"is_bbox_normalized": True,
# "norm_type": "bn",
},
"yolo": {
"with_background": False,
"is_bbox_normalized": False,
# "norm_type": "sync_bn",
"mixup_epoch": 10,
"num_max_boxes": 50,
},
"rcnn": {
"with_background": True,
"is_bbox_normalized": False,
# "norm_type": "affine_channel",
}
}
ssd_train_ops = [
dict(op='DecodeImage', to_rgb=True, with_mixup=False),
dict(op='NormalizeBox'),
dict(
op='RandomDistort',
brightness_lower=0.875,
brightness_upper=1.125,
is_order=True),
dict(op='ExpandImage', max_ratio=4, prob=0.5),
dict(
op='CropImage',
batch_sampler=[[1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 0.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 0.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 0.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 0.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 0.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]],
satisfy_all=False,
avoid_no_bbox=False),
dict(op='ResizeImage', target_size=512, use_cv2=False, interp=1),
dict(op='RandomFlipImage', is_normalized=True),
dict(op='Permute'),
dict(
op='NormalizeImage',
mean=[104, 117, 123],
std=[1, 1, 1],
is_scale=False),
dict(op='ArrangeSSD')
]
ssd_eval_fields = ['image', 'im_shape', 'im_id', 'gt_box', 'gt_label', 'is_difficult']
ssd_eval_ops = [
dict(op='DecodeImage', to_rgb=True, with_mixup=False),
dict(op='NormalizeBox'),
dict(op='ResizeImage', target_size=512, use_cv2=False, interp=1),
dict(op='Permute'),
dict(
op='NormalizeImage',
mean=[104, 117, 123],
std=[1, 1, 1],
is_scale=False),
dict(op='ArrangeEvalSSD', fields=ssd_eval_fields)
]
ssd_predict_ops = [
dict(op='DecodeImage', to_rgb=True, with_mixup=False),
dict(op='ResizeImage', target_size=512, use_cv2=False, interp=1),
dict(op='Permute'),
dict(
op='NormalizeImage',
mean=[104, 117, 123],
std=[1, 1, 1],
is_scale=False),
dict(op='ArrangeTestSSD')
]
rcnn_train_ops = [
dict(op='DecodeImage', to_rgb=True),
dict(op='RandomFlipImage', prob=0.5),
dict(
op='NormalizeImage',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True,
is_channel_first=False),
dict(op='ResizeImage', target_size=800, max_size=1333, interp=1),
dict(op='Permute', to_bgr=False),
dict(op='ArrangeRCNN'),
]
rcnn_eval_ops = [
dict(op='DecodeImage', to_rgb=True),
dict(
op='NormalizeImage',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True,
is_channel_first=False),
dict(op='ResizeImage', target_size=800, max_size=1333, interp=1),
dict(op='Permute', to_bgr=False),
dict(op='ArrangeEvalRCNN'),
]
rcnn_predict_ops = [
dict(op='DecodeImage', to_rgb=True),
dict(
op='NormalizeImage',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True,
is_channel_first=False),
dict(op='ResizeImage', target_size=800, max_size=1333, interp=1),
dict(op='Permute', to_bgr=False),
dict(op='ArrangeTestRCNN'),
]
yolo_train_ops = [
dict(op='DecodeImage', to_rgb=True, with_mixup=True),
dict(op='MixupImage', alpha=1.5, beta=1.5),
dict(op='ColorDistort'),
dict(op='RandomExpand', fill_value=[123.675, 116.28, 103.53]),
dict(op='RandomCrop'),
dict(op='RandomFlipImage', is_normalized=False),
dict(op='Resize', target_dim=608, interp='random'),
dict(op='NormalizePermute',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.120, 57.375]),
dict(op='NormalizeBox'),
dict(op='ArrangeYOLO'),
]
yolo_eval_ops = [
dict(op='DecodeImage', to_rgb=True),
dict(op='ResizeImage', target_size=608, interp=2),
dict(
op='NormalizeImage',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True,
is_channel_first=False),
dict(op='Permute', to_bgr=False),
dict(op='ArrangeEvalYOLO'),
]
yolo_predict_ops = [
dict(op='DecodeImage', to_rgb=True),
dict(op='ResizeImage', target_size=608, interp=2),
dict(
op='NormalizeImage',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True,
is_channel_first=False),
dict(op='Permute', to_bgr=False),
dict(op='ArrangeTestYOLO'),
]
feed_config = {
"ssd": {
"train": {
"fields": ['image', 'gt_box', 'gt_label'],
"OPS": ssd_train_ops,
"IS_PADDING": False,
},
"dev": {
# ['image', 'im_shape', 'im_id', 'gt_box', 'gt_label', 'is_difficult']
"fields": ssd_eval_fields,
"OPS": ssd_eval_ops,
"IS_PADDING": False,
},
"predict": {
"fields": ['image', 'im_id', 'im_shape'],
# "fields": ['image', 'im_id'],
"OPS": ssd_predict_ops,
"IS_PADDING": False,
},
},
"rcnn": {
"train": {
"fields": ['image', 'im_info', 'im_id', 'gt_box', 'gt_label', 'is_crowd'],
"OPS": rcnn_train_ops,
"IS_PADDING": True,
"COARSEST_STRIDE": 32,
},
"dev": {
"fields": ['image', 'im_info', 'im_id', 'im_shape', 'gt_box',
'gt_label', 'is_difficult'],
"OPS": rcnn_eval_ops,
"IS_PADDING": True,
"COARSEST_STRIDE": 32,
"USE_PADDED_IM_INFO": True,
},
"predict": {
"fields": ['image', 'im_info', 'im_id', 'im_shape'],
"OPS": rcnn_predict_ops,
"IS_PADDING": True,
"COARSEST_STRIDE": 32,
"USE_PADDED_IM_INFO": True,
},
},
"yolo": {
"train": {
"fields": ['image', 'gt_box', 'gt_label', 'gt_score'],
"OPS": yolo_train_ops,
"RANDOM_SHAPES": [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
},
"dev": {
"fields": ['image', 'im_size', 'im_id', 'gt_box', 'gt_label', 'is_difficult'],
"OPS": yolo_eval_ops,
},
"predict": {
"fields": ['image', 'im_size', 'im_id'],
"OPS": yolo_predict_ops,
},
},
}
def get_model_type(module_name):
if 'yolo' in module_name:
return 'yolo'
elif 'ssd' in module_name:
return 'ssd'
elif 'rcnn' in module_name:
return 'rcnn'
else:
raise ValueError("module {} not supported".format(module_name))
def get_feed_list(module_name, input_dict, input_dict_pred=None):
pred_feed_list = None
if 'yolo' in module_name:
img = input_dict["image"]
im_size = input_dict["im_size"]
feed_list = [img.name, im_size.name]
elif 'ssd' in module_name:
image = input_dict["image"]
# image_shape = input_dict["im_shape"]
image_shape = input_dict["im_size"]
feed_list = [image.name, image_shape.name]
elif 'rcnn' in module_name:
image = input_dict['image']
im_info = input_dict['im_info']
gt_bbox = input_dict['gt_bbox']
gt_class = input_dict['gt_class']
is_crowd = input_dict['is_crowd']
feed_list = [image.name, im_info.name, gt_bbox.name, gt_class.name, is_crowd.name]
assert input_dict_pred is not None
image = input_dict_pred['image']
im_info = input_dict_pred['im_info']
im_shape = input_dict['im_shape']
pred_feed_list = [image.name, im_info.name, im_shape.name]
else:
raise NotImplementedError
return feed_list, pred_feed_list
def get_mid_feature(module_name, output_dict, output_dict_pred=None):
feature_pred = None
if 'yolo' in module_name:
feature = output_dict['head_features']
elif 'ssd' in module_name:
feature = output_dict['body_features']
elif 'rcnn' in module_name:
head_feat = output_dict['head_feat']
rpn_cls_loss = output_dict['rpn_cls_loss']
rpn_reg_loss = output_dict['rpn_reg_loss']
generate_proposal_labels = output_dict['generate_proposal_labels']
feature = [head_feat, rpn_cls_loss, rpn_reg_loss, generate_proposal_labels]
assert output_dict_pred is not None
head_feat = output_dict_pred['head_feat']
rois = output_dict_pred['rois']
feature_pred = [head_feat, rois]
else:
raise NotImplementedError
return feature, feature_pred
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# function:
# interface for accessing data samples in stream
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class Dataset(object):
"""interface to access a stream of data samples"""
def __init__(self):
self._epoch = -1
def __next__(self):
return self.next()
def __iter__(self):
return self
def __str__(self):
return "{}(fname:{}, epoch:{:d}, size:{:d}, pos:{:d})".format(
type(self).__name__, self._fname, self._epoch, self.size(),
self._pos)
def next(self):
"""get next sample"""
raise NotImplementedError(
'%s.next not available' % (self.__class__.__name__))
def reset(self):
"""reset to initial status and begins a new epoch"""
raise NotImplementedError(
'%s.reset not available' % (self.__class__.__name__))
def size(self):
"""get number of samples in this dataset"""
raise NotImplementedError(
'%s.size not available' % (self.__class__.__name__))
def drained(self):
"""whether all sampled has been readed out for this epoch"""
raise NotImplementedError(
'%s.drained not available' % (self.__class__.__name__))
def epoch_id(self):
"""return epoch id for latest sample"""
raise NotImplementedError(
'%s.epoch_id not available' % (self.__class__.__name__))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# function:
# Interface to build readers for detection data like COCO or VOC
#
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from numbers import Integral
import logging
from .source import build_source
from .transform import build_mapper, map, batch, batch_map
logger = logging.getLogger(__name__)
class Reader(object):
"""Interface to make readers for training or evaluation"""
def __init__(self, data_cf, trans_conf, maxiter=-1):
self._data_cf = data_cf
self._trans_conf = trans_conf
self._maxiter = maxiter
self._cname2cid = None
assert isinstance(self._maxiter, Integral), "maxiter should be int"
def _make_reader(self, mode, my_source=None):
"""Build reader for training or validation"""
if my_source is None:
file_conf = self._data_cf[mode]
# 1, Build data source
sc_conf = {'data_cf': file_conf, 'cname2cid': self._cname2cid}
sc = build_source(sc_conf)
else:
sc = my_source
# 2, Buid a transformed dataset
ops = self._trans_conf[mode]['OPS']
batchsize = self._trans_conf[mode]['BATCH_SIZE']
drop_last = False if 'DROP_LAST' not in \
self._trans_conf[mode] else self._trans_conf[mode]['DROP_LAST']
mapper = build_mapper(ops, {'is_train': mode == 'TRAIN'})
worker_args = None
if 'WORKER_CONF' in self._trans_conf[mode]:
worker_args = self._trans_conf[mode]['WORKER_CONF']
worker_args = {k.lower(): v for k, v in worker_args.items()}
mapped_ds = map(sc, mapper, worker_args)
# In VAL mode, gt_bbox, gt_label can be empty, and should
# not be dropped
batched_ds = batch(
mapped_ds, batchsize, drop_last, drop_empty=(mode != "VAL"))
trans_conf = {k.lower(): v for k, v in self._trans_conf[mode].items()}
need_keys = {
'is_padding',
'coarsest_stride',
'random_shapes',
'multi_scales',
'use_padded_im_info',
'enable_multiscale_test',
'num_scale',
}
bm_config = {
key: value
for key, value in trans_conf.items() if key in need_keys
}
batched_ds = batch_map(batched_ds, bm_config)
batched_ds.reset()
if mode.lower() == 'train':
if self._cname2cid is not None:
logger.warn('cname2cid already set, it will be overridden')
self._cname2cid = getattr(sc, 'cname2cid', None)
# 3, Build a reader
maxit = -1 if self._maxiter <= 0 else self._maxiter
def _reader():
n = 0
while True:
for _batch in batched_ds:
yield _batch
n += 1
if maxit > 0 and n == maxit:
return
batched_ds.reset()
if maxit <= 0:
return
_reader._fname = None
if hasattr(sc, '_fname'):
_reader.annotation = sc._fname
if hasattr(sc, 'get_imid2path'):
_reader.imid2path = sc.get_imid2path()
return _reader
def train(self):
"""Build reader for training"""
return self._make_reader('TRAIN')
def val(self):
"""Build reader for validation"""
return self._make_reader('VAL')
def test(self):
"""Build reader for inference"""
return self._make_reader('TEST')
@classmethod
def create(cls,
mode,
data_config,
transform_config,
max_iter=-1,
my_source=None,
ret_iter=True):
""" create a specific reader """
reader = Reader({mode: data_config}, {mode: transform_config}, max_iter)
if ret_iter:
return reader._make_reader(mode, my_source)
else:
return reader
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
from .roidb_source import RoiDbSource
from .simple_source import SimpleSource
from .iterator_source import IteratorSource
from .class_aware_sampling_roidb_source import ClassAwareSamplingRoiDbSource
def build_source(config):
"""
Build dataset from source data, default source type is 'RoiDbSource'
Args:
config (dict): should have following structure:
{
data_cf (dict):
anno_file (str): label file or image list file path
image_dir (str): root directory for images
samples (int): number of samples to load, -1 means all
is_shuffle (bool): should samples be shuffled
load_img (bool): should images be loaded
mixup_epoch (int): parse mixup in first n epoch
with_background (bool): whether load background as a class
cname2cid (dict): the label name to id dictionary
}
"""
if 'data_cf' in config:
data_cf = config['data_cf']
data_cf['cname2cid'] = config['cname2cid']
else:
data_cf = config
data_cf = {k.lower(): v for k, v in data_cf.items()}
args = copy.deepcopy(data_cf)
# defaut type is 'RoiDbSource'
source_type = 'RoiDbSource'
if 'type' in data_cf:
if data_cf['type'] in ['VOCSource', 'COCOSource', 'RoiDbSource']:
if 'class_aware_sampling' in args and args['class_aware_sampling']:
source_type = 'ClassAwareSamplingRoiDbSource'
else:
source_type = 'RoiDbSource'
if 'class_aware_sampling' in args:
del args['class_aware_sampling']
else:
source_type = data_cf['type']
del args['type']
if source_type == 'RoiDbSource':
return RoiDbSource(**args)
elif source_type == 'SimpleSource':
return SimpleSource(**args)
elif source_type == 'ClassAwareSamplingRoiDbSource':
return ClassAwareSamplingRoiDbSource(**args)
else:
raise ValueError('source type not supported: ' + source_type)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#function:
# interface to load data from local files and parse it for samples,
# eg: roidb data in pickled files
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import random
import copy
import collections
import pickle as pkl
import numpy as np
from .roidb_source import RoiDbSource
class ClassAwareSamplingRoiDbSource(RoiDbSource):
""" interface to load class aware sampling roidb data from files
"""
def __init__(self,
anno_file,
image_dir=None,
samples=-1,
is_shuffle=True,
load_img=False,
cname2cid=None,
use_default_label=None,
mixup_epoch=-1,
with_background=True):
""" Init
Args:
fname (str): label file path
image_dir (str): root dir for images
samples (int): samples to load, -1 means all
is_shuffle (bool): whether to shuffle samples
load_img (bool): whether load data in this class
cname2cid (dict): the label name to id dictionary
use_default_label (bool):whether use the default mapping of label to id
mixup_epoch (int): parse mixup in first n epoch
with_background (bool): whether load background
as a class
"""
super(ClassAwareSamplingRoiDbSource, self).__init__(
anno_file=anno_file,
image_dir=image_dir,
samples=samples,
is_shuffle=is_shuffle,
load_img=load_img,
cname2cid=cname2cid,
use_default_label=use_default_label,
mixup_epoch=mixup_epoch,
with_background=with_background)
self._img_weights = None
def __str__(self):
return 'ClassAwareSamplingRoidbSource(fname:%s,epoch:%d,size:%d)' \
% (self._fname, self._epoch, self.size())
def next(self):
""" load next sample
"""
if self._epoch < 0:
self.reset()
_pos = np.random.choice(
self._samples, 1, replace=False, p=self._img_weights)[0]
sample = copy.deepcopy(self._roidb[_pos])
if self._load_img:
sample['image'] = self._load_image(sample['im_file'])
else:
sample['im_file'] = os.path.join(self._image_dir, sample['im_file'])
return sample
def _calc_img_weights(self):
""" calculate the probabilities of each sample
"""
imgs_cls = []
num_per_cls = {}
img_weights = []
for i, roidb in enumerate(self._roidb):
img_cls = set(
[k for cls in self._roidb[i]['gt_class'] for k in cls])
imgs_cls.append(img_cls)
for c in img_cls:
if c not in num_per_cls:
num_per_cls[c] = 1
else:
num_per_cls[c] += 1
for i in range(len(self._roidb)):
weights = 0
for c in imgs_cls[i]:
weights += 1 / num_per_cls[c]
img_weights.append(weights)
# Probabilities sum to 1
img_weights = img_weights / np.sum(img_weights)
return img_weights
def reset(self):
""" implementation of Dataset.reset
"""
if self._roidb is None:
self._roidb = self._load()
if self._img_weights is None:
self._img_weights = self._calc_img_weights()
self._samples = len(self._roidb)
if self._epoch < 0:
self._epoch = 0
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from pycocotools.coco import COCO
import logging
logger = logging.getLogger(__name__)
def load(anno_path, sample_num=-1, with_background=True):
"""
Load COCO records with annotations in json file 'anno_path'
Args:
anno_path (str): json file path
sample_num (int): number of samples to load, -1 means all
with_background (bool): whether load background as a class.
if True, total class number will
be 81. default True
Returns:
(records, cname2cid)
'records' is list of dict whose structure is:
{
'im_file': im_fname, # image file name
'im_id': img_id, # image id
'h': im_h, # height of image
'w': im_w, # width
'is_crowd': is_crowd,
'gt_score': gt_score,
'gt_class': gt_class,
'gt_bbox': gt_bbox,
'gt_poly': gt_poly,
}
'cname2cid' is a dict used to map category name to class id
"""
assert anno_path.endswith('.json'), 'invalid coco annotation file: ' \
+ anno_path
coco = COCO(anno_path)
img_ids = coco.getImgIds()
cat_ids = coco.getCatIds()
records = []
ct = 0
# when with_background = True, mapping category to classid, like:
# background:0, first_class:1, second_class:2, ...
catid2clsid = dict(
{catid: i + int(with_background)
for i, catid in enumerate(cat_ids)})
cname2cid = dict({
coco.loadCats(catid)[0]['name']: clsid
for catid, clsid in catid2clsid.items()
})
for img_id in img_ids:
img_anno = coco.loadImgs(img_id)[0]
im_fname = img_anno['file_name']
im_w = float(img_anno['width'])
im_h = float(img_anno['height'])
ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False)
instances = coco.loadAnns(ins_anno_ids)
bboxes = []
for inst in instances:
x, y, box_w, box_h = inst['bbox']
x1 = max(0, x)
y1 = max(0, y)
x2 = min(im_w - 1, x1 + max(0, box_w - 1))
y2 = min(im_h - 1, y1 + max(0, box_h - 1))
if inst['area'] > 0 and x2 >= x1 and y2 >= y1:
inst['clean_bbox'] = [x1, y1, x2, y2]
bboxes.append(inst)
else:
logger.warn(
'Found an invalid bbox in annotations: im_id: {}, area: {} x1: {}, y1: {}, x2: {}, y2: {}.'
.format(img_id, float(inst['area']), x1, y1, x2, y2))
num_bbox = len(bboxes)
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
gt_score = np.ones((num_bbox, 1), dtype=np.float32)
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
difficult = np.zeros((num_bbox, 1), dtype=np.int32)
gt_poly = [None] * num_bbox
for i, box in enumerate(bboxes):
catid = box['category_id']
gt_class[i][0] = catid2clsid[catid]
gt_bbox[i, :] = box['clean_bbox']
is_crowd[i][0] = box['iscrowd']
if 'segmentation' in box:
gt_poly[i] = box['segmentation']
coco_rec = {
'im_file': im_fname,
'im_id': np.array([img_id]),
'h': im_h,
'w': im_w,
'is_crowd': is_crowd,
'gt_class': gt_class,
'gt_bbox': gt_bbox,
'gt_score': gt_score,
'gt_poly': gt_poly,
'difficult': difficult
}
logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
im_fname, img_id, im_h, im_w))
records.append(coco_rec)
ct += 1
if sample_num > 0 and ct >= sample_num:
break
assert len(records) > 0, 'not found any coco record in %s' % (anno_path)
logger.info('{} samples in file {}'.format(ct, anno_path))
return records, cname2cid
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import copy
import logging
logger = logging.getLogger(__name__)
from ..dataset import Dataset
class IteratorSource(Dataset):
"""
Load data samples from iterator in stream mode
Args:
iter_maker (callable): callable function to generate a iter
samples (int): number of samples to load, -1 means all
"""
def __init__(self, iter_maker, samples=-1, **kwargs):
super(IteratorSource, self).__init__()
self._epoch = -1
self._iter_maker = iter_maker
self._data_iter = None
self._pos = -1
self._drained = False
self._samples = samples
self._sample_num = -1
def next(self):
if self._epoch < 0:
self.reset()
if self._data_iter is not None:
try:
sample = next(self._data_iter)
self._pos += 1
ret = sample
except StopIteration as e:
if self._sample_num <= 0:
self._sample_num = self._pos
elif self._sample_num != self._pos:
logger.info('num of loaded samples is different '
'with previouse setting[prev:%d,now:%d]' %
(self._sample_num, self._pos))
self._sample_num = self._pos
self._data_iter = None
self._drained = True
raise e
else:
raise StopIteration("no more data in " + str(self))
if self._samples > 0 and self._pos >= self._samples:
self._data_iter = None
self._drained = True
raise StopIteration("no more data in " + str(self))
else:
return ret
def reset(self):
if self._data_iter is None:
self._data_iter = self._iter_maker()
if self._epoch < 0:
self._epoch = 0
else:
self._epoch += 1
self._pos = 0
self._drained = False
def size(self):
return self._sample_num
def drained(self):
assert self._epoch >= 0, "the first epoch has not started yet"
return self._pos >= self.size()
def epoch_id(self):
return self._epoch
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# function:
# load data records from local files(maybe in COCO or VOC data formats)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import numpy as np
import logging
import pickle as pkl
logger = logging.getLogger(__name__)
def check_records(records):
""" check the fields of 'records' must contains some keys
"""
needed_fields = [
'im_file', 'im_id', 'h', 'w', 'is_crowd', 'gt_class', 'gt_bbox',
'gt_poly'
]
for i, rec in enumerate(records):
for k in needed_fields:
assert k in rec, 'not found field[%s] in record[%d]' % (k, i)
def load_roidb(anno_file, sample_num=-1):
""" load normalized data records from file
'anno_file' which is a pickled file.
And the records should has a structure:
{
'im_file': str, # image file name
'im_id': int, # image id
'h': int, # height of image
'w': int, # width of image
'is_crowd': bool,
'gt_class': list of np.ndarray, # classids info
'gt_bbox': list of np.ndarray, # bounding box info
'gt_poly': list of int, # poly info
}
Args:
anno_file (str): file name for picked records
sample_num (int): number of samples to load
Returns:
list of records for detection model training
"""
assert anno_file.endswith('.roidb'), 'invalid roidb file[%s]' % (anno_file)
with open(anno_file, 'rb') as f:
roidb = f.read()
# for support python3 and python2
try:
records, cname2cid = pkl.loads(roidb, encoding='bytes')
except:
records, cname2cid = pkl.loads(roidb)
assert type(records) is list, 'invalid data type from roidb'
if sample_num > 0 and sample_num < len(records):
records = records[:sample_num]
return records, cname2cid
def load(fname,
samples=-1,
with_background=True,
with_cat2id=False,
use_default_label=None,
cname2cid=None):
""" Load data records from 'fnames'
Args:
fnames (str): file name for data record, eg:
instances_val2017.json or COCO17_val2017.roidb
samples (int): number of samples to load, default to all
with_background (bool): whether load background as a class.
default True.
with_cat2id (bool): whether return cname2cid info out
use_default_label (bool): whether use the default mapping of label to id
cname2cid (dict): the mapping of category name to id
Returns:
list of loaded records whose structure is:
{
'im_file': str, # image file name
'im_id': int, # image id
'h': int, # height of image
'w': int, # width of image
'is_crowd': bool,
'gt_class': list of np.ndarray, # classids info
'gt_bbox': list of np.ndarray, # bounding box info
'gt_poly': list of int, # poly info
}
"""
if fname.endswith('.roidb'):
records, cname2cid = load_roidb(fname, samples)
elif fname.endswith('.json'):
from . import coco_loader
records, cname2cid = coco_loader.load(fname, samples, with_background)
elif "wider_face" in fname:
from . import widerface_loader
records = widerface_loader.load(fname, samples)
return records
elif os.path.isfile(fname):
from . import voc_loader
if use_default_label is None or cname2cid is not None:
records, cname2cid = voc_loader.get_roidb(
fname, samples, cname2cid, with_background=with_background)
else:
records, cname2cid = voc_loader.load(
fname,
samples,
use_default_label,
with_background=with_background)
else:
raise ValueError(
'invalid file type when load data from file[%s]' % (fname))
check_records(records)
if with_cat2id:
return records, cname2cid
else:
return records
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#function:
# interface to load data from local files and parse it for samples,
# eg: roidb data in pickled files
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import random
import copy
import pickle as pkl
from ..dataset import Dataset
class RoiDbSource(Dataset):
""" interface to load roidb data from files
"""
def __init__(self,
anno_file,
image_dir=None,
samples=-1,
is_shuffle=True,
load_img=False,
cname2cid=None,
use_default_label=None,
mixup_epoch=-1,
with_background=True):
""" Init
Args:
fname (str): label file path
image_dir (str): root dir for images
samples (int): samples to load, -1 means all
is_shuffle (bool): whether to shuffle samples
load_img (bool): whether load data in this class
cname2cid (dict): the label name to id dictionary
use_default_label (bool):whether use the default mapping of label to id
mixup_epoch (int): parse mixup in first n epoch
with_background (bool): whether load background
as a class
"""
super(RoiDbSource, self).__init__()
self._epoch = -1
assert os.path.isfile(anno_file) or os.path.isdir(anno_file), \
'anno_file {} is not a file or a directory'.format(anno_file)
self._fname = anno_file
self._image_dir = image_dir if image_dir is not None else ''
if image_dir is not None:
assert os.path.isdir(image_dir), \
'image_dir {} is not a directory'.format(image_dir)
self._roidb = None
self._pos = -1
self._drained = False
self._samples = samples
self._is_shuffle = is_shuffle
self._load_img = load_img
self.use_default_label = use_default_label
self._mixup_epoch = mixup_epoch
self._with_background = with_background
self.cname2cid = cname2cid
self._imid2path = None
def __str__(self):
return 'RoiDbSource(fname:%s,epoch:%d,size:%d,pos:%d)' \
% (self._fname, self._epoch, self.size(), self._pos)
def next(self):
""" load next sample
"""
if self._epoch < 0:
self.reset()
if self._pos >= self._samples:
self._drained = True
raise StopIteration('%s no more data' % (str(self)))
sample = copy.deepcopy(self._roidb[self._pos])
if self._load_img:
sample['image'] = self._load_image(sample['im_file'])
else:
sample['im_file'] = os.path.join(self._image_dir, sample['im_file'])
if self._epoch < self._mixup_epoch:
mix_idx = random.randint(1, self._samples - 1)
mix_pos = (mix_idx + self._pos) % self._samples
sample['mixup'] = copy.deepcopy(self._roidb[mix_pos])
if self._load_img:
sample['mixup']['image'] = \
self._load_image(sample['mixup']['im_file'])
else:
sample['mixup']['im_file'] = \
os.path.join(self._image_dir, sample['mixup']['im_file'])
self._pos += 1
return sample
def _load(self):
""" load data from file
"""
from . import loader
records, cname2cid = loader.load(self._fname, self._samples,
self._with_background, True,
self.use_default_label, self.cname2cid)
self.cname2cid = cname2cid
return records
def _load_image(self, where):
fn = os.path.join(self._image_dir, where)
with open(fn, 'rb') as f:
return f.read()
def reset(self):
""" implementation of Dataset.reset
"""
if self._roidb is None:
self._roidb = self._load()
self._samples = len(self._roidb)
if self._is_shuffle:
random.shuffle(self._roidb)
if self._epoch < 0:
self._epoch = 0
else:
self._epoch += 1
self._pos = 0
self._drained = False
def size(self):
""" implementation of Dataset.size
"""
return len(self._roidb)
def drained(self):
""" implementation of Dataset.drained
"""
assert self._epoch >= 0, 'The first epoch has not begin!'
return self._pos >= self.size()
def epoch_id(self):
""" return epoch id for latest sample
"""
return self._epoch
def get_imid2path(self):
"""return image id to image path map"""
if self._imid2path is None:
self._imid2path = {}
for record in self._roidb:
im_id = record['im_id']
im_id = im_id if isinstance(im_id, int) else im_id[0]
im_path = os.path.join(self._image_dir, record['im_file'])
self._imid2path[im_id] = im_path
return self._imid2path
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# function:
# interface to load data from txt file.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import copy
from ..dataset import Dataset
class SimpleSource(Dataset):
"""
Load image files for testing purpose
Args:
images (list): list of path of images
samples (int): number of samples to load, -1 means all
load_img (bool): should images be loaded
"""
def __init__(self, images=[], samples=-1, load_img=True, **kwargs):
super(SimpleSource, self).__init__()
self._epoch = -1
for image in images:
assert image != '' and os.path.isfile(image), \
"Image {} not found".format(image)
self._images = images
self._fname = None
self._simple = None
self._pos = -1
self._drained = False
self._samples = samples
self._load_img = load_img
self._imid2path = {}
def next(self):
if self._epoch < 0:
self.reset()
if self._pos >= self.size():
self._drained = True
raise StopIteration("no more data in " + str(self))
else:
sample = copy.deepcopy(self._simple[self._pos])
if self._load_img:
sample['image'] = self._load_image(sample['im_file'])
self._pos += 1
return sample
def _load(self):
ct = 0
records = []
for image in self._images:
if self._samples > 0 and ct >= self._samples:
break
rec = {'im_id': np.array([ct]), 'im_file': image}
self._imid2path[ct] = image
ct += 1
records.append(rec)
assert len(records) > 0, "no image file found"
return records
def _load_image(self, where):
with open(where, 'rb') as f:
return f.read()
def reset(self):
if self._simple is None:
self._simple = self._load()
if self._epoch < 0:
self._epoch = 0
else:
self._epoch += 1
self._pos = 0
self._drained = False
def size(self):
return len(self._simple)
def drained(self):
assert self._epoch >= 0, "the first epoch has not started yet"
return self._pos >= self.size()
def epoch_id(self):
return self._epoch
def get_imid2path(self):
"""return image id to image path map"""
return self._imid2path
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import xml.etree.ElementTree as ET
def get_roidb(anno_path, sample_num=-1, cname2cid=None, with_background=True):
"""
Load VOC records with annotations in xml directory 'anno_path'
Notes:
${anno_path} must contains xml file and image file path for annotations
Args:
anno_path (str): root directory for voc annotation data
sample_num (int): number of samples to load, -1 means all
cname2cid (dict): the label name to id dictionary
with_background (bool): whether load background as a class.
if True, total class number will
be 81. default True
Returns:
(records, catname2clsid)
'records' is list of dict whose structure is:
{
'im_file': im_fname, # image file name
'im_id': im_id, # image id
'h': im_h, # height of image
'w': im_w, # width
'is_crowd': is_crowd,
'gt_class': gt_class,
'gt_bbox': gt_bbox,
'gt_poly': gt_poly,
}
'cname2id' is a dict to map category name to class id
"""
data_dir = os.path.dirname(anno_path)
records = []
ct = 0
existence = False if cname2cid is None else True
if cname2cid is None:
cname2cid = {}
# mapping category name to class id
# background:0, first_class:1, second_class:2, ...
with open(anno_path, 'r') as fr:
while True:
line = fr.readline()
if not line:
break
img_file, xml_file = [os.path.join(data_dir, x) \
for x in line.strip().split()[:2]]
if not os.path.isfile(xml_file):
continue
tree = ET.parse(xml_file)
if tree.find('id') is None:
im_id = np.array([ct])
else:
im_id = np.array([int(tree.find('id').text)])
objs = tree.findall('object')
im_w = float(tree.find('size').find('width').text)
im_h = float(tree.find('size').find('height').text)
gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
gt_class = np.zeros((len(objs), 1), dtype=np.int32)
gt_score = np.ones((len(objs), 1), dtype=np.float32)
is_crowd = np.zeros((len(objs), 1), dtype=np.int32)
difficult = np.zeros((len(objs), 1), dtype=np.int32)
for i, obj in enumerate(objs):
cname = obj.find('name').text
if not existence and cname not in cname2cid:
# the background's id is 0, so need to add 1.
cname2cid[cname] = len(cname2cid) + int(with_background)
elif existence and cname not in cname2cid:
raise KeyError(
'Not found cname[%s] in cname2cid when map it to cid.' %
(cname))
gt_class[i][0] = cname2cid[cname]
_difficult = int(obj.find('difficult').text)
x1 = float(obj.find('bndbox').find('xmin').text)
y1 = float(obj.find('bndbox').find('ymin').text)
x2 = float(obj.find('bndbox').find('xmax').text)
y2 = float(obj.find('bndbox').find('ymax').text)
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(im_w - 1, x2)
y2 = min(im_h - 1, y2)
gt_bbox[i] = [x1, y1, x2, y2]
is_crowd[i][0] = 0
difficult[i][0] = _difficult
voc_rec = {
'im_file': img_file,
'im_id': im_id,
'h': im_h,
'w': im_w,
'is_crowd': is_crowd,
'gt_class': gt_class,
'gt_score': gt_score,
'gt_bbox': gt_bbox,
'gt_poly': [],
'difficult': difficult
}
if len(objs) != 0:
records.append(voc_rec)
ct += 1
if sample_num > 0 and ct >= sample_num:
break
assert len(records) > 0, 'not found any voc record in %s' % (anno_path)
return [records, cname2cid]
def load(anno_path, sample_num=-1, use_default_label=True,
with_background=True):
"""
Load VOC records with annotations in
xml directory 'anno_path'
Notes:
${anno_path} must contains xml file and image file path for annotations
Args:
@anno_path (str): root directory for voc annotation data
@sample_num (int): number of samples to load, -1 means all
@use_default_label (bool): whether use the default mapping of label to id
@with_background (bool): whether load background as a class.
if True, total class number will
be 81. default True
Returns:
(records, catname2clsid)
'records' is list of dict whose structure is:
{
'im_file': im_fname, # image file name
'im_id': im_id, # image id
'h': im_h, # height of image
'w': im_w, # width
'is_crowd': is_crowd,
'gt_class': gt_class,
'gt_bbox': gt_bbox,
'gt_poly': gt_poly,
}
'cname2id' is a dict to map category name to class id
"""
data_dir = os.path.dirname(anno_path)
# mapping category name to class id
# if with_background is True:
# background:0, first_class:1, second_class:2, ...
# if with_background is False:
# first_class:0, second_class:1, ...
records = []
ct = 0
cname2cid = {}
if not use_default_label:
label_path = os.path.join(data_dir, 'label_list.txt')
with open(label_path, 'r') as fr:
label_id = int(with_background)
for line in fr.readlines():
cname2cid[line.strip()] = label_id
label_id += 1
else:
cname2cid = pascalvoc_label(with_background)
with open(anno_path, 'r') as fr:
while True:
line = fr.readline()
if not line:
break
img_file, xml_file = [os.path.join(data_dir, x) \
for x in line.strip().split()[:2]]
if not os.path.isfile(xml_file):
continue
tree = ET.parse(xml_file)
if tree.find('id') is None:
im_id = np.array([ct])
else:
im_id = np.array([int(tree.find('id').text)])
objs = tree.findall('object')
im_w = float(tree.find('size').find('width').text)
im_h = float(tree.find('size').find('height').text)
gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
gt_class = np.zeros((len(objs), 1), dtype=np.int32)
gt_score = np.ones((len(objs), 1), dtype=np.float32)
is_crowd = np.zeros((len(objs), 1), dtype=np.int32)
difficult = np.zeros((len(objs), 1), dtype=np.int32)
for i, obj in enumerate(objs):
cname = obj.find('name').text
gt_class[i][0] = cname2cid[cname]
_difficult = int(obj.find('difficult').text)
x1 = float(obj.find('bndbox').find('xmin').text)
y1 = float(obj.find('bndbox').find('ymin').text)
x2 = float(obj.find('bndbox').find('xmax').text)
y2 = float(obj.find('bndbox').find('ymax').text)
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(im_w - 1, x2)
y2 = min(im_h - 1, y2)
gt_bbox[i] = [x1, y1, x2, y2]
is_crowd[i][0] = 0
difficult[i][0] = _difficult
voc_rec = {
'im_file': img_file,
'im_id': im_id,
'h': im_h,
'w': im_w,
'is_crowd': is_crowd,
'gt_class': gt_class,
'gt_score': gt_score,
'gt_bbox': gt_bbox,
'gt_poly': [],
'difficult': difficult
}
if len(objs) != 0:
records.append(voc_rec)
ct += 1
if sample_num > 0 and ct >= sample_num:
break
assert len(records) > 0, 'not found any voc record in %s' % (anno_path)
return [records, cname2cid]
def pascalvoc_label(with_background=True):
labels_map = {
'aeroplane': 1,
'bicycle': 2,
'bird': 3,
'boat': 4,
'bottle': 5,
'bus': 6,
'car': 7,
'cat': 8,
'chair': 9,
'cow': 10,
'diningtable': 11,
'dog': 12,
'horse': 13,
'motorbike': 14,
'person': 15,
'pottedplant': 16,
'sheep': 17,
'sofa': 18,
'train': 19,
'tvmonitor': 20
}
if not with_background:
labels_map = {k: v - 1 for k, v in labels_map.items()}
return labels_map
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import logging
logger = logging.getLogger(__name__)
def load(anno_path, sample_num=-1, cname2cid=None, with_background=True):
"""
Load WiderFace records with 'anno_path'
Args:
anno_path (str): root directory for voc annotation data
sample_num (int): number of samples to load, -1 means all
with_background (bool): whether load background as a class.
if True, total class number will
be 2. default True
Returns:
(records, catname2clsid)
'records' is list of dict whose structure is:
{
'im_file': im_fname, # image file name
'im_id': im_id, # image id
'gt_class': gt_class,
'gt_bbox': gt_bbox,
}
'cname2id' is a dict to map category name to class id
"""
txt_file = anno_path
records = []
ct = 0
file_lists = _load_file_list(txt_file)
cname2cid = widerface_label(with_background)
for item in file_lists:
im_fname = item[0]
im_id = np.array([ct])
gt_bbox = np.zeros((len(item) - 2, 4), dtype=np.float32)
gt_class = np.ones((len(item) - 2, 1), dtype=np.int32)
for index_box in range(len(item)):
if index_box >= 2:
temp_info_box = item[index_box].split(' ')
xmin = float(temp_info_box[0])
ymin = float(temp_info_box[1])
w = float(temp_info_box[2])
h = float(temp_info_box[3])
# Filter out wrong labels
if w < 0 or h < 0:
continue
xmin = max(0, xmin)
ymin = max(0, ymin)
xmax = xmin + w
ymax = ymin + h
gt_bbox[index_box - 2] = [xmin, ymin, xmax, ymax]
widerface_rec = {
'im_file': im_fname,
'im_id': im_id,
'gt_bbox': gt_bbox,
'gt_class': gt_class,
}
# logger.debug
if len(item) != 0:
records.append(widerface_rec)
ct += 1
if sample_num > 0 and ct >= sample_num:
break
assert len(records) > 0, 'not found any widerface in %s' % (anno_path)
logger.info('{} samples in file {}'.format(ct, anno_path))
return records, cname2cid
def _load_file_list(input_txt):
with open(input_txt, 'r') as f_dir:
lines_input_txt = f_dir.readlines()
file_dict = {}
num_class = 0
for i in range(len(lines_input_txt)):
line_txt = lines_input_txt[i].strip('\n\t\r')
if '.jpg' in line_txt:
if i != 0:
num_class += 1
file_dict[num_class] = []
file_dict[num_class].append(line_txt)
if '.jpg' not in line_txt:
if len(line_txt) > 6:
split_str = line_txt.split(' ')
x1_min = float(split_str[0])
y1_min = float(split_str[1])
x2_max = float(split_str[2])
y2_max = float(split_str[3])
line_txt = str(x1_min) + ' ' + str(y1_min) + ' ' + str(
x2_max) + ' ' + str(y2_max)
file_dict[num_class].append(line_txt)
else:
file_dict[num_class].append(line_txt)
return list(file_dict.values())
def widerface_label(with_background=True):
labels_map = {'face': 1}
if not with_background:
labels_map = {k: v - 1 for k, v in labels_map.items()}
return labels_map
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import print_function
import copy
import logging
import traceback
from .transformer import MappedDataset, BatchedDataset
from .post_map import build_post_map
from .parallel_map import ParallelMappedDataset
from .operators import BaseOperator, registered_ops
__all__ = ['build_mapper', 'map', 'batch', 'batch_map']
logger = logging.getLogger(__name__)
def build_mapper(ops, context=None):
"""
Build a mapper for operators in 'ops'
Args:
ops (list of operator.BaseOperator or list of op dict):
configs for oprators, eg:
[{'name': 'DecodeImage', 'params': {'to_rgb': True}}, {xxx}]
context (dict): a context object for mapper
Returns:
a mapper function which accept one argument 'sample' and
return the processed result
"""
new_ops = []
for _dict in ops:
new_dict = {}
for i, j in _dict.items():
new_dict[i.lower()] = j
new_ops.append(new_dict)
ops = new_ops
op_funcs = []
op_repr = []
for op in ops:
if type(op) is dict and 'op' in op:
op_func = getattr(BaseOperator, op['op'])
params = copy.deepcopy(op)
del params['op']
o = op_func(**params)
elif not isinstance(op, BaseOperator):
op_func = getattr(BaseOperator, op['name'])
params = {} if 'params' not in op else op['params']
o = op_func(**params)
else:
assert isinstance(op, BaseOperator), \
"invalid operator when build ops"
o = op
op_funcs.append(o)
op_repr.append('{{{}}}'.format(str(o)))
op_repr = '[{}]'.format(','.join(op_repr))
def _mapper(sample):
ctx = {} if context is None else copy.deepcopy(context)
for f in op_funcs:
try:
out = f(sample, ctx)
sample = out
except Exception as e:
stack_info = traceback.format_exc()
logger.warn(
"fail to map op [{}] with error: {} and stack:\n{}".format(
f, e, str(stack_info)))
raise e
return out
_mapper.ops = op_repr
return _mapper
def map(ds, mapper, worker_args=None):
"""
Apply 'mapper' to 'ds'
Args:
ds (instance of Dataset): dataset to be mapped
mapper (function): action to be executed for every data sample
worker_args (dict): configs for concurrent mapper
Returns:
a mapped dataset
"""
if worker_args is not None:
return ParallelMappedDataset(ds, mapper, worker_args)
else:
return MappedDataset(ds, mapper)
def batch(ds, batchsize, drop_last=False, drop_empty=True):
"""
Batch data samples to batches
Args:
batchsize (int): number of samples for a batch
drop_last (bool): drop last few samples if not enough for a batch
Returns:
a batched dataset
"""
return BatchedDataset(
ds, batchsize, drop_last=drop_last, drop_empty=drop_empty)
def batch_map(ds, config):
"""
Post process the batches.
Args:
ds (instance of Dataset): dataset to be mapped
mapper (function): action to be executed for every batch
Returns:
a batched dataset which is processed
"""
mapper = build_post_map(**config)
return MappedDataset(ds, mapper)
for nm in registered_ops:
op = getattr(BaseOperator, nm)
locals()[nm] = op
__all__ += registered_ops
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# function:
# operators to process sample,
# eg: decode/resize/crop image
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
from .operators import BaseOperator, register_op
logger = logging.getLogger(__name__)
@register_op
class ArrangeRCNN(BaseOperator):
"""
Transform dict to tuple format needed for training.
Args:
is_mask (bool): whether to use include mask data
"""
def __init__(self, is_mask=False):
super(ArrangeRCNN, self).__init__()
self.is_mask = is_mask
assert isinstance(self.is_mask, bool), "wrong type for is_mask"
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing following items
(image, im_info, im_id, gt_bbox, gt_class, is_crowd, gt_masks)
"""
im = sample['image']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
keys = list(sample.keys())
if 'is_crowd' in keys:
is_crowd = sample['is_crowd']
else:
raise KeyError("The dataset doesn't have 'is_crowd' key.")
if 'im_info' in keys:
im_info = sample['im_info']
else:
raise KeyError("The dataset doesn't have 'im_info' key.")
im_id = sample['im_id']
outs = (im, im_info, im_id, gt_bbox, gt_class, is_crowd)
gt_masks = []
if self.is_mask and len(sample['gt_poly']) != 0 \
and 'is_crowd' in keys:
valid = True
segms = sample['gt_poly']
assert len(segms) == is_crowd.shape[0]
for i in range(len(sample['gt_poly'])):
segm, iscrowd = segms[i], is_crowd[i]
gt_segm = []
if iscrowd:
gt_segm.append([[0, 0]])
else:
for poly in segm:
if len(poly) == 0:
valid = False
break
gt_segm.append(np.array(poly).reshape(-1, 2))
if (not valid) or len(gt_segm) == 0:
break
gt_masks.append(gt_segm)
outs = outs + (gt_masks, )
return outs
@register_op
class ArrangeEvalRCNN(BaseOperator):
"""
Transform dict to the tuple format needed for evaluation.
"""
def __init__(self):
super(ArrangeEvalRCNN, self).__init__()
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items:
(image, im_info, im_id, im_shape, gt_bbox,
gt_class, difficult)
"""
ims = []
keys = sorted(list(sample.keys()))
for k in keys:
if 'image' in k:
ims.append(sample[k])
if 'im_info' in keys:
im_info = sample['im_info']
else:
raise KeyError("The dataset doesn't have 'im_info' key.")
im_id = sample['im_id']
h = sample['h']
w = sample['w']
# For rcnn models in eval and infer stage, original image size
# is needed to clip the bounding boxes. And box clip op in
# bbox prediction needs im_info as input in format of [N, 3],
# so im_shape is appended by 1 to match dimension.
im_shape = np.array((h, w, 1), dtype=np.float32)
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
difficult = sample['difficult']
remain_list = [im_info, im_id, im_shape, gt_bbox, gt_class, difficult]
ims.extend(remain_list)
outs = tuple(ims)
return outs
@register_op
class ArrangeTestRCNN(BaseOperator):
"""
Transform dict to the tuple format needed for training.
"""
def __init__(self):
super(ArrangeTestRCNN, self).__init__()
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items:
(image, im_info, im_id, im_shape)
"""
ims = []
keys = sorted(list(sample.keys()))
for k in keys:
if 'image' in k:
ims.append(sample[k])
if 'im_info' in keys:
im_info = sample['im_info']
else:
raise KeyError("The dataset doesn't have 'im_info' key.")
im_id = sample['im_id']
h = sample['h']
w = sample['w']
# For rcnn models in eval and infer stage, original image size
# is needed to clip the bounding boxes. And box clip op in
# bbox prediction needs im_info as input in format of [N, 3],
# so im_shape is appended by 1 to match dimension.
im_shape = np.array((h, w, 1), dtype=np.float32)
remain_list = [im_info, im_id, im_shape]
ims.extend(remain_list)
outs = tuple(ims)
return outs
@register_op
class ArrangeSSD(BaseOperator):
"""
Transform dict to tuple format needed for training.
"""
def __init__(self):
super(ArrangeSSD, self).__init__()
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items:
(image, gt_bbox, gt_class, difficult)
"""
im = sample['image']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
outs = (im, gt_bbox, gt_class)
return outs
@register_op
class ArrangeEvalSSD(BaseOperator):
"""
Transform dict to tuple format needed for training.
"""
def __init__(self, fields):
super(ArrangeEvalSSD, self).__init__()
self.fields = fields
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items: (image)
"""
outs = []
if len(sample['gt_bbox']) != len(sample['gt_class']):
raise ValueError("gt num mismatch: bbox and class.")
for field in self.fields:
if field == 'im_shape':
h = sample['h']
w = sample['w']
im_shape = np.array((h, w))
outs.append(im_shape)
elif field == 'is_difficult':
outs.append(sample['difficult'])
elif field == 'gt_box':
outs.append(sample['gt_bbox'])
elif field == 'gt_label':
outs.append(sample['gt_class'])
else:
outs.append(sample[field])
outs = tuple(outs)
return outs
@register_op
class ArrangeTestSSD(BaseOperator):
"""
Transform dict to tuple format needed for training.
Args:
is_mask (bool): whether to use include mask data
"""
def __init__(self):
super(ArrangeTestSSD, self).__init__()
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items: (image)
"""
im = sample['image']
im_id = sample['im_id']
h = sample['h']
w = sample['w']
im_shape = np.array((h, w))
outs = (im, im_id, im_shape)
return outs
@register_op
class ArrangeYOLO(BaseOperator):
"""
Transform dict to the tuple format needed for training.
"""
def __init__(self):
super(ArrangeYOLO, self).__init__()
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items:
(image, gt_bbox, gt_class, gt_score,
is_crowd, im_info, gt_masks)
"""
im = sample['image']
if len(sample['gt_bbox']) != len(sample['gt_class']):
raise ValueError("gt num mismatch: bbox and class.")
if len(sample['gt_bbox']) != len(sample['gt_score']):
raise ValueError("gt num mismatch: bbox and score.")
gt_bbox = np.zeros((50, 4), dtype=im.dtype)
gt_class = np.zeros((50, ), dtype=np.int32)
gt_score = np.zeros((50, ), dtype=im.dtype)
gt_num = min(50, len(sample['gt_bbox']))
if gt_num > 0:
gt_bbox[:gt_num, :] = sample['gt_bbox'][:gt_num, :]
gt_class[:gt_num] = sample['gt_class'][:gt_num, 0]
gt_score[:gt_num] = sample['gt_score'][:gt_num, 0]
# parse [x1, y1, x2, y2] to [x, y, w, h]
gt_bbox[:, 2:4] = gt_bbox[:, 2:4] - gt_bbox[:, :2]
gt_bbox[:, :2] = gt_bbox[:, :2] + gt_bbox[:, 2:4] / 2.
outs = (im, gt_bbox, gt_class, gt_score)
return outs
@register_op
class ArrangeEvalYOLO(BaseOperator):
"""
Transform dict to the tuple format needed for evaluation.
"""
def __init__(self):
super(ArrangeEvalYOLO, self).__init__()
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items:
(image, im_shape, im_id, gt_bbox, gt_class,
difficult)
"""
im = sample['image']
if len(sample['gt_bbox']) != len(sample['gt_class']):
raise ValueError("gt num mismatch: bbox and class.")
im_id = sample['im_id']
h = sample['h']
w = sample['w']
im_shape = np.array((h, w))
gt_bbox = np.zeros((50, 4), dtype=im.dtype)
gt_class = np.zeros((50, ), dtype=np.int32)
difficult = np.zeros((50, ), dtype=np.int32)
gt_num = min(50, len(sample['gt_bbox']))
if gt_num > 0:
gt_bbox[:gt_num, :] = sample['gt_bbox'][:gt_num, :]
gt_class[:gt_num] = sample['gt_class'][:gt_num, 0]
difficult[:gt_num] = sample['difficult'][:gt_num, 0]
outs = (im, im_shape, im_id, gt_bbox, gt_class, difficult)
return outs
@register_op
class ArrangeTestYOLO(BaseOperator):
"""
Transform dict to the tuple format needed for inference.
"""
def __init__(self):
super(ArrangeTestYOLO, self).__init__()
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items:
(image, gt_bbox, gt_class, gt_score, is_crowd,
im_info, gt_masks)
"""
im = sample['image']
im_id = sample['im_id']
h = sample['h']
w = sample['w']
im_shape = np.array((h, w))
outs = (im, im_shape, im_id)
return outs
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this file contains helper methods for BBOX processing
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import random
import math
import cv2
def meet_emit_constraint(src_bbox, sample_bbox):
center_x = (src_bbox[2] + src_bbox[0]) / 2
center_y = (src_bbox[3] + src_bbox[1]) / 2
if center_x >= sample_bbox[0] and \
center_x <= sample_bbox[2] and \
center_y >= sample_bbox[1] and \
center_y <= sample_bbox[3]:
return True
return False
def clip_bbox(src_bbox):
src_bbox[0] = max(min(src_bbox[0], 1.0), 0.0)
src_bbox[1] = max(min(src_bbox[1], 1.0), 0.0)
src_bbox[2] = max(min(src_bbox[2], 1.0), 0.0)
src_bbox[3] = max(min(src_bbox[3], 1.0), 0.0)
return src_bbox
def bbox_area(src_bbox):
if src_bbox[2] < src_bbox[0] or src_bbox[3] < src_bbox[1]:
return 0.
else:
width = src_bbox[2] - src_bbox[0]
height = src_bbox[3] - src_bbox[1]
return width * height
def is_overlap(object_bbox, sample_bbox):
if object_bbox[0] >= sample_bbox[2] or \
object_bbox[2] <= sample_bbox[0] or \
object_bbox[1] >= sample_bbox[3] or \
object_bbox[3] <= sample_bbox[1]:
return False
else:
return True
def filter_and_process(sample_bbox, bboxes, labels, scores=None):
new_bboxes = []
new_labels = []
new_scores = []
for i in range(len(bboxes)):
new_bbox = [0, 0, 0, 0]
obj_bbox = [bboxes[i][0], bboxes[i][1], bboxes[i][2], bboxes[i][3]]
if not meet_emit_constraint(obj_bbox, sample_bbox):
continue
if not is_overlap(obj_bbox, sample_bbox):
continue
sample_width = sample_bbox[2] - sample_bbox[0]
sample_height = sample_bbox[3] - sample_bbox[1]
new_bbox[0] = (obj_bbox[0] - sample_bbox[0]) / sample_width
new_bbox[1] = (obj_bbox[1] - sample_bbox[1]) / sample_height
new_bbox[2] = (obj_bbox[2] - sample_bbox[0]) / sample_width
new_bbox[3] = (obj_bbox[3] - sample_bbox[1]) / sample_height
new_bbox = clip_bbox(new_bbox)
if bbox_area(new_bbox) > 0:
new_bboxes.append(new_bbox)
new_labels.append([labels[i][0]])
if scores is not None:
new_scores.append([scores[i][0]])
bboxes = np.array(new_bboxes)
labels = np.array(new_labels)
scores = np.array(new_scores)
return bboxes, labels, scores
def bbox_area_sampling(bboxes, labels, scores, target_size, min_size):
new_bboxes = []
new_labels = []
new_scores = []
for i, bbox in enumerate(bboxes):
w = float((bbox[2] - bbox[0]) * target_size)
h = float((bbox[3] - bbox[1]) * target_size)
if w * h < float(min_size * min_size):
continue
else:
new_bboxes.append(bbox)
new_labels.append(labels[i])
if scores is not None and scores.size != 0:
new_scores.append(scores[i])
bboxes = np.array(new_bboxes)
labels = np.array(new_labels)
scores = np.array(new_scores)
return bboxes, labels, scores
def generate_sample_bbox(sampler):
scale = np.random.uniform(sampler[2], sampler[3])
aspect_ratio = np.random.uniform(sampler[4], sampler[5])
aspect_ratio = max(aspect_ratio, (scale**2.0))
aspect_ratio = min(aspect_ratio, 1 / (scale**2.0))
bbox_width = scale * (aspect_ratio**0.5)
bbox_height = scale / (aspect_ratio**0.5)
xmin_bound = 1 - bbox_width
ymin_bound = 1 - bbox_height
xmin = np.random.uniform(0, xmin_bound)
ymin = np.random.uniform(0, ymin_bound)
xmax = xmin + bbox_width
ymax = ymin + bbox_height
sampled_bbox = [xmin, ymin, xmax, ymax]
return sampled_bbox
def generate_sample_bbox_square(sampler, image_width, image_height):
scale = np.random.uniform(sampler[2], sampler[3])
aspect_ratio = np.random.uniform(sampler[4], sampler[5])
aspect_ratio = max(aspect_ratio, (scale**2.0))
aspect_ratio = min(aspect_ratio, 1 / (scale**2.0))
bbox_width = scale * (aspect_ratio**0.5)
bbox_height = scale / (aspect_ratio**0.5)
if image_height < image_width:
bbox_width = bbox_height * image_height / image_width
else:
bbox_height = bbox_width * image_width / image_height
xmin_bound = 1 - bbox_width
ymin_bound = 1 - bbox_height
xmin = np.random.uniform(0, xmin_bound)
ymin = np.random.uniform(0, ymin_bound)
xmax = xmin + bbox_width
ymax = ymin + bbox_height
sampled_bbox = [xmin, ymin, xmax, ymax]
return sampled_bbox
def data_anchor_sampling(bbox_labels, image_width, image_height, scale_array,
resize_width):
num_gt = len(bbox_labels)
# np.random.randint range: [low, high)
rand_idx = np.random.randint(0, num_gt) if num_gt != 0 else 0
if num_gt != 0:
norm_xmin = bbox_labels[rand_idx][0]
norm_ymin = bbox_labels[rand_idx][1]
norm_xmax = bbox_labels[rand_idx][2]
norm_ymax = bbox_labels[rand_idx][3]
xmin = norm_xmin * image_width
ymin = norm_ymin * image_height
wid = image_width * (norm_xmax - norm_xmin)
hei = image_height * (norm_ymax - norm_ymin)
range_size = 0
area = wid * hei
for scale_ind in range(0, len(scale_array) - 1):
if area > scale_array[scale_ind] ** 2 and area < \
scale_array[scale_ind + 1] ** 2:
range_size = scale_ind + 1
break
if area > scale_array[len(scale_array) - 2]**2:
range_size = len(scale_array) - 2
scale_choose = 0.0
if range_size == 0:
rand_idx_size = 0
else:
# np.random.randint range: [low, high)
rng_rand_size = np.random.randint(0, range_size + 1)
rand_idx_size = rng_rand_size % (range_size + 1)
if rand_idx_size == range_size:
min_resize_val = scale_array[rand_idx_size] / 2.0
max_resize_val = min(2.0 * scale_array[rand_idx_size],
2 * math.sqrt(wid * hei))
scale_choose = random.uniform(min_resize_val, max_resize_val)
else:
min_resize_val = scale_array[rand_idx_size] / 2.0
max_resize_val = 2.0 * scale_array[rand_idx_size]
scale_choose = random.uniform(min_resize_val, max_resize_val)
sample_bbox_size = wid * resize_width / scale_choose
w_off_orig = 0.0
h_off_orig = 0.0
if sample_bbox_size < max(image_height, image_width):
if wid <= sample_bbox_size:
w_off_orig = np.random.uniform(xmin + wid - sample_bbox_size,
xmin)
else:
w_off_orig = np.random.uniform(xmin,
xmin + wid - sample_bbox_size)
if hei <= sample_bbox_size:
h_off_orig = np.random.uniform(ymin + hei - sample_bbox_size,
ymin)
else:
h_off_orig = np.random.uniform(ymin,
ymin + hei - sample_bbox_size)
else:
w_off_orig = np.random.uniform(image_width - sample_bbox_size, 0.0)
h_off_orig = np.random.uniform(image_height - sample_bbox_size, 0.0)
w_off_orig = math.floor(w_off_orig)
h_off_orig = math.floor(h_off_orig)
# Figure out top left coordinates.
w_off = float(w_off_orig / image_width)
h_off = float(h_off_orig / image_height)
sampled_bbox = [
w_off, h_off, w_off + float(sample_bbox_size / image_width),
h_off + float(sample_bbox_size / image_height)
]
return sampled_bbox
else:
return 0
def jaccard_overlap(sample_bbox, object_bbox):
if sample_bbox[0] >= object_bbox[2] or \
sample_bbox[2] <= object_bbox[0] or \
sample_bbox[1] >= object_bbox[3] or \
sample_bbox[3] <= object_bbox[1]:
return 0
intersect_xmin = max(sample_bbox[0], object_bbox[0])
intersect_ymin = max(sample_bbox[1], object_bbox[1])
intersect_xmax = min(sample_bbox[2], object_bbox[2])
intersect_ymax = min(sample_bbox[3], object_bbox[3])
intersect_size = (intersect_xmax - intersect_xmin) * (
intersect_ymax - intersect_ymin)
sample_bbox_size = bbox_area(sample_bbox)
object_bbox_size = bbox_area(object_bbox)
overlap = intersect_size / (
sample_bbox_size + object_bbox_size - intersect_size)
return overlap
def intersect_bbox(bbox1, bbox2):
if bbox2[0] > bbox1[2] or bbox2[2] < bbox1[0] or \
bbox2[1] > bbox1[3] or bbox2[3] < bbox1[1]:
intersection_box = [0.0, 0.0, 0.0, 0.0]
else:
intersection_box = [
max(bbox1[0], bbox2[0]),
max(bbox1[1], bbox2[1]),
min(bbox1[2], bbox2[2]),
min(bbox1[3], bbox2[3])
]
return intersection_box
def bbox_coverage(bbox1, bbox2):
inter_box = intersect_bbox(bbox1, bbox2)
intersect_size = bbox_area(inter_box)
if intersect_size > 0:
bbox1_size = bbox_area(bbox1)
return intersect_size / bbox1_size
else:
return 0.
def satisfy_sample_constraint(sampler,
sample_bbox,
gt_bboxes,
satisfy_all=False):
if sampler[6] == 0 and sampler[7] == 0:
return True
satisfied = []
for i in range(len(gt_bboxes)):
object_bbox = [
gt_bboxes[i][0], gt_bboxes[i][1], gt_bboxes[i][2], gt_bboxes[i][3]
]
overlap = jaccard_overlap(sample_bbox, object_bbox)
if sampler[6] != 0 and \
overlap < sampler[6]:
satisfied.append(False)
continue
if sampler[7] != 0 and \
overlap > sampler[7]:
satisfied.append(False)
continue
satisfied.append(True)
if not satisfy_all:
return True
if satisfy_all:
return np.all(satisfied)
else:
return False
def satisfy_sample_constraint_coverage(sampler, sample_bbox, gt_bboxes):
if sampler[6] == 0 and sampler[7] == 0:
has_jaccard_overlap = False
else:
has_jaccard_overlap = True
if sampler[8] == 0 and sampler[9] == 0:
has_object_coverage = False
else:
has_object_coverage = True
if not has_jaccard_overlap and not has_object_coverage:
return True
found = False
for i in range(len(gt_bboxes)):
object_bbox = [
gt_bboxes[i][0], gt_bboxes[i][1], gt_bboxes[i][2], gt_bboxes[i][3]
]
if has_jaccard_overlap:
overlap = jaccard_overlap(sample_bbox, object_bbox)
if sampler[6] != 0 and \
overlap < sampler[6]:
continue
if sampler[7] != 0 and \
overlap > sampler[7]:
continue
found = True
if has_object_coverage:
object_coverage = bbox_coverage(object_bbox, sample_bbox)
if sampler[8] != 0 and \
object_coverage < sampler[8]:
continue
if sampler[9] != 0 and \
object_coverage > sampler[9]:
continue
found = True
if found:
return True
return found
def crop_image_sampling(img, sample_bbox, image_width, image_height,
target_size):
# no clipping here
xmin = int(sample_bbox[0] * image_width)
xmax = int(sample_bbox[2] * image_width)
ymin = int(sample_bbox[1] * image_height)
ymax = int(sample_bbox[3] * image_height)
w_off = xmin
h_off = ymin
width = xmax - xmin
height = ymax - ymin
cross_xmin = max(0.0, float(w_off))
cross_ymin = max(0.0, float(h_off))
cross_xmax = min(float(w_off + width - 1.0), float(image_width))
cross_ymax = min(float(h_off + height - 1.0), float(image_height))
cross_width = cross_xmax - cross_xmin
cross_height = cross_ymax - cross_ymin
roi_xmin = 0 if w_off >= 0 else abs(w_off)
roi_ymin = 0 if h_off >= 0 else abs(h_off)
roi_width = cross_width
roi_height = cross_height
roi_y1 = int(roi_ymin)
roi_y2 = int(roi_ymin + roi_height)
roi_x1 = int(roi_xmin)
roi_x2 = int(roi_xmin + roi_width)
cross_y1 = int(cross_ymin)
cross_y2 = int(cross_ymin + cross_height)
cross_x1 = int(cross_xmin)
cross_x2 = int(cross_xmin + cross_width)
sample_img = np.zeros((height, width, 3))
sample_img[roi_y1: roi_y2, roi_x1: roi_x2] = \
img[cross_y1: cross_y2, cross_x1: cross_x2]
sample_img = cv2.resize(
sample_img, (target_size, target_size), interpolation=cv2.INTER_AREA)
return sample_img
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# function:
# transform samples in 'source' using 'mapper'
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import six
import uuid
import logging
import signal
import threading
from .transformer import ProxiedDataset
logger = logging.getLogger(__name__)
class EndSignal(object):
def __init__(self, errno=0, errmsg=''):
self.errno = errno
self.errmsg = errmsg
class ParallelMappedDataset(ProxiedDataset):
"""
Transform samples to mapped samples which is similar to 'basic.MappedDataset',
but multiple workers (threads or processes) will be used
Notes:
this class is not thread-safe
"""
def __init__(self, source, mapper, worker_args):
super(ParallelMappedDataset, self).__init__(source)
worker_args = {k.lower(): v for k, v in worker_args.items()}
args = {
'bufsize': 100,
'worker_num': 8,
'use_process': False,
'memsize': '3G'
}
args.update(worker_args)
if args['use_process'] and type(args['memsize']) is str:
assert args['memsize'][-1].lower() == 'g', \
"invalid param for memsize[%s], should be ended with 'G' or 'g'" % (args['memsize'])
gb = args['memsize'][:-1]
args['memsize'] = int(gb) * 1024**3
self._worker_args = args
self._started = False
self._source = source
self._mapper = mapper
self._exit = False
self._setup()
def _setup(self):
"""setup input/output queues and workers """
use_process = self._worker_args.get('use_process', False)
if use_process and sys.platform == "win32":
logger.info("Use multi-thread reader instead of "
"multi-process reader on Windows.")
use_process = False
bufsize = self._worker_args['bufsize']
if use_process:
from .shared_queue import SharedQueue as Queue
from multiprocessing import Process as Worker
from multiprocessing import Event
memsize = self._worker_args['memsize']
self._inq = Queue(bufsize, memsize=memsize)
self._outq = Queue(bufsize, memsize=memsize)
else:
if six.PY3:
from queue import Queue
else:
from Queue import Queue
from threading import Thread as Worker
from threading import Event
self._inq = Queue(bufsize)
self._outq = Queue(bufsize)
consumer_num = self._worker_args['worker_num']
id = str(uuid.uuid4())[-3:]
self._producer = threading.Thread(
target=self._produce,
args=('producer-' + id, self._source, self._inq))
self._producer.daemon = True
self._consumers = []
for i in range(consumer_num):
p = Worker(
target=self._consume,
args=('consumer-' + id + '_' + str(i), self._inq, self._outq,
self._mapper))
self._consumers.append(p)
p.daemon = True
self._epoch = -1
self._feeding_ev = Event()
self._produced = 0 # produced sample in self._produce
self._consumed = 0 # consumed sample in self.next
self._stopped_consumers = 0
def _produce(self, id, source, inq):
"""Fetch data from source and feed it to 'inq' queue"""
while True:
self._feeding_ev.wait()
if self._exit:
break
try:
inq.put(source.next())
self._produced += 1
except StopIteration:
self._feeding_ev.clear()
self._feeding_ev.wait() # wait other guy to wake up me
logger.debug("producer[{}] starts new epoch".format(id))
except Exception as e:
msg = "producer[{}] failed with error: {}".format(id, str(e))
inq.put(EndSignal(-1, msg))
break
logger.debug("producer[{}] exits".format(id))
def _consume(self, id, inq, outq, mapper):
"""Fetch data from 'inq', process it and put result to 'outq'"""
while True:
sample = inq.get()
if isinstance(sample, EndSignal):
sample.errmsg += "[consumer[{}] exits]".format(id)
outq.put(sample)
logger.debug("end signal received, " +
"consumer[{}] exits".format(id))
break
try:
result = mapper(sample)
outq.put(result)
except Exception as e:
msg = 'failed to map consumer[%s], error: {}'.format(str(e), id)
outq.put(EndSignal(-1, msg))
break
def drained(self):
assert self._epoch >= 0, "first epoch has not started yet"
return self._source.drained() and self._produced == self._consumed
def stop(self):
""" notify to exit
"""
self._exit = True
self._feeding_ev.set()
for _ in range(len(self._consumers)):
self._inq.put(EndSignal(0, "notify consumers to exit"))
def next(self):
""" get next transformed sample
"""
if self._epoch < 0:
self.reset()
if self.drained():
raise StopIteration()
while True:
sample = self._outq.get()
if isinstance(sample, EndSignal):
self._stopped_consumers += 1
if sample.errno != 0:
logger.warn("consumer failed with error: {}".format(
sample.errmsg))
if self._stopped_consumers < len(self._consumers):
self._inq.put(sample)
else:
raise ValueError("all consumers exited, no more samples")
else:
self._consumed += 1
return sample
def reset(self):
""" reset for a new epoch of samples
"""
if self._epoch < 0:
self._epoch = 0
for p in self._consumers:
p.start()
self._producer.start()
else:
if not self.drained():
logger.warn("do not reset before epoch[%d] finishes".format(
self._epoch))
self._produced = self._produced - self._consumed
else:
self._produced = 0
self._epoch += 1
assert self._stopped_consumers == 0, "some consumers already exited," \
+ " cannot start another epoch"
self._source.reset()
self._consumed = 0
self._feeding_ev.set()
# FIXME(dengkaipeng): fix me if you have better impliment
# handle terminate reader process, do not print stack frame
def _reader_exit(signum, frame):
logger.debug("Reader process exit.")
sys.exit()
signal.signal(signal.SIGTERM, _reader_exit)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import cv2
import numpy as np
logger = logging.getLogger(__name__)
def build_post_map(coarsest_stride=1,
is_padding=False,
random_shapes=[],
multi_scales=[],
use_padded_im_info=False,
enable_multiscale_test=False,
num_scale=1):
"""
Build a mapper for post-processing batches
Args:
config (dict of parameters):
{
coarsest_stride (int): stride of the coarsest FPN level
is_padding (bool): whether to padding in minibatch
random_shapes (list of int): resize to image to random shapes,
[] for not resize.
multi_scales (list of int): resize image by random scales,
[] for not resize.
use_padded_im_info (bool): whether to update im_info after padding
enable_multiscale_test (bool): whether to use multiscale test.
num_scale (int) : the number of scales for multiscale test.
}
Returns:
a mapper function which accept one argument 'batch' and
return the processed result
"""
def padding_minibatch(batch_data):
if len(batch_data) == 1 and coarsest_stride == 1:
return batch_data
max_shape = np.array([data[0].shape for data in batch_data]).max(axis=0)
if coarsest_stride > 1:
max_shape[1] = int(
np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
max_shape[2] = int(
np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride)
padding_batch = []
for data in batch_data:
im_c, im_h, im_w = data[0].shape[:]
padding_im = np.zeros((im_c, max_shape[1], max_shape[2]),
dtype=np.float32)
padding_im[:, :im_h, :im_w] = data[0]
if use_padded_im_info:
data[1][:2] = max_shape[1:3]
padding_batch.append((padding_im, ) + data[1:])
return padding_batch
def padding_multiscale_test(batch_data):
if len(batch_data) != 1:
raise NotImplementedError(
"Batch size must be 1 when using multiscale test, but now batch size is {}"
.format(len(batch_data)))
if coarsest_stride > 1:
padding_batch = []
padding_images = []
data = batch_data[0]
for i, input in enumerate(data):
if i < num_scale:
im_c, im_h, im_w = input.shape
max_h = int(
np.ceil(im_h / coarsest_stride) * coarsest_stride)
max_w = int(
np.ceil(im_w / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, max_h, max_w),
dtype=np.float32)
padding_im[:, :im_h, :im_w] = input
data[num_scale][3 * i:3 * i + 2] = [max_h, max_w]
padding_batch.append(padding_im)
else:
padding_batch.append(input)
return [tuple(padding_batch)]
# no need to padding
return batch_data
def random_shape(batch_data):
# For YOLO: gt_bbox is normalized, is scale invariant.
shape = np.random.choice(random_shapes)
scaled_batch = []
h, w = batch_data[0][0].shape[1:3]
scale_x = float(shape) / w
scale_y = float(shape) / h
for data in batch_data:
im = cv2.resize(
data[0].transpose((1, 2, 0)),
None,
None,
fx=scale_x,
fy=scale_y,
interpolation=cv2.INTER_NEAREST)
scaled_batch.append((im.transpose(2, 0, 1), ) + data[1:])
return scaled_batch
def multi_scale_resize(batch_data):
# For RCNN: image shape in record in im_info.
scale = np.random.choice(multi_scales)
scaled_batch = []
for data in batch_data:
im = cv2.resize(
data[0].transpose((1, 2, 0)),
None,
None,
fx=scale,
fy=scale,
interpolation=cv2.INTER_NEAREST)
im_info = [im.shape[:2], scale]
scaled_batch.append((im.transpose(2, 0, 1), im_info) + data[2:])
return scaled_batch
def _mapper(batch_data):
try:
if is_padding:
batch_data = padding_minibatch(batch_data)
if len(random_shapes) > 0:
batch_data = random_shape(batch_data)
if len(multi_scales) > 0:
batch_data = multi_scale_resize(batch_data)
if enable_multiscale_test:
batch_data = padding_multiscale_test(batch_data)
except Exception as e:
errmsg = "post-process failed with error: " + str(e)
logger.warn(errmsg)
raise e
return batch_data
return _mapper
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
__all__ = ['SharedBuffer', 'SharedMemoryMgr', 'SharedQueue']
from .sharedmemory import SharedBuffer
from .sharedmemory import SharedMemoryMgr
from .sharedmemory import SharedMemoryError
from .queue import SharedQueue
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import six
if six.PY3:
import pickle
from io import BytesIO as StringIO
else:
import cPickle as pickle
from cStringIO import StringIO
import logging
import traceback
import multiprocessing as mp
from multiprocessing.queues import Queue
from .sharedmemory import SharedMemoryMgr
logger = logging.getLogger(__name__)
class SharedQueueError(ValueError):
""" SharedQueueError
"""
pass
class SharedQueue(Queue):
""" a Queue based on shared memory to communicate data between Process,
and it's interface is compatible with 'multiprocessing.queues.Queue'
"""
def __init__(self, maxsize=0, mem_mgr=None, memsize=None, pagesize=None):
""" init
"""
if six.PY3:
super(SharedQueue, self).__init__(maxsize, ctx=mp.get_context())
else:
super(SharedQueue, self).__init__(maxsize)
if mem_mgr is not None:
self._shared_mem = mem_mgr
else:
self._shared_mem = SharedMemoryMgr(
capacity=memsize, pagesize=pagesize)
def put(self, obj, **kwargs):
""" put an object to this queue
"""
obj = pickle.dumps(obj, -1)
buff = None
try:
buff = self._shared_mem.malloc(len(obj))
buff.put(obj)
super(SharedQueue, self).put(buff, **kwargs)
except Exception as e:
stack_info = traceback.format_exc()
err_msg = 'failed to put a element to SharedQueue '\
'with stack info[%s]' % (stack_info)
logger.warn(err_msg)
if buff is not None:
buff.free()
raise e
def get(self, **kwargs):
""" get an object from this queue
"""
buff = None
try:
buff = super(SharedQueue, self).get(**kwargs)
data = buff.get()
return pickle.load(StringIO(data))
except Exception as e:
stack_info = traceback.format_exc()
err_msg = 'failed to get element from SharedQueue '\
'with stack info[%s]' % (stack_info)
logger.warn(err_msg)
raise e
finally:
if buff is not None:
buff.free()
def release(self):
self._shared_mem.release()
self._shared_mem = None
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import functools
import collections
from ..dataset import Dataset
class ProxiedDataset(Dataset):
"""proxy method called to 'self._ds' when if not defined"""
def __init__(self, ds):
super(ProxiedDataset, self).__init__()
self._ds = ds
methods = filter(lambda k: not k.startswith('_'),
Dataset.__dict__.keys())
for m in methods:
func = functools.partial(self._proxy_method, getattr(self, m))
setattr(self, m, func)
def _proxy_method(self, func, *args, **kwargs):
"""
proxy call to 'func', if not available then call self._ds.xxx
whose name is the same with func.__name__
"""
method = func.__name__
try:
return func(*args, **kwargs)
except NotImplementedError:
ds_func = getattr(self._ds, method)
return ds_func(*args, **kwargs)
class MappedDataset(ProxiedDataset):
def __init__(self, ds, mapper):
super(MappedDataset, self).__init__(ds)
self._ds = ds
self._mapper = mapper
def next(self):
sample = self._ds.next()
return self._mapper(sample)
class BatchedDataset(ProxiedDataset):
"""
Batching samples
Args:
ds (instance of Dataset): dataset to be batched
batchsize (int): sample number for each batch
drop_last (bool): drop last samples when not enough for one batch
drop_empty (bool): drop samples which have empty field
"""
def __init__(self, ds, batchsize, drop_last=False, drop_empty=True):
super(BatchedDataset, self).__init__(ds)
self._batchsz = batchsize
self._drop_last = drop_last
self._drop_empty = drop_empty
def next(self):
"""proxy to self._ds.next"""
def empty(x):
if isinstance(x, np.ndarray) and x.size == 0:
return True
elif isinstance(x, collections.Sequence) and len(x) == 0:
return True
else:
return False
def has_empty(items):
if any(x is None for x in items):
return True
if any(empty(x) for x in items):
return True
return False
batch = []
for _ in range(self._batchsz):
try:
out = self._ds.next()
while self._drop_empty and has_empty(out):
out = self._ds.next()
batch.append(out)
except StopIteration:
if not self._drop_last and len(batch) > 0:
return batch
else:
raise StopIteration
return batch
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import os.path as osp
import re
import random
import shutil
__all__ = ['create_list']
def create_list(devkit_dir, years, output_dir):
"""
create following list:
1. trainval.txt
2. test.txt
"""
trainval_list = []
test_list = []
for year in years:
trainval, test = _walk_voc_dir(devkit_dir, year, output_dir)
trainval_list.extend(trainval)
test_list.extend(test)
random.shuffle(trainval_list)
with open(osp.join(output_dir, 'trainval.txt'), 'w') as ftrainval:
for item in trainval_list:
ftrainval.write(item[0] + ' ' + item[1] + '\n')
with open(osp.join(output_dir, 'test.txt'), 'w') as fval:
ct = 0
for item in test_list:
ct += 1
fval.write(item[0] + ' ' + item[1] + '\n')
def _get_voc_dir(devkit_dir, year, type):
return osp.join(devkit_dir, 'VOC' + year, type)
def _walk_voc_dir(devkit_dir, year, output_dir):
filelist_dir = _get_voc_dir(devkit_dir, year, 'ImageSets/Main')
annotation_dir = _get_voc_dir(devkit_dir, year, 'Annotations')
img_dir = _get_voc_dir(devkit_dir, year, 'JPEGImages')
trainval_list = []
test_list = []
added = set()
for _, _, files in os.walk(filelist_dir):
for fname in files:
img_ann_list = []
if re.match('[a-z]+_trainval\.txt', fname):
img_ann_list = trainval_list
elif re.match('[a-z]+_test\.txt', fname):
img_ann_list = test_list
else:
continue
fpath = osp.join(filelist_dir, fname)
for line in open(fpath):
name_prefix = line.strip().split()[0]
if name_prefix in added:
continue
added.add(name_prefix)
ann_path = osp.join(
osp.relpath(annotation_dir, output_dir),
name_prefix + '.xml')
img_path = osp.join(
osp.relpath(img_dir, output_dir), name_prefix + '.jpg')
img_ann_list.append((img_path, ann_path))
return trainval_list, test_list
......@@ -37,3 +37,4 @@ from .flowers import FlowersDataset as Flowers
from .stanford_dogs import StanfordDogsDataset as StanfordDogs
from .food101 import Food101Dataset as Food101
from .indoor67 import Indoor67Dataset as Indoor67
from .coco10 import Coco10
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册