未验证 提交 88d1f7dc 编写于 作者: J jayhenry 提交者: GitHub

Add Detection Finetune Task (#510)

* add object detection dataset, reader, task and demo
上级 13fc90ba
#coding:utf-8
import argparse
import os
import ast
import paddle.fluid as fluid
import paddlehub as hub
import numpy as np
from paddlehub.reader.cv_reader import ObjectDetectionReader
from paddlehub.dataset.base_cv_dataset import ObjectDetectionDataset
from paddlehub.contrib.ppdet.utils.coco_eval import bbox2out
from paddlehub.common.detection_config import get_model_type, get_feed_list, get_mid_feature
from paddlehub.common import detection_config as dconf
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for predict.")
parser.add_argument("--checkpoint_dir", type=str, default="paddlehub_finetune_ckpt", help="Path to save log data.")
parser.add_argument("--batch_size", type=int, default=2, help="Total examples' number in batch for training.")
parser.add_argument("--module", type=str, default="ssd", help="Module used as a feature extractor.")
parser.add_argument("--dataset", type=str, default="coco10", help="Dataset to finetune.")
# yapf: enable.
module_map = {
"yolov3": "yolov3_darknet53_coco2017",
"ssd": "ssd_vgg16_512_coco2017",
"faster_rcnn": "faster_rcnn_resnet50_coco2017",
}
def predict(args):
module_name = args.module # 'yolov3_darknet53_coco2017'
model_type = get_model_type(module_name) # 'yolo'
# define data
ds = hub.dataset.Coco10(model_type)
print("ds.num_labels", ds.num_labels)
data_reader = ObjectDetectionReader(dataset=ds, model_type=model_type)
# define model(program)
module = hub.Module(name=module_name)
if model_type == 'rcnn':
input_dict, output_dict, program = module.context(trainable=True, phase='train')
input_dict_pred, output_dict_pred, program_pred = module.context(trainable=False)
else:
input_dict, output_dict, program = module.context(trainable=True)
input_dict_pred = output_dict_pred = None
feed_list, pred_feed_list = get_feed_list(module_name, input_dict, input_dict_pred)
feature, pred_feature = get_mid_feature(module_name, output_dict, output_dict_pred)
config = hub.RunConfig(
use_data_parallel=False,
use_pyreader=True,
use_cuda=args.use_gpu,
batch_size=args.batch_size,
enable_memory_optim=False,
checkpoint_dir=args.checkpoint_dir,
strategy=hub.finetune.strategy.DefaultFinetuneStrategy())
task = hub.DetectionTask(
data_reader=data_reader,
num_classes=ds.num_labels,
feed_list=feed_list,
feature=feature,
predict_feed_list=pred_feed_list,
predict_feature=pred_feature,
model_type=model_type,
config=config)
data = ["./test/test_img_bird.jpg", "./test/test_img_cat.jpg",]
label_map = ds.label_dict()
run_states = task.predict(data=data, accelerate_mode=False)
results = [run_state.run_results for run_state in run_states]
for outs in results:
keys = ['im_shape', 'im_id', 'bbox']
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(keys, outs)
}
print("im_id", res['im_id'])
is_bbox_normalized = dconf.conf[model_type]['is_bbox_normalized']
clsid2catid = {}
for k in label_map:
clsid2catid[k] = k
bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized)
print(bbox_results)
if __name__ == "__main__":
args = parser.parse_args()
if not args.module in module_map:
hub.logger.error("module should in %s" % module_map.keys())
exit(1)
args.module = module_map[args.module]
predict(args)
# -*- coding:utf8 -*-
import argparse
import os
import ast
import paddle.fluid as fluid
import paddlehub as hub
from paddlehub.reader.cv_reader import ObjectDetectionReader
from paddlehub.dataset.base_cv_dataset import ObjectDetectionDataset
import numpy as np
from paddlehub.common.detection_config import get_model_type, get_feed_list, get_mid_feature
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch", type=int, default=50, help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for fine-tuning.")
parser.add_argument("--checkpoint_dir", type=str, default="paddlehub_finetune_ckpt", help="Path to save log data.")
parser.add_argument("--batch_size", type=int, default=8, help="Total examples' number in batch for training.")
parser.add_argument("--module", type=str, default="ssd", help="Module used as feature extractor.")
parser.add_argument("--dataset", type=str, default="coco_10", help="Dataset to finetune.")
parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.")
# yapf: enable.
module_map = {
"yolov3": "yolov3_darknet53_coco2017",
"ssd": "ssd_vgg16_512_coco2017",
"faster_rcnn": "faster_rcnn_resnet50_coco2017",
}
def finetune(args):
module_name = args.module # 'yolov3_darknet53_coco2017'
model_type = get_model_type(module_name) # 'yolo'
# define dataset
ds = hub.dataset.Coco10(model_type)
# base_path = '/home/local3/zhaopenghao/data/detect/paddle-job-84942-0'
# train_dir = 'train_data/images'
# train_list = 'train_data/coco/instances_coco.json'
# val_dir = 'eval_data/images'
# val_list = 'eval_data/coco/instances_coco.json'
# ds = ObjectDetectionDataset(base_path, train_dir, train_list, val_dir, val_list, val_dir, val_list, model_type=model_type)
# print(ds.label_dict())
print("ds.num_labels", ds.num_labels)
# define batch reader
data_reader = ObjectDetectionReader(dataset=ds, model_type=model_type)
# define model(program)
module = hub.Module(name=module_name)
if model_type == 'rcnn':
input_dict, output_dict, program = module.context(trainable=True, phase='train')
input_dict_pred, output_dict_pred, program_pred = module.context(trainable=False)
else:
input_dict, output_dict, program = module.context(trainable=True)
input_dict_pred = output_dict_pred = None
print("input_dict keys", input_dict.keys())
print("output_dict keys", output_dict.keys())
feed_list, pred_feed_list = get_feed_list(module_name, input_dict, input_dict_pred)
print("output_dict length:", len(output_dict))
print(output_dict.keys())
if output_dict_pred is not None:
print(output_dict_pred.keys())
feature, pred_feature = get_mid_feature(module_name, output_dict, output_dict_pred)
config = hub.RunConfig(
log_interval=10,
eval_interval=100,
use_data_parallel=args.use_data_parallel,
use_pyreader=True,
use_cuda=args.use_gpu,
num_epoch=args.num_epoch,
batch_size=args.batch_size,
enable_memory_optim=False,
checkpoint_dir=args.checkpoint_dir,
strategy=hub.finetune.strategy.DefaultFinetuneStrategy(learning_rate=0.00025, optimizer_name="adam"))
task = hub.DetectionTask(
data_reader=data_reader,
num_classes=ds.num_labels,
feed_list=feed_list,
feature=feature,
predict_feed_list=pred_feed_list,
predict_feature=pred_feature,
model_type=model_type,
config=config)
task.finetune_and_eval()
if __name__ == "__main__":
args = parser.parse_args()
if not args.module in module_map:
hub.logger.error("module should in %s" % module_map.keys())
exit(1)
args.module = module_map[args.module]
finetune(args)
......@@ -48,6 +48,7 @@ from .io.type import DataType
from .finetune.task import BaseTask
from .finetune.task import ClassifierTask
from .finetune.task import DetectionTask
from .finetune.task import TextClassifierTask
from .finetune.task import ImageClassifierTask
from .finetune.task import SequenceLabelTask
......
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
conf = {
"ssd": {
"with_background": True,
"is_bbox_normalized": True,
# "norm_type": "bn",
},
"yolo": {
"with_background": False,
"is_bbox_normalized": False,
# "norm_type": "sync_bn",
"mixup_epoch": 10,
"num_max_boxes": 50,
},
"rcnn": {
"with_background": True,
"is_bbox_normalized": False,
# "norm_type": "affine_channel",
}
}
ssd_train_ops = [
dict(op='DecodeImage', to_rgb=True, with_mixup=False),
dict(op='NormalizeBox'),
dict(
op='RandomDistort',
brightness_lower=0.875,
brightness_upper=1.125,
is_order=True),
dict(op='ExpandImage', max_ratio=4, prob=0.5),
dict(
op='CropImage',
batch_sampler=[[1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 0.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 0.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 0.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 0.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 0.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]],
satisfy_all=False,
avoid_no_bbox=False),
dict(op='ResizeImage', target_size=512, use_cv2=False, interp=1),
dict(op='RandomFlipImage', is_normalized=True),
dict(op='Permute'),
dict(
op='NormalizeImage',
mean=[104, 117, 123],
std=[1, 1, 1],
is_scale=False),
dict(op='ArrangeSSD')
]
ssd_eval_fields = ['image', 'im_shape', 'im_id', 'gt_box', 'gt_label', 'is_difficult']
ssd_eval_ops = [
dict(op='DecodeImage', to_rgb=True, with_mixup=False),
dict(op='NormalizeBox'),
dict(op='ResizeImage', target_size=512, use_cv2=False, interp=1),
dict(op='Permute'),
dict(
op='NormalizeImage',
mean=[104, 117, 123],
std=[1, 1, 1],
is_scale=False),
dict(op='ArrangeEvalSSD', fields=ssd_eval_fields)
]
ssd_predict_ops = [
dict(op='DecodeImage', to_rgb=True, with_mixup=False),
dict(op='ResizeImage', target_size=512, use_cv2=False, interp=1),
dict(op='Permute'),
dict(
op='NormalizeImage',
mean=[104, 117, 123],
std=[1, 1, 1],
is_scale=False),
dict(op='ArrangeTestSSD')
]
rcnn_train_ops = [
dict(op='DecodeImage', to_rgb=True),
dict(op='RandomFlipImage', prob=0.5),
dict(
op='NormalizeImage',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True,
is_channel_first=False),
dict(op='ResizeImage', target_size=800, max_size=1333, interp=1),
dict(op='Permute', to_bgr=False),
dict(op='ArrangeRCNN'),
]
rcnn_eval_ops = [
dict(op='DecodeImage', to_rgb=True),
dict(
op='NormalizeImage',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True,
is_channel_first=False),
dict(op='ResizeImage', target_size=800, max_size=1333, interp=1),
dict(op='Permute', to_bgr=False),
dict(op='ArrangeEvalRCNN'),
]
rcnn_predict_ops = [
dict(op='DecodeImage', to_rgb=True),
dict(
op='NormalizeImage',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True,
is_channel_first=False),
dict(op='ResizeImage', target_size=800, max_size=1333, interp=1),
dict(op='Permute', to_bgr=False),
dict(op='ArrangeTestRCNN'),
]
yolo_train_ops = [
dict(op='DecodeImage', to_rgb=True, with_mixup=True),
dict(op='MixupImage', alpha=1.5, beta=1.5),
dict(op='ColorDistort'),
dict(op='RandomExpand', fill_value=[123.675, 116.28, 103.53]),
dict(op='RandomCrop'),
dict(op='RandomFlipImage', is_normalized=False),
dict(op='Resize', target_dim=608, interp='random'),
dict(op='NormalizePermute',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.120, 57.375]),
dict(op='NormalizeBox'),
dict(op='ArrangeYOLO'),
]
yolo_eval_ops = [
dict(op='DecodeImage', to_rgb=True),
dict(op='ResizeImage', target_size=608, interp=2),
dict(
op='NormalizeImage',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True,
is_channel_first=False),
dict(op='Permute', to_bgr=False),
dict(op='ArrangeEvalYOLO'),
]
yolo_predict_ops = [
dict(op='DecodeImage', to_rgb=True),
dict(op='ResizeImage', target_size=608, interp=2),
dict(
op='NormalizeImage',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True,
is_channel_first=False),
dict(op='Permute', to_bgr=False),
dict(op='ArrangeTestYOLO'),
]
feed_config = {
"ssd": {
"train": {
"fields": ['image', 'gt_box', 'gt_label'],
"OPS": ssd_train_ops,
"IS_PADDING": False,
},
"dev": {
# ['image', 'im_shape', 'im_id', 'gt_box', 'gt_label', 'is_difficult']
"fields": ssd_eval_fields,
"OPS": ssd_eval_ops,
"IS_PADDING": False,
},
"predict": {
"fields": ['image', 'im_id', 'im_shape'],
# "fields": ['image', 'im_id'],
"OPS": ssd_predict_ops,
"IS_PADDING": False,
},
},
"rcnn": {
"train": {
"fields": ['image', 'im_info', 'im_id', 'gt_box', 'gt_label', 'is_crowd'],
"OPS": rcnn_train_ops,
"IS_PADDING": True,
"COARSEST_STRIDE": 32,
},
"dev": {
"fields": ['image', 'im_info', 'im_id', 'im_shape', 'gt_box',
'gt_label', 'is_difficult'],
"OPS": rcnn_eval_ops,
"IS_PADDING": True,
"COARSEST_STRIDE": 32,
"USE_PADDED_IM_INFO": True,
},
"predict": {
"fields": ['image', 'im_info', 'im_id', 'im_shape'],
"OPS": rcnn_predict_ops,
"IS_PADDING": True,
"COARSEST_STRIDE": 32,
"USE_PADDED_IM_INFO": True,
},
},
"yolo": {
"train": {
"fields": ['image', 'gt_box', 'gt_label', 'gt_score'],
"OPS": yolo_train_ops,
"RANDOM_SHAPES": [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
},
"dev": {
"fields": ['image', 'im_size', 'im_id', 'gt_box', 'gt_label', 'is_difficult'],
"OPS": yolo_eval_ops,
},
"predict": {
"fields": ['image', 'im_size', 'im_id'],
"OPS": yolo_predict_ops,
},
},
}
def get_model_type(module_name):
if 'yolo' in module_name:
return 'yolo'
elif 'ssd' in module_name:
return 'ssd'
elif 'rcnn' in module_name:
return 'rcnn'
else:
raise ValueError("module {} not supported".format(module_name))
def get_feed_list(module_name, input_dict, input_dict_pred=None):
pred_feed_list = None
if 'yolo' in module_name:
img = input_dict["image"]
im_size = input_dict["im_size"]
feed_list = [img.name, im_size.name]
elif 'ssd' in module_name:
image = input_dict["image"]
# image_shape = input_dict["im_shape"]
image_shape = input_dict["im_size"]
feed_list = [image.name, image_shape.name]
elif 'rcnn' in module_name:
image = input_dict['image']
im_info = input_dict['im_info']
gt_bbox = input_dict['gt_bbox']
gt_class = input_dict['gt_class']
is_crowd = input_dict['is_crowd']
feed_list = [image.name, im_info.name, gt_bbox.name, gt_class.name, is_crowd.name]
assert input_dict_pred is not None
image = input_dict_pred['image']
im_info = input_dict_pred['im_info']
im_shape = input_dict['im_shape']
pred_feed_list = [image.name, im_info.name, im_shape.name]
else:
raise NotImplementedError
return feed_list, pred_feed_list
def get_mid_feature(module_name, output_dict, output_dict_pred=None):
feature_pred = None
if 'yolo' in module_name:
feature = output_dict['head_features']
elif 'ssd' in module_name:
feature = output_dict['body_features']
elif 'rcnn' in module_name:
head_feat = output_dict['head_feat']
rpn_cls_loss = output_dict['rpn_cls_loss']
rpn_reg_loss = output_dict['rpn_reg_loss']
generate_proposal_labels = output_dict['generate_proposal_labels']
feature = [head_feat, rpn_cls_loss, rpn_reg_loss, generate_proposal_labels]
assert output_dict_pred is not None
head_feat = output_dict_pred['head_feat']
rois = output_dict_pred['rois']
feature_pred = [head_feat, rois]
else:
raise NotImplementedError
return feature, feature_pred
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# function:
# interface for accessing data samples in stream
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class Dataset(object):
"""interface to access a stream of data samples"""
def __init__(self):
self._epoch = -1
def __next__(self):
return self.next()
def __iter__(self):
return self
def __str__(self):
return "{}(fname:{}, epoch:{:d}, size:{:d}, pos:{:d})".format(
type(self).__name__, self._fname, self._epoch, self.size(),
self._pos)
def next(self):
"""get next sample"""
raise NotImplementedError(
'%s.next not available' % (self.__class__.__name__))
def reset(self):
"""reset to initial status and begins a new epoch"""
raise NotImplementedError(
'%s.reset not available' % (self.__class__.__name__))
def size(self):
"""get number of samples in this dataset"""
raise NotImplementedError(
'%s.size not available' % (self.__class__.__name__))
def drained(self):
"""whether all sampled has been readed out for this epoch"""
raise NotImplementedError(
'%s.drained not available' % (self.__class__.__name__))
def epoch_id(self):
"""return epoch id for latest sample"""
raise NotImplementedError(
'%s.epoch_id not available' % (self.__class__.__name__))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# function:
# Interface to build readers for detection data like COCO or VOC
#
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from numbers import Integral
import logging
from .source import build_source
from .transform import build_mapper, map, batch, batch_map
logger = logging.getLogger(__name__)
class Reader(object):
"""Interface to make readers for training or evaluation"""
def __init__(self, data_cf, trans_conf, maxiter=-1):
self._data_cf = data_cf
self._trans_conf = trans_conf
self._maxiter = maxiter
self._cname2cid = None
assert isinstance(self._maxiter, Integral), "maxiter should be int"
def _make_reader(self, mode, my_source=None):
"""Build reader for training or validation"""
if my_source is None:
file_conf = self._data_cf[mode]
# 1, Build data source
sc_conf = {'data_cf': file_conf, 'cname2cid': self._cname2cid}
sc = build_source(sc_conf)
else:
sc = my_source
# 2, Buid a transformed dataset
ops = self._trans_conf[mode]['OPS']
batchsize = self._trans_conf[mode]['BATCH_SIZE']
drop_last = False if 'DROP_LAST' not in \
self._trans_conf[mode] else self._trans_conf[mode]['DROP_LAST']
mapper = build_mapper(ops, {'is_train': mode == 'TRAIN'})
worker_args = None
if 'WORKER_CONF' in self._trans_conf[mode]:
worker_args = self._trans_conf[mode]['WORKER_CONF']
worker_args = {k.lower(): v for k, v in worker_args.items()}
mapped_ds = map(sc, mapper, worker_args)
# In VAL mode, gt_bbox, gt_label can be empty, and should
# not be dropped
batched_ds = batch(
mapped_ds, batchsize, drop_last, drop_empty=(mode != "VAL"))
trans_conf = {k.lower(): v for k, v in self._trans_conf[mode].items()}
need_keys = {
'is_padding',
'coarsest_stride',
'random_shapes',
'multi_scales',
'use_padded_im_info',
'enable_multiscale_test',
'num_scale',
}
bm_config = {
key: value
for key, value in trans_conf.items() if key in need_keys
}
batched_ds = batch_map(batched_ds, bm_config)
batched_ds.reset()
if mode.lower() == 'train':
if self._cname2cid is not None:
logger.warn('cname2cid already set, it will be overridden')
self._cname2cid = getattr(sc, 'cname2cid', None)
# 3, Build a reader
maxit = -1 if self._maxiter <= 0 else self._maxiter
def _reader():
n = 0
while True:
for _batch in batched_ds:
yield _batch
n += 1
if maxit > 0 and n == maxit:
return
batched_ds.reset()
if maxit <= 0:
return
_reader._fname = None
if hasattr(sc, '_fname'):
_reader.annotation = sc._fname
if hasattr(sc, 'get_imid2path'):
_reader.imid2path = sc.get_imid2path()
return _reader
def train(self):
"""Build reader for training"""
return self._make_reader('TRAIN')
def val(self):
"""Build reader for validation"""
return self._make_reader('VAL')
def test(self):
"""Build reader for inference"""
return self._make_reader('TEST')
@classmethod
def create(cls,
mode,
data_config,
transform_config,
max_iter=-1,
my_source=None,
ret_iter=True):
""" create a specific reader """
reader = Reader({mode: data_config}, {mode: transform_config}, max_iter)
if ret_iter:
return reader._make_reader(mode, my_source)
else:
return reader
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
from .roidb_source import RoiDbSource
from .simple_source import SimpleSource
from .iterator_source import IteratorSource
from .class_aware_sampling_roidb_source import ClassAwareSamplingRoiDbSource
def build_source(config):
"""
Build dataset from source data, default source type is 'RoiDbSource'
Args:
config (dict): should have following structure:
{
data_cf (dict):
anno_file (str): label file or image list file path
image_dir (str): root directory for images
samples (int): number of samples to load, -1 means all
is_shuffle (bool): should samples be shuffled
load_img (bool): should images be loaded
mixup_epoch (int): parse mixup in first n epoch
with_background (bool): whether load background as a class
cname2cid (dict): the label name to id dictionary
}
"""
if 'data_cf' in config:
data_cf = config['data_cf']
data_cf['cname2cid'] = config['cname2cid']
else:
data_cf = config
data_cf = {k.lower(): v for k, v in data_cf.items()}
args = copy.deepcopy(data_cf)
# defaut type is 'RoiDbSource'
source_type = 'RoiDbSource'
if 'type' in data_cf:
if data_cf['type'] in ['VOCSource', 'COCOSource', 'RoiDbSource']:
if 'class_aware_sampling' in args and args['class_aware_sampling']:
source_type = 'ClassAwareSamplingRoiDbSource'
else:
source_type = 'RoiDbSource'
if 'class_aware_sampling' in args:
del args['class_aware_sampling']
else:
source_type = data_cf['type']
del args['type']
if source_type == 'RoiDbSource':
return RoiDbSource(**args)
elif source_type == 'SimpleSource':
return SimpleSource(**args)
elif source_type == 'ClassAwareSamplingRoiDbSource':
return ClassAwareSamplingRoiDbSource(**args)
else:
raise ValueError('source type not supported: ' + source_type)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#function:
# interface to load data from local files and parse it for samples,
# eg: roidb data in pickled files
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import random
import copy
import collections
import pickle as pkl
import numpy as np
from .roidb_source import RoiDbSource
class ClassAwareSamplingRoiDbSource(RoiDbSource):
""" interface to load class aware sampling roidb data from files
"""
def __init__(self,
anno_file,
image_dir=None,
samples=-1,
is_shuffle=True,
load_img=False,
cname2cid=None,
use_default_label=None,
mixup_epoch=-1,
with_background=True):
""" Init
Args:
fname (str): label file path
image_dir (str): root dir for images
samples (int): samples to load, -1 means all
is_shuffle (bool): whether to shuffle samples
load_img (bool): whether load data in this class
cname2cid (dict): the label name to id dictionary
use_default_label (bool):whether use the default mapping of label to id
mixup_epoch (int): parse mixup in first n epoch
with_background (bool): whether load background
as a class
"""
super(ClassAwareSamplingRoiDbSource, self).__init__(
anno_file=anno_file,
image_dir=image_dir,
samples=samples,
is_shuffle=is_shuffle,
load_img=load_img,
cname2cid=cname2cid,
use_default_label=use_default_label,
mixup_epoch=mixup_epoch,
with_background=with_background)
self._img_weights = None
def __str__(self):
return 'ClassAwareSamplingRoidbSource(fname:%s,epoch:%d,size:%d)' \
% (self._fname, self._epoch, self.size())
def next(self):
""" load next sample
"""
if self._epoch < 0:
self.reset()
_pos = np.random.choice(
self._samples, 1, replace=False, p=self._img_weights)[0]
sample = copy.deepcopy(self._roidb[_pos])
if self._load_img:
sample['image'] = self._load_image(sample['im_file'])
else:
sample['im_file'] = os.path.join(self._image_dir, sample['im_file'])
return sample
def _calc_img_weights(self):
""" calculate the probabilities of each sample
"""
imgs_cls = []
num_per_cls = {}
img_weights = []
for i, roidb in enumerate(self._roidb):
img_cls = set(
[k for cls in self._roidb[i]['gt_class'] for k in cls])
imgs_cls.append(img_cls)
for c in img_cls:
if c not in num_per_cls:
num_per_cls[c] = 1
else:
num_per_cls[c] += 1
for i in range(len(self._roidb)):
weights = 0
for c in imgs_cls[i]:
weights += 1 / num_per_cls[c]
img_weights.append(weights)
# Probabilities sum to 1
img_weights = img_weights / np.sum(img_weights)
return img_weights
def reset(self):
""" implementation of Dataset.reset
"""
if self._roidb is None:
self._roidb = self._load()
if self._img_weights is None:
self._img_weights = self._calc_img_weights()
self._samples = len(self._roidb)
if self._epoch < 0:
self._epoch = 0
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from pycocotools.coco import COCO
import logging
logger = logging.getLogger(__name__)
def load(anno_path, sample_num=-1, with_background=True):
"""
Load COCO records with annotations in json file 'anno_path'
Args:
anno_path (str): json file path
sample_num (int): number of samples to load, -1 means all
with_background (bool): whether load background as a class.
if True, total class number will
be 81. default True
Returns:
(records, cname2cid)
'records' is list of dict whose structure is:
{
'im_file': im_fname, # image file name
'im_id': img_id, # image id
'h': im_h, # height of image
'w': im_w, # width
'is_crowd': is_crowd,
'gt_score': gt_score,
'gt_class': gt_class,
'gt_bbox': gt_bbox,
'gt_poly': gt_poly,
}
'cname2cid' is a dict used to map category name to class id
"""
assert anno_path.endswith('.json'), 'invalid coco annotation file: ' \
+ anno_path
coco = COCO(anno_path)
img_ids = coco.getImgIds()
cat_ids = coco.getCatIds()
records = []
ct = 0
# when with_background = True, mapping category to classid, like:
# background:0, first_class:1, second_class:2, ...
catid2clsid = dict(
{catid: i + int(with_background)
for i, catid in enumerate(cat_ids)})
cname2cid = dict({
coco.loadCats(catid)[0]['name']: clsid
for catid, clsid in catid2clsid.items()
})
for img_id in img_ids:
img_anno = coco.loadImgs(img_id)[0]
im_fname = img_anno['file_name']
im_w = float(img_anno['width'])
im_h = float(img_anno['height'])
ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False)
instances = coco.loadAnns(ins_anno_ids)
bboxes = []
for inst in instances:
x, y, box_w, box_h = inst['bbox']
x1 = max(0, x)
y1 = max(0, y)
x2 = min(im_w - 1, x1 + max(0, box_w - 1))
y2 = min(im_h - 1, y1 + max(0, box_h - 1))
if inst['area'] > 0 and x2 >= x1 and y2 >= y1:
inst['clean_bbox'] = [x1, y1, x2, y2]
bboxes.append(inst)
else:
logger.warn(
'Found an invalid bbox in annotations: im_id: {}, area: {} x1: {}, y1: {}, x2: {}, y2: {}.'
.format(img_id, float(inst['area']), x1, y1, x2, y2))
num_bbox = len(bboxes)
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
gt_score = np.ones((num_bbox, 1), dtype=np.float32)
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
difficult = np.zeros((num_bbox, 1), dtype=np.int32)
gt_poly = [None] * num_bbox
for i, box in enumerate(bboxes):
catid = box['category_id']
gt_class[i][0] = catid2clsid[catid]
gt_bbox[i, :] = box['clean_bbox']
is_crowd[i][0] = box['iscrowd']
if 'segmentation' in box:
gt_poly[i] = box['segmentation']
coco_rec = {
'im_file': im_fname,
'im_id': np.array([img_id]),
'h': im_h,
'w': im_w,
'is_crowd': is_crowd,
'gt_class': gt_class,
'gt_bbox': gt_bbox,
'gt_score': gt_score,
'gt_poly': gt_poly,
'difficult': difficult
}
logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
im_fname, img_id, im_h, im_w))
records.append(coco_rec)
ct += 1
if sample_num > 0 and ct >= sample_num:
break
assert len(records) > 0, 'not found any coco record in %s' % (anno_path)
logger.info('{} samples in file {}'.format(ct, anno_path))
return records, cname2cid
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import copy
import logging
logger = logging.getLogger(__name__)
from ..dataset import Dataset
class IteratorSource(Dataset):
"""
Load data samples from iterator in stream mode
Args:
iter_maker (callable): callable function to generate a iter
samples (int): number of samples to load, -1 means all
"""
def __init__(self, iter_maker, samples=-1, **kwargs):
super(IteratorSource, self).__init__()
self._epoch = -1
self._iter_maker = iter_maker
self._data_iter = None
self._pos = -1
self._drained = False
self._samples = samples
self._sample_num = -1
def next(self):
if self._epoch < 0:
self.reset()
if self._data_iter is not None:
try:
sample = next(self._data_iter)
self._pos += 1
ret = sample
except StopIteration as e:
if self._sample_num <= 0:
self._sample_num = self._pos
elif self._sample_num != self._pos:
logger.info('num of loaded samples is different '
'with previouse setting[prev:%d,now:%d]' %
(self._sample_num, self._pos))
self._sample_num = self._pos
self._data_iter = None
self._drained = True
raise e
else:
raise StopIteration("no more data in " + str(self))
if self._samples > 0 and self._pos >= self._samples:
self._data_iter = None
self._drained = True
raise StopIteration("no more data in " + str(self))
else:
return ret
def reset(self):
if self._data_iter is None:
self._data_iter = self._iter_maker()
if self._epoch < 0:
self._epoch = 0
else:
self._epoch += 1
self._pos = 0
self._drained = False
def size(self):
return self._sample_num
def drained(self):
assert self._epoch >= 0, "the first epoch has not started yet"
return self._pos >= self.size()
def epoch_id(self):
return self._epoch
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# function:
# load data records from local files(maybe in COCO or VOC data formats)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import numpy as np
import logging
import pickle as pkl
logger = logging.getLogger(__name__)
def check_records(records):
""" check the fields of 'records' must contains some keys
"""
needed_fields = [
'im_file', 'im_id', 'h', 'w', 'is_crowd', 'gt_class', 'gt_bbox',
'gt_poly'
]
for i, rec in enumerate(records):
for k in needed_fields:
assert k in rec, 'not found field[%s] in record[%d]' % (k, i)
def load_roidb(anno_file, sample_num=-1):
""" load normalized data records from file
'anno_file' which is a pickled file.
And the records should has a structure:
{
'im_file': str, # image file name
'im_id': int, # image id
'h': int, # height of image
'w': int, # width of image
'is_crowd': bool,
'gt_class': list of np.ndarray, # classids info
'gt_bbox': list of np.ndarray, # bounding box info
'gt_poly': list of int, # poly info
}
Args:
anno_file (str): file name for picked records
sample_num (int): number of samples to load
Returns:
list of records for detection model training
"""
assert anno_file.endswith('.roidb'), 'invalid roidb file[%s]' % (anno_file)
with open(anno_file, 'rb') as f:
roidb = f.read()
# for support python3 and python2
try:
records, cname2cid = pkl.loads(roidb, encoding='bytes')
except:
records, cname2cid = pkl.loads(roidb)
assert type(records) is list, 'invalid data type from roidb'
if sample_num > 0 and sample_num < len(records):
records = records[:sample_num]
return records, cname2cid
def load(fname,
samples=-1,
with_background=True,
with_cat2id=False,
use_default_label=None,
cname2cid=None):
""" Load data records from 'fnames'
Args:
fnames (str): file name for data record, eg:
instances_val2017.json or COCO17_val2017.roidb
samples (int): number of samples to load, default to all
with_background (bool): whether load background as a class.
default True.
with_cat2id (bool): whether return cname2cid info out
use_default_label (bool): whether use the default mapping of label to id
cname2cid (dict): the mapping of category name to id
Returns:
list of loaded records whose structure is:
{
'im_file': str, # image file name
'im_id': int, # image id
'h': int, # height of image
'w': int, # width of image
'is_crowd': bool,
'gt_class': list of np.ndarray, # classids info
'gt_bbox': list of np.ndarray, # bounding box info
'gt_poly': list of int, # poly info
}
"""
if fname.endswith('.roidb'):
records, cname2cid = load_roidb(fname, samples)
elif fname.endswith('.json'):
from . import coco_loader
records, cname2cid = coco_loader.load(fname, samples, with_background)
elif "wider_face" in fname:
from . import widerface_loader
records = widerface_loader.load(fname, samples)
return records
elif os.path.isfile(fname):
from . import voc_loader
if use_default_label is None or cname2cid is not None:
records, cname2cid = voc_loader.get_roidb(
fname, samples, cname2cid, with_background=with_background)
else:
records, cname2cid = voc_loader.load(
fname,
samples,
use_default_label,
with_background=with_background)
else:
raise ValueError(
'invalid file type when load data from file[%s]' % (fname))
check_records(records)
if with_cat2id:
return records, cname2cid
else:
return records
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#function:
# interface to load data from local files and parse it for samples,
# eg: roidb data in pickled files
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import random
import copy
import pickle as pkl
from ..dataset import Dataset
class RoiDbSource(Dataset):
""" interface to load roidb data from files
"""
def __init__(self,
anno_file,
image_dir=None,
samples=-1,
is_shuffle=True,
load_img=False,
cname2cid=None,
use_default_label=None,
mixup_epoch=-1,
with_background=True):
""" Init
Args:
fname (str): label file path
image_dir (str): root dir for images
samples (int): samples to load, -1 means all
is_shuffle (bool): whether to shuffle samples
load_img (bool): whether load data in this class
cname2cid (dict): the label name to id dictionary
use_default_label (bool):whether use the default mapping of label to id
mixup_epoch (int): parse mixup in first n epoch
with_background (bool): whether load background
as a class
"""
super(RoiDbSource, self).__init__()
self._epoch = -1
assert os.path.isfile(anno_file) or os.path.isdir(anno_file), \
'anno_file {} is not a file or a directory'.format(anno_file)
self._fname = anno_file
self._image_dir = image_dir if image_dir is not None else ''
if image_dir is not None:
assert os.path.isdir(image_dir), \
'image_dir {} is not a directory'.format(image_dir)
self._roidb = None
self._pos = -1
self._drained = False
self._samples = samples
self._is_shuffle = is_shuffle
self._load_img = load_img
self.use_default_label = use_default_label
self._mixup_epoch = mixup_epoch
self._with_background = with_background
self.cname2cid = cname2cid
self._imid2path = None
def __str__(self):
return 'RoiDbSource(fname:%s,epoch:%d,size:%d,pos:%d)' \
% (self._fname, self._epoch, self.size(), self._pos)
def next(self):
""" load next sample
"""
if self._epoch < 0:
self.reset()
if self._pos >= self._samples:
self._drained = True
raise StopIteration('%s no more data' % (str(self)))
sample = copy.deepcopy(self._roidb[self._pos])
if self._load_img:
sample['image'] = self._load_image(sample['im_file'])
else:
sample['im_file'] = os.path.join(self._image_dir, sample['im_file'])
if self._epoch < self._mixup_epoch:
mix_idx = random.randint(1, self._samples - 1)
mix_pos = (mix_idx + self._pos) % self._samples
sample['mixup'] = copy.deepcopy(self._roidb[mix_pos])
if self._load_img:
sample['mixup']['image'] = \
self._load_image(sample['mixup']['im_file'])
else:
sample['mixup']['im_file'] = \
os.path.join(self._image_dir, sample['mixup']['im_file'])
self._pos += 1
return sample
def _load(self):
""" load data from file
"""
from . import loader
records, cname2cid = loader.load(self._fname, self._samples,
self._with_background, True,
self.use_default_label, self.cname2cid)
self.cname2cid = cname2cid
return records
def _load_image(self, where):
fn = os.path.join(self._image_dir, where)
with open(fn, 'rb') as f:
return f.read()
def reset(self):
""" implementation of Dataset.reset
"""
if self._roidb is None:
self._roidb = self._load()
self._samples = len(self._roidb)
if self._is_shuffle:
random.shuffle(self._roidb)
if self._epoch < 0:
self._epoch = 0
else:
self._epoch += 1
self._pos = 0
self._drained = False
def size(self):
""" implementation of Dataset.size
"""
return len(self._roidb)
def drained(self):
""" implementation of Dataset.drained
"""
assert self._epoch >= 0, 'The first epoch has not begin!'
return self._pos >= self.size()
def epoch_id(self):
""" return epoch id for latest sample
"""
return self._epoch
def get_imid2path(self):
"""return image id to image path map"""
if self._imid2path is None:
self._imid2path = {}
for record in self._roidb:
im_id = record['im_id']
im_id = im_id if isinstance(im_id, int) else im_id[0]
im_path = os.path.join(self._image_dir, record['im_file'])
self._imid2path[im_id] = im_path
return self._imid2path
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# function:
# interface to load data from txt file.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import copy
from ..dataset import Dataset
class SimpleSource(Dataset):
"""
Load image files for testing purpose
Args:
images (list): list of path of images
samples (int): number of samples to load, -1 means all
load_img (bool): should images be loaded
"""
def __init__(self, images=[], samples=-1, load_img=True, **kwargs):
super(SimpleSource, self).__init__()
self._epoch = -1
for image in images:
assert image != '' and os.path.isfile(image), \
"Image {} not found".format(image)
self._images = images
self._fname = None
self._simple = None
self._pos = -1
self._drained = False
self._samples = samples
self._load_img = load_img
self._imid2path = {}
def next(self):
if self._epoch < 0:
self.reset()
if self._pos >= self.size():
self._drained = True
raise StopIteration("no more data in " + str(self))
else:
sample = copy.deepcopy(self._simple[self._pos])
if self._load_img:
sample['image'] = self._load_image(sample['im_file'])
self._pos += 1
return sample
def _load(self):
ct = 0
records = []
for image in self._images:
if self._samples > 0 and ct >= self._samples:
break
rec = {'im_id': np.array([ct]), 'im_file': image}
self._imid2path[ct] = image
ct += 1
records.append(rec)
assert len(records) > 0, "no image file found"
return records
def _load_image(self, where):
with open(where, 'rb') as f:
return f.read()
def reset(self):
if self._simple is None:
self._simple = self._load()
if self._epoch < 0:
self._epoch = 0
else:
self._epoch += 1
self._pos = 0
self._drained = False
def size(self):
return len(self._simple)
def drained(self):
assert self._epoch >= 0, "the first epoch has not started yet"
return self._pos >= self.size()
def epoch_id(self):
return self._epoch
def get_imid2path(self):
"""return image id to image path map"""
return self._imid2path
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import xml.etree.ElementTree as ET
def get_roidb(anno_path, sample_num=-1, cname2cid=None, with_background=True):
"""
Load VOC records with annotations in xml directory 'anno_path'
Notes:
${anno_path} must contains xml file and image file path for annotations
Args:
anno_path (str): root directory for voc annotation data
sample_num (int): number of samples to load, -1 means all
cname2cid (dict): the label name to id dictionary
with_background (bool): whether load background as a class.
if True, total class number will
be 81. default True
Returns:
(records, catname2clsid)
'records' is list of dict whose structure is:
{
'im_file': im_fname, # image file name
'im_id': im_id, # image id
'h': im_h, # height of image
'w': im_w, # width
'is_crowd': is_crowd,
'gt_class': gt_class,
'gt_bbox': gt_bbox,
'gt_poly': gt_poly,
}
'cname2id' is a dict to map category name to class id
"""
data_dir = os.path.dirname(anno_path)
records = []
ct = 0
existence = False if cname2cid is None else True
if cname2cid is None:
cname2cid = {}
# mapping category name to class id
# background:0, first_class:1, second_class:2, ...
with open(anno_path, 'r') as fr:
while True:
line = fr.readline()
if not line:
break
img_file, xml_file = [os.path.join(data_dir, x) \
for x in line.strip().split()[:2]]
if not os.path.isfile(xml_file):
continue
tree = ET.parse(xml_file)
if tree.find('id') is None:
im_id = np.array([ct])
else:
im_id = np.array([int(tree.find('id').text)])
objs = tree.findall('object')
im_w = float(tree.find('size').find('width').text)
im_h = float(tree.find('size').find('height').text)
gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
gt_class = np.zeros((len(objs), 1), dtype=np.int32)
gt_score = np.ones((len(objs), 1), dtype=np.float32)
is_crowd = np.zeros((len(objs), 1), dtype=np.int32)
difficult = np.zeros((len(objs), 1), dtype=np.int32)
for i, obj in enumerate(objs):
cname = obj.find('name').text
if not existence and cname not in cname2cid:
# the background's id is 0, so need to add 1.
cname2cid[cname] = len(cname2cid) + int(with_background)
elif existence and cname not in cname2cid:
raise KeyError(
'Not found cname[%s] in cname2cid when map it to cid.' %
(cname))
gt_class[i][0] = cname2cid[cname]
_difficult = int(obj.find('difficult').text)
x1 = float(obj.find('bndbox').find('xmin').text)
y1 = float(obj.find('bndbox').find('ymin').text)
x2 = float(obj.find('bndbox').find('xmax').text)
y2 = float(obj.find('bndbox').find('ymax').text)
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(im_w - 1, x2)
y2 = min(im_h - 1, y2)
gt_bbox[i] = [x1, y1, x2, y2]
is_crowd[i][0] = 0
difficult[i][0] = _difficult
voc_rec = {
'im_file': img_file,
'im_id': im_id,
'h': im_h,
'w': im_w,
'is_crowd': is_crowd,
'gt_class': gt_class,
'gt_score': gt_score,
'gt_bbox': gt_bbox,
'gt_poly': [],
'difficult': difficult
}
if len(objs) != 0:
records.append(voc_rec)
ct += 1
if sample_num > 0 and ct >= sample_num:
break
assert len(records) > 0, 'not found any voc record in %s' % (anno_path)
return [records, cname2cid]
def load(anno_path, sample_num=-1, use_default_label=True,
with_background=True):
"""
Load VOC records with annotations in
xml directory 'anno_path'
Notes:
${anno_path} must contains xml file and image file path for annotations
Args:
@anno_path (str): root directory for voc annotation data
@sample_num (int): number of samples to load, -1 means all
@use_default_label (bool): whether use the default mapping of label to id
@with_background (bool): whether load background as a class.
if True, total class number will
be 81. default True
Returns:
(records, catname2clsid)
'records' is list of dict whose structure is:
{
'im_file': im_fname, # image file name
'im_id': im_id, # image id
'h': im_h, # height of image
'w': im_w, # width
'is_crowd': is_crowd,
'gt_class': gt_class,
'gt_bbox': gt_bbox,
'gt_poly': gt_poly,
}
'cname2id' is a dict to map category name to class id
"""
data_dir = os.path.dirname(anno_path)
# mapping category name to class id
# if with_background is True:
# background:0, first_class:1, second_class:2, ...
# if with_background is False:
# first_class:0, second_class:1, ...
records = []
ct = 0
cname2cid = {}
if not use_default_label:
label_path = os.path.join(data_dir, 'label_list.txt')
with open(label_path, 'r') as fr:
label_id = int(with_background)
for line in fr.readlines():
cname2cid[line.strip()] = label_id
label_id += 1
else:
cname2cid = pascalvoc_label(with_background)
with open(anno_path, 'r') as fr:
while True:
line = fr.readline()
if not line:
break
img_file, xml_file = [os.path.join(data_dir, x) \
for x in line.strip().split()[:2]]
if not os.path.isfile(xml_file):
continue
tree = ET.parse(xml_file)
if tree.find('id') is None:
im_id = np.array([ct])
else:
im_id = np.array([int(tree.find('id').text)])
objs = tree.findall('object')
im_w = float(tree.find('size').find('width').text)
im_h = float(tree.find('size').find('height').text)
gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
gt_class = np.zeros((len(objs), 1), dtype=np.int32)
gt_score = np.ones((len(objs), 1), dtype=np.float32)
is_crowd = np.zeros((len(objs), 1), dtype=np.int32)
difficult = np.zeros((len(objs), 1), dtype=np.int32)
for i, obj in enumerate(objs):
cname = obj.find('name').text
gt_class[i][0] = cname2cid[cname]
_difficult = int(obj.find('difficult').text)
x1 = float(obj.find('bndbox').find('xmin').text)
y1 = float(obj.find('bndbox').find('ymin').text)
x2 = float(obj.find('bndbox').find('xmax').text)
y2 = float(obj.find('bndbox').find('ymax').text)
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(im_w - 1, x2)
y2 = min(im_h - 1, y2)
gt_bbox[i] = [x1, y1, x2, y2]
is_crowd[i][0] = 0
difficult[i][0] = _difficult
voc_rec = {
'im_file': img_file,
'im_id': im_id,
'h': im_h,
'w': im_w,
'is_crowd': is_crowd,
'gt_class': gt_class,
'gt_score': gt_score,
'gt_bbox': gt_bbox,
'gt_poly': [],
'difficult': difficult
}
if len(objs) != 0:
records.append(voc_rec)
ct += 1
if sample_num > 0 and ct >= sample_num:
break
assert len(records) > 0, 'not found any voc record in %s' % (anno_path)
return [records, cname2cid]
def pascalvoc_label(with_background=True):
labels_map = {
'aeroplane': 1,
'bicycle': 2,
'bird': 3,
'boat': 4,
'bottle': 5,
'bus': 6,
'car': 7,
'cat': 8,
'chair': 9,
'cow': 10,
'diningtable': 11,
'dog': 12,
'horse': 13,
'motorbike': 14,
'person': 15,
'pottedplant': 16,
'sheep': 17,
'sofa': 18,
'train': 19,
'tvmonitor': 20
}
if not with_background:
labels_map = {k: v - 1 for k, v in labels_map.items()}
return labels_map
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import logging
logger = logging.getLogger(__name__)
def load(anno_path, sample_num=-1, cname2cid=None, with_background=True):
"""
Load WiderFace records with 'anno_path'
Args:
anno_path (str): root directory for voc annotation data
sample_num (int): number of samples to load, -1 means all
with_background (bool): whether load background as a class.
if True, total class number will
be 2. default True
Returns:
(records, catname2clsid)
'records' is list of dict whose structure is:
{
'im_file': im_fname, # image file name
'im_id': im_id, # image id
'gt_class': gt_class,
'gt_bbox': gt_bbox,
}
'cname2id' is a dict to map category name to class id
"""
txt_file = anno_path
records = []
ct = 0
file_lists = _load_file_list(txt_file)
cname2cid = widerface_label(with_background)
for item in file_lists:
im_fname = item[0]
im_id = np.array([ct])
gt_bbox = np.zeros((len(item) - 2, 4), dtype=np.float32)
gt_class = np.ones((len(item) - 2, 1), dtype=np.int32)
for index_box in range(len(item)):
if index_box >= 2:
temp_info_box = item[index_box].split(' ')
xmin = float(temp_info_box[0])
ymin = float(temp_info_box[1])
w = float(temp_info_box[2])
h = float(temp_info_box[3])
# Filter out wrong labels
if w < 0 or h < 0:
continue
xmin = max(0, xmin)
ymin = max(0, ymin)
xmax = xmin + w
ymax = ymin + h
gt_bbox[index_box - 2] = [xmin, ymin, xmax, ymax]
widerface_rec = {
'im_file': im_fname,
'im_id': im_id,
'gt_bbox': gt_bbox,
'gt_class': gt_class,
}
# logger.debug
if len(item) != 0:
records.append(widerface_rec)
ct += 1
if sample_num > 0 and ct >= sample_num:
break
assert len(records) > 0, 'not found any widerface in %s' % (anno_path)
logger.info('{} samples in file {}'.format(ct, anno_path))
return records, cname2cid
def _load_file_list(input_txt):
with open(input_txt, 'r') as f_dir:
lines_input_txt = f_dir.readlines()
file_dict = {}
num_class = 0
for i in range(len(lines_input_txt)):
line_txt = lines_input_txt[i].strip('\n\t\r')
if '.jpg' in line_txt:
if i != 0:
num_class += 1
file_dict[num_class] = []
file_dict[num_class].append(line_txt)
if '.jpg' not in line_txt:
if len(line_txt) > 6:
split_str = line_txt.split(' ')
x1_min = float(split_str[0])
y1_min = float(split_str[1])
x2_max = float(split_str[2])
y2_max = float(split_str[3])
line_txt = str(x1_min) + ' ' + str(y1_min) + ' ' + str(
x2_max) + ' ' + str(y2_max)
file_dict[num_class].append(line_txt)
else:
file_dict[num_class].append(line_txt)
return list(file_dict.values())
def widerface_label(with_background=True):
labels_map = {'face': 1}
if not with_background:
labels_map = {k: v - 1 for k, v in labels_map.items()}
return labels_map
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import print_function
import copy
import logging
import traceback
from .transformer import MappedDataset, BatchedDataset
from .post_map import build_post_map
from .parallel_map import ParallelMappedDataset
from .operators import BaseOperator, registered_ops
__all__ = ['build_mapper', 'map', 'batch', 'batch_map']
logger = logging.getLogger(__name__)
def build_mapper(ops, context=None):
"""
Build a mapper for operators in 'ops'
Args:
ops (list of operator.BaseOperator or list of op dict):
configs for oprators, eg:
[{'name': 'DecodeImage', 'params': {'to_rgb': True}}, {xxx}]
context (dict): a context object for mapper
Returns:
a mapper function which accept one argument 'sample' and
return the processed result
"""
new_ops = []
for _dict in ops:
new_dict = {}
for i, j in _dict.items():
new_dict[i.lower()] = j
new_ops.append(new_dict)
ops = new_ops
op_funcs = []
op_repr = []
for op in ops:
if type(op) is dict and 'op' in op:
op_func = getattr(BaseOperator, op['op'])
params = copy.deepcopy(op)
del params['op']
o = op_func(**params)
elif not isinstance(op, BaseOperator):
op_func = getattr(BaseOperator, op['name'])
params = {} if 'params' not in op else op['params']
o = op_func(**params)
else:
assert isinstance(op, BaseOperator), \
"invalid operator when build ops"
o = op
op_funcs.append(o)
op_repr.append('{{{}}}'.format(str(o)))
op_repr = '[{}]'.format(','.join(op_repr))
def _mapper(sample):
ctx = {} if context is None else copy.deepcopy(context)
for f in op_funcs:
try:
out = f(sample, ctx)
sample = out
except Exception as e:
stack_info = traceback.format_exc()
logger.warn(
"fail to map op [{}] with error: {} and stack:\n{}".format(
f, e, str(stack_info)))
raise e
return out
_mapper.ops = op_repr
return _mapper
def map(ds, mapper, worker_args=None):
"""
Apply 'mapper' to 'ds'
Args:
ds (instance of Dataset): dataset to be mapped
mapper (function): action to be executed for every data sample
worker_args (dict): configs for concurrent mapper
Returns:
a mapped dataset
"""
if worker_args is not None:
return ParallelMappedDataset(ds, mapper, worker_args)
else:
return MappedDataset(ds, mapper)
def batch(ds, batchsize, drop_last=False, drop_empty=True):
"""
Batch data samples to batches
Args:
batchsize (int): number of samples for a batch
drop_last (bool): drop last few samples if not enough for a batch
Returns:
a batched dataset
"""
return BatchedDataset(
ds, batchsize, drop_last=drop_last, drop_empty=drop_empty)
def batch_map(ds, config):
"""
Post process the batches.
Args:
ds (instance of Dataset): dataset to be mapped
mapper (function): action to be executed for every batch
Returns:
a batched dataset which is processed
"""
mapper = build_post_map(**config)
return MappedDataset(ds, mapper)
for nm in registered_ops:
op = getattr(BaseOperator, nm)
locals()[nm] = op
__all__ += registered_ops
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# function:
# operators to process sample,
# eg: decode/resize/crop image
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
from .operators import BaseOperator, register_op
logger = logging.getLogger(__name__)
@register_op
class ArrangeRCNN(BaseOperator):
"""
Transform dict to tuple format needed for training.
Args:
is_mask (bool): whether to use include mask data
"""
def __init__(self, is_mask=False):
super(ArrangeRCNN, self).__init__()
self.is_mask = is_mask
assert isinstance(self.is_mask, bool), "wrong type for is_mask"
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing following items
(image, im_info, im_id, gt_bbox, gt_class, is_crowd, gt_masks)
"""
im = sample['image']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
keys = list(sample.keys())
if 'is_crowd' in keys:
is_crowd = sample['is_crowd']
else:
raise KeyError("The dataset doesn't have 'is_crowd' key.")
if 'im_info' in keys:
im_info = sample['im_info']
else:
raise KeyError("The dataset doesn't have 'im_info' key.")
im_id = sample['im_id']
outs = (im, im_info, im_id, gt_bbox, gt_class, is_crowd)
gt_masks = []
if self.is_mask and len(sample['gt_poly']) != 0 \
and 'is_crowd' in keys:
valid = True
segms = sample['gt_poly']
assert len(segms) == is_crowd.shape[0]
for i in range(len(sample['gt_poly'])):
segm, iscrowd = segms[i], is_crowd[i]
gt_segm = []
if iscrowd:
gt_segm.append([[0, 0]])
else:
for poly in segm:
if len(poly) == 0:
valid = False
break
gt_segm.append(np.array(poly).reshape(-1, 2))
if (not valid) or len(gt_segm) == 0:
break
gt_masks.append(gt_segm)
outs = outs + (gt_masks, )
return outs
@register_op
class ArrangeEvalRCNN(BaseOperator):
"""
Transform dict to the tuple format needed for evaluation.
"""
def __init__(self):
super(ArrangeEvalRCNN, self).__init__()
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items:
(image, im_info, im_id, im_shape, gt_bbox,
gt_class, difficult)
"""
ims = []
keys = sorted(list(sample.keys()))
for k in keys:
if 'image' in k:
ims.append(sample[k])
if 'im_info' in keys:
im_info = sample['im_info']
else:
raise KeyError("The dataset doesn't have 'im_info' key.")
im_id = sample['im_id']
h = sample['h']
w = sample['w']
# For rcnn models in eval and infer stage, original image size
# is needed to clip the bounding boxes. And box clip op in
# bbox prediction needs im_info as input in format of [N, 3],
# so im_shape is appended by 1 to match dimension.
im_shape = np.array((h, w, 1), dtype=np.float32)
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
difficult = sample['difficult']
remain_list = [im_info, im_id, im_shape, gt_bbox, gt_class, difficult]
ims.extend(remain_list)
outs = tuple(ims)
return outs
@register_op
class ArrangeTestRCNN(BaseOperator):
"""
Transform dict to the tuple format needed for training.
"""
def __init__(self):
super(ArrangeTestRCNN, self).__init__()
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items:
(image, im_info, im_id, im_shape)
"""
ims = []
keys = sorted(list(sample.keys()))
for k in keys:
if 'image' in k:
ims.append(sample[k])
if 'im_info' in keys:
im_info = sample['im_info']
else:
raise KeyError("The dataset doesn't have 'im_info' key.")
im_id = sample['im_id']
h = sample['h']
w = sample['w']
# For rcnn models in eval and infer stage, original image size
# is needed to clip the bounding boxes. And box clip op in
# bbox prediction needs im_info as input in format of [N, 3],
# so im_shape is appended by 1 to match dimension.
im_shape = np.array((h, w, 1), dtype=np.float32)
remain_list = [im_info, im_id, im_shape]
ims.extend(remain_list)
outs = tuple(ims)
return outs
@register_op
class ArrangeSSD(BaseOperator):
"""
Transform dict to tuple format needed for training.
"""
def __init__(self):
super(ArrangeSSD, self).__init__()
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items:
(image, gt_bbox, gt_class, difficult)
"""
im = sample['image']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
outs = (im, gt_bbox, gt_class)
return outs
@register_op
class ArrangeEvalSSD(BaseOperator):
"""
Transform dict to tuple format needed for training.
"""
def __init__(self, fields):
super(ArrangeEvalSSD, self).__init__()
self.fields = fields
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items: (image)
"""
outs = []
if len(sample['gt_bbox']) != len(sample['gt_class']):
raise ValueError("gt num mismatch: bbox and class.")
for field in self.fields:
if field == 'im_shape':
h = sample['h']
w = sample['w']
im_shape = np.array((h, w))
outs.append(im_shape)
elif field == 'is_difficult':
outs.append(sample['difficult'])
elif field == 'gt_box':
outs.append(sample['gt_bbox'])
elif field == 'gt_label':
outs.append(sample['gt_class'])
else:
outs.append(sample[field])
outs = tuple(outs)
return outs
@register_op
class ArrangeTestSSD(BaseOperator):
"""
Transform dict to tuple format needed for training.
Args:
is_mask (bool): whether to use include mask data
"""
def __init__(self):
super(ArrangeTestSSD, self).__init__()
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items: (image)
"""
im = sample['image']
im_id = sample['im_id']
h = sample['h']
w = sample['w']
im_shape = np.array((h, w))
outs = (im, im_id, im_shape)
return outs
@register_op
class ArrangeYOLO(BaseOperator):
"""
Transform dict to the tuple format needed for training.
"""
def __init__(self):
super(ArrangeYOLO, self).__init__()
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items:
(image, gt_bbox, gt_class, gt_score,
is_crowd, im_info, gt_masks)
"""
im = sample['image']
if len(sample['gt_bbox']) != len(sample['gt_class']):
raise ValueError("gt num mismatch: bbox and class.")
if len(sample['gt_bbox']) != len(sample['gt_score']):
raise ValueError("gt num mismatch: bbox and score.")
gt_bbox = np.zeros((50, 4), dtype=im.dtype)
gt_class = np.zeros((50, ), dtype=np.int32)
gt_score = np.zeros((50, ), dtype=im.dtype)
gt_num = min(50, len(sample['gt_bbox']))
if gt_num > 0:
gt_bbox[:gt_num, :] = sample['gt_bbox'][:gt_num, :]
gt_class[:gt_num] = sample['gt_class'][:gt_num, 0]
gt_score[:gt_num] = sample['gt_score'][:gt_num, 0]
# parse [x1, y1, x2, y2] to [x, y, w, h]
gt_bbox[:, 2:4] = gt_bbox[:, 2:4] - gt_bbox[:, :2]
gt_bbox[:, :2] = gt_bbox[:, :2] + gt_bbox[:, 2:4] / 2.
outs = (im, gt_bbox, gt_class, gt_score)
return outs
@register_op
class ArrangeEvalYOLO(BaseOperator):
"""
Transform dict to the tuple format needed for evaluation.
"""
def __init__(self):
super(ArrangeEvalYOLO, self).__init__()
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items:
(image, im_shape, im_id, gt_bbox, gt_class,
difficult)
"""
im = sample['image']
if len(sample['gt_bbox']) != len(sample['gt_class']):
raise ValueError("gt num mismatch: bbox and class.")
im_id = sample['im_id']
h = sample['h']
w = sample['w']
im_shape = np.array((h, w))
gt_bbox = np.zeros((50, 4), dtype=im.dtype)
gt_class = np.zeros((50, ), dtype=np.int32)
difficult = np.zeros((50, ), dtype=np.int32)
gt_num = min(50, len(sample['gt_bbox']))
if gt_num > 0:
gt_bbox[:gt_num, :] = sample['gt_bbox'][:gt_num, :]
gt_class[:gt_num] = sample['gt_class'][:gt_num, 0]
difficult[:gt_num] = sample['difficult'][:gt_num, 0]
outs = (im, im_shape, im_id, gt_bbox, gt_class, difficult)
return outs
@register_op
class ArrangeTestYOLO(BaseOperator):
"""
Transform dict to the tuple format needed for inference.
"""
def __init__(self):
super(ArrangeTestYOLO, self).__init__()
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items:
(image, gt_bbox, gt_class, gt_score, is_crowd,
im_info, gt_masks)
"""
im = sample['image']
im_id = sample['im_id']
h = sample['h']
w = sample['w']
im_shape = np.array((h, w))
outs = (im, im_shape, im_id)
return outs
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this file contains helper methods for BBOX processing
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import random
import math
import cv2
def meet_emit_constraint(src_bbox, sample_bbox):
center_x = (src_bbox[2] + src_bbox[0]) / 2
center_y = (src_bbox[3] + src_bbox[1]) / 2
if center_x >= sample_bbox[0] and \
center_x <= sample_bbox[2] and \
center_y >= sample_bbox[1] and \
center_y <= sample_bbox[3]:
return True
return False
def clip_bbox(src_bbox):
src_bbox[0] = max(min(src_bbox[0], 1.0), 0.0)
src_bbox[1] = max(min(src_bbox[1], 1.0), 0.0)
src_bbox[2] = max(min(src_bbox[2], 1.0), 0.0)
src_bbox[3] = max(min(src_bbox[3], 1.0), 0.0)
return src_bbox
def bbox_area(src_bbox):
if src_bbox[2] < src_bbox[0] or src_bbox[3] < src_bbox[1]:
return 0.
else:
width = src_bbox[2] - src_bbox[0]
height = src_bbox[3] - src_bbox[1]
return width * height
def is_overlap(object_bbox, sample_bbox):
if object_bbox[0] >= sample_bbox[2] or \
object_bbox[2] <= sample_bbox[0] or \
object_bbox[1] >= sample_bbox[3] or \
object_bbox[3] <= sample_bbox[1]:
return False
else:
return True
def filter_and_process(sample_bbox, bboxes, labels, scores=None):
new_bboxes = []
new_labels = []
new_scores = []
for i in range(len(bboxes)):
new_bbox = [0, 0, 0, 0]
obj_bbox = [bboxes[i][0], bboxes[i][1], bboxes[i][2], bboxes[i][3]]
if not meet_emit_constraint(obj_bbox, sample_bbox):
continue
if not is_overlap(obj_bbox, sample_bbox):
continue
sample_width = sample_bbox[2] - sample_bbox[0]
sample_height = sample_bbox[3] - sample_bbox[1]
new_bbox[0] = (obj_bbox[0] - sample_bbox[0]) / sample_width
new_bbox[1] = (obj_bbox[1] - sample_bbox[1]) / sample_height
new_bbox[2] = (obj_bbox[2] - sample_bbox[0]) / sample_width
new_bbox[3] = (obj_bbox[3] - sample_bbox[1]) / sample_height
new_bbox = clip_bbox(new_bbox)
if bbox_area(new_bbox) > 0:
new_bboxes.append(new_bbox)
new_labels.append([labels[i][0]])
if scores is not None:
new_scores.append([scores[i][0]])
bboxes = np.array(new_bboxes)
labels = np.array(new_labels)
scores = np.array(new_scores)
return bboxes, labels, scores
def bbox_area_sampling(bboxes, labels, scores, target_size, min_size):
new_bboxes = []
new_labels = []
new_scores = []
for i, bbox in enumerate(bboxes):
w = float((bbox[2] - bbox[0]) * target_size)
h = float((bbox[3] - bbox[1]) * target_size)
if w * h < float(min_size * min_size):
continue
else:
new_bboxes.append(bbox)
new_labels.append(labels[i])
if scores is not None and scores.size != 0:
new_scores.append(scores[i])
bboxes = np.array(new_bboxes)
labels = np.array(new_labels)
scores = np.array(new_scores)
return bboxes, labels, scores
def generate_sample_bbox(sampler):
scale = np.random.uniform(sampler[2], sampler[3])
aspect_ratio = np.random.uniform(sampler[4], sampler[5])
aspect_ratio = max(aspect_ratio, (scale**2.0))
aspect_ratio = min(aspect_ratio, 1 / (scale**2.0))
bbox_width = scale * (aspect_ratio**0.5)
bbox_height = scale / (aspect_ratio**0.5)
xmin_bound = 1 - bbox_width
ymin_bound = 1 - bbox_height
xmin = np.random.uniform(0, xmin_bound)
ymin = np.random.uniform(0, ymin_bound)
xmax = xmin + bbox_width
ymax = ymin + bbox_height
sampled_bbox = [xmin, ymin, xmax, ymax]
return sampled_bbox
def generate_sample_bbox_square(sampler, image_width, image_height):
scale = np.random.uniform(sampler[2], sampler[3])
aspect_ratio = np.random.uniform(sampler[4], sampler[5])
aspect_ratio = max(aspect_ratio, (scale**2.0))
aspect_ratio = min(aspect_ratio, 1 / (scale**2.0))
bbox_width = scale * (aspect_ratio**0.5)
bbox_height = scale / (aspect_ratio**0.5)
if image_height < image_width:
bbox_width = bbox_height * image_height / image_width
else:
bbox_height = bbox_width * image_width / image_height
xmin_bound = 1 - bbox_width
ymin_bound = 1 - bbox_height
xmin = np.random.uniform(0, xmin_bound)
ymin = np.random.uniform(0, ymin_bound)
xmax = xmin + bbox_width
ymax = ymin + bbox_height
sampled_bbox = [xmin, ymin, xmax, ymax]
return sampled_bbox
def data_anchor_sampling(bbox_labels, image_width, image_height, scale_array,
resize_width):
num_gt = len(bbox_labels)
# np.random.randint range: [low, high)
rand_idx = np.random.randint(0, num_gt) if num_gt != 0 else 0
if num_gt != 0:
norm_xmin = bbox_labels[rand_idx][0]
norm_ymin = bbox_labels[rand_idx][1]
norm_xmax = bbox_labels[rand_idx][2]
norm_ymax = bbox_labels[rand_idx][3]
xmin = norm_xmin * image_width
ymin = norm_ymin * image_height
wid = image_width * (norm_xmax - norm_xmin)
hei = image_height * (norm_ymax - norm_ymin)
range_size = 0
area = wid * hei
for scale_ind in range(0, len(scale_array) - 1):
if area > scale_array[scale_ind] ** 2 and area < \
scale_array[scale_ind + 1] ** 2:
range_size = scale_ind + 1
break
if area > scale_array[len(scale_array) - 2]**2:
range_size = len(scale_array) - 2
scale_choose = 0.0
if range_size == 0:
rand_idx_size = 0
else:
# np.random.randint range: [low, high)
rng_rand_size = np.random.randint(0, range_size + 1)
rand_idx_size = rng_rand_size % (range_size + 1)
if rand_idx_size == range_size:
min_resize_val = scale_array[rand_idx_size] / 2.0
max_resize_val = min(2.0 * scale_array[rand_idx_size],
2 * math.sqrt(wid * hei))
scale_choose = random.uniform(min_resize_val, max_resize_val)
else:
min_resize_val = scale_array[rand_idx_size] / 2.0
max_resize_val = 2.0 * scale_array[rand_idx_size]
scale_choose = random.uniform(min_resize_val, max_resize_val)
sample_bbox_size = wid * resize_width / scale_choose
w_off_orig = 0.0
h_off_orig = 0.0
if sample_bbox_size < max(image_height, image_width):
if wid <= sample_bbox_size:
w_off_orig = np.random.uniform(xmin + wid - sample_bbox_size,
xmin)
else:
w_off_orig = np.random.uniform(xmin,
xmin + wid - sample_bbox_size)
if hei <= sample_bbox_size:
h_off_orig = np.random.uniform(ymin + hei - sample_bbox_size,
ymin)
else:
h_off_orig = np.random.uniform(ymin,
ymin + hei - sample_bbox_size)
else:
w_off_orig = np.random.uniform(image_width - sample_bbox_size, 0.0)
h_off_orig = np.random.uniform(image_height - sample_bbox_size, 0.0)
w_off_orig = math.floor(w_off_orig)
h_off_orig = math.floor(h_off_orig)
# Figure out top left coordinates.
w_off = float(w_off_orig / image_width)
h_off = float(h_off_orig / image_height)
sampled_bbox = [
w_off, h_off, w_off + float(sample_bbox_size / image_width),
h_off + float(sample_bbox_size / image_height)
]
return sampled_bbox
else:
return 0
def jaccard_overlap(sample_bbox, object_bbox):
if sample_bbox[0] >= object_bbox[2] or \
sample_bbox[2] <= object_bbox[0] or \
sample_bbox[1] >= object_bbox[3] or \
sample_bbox[3] <= object_bbox[1]:
return 0
intersect_xmin = max(sample_bbox[0], object_bbox[0])
intersect_ymin = max(sample_bbox[1], object_bbox[1])
intersect_xmax = min(sample_bbox[2], object_bbox[2])
intersect_ymax = min(sample_bbox[3], object_bbox[3])
intersect_size = (intersect_xmax - intersect_xmin) * (
intersect_ymax - intersect_ymin)
sample_bbox_size = bbox_area(sample_bbox)
object_bbox_size = bbox_area(object_bbox)
overlap = intersect_size / (
sample_bbox_size + object_bbox_size - intersect_size)
return overlap
def intersect_bbox(bbox1, bbox2):
if bbox2[0] > bbox1[2] or bbox2[2] < bbox1[0] or \
bbox2[1] > bbox1[3] or bbox2[3] < bbox1[1]:
intersection_box = [0.0, 0.0, 0.0, 0.0]
else:
intersection_box = [
max(bbox1[0], bbox2[0]),
max(bbox1[1], bbox2[1]),
min(bbox1[2], bbox2[2]),
min(bbox1[3], bbox2[3])
]
return intersection_box
def bbox_coverage(bbox1, bbox2):
inter_box = intersect_bbox(bbox1, bbox2)
intersect_size = bbox_area(inter_box)
if intersect_size > 0:
bbox1_size = bbox_area(bbox1)
return intersect_size / bbox1_size
else:
return 0.
def satisfy_sample_constraint(sampler,
sample_bbox,
gt_bboxes,
satisfy_all=False):
if sampler[6] == 0 and sampler[7] == 0:
return True
satisfied = []
for i in range(len(gt_bboxes)):
object_bbox = [
gt_bboxes[i][0], gt_bboxes[i][1], gt_bboxes[i][2], gt_bboxes[i][3]
]
overlap = jaccard_overlap(sample_bbox, object_bbox)
if sampler[6] != 0 and \
overlap < sampler[6]:
satisfied.append(False)
continue
if sampler[7] != 0 and \
overlap > sampler[7]:
satisfied.append(False)
continue
satisfied.append(True)
if not satisfy_all:
return True
if satisfy_all:
return np.all(satisfied)
else:
return False
def satisfy_sample_constraint_coverage(sampler, sample_bbox, gt_bboxes):
if sampler[6] == 0 and sampler[7] == 0:
has_jaccard_overlap = False
else:
has_jaccard_overlap = True
if sampler[8] == 0 and sampler[9] == 0:
has_object_coverage = False
else:
has_object_coverage = True
if not has_jaccard_overlap and not has_object_coverage:
return True
found = False
for i in range(len(gt_bboxes)):
object_bbox = [
gt_bboxes[i][0], gt_bboxes[i][1], gt_bboxes[i][2], gt_bboxes[i][3]
]
if has_jaccard_overlap:
overlap = jaccard_overlap(sample_bbox, object_bbox)
if sampler[6] != 0 and \
overlap < sampler[6]:
continue
if sampler[7] != 0 and \
overlap > sampler[7]:
continue
found = True
if has_object_coverage:
object_coverage = bbox_coverage(object_bbox, sample_bbox)
if sampler[8] != 0 and \
object_coverage < sampler[8]:
continue
if sampler[9] != 0 and \
object_coverage > sampler[9]:
continue
found = True
if found:
return True
return found
def crop_image_sampling(img, sample_bbox, image_width, image_height,
target_size):
# no clipping here
xmin = int(sample_bbox[0] * image_width)
xmax = int(sample_bbox[2] * image_width)
ymin = int(sample_bbox[1] * image_height)
ymax = int(sample_bbox[3] * image_height)
w_off = xmin
h_off = ymin
width = xmax - xmin
height = ymax - ymin
cross_xmin = max(0.0, float(w_off))
cross_ymin = max(0.0, float(h_off))
cross_xmax = min(float(w_off + width - 1.0), float(image_width))
cross_ymax = min(float(h_off + height - 1.0), float(image_height))
cross_width = cross_xmax - cross_xmin
cross_height = cross_ymax - cross_ymin
roi_xmin = 0 if w_off >= 0 else abs(w_off)
roi_ymin = 0 if h_off >= 0 else abs(h_off)
roi_width = cross_width
roi_height = cross_height
roi_y1 = int(roi_ymin)
roi_y2 = int(roi_ymin + roi_height)
roi_x1 = int(roi_xmin)
roi_x2 = int(roi_xmin + roi_width)
cross_y1 = int(cross_ymin)
cross_y2 = int(cross_ymin + cross_height)
cross_x1 = int(cross_xmin)
cross_x2 = int(cross_xmin + cross_width)
sample_img = np.zeros((height, width, 3))
sample_img[roi_y1: roi_y2, roi_x1: roi_x2] = \
img[cross_y1: cross_y2, cross_x1: cross_x2]
sample_img = cv2.resize(
sample_img, (target_size, target_size), interpolation=cv2.INTER_AREA)
return sample_img
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# function:
# operators to process sample,
# eg: decode/resize/crop image
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
try:
from collections.abc import Sequence
except Exception:
from collections import Sequence
from numbers import Number
import uuid
import logging
import random
import math
import numpy as np
import cv2
from PIL import Image, ImageEnhance
# from ppdet.core.workspace import serializable
from .op_helper import (satisfy_sample_constraint, filter_and_process,
generate_sample_bbox, clip_bbox, data_anchor_sampling,
satisfy_sample_constraint_coverage, crop_image_sampling,
generate_sample_bbox_square, bbox_area_sampling)
logger = logging.getLogger(__name__)
registered_ops = []
def register_op(cls):
registered_ops.append(cls.__name__)
if not hasattr(BaseOperator, cls.__name__):
setattr(BaseOperator, cls.__name__, cls)
else:
raise KeyError("The {} class has been registered.".format(cls.__name__))
# return serializable(cls)
return cls
class BboxError(ValueError):
pass
class ImageError(ValueError):
pass
class BaseOperator(object):
def __init__(self, name=None):
if name is None:
name = self.__class__.__name__
self._id = name + '_' + str(uuid.uuid4())[-6:]
def __call__(self, sample, context=None):
""" Process a sample.
Args:
sample (dict): a dict of sample, eg: {'image':xx, 'label': xxx}
context (dict): info about this sample processing
Returns:
result (dict): a processed sample
"""
return sample
def __str__(self):
return str(self._id)
@register_op
class DecodeImage(BaseOperator):
def __init__(self, to_rgb=True, with_mixup=False):
""" Transform the image data to numpy format.
Args:
to_rgb (bool): whether to convert BGR to RGB
with_mixup (bool): whether or not to mixup image and gt_bbbox/gt_score
"""
super(DecodeImage, self).__init__()
self.to_rgb = to_rgb
self.with_mixup = with_mixup
if not isinstance(self.to_rgb, bool):
raise TypeError("{}: input type is invalid.".format(self))
if not isinstance(self.with_mixup, bool):
raise TypeError("{}: input type is invalid.".format(self))
def __call__(self, sample, context=None):
""" load image if 'im_file' field is not empty but 'image' is"""
if 'image' not in sample:
with open(sample['im_file'], 'rb') as f:
sample['image'] = f.read()
im = sample['image']
data = np.frombuffer(im, dtype='uint8')
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
if self.to_rgb:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
sample['image'] = im
if 'h' not in sample:
sample['h'] = im.shape[0]
if 'w' not in sample:
sample['w'] = im.shape[1]
# make default im_info with [h, w, 1]
sample['im_info'] = np.array([im.shape[0], im.shape[1], 1.],
dtype=np.float32)
# decode mixup image
if self.with_mixup and 'mixup' in sample:
self.__call__(sample['mixup'], context)
return sample
@register_op
class MultiscaleTestResize(BaseOperator):
def __init__(self,
origin_target_size=800,
origin_max_size=1333,
target_size=[],
max_size=2000,
interp=cv2.INTER_LINEAR,
use_flip=True):
"""
Rescale image to the each size in target size, and capped at max_size.
Args:
origin_target_size(int): original target size of image's short side.
origin_max_size(int): original max size of image.
target_size (list): A list of target sizes of image's short side.
max_size (int): the max size of image.
interp (int): the interpolation method.
use_flip (bool): whether use flip augmentation.
"""
super(MultiscaleTestResize, self).__init__()
self.origin_target_size = int(origin_target_size)
self.origin_max_size = int(origin_max_size)
self.max_size = int(max_size)
self.interp = int(interp)
self.use_flip = use_flip
if not isinstance(target_size, list):
raise TypeError(
"Type of target_size is invalid. Must be List, now is {}".
format(type(target_size)))
self.target_size = target_size
if not (isinstance(self.origin_target_size, int) and isinstance(
self.origin_max_size, int) and isinstance(self.max_size, int)
and isinstance(self.interp, int)):
raise TypeError("{}: input type is invalid.".format(self))
def __call__(self, sample, context=None):
""" Resize the image numpy for multi-scale test.
"""
origin_ims = {}
im = sample['image']
if not isinstance(im, np.ndarray):
raise TypeError("{}: image type is not numpy.".format(self))
if len(im.shape) != 3:
raise ImageError('{}: image is not 3-dimensional.'.format(self))
im_shape = im.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
if float(im_size_min) == 0:
raise ZeroDivisionError('{}: min size of image is 0'.format(self))
base_name_list = ['image']
origin_ims['image'] = im
if self.use_flip:
sample['flip_image'] = im[:, ::-1, :]
base_name_list.append('flip_image')
origin_ims['flip_image'] = sample['flip_image']
im_info = []
for base_name in base_name_list:
im_scale = float(self.origin_target_size) / float(im_size_min)
# Prevent the biggest axis from being more than max_size
if np.round(im_scale * im_size_max) > self.origin_max_size:
im_scale = float(self.origin_max_size) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
resize_w = np.round(im_scale_x * float(im_shape[1]))
resize_h = np.round(im_scale_y * float(im_shape[0]))
im_resize = cv2.resize(
origin_ims[base_name],
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
im_info.extend([resize_h, resize_w, im_scale])
sample[base_name] = im_resize
for i, size in enumerate(self.target_size):
im_scale = float(size) / float(im_size_min)
if np.round(im_scale * im_size_max) > self.max_size:
im_scale = float(self.max_size) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
resize_w = np.round(im_scale_x * float(im_shape[1]))
resize_h = np.round(im_scale_y * float(im_shape[0]))
im_resize = cv2.resize(
origin_ims[base_name],
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
im_info.extend([resize_h, resize_w, im_scale])
name = base_name + '_scale_' + str(i)
sample[name] = im_resize
sample['im_info'] = np.array(im_info, dtype=np.float32)
return sample
@register_op
class ResizeImage(BaseOperator):
def __init__(self,
target_size=0,
max_size=0,
interp=cv2.INTER_LINEAR,
use_cv2=True):
"""
Rescale image to the specified target size, and capped at max_size
if max_size != 0.
If target_size is list, selected a scale randomly as the specified
target size.
Args:
target_size (int|list): the target size of image's short side,
multi-scale training is adopted when type is list.
max_size (int): the max size of image
interp (int): the interpolation method
use_cv2 (bool): use the cv2 interpolation method or use PIL
interpolation method
"""
super(ResizeImage, self).__init__()
self.max_size = int(max_size)
self.interp = int(interp)
self.use_cv2 = use_cv2
if not (isinstance(target_size, int) or isinstance(target_size, list)):
raise TypeError(
"Type of target_size is invalid. Must be Integer or List, now is {}"
.format(type(target_size)))
self.target_size = target_size
if not (isinstance(self.max_size, int)
and isinstance(self.interp, int)):
raise TypeError("{}: input type is invalid.".format(self))
def __call__(self, sample, context=None):
""" Resize the image numpy.
"""
im = sample['image']
if not isinstance(im, np.ndarray):
raise TypeError("{}: image type is not numpy.".format(self))
if len(im.shape) != 3:
raise ImageError('{}: image is not 3-dimensional.'.format(self))
im_shape = im.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
if isinstance(self.target_size, list):
# Case for multi-scale training
selected_size = random.choice(self.target_size)
else:
selected_size = self.target_size
if float(im_size_min) == 0:
raise ZeroDivisionError('{}: min size of image is 0'.format(self))
if self.max_size != 0:
im_scale = float(selected_size) / float(im_size_min)
# Prevent the biggest axis from being more than max_size
if np.round(im_scale * im_size_max) > self.max_size:
im_scale = float(self.max_size) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
resize_w = im_scale_x * float(im_shape[1])
resize_h = im_scale_y * float(im_shape[0])
im_info = [resize_h, resize_w, im_scale]
if 'im_info' in sample and sample['im_info'][2] != 1.:
sample['im_info'] = np.append(list(sample['im_info']),
im_info).astype(np.float32)
else:
sample['im_info'] = np.array(im_info).astype(np.float32)
else:
im_scale_x = float(selected_size) / float(im_shape[1])
im_scale_y = float(selected_size) / float(im_shape[0])
resize_w = selected_size
resize_h = selected_size
if self.use_cv2:
im = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
else:
if self.max_size != 0:
raise TypeError(
'If you set max_size to cap the maximum size of image,'
'please set use_cv2 to True to resize the image.')
im = Image.fromarray(im)
im = im.resize((int(resize_w), int(resize_h)), self.interp)
im = np.array(im)
sample['image'] = im
return sample
@register_op
class RandomFlipImage(BaseOperator):
def __init__(self, prob=0.5, is_normalized=False, is_mask_flip=False):
"""
Args:
prob (float): the probability of flipping image
is_normalized (bool): whether the bbox scale to [0,1]
is_mask_flip (bool): whether flip the segmentation
"""
super(RandomFlipImage, self).__init__()
self.prob = prob
self.is_normalized = is_normalized
self.is_mask_flip = is_mask_flip
if not (isinstance(self.prob, float)
and isinstance(self.is_normalized, bool)
and isinstance(self.is_mask_flip, bool)):
raise TypeError("{}: input type is invalid.".format(self))
def flip_segms(self, segms, height, width):
def _flip_poly(poly, width):
flipped_poly = np.array(poly)
flipped_poly[0::2] = width - np.array(poly[0::2]) - 1
return flipped_poly.tolist()
def _flip_rle(rle, height, width):
if 'counts' in rle and type(rle['counts']) == list:
rle = mask_util.frPyObjects([rle], height, width)
mask = mask_util.decode(rle)
mask = mask[:, ::-1, :]
rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
return rle
def is_poly(segm):
assert isinstance(segm, (list, dict)), \
"Invalid segm type: {}".format(type(segm))
return isinstance(segm, list)
flipped_segms = []
for segm in segms:
if is_poly(segm):
# Polygon format
flipped_segms.append([_flip_poly(poly, width) for poly in segm])
else:
# RLE format
import pycocotools.mask as mask_util
flipped_segms.append(_flip_rle(segm, height, width))
return flipped_segms
def __call__(self, sample, context=None):
"""Filp the image and bounding box.
Operators:
1. Flip the image numpy.
2. Transform the bboxes' x coordinates.
(Must judge whether the coordinates are normalized!)
3. Transform the segmentations' x coordinates.
(Must judge whether the coordinates are normalized!)
Output:
sample: the image, bounding box and segmentation part
in sample are flipped.
"""
gt_bbox = sample['gt_bbox']
im = sample['image']
if not isinstance(im, np.ndarray):
raise TypeError("{}: image is not a numpy array.".format(self))
if len(im.shape) != 3:
raise ImageError("{}: image is not 3-dimensional.".format(self))
height, width, _ = im.shape
if np.random.uniform(0, 1) < self.prob:
im = im[:, ::-1, :]
if gt_bbox.shape[0] == 0:
return sample
oldx1 = gt_bbox[:, 0].copy()
oldx2 = gt_bbox[:, 2].copy()
if self.is_normalized:
gt_bbox[:, 0] = 1 - oldx2
gt_bbox[:, 2] = 1 - oldx1
else:
gt_bbox[:, 0] = width - oldx2 - 1
gt_bbox[:, 2] = width - oldx1 - 1
if gt_bbox.shape[0] != 0 and (gt_bbox[:, 2] < gt_bbox[:, 0]).all():
m = "{}: invalid box, x2 should be greater than x1".format(self)
raise BboxError(m)
sample['gt_bbox'] = gt_bbox
if self.is_mask_flip and len(sample['gt_poly']) != 0:
sample['gt_poly'] = self.flip_segms(sample['gt_poly'], height,
width)
sample['flipped'] = True
sample['image'] = im
return sample
@register_op
class NormalizeImage(BaseOperator):
def __init__(self,
mean=[0.485, 0.456, 0.406],
std=[1, 1, 1],
is_scale=True,
is_channel_first=True):
"""
Args:
mean (list): the pixel mean
std (list): the pixel variance
"""
super(NormalizeImage, self).__init__()
self.mean = mean
self.std = std
self.is_scale = is_scale
self.is_channel_first = is_channel_first
if not (isinstance(self.mean, list) and isinstance(self.std, list)
and isinstance(self.is_scale, bool)):
raise TypeError("{}: input type is invalid.".format(self))
from functools import reduce
if reduce(lambda x, y: x * y, self.std) == 0:
raise ValueError('{}: std is invalid!'.format(self))
def __call__(self, sample, context=None):
"""Normalize the image.
Operators:
1.(optional) Scale the image to [0,1]
2. Each pixel minus mean and is divided by std
"""
for k in sample.keys():
if 'image' in k:
im = sample[k]
im = im.astype(np.float32, copy=False)
if self.is_channel_first:
mean = np.array(self.mean)[:, np.newaxis, np.newaxis]
std = np.array(self.std)[:, np.newaxis, np.newaxis]
else:
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
if self.is_scale:
im = im / 255.0
im -= mean
im /= std
sample[k] = im
return sample
@register_op
class RandomDistort(BaseOperator):
def __init__(self,
brightness_lower=0.5,
brightness_upper=1.5,
contrast_lower=0.5,
contrast_upper=1.5,
saturation_lower=0.5,
saturation_upper=1.5,
hue_lower=-18,
hue_upper=18,
brightness_prob=0.5,
contrast_prob=0.5,
saturation_prob=0.5,
hue_prob=0.5,
count=4,
is_order=False):
"""
Args:
brightness_lower/ brightness_upper (float): the brightness
between brightness_lower and brightness_upper
contrast_lower/ contrast_upper (float): the contrast between
contrast_lower and contrast_lower
saturation_lower/ saturation_upper (float): the saturation
between saturation_lower and saturation_upper
hue_lower/ hue_upper (float): the hue between
hue_lower and hue_upper
brightness_prob (float): the probability of changing brightness
contrast_prob (float): the probability of changing contrast
saturation_prob (float): the probability of changing saturation
hue_prob (float): the probability of changing hue
count (int): the kinds of doing distrot
is_order (bool): whether determine the order of distortion
"""
super(RandomDistort, self).__init__()
self.brightness_lower = brightness_lower
self.brightness_upper = brightness_upper
self.contrast_lower = contrast_lower
self.contrast_upper = contrast_upper
self.saturation_lower = saturation_lower
self.saturation_upper = saturation_upper
self.hue_lower = hue_lower
self.hue_upper = hue_upper
self.brightness_prob = brightness_prob
self.contrast_prob = contrast_prob
self.saturation_prob = saturation_prob
self.hue_prob = hue_prob
self.count = count
self.is_order = is_order
def random_brightness(self, img):
brightness_delta = np.random.uniform(self.brightness_lower,
self.brightness_upper)
prob = np.random.uniform(0, 1)
if prob < self.brightness_prob:
img = ImageEnhance.Brightness(img).enhance(brightness_delta)
return img
def random_contrast(self, img):
contrast_delta = np.random.uniform(self.contrast_lower,
self.contrast_upper)
prob = np.random.uniform(0, 1)
if prob < self.contrast_prob:
img = ImageEnhance.Contrast(img).enhance(contrast_delta)
return img
def random_saturation(self, img):
saturation_delta = np.random.uniform(self.saturation_lower,
self.saturation_upper)
prob = np.random.uniform(0, 1)
if prob < self.saturation_prob:
img = ImageEnhance.Color(img).enhance(saturation_delta)
return img
def random_hue(self, img):
hue_delta = np.random.uniform(self.hue_lower, self.hue_upper)
prob = np.random.uniform(0, 1)
if prob < self.hue_prob:
img = np.array(img.convert('HSV'))
img[:, :, 0] = img[:, :, 0] + hue_delta
img = Image.fromarray(img, mode='HSV').convert('RGB')
return img
def __call__(self, sample, context):
"""random distort the image"""
ops = [
self.random_brightness, self.random_contrast,
self.random_saturation, self.random_hue
]
if self.is_order:
prob = np.random.uniform(0, 1)
if prob < 0.5:
ops = [
self.random_brightness,
self.random_saturation,
self.random_hue,
self.random_contrast,
]
else:
ops = random.sample(ops, self.count)
assert 'image' in sample, "image data not found"
im = sample['image']
im = Image.fromarray(im)
for id in range(self.count):
im = ops[id](im)
im = np.asarray(im)
sample['image'] = im
return sample
@register_op
class ExpandImage(BaseOperator):
def __init__(self, max_ratio, prob, mean=[127.5, 127.5, 127.5]):
"""
Args:
max_ratio (float): the ratio of expanding
prob (float): the probability of expanding image
mean (list): the pixel mean
"""
super(ExpandImage, self).__init__()
self.max_ratio = max_ratio
self.mean = mean
self.prob = prob
def __call__(self, sample, context):
"""
Expand the image and modify bounding box.
Operators:
1. Scale the image width and height.
2. Construct new images with new height and width.
3. Fill the new image with the mean.
4. Put original imge into new image.
5. Rescale the bounding box.
6. Determine if the new bbox is satisfied in the new image.
Returns:
sample: the image, bounding box are replaced.
"""
prob = np.random.uniform(0, 1)
assert 'image' in sample, 'not found image data'
im = sample['image']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
im_width = sample['w']
im_height = sample['h']
if prob < self.prob:
if self.max_ratio - 1 >= 0.01:
expand_ratio = np.random.uniform(1, self.max_ratio)
height = int(im_height * expand_ratio)
width = int(im_width * expand_ratio)
h_off = math.floor(np.random.uniform(0, height - im_height))
w_off = math.floor(np.random.uniform(0, width - im_width))
expand_bbox = [
-w_off / im_width, -h_off / im_height,
(width - w_off) / im_width, (height - h_off) / im_height
]
expand_im = np.ones((height, width, 3))
expand_im = np.uint8(expand_im * np.squeeze(self.mean))
expand_im = Image.fromarray(expand_im)
im = Image.fromarray(im)
expand_im.paste(im, (int(w_off), int(h_off)))
expand_im = np.asarray(expand_im)
gt_bbox, gt_class, _ = filter_and_process(
expand_bbox, gt_bbox, gt_class)
sample['image'] = expand_im
sample['gt_bbox'] = gt_bbox
sample['gt_class'] = gt_class
sample['w'] = width
sample['h'] = height
return sample
@register_op
class CropImage(BaseOperator):
def __init__(self, batch_sampler, satisfy_all=False, avoid_no_bbox=True):
"""
Args:
batch_sampler (list): Multiple sets of different
parameters for cropping.
satisfy_all (bool): whether all boxes must satisfy.
e.g.[[1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]]
[max sample, max trial, min scale, max scale,
min aspect ratio, max aspect ratio,
min overlap, max overlap]
avoid_no_bbox (bool): whether to to avoid the
situation where the box does not appear.
"""
super(CropImage, self).__init__()
self.batch_sampler = batch_sampler
self.satisfy_all = satisfy_all
self.avoid_no_bbox = avoid_no_bbox
def __call__(self, sample, context):
"""
Crop the image and modify bounding box.
Operators:
1. Scale the image width and height.
2. Crop the image according to a radom sample.
3. Rescale the bounding box.
4. Determine if the new bbox is satisfied in the new image.
Returns:
sample: the image, bounding box are replaced.
"""
assert 'image' in sample, "image data not found"
im = sample['image']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
im_width = sample['w']
im_height = sample['h']
gt_score = None
if 'gt_score' in sample:
gt_score = sample['gt_score']
sampled_bbox = []
gt_bbox = gt_bbox.tolist()
for sampler in self.batch_sampler:
found = 0
for i in range(sampler[1]):
if found >= sampler[0]:
break
sample_bbox = generate_sample_bbox(sampler)
if satisfy_sample_constraint(sampler, sample_bbox, gt_bbox,
self.satisfy_all):
sampled_bbox.append(sample_bbox)
found = found + 1
im = np.array(im)
while sampled_bbox:
idx = int(np.random.uniform(0, len(sampled_bbox)))
sample_bbox = sampled_bbox.pop(idx)
sample_bbox = clip_bbox(sample_bbox)
crop_bbox, crop_class, crop_score = \
filter_and_process(sample_bbox, gt_bbox, gt_class, gt_score)
if self.avoid_no_bbox:
if len(crop_bbox) < 1:
continue
xmin = int(sample_bbox[0] * im_width)
xmax = int(sample_bbox[2] * im_width)
ymin = int(sample_bbox[1] * im_height)
ymax = int(sample_bbox[3] * im_height)
im = im[ymin:ymax, xmin:xmax]
sample['image'] = im
sample['gt_bbox'] = crop_bbox
sample['gt_class'] = crop_class
sample['gt_score'] = crop_score
return sample
return sample
@register_op
class CropImageWithDataAchorSampling(BaseOperator):
def __init__(self,
batch_sampler,
anchor_sampler=None,
target_size=None,
das_anchor_scales=[16, 32, 64, 128],
sampling_prob=0.5,
min_size=8.,
avoid_no_bbox=True):
"""
Args:
anchor_sampler (list): anchor_sampling sets of different
parameters for cropping.
batch_sampler (list): Multiple sets of different
parameters for cropping.
e.g.[[1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0]]
[[1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]]
[max sample, max trial, min scale, max scale,
min aspect ratio, max aspect ratio,
min overlap, max overlap, min coverage, max coverage]
target_size (bool): target image size.
das_anchor_scales (list[float]): a list of anchor scales in data
anchor smapling.
min_size (float): minimum size of sampled bbox.
avoid_no_bbox (bool): whether to to avoid the
situation where the box does not appear.
"""
super(CropImageWithDataAchorSampling, self).__init__()
self.anchor_sampler = anchor_sampler
self.batch_sampler = batch_sampler
self.target_size = target_size
self.sampling_prob = sampling_prob
self.min_size = min_size
self.avoid_no_bbox = avoid_no_bbox
self.das_anchor_scales = np.array(das_anchor_scales)
def __call__(self, sample, context):
"""
Crop the image and modify bounding box.
Operators:
1. Scale the image width and height.
2. Crop the image according to a radom sample.
3. Rescale the bounding box.
4. Determine if the new bbox is satisfied in the new image.
Returns:
sample: the image, bounding box are replaced.
"""
assert 'image' in sample, "image data not found"
im = sample['image']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
image_width = sample['w']
image_height = sample['h']
gt_score = None
if 'gt_score' in sample:
gt_score = sample['gt_score']
sampled_bbox = []
gt_bbox = gt_bbox.tolist()
prob = np.random.uniform(0., 1.)
if prob > self.sampling_prob: # anchor sampling
assert self.anchor_sampler
for sampler in self.anchor_sampler:
found = 0
for i in range(sampler[1]):
if found >= sampler[0]:
break
sample_bbox = data_anchor_sampling(
gt_bbox, image_width, image_height,
self.das_anchor_scales, self.target_size)
if sample_bbox == 0:
break
if satisfy_sample_constraint_coverage(
sampler, sample_bbox, gt_bbox):
sampled_bbox.append(sample_bbox)
found = found + 1
im = np.array(im)
while sampled_bbox:
idx = int(np.random.uniform(0, len(sampled_bbox)))
sample_bbox = sampled_bbox.pop(idx)
crop_bbox, crop_class, crop_score = filter_and_process(
sample_bbox, gt_bbox, gt_class, gt_score)
crop_bbox, crop_class, crop_score = bbox_area_sampling(
crop_bbox, crop_class, crop_score, self.target_size,
self.min_size)
if self.avoid_no_bbox:
if len(crop_bbox) < 1:
continue
im = crop_image_sampling(im, sample_bbox, image_width,
image_height, self.target_size)
sample['image'] = im
sample['gt_bbox'] = crop_bbox
sample['gt_class'] = crop_class
sample['gt_score'] = crop_score
return sample
return sample
else:
for sampler in self.batch_sampler:
found = 0
for i in range(sampler[1]):
if found >= sampler[0]:
break
sample_bbox = generate_sample_bbox_square(
sampler, image_width, image_height)
if satisfy_sample_constraint_coverage(
sampler, sample_bbox, gt_bbox):
sampled_bbox.append(sample_bbox)
found = found + 1
im = np.array(im)
while sampled_bbox:
idx = int(np.random.uniform(0, len(sampled_bbox)))
sample_bbox = sampled_bbox.pop(idx)
sample_bbox = clip_bbox(sample_bbox)
crop_bbox, crop_class, crop_score = filter_and_process(
sample_bbox, gt_bbox, gt_class, gt_score)
# sampling bbox according the bbox area
crop_bbox, crop_class, crop_score = bbox_area_sampling(
crop_bbox, crop_class, crop_score, self.target_size,
self.min_size)
if self.avoid_no_bbox:
if len(crop_bbox) < 1:
continue
xmin = int(sample_bbox[0] * image_width)
xmax = int(sample_bbox[2] * image_width)
ymin = int(sample_bbox[1] * image_height)
ymax = int(sample_bbox[3] * image_height)
im = im[ymin:ymax, xmin:xmax]
sample['image'] = im
sample['gt_bbox'] = crop_bbox
sample['gt_class'] = crop_class
sample['gt_score'] = crop_score
return sample
return sample
@register_op
class NormalizeBox(BaseOperator):
"""Transform the bounding box's coornidates to [0,1]."""
def __init__(self):
super(NormalizeBox, self).__init__()
def __call__(self, sample, context):
gt_bbox = sample['gt_bbox']
width = sample['w']
height = sample['h']
for i in range(gt_bbox.shape[0]):
gt_bbox[i][0] = gt_bbox[i][0] / width
gt_bbox[i][1] = gt_bbox[i][1] / height
gt_bbox[i][2] = gt_bbox[i][2] / width
gt_bbox[i][3] = gt_bbox[i][3] / height
sample['gt_bbox'] = gt_bbox
return sample
@register_op
class Permute(BaseOperator):
def __init__(self, to_bgr=True, channel_first=True):
"""
Change the channel.
Args:
to_bgr (bool): confirm whether to convert RGB to BGR
channel_first (bool): confirm whether to change channel
"""
super(Permute, self).__init__()
self.to_bgr = to_bgr
self.channel_first = channel_first
if not (isinstance(self.to_bgr, bool)
and isinstance(self.channel_first, bool)):
raise TypeError("{}: input type is invalid.".format(self))
def __call__(self, sample, context=None):
assert 'image' in sample, "image data not found"
for k in sample.keys():
if 'image' in k:
im = sample[k]
if self.channel_first:
im = np.swapaxes(im, 1, 2)
im = np.swapaxes(im, 1, 0)
if self.to_bgr:
im = im[[2, 1, 0], :, :]
sample[k] = im
return sample
@register_op
class MixupImage(BaseOperator):
def __init__(self, alpha=1.5, beta=1.5):
""" Mixup image and gt_bbbox/gt_score
Args:
alpha (float): alpha parameter of beta distribute
beta (float): beta parameter of beta distribute
"""
super(MixupImage, self).__init__()
self.alpha = alpha
self.beta = beta
if self.alpha <= 0.0:
raise ValueError("alpha shold be positive in {}".format(self))
if self.beta <= 0.0:
raise ValueError("beta shold be positive in {}".format(self))
def _mixup_img(self, img1, img2, factor):
h = max(img1.shape[0], img2.shape[0])
w = max(img1.shape[1], img2.shape[1])
img = np.zeros((h, w, img1.shape[2]), 'float32')
img[:img1.shape[0], :img1.shape[1], :] = \
img1.astype('float32') * factor
img[:img2.shape[0], :img2.shape[1], :] += \
img2.astype('float32') * (1.0 - factor)
return img.astype('uint8')
def __call__(self, sample, context=None):
if 'mixup' not in sample:
return sample
factor = np.random.beta(self.alpha, self.beta)
factor = max(0.0, min(1.0, factor))
if factor >= 1.0:
sample.pop('mixup')
return sample
if factor <= 0.0:
return sample['mixup']
im = self._mixup_img(sample['image'], sample['mixup']['image'], factor)
gt_bbox1 = sample['gt_bbox']
gt_bbox2 = sample['mixup']['gt_bbox']
gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
gt_class1 = sample['gt_class']
gt_class2 = sample['mixup']['gt_class']
gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
gt_score1 = sample['gt_score']
gt_score2 = sample['mixup']['gt_score']
gt_score = np.concatenate(
(gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
sample['image'] = im
sample['gt_bbox'] = gt_bbox
sample['gt_score'] = gt_score
sample['gt_class'] = gt_class
sample['h'] = im.shape[0]
sample['w'] = im.shape[1]
sample.pop('mixup')
return sample
@register_op
class RandomInterpImage(BaseOperator):
def __init__(self, target_size=0, max_size=0):
"""
Random reisze image by multiply interpolate method.
Args:
target_size (int): the taregt size of image's short side
max_size (int): the max size of image
"""
super(RandomInterpImage, self).__init__()
self.target_size = target_size
self.max_size = max_size
if not (isinstance(self.target_size, int)
and isinstance(self.max_size, int)):
raise TypeError('{}: input type is invalid.'.format(self))
interps = [
cv2.INTER_NEAREST,
cv2.INTER_LINEAR,
cv2.INTER_AREA,
cv2.INTER_CUBIC,
cv2.INTER_LANCZOS4,
]
self.resizers = []
for interp in interps:
self.resizers.append(ResizeImage(target_size, max_size, interp))
def __call__(self, sample, context=None):
"""Resise the image numpy by random resizer."""
resizer = random.choice(self.resizers)
return resizer(sample, context)
@register_op
class Resize(BaseOperator):
"""Resize image and bbox.
Args:
target_dim (int or list): target size, can be a single number or a list
(for random shape).
interp (int or str): interpolation method, can be an integer or
'random' (for randomized interpolation).
default to `cv2.INTER_LINEAR`.
"""
def __init__(self, target_dim=[], interp=cv2.INTER_LINEAR):
super(Resize, self).__init__()
self.target_dim = target_dim
self.interp = interp # 'random' for yolov3
def __call__(self, sample, context=None):
w = sample['w']
h = sample['h']
interp = self.interp
if interp == 'random':
interp = np.random.choice(range(5))
if isinstance(self.target_dim, Sequence):
dim = np.random.choice(self.target_dim)
else:
dim = self.target_dim
resize_w = resize_h = dim
scale_x = dim / w
scale_y = dim / h
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
scale_array = np.array([scale_x, scale_y] * 2, dtype=np.float32)
sample['gt_bbox'] = np.clip(sample['gt_bbox'] * scale_array, 0,
dim - 1)
sample['h'] = resize_h
sample['w'] = resize_w
sample['image'] = cv2.resize(
sample['image'], (resize_w, resize_h), interpolation=interp)
return sample
@register_op
class ColorDistort(BaseOperator):
"""Random color distortion.
Args:
hue (list): hue settings.
in [lower, upper, probability] format.
saturation (list): saturation settings.
in [lower, upper, probability] format.
contrast (list): contrast settings.
in [lower, upper, probability] format.
brightness (list): brightness settings.
in [lower, upper, probability] format.
random_apply (bool): whether to apply in random (yolo) or fixed (SSD)
order.
"""
def __init__(self,
hue=[-18, 18, 0.5],
saturation=[0.5, 1.5, 0.5],
contrast=[0.5, 1.5, 0.5],
brightness=[0.5, 1.5, 0.5],
random_apply=True):
super(ColorDistort, self).__init__()
self.hue = hue
self.saturation = saturation
self.contrast = contrast
self.brightness = brightness
self.random_apply = random_apply
def apply_hue(self, img):
low, high, prob = self.hue
if np.random.uniform(0., 1.) < prob:
return img
img = img.astype(np.float32)
# XXX works, but result differ from HSV version
delta = np.random.uniform(low, high)
u = np.cos(delta * np.pi)
w = np.sin(delta * np.pi)
bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321],
[0.211, -0.523, 0.311]])
ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647],
[1.0, -1.107, 1.705]])
t = np.dot(np.dot(ityiq, bt), tyiq).T
img = np.dot(img, t)
return img
def apply_saturation(self, img):
low, high, prob = self.saturation
if np.random.uniform(0., 1.) < prob:
return img
delta = np.random.uniform(low, high)
img = img.astype(np.float32)
gray = img * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32)
gray = gray.sum(axis=2, keepdims=True)
gray *= (1.0 - delta)
img *= delta
img += gray
return img
def apply_contrast(self, img):
low, high, prob = self.contrast
if np.random.uniform(0., 1.) < prob:
return img
delta = np.random.uniform(low, high)
img = img.astype(np.float32)
img *= delta
return img
def apply_brightness(self, img):
low, high, prob = self.brightness
if np.random.uniform(0., 1.) < prob:
return img
delta = np.random.uniform(low, high)
img = img.astype(np.float32)
img += delta
return img
def __call__(self, sample, context=None):
img = sample['image']
if self.random_apply:
distortions = np.random.permutation([
self.apply_brightness, self.apply_contrast,
self.apply_saturation, self.apply_hue
])
for func in distortions:
img = func(img)
sample['image'] = img
return sample
img = self.apply_brightness(img)
if np.random.randint(0, 2):
img = self.apply_contrast(img)
img = self.apply_saturation(img)
img = self.apply_hue(img)
else:
img = self.apply_saturation(img)
img = self.apply_hue(img)
img = self.apply_contrast(img)
sample['image'] = img
return sample
@register_op
class NormalizePermute(BaseOperator):
"""Normalize and permute channel order.
Args:
mean (list): mean values in RGB order.
std (list): std values in RGB order.
"""
def __init__(self,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.120, 57.375]):
super(NormalizePermute, self).__init__()
self.mean = mean
self.std = std
def __call__(self, sample, context=None):
img = sample['image']
img = img.astype(np.float32)
img = img.transpose((2, 0, 1))
mean = np.array(self.mean, dtype=np.float32)
std = np.array(self.std, dtype=np.float32)
invstd = 1. / std
for v, m, s in zip(img, mean, invstd):
v.__isub__(m).__imul__(s)
sample['image'] = img
return sample
@register_op
class RandomExpand(BaseOperator):
"""Random expand the canvas.
Args:
ratio (float): maximum expansion ratio.
prob (float): probability to expand.
fill_value (list): color value used to fill the canvas. in RGB order.
"""
def __init__(self, ratio=4., prob=0.5, fill_value=(127.5, ) * 3):
super(RandomExpand, self).__init__()
assert ratio > 1.01, "expand ratio must be larger than 1.01"
self.ratio = ratio
self.prob = prob
assert isinstance(fill_value, (Number, Sequence)), \
"fill value must be either float or sequence"
if isinstance(fill_value, Number):
fill_value = (fill_value, ) * 3
if not isinstance(fill_value, tuple):
fill_value = tuple(fill_value)
self.fill_value = fill_value
def __call__(self, sample, context=None):
if np.random.uniform(0., 1.) < self.prob:
return sample
img = sample['image']
height = int(sample['h'])
width = int(sample['w'])
expand_ratio = np.random.uniform(1., self.ratio)
h = int(height * expand_ratio)
w = int(width * expand_ratio)
if not h > height or not w > width:
return sample
y = np.random.randint(0, h - height)
x = np.random.randint(0, w - width)
canvas = np.ones((h, w, 3), dtype=np.uint8)
canvas *= np.array(self.fill_value, dtype=np.uint8)
canvas[y:y + height, x:x + width, :] = img.astype(np.uint8)
sample['h'] = h
sample['w'] = w
sample['image'] = canvas
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
sample['gt_bbox'] += np.array([x, y] * 2, dtype=np.float32)
return sample
@register_op
class RandomCrop(BaseOperator):
"""Random crop image and bboxes.
Args:
aspect_ratio (list): aspect ratio of cropped region.
in [min, max] format.
thresholds (list): iou thresholds for decide a valid bbox crop.
scaling (list): ratio between a cropped region and the original image.
in [min, max] format.
num_attempts (int): number of tries before giving up.
allow_no_crop (bool): allow return without actually cropping them.
cover_all_box (bool): ensure all bboxes are covered in the final crop.
"""
def __init__(self,
aspect_ratio=[.5, 2.],
thresholds=[.0, .1, .3, .5, .7, .9],
scaling=[.3, 1.],
num_attempts=50,
allow_no_crop=True,
cover_all_box=False):
super(RandomCrop, self).__init__()
self.aspect_ratio = aspect_ratio
self.thresholds = thresholds
self.scaling = scaling
self.num_attempts = num_attempts
self.allow_no_crop = allow_no_crop
self.cover_all_box = cover_all_box
def __call__(self, sample, context=None):
if 'gt_bbox' in sample and len(sample['gt_bbox']) == 0:
return sample
h = sample['h']
w = sample['w']
gt_bbox = sample['gt_bbox']
# NOTE Original method attempts to generate one candidate for each
# threshold then randomly sample one from the resulting list.
# Here a short circuit approach is taken, i.e., randomly choose a
# threshold and attempt to find a valid crop, and simply return the
# first one found.
# The probability is not exactly the same, kinda resembling the
# "Monty Hall" problem. Actually carrying out the attempts will affect
# observability (just like opening doors in the "Monty Hall" game).
thresholds = list(self.thresholds)
if self.allow_no_crop:
thresholds.append('no_crop')
np.random.shuffle(thresholds)
for thresh in thresholds:
if thresh == 'no_crop':
return sample
found = False
for i in range(self.num_attempts):
scale = np.random.uniform(*self.scaling)
min_ar, max_ar = self.aspect_ratio
aspect_ratio = np.random.uniform(
max(min_ar, scale**2), min(max_ar, scale**-2))
crop_h = int(h * scale / np.sqrt(aspect_ratio))
crop_w = int(w * scale * np.sqrt(aspect_ratio))
crop_y = np.random.randint(0, h - crop_h)
crop_x = np.random.randint(0, w - crop_w)
crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
iou = self._iou_matrix(gt_bbox,
np.array([crop_box], dtype=np.float32))
if iou.max() < thresh:
continue
if self.cover_all_box and iou.min() < thresh:
continue
cropped_box, valid_ids = self._crop_box_with_center_constraint(
gt_bbox, np.array(crop_box, dtype=np.float32))
if valid_ids.size > 0:
found = True
break
if found:
sample['image'] = self._crop_image(sample['image'], crop_box)
sample['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0)
sample['gt_class'] = np.take(
sample['gt_class'], valid_ids, axis=0)
sample['w'] = crop_box[2] - crop_box[0]
sample['h'] = crop_box[3] - crop_box[1]
if 'gt_score' in sample:
sample['gt_score'] = np.take(
sample['gt_score'], valid_ids, axis=0)
return sample
return sample
def _iou_matrix(self, a, b):
tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
area_o = (area_a[:, np.newaxis] + area_b - area_i)
return area_i / (area_o + 1e-10)
def _crop_box_with_center_constraint(self, box, crop):
cropped_box = box.copy()
cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2])
cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:])
cropped_box[:, :2] -= crop[:2]
cropped_box[:, 2:] -= crop[:2]
centers = (box[:, :2] + box[:, 2:]) / 2
valid = np.logical_and(crop[:2] <= centers,
centers < crop[2:]).all(axis=1)
valid = np.logical_and(
valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
return cropped_box, np.where(valid)[0]
def _crop_image(self, img, crop):
x1, y1, x2, y2 = crop
return img[y1:y2, x1:x2, :]
# import Arrange OPs to register them to BaseOperator
from .arrange_sample import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# function:
# transform samples in 'source' using 'mapper'
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import six
import uuid
import logging
import signal
import threading
from .transformer import ProxiedDataset
logger = logging.getLogger(__name__)
class EndSignal(object):
def __init__(self, errno=0, errmsg=''):
self.errno = errno
self.errmsg = errmsg
class ParallelMappedDataset(ProxiedDataset):
"""
Transform samples to mapped samples which is similar to 'basic.MappedDataset',
but multiple workers (threads or processes) will be used
Notes:
this class is not thread-safe
"""
def __init__(self, source, mapper, worker_args):
super(ParallelMappedDataset, self).__init__(source)
worker_args = {k.lower(): v for k, v in worker_args.items()}
args = {
'bufsize': 100,
'worker_num': 8,
'use_process': False,
'memsize': '3G'
}
args.update(worker_args)
if args['use_process'] and type(args['memsize']) is str:
assert args['memsize'][-1].lower() == 'g', \
"invalid param for memsize[%s], should be ended with 'G' or 'g'" % (args['memsize'])
gb = args['memsize'][:-1]
args['memsize'] = int(gb) * 1024**3
self._worker_args = args
self._started = False
self._source = source
self._mapper = mapper
self._exit = False
self._setup()
def _setup(self):
"""setup input/output queues and workers """
use_process = self._worker_args.get('use_process', False)
if use_process and sys.platform == "win32":
logger.info("Use multi-thread reader instead of "
"multi-process reader on Windows.")
use_process = False
bufsize = self._worker_args['bufsize']
if use_process:
from .shared_queue import SharedQueue as Queue
from multiprocessing import Process as Worker
from multiprocessing import Event
memsize = self._worker_args['memsize']
self._inq = Queue(bufsize, memsize=memsize)
self._outq = Queue(bufsize, memsize=memsize)
else:
if six.PY3:
from queue import Queue
else:
from Queue import Queue
from threading import Thread as Worker
from threading import Event
self._inq = Queue(bufsize)
self._outq = Queue(bufsize)
consumer_num = self._worker_args['worker_num']
id = str(uuid.uuid4())[-3:]
self._producer = threading.Thread(
target=self._produce,
args=('producer-' + id, self._source, self._inq))
self._producer.daemon = True
self._consumers = []
for i in range(consumer_num):
p = Worker(
target=self._consume,
args=('consumer-' + id + '_' + str(i), self._inq, self._outq,
self._mapper))
self._consumers.append(p)
p.daemon = True
self._epoch = -1
self._feeding_ev = Event()
self._produced = 0 # produced sample in self._produce
self._consumed = 0 # consumed sample in self.next
self._stopped_consumers = 0
def _produce(self, id, source, inq):
"""Fetch data from source and feed it to 'inq' queue"""
while True:
self._feeding_ev.wait()
if self._exit:
break
try:
inq.put(source.next())
self._produced += 1
except StopIteration:
self._feeding_ev.clear()
self._feeding_ev.wait() # wait other guy to wake up me
logger.debug("producer[{}] starts new epoch".format(id))
except Exception as e:
msg = "producer[{}] failed with error: {}".format(id, str(e))
inq.put(EndSignal(-1, msg))
break
logger.debug("producer[{}] exits".format(id))
def _consume(self, id, inq, outq, mapper):
"""Fetch data from 'inq', process it and put result to 'outq'"""
while True:
sample = inq.get()
if isinstance(sample, EndSignal):
sample.errmsg += "[consumer[{}] exits]".format(id)
outq.put(sample)
logger.debug("end signal received, " +
"consumer[{}] exits".format(id))
break
try:
result = mapper(sample)
outq.put(result)
except Exception as e:
msg = 'failed to map consumer[%s], error: {}'.format(str(e), id)
outq.put(EndSignal(-1, msg))
break
def drained(self):
assert self._epoch >= 0, "first epoch has not started yet"
return self._source.drained() and self._produced == self._consumed
def stop(self):
""" notify to exit
"""
self._exit = True
self._feeding_ev.set()
for _ in range(len(self._consumers)):
self._inq.put(EndSignal(0, "notify consumers to exit"))
def next(self):
""" get next transformed sample
"""
if self._epoch < 0:
self.reset()
if self.drained():
raise StopIteration()
while True:
sample = self._outq.get()
if isinstance(sample, EndSignal):
self._stopped_consumers += 1
if sample.errno != 0:
logger.warn("consumer failed with error: {}".format(
sample.errmsg))
if self._stopped_consumers < len(self._consumers):
self._inq.put(sample)
else:
raise ValueError("all consumers exited, no more samples")
else:
self._consumed += 1
return sample
def reset(self):
""" reset for a new epoch of samples
"""
if self._epoch < 0:
self._epoch = 0
for p in self._consumers:
p.start()
self._producer.start()
else:
if not self.drained():
logger.warn("do not reset before epoch[%d] finishes".format(
self._epoch))
self._produced = self._produced - self._consumed
else:
self._produced = 0
self._epoch += 1
assert self._stopped_consumers == 0, "some consumers already exited," \
+ " cannot start another epoch"
self._source.reset()
self._consumed = 0
self._feeding_ev.set()
# FIXME(dengkaipeng): fix me if you have better impliment
# handle terminate reader process, do not print stack frame
def _reader_exit(signum, frame):
logger.debug("Reader process exit.")
sys.exit()
signal.signal(signal.SIGTERM, _reader_exit)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import cv2
import numpy as np
logger = logging.getLogger(__name__)
def build_post_map(coarsest_stride=1,
is_padding=False,
random_shapes=[],
multi_scales=[],
use_padded_im_info=False,
enable_multiscale_test=False,
num_scale=1):
"""
Build a mapper for post-processing batches
Args:
config (dict of parameters):
{
coarsest_stride (int): stride of the coarsest FPN level
is_padding (bool): whether to padding in minibatch
random_shapes (list of int): resize to image to random shapes,
[] for not resize.
multi_scales (list of int): resize image by random scales,
[] for not resize.
use_padded_im_info (bool): whether to update im_info after padding
enable_multiscale_test (bool): whether to use multiscale test.
num_scale (int) : the number of scales for multiscale test.
}
Returns:
a mapper function which accept one argument 'batch' and
return the processed result
"""
def padding_minibatch(batch_data):
if len(batch_data) == 1 and coarsest_stride == 1:
return batch_data
max_shape = np.array([data[0].shape for data in batch_data]).max(axis=0)
if coarsest_stride > 1:
max_shape[1] = int(
np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
max_shape[2] = int(
np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride)
padding_batch = []
for data in batch_data:
im_c, im_h, im_w = data[0].shape[:]
padding_im = np.zeros((im_c, max_shape[1], max_shape[2]),
dtype=np.float32)
padding_im[:, :im_h, :im_w] = data[0]
if use_padded_im_info:
data[1][:2] = max_shape[1:3]
padding_batch.append((padding_im, ) + data[1:])
return padding_batch
def padding_multiscale_test(batch_data):
if len(batch_data) != 1:
raise NotImplementedError(
"Batch size must be 1 when using multiscale test, but now batch size is {}"
.format(len(batch_data)))
if coarsest_stride > 1:
padding_batch = []
padding_images = []
data = batch_data[0]
for i, input in enumerate(data):
if i < num_scale:
im_c, im_h, im_w = input.shape
max_h = int(
np.ceil(im_h / coarsest_stride) * coarsest_stride)
max_w = int(
np.ceil(im_w / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, max_h, max_w),
dtype=np.float32)
padding_im[:, :im_h, :im_w] = input
data[num_scale][3 * i:3 * i + 2] = [max_h, max_w]
padding_batch.append(padding_im)
else:
padding_batch.append(input)
return [tuple(padding_batch)]
# no need to padding
return batch_data
def random_shape(batch_data):
# For YOLO: gt_bbox is normalized, is scale invariant.
shape = np.random.choice(random_shapes)
scaled_batch = []
h, w = batch_data[0][0].shape[1:3]
scale_x = float(shape) / w
scale_y = float(shape) / h
for data in batch_data:
im = cv2.resize(
data[0].transpose((1, 2, 0)),
None,
None,
fx=scale_x,
fy=scale_y,
interpolation=cv2.INTER_NEAREST)
scaled_batch.append((im.transpose(2, 0, 1), ) + data[1:])
return scaled_batch
def multi_scale_resize(batch_data):
# For RCNN: image shape in record in im_info.
scale = np.random.choice(multi_scales)
scaled_batch = []
for data in batch_data:
im = cv2.resize(
data[0].transpose((1, 2, 0)),
None,
None,
fx=scale,
fy=scale,
interpolation=cv2.INTER_NEAREST)
im_info = [im.shape[:2], scale]
scaled_batch.append((im.transpose(2, 0, 1), im_info) + data[2:])
return scaled_batch
def _mapper(batch_data):
try:
if is_padding:
batch_data = padding_minibatch(batch_data)
if len(random_shapes) > 0:
batch_data = random_shape(batch_data)
if len(multi_scales) > 0:
batch_data = multi_scale_resize(batch_data)
if enable_multiscale_test:
batch_data = padding_multiscale_test(batch_data)
except Exception as e:
errmsg = "post-process failed with error: " + str(e)
logger.warn(errmsg)
raise e
return batch_data
return _mapper
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
__all__ = ['SharedBuffer', 'SharedMemoryMgr', 'SharedQueue']
from .sharedmemory import SharedBuffer
from .sharedmemory import SharedMemoryMgr
from .sharedmemory import SharedMemoryError
from .queue import SharedQueue
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import six
if six.PY3:
import pickle
from io import BytesIO as StringIO
else:
import cPickle as pickle
from cStringIO import StringIO
import logging
import traceback
import multiprocessing as mp
from multiprocessing.queues import Queue
from .sharedmemory import SharedMemoryMgr
logger = logging.getLogger(__name__)
class SharedQueueError(ValueError):
""" SharedQueueError
"""
pass
class SharedQueue(Queue):
""" a Queue based on shared memory to communicate data between Process,
and it's interface is compatible with 'multiprocessing.queues.Queue'
"""
def __init__(self, maxsize=0, mem_mgr=None, memsize=None, pagesize=None):
""" init
"""
if six.PY3:
super(SharedQueue, self).__init__(maxsize, ctx=mp.get_context())
else:
super(SharedQueue, self).__init__(maxsize)
if mem_mgr is not None:
self._shared_mem = mem_mgr
else:
self._shared_mem = SharedMemoryMgr(
capacity=memsize, pagesize=pagesize)
def put(self, obj, **kwargs):
""" put an object to this queue
"""
obj = pickle.dumps(obj, -1)
buff = None
try:
buff = self._shared_mem.malloc(len(obj))
buff.put(obj)
super(SharedQueue, self).put(buff, **kwargs)
except Exception as e:
stack_info = traceback.format_exc()
err_msg = 'failed to put a element to SharedQueue '\
'with stack info[%s]' % (stack_info)
logger.warn(err_msg)
if buff is not None:
buff.free()
raise e
def get(self, **kwargs):
""" get an object from this queue
"""
buff = None
try:
buff = super(SharedQueue, self).get(**kwargs)
data = buff.get()
return pickle.load(StringIO(data))
except Exception as e:
stack_info = traceback.format_exc()
err_msg = 'failed to get element from SharedQueue '\
'with stack info[%s]' % (stack_info)
logger.warn(err_msg)
raise e
finally:
if buff is not None:
buff.free()
def release(self):
self._shared_mem.release()
self._shared_mem = None
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# utils for memory management which is allocated on sharedmemory,
# note that these structures may not be thread-safe
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import time
import math
import struct
import sys
import six
if six.PY3:
import pickle
else:
import cPickle as pickle
import json
import uuid
import random
import numpy as np
import weakref
import logging
from multiprocessing import Lock
from multiprocessing import RawArray
logger = logging.getLogger(__name__)
class SharedMemoryError(ValueError):
""" SharedMemoryError
"""
pass
class SharedBufferError(SharedMemoryError):
""" SharedBufferError
"""
pass
class MemoryFullError(SharedMemoryError):
""" MemoryFullError
"""
def __init__(self, errmsg=''):
super(MemoryFullError, self).__init__()
self.errmsg = errmsg
def memcopy(dst, src, offset=0, length=None):
""" copy data from 'src' to 'dst' in bytes
"""
length = length if length is not None else len(src)
assert type(dst) == np.ndarray, 'invalid type for "dst" in memcopy'
if type(src) is not np.ndarray:
if type(src) is str and six.PY3:
src = src.encode()
src = np.frombuffer(src, dtype='uint8', count=len(src))
dst[:] = src[offset:offset + length]
class SharedBuffer(object):
""" Buffer allocated from SharedMemoryMgr, and it stores data on shared memory
note that:
every instance of this should be freed explicitely by calling 'self.free'
"""
def __init__(self, owner, capacity, pos, size=0, alloc_status=''):
""" Init
Args:
owner (str): manager to own this buffer
capacity (int): capacity in bytes for this buffer
pos (int): page position in shared memory
size (int): bytes already used
alloc_status (str): debug info about allocator when allocate this
"""
self._owner = owner
self._cap = capacity
self._pos = pos
self._size = size
self._alloc_status = alloc_status
assert self._pos >= 0 and self._cap > 0, \
"invalid params[%d:%d] to construct SharedBuffer" \
% (self._pos, self._cap)
def owner(self):
""" get owner
"""
return SharedMemoryMgr.get_mgr(self._owner)
def put(self, data, override=False):
""" put data to this buffer
Args:
data (str): data to be stored in this buffer
Returns:
None
Raises:
SharedMemoryError when not enough space in this buffer
"""
assert type(data) in [str, bytes], \
'invalid type[%s] for SharedBuffer::put' % (str(type(data)))
if self._size > 0 and not override:
raise SharedBufferError('already has already been setted before')
if self.capacity() < len(data):
raise SharedBufferError('data[%d] is larger than size of buffer[%s]'\
% (len(data), str(self)))
self.owner().put_data(self, data)
self._size = len(data)
def get(self, offset=0, size=None, no_copy=True):
""" get the data stored this buffer
Args:
offset (int): position for the start point to 'get'
size (int): size to get
Returns:
data (np.ndarray('uint8')): user's data in numpy
which is passed in by 'put'
None: if no data stored in
"""
offset = offset if offset >= 0 else self._size + offset
if self._size <= 0:
return None
size = self._size if size is None else size
assert offset + size <= self._cap, 'invalid offset[%d] '\
'or size[%d] for capacity[%d]' % (offset, size, self._cap)
return self.owner().get_data(self, offset, size, no_copy=no_copy)
def size(self):
""" bytes of used memory
"""
return self._size
def resize(self, size):
""" resize the used memory to 'size', should not be greater than capacity
"""
assert size >= 0 and size <= self._cap, \
"invalid size[%d] for resize" % (size)
self._size = size
def capacity(self):
""" size of allocated memory
"""
return self._cap
def __str__(self):
""" human readable format
"""
return "SharedBuffer(owner:%s, pos:%d, size:%d, "\
"capacity:%d, alloc_status:[%s], pid:%d)" \
% (str(self._owner), self._pos, self._size, \
self._cap, self._alloc_status, os.getpid())
def free(self):
""" free this buffer to it's owner
"""
if self._owner is not None:
self.owner().free(self)
self._owner = None
self._cap = 0
self._pos = -1
self._size = 0
return True
else:
return False
class PageAllocator(object):
""" allocator used to malloc and free shared memory which
is split into pages
"""
s_allocator_header = 12
def __init__(self, base, total_pages, page_size):
""" init
"""
self._magic_num = 1234321000 + random.randint(100, 999)
self._base = base
self._total_pages = total_pages
self._page_size = page_size
header_pages = int(
math.ceil((total_pages + self.s_allocator_header) / page_size))
self._header_pages = header_pages
self._free_pages = total_pages - header_pages
self._header_size = self._header_pages * page_size
self._reset()
def _dump_alloc_info(self, fname):
hpages, tpages, pos, used = self.header()
start = self.s_allocator_header
end = start + self._page_size * hpages
alloc_flags = self._base[start:end].tostring()
info = {
'magic_num': self._magic_num,
'header_pages': hpages,
'total_pages': tpages,
'pos': pos,
'used': used
}
info['alloc_flags'] = alloc_flags
fname = fname + '.' + str(uuid.uuid4())[:6]
with open(fname, 'wb') as f:
f.write(pickle.dumps(info, -1))
logger.warn('dump alloc info to file[%s]' % (fname))
def _reset(self):
alloc_page_pos = self._header_pages
used_pages = self._header_pages
header_info = struct.pack(
str('III'), self._magic_num, alloc_page_pos, used_pages)
assert len(header_info) == self.s_allocator_header, \
'invalid size of header_info'
memcopy(self._base[0:self.s_allocator_header], header_info)
self.set_page_status(0, self._header_pages, '1')
self.set_page_status(self._header_pages, self._free_pages, '0')
def header(self):
""" get header info of this allocator
"""
header_str = self._base[0:self.s_allocator_header].tostring()
magic, pos, used = struct.unpack(str('III'), header_str)
assert magic == self._magic_num, \
'invalid header magic[%d] in shared memory' % (magic)
return self._header_pages, self._total_pages, pos, used
def empty(self):
""" are all allocatable pages available
"""
header_pages, pages, pos, used = self.header()
return header_pages == used
def full(self):
""" are all allocatable pages used
"""
header_pages, pages, pos, used = self.header()
return header_pages + used == pages
def __str__(self):
header_pages, pages, pos, used = self.header()
desc = '{page_info[magic:%d,total:%d,used:%d,header:%d,alloc_pos:%d,pagesize:%d]}' \
% (self._magic_num, pages, used, header_pages, pos, self._page_size)
return 'PageAllocator:%s' % (desc)
def set_alloc_info(self, alloc_pos, used_pages):
""" set allocating position to new value
"""
memcopy(self._base[4:12], struct.pack(str('II'), alloc_pos, used_pages))
def set_page_status(self, start, page_num, status):
""" set pages from 'start' to 'end' with new same status 'status'
"""
assert status in ['0', '1'], 'invalid status[%s] for page status '\
'in allocator[%s]' % (status, str(self))
start += self.s_allocator_header
end = start + page_num
assert start >= 0 and end <= self._header_size, 'invalid end[%d] of pages '\
'in allocator[%s]' % (end, str(self))
memcopy(self._base[start:end], str(status * page_num))
def get_page_status(self, start, page_num, ret_flag=False):
start += self.s_allocator_header
end = start + page_num
assert start >= 0 and end <= self._header_size, 'invalid end[%d] of pages '\
'in allocator[%s]' % (end, str(self))
status = self._base[start:end].tostring().decode()
if ret_flag:
return status
zero_num = status.count('0')
if zero_num == 0:
return (page_num, 1)
else:
return (zero_num, 0)
def malloc_page(self, page_num):
header_pages, pages, pos, used = self.header()
end = pos + page_num
if end > pages:
pos = self._header_pages
end = pos + page_num
start_pos = pos
flags = ''
while True:
# maybe flags already has some '0' pages,
# so just check 'page_num - len(flags)' pages
flags = self.get_page_status(pos, page_num, ret_flag=True)
if flags.count('0') == page_num:
break
# not found enough pages, so shift to next few pages
free_pos = flags.rfind('1') + 1
pos += free_pos
end = pos + page_num
if end > pages:
pos = self._header_pages
end = pos + page_num
flags = ''
# not found available pages after scan all pages
if pos <= start_pos and end >= start_pos:
logger.debug('not found available pages after scan all pages')
break
page_status = (flags.count('0'), 0)
if page_status != (page_num, 0):
free_pages = self._total_pages - used
if free_pages == 0:
err_msg = 'all pages have been used:%s' % (str(self))
else:
err_msg = 'not found available pages with page_status[%s] '\
'and %d free pages' % (str(page_status), free_pages)
err_msg = 'failed to malloc %d pages at pos[%d] for reason[%s] and allocator status[%s]' \
% (page_num, pos, err_msg, str(self))
raise MemoryFullError(err_msg)
self.set_page_status(pos, page_num, '1')
used += page_num
self.set_alloc_info(end, used)
return pos
def free_page(self, start, page_num):
""" free 'page_num' pages start from 'start'
"""
page_status = self.get_page_status(start, page_num)
assert page_status == (page_num, 1), \
'invalid status[%s] when free [%d, %d]' \
% (str(page_status), start, page_num)
self.set_page_status(start, page_num, '0')
_, _, pos, used = self.header()
used -= page_num
self.set_alloc_info(pos, used)
DEFAULT_SHARED_MEMORY_SIZE = 1024 * 1024 * 1024
class SharedMemoryMgr(object):
""" manage a continouse block of memory, provide
'malloc' to allocate new buffer, and 'free' to free buffer
"""
s_memory_mgrs = weakref.WeakValueDictionary()
s_mgr_num = 0
s_log_statis = False
@classmethod
def get_mgr(cls, id):
""" get a SharedMemoryMgr with size of 'capacity'
"""
assert id in cls.s_memory_mgrs, 'invalid id[%s] for memory managers' % (
id)
return cls.s_memory_mgrs[id]
def __init__(self, capacity=None, pagesize=None):
""" init
"""
logger.debug('create SharedMemoryMgr')
pagesize = 64 * 1024 if pagesize is None else pagesize
assert type(pagesize) is int, "invalid type of pagesize[%s]" \
% (str(pagesize))
capacity = DEFAULT_SHARED_MEMORY_SIZE if capacity is None else capacity
assert type(capacity) is int, "invalid type of capacity[%s]" \
% (str(capacity))
assert capacity > 0, '"size of shared memory should be greater than 0'
self._released = False
self._cap = capacity
self._page_size = pagesize
assert self._cap % self._page_size == 0, \
"capacity[%d] and pagesize[%d] are not consistent" \
% (self._cap, self._page_size)
self._total_pages = self._cap // self._page_size
self._pid = os.getpid()
SharedMemoryMgr.s_mgr_num += 1
self._id = self._pid * 100 + SharedMemoryMgr.s_mgr_num
SharedMemoryMgr.s_memory_mgrs[self._id] = self
self._locker = Lock()
self._setup()
def _setup(self):
self._shared_mem = RawArray('c', self._cap)
self._base = np.frombuffer(
self._shared_mem, dtype='uint8', count=self._cap)
self._locker.acquire()
try:
self._allocator = PageAllocator(self._base, self._total_pages,
self._page_size)
finally:
self._locker.release()
def malloc(self, size, wait=True):
""" malloc a new SharedBuffer
Args:
size (int): buffer size to be malloc
wait (bool): whether to wait when no enough memory
Returns:
SharedBuffer
Raises:
SharedMemoryError when not found available memory
"""
page_num = int(math.ceil(size / self._page_size))
size = page_num * self._page_size
start = None
ct = 0
errmsg = ''
while True:
self._locker.acquire()
try:
start = self._allocator.malloc_page(page_num)
alloc_status = str(self._allocator)
except MemoryFullError as e:
start = None
errmsg = e.errmsg
if not wait:
raise e
finally:
self._locker.release()
if start is None:
time.sleep(0.1)
if ct % 100 == 0:
logger.warn('not enough space for reason[%s]' % (errmsg))
ct += 1
else:
break
return SharedBuffer(self._id, size, start, alloc_status=alloc_status)
def free(self, shared_buf):
""" free a SharedBuffer
Args:
shared_buf (SharedBuffer): buffer to be freed
Returns:
None
Raises:
SharedMemoryError when failed to release this buffer
"""
assert shared_buf._owner == self._id, "invalid shared_buf[%s] "\
"for it's not allocated from me[%s]" % (str(shared_buf), str(self))
cap = shared_buf.capacity()
start_page = shared_buf._pos
page_num = cap // self._page_size
#maybe we don't need this lock here
self._locker.acquire()
try:
self._allocator.free_page(start_page, page_num)
finally:
self._locker.release()
def put_data(self, shared_buf, data):
""" fill 'data' into 'shared_buf'
"""
assert len(data) <= shared_buf.capacity(), 'too large data[%d] '\
'for this buffer[%s]' % (len(data), str(shared_buf))
start = shared_buf._pos * self._page_size
end = start + len(data)
assert start >= 0 and end <= self._cap, "invalid start "\
"position[%d] when put data to buff:%s" % (start, str(shared_buf))
self._base[start:end] = np.frombuffer(data, 'uint8', len(data))
def get_data(self, shared_buf, offset, size, no_copy=True):
""" extract 'data' from 'shared_buf' in range [offset, offset + size)
"""
start = shared_buf._pos * self._page_size
start += offset
if no_copy:
return self._base[start:start + size]
else:
return self._base[start:start + size].tostring()
def __str__(self):
return 'SharedMemoryMgr:{id:%d, %s}' % (self._id, str(self._allocator))
def __del__(self):
if SharedMemoryMgr.s_log_statis:
logger.info('destroy [%s]' % (self))
if not self._released and not self._allocator.empty():
logger.debug(
'not empty when delete this SharedMemoryMgr[%s]' % (self))
else:
self._released = True
if self._id in SharedMemoryMgr.s_memory_mgrs:
del SharedMemoryMgr.s_memory_mgrs[self._id]
SharedMemoryMgr.s_mgr_num -= 1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import functools
import collections
from ..dataset import Dataset
class ProxiedDataset(Dataset):
"""proxy method called to 'self._ds' when if not defined"""
def __init__(self, ds):
super(ProxiedDataset, self).__init__()
self._ds = ds
methods = filter(lambda k: not k.startswith('_'),
Dataset.__dict__.keys())
for m in methods:
func = functools.partial(self._proxy_method, getattr(self, m))
setattr(self, m, func)
def _proxy_method(self, func, *args, **kwargs):
"""
proxy call to 'func', if not available then call self._ds.xxx
whose name is the same with func.__name__
"""
method = func.__name__
try:
return func(*args, **kwargs)
except NotImplementedError:
ds_func = getattr(self._ds, method)
return ds_func(*args, **kwargs)
class MappedDataset(ProxiedDataset):
def __init__(self, ds, mapper):
super(MappedDataset, self).__init__(ds)
self._ds = ds
self._mapper = mapper
def next(self):
sample = self._ds.next()
return self._mapper(sample)
class BatchedDataset(ProxiedDataset):
"""
Batching samples
Args:
ds (instance of Dataset): dataset to be batched
batchsize (int): sample number for each batch
drop_last (bool): drop last samples when not enough for one batch
drop_empty (bool): drop samples which have empty field
"""
def __init__(self, ds, batchsize, drop_last=False, drop_empty=True):
super(BatchedDataset, self).__init__(ds)
self._batchsz = batchsize
self._drop_last = drop_last
self._drop_empty = drop_empty
def next(self):
"""proxy to self._ds.next"""
def empty(x):
if isinstance(x, np.ndarray) and x.size == 0:
return True
elif isinstance(x, collections.Sequence) and len(x) == 0:
return True
else:
return False
def has_empty(items):
if any(x is None for x in items):
return True
if any(empty(x) for x in items):
return True
return False
batch = []
for _ in range(self._batchsz):
try:
out = self._ds.next()
while self._drop_empty and has_empty(out):
out = self._ds.next()
batch.append(out)
except StopIteration:
if not self._drop_last and len(batch) > 0:
return batch
else:
raise StopIteration
return batch
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import sys
import json
import cv2
import numpy as np
import matplotlib
matplotlib.use('Agg')
import logging
logger = logging.getLogger(__name__)
__all__ = [
'bbox_eval',
'mask_eval',
'bbox2out',
'mask2out',
'get_category_info',
'proposal_eval',
'cocoapi_eval',
]
def clip_bbox(bbox):
xmin = max(min(bbox[0], 1.), 0.)
ymin = max(min(bbox[1], 1.), 0.)
xmax = max(min(bbox[2], 1.), 0.)
ymax = max(min(bbox[3], 1.), 0.)
return xmin, ymin, xmax, ymax
def proposal_eval(results, anno_file, outfile, max_dets=(100, 300, 1000)):
assert 'proposal' in results[0]
assert outfile.endswith('.json')
xywh_results = proposal2out(results)
assert len(
xywh_results) > 0, "The number of valid proposal detected is zero.\n \
Please use reasonable model and check input data."
with open(outfile, 'w') as f:
json.dump(xywh_results, f)
cocoapi_eval(outfile, 'proposal', anno_file=anno_file, max_dets=max_dets)
# flush coco evaluation result
sys.stdout.flush()
def bbox_eval(results,
anno_file,
outfile,
with_background=True,
is_bbox_normalized=False):
assert 'bbox' in results[0]
assert outfile.endswith('.json')
from pycocotools.coco import COCO
coco_gt = COCO(anno_file)
cat_ids = coco_gt.getCatIds()
# when with_background = True, mapping category to classid, like:
# background:0, first_class:1, second_class:2, ...
clsid2catid = dict(
{i + int(with_background): catid
for i, catid in enumerate(cat_ids)})
xywh_results = bbox2out(
results, clsid2catid, is_bbox_normalized=is_bbox_normalized)
if len(xywh_results) == 0:
logger.warning("The number of valid bbox detected is zero.\n \
Please use reasonable model and check input data.\n \
stop eval!")
return [0.0]
with open(outfile, 'w') as f:
json.dump(xywh_results, f)
map_stats = cocoapi_eval(outfile, 'bbox', coco_gt=coco_gt)
# flush coco evaluation result
sys.stdout.flush()
return map_stats
def mask_eval(results, anno_file, outfile, resolution, thresh_binarize=0.5):
assert 'mask' in results[0]
assert outfile.endswith('.json')
from pycocotools.coco import COCO
coco_gt = COCO(anno_file)
clsid2catid = {i + 1: v for i, v in enumerate(coco_gt.getCatIds())}
segm_results = mask2out(results, clsid2catid, resolution, thresh_binarize)
if len(segm_results) == 0:
logger.warning("The number of valid mask detected is zero.\n \
Please use reasonable model and check input data.")
return
with open(outfile, 'w') as f:
json.dump(segm_results, f)
cocoapi_eval(outfile, 'segm', coco_gt=coco_gt)
def cocoapi_eval(jsonfile,
style,
coco_gt=None,
anno_file=None,
max_dets=(100, 300, 1000)):
"""
Args:
jsonfile: Evaluation json file, eg: bbox.json, mask.json.
style: COCOeval style, can be `bbox` , `segm` and `proposal`.
coco_gt: Whether to load COCOAPI through anno_file,
eg: coco_gt = COCO(anno_file)
anno_file: COCO annotations file.
max_dets: COCO evaluation maxDets.
"""
assert coco_gt != None or anno_file != None
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
if coco_gt == None:
coco_gt = COCO(anno_file)
logger.info("Start evaluate...")
coco_dt = coco_gt.loadRes(jsonfile)
if style == 'proposal':
coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
coco_eval.params.useCats = 0
coco_eval.params.maxDets = list(max_dets)
else:
coco_eval = COCOeval(coco_gt, coco_dt, style)
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
return coco_eval.stats
def proposal2out(results, is_bbox_normalized=False):
xywh_res = []
for t in results:
bboxes = t['proposal'][0]
lengths = t['proposal'][1][0]
im_ids = np.array(t['im_id'][0])
if bboxes.shape == (1, 1) or bboxes is None:
continue
k = 0
for i in range(len(lengths)):
num = lengths[i]
im_id = int(im_ids[i][0])
for j in range(num):
dt = bboxes[k]
xmin, ymin, xmax, ymax = dt.tolist()
if is_bbox_normalized:
xmin, ymin, xmax, ymax = \
clip_bbox([xmin, ymin, xmax, ymax])
w = xmax - xmin
h = ymax - ymin
else:
w = xmax - xmin + 1
h = ymax - ymin + 1
bbox = [xmin, ymin, w, h]
coco_res = {
'image_id': im_id,
'category_id': 1,
'bbox': bbox,
'score': 1.0
}
xywh_res.append(coco_res)
k += 1
return xywh_res
def bbox2out(results, clsid2catid, is_bbox_normalized=False):
"""
Args:
results: request a dict, should include: `bbox`, `im_id`,
if is_bbox_normalized=True, also need `im_shape`.
clsid2catid: class id to category id map of COCO2017 dataset.
is_bbox_normalized: whether or not bbox is normalized.
"""
xywh_res = []
for t in results:
bboxes = t['bbox'][0]
lengths = t['bbox'][1][0]
im_ids = np.array(t['im_id'][0])
if bboxes.shape == (1, 1) or bboxes is None:
continue
k = 0
for i in range(len(lengths)):
num = lengths[i]
im_id = int(im_ids[i][0])
for j in range(num):
dt = bboxes[k]
clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
catid = (clsid2catid[int(clsid)])
if is_bbox_normalized:
xmin, ymin, xmax, ymax = \
clip_bbox([xmin, ymin, xmax, ymax])
w = xmax - xmin
h = ymax - ymin
im_height, im_width = t['im_shape'][0][i].tolist()
xmin *= im_width
ymin *= im_height
w *= im_width
h *= im_height
else:
w = xmax - xmin + 1
h = ymax - ymin + 1
bbox = [xmin, ymin, w, h]
coco_res = {
'image_id': im_id,
'category_id': catid,
'bbox': bbox,
'score': score
}
xywh_res.append(coco_res)
k += 1
return xywh_res
def mask2out(results, clsid2catid, resolution, thresh_binarize=0.5):
import pycocotools.mask as mask_util
scale = (resolution + 2.0) / resolution
segm_res = []
# for each batch
for t in results:
bboxes = t['bbox'][0]
lengths = t['bbox'][1][0]
im_ids = np.array(t['im_id'][0])
if bboxes.shape == (1, 1) or bboxes is None:
continue
if len(bboxes.tolist()) == 0:
continue
masks = t['mask'][0]
s = 0
# for each sample
for i in range(len(lengths)):
num = lengths[i]
im_id = int(im_ids[i][0])
im_shape = t['im_shape'][0][i]
bbox = bboxes[s:s + num][:, 2:]
clsid_scores = bboxes[s:s + num][:, 0:2]
mask = masks[s:s + num]
s += num
im_h = int(im_shape[0])
im_w = int(im_shape[1])
expand_bbox = expand_boxes(bbox, scale)
expand_bbox = expand_bbox.astype(np.int32)
padded_mask = np.zeros((resolution + 2, resolution + 2),
dtype=np.float32)
for j in range(num):
xmin, ymin, xmax, ymax = expand_bbox[j].tolist()
clsid, score = clsid_scores[j].tolist()
clsid = int(clsid)
padded_mask[1:-1, 1:-1] = mask[j, clsid, :, :]
catid = clsid2catid[clsid]
w = xmax - xmin + 1
h = ymax - ymin + 1
w = np.maximum(w, 1)
h = np.maximum(h, 1)
resized_mask = cv2.resize(padded_mask, (w, h))
resized_mask = np.array(
resized_mask > thresh_binarize, dtype=np.uint8)
im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
x0 = min(max(xmin, 0), im_w)
x1 = min(max(xmax + 1, 0), im_w)
y0 = min(max(ymin, 0), im_h)
y1 = min(max(ymax + 1, 0), im_h)
im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), (
x0 - xmin):(x1 - xmin)]
segm = mask_util.encode(
np.array(im_mask[:, :, np.newaxis], order='F'))[0]
catid = clsid2catid[clsid]
segm['counts'] = segm['counts'].decode('utf8')
coco_res = {
'image_id': im_id,
'category_id': catid,
'segmentation': segm,
'score': score
}
segm_res.append(coco_res)
return segm_res
def expand_boxes(boxes, scale):
"""
Expand an array of boxes by a given scale.
"""
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
y_c = (boxes[:, 3] + boxes[:, 1]) * .5
w_half *= scale
h_half *= scale
boxes_exp = np.zeros(boxes.shape)
boxes_exp[:, 0] = x_c - w_half
boxes_exp[:, 2] = x_c + w_half
boxes_exp[:, 1] = y_c - h_half
boxes_exp[:, 3] = y_c + h_half
return boxes_exp
def get_category_info(anno_file=None,
with_background=True,
use_default_label=False):
if use_default_label or anno_file is None \
or not os.path.exists(anno_file):
logger.info("Not found annotation file {}, load "
"coco17 categories.".format(anno_file))
return coco17_category_info(with_background)
else:
logger.info("Load categories from {}".format(anno_file))
return get_category_info_from_anno(anno_file, with_background)
def get_category_info_from_anno(anno_file, with_background=True):
"""
Get class id to category id map and category id
to category name map from annotation file.
Args:
anno_file (str): annotation file path
with_background (bool, default True):
whether load background as class 0.
"""
from pycocotools.coco import COCO
coco = COCO(anno_file)
cats = coco.loadCats(coco.getCatIds())
clsid2catid = {
i + int(with_background): cat['id']
for i, cat in enumerate(cats)
}
catid2name = {cat['id']: cat['name'] for cat in cats}
return clsid2catid, catid2name
def coco17_category_info(with_background=True):
"""
Get class id to category id map and category id
to category name map of COCO2017 dataset
Args:
with_background (bool, default True):
whether load background as class 0.
"""
clsid2catid = {
1: 1,
2: 2,
3: 3,
4: 4,
5: 5,
6: 6,
7: 7,
8: 8,
9: 9,
10: 10,
11: 11,
12: 13,
13: 14,
14: 15,
15: 16,
16: 17,
17: 18,
18: 19,
19: 20,
20: 21,
21: 22,
22: 23,
23: 24,
24: 25,
25: 27,
26: 28,
27: 31,
28: 32,
29: 33,
30: 34,
31: 35,
32: 36,
33: 37,
34: 38,
35: 39,
36: 40,
37: 41,
38: 42,
39: 43,
40: 44,
41: 46,
42: 47,
43: 48,
44: 49,
45: 50,
46: 51,
47: 52,
48: 53,
49: 54,
50: 55,
51: 56,
52: 57,
53: 58,
54: 59,
55: 60,
56: 61,
57: 62,
58: 63,
59: 64,
60: 65,
61: 67,
62: 70,
63: 72,
64: 73,
65: 74,
66: 75,
67: 76,
68: 77,
69: 78,
70: 79,
71: 80,
72: 81,
73: 82,
74: 84,
75: 85,
76: 86,
77: 87,
78: 88,
79: 89,
80: 90
}
catid2name = {
0: 'background',
1: 'person',
2: 'bicycle',
3: 'car',
4: 'motorcycle',
5: 'airplane',
6: 'bus',
7: 'train',
8: 'truck',
9: 'boat',
10: 'traffic light',
11: 'fire hydrant',
13: 'stop sign',
14: 'parking meter',
15: 'bench',
16: 'bird',
17: 'cat',
18: 'dog',
19: 'horse',
20: 'sheep',
21: 'cow',
22: 'elephant',
23: 'bear',
24: 'zebra',
25: 'giraffe',
27: 'backpack',
28: 'umbrella',
31: 'handbag',
32: 'tie',
33: 'suitcase',
34: 'frisbee',
35: 'skis',
36: 'snowboard',
37: 'sports ball',
38: 'kite',
39: 'baseball bat',
40: 'baseball glove',
41: 'skateboard',
42: 'surfboard',
43: 'tennis racket',
44: 'bottle',
46: 'wine glass',
47: 'cup',
48: 'fork',
49: 'knife',
50: 'spoon',
51: 'bowl',
52: 'banana',
53: 'apple',
54: 'sandwich',
55: 'orange',
56: 'broccoli',
57: 'carrot',
58: 'hot dog',
59: 'pizza',
60: 'donut',
61: 'cake',
62: 'chair',
63: 'couch',
64: 'potted plant',
65: 'bed',
67: 'dining table',
70: 'toilet',
72: 'tv',
73: 'laptop',
74: 'mouse',
75: 'remote',
76: 'keyboard',
77: 'cell phone',
78: 'microwave',
79: 'oven',
80: 'toaster',
81: 'sink',
82: 'refrigerator',
84: 'book',
85: 'clock',
86: 'vase',
87: 'scissors',
88: 'teddy bear',
89: 'hair drier',
90: 'toothbrush'
}
if not with_background:
clsid2catid = {k - 1: v for k, v in clsid2catid.items()}
return clsid2catid, catid2name
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
import os
import time
import paddle.fluid as fluid
from ..utils.voc_eval import bbox_eval as voc_bbox_eval
from ..utils.post_process import mstest_box_post_process, mstest_mask_post_process, box_flip
__all__ = ['parse_fetches', 'eval_run', 'eval_results', 'json_eval_results']
logger = logging.getLogger(__name__)
def parse_fetches(fetches, prog=None, extra_keys=None):
"""
Parse fetch variable infos from model fetches,
values for fetch_list and keys for stat
"""
keys, values = [], []
cls = []
for k, v in fetches.items():
if hasattr(v, 'name'):
keys.append(k)
v.persistable = True
values.append(v.name)
else:
cls.append(v)
if prog is not None and extra_keys is not None:
for k in extra_keys:
try:
v = fluid.framework._get_var(k, prog)
keys.append(k)
values.append(v.name)
except Exception:
pass
return keys, values, cls
def length2lod(length_lod):
offset_lod = [0]
for i in length_lod:
offset_lod.append(offset_lod[-1] + i)
return [offset_lod]
def get_sub_feed(input, place):
new_dict = {}
res_feed = {}
key_name = ['bbox', 'im_info', 'im_id', 'im_shape', 'bbox_flip']
for k in key_name:
if k in input.keys():
new_dict[k] = input[k]
for k in input.keys():
if 'image' in k:
new_dict[k] = input[k]
for k, v in new_dict.items():
data_t = fluid.LoDTensor()
data_t.set(v[0], place)
if 'bbox' in k:
lod = length2lod(v[1][0])
data_t.set_lod(lod)
res_feed[k] = data_t
return res_feed
def clean_res(result, keep_name_list):
clean_result = {}
for k in result.keys():
if k in keep_name_list:
clean_result[k] = result[k]
result.clear()
return clean_result
def eval_run(exe,
compile_program,
loader,
keys,
values,
cls,
cfg=None,
sub_prog=None,
sub_keys=None,
sub_values=None):
"""
Run evaluation program, return program outputs.
"""
iter_id = 0
results = []
if len(cls) != 0:
values = []
for i in range(len(cls)):
_, accum_map = cls[i].get_map_var()
cls[i].reset(exe)
values.append(accum_map)
images_num = 0
start_time = time.time()
has_bbox = 'bbox' in keys
try:
loader.start()
while True:
outs = exe.run(
compile_program, fetch_list=values, return_numpy=False)
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(keys, outs)
}
multi_scale_test = getattr(cfg, 'MultiScaleTEST', None)
mask_multi_scale_test = multi_scale_test and 'Mask' in cfg.architecture
if multi_scale_test:
post_res = mstest_box_post_process(res, cfg)
res.update(post_res)
if mask_multi_scale_test:
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
sub_feed = get_sub_feed(res, place)
sub_prog_outs = exe.run(
sub_prog,
feed=sub_feed,
fetch_list=sub_values,
return_numpy=False)
sub_prog_res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(sub_keys, sub_prog_outs)
}
post_res = mstest_mask_post_process(sub_prog_res, cfg)
res.update(post_res)
if multi_scale_test:
res = clean_res(
res, ['im_info', 'bbox', 'im_id', 'im_shape', 'mask'])
results.append(res)
if iter_id % 100 == 0:
logger.info('Test iter {}'.format(iter_id))
iter_id += 1
images_num += len(res['bbox'][1][0]) if has_bbox else 1
except (StopIteration, fluid.core.EOFException):
loader.reset()
logger.info('Test finish iter {}'.format(iter_id))
end_time = time.time()
fps = images_num / (end_time - start_time)
if has_bbox:
logger.info(
'Total number of images: {}, inference time: {} fps.'.format(
images_num, fps))
else:
logger.info('Total iteration: {}, inference time: {} batch/s.'.format(
images_num, fps))
return results
def eval_results(results,
feed,
metric,
num_classes,
resolution=None,
is_bbox_normalized=False,
output_directory=None,
map_type='11point'):
"""Evaluation for evaluation program results"""
box_ap_stats = []
if metric == 'COCO':
from ..utils.coco_eval import proposal_eval, bbox_eval, mask_eval
anno_file = getattr(feed.dataset, 'annotation', None)
assert anno_file is not None
with_background = getattr(feed, 'with_background', True)
if 'proposal' in results[0]:
output = 'proposal.json'
if output_directory:
output = os.path.join(output_directory, 'proposal.json')
proposal_eval(results, anno_file, output)
if 'bbox' in results[0]:
output = 'bbox.json'
if output_directory:
output = os.path.join(output_directory, 'bbox.json')
box_ap_stats = bbox_eval(
results,
anno_file,
output,
with_background,
is_bbox_normalized=is_bbox_normalized)
if 'mask' in results[0]:
output = 'mask.json'
if output_directory:
output = os.path.join(output_directory, 'mask.json')
mask_eval(results, anno_file, output, resolution)
else:
if 'accum_map' in results[-1]:
res = np.mean(results[-1]['accum_map'][0])
logger.info('mAP: {:.2f}'.format(res * 100.))
box_ap_stats.append(res * 100.)
elif 'bbox' in results[0]:
box_ap = voc_bbox_eval(
results,
num_classes,
is_bbox_normalized=is_bbox_normalized,
map_type=map_type)
box_ap_stats.append(box_ap)
return box_ap_stats
def json_eval_results(feed, metric, json_directory=None):
"""
cocoapi eval with already exists proposal.json, bbox.json or mask.json
"""
assert metric == 'COCO'
from ppdet.utils.coco_eval import cocoapi_eval
anno_file = getattr(feed.dataset, 'annotation', None)
json_file_list = ['proposal.json', 'bbox.json', 'mask.json']
if json_directory:
assert os.path.exists(
json_directory), "The json directory:{} does not exist".format(
json_directory)
for k, v in enumerate(json_file_list):
json_file_list[k] = os.path.join(str(json_directory), v)
coco_eval_style = ['proposal', 'bbox', 'segm']
for i, v_json in enumerate(json_file_list):
if os.path.exists(v_json):
cocoapi_eval(v_json, coco_eval_style[i], anno_file=anno_file)
else:
logger.info("{} not exists!".format(v_json))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import numpy as np
import logging
logger = logging.getLogger(__name__)
__all__ = ['bbox_area', 'jaccard_overlap', 'DetectionMAP']
def bbox_area(bbox, is_bbox_normalized):
"""
Calculate area of a bounding box
"""
norm = 1. - float(is_bbox_normalized)
width = bbox[2] - bbox[0] + norm
height = bbox[3] - bbox[1] + norm
return width * height
def jaccard_overlap(pred, gt, is_bbox_normalized=False):
"""
Calculate jaccard overlap ratio between two bounding box
"""
if pred[0] >= gt[2] or pred[2] <= gt[0] or \
pred[1] >= gt[3] or pred[3] <= gt[1]:
return 0.
inter_xmin = max(pred[0], gt[0])
inter_ymin = max(pred[1], gt[1])
inter_xmax = min(pred[2], gt[2])
inter_ymax = min(pred[3], gt[3])
inter_size = bbox_area([inter_xmin, inter_ymin, inter_xmax, inter_ymax],
is_bbox_normalized)
pred_size = bbox_area(pred, is_bbox_normalized)
gt_size = bbox_area(gt, is_bbox_normalized)
overlap = float(inter_size) / (pred_size + gt_size - inter_size)
return overlap
class DetectionMAP(object):
"""
Calculate detection mean average precision.
Currently support two types: 11point and integral
Args:
class_num (int): the class number.
overlap_thresh (float): The threshold of overlap
ratio between prediction bounding box and
ground truth bounding box for deciding
true/false positive. Default 0.5.
map_type (str): calculation method of mean average
precision, currently support '11point' and
'integral'. Default '11point'.
is_bbox_normalized (bool): whther bounding boxes
is normalized to range[0, 1]. Default False.
evaluate_difficult (bool): whether to evaluate
difficult bounding boxes. Default False.
"""
def __init__(self,
class_num,
overlap_thresh=0.5,
map_type='11point',
is_bbox_normalized=False,
evaluate_difficult=False):
self.class_num = class_num
self.overlap_thresh = overlap_thresh
assert map_type in ['11point', 'integral'], \
"map_type currently only support '11point' "\
"and 'integral'"
self.map_type = map_type
self.is_bbox_normalized = is_bbox_normalized
self.evaluate_difficult = evaluate_difficult
self.reset()
def update(self, bbox, gt_box, gt_label, difficult=None):
"""
Update metric statics from given prediction and ground
truth infomations.
"""
if difficult is None:
difficult = np.zeros_like(gt_label)
# record class gt count
for gtl, diff in zip(gt_label, difficult):
if self.evaluate_difficult or int(diff) == 0:
self.class_gt_counts[int(np.array(gtl))] += 1
# record class score positive
visited = [False] * len(gt_label)
for b in bbox:
label, score, xmin, ymin, xmax, ymax = b.tolist()
pred = [xmin, ymin, xmax, ymax]
max_idx = -1
max_overlap = -1.0
for i, gl in enumerate(gt_label):
if int(gl) == int(label):
overlap = jaccard_overlap(pred, gt_box[i],
self.is_bbox_normalized)
if overlap > max_overlap:
max_overlap = overlap
max_idx = i
if max_overlap > self.overlap_thresh:
if self.evaluate_difficult or \
int(np.array(difficult[max_idx])) == 0:
if not visited[max_idx]:
self.class_score_poss[int(label)].append([score, 1.0])
visited[max_idx] = True
else:
self.class_score_poss[int(label)].append([score, 0.0])
else:
self.class_score_poss[int(label)].append([score, 0.0])
def reset(self):
"""
Reset metric statics
"""
self.class_score_poss = [[] for _ in range(self.class_num)]
self.class_gt_counts = [0] * self.class_num
self.mAP = None
def accumulate(self):
"""
Accumulate metric results and calculate mAP
"""
mAP = 0.
valid_cnt = 0
for score_pos, count in zip(self.class_score_poss,
self.class_gt_counts):
if count == 0 or len(score_pos) == 0:
continue
accum_tp_list, accum_fp_list = \
self._get_tp_fp_accum(score_pos)
precision = []
recall = []
for ac_tp, ac_fp in zip(accum_tp_list, accum_fp_list):
precision.append(float(ac_tp) / (ac_tp + ac_fp))
recall.append(float(ac_tp) / count)
if self.map_type == '11point':
max_precisions = [0.] * 11
start_idx = len(precision) - 1
for j in range(10, -1, -1):
for i in range(start_idx, -1, -1):
if recall[i] < float(j) / 10.:
start_idx = i
if j > 0:
max_precisions[j - 1] = max_precisions[j]
break
else:
if max_precisions[j] < precision[i]:
max_precisions[j] = precision[i]
mAP += sum(max_precisions) / 11.
valid_cnt += 1
elif self.map_type == 'integral':
import math
ap = 0.
prev_recall = 0.
for i in range(len(precision)):
recall_gap = math.fabs(recall[i] - prev_recall)
if recall_gap > 1e-6:
ap += precision[i] * recall_gap
prev_recall = recall[i]
mAP += ap
valid_cnt += 1
else:
logger.error("Unspported mAP type {}".format(self.map_type))
sys.exit(1)
self.mAP = mAP / float(valid_cnt) if valid_cnt > 0 else mAP
def get_map(self):
"""
Get mAP result
"""
if self.mAP is None:
logger.error("mAP is not calculated.")
return self.mAP
def _get_tp_fp_accum(self, score_pos_list):
"""
Calculate accumulating true/false positive results from
[score, pos] records
"""
sorted_list = sorted(score_pos_list, key=lambda s: s[0], reverse=True)
accum_tp = 0
accum_fp = 0
accum_tp_list = []
accum_fp_list = []
for (score, pos) in sorted_list:
accum_tp += int(pos)
accum_tp_list.append(accum_tp)
accum_fp += 1 - int(pos)
accum_fp_list.append(accum_fp)
return accum_tp_list, accum_fp_list
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
import paddle.fluid as fluid
__all__ = ['nms']
logger = logging.getLogger(__name__)
def box_flip(boxes, im_shape):
im_width = im_shape[0][1]
flipped_boxes = boxes.copy()
flipped_boxes[:, 0::4] = im_width - boxes[:, 2::4] - 1
flipped_boxes[:, 2::4] = im_width - boxes[:, 0::4] - 1
return flipped_boxes
def nms(dets, thresh):
"""Apply classic DPM-style greedy NMS."""
if dets.shape[0] == 0:
return []
scores = dets[:, 0]
x1 = dets[:, 1]
y1 = dets[:, 2]
x2 = dets[:, 3]
y2 = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
ndets = dets.shape[0]
suppressed = np.zeros((ndets), dtype=np.int)
# nominal indices
# _i, _j
# sorted indices
# i, j
# temp variables for box i's (the box currently under consideration)
# ix1, iy1, ix2, iy2, iarea
# variables for computing overlap with box j (lower scoring box)
# xx1, yy1, xx2, yy2
# w, h
# inter, ovr
for _i in range(ndets):
i = order[_i]
if suppressed[i] == 1:
continue
ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i]
for _j in range(_i + 1, ndets):
j = order[_j]
if suppressed[j] == 1:
continue
xx1 = max(ix1, x1[j])
yy1 = max(iy1, y1[j])
xx2 = min(ix2, x2[j])
yy2 = min(iy2, y2[j])
w = max(0.0, xx2 - xx1 + 1)
h = max(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (iarea + areas[j] - inter)
if ovr >= thresh:
suppressed[j] = 1
return np.where(suppressed == 0)[0]
def bbox_area(box):
w = box[2] - box[0] + 1
h = box[3] - box[1] + 1
return w * h
def bbox_overlaps(x, y):
N = x.shape[0]
K = y.shape[0]
overlaps = np.zeros((N, K), dtype=np.float32)
for k in range(K):
y_area = bbox_area(y[k])
for n in range(N):
iw = min(x[n, 2], y[k, 2]) - max(x[n, 0], y[k, 0]) + 1
if iw > 0:
ih = min(x[n, 3], y[k, 3]) - max(x[n, 1], y[k, 1]) + 1
if ih > 0:
x_area = bbox_area(x[n])
ua = x_area + y_area - iw * ih
overlaps[n, k] = iw * ih / ua
return overlaps
def box_voting(nms_dets, dets, vote_thresh):
top_dets = nms_dets.copy()
top_boxes = nms_dets[:, 1:]
all_boxes = dets[:, 1:]
all_scores = dets[:, 0]
top_to_all_overlaps = bbox_overlaps(top_boxes, all_boxes)
for k in range(nms_dets.shape[0]):
inds_to_vote = np.where(top_to_all_overlaps[k] >= vote_thresh)[0]
boxes_to_vote = all_boxes[inds_to_vote, :]
ws = all_scores[inds_to_vote]
top_dets[k, 1:] = np.average(boxes_to_vote, axis=0, weights=ws)
return top_dets
def get_nms_result(boxes, scores, cfg):
cls_boxes = [[] for _ in range(cfg.num_classes)]
for j in range(1, cfg.num_classes):
inds = np.where(scores[:, j] > cfg.MultiScaleTEST['score_thresh'])[0]
scores_j = scores[inds, j]
boxes_j = boxes[inds, j * 4:(j + 1) * 4]
dets_j = np.hstack((scores_j[:, np.newaxis], boxes_j)).astype(
np.float32, copy=False)
keep = nms(dets_j, cfg.MultiScaleTEST['nms_thresh'])
nms_dets = dets_j[keep, :]
if cfg.MultiScaleTEST['enable_voting']:
nms_dets = box_voting(nms_dets, dets_j,
cfg.MultiScaleTEST['vote_thresh'])
#add labels
label = np.array([j for _ in range(len(keep))])
nms_dets = np.hstack((label[:, np.newaxis], nms_dets)).astype(
np.float32, copy=False)
cls_boxes[j] = nms_dets
# Limit to max_per_image detections **over all classes**
image_scores = np.hstack(
[cls_boxes[j][:, 1] for j in range(1, cfg.num_classes)])
if len(image_scores) > cfg.MultiScaleTEST['detections_per_im']:
image_thresh = np.sort(
image_scores)[-cfg.MultiScaleTEST['detections_per_im']]
for j in range(1, cfg.num_classes):
keep = np.where(cls_boxes[j][:, 1] >= image_thresh)[0]
cls_boxes[j] = cls_boxes[j][keep, :]
im_results = np.vstack([cls_boxes[j] for j in range(1, cfg.num_classes)])
return im_results
def mstest_box_post_process(result, cfg):
"""
Multi-scale Test
Only available for batch_size=1 now.
"""
post_bbox = {}
use_flip = False
ms_boxes = []
ms_scores = []
im_shape = result['im_shape'][0]
for k in result.keys():
if 'bbox' in k:
boxes = result[k][0]
boxes = np.reshape(boxes, (-1, 4 * cfg.num_classes))
scores = result['score' + k[4:]][0]
if 'flip' in k:
boxes = box_flip(boxes, im_shape)
use_flip = True
ms_boxes.append(boxes)
ms_scores.append(scores)
ms_boxes = np.concatenate(ms_boxes)
ms_scores = np.concatenate(ms_scores)
bbox_pred = get_nms_result(ms_boxes, ms_scores, cfg)
post_bbox.update({'bbox': (bbox_pred, [[len(bbox_pred)]])})
if use_flip:
bbox = bbox_pred[:, 2:]
bbox_flip = np.append(
bbox_pred[:, :2], box_flip(bbox, im_shape), axis=1)
post_bbox.update({'bbox_flip': (bbox_flip, [[len(bbox_flip)]])})
return post_bbox
def mstest_mask_post_process(result, cfg):
mask_list = []
im_shape = result['im_shape'][0]
M = cfg.FPNRoIAlign['mask_resolution']
for k in result.keys():
if 'mask' in k:
masks = result[k][0]
if len(masks.shape) != 4:
masks = np.zeros((0, M, M))
mask_list.append(masks)
continue
if 'flip' in k:
masks = masks[:, :, :, ::-1]
mask_list.append(masks)
mask_pred = np.mean(mask_list, axis=0)
return {'mask': (mask_pred, [[len(mask_pred)]])}
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import sys
import numpy as np
from ..data.source.voc_loader import pascalvoc_label
from .map_utils import DetectionMAP
from .coco_eval import bbox2out
import logging
logger = logging.getLogger(__name__)
__all__ = ['bbox_eval', 'bbox2out', 'get_category_info']
def bbox_eval(results,
class_num,
overlap_thresh=0.5,
map_type='11point',
is_bbox_normalized=False,
evaluate_difficult=False):
"""
Bounding box evaluation for VOC dataset
Args:
results (list): prediction bounding box results.
class_num (int): evaluation class number.
overlap_thresh (float): the postive threshold of
bbox overlap
map_type (string): method for mAP calcualtion,
can only be '11point' or 'integral'
is_bbox_normalized (bool): whether bbox is normalized
to range [0, 1].
evaluate_difficult (bool): whether to evaluate
difficult gt bbox.
"""
assert 'bbox' in results[0]
logger.info("Start evaluate...")
detection_map = DetectionMAP(
class_num=class_num,
overlap_thresh=overlap_thresh,
map_type=map_type,
is_bbox_normalized=is_bbox_normalized,
evaluate_difficult=evaluate_difficult)
for t in results:
bboxes = t['bbox'][0]
bbox_lengths = t['bbox'][1][0]
if bboxes.shape == (1, 1) or bboxes is None:
continue
gt_boxes = t['gt_box'][0]
gt_labels = t['gt_label'][0]
difficults = t['is_difficult'][0] if not evaluate_difficult \
else None
if len(t['gt_box'][1]) == 0:
# gt_box, gt_label, difficult read as zero padded Tensor
bbox_idx = 0
for i in range(len(gt_boxes)):
gt_box = gt_boxes[i]
gt_label = gt_labels[i]
difficult = None if difficults is None \
else difficults[i]
bbox_num = bbox_lengths[i]
bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
gt_box, gt_label, difficult = prune_zero_padding(
gt_box, gt_label, difficult)
detection_map.update(bbox, gt_box, gt_label, difficult)
bbox_idx += bbox_num
else:
# gt_box, gt_label, difficult read as LoDTensor
gt_box_lengths = t['gt_box'][1][0]
bbox_idx = 0
gt_box_idx = 0
for i in range(len(bbox_lengths)):
bbox_num = bbox_lengths[i]
gt_box_num = gt_box_lengths[i]
bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
gt_box = gt_boxes[gt_box_idx:gt_box_idx + gt_box_num]
gt_label = gt_labels[gt_box_idx:gt_box_idx + gt_box_num]
difficult = None if difficults is None else \
difficults[gt_box_idx: gt_box_idx + gt_box_num]
detection_map.update(bbox, gt_box, gt_label, difficult)
bbox_idx += bbox_num
gt_box_idx += gt_box_num
logger.info("Accumulating evaluatation results...")
detection_map.accumulate()
map_stat = 100. * detection_map.get_map()
logger.info("mAP({:.2f}, {}) = {:.2f}".format(overlap_thresh, map_type,
map_stat))
return map_stat
def prune_zero_padding(gt_box, gt_label, difficult=None):
valid_cnt = 0
for i in range(len(gt_box)):
if gt_box[i, 0] == 0 and gt_box[i, 1] == 0 and \
gt_box[i, 2] == 0 and gt_box[i, 3] == 0:
break
valid_cnt += 1
return (gt_box[:valid_cnt], gt_label[:valid_cnt],
difficult[:valid_cnt] if difficult is not None else None)
def get_category_info(anno_file=None,
with_background=True,
use_default_label=False):
if use_default_label or anno_file is None \
or not os.path.exists(anno_file):
logger.info("Not found annotation file {}, load "
"voc2012 categories.".format(anno_file))
return vocall_category_info(with_background)
else:
logger.info("Load categories from {}".format(anno_file))
return get_category_info_from_anno(anno_file, with_background)
def get_category_info_from_anno(anno_file, with_background=True):
"""
Get class id to category id map and category id
to category name map from annotation file.
Args:
anno_file (str): annotation file path
with_background (bool, default True):
whether load background as class 0.
"""
cats = []
with open(anno_file) as f:
for line in f.readlines():
cats.append(line.strip())
if cats[0] != 'background' and with_background:
cats.insert(0, 'background')
if cats[0] == 'background' and not with_background:
cats = cats[1:]
clsid2catid = {i: i for i in range(len(cats))}
catid2name = {i: name for i, name in enumerate(cats)}
return clsid2catid, catid2name
def vocall_category_info(with_background=True):
"""
Get class id to category id map and category id
to category name map of mixup voc dataset
Args:
with_background (bool, default True):
whether load background as class 0.
"""
label_map = pascalvoc_label(with_background)
label_map = sorted(label_map.items(), key=lambda x: x[1])
cats = [l[0] for l in label_map]
if with_background:
cats.insert(0, 'background')
clsid2catid = {i: i for i in range(len(cats))}
catid2name = {i: name for i, name in enumerate(cats)}
return clsid2catid, catid2name
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import os.path as osp
import re
import random
import shutil
__all__ = ['create_list']
def create_list(devkit_dir, years, output_dir):
"""
create following list:
1. trainval.txt
2. test.txt
"""
trainval_list = []
test_list = []
for year in years:
trainval, test = _walk_voc_dir(devkit_dir, year, output_dir)
trainval_list.extend(trainval)
test_list.extend(test)
random.shuffle(trainval_list)
with open(osp.join(output_dir, 'trainval.txt'), 'w') as ftrainval:
for item in trainval_list:
ftrainval.write(item[0] + ' ' + item[1] + '\n')
with open(osp.join(output_dir, 'test.txt'), 'w') as fval:
ct = 0
for item in test_list:
ct += 1
fval.write(item[0] + ' ' + item[1] + '\n')
def _get_voc_dir(devkit_dir, year, type):
return osp.join(devkit_dir, 'VOC' + year, type)
def _walk_voc_dir(devkit_dir, year, output_dir):
filelist_dir = _get_voc_dir(devkit_dir, year, 'ImageSets/Main')
annotation_dir = _get_voc_dir(devkit_dir, year, 'Annotations')
img_dir = _get_voc_dir(devkit_dir, year, 'JPEGImages')
trainval_list = []
test_list = []
added = set()
for _, _, files in os.walk(filelist_dir):
for fname in files:
img_ann_list = []
if re.match('[a-z]+_trainval\.txt', fname):
img_ann_list = trainval_list
elif re.match('[a-z]+_test\.txt', fname):
img_ann_list = test_list
else:
continue
fpath = osp.join(filelist_dir, fname)
for line in open(fpath):
name_prefix = line.strip().split()[0]
if name_prefix in added:
continue
added.add(name_prefix)
ann_path = osp.join(
osp.relpath(annotation_dir, output_dir),
name_prefix + '.xml')
img_path = osp.join(
osp.relpath(img_dir, output_dir), name_prefix + '.jpg')
img_ann_list.append((img_path, ann_path))
return trainval_list, test_list
......@@ -37,3 +37,4 @@ from .flowers import FlowersDataset as Flowers
from .stanford_dogs import StanfordDogsDataset as StanfordDogs
from .food101 import Food101Dataset as Food101
from .indoor67 import Indoor67Dataset as Indoor67
from .coco10 import Coco10
......@@ -25,6 +25,8 @@ import paddlehub as hub
from paddlehub.common.downloader import default_downloader
from paddlehub.common.logger import logger
from ..contrib.ppdet.data.source import build_source
from ..common import detection_config as dconf
class BaseCVDataset(BaseDataset):
def __init__(self,
......@@ -160,3 +162,86 @@ class ImageClassificationDataset(object):
def get_test_examples(self):
return self.test_examples
class ObjectDetectionDataset(ImageClassificationDataset):
def __init__(self, base_path, train_image_dir, train_list_file, validate_image_dir, validate_list_file,
test_image_dir, test_list_file, model_type='ssd'):
super(ObjectDetectionDataset, self).__init__()
self.base_path = base_path
self.train_image_dir = train_image_dir
self.train_list_file = train_list_file
self.validate_image_dir = validate_image_dir
self.validate_list_file = validate_list_file
self.test_image_dir = test_image_dir
self.test_list_file = test_list_file
self.model_type = model_type
self._dsc = None
self.cid2cname = None
self.label_dict() # refresh cid2cname and num_labels
assert self.cid2cname is not None
assert self.num_labels > 0
def label_dict(self):
if self.cid2cname is not None:
return self.cid2cname
# get label info from train data json
_ = self.train_data()
return self.cid2cname
def _parse_data(self, data_path, image_dir, shuffle=False, phase=None):
with_background = dconf.conf[self.model_type]['with_background']
mixup_epoch = -1
if phase == 'train':
mixup_epoch = dconf.conf[self.model_type].get('mixup_epoch', -1)
file_conf = {
'ANNO_FILE': data_path,
'IMAGE_DIR': image_dir,
# 'USE_DEFAULT_LABEL': feed.dataset.use_default_label,
'IS_SHUFFLE': shuffle,
'SAMPLES': -1,
'WITH_BACKGROUND': with_background,
'MIXUP_EPOCH': mixup_epoch,
'TYPE': 'RoiDbSource',
}
sc_conf = {'data_cf': file_conf, 'cname2cid': None}
data_source = build_source(sc_conf)
self._dsc = data_source
data_source.reset()
data = data_source._roidb
if self.cid2cname is None:
cname2cid = data_source.cname2cid
cid2cname = {v: k for k, v in cname2cid.items()}
self.cid2cname = cid2cname
if with_background:
self.num_labels = len(cid2cname) + 1
else:
self.num_labels = len(cid2cname)
if phase == 'train':
self.train_examples = data
elif phase == 'dev':
self.dev_examples = data
elif phase == 'test':
self.test_examples = data
return data_source
def train_data(self, shuffle=True):
train_data_path = os.path.join(self.base_path, self.train_list_file)
train_image_dir = os.path.join(self.base_path, self.train_image_dir)
return self._parse_data(
train_data_path, train_image_dir, shuffle, phase='train')
def test_data(self, shuffle=False):
test_data_path = os.path.join(self.base_path, self.test_list_file)
test_image_dir = os.path.join(self.base_path, self.test_image_dir)
return self._parse_data(
test_data_path, test_image_dir, shuffle, phase='dev')
def validate_data(self, shuffle=False):
validate_data_path = os.path.join(self.base_path,
self.validate_list_file)
validate_image_dir = os.path.join(self.base_path,
self.validate_image_dir)
return self._parse_data(
validate_data_path, validate_image_dir, shuffle, phase='test')
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddlehub as hub
from paddlehub.dataset.base_cv_dataset import ObjectDetectionDataset
class Coco10(ObjectDetectionDataset):
def __init__(self, model_type='ssd'):
dataset_path = os.path.join(hub.common.dir.DATA_HOME, "coco_10")
# self.base_path = self._download_dataset(
# dataset_path=dataset_path,
# url="https://bj.bcebos.com/paddlehub-dataset/dog-cat.tar.gz")
base_path = dataset_path
train_image_dir = 'val'
train_list_file = 'annotations/val.json'
validate_image_dir = 'val'
validate_list_file = 'annotations/val.json'
test_image_dir = 'val'
test_list_file = 'annotations/val.json'
super(Coco10, self).__init__(base_path, train_image_dir, train_list_file, validate_image_dir, validate_list_file,
test_image_dir, test_list_file, model_type)
......@@ -15,6 +15,7 @@
from .base_task import BaseTask, RunEnv, RunState
from .classifier_task import ClassifierTask, ImageClassifierTask, TextClassifierTask, MultiLabelClassifierTask
from .detection_task import DetectionTask
from .reading_comprehension_task import ReadingComprehensionTask
from .regression_task import RegressionTask
from .sequence_task import SequenceLabelTask
......@@ -344,6 +344,10 @@ class BaseTask(object):
# set default phase
self.enter_phase("train")
@property
def base_main_program(self):
return self._base_main_program
@contextlib.contextmanager
def phase_guard(self, phase):
self.enter_phase(phase)
......@@ -393,7 +397,7 @@ class BaseTask(object):
self._build_env_start_event()
self.env.is_inititalized = True
self.env.main_program = clone_program(
self._base_main_program, for_test=False)
self.base_main_program, for_test=False)
self.env.startup_program = fluid.Program()
with fluid.program_guard(self.env.main_program,
......@@ -406,8 +410,9 @@ class BaseTask(object):
self.env.metrics = self._add_metrics()
if self.is_predict_phase or self.is_test_phase:
self.env.main_program = clone_program(
self.env.main_program, for_test=True)
# Todo: paddle.fluid.core_avx.EnforceNotMet: Getting 'tensor_desc' is not supported by the type of var kCUDNNFwdAlgoCache. at
# self.env.main_program = clone_program(
# self.env.main_program, for_test=True)
hub.common.paddle_helper.set_op_attr(
self.env.main_program, is_test=True)
......@@ -1058,8 +1063,9 @@ class BaseTask(object):
capacity=64,
use_double_buffer=True,
iterable=True)
data_reader = data_loader.set_batch_generator(
self.reader, places=self.places)
data_reader = data_loader.set_sample_list_generator(self.reader, self.places[0])
# data_reader = data_loader.set_batch_generator(
# self.reader, places=self.places)
else:
data_feeder = fluid.DataFeeder(
feed_list=self.feed_list, place=self.place)
......@@ -1076,12 +1082,29 @@ class BaseTask(object):
step_run_state.run_step = 1
num_batch_examples = len(batch)
fetch_result = self.exe.run(
self.main_program_to_be_run,
feed=batch,
fetch_list=self.fetch_list,
return_numpy=self.return_numpy)
if not self.return_numpy:
if self.return_numpy == 2:
fetch_result = self.exe.run(
self.main_program_to_be_run,
feed=batch,
fetch_list=self.fetch_list,
return_numpy=False)
# fetch_result = [x if isinstance(x,fluid.LoDTensor) else np.array(x) for x in fetch_result]
fetch_result = [
x
if hasattr(x, 'recursive_sequence_lengths') else np.array(x)
for x in fetch_result
]
elif self.return_numpy:
fetch_result = self.exe.run(
self.main_program_to_be_run,
feed=batch,
fetch_list=self.fetch_list)
else:
fetch_result = self.exe.run(
self.main_program_to_be_run,
feed=batch,
fetch_list=self.fetch_list,
return_numpy=False)
fetch_result = [np.array(x) for x in fetch_result]
for index, result in enumerate(fetch_result):
......
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from collections import OrderedDict
import numpy as np
import six
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Normal, Xavier
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.initializer import MSRA
from .base_task import BaseTask
from ...contrib.ppdet.utils.eval_utils import eval_results
from ...common import detection_config as dconf
from paddlehub.common.paddle_helper import clone_program
feed_var_def = [
{
'name': 'im_info',
'shape': [3],
'dtype': 'float32',
'lod_level': 0
},
{
'name': 'im_id',
'shape': [1],
'dtype': 'int32',
'lod_level': 0
},
{
'name': 'gt_box',
'shape': [4],
'dtype': 'float32',
'lod_level': 1
},
{
'name': 'gt_label',
'shape': [1],
'dtype': 'int32',
'lod_level': 1
},
{
'name': 'is_crowd',
'shape': [1],
'dtype': 'int32',
'lod_level': 1
},
{
'name': 'gt_mask',
'shape': [2],
'dtype': 'float32',
'lod_level': 3
},
{
'name': 'is_difficult',
'shape': [1],
'dtype': 'int32',
'lod_level': 1
},
{
'name': 'gt_score',
'shape': [1],
'dtype': 'float32',
'lod_level': 0
},
{
'name': 'im_shape',
'shape': [3],
'dtype': 'float32',
'lod_level': 0
},
{
'name': 'im_size',
'shape': [2],
'dtype': 'int32',
'lod_level': 0
},
]
class Feed(object):
def __init__(self):
self.dataset = None
self.with_background = True
class DetectionTask(BaseTask):
def __init__(self,
data_reader,
num_classes,
feed_list,
feature,
model_type='ssd',
predict_feature=None,
predict_feed_list=None,
startup_program=None,
config=None,
hidden_units=None,
metrics_choices="default"):
if metrics_choices == "default":
metrics_choices = ["ap"]
main_program = feature[0].block.program
super(DetectionTask, self).__init__(
data_reader=data_reader,
main_program=main_program,
feed_list=feed_list,
startup_program=startup_program,
config=config,
metrics_choices=metrics_choices)
if predict_feature is not None:
main_program = predict_feature[0].block.program
self._predict_base_main_program = clone_program(
main_program, for_test=False)
else:
self._predict_base_main_program = None
self._predict_base_feed_list = predict_feed_list
self.feature = feature
self.predict_feature = predict_feature
self.num_classes = num_classes
self.hidden_units = hidden_units
self.model_type = model_type
@property
def base_main_program(self):
if not self.is_train_phase and self._predict_base_main_program is not None:
return self._predict_base_main_program
return self._base_main_program
@property
def base_feed_list(self):
if not self.is_train_phase and self._predict_base_feed_list is not None:
return self._predict_base_feed_list
return self._base_feed_list
@property
def base_feed_var_list(self):
vars = self.main_program.global_block().vars
return [vars[varname] for varname in self.base_feed_list]
@property
def return_numpy(self):
return 2 # return lod tensor
def _add_label_by_fields(self, idx_list):
feed_var_map = {var['name']: var for var in feed_var_def}
# tensor padding with 0 is used instead of LoD tensor when
# num_max_boxes is set
num_max_boxes = dconf.conf[self.model_type].get('num_max_boxes', None)
if num_max_boxes is not None:
feed_var_map['gt_label']['shape'] = [num_max_boxes]
feed_var_map['gt_score']['shape'] = [num_max_boxes]
feed_var_map['gt_box']['shape'] = [num_max_boxes, 4]
feed_var_map['is_difficult']['shape'] = [num_max_boxes]
feed_var_map['gt_label']['lod_level'] = 0
feed_var_map['gt_score']['lod_level'] = 0
feed_var_map['gt_box']['lod_level'] = 0
feed_var_map['is_difficult']['lod_level'] = 0
if self.is_train_phase:
fields = dconf.feed_config[self.model_type]['train']['fields']
elif self.is_test_phase:
fields = dconf.feed_config[self.model_type]['dev']['fields']
else: # Cannot go to here
# raise RuntimeError("Cannot go to _add_label in predict phase")
fields = dconf.feed_config[self.model_type]['predict']['fields']
labels = []
for i in idx_list:
key = fields[i]
l = fluid.layers.data(
name=feed_var_map[key]['name'],
shape=feed_var_map[key]['shape'],
dtype=feed_var_map[key]['dtype'],
lod_level=feed_var_map[key]['lod_level'])
labels.append(l)
return labels
def _ssd_build_net(self):
feature_list = self.feature
image = self.base_feed_var_list[0]
# fix input size according to its module
mbox_locs, mbox_confs, box, box_var = fluid.layers.multi_box_head(
inputs=feature_list,
image=image,
num_classes=self.num_classes,
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
base_size=512, # 300,
min_sizes=[20.0, 51.0, 133.0, 215.0, 296.0, 378.0, 460.0], # [60.0, 105.0, 150.0, 195.0, 240.0, 285.0],
max_sizes=[51.0, 133.0, 215.0, 296.0, 378.0, 460.0, 542.0], # [[], 150.0, 195.0, 240.0, 285.0, 300.0],
steps=[8, 16, 32, 64, 128, 256, 512],
min_ratio=15,
max_ratio=90,
kernel_size=3,
offset=0.5,
flip=True,
pad=1,
)
self.env.mid_vars = [mbox_locs, mbox_confs, box, box_var]
nmsed_out = fluid.layers.detection_output(
mbox_locs,
mbox_confs,
box,
box_var,
background_label=0,
nms_threshold=0.45,
nms_top_k=400,
keep_top_k=200,
score_threshold=0.01,
nms_eta=1.0)
if self.is_predict_phase: # add im_id
self.env.labels = self._ssd_add_label()
return [nmsed_out]
def _ssd_add_label(self):
# train: 'gt_box', 'gt_label'
# dev: 'im_shape', 'im_id', 'gt_box', 'gt_label', 'is_difficult'
if self.is_train_phase:
idx_list = [1, 2] # 'gt_box', 'gt_label'
elif self.is_test_phase:
# xTodo: remove 'im_shape' when using new module
idx_list = [2, 3, 4, 5] # 'im_id', 'gt_box', 'gt_label', 'is_difficult'
else:
idx_list = [1] # im_id
return self._add_label_by_fields(idx_list)
def _ssd_add_loss(self):
if self.is_train_phase:
gt_box = self.labels[0]
gt_label = self.labels[1]
else: # xTodo: update here when using new module
gt_box = self.labels[1]
gt_label = self.labels[2]
mbox_locs, mbox_confs, box, box_var = self.env.mid_vars
loss = fluid.layers.ssd_loss(
location=mbox_locs,
confidence=mbox_confs,
gt_box=gt_box,
gt_label=gt_label,
prior_box=box,
prior_box_var=box_var)
loss = fluid.layers.reduce_sum(loss)
loss.persistable = True
return loss
def _ssd_feed_list(self, for_export=False):
# xTodo: update when using new module
feed_list = [varname for varname in self.base_feed_list]
if self.is_train_phase:
feed_list = feed_list[:1] + [label.name for label in self.labels]
elif self.is_test_phase:
feed_list = feed_list + [label.name for label in self.labels]
else: # self.is_predict_phase:
if for_export:
feed_list = [feed_list[0]]
else:
# 'image', 'im_id', 'im_shape'
feed_list = [feed_list[0], self.labels[0].name, feed_list[1]]
return feed_list
def _ssd_fetch_list(self, for_export=False):
# ensure fetch 'im_shape', 'im_id', 'bbox' at first three elements in test phase
if self.is_train_phase:
return [self.loss.name]
elif self.is_test_phase:
# xTodo: update when using new module
# im_id, bbox, dets, loss
return [
self.base_feed_list[1], self.labels[0].name, self.outputs[0].name,
self.loss.name]
# im_shape, im_id, bbox
if for_export:
return [self.outputs[0].name]
else:
return [self.base_feed_list[1], self.labels[0].name, self.outputs[0].name]
def _rcnn_build_net(self):
if self.is_train_phase:
head_feat = self.feature[0]
else:
head_feat = self.predict_feature[0]
# Rename following layers for: ValueError: Variable cls_score_w has been created before.
# the previous shape is (2048, 81); the new shape is (100352, 81).
# They are not matched.
cls_score = fluid.layers.fc(input=head_feat,
size=self.num_classes,
act=None,
name='my_cls_score',
param_attr=ParamAttr(
name='my_cls_score_w',
initializer=Normal(
loc=0.0, scale=0.01)),
bias_attr=ParamAttr(
name='my_cls_score_b',
learning_rate=2.,
regularizer=L2Decay(0.)))
bbox_pred = fluid.layers.fc(input=head_feat,
size=4 * self.num_classes,
act=None,
name='my_bbox_pred',
param_attr=ParamAttr(
name='my_bbox_pred_w',
initializer=Normal(
loc=0.0, scale=0.001)),
bias_attr=ParamAttr(
name='my_bbox_pred_b',
learning_rate=2.,
regularizer=L2Decay(0.)))
if self.is_train_phase:
rpn_cls_loss, rpn_reg_loss, outs = self.feature[1:]
labels_int32 = outs[1]
bbox_targets = outs[2]
bbox_inside_weights = outs[3]
bbox_outside_weights = outs[4]
labels_int64 = fluid.layers.cast(x=labels_int32, dtype='int64')
labels_int64.stop_gradient = True
loss_cls = fluid.layers.softmax_with_cross_entropy(
logits=cls_score, label=labels_int64, numeric_stable_mode=True)
loss_cls = fluid.layers.reduce_mean(loss_cls)
loss_bbox = fluid.layers.smooth_l1(
x=bbox_pred,
y=bbox_targets,
inside_weight=bbox_inside_weights,
outside_weight=bbox_outside_weights,
sigma=1.0)
loss_bbox = fluid.layers.reduce_mean(loss_bbox)
total_loss = fluid.layers.sum([loss_bbox, loss_cls, rpn_cls_loss, rpn_reg_loss])
return [total_loss]
else:
rois = self.predict_feature[1]
im_info = self.base_feed_var_list[1]
im_shape = self.base_feed_var_list[2]
im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3])
im_scale = fluid.layers.sequence_expand(im_scale, rois)
boxes = rois / im_scale
cls_prob = fluid.layers.softmax(cls_score, use_cudnn=False)
bbox_pred = fluid.layers.reshape(bbox_pred, (-1, self.num_classes, 4))
# decoded_box = self.box_coder(prior_box=boxes, target_box=bbox_pred)
decoded_box = fluid.layers.box_coder(
prior_box=boxes, prior_box_var=[0.1, 0.1, 0.2, 0.2],
target_box=bbox_pred, code_type='decode_center_size',
box_normalized=False, axis=1)
cliped_box = fluid.layers.box_clip(input=decoded_box, im_info=im_shape)
# pred_result = self.nms(bboxes=cliped_box, scores=cls_prob)
pred_result = fluid.layers.multiclass_nms(
bboxes=cliped_box, scores=cls_prob,
score_threshold=.05,
nms_top_k=-1,
keep_top_k=100,
nms_threshold=.5,
normalized=False,
nms_eta=1.0,
background_label=0
)
if self.is_predict_phase:
self.env.labels = self._rcnn_add_label()
return [pred_result]
def _rcnn_add_label(self):
if self.is_train_phase:
idx_list = [2,] # 'im_id'
elif self.is_test_phase:
idx_list = [2, 4, 5, 6] # 'im_id', 'gt_box', 'gt_label', 'is_difficult'
else: # predict
idx_list = [2]
return self._add_label_by_fields(idx_list)
def _rcnn_add_loss(self):
if self.is_train_phase:
loss = self.env.outputs[-1]
else:
loss = fluid.layers.fill_constant(shape=[1], value=-1, dtype='float32')
return loss
def _rcnn_feed_list(self, for_export=False):
feed_list = [varname for varname in self.base_feed_list]
if self.is_train_phase:
# feed_list is ['image', 'im_info', 'gt_box', 'gt_label', 'is_crowd']
return feed_list[:2] + [self.labels[0].name] + feed_list[2:]
elif self.is_test_phase:
# feed list is ['image', 'im_info', 'im_shape']
return feed_list[:2] + [self.labels[0].name] + feed_list[2:] + \
[label.name for label in self.labels[1:]]
if for_export:
# skip im_id
return feed_list[:2] + feed_list[3:]
else:
return feed_list[:2] + [self.labels[0].name] + feed_list[2:]
def _rcnn_fetch_list(self, for_export=False):
# ensure fetch 'im_shape', 'im_id', 'bbox' at first three elements in test phase
if self.is_train_phase:
return [self.loss.name]
elif self.is_test_phase:
# im_shape, im_id, bbox
return [self.feed_list[2], self.labels[0].name, self.outputs[0].name, self.loss.name]
# im_shape, im_id, bbox
if for_export:
return [self.outputs[0].name]
else:
return [self.feed_list[2], self.labels[0].name, self.outputs[0].name]
def _yolo_parse_anchors(self, anchors):
"""
Check ANCHORS/ANCHOR_MASKS in config and parse mask_anchors
"""
self.anchors = []
self.mask_anchors = []
assert len(anchors) > 0, "ANCHORS not set."
assert len(self.anchor_masks) > 0, "ANCHOR_MASKS not set."
for anchor in anchors:
assert len(anchor) == 2, "anchor {} len should be 2".format(anchor)
self.anchors.extend(anchor)
anchor_num = len(anchors)
for masks in self.anchor_masks:
self.mask_anchors.append([])
for mask in masks:
assert mask < anchor_num, "anchor mask index overflow"
self.mask_anchors[-1].extend(anchors[mask])
def _yolo_build_net(self):
self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]]
self._yolo_parse_anchors(anchors)
tip_list = self.feature
outputs = []
for i, tip in enumerate(tip_list):
# out channel number = mask_num * (5 + class_num)
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5)
block_out = fluid.layers.conv2d(
input=tip,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
act=None,
# Rename for: conflict with module pretrain weights
param_attr=ParamAttr(name="ft_yolo_output.{}.conv.weights".format(i)),
bias_attr=ParamAttr(
regularizer=L2Decay(0.),
name="ft_yolo_output.{}.conv.bias".format(i)))
outputs.append(block_out)
if self.is_train_phase:
return outputs
im_size = self.base_feed_var_list[1]
boxes = []
scores = []
downsample = 32
for i, output in enumerate(outputs):
box, score = fluid.layers.yolo_box(
x=output,
img_size=im_size,
anchors=self.mask_anchors[i],
class_num=self.num_classes,
conf_thresh=0.01,
downsample_ratio=downsample,
name="yolo_box" + str(i))
boxes.append(box)
scores.append(fluid.layers.transpose(score, perm=[0, 2, 1]))
downsample //= 2
yolo_boxes = fluid.layers.concat(boxes, axis=1)
yolo_scores = fluid.layers.concat(scores, axis=2)
# pred = self.nms(bboxes=yolo_boxes, scores=yolo_scores)
pred = fluid.layers.multiclass_nms(
bboxes=yolo_boxes, scores=yolo_scores,
score_threshold=.01,
nms_top_k=1000,
keep_top_k=100,
nms_threshold=0.45,
normalized=False,
nms_eta=1.0,
background_label=-1)
if self.is_predict_phase:
self.env.labels = self._yolo_add_label()
return [pred]
def _yolo_add_label(self):
if self.is_train_phase:
idx_list = [1, 2, 3] # 'gt_box', 'gt_label', 'gt_score'
elif self.is_test_phase:
idx_list = [2, 3, 4, 5] # 'im_id', 'gt_box', 'gt_label', 'is_difficult'
else: # predict
idx_list = [2]
return self._add_label_by_fields(idx_list)
def _yolo_add_loss(self):
if self.is_train_phase:
gt_box, gt_label, gt_score = self.labels
outputs = self.outputs
losses = []
downsample = 32
for i, output in enumerate(outputs):
anchor_mask = self.anchor_masks[i]
loss = fluid.layers.yolov3_loss(
x=output,
gt_box=gt_box,
gt_label=gt_label,
gt_score=gt_score,
anchors=self.anchors,
anchor_mask=anchor_mask,
class_num=self.num_classes,
ignore_thresh=0.7,
downsample_ratio=downsample,
use_label_smooth=True,
name="yolo_loss" + str(i))
losses.append(fluid.layers.reduce_mean(loss))
downsample //= 2
loss = sum(losses)
else:
loss = fluid.layers.fill_constant(shape=[1], value=-1, dtype='float32')
return loss
def _yolo_feed_list(self, for_export=False):
feed_list = [varname for varname in self.base_feed_list]
if self.is_train_phase:
return [feed_list[0]] + [label.name for label in self.labels]
elif self.is_test_phase:
return feed_list + [label.name for label in self.labels]
if for_export:
return feed_list[:2]
else:
return feed_list + [self.labels[0].name]
def _yolo_fetch_list(self, for_export=False):
# ensure fetch 'im_shape', 'im_id', 'bbox' at first three elements in test phase
if self.is_train_phase:
return [self.loss.name]
elif self.is_test_phase:
# im_shape, im_id, bbox
return [self.feed_list[1], self.labels[0].name, self.outputs[0].name, self.loss.name]
# im_shape, im_id, bbox
if for_export:
return [self.outputs[0].name]
else:
return [self.feed_list[1], self.labels[0].name, self.outputs[0].name]
def _build_net(self):
if self.model_type == 'ssd':
outputs = self._ssd_build_net()
elif self.model_type == 'rcnn':
outputs = self._rcnn_build_net()
elif self.model_type == 'yolo':
outputs = self._yolo_build_net()
else:
raise NotImplementedError
return outputs
def _add_label(self):
if self.model_type == 'ssd':
labels = self._ssd_add_label()
elif self.model_type == 'rcnn':
labels = self._rcnn_add_label()
elif self.model_type == 'yolo':
labels = self._yolo_add_label()
else:
raise NotImplementedError
return labels
def _add_loss(self):
if self.model_type == 'ssd':
loss = self._ssd_add_loss()
elif self.model_type == 'rcnn':
loss = self._rcnn_add_loss()
elif self.model_type == 'yolo':
loss = self._yolo_add_loss()
else:
raise NotImplementedError
return loss
def _add_metrics(self):
return []
@property
def feed_list(self):
return self._feed_list(False)
def _feed_list(self, for_export=False):
if self.model_type == 'ssd':
return self._ssd_feed_list(for_export)
elif self.model_type == 'rcnn':
return self._rcnn_feed_list(for_export)
elif self.model_type == 'yolo':
return self._yolo_feed_list(for_export)
else:
raise NotImplementedError
@property
def fetch_list(self):
# ensure fetch loss at last element in train/test phase
# ensure fetch 'im_shape', 'im_id', 'bbox' at first three elements in test phase
return self._fetch_list(False)
def _fetch_list(self, for_export=False):
if self.model_type == 'ssd':
return self._ssd_fetch_list(for_export)
elif self.model_type == 'rcnn':
return self._rcnn_fetch_list(for_export)
elif self.model_type == 'yolo':
return self._yolo_fetch_list(for_export)
else:
raise NotImplementedError
@property
def fetch_var_list(self):
fetch_list = self._fetch_list(True)
vars = self.main_program.global_block().vars
return [vars[varname] for varname in fetch_list]
@property
def labels(self):
if not self.env.is_inititalized:
self._build_env()
return self.env.labels
def save_inference_model(self,
dirname,
model_filename=None,
params_filename=None):
with self.phase_guard("predict"):
fluid.io.save_inference_model(
dirname=dirname,
executor=self.exe,
feeded_var_names=self._feed_list(for_export=True),
target_vars=self.fetch_var_list,
main_program=self.main_program,
model_filename=model_filename,
params_filename=params_filename)
def _calculate_metrics(self, run_states):
loss_sum = run_examples = 0
run_step = run_time_used = 0
for run_state in run_states:
run_examples += run_state.run_examples
run_step += run_state.run_step
loss_sum += np.mean(np.array(
run_state.run_results[-1])) * run_state.run_examples
run_time_used = time.time() - run_states[0].run_time_begin
avg_loss = loss_sum / run_examples
run_speed = run_step / run_time_used
# The first key will be used as main metrics to update the best model
scores = OrderedDict()
if self.is_train_phase:
return scores, avg_loss, run_speed
keys = ['im_shape', 'im_id', 'bbox']
results = []
for run_state in run_states:
outs = [
run_state.run_results[0], run_state.run_results[1],
run_state.run_results[2]
]
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(keys, outs)
}
results.append(res)
is_bbox_normalized = dconf.conf[self.model_type]['is_bbox_normalized']
eval_feed = Feed()
eval_feed.with_background = dconf.conf[self.model_type]['with_background']
eval_feed.dataset = self.reader
for metric in self.metrics_choices:
if metric == "ap":
box_ap_stats = eval_results(results, eval_feed, 'COCO',
self.num_classes, None,
is_bbox_normalized, None, None)
print("box_ap_stats", box_ap_stats)
scores["ap"] = box_ap_stats[0]
else:
raise ValueError("Not Support Metric: \"%s\"" % metric)
return scores, avg_loss, run_speed
......@@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddlehub.module.checkinfo',
syntax='proto3',
serialized_pb=_b(
'\n\x10\x63heck_info.proto\x12\x1apaddlehub.module.checkinfo\"\x85\x01\n\x08\x46ileInfo\x12\x11\n\tfile_name\x18\x01 \x01(\t\x12\x33\n\x04type\x18\x02 \x01(\x0e\x32%.paddlehub.module.checkinfo.FILE_TYPE\x12\x0f\n\x07is_need\x18\x03 \x01(\x08\x12\x0b\n\x03md5\x18\x04 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t\"\x84\x01\n\x08Requires\x12>\n\x0crequire_type\x18\x01 \x01(\x0e\x32(.paddlehub.module.checkinfo.REQUIRE_TYPE\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\ngreat_than\x18\x03 \x01(\x08\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\xc8\x01\n\tCheckInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\x12\x1c\n\x14module_proto_version\x18\x03 \x01(\t\x12\x38\n\nfile_infos\x18\x04 \x03(\x0b\x32$.paddlehub.module.checkinfo.FileInfo\x12\x36\n\x08requires\x18\x05 \x03(\x0b\x32$.paddlehub.module.checkinfo.Requires*\x1e\n\tFILE_TYPE\x12\x08\n\x04\x46ILE\x10\x00\x12\x07\n\x03\x44IR\x10\x01*[\n\x0cREQUIRE_TYPE\x12\x12\n\x0ePYTHON_PACKAGE\x10\x00\x12\x0e\n\nHUB_MODULE\x10\x01\x12\n\n\x06SYSTEM\x10\x02\x12\x0b\n\x07\x43OMMAND\x10\x03\x12\x0e\n\nPY_VERSION\x10\x04\x42\x02H\x03\x62\x06proto3'
'\n\x10\x63heck_info.proto\x12\x1apaddlehub.module.checkinfo\"\x85\x01\n\x08\x46ileInfo\x12\x11\n\tfile_name\x18\x01 \x01(\t\x12\x33\n\x04type\x18\x02 \x01(\x0e\x32%.paddlehub.module.checkinfo.FILE_TYPE\x12\x0f\n\x07is_need\x18\x03 \x01(\x08\x12\x0b\n\x03md5\x18\x04 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t\"\x84\x01\n\x08Requires\x12>\n\x0crequire_type\x18\x01 \x01(\x0e\x32(.paddlehub.module.checkinfo.REQUIRE_TYPE\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\ngreat_than\x18\x03 \x01(\x08\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\xe5\x01\n\tCheckInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\x12\x1c\n\x14module_proto_version\x18\x03 \x01(\t\x12\x38\n\nfile_infos\x18\x04 \x03(\x0b\x32$.paddlehub.module.checkinfo.FileInfo\x12\x36\n\x08requires\x18\x05 \x03(\x0b\x32$.paddlehub.module.checkinfo.Requires\x12\x1b\n\x13module_code_version\x18\x06 \x01(\t*\x1e\n\tFILE_TYPE\x12\x08\n\x04\x46ILE\x10\x00\x12\x07\n\x03\x44IR\x10\x01*[\n\x0cREQUIRE_TYPE\x12\x12\n\x0ePYTHON_PACKAGE\x10\x00\x12\x0e\n\nHUB_MODULE\x10\x01\x12\n\n\x06SYSTEM\x10\x02\x12\x0b\n\x07\x43OMMAND\x10\x03\x12\x0e\n\nPY_VERSION\x10\x04\x42\x02H\x03\x62\x06proto3'
))
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
......@@ -35,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=522,
serialized_end=552,
serialized_start=551,
serialized_end=581,
)
_sym_db.RegisterEnumDescriptor(_FILE_TYPE)
......@@ -60,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=554,
serialized_end=645,
serialized_start=583,
serialized_end=674,
)
_sym_db.RegisterEnumDescriptor(_REQUIRE_TYPE)
......@@ -346,6 +346,22 @@ _CHECKINFO = _descriptor.Descriptor(
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='module_code_version',
full_name='paddlehub.module.checkinfo.CheckInfo.module_code_version',
index=5,
number=6,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
......@@ -356,7 +372,7 @@ _CHECKINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[],
serialized_start=320,
serialized_end=520,
serialized_end=549,
)
_FILEINFO.fields_by_name['type'].enum_type = _FILE_TYPE
......
......@@ -23,6 +23,8 @@ from PIL import Image
import paddlehub.io.augmentation as image_augmentation
from .base_reader import BaseReader
from ..contrib.ppdet.data.reader import Reader
from ..common import detection_config as dconf
channel_order_dict = {
"RGB": [0, 1, 2],
......@@ -181,3 +183,90 @@ class ImageClassificationReader(BaseReader):
labels = []
return _data_reader
class ObjectDetectionReader(ImageClassificationReader):
def __init__(self,
dataset=None,
model_type='ssd',
channel_order="RGB",
worker_num=1,
use_process=False,
):
super(ObjectDetectionReader,
self).__init__(1, 1, dataset, channel_order,
None, None, None)
self.model_type = model_type
self.worker_num = worker_num
self.use_process = use_process
def data_generator(self,
batch_size,
phase="train",
shuffle=False,
data=None,
return_list=False
):
if phase != 'predict' and not self.dataset:
raise ValueError("The dataset is none and it's not allowed!")
drop_last = False
if phase == "train":
data_src = self.dataset.train_data(shuffle)
self.num_examples['train'] = len(self.get_train_examples())
drop_last = True
elif phase == "test":
shuffle = False
data_src = self.dataset.test_data(shuffle)
self.num_examples['test'] = len(self.get_test_examples())
elif phase == "val" or phase == "dev":
shuffle = False
data_src = self.dataset.validate_data(shuffle)
self.num_examples['dev'] = len(self.get_dev_examples())
else: # phase == "predict":
from ..contrib.ppdet.data.source import build_source
data_config = {
"IMAGES": data,
"TYPE": "SimpleSource"
}
data_src = build_source(data_config)
data_cf = {}
transform_config = {
'WORKER_CONF': {
'bufsize': 20,
'worker_num': self.worker_num,
'use_process': self.use_process,
'memsize': '3G'
},
'BATCH_SIZE': batch_size,
'DROP_LAST': drop_last,
'USE_PADDED_IM_INFO': False,
}
phase_trans = {
"val": "dev",
"test": "dev",
"inference": "predict"
}
if phase in phase_trans:
phase = phase_trans[phase]
assert phase in ('train', 'dev', 'predict')
feed_config = dconf.feed_config[self.model_type][phase]
transform_config.update(feed_config) # add 'OPS' etc.
ppdet_mode = 'VAL' if phase != 'train' else 'TRAIN'
_batch_reader = Reader.create(
ppdet_mode, data_cf, transform_config, my_source=data_src)
# return itr
# When call `_batch_reader()`, then return generator(or iterator)
def batch_reader():
"""batch reader"""
for b in _batch_reader():
if return_list:
yield [b]
else:
yield b
batch_reader.annotation = _batch_reader.annotation
return batch_reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册