提交 60f23e73 编写于 作者: S sunyanfang01

add blazeface

上级 17f8b256
......@@ -18,4 +18,5 @@ from .coco import CocoDetection
from .seg_dataset import SegDataset
from .easydata_cls import EasyDataCls
from .easydata_det import EasyDataDet
from .easydata_seg import EasyDataSeg
\ No newline at end of file
from .easydata_seg import EasyDataSeg
from .widerface import WIDERFACEDetection
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import copy
import os.path as osp
import random
import cv2
import numpy as np
from collections import OrderedDict
import xml.etree.ElementTree as ET
import paddlex.utils.logging as logging
from .voc import VOCDetection
from .dataset import is_pic
from .dataset import get_encoding
class WIDERFACEDetection(VOCDetection):
"""读取WIDER Face格式的检测数据集,并对样本进行相应的处理。
Args:
data_dir (str): 数据集所在的目录路径。
ann_file (str): 数据集的标注文件,为一个独立的txt格式文件。
transforms (paddlex.det.transforms): 数据集中每个样本的预处理/增强算子。
num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。
buffer_size (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。
parallel_method (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'
线程和'process'进程两种方式。默认为'process'(Windows和Mac下会强制使用thread,该参数无效)。
shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
"""
def __init__(self,
data_dir,
ann_file,
transforms=None,
num_workers='auto',
buffer_size=100,
parallel_method='process',
shuffle=False):
super(VOCDetection, self).__init__(
transforms=transforms,
num_workers=num_workers,
buffer_size=buffer_size,
parallel_method=parallel_method,
shuffle=shuffle)
self.file_list = list()
self.labels = list()
self._epoch = 0
self.labels.append('face')
valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png']
from pycocotools.coco import COCO
annotations = {}
annotations['images'] = []
annotations['categories'] = []
annotations['annotations'] = []
annotations['categories'].append({
'supercategory': 'component',
'id': 1,
'name': 'face'
})
logging.info("Starting to read file list from dataset...")
im_ct = 0
ann_ct = 0
is_discard = False
with open(ann_file, 'r', encoding=get_encoding(ann_file)) as fr:
lines_txt = fr.readlines()
for line in lines_txt:
line = line.strip('\n\t\r')
if any(suffix in line for suffix in valid_suffix):
img_file = osp.join(data_dir, line)
if not is_pic(img_file):
is_discard = False
continue
else:
is_discard = True
im = cv2.imread(img_file)
im_w = im.shape[1]
im_h = im.shape[0]
im_info = {
'im_id': np.array([im_ct]),
'image_shape': np.array([im_h, im_w]).astype('int32'),
}
bbox_id = 0
annotations['images'].append({
'height':
im_h,
'width':
im_w,
'id':
im_ct,
'file_name':
osp.split(img_file)[1]
})
elif ' ' not in line:
if not is_discard:
continue
bbox_ct = int(line)
if bbox_ct == 0:
is_discard = False
continue
gt_bbox = np.zeros((bbox_ct, 4), dtype=np.float32)
gt_class = np.ones((bbox_ct, 1), dtype=np.int32)
difficult = np.zeros((bbox_ct, 1), dtype=np.int32)
else:
if not is_discard:
continue
split_str = line.split(' ')
xmin = float(split_str[0])
ymin = float(split_str[1])
w = float(split_str[2])
h = float(split_str[3])
# Filter out wrong labels
if w < 0 or h < 0:
logging.warning('Illegal box with w: {}, h: {} in '
'img: {}, and it will be ignored'.format(
w, h, img_file))
gt_class[bbox_id, 0] = 0
bbox_id += 1
continue
xmin = max(0, xmin)
ymin = max(0, ymin)
xmax = xmin + w
ymax = ymin + h
gt_bbox[bbox_id] = [xmin, ymin, xmax, ymax]
bbox_id += 1
annotations['annotations'].append({
'iscrowd': 0,
'image_id': im_ct,
'bbox': [xmin, ymin, w, h],
'area': float(w * h),
'category_id': 1,
'id': ann_ct,
'difficult': 0
})
ann_ct += 1
if bbox_id == bbox_ct:
label_info = {
'gt_class': gt_class,
'gt_bbox': gt_bbox,
'difficult': difficult
}
voc_rec = (im_info, label_info)
self.file_list.append([img_file, voc_rec])
im_ct += 1
self.coco_gt = COCO()
self.coco_gt.dataset = annotations
self.coco_gt.createIndex()
\ No newline at end of file
......@@ -39,6 +39,7 @@ from .base import BaseAPI
from .yolo_v3 import YOLOv3
from .faster_rcnn import FasterRCNN
from .mask_rcnn import MaskRCNN
from .blazeface import BlazeFace
from .unet import UNet
from .deeplabv3p import DeepLabv3p
from .hrnet import HRNet
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
import math
import tqdm
import os.path as osp
import numpy as np
import paddle.fluid as fluid
import paddlex.utils.logging as logging
import paddlex
from .base import BaseAPI
from collections import OrderedDict
import copy
class BlazeFace(BaseAPI):
"""构建BlazeFace,并实现其训练、评估、预测和模型导出。
Args:
num_classes (int): 类别数。默认为2。
backbone (str): YOLOv3的backbone网络,取值范围为['BlazeNet']。默认为'BlazeNet'。
nms_iou_threshold (float): 进行NMS时,用于剔除检测框IOU的阈值。默认为0.3。
nms_topk (int): 进行NMS时,根据置信度保留的最大检测框数。默认为5000。
nms_keep_topk (int): 进行NMS后,每个图像要保留的总检测框数。默认为750。
nms_score_threshold (float): 检测框的置信度得分阈值,置信度得分低于阈值的框应该被忽略。默认为0.01。
min_sizes (list): 候选框最小size组成的列表,当use_density_prior_box为False时使用。
densities (list): 生成候选框的密度,当use_density_prior_box为True时使用。
use_density_prior_box (bool): 是否使用密度方式获取候选框。默认为False。
"""
def __init__(self,
num_classes=2,
backbone='BlazeNet',
nms_iou_threshold=0.3,
nms_topk=5000,
nms_keep_topk=750,
nms_score_threshold=0.01,
min_sizes=[[16.,24.], [32., 48., 64., 80., 96., 128.]],
densities=[[2, 2], [2, 1, 1, 1, 1, 1]],
use_density_prior_box=False):
self.init_params = locals()
super(BlazeFace, self).__init__('detector')
backbones = [
'BlazeNet'
]
assert backbone in backbones, "backbone should be one of {}".format(
backbones)
self.backbone = backbone
self.num_classes = num_classes
self.nms_iou_threshold = nms_iou_threshold
self.nms_topk = nms_topk
self.nms_keep_topk = nms_keep_topk
self.nms_score_threshold = nms_score_threshold
self.min_sizes = min_sizes
self.densities = densities
self.use_density_prior_box = use_density_prior_box
self.fixed_input_shape = None
def _get_backbone(self, backbone_name):
if backbone_name == 'BlazeNet':
backbone = paddlex.cv.nets.BlazeNet()
return backbone
def build_net(self, mode='train'):
model = paddlex.cv.nets.detection.BlazeFace(
backbone=self._get_backbone(self.backbone),
min_sizes=self.min_sizes,
num_classes=self.num_classes,
use_density_prior_box=self.use_density_prior_box,
densities=self.densities,
nms_threshold=self.nms_iou_threshold,
nms_topk=self.nms_topk,
nms_keep_topk=self.nms_score_threshold,
score_threshold=self.nms_score_threshold,
fixed_input_shape=self.fixed_input_shape)
inputs = model.generate_inputs()
model_out = model.build_net(inputs)
outputs = OrderedDict([('bbox', model_out)])
if mode == 'train':
self.optimizer.minimize(model_out)
outputs = OrderedDict([('loss', model_out)])
return inputs, outputs
def default_optimizer(self, learning_rate,
lr_decay_epochs, lr_decay_gamma,
num_steps_each_epoch):
boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
values = [(lr_decay_gamma**i) * learning_rate
for i in range(len(lr_decay_epochs) + 1)]
lr_decay = fluid.layers.piecewise_decay(
boundaries=boundaries, values=values)
optimizer = fluid.optimizer.RMSPropOptimizer(
learning_rate=lr_decay,
momentum=0.0,
regularization=fluid.regularizer.L2DecayRegularizer(5e-04))
return optimizer
def train(self,
num_epochs,
train_dataset,
train_batch_size=2,
eval_dataset=None,
save_interval_epochs=1,
log_interval_steps=20,
save_dir='output',
pretrain_weights=None,
optimizer=None,
learning_rate=0.0025,
lr_decay_epochs=[597, 746],
lr_decay_gamma=0.1,
metric='COCO',
use_vdl=False,
early_stop=False,
early_stop_patience=5,
resume_checkpoint=None):
"""训练。
Args:
num_epochs (int): 训练迭代轮数。
train_dataset (paddlex.datasets): 训练数据读取器。
train_batch_size (int): 训练数据batch大小。目前检测仅支持单卡评估,训练数据batch大小与
显卡数量之商为验证数据batch大小。默认为2。
eval_dataset (paddlex.datasets): 验证数据读取器。
save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为1。
log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为20。
save_dir (str): 模型保存路径。默认值为'output'。
pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为None,则不使用预训练模型。默认为None。
optimizer (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认优化器:
fluid.layers.piecewise_decay衰减策略,fluid.optimizer.Momentum优化方法。
learning_rate (float): 默认优化器的初始学习率。默认为0.001。
lr_decay_epochs (list): 默认优化器的学习率衰减轮数。默认为[597, 746]。
lr_decay_gamma (float): 默认优化器的学习率衰减率。默认为0.1。
metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。
use_vdl (bool): 是否使用VisualDL进行可视化。默认值为False。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises:
ValueError: 评估类型不在指定列表中。
ValueError: 模型从inference model进行加载。
"""
if metric is None:
if isinstances(train_dataset, paddlex.datasets.WIDERFACEDetection):
metric = 'WIDERFACE'
elif isinstance(train_dataset, paddlex.datasets.CocoDetection):
metric = 'COCO'
elif isinstance(train_dataset, paddlex.datasets.VOCDetection) or \
isinstance(train_dataset, paddlex.datasets.EasyDataDet):
metric = 'VOC'
else:
raise ValueError(
"train_dataset should be datasets.VOCDetection or datasets.COCODetection or datasets.EasyDataDet."
)
assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
self.metric = metric
if not self.trainable:
raise ValueError("Model is not trainable from load_model method.")
self.labels = copy.deepcopy(train_dataset.labels)
self.labels.insert(0, 'background')
# 构建训练网络
if optimizer is None:
# 构建默认的优化策略
num_steps_each_epoch = train_dataset.num_samples // train_batch_size
optimizer = self.default_optimizer(
learning_rate, lr_decay_epochs,
lr_decay_gamma, num_steps_each_epoch)
self.optimizer = optimizer
# 构建训练、验证、测试网络
self.build_program()
self.net_initialize(
startup_prog=fluid.default_startup_program(),
pretrain_weights=pretrain_weights,
save_dir=save_dir,
resume_checkpoint=resume_checkpoint)
# 训练
self.train_loop(
num_epochs=num_epochs,
train_dataset=train_dataset,
train_batch_size=train_batch_size,
eval_dataset=eval_dataset,
save_interval_epochs=save_interval_epochs,
log_interval_steps=log_interval_steps,
save_dir=save_dir,
use_vdl=use_vdl,
early_stop=early_stop,
early_stop_patience=early_stop_patience)
def evaluate(self,
eval_dataset,
batch_size=1,
epoch_id=None,
metric=None,
return_details=False):
"""评估。
Args:
eval_dataset (paddlex.datasets): 验证数据读取器。
batch_size (int): 验证数据批大小。默认为1。当前只支持设置为1。
epoch_id (int): 当前评估模型所在的训练轮数。
metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,
根据用户传入的Dataset自动选择,如为VOCDetection,则metric为'VOC';
如为COCODetection,则metric为'COCO'。
return_details (bool): 是否返回详细信息。默认值为False。
Returns:
tuple (metrics, eval_details) /dict (metrics): 当return_details为True时,返回(metrics, eval_details),
当return_details为False时,返回metrics。metrics为dict,包含关键字:'bbox_mmap'或者’bbox_map‘,
分别表示平均准确率平均值在各个阈值下的结果取平均值的结果(mmAP)、平均准确率平均值(mAP)。
eval_details为dict,包含关键字:'bbox',对应元素预测结果列表,每个预测结果由图像id、
预测框类别id、预测框坐标、预测框得分;’gt‘:真实标注框相关信息。
"""
self.arrange_transforms(transforms=eval_dataset.transforms, mode='eval')
if metric is None:
if hasattr(self, 'metric') and self.metric is not None:
metric = self.metric
else:
if isinstance(eval_dataset, paddlex.datasets.CocoDetection):
metric = 'COCO'
elif isinstance(eval_dataset, paddlex.datasets.VOCDetection):
metric = 'VOC'
elif isinstances(train_dataset, paddlex.datasets.WIDERFACEDetection):
metric = 'WIDERFACE'
logging.info("The metric of WIDERFACE is not supported. This will be implemented soon. " \
+ "Now only support 'VOC' or 'COCO'")
exit(0)
else:
raise Exception(
"eval_dataset should be datasets.VOCDetection or datasets.COCODetection."
)
assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
dataset = eval_dataset.generator(batch_size=batch_size, drop_last=False)
total_steps = math.ceil(eval_dataset.num_samples * 1.0 / batch_size)
results = list()
logging.info("Start to evaluating(total_samples={}, total_steps={})...".
format(eval_dataset.num_samples, total_steps))
for step, data in tqdm.tqdm(enumerate(dataset()), total=total_steps):
images = np.array([d[0] for d in data]).astype('float32')
feed_data = {
'image': images,
}
outputs = self.exe.run(self.test_prog,
feed=[feed_data],
fetch_list=list(self.test_outputs.values()),
return_numpy=False)
res = {
'bbox': (np.array(outputs[0]),
outputs[0].recursive_sequence_lengths())
}
res_im_id = [d[4] for d in data]
res['im_id'] = (np.array(res_im_id), [])
if metric == 'VOC':
res_gt_box = []
res_gt_label = []
res_is_difficult = []
for d in data:
res_gt_box.extend(d[1])
res_gt_label.extend(d[2])
res_is_difficult.extend(d[3])
res_gt_box_lod = [d[1].shape[0] for d in data]
res_gt_label_lod = [d[2].shape[0] for d in data]
res_is_difficult_lod = [d[3].shape[0] for d in data]
res['gt_box'] = (np.array(res_gt_box), [res_gt_box_lod])
res['gt_label'] = (np.array(res_gt_label), [res_gt_label_lod])
res['is_difficult'] = (np.array(res_is_difficult),
[res_is_difficult_lod])
results.append(res)
logging.debug("[EVAL] Epoch={}, Step={}/{}".format(epoch_id, step +
1, total_steps))
box_ap_stats, eval_details = eval_results(
results, metric, eval_dataset.coco_gt, with_background=True, is_bbox_normalized=True)
metrics = OrderedDict(
zip(['bbox_mmap'
if metric == 'COCO' else 'bbox_map'], box_ap_stats))
if return_details:
return metrics, eval_details
return metrics
def predict(self, img_file, transforms=None):
"""预测。
Args:
img_file (str): 预测图像路径。
transforms (paddlex.det.transforms): 数据预处理操作。
Returns:
list: 预测结果列表,每个预测结果由预测框类别标签、
预测框类别名称、预测框坐标(坐标格式为[xmin, ymin, w, h])、
预测框得分组成。
"""
if transforms is None and not hasattr(self, 'test_transforms'):
raise Exception("transforms need to be defined, now is None.")
if transforms is not None:
self.arrange_transforms(transforms=transforms, mode='test')
im, im_resize_info, im_shape = transforms(img_file)
else:
self.arrange_transforms(
transforms=self.test_transforms, mode='test')
im, im_shape = self.test_transforms(img_file)
im = np.expand_dims(im, axis=0)
im_shape = np.expand_dims(im_shape, axis=0)
outputs = self.exe.run(self.test_prog,
feed={
'image': im,
'im_shape': im_shape
},
fetch_list=list(self.test_outputs.values()),
return_numpy=False,
use_program_cache=True)
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(self.test_outputs.keys()), outputs)
}
res['im_id'] = (np.array([[0]]).astype('int32'), [])
clsid2catid = dict({i: i for i in range(self.num_classes)})
xywh_results = bbox2out([res], clsid2catid, is_bbox_normalized=True)
results = list()
for xywh_res in xywh_results:
del xywh_res['image_id']
xywh_res['category'] = self.labels[xywh_res['category_id']]
results.append(xywh_res)
return results
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from ppdet.experimental import mixed_precision_global_state
from ppdet.core.workspace import register
class BlazeNet(object):
"""
BlazeFace, see https://arxiv.org/abs/1907.05047
Args:
blaze_filters (list): number of filter for each blaze block
double_blaze_filters (list): number of filter for each double_blaze block
with_extra_blocks (bool): whether or not extra blocks should be added
lite_edition (bool): whether or not is blazeface-lite
use_5x5kernel (bool): whether or not filter size is 5x5 in depth-wise conv
"""
def __init__(self,
blaze_filters=[[24, 24], [24, 24], [24, 48, 2], [48, 48], [48, 48]],
double_blaze_filters=[[48, 24, 96, 2], [96, 24, 96], [96, 24, 96],
[96, 24, 96, 2], [96, 24, 96], [96, 24, 96]],
with_extra_blocks=True,
lite_edition=False,
use_5x5kernel=True):
self.blaze_filters = blaze_filters
self.double_blaze_filters = double_blaze_filters
self.with_extra_blocks = with_extra_blocks
self.lite_edition = lite_edition
self.use_5x5kernel = use_5x5kernel
def __call__(self, input):
if not self.lite_edition:
conv1_num_filters = self.blaze_filters[0][0]
conv = self._conv_norm(
input=input,
num_filters=conv1_num_filters,
filter_size=3,
stride=2,
padding=1,
act='relu',
name="conv1")
for k, v in enumerate(self.blaze_filters):
assert len(v) in [2, 3], \
"blaze_filters {} not in [2, 3]"
if len(v) == 2:
conv = self.BlazeBlock(
conv,
v[0],
v[1],
use_5x5kernel=self.use_5x5kernel,
name='blaze_{}'.format(k))
elif len(v) == 3:
conv = self.BlazeBlock(
conv,
v[0],
v[1],
stride=v[2],
use_5x5kernel=self.use_5x5kernel,
name='blaze_{}'.format(k))
layers = []
for k, v in enumerate(self.double_blaze_filters):
assert len(v) in [3, 4], \
"blaze_filters {} not in [3, 4]"
if len(v) == 3:
conv = self.BlazeBlock(
conv,
v[0],
v[1],
double_channels=v[2],
use_5x5kernel=self.use_5x5kernel,
name='double_blaze_{}'.format(k))
elif len(v) == 4:
layers.append(conv)
conv = self.BlazeBlock(
conv,
v[0],
v[1],
double_channels=v[2],
stride=v[3],
use_5x5kernel=self.use_5x5kernel,
name='double_blaze_{}'.format(k))
layers.append(conv)
if not self.with_extra_blocks:
return layers[-1]
return layers[-2], layers[-1]
else:
conv1 = self._conv_norm(
input=input,
num_filters=24,
filter_size=5,
stride=2,
padding=2,
act='relu',
name="conv1")
conv2 = self.Blaze_lite(conv1, 24, 24, 1, 'conv2')
conv3 = self.Blaze_lite(conv2, 24, 28, 1, 'conv3')
conv4 = self.Blaze_lite(conv3, 28, 32, 2, 'conv4')
conv5 = self.Blaze_lite(conv4, 32, 36, 1, 'conv5')
conv6 = self.Blaze_lite(conv5, 36, 42, 1, 'conv6')
conv7 = self.Blaze_lite(conv6, 42, 48, 2, 'conv7')
in_ch = 48
for i in range(5):
conv7 = self.Blaze_lite(conv7, in_ch, in_ch + 8, 1,
'conv{}'.format(8 + i))
in_ch += 8
assert in_ch == 88
conv13 = self.Blaze_lite(conv7, 88, 96, 2, 'conv13')
for i in range(4):
conv13 = self.Blaze_lite(conv13, 96, 96, 1,
'conv{}'.format(14 + i))
return conv7, conv13
def BlazeBlock(self,
input,
in_channels,
out_channels,
double_channels=None,
stride=1,
use_5x5kernel=True,
name=None):
assert stride in [1, 2]
use_pool = not stride == 1
use_double_block = double_channels is not None
act = 'relu' if use_double_block else None
mixed_precision_enabled = mixed_precision_global_state() is not None
if use_5x5kernel:
conv_dw = self._conv_norm(
input=input,
filter_size=5,
num_filters=in_channels,
stride=stride,
padding=2,
num_groups=in_channels,
use_cudnn=mixed_precision_enabled,
name=name + "1_dw")
else:
conv_dw_1 = self._conv_norm(
input=input,
filter_size=3,
num_filters=in_channels,
stride=1,
padding=1,
num_groups=in_channels,
use_cudnn=mixed_precision_enabled,
name=name + "1_dw_1")
conv_dw = self._conv_norm(
input=conv_dw_1,
filter_size=3,
num_filters=in_channels,
stride=stride,
padding=1,
num_groups=in_channels,
use_cudnn=mixed_precision_enabled,
name=name + "1_dw_2")
conv_pw = self._conv_norm(
input=conv_dw,
filter_size=1,
num_filters=out_channels,
stride=1,
padding=0,
act=act,
name=name + "1_sep")
if use_double_block:
if use_5x5kernel:
conv_dw = self._conv_norm(
input=conv_pw,
filter_size=5,
num_filters=out_channels,
stride=1,
padding=2,
use_cudnn=mixed_precision_enabled,
name=name + "2_dw")
else:
conv_dw_1 = self._conv_norm(
input=conv_pw,
filter_size=3,
num_filters=out_channels,
stride=1,
padding=1,
num_groups=out_channels,
use_cudnn=mixed_precision_enabled,
name=name + "2_dw_1")
conv_dw = self._conv_norm(
input=conv_dw_1,
filter_size=3,
num_filters=out_channels,
stride=1,
padding=1,
num_groups=out_channels,
use_cudnn=mixed_precision_enabled,
name=name + "2_dw_2")
conv_pw = self._conv_norm(
input=conv_dw,
filter_size=1,
num_filters=double_channels,
stride=1,
padding=0,
name=name + "2_sep")
# shortcut
if use_pool:
shortcut_channel = double_channels or out_channels
shortcut_pool = self._pooling_block(input, stride, stride)
channel_pad = self._conv_norm(
input=shortcut_pool,
filter_size=1,
num_filters=shortcut_channel,
stride=1,
padding=0,
name="shortcut" + name)
return fluid.layers.elementwise_add(
x=channel_pad, y=conv_pw, act='relu')
return fluid.layers.elementwise_add(x=input, y=conv_pw, act='relu')
def Blaze_lite(self, input, in_channels, out_channels, stride=1, name=None):
assert stride in [1, 2]
use_pool = not stride == 1
ues_pad = not in_channels == out_channels
conv_dw = self._conv_norm(
input=input,
filter_size=3,
num_filters=in_channels,
stride=stride,
padding=1,
num_groups=in_channels,
name=name + "_dw")
conv_pw = self._conv_norm(
input=conv_dw,
filter_size=1,
num_filters=out_channels,
stride=1,
padding=0,
name=name + "_sep")
if use_pool:
shortcut_pool = self._pooling_block(input, stride, stride)
if ues_pad:
conv_pad = shortcut_pool if use_pool else input
channel_pad = self._conv_norm(
input=conv_pad,
filter_size=1,
num_filters=out_channels,
stride=1,
padding=0,
name="shortcut" + name)
return fluid.layers.elementwise_add(
x=channel_pad, y=conv_pw, act='relu')
return fluid.layers.elementwise_add(x=input, y=conv_pw, act='relu')
def _conv_norm(
self,
input,
filter_size,
num_filters,
stride,
padding,
num_groups=1,
act='relu', # None
use_cudnn=True,
name=None):
parameter_attr = ParamAttr(
learning_rate=0.1,
initializer=fluid.initializer.MSRA(),
name=name + "_weights")
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=parameter_attr,
bias_attr=False)
return fluid.layers.batch_norm(input=conv, act=act)
def _pooling_block(self,
conv,
pool_size,
pool_stride,
pool_padding=0,
ceil_mode=True):
pool = fluid.layers.pool2d(
input=conv,
pool_size=pool_size,
pool_type='max',
pool_stride=pool_stride,
pool_padding=pool_padding,
ceil_mode=ceil_mode)
return pool
......@@ -15,3 +15,4 @@
from .yolo_v3 import YOLOv3
from .faster_rcnn import FasterRCNN
from .mask_rcnn import MaskRCNN
from .blazeface import BlazeFace
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from collections import OrderedDict
class BlazeFace:
def __init__(self,
backbone,
min_sizes=[[16., 24.], [32., 48., 64., 80., 96., 128.]],
max_sizes=None,
steps=[8., 16.],
num_classes=2,
use_density_prior_box=False,
densities=[[2, 2], [2, 1, 1, 1, 1, 1]],
nms_threshold=0.3,
nms_topk=5000,
nms_keep_topk=750,
score_threshold=0.01,
nms_eta=1.0,
fixed_input_shape=None):
self.backbone = backbone
self.num_classes = num_classes
self.output_decoder = output_decoder
self.min_sizes = min_sizes
self.max_sizes = max_sizes
self.steps = steps
self.use_density_prior_box = use_density_prior_box
self.densities = densities
self.fixed_input_shape = fixed_input_shape
self.nms_threshold = nms_threshold
self.nms_topk = nms_topk
self.nms_keep_topk = nms_keep_topk
self.score_threshold = score_threshold
self.nms_eta = nms_eta
self.background_label = 0
def _multi_box_head(self,
inputs,
image,
num_classes=2,
use_density_prior_box=False):
def permute_and_reshape(input, last_dim):
trans = fluid.layers.transpose(input, perm=[0, 2, 3, 1])
compile_shape = [0, -1, last_dim]
return fluid.layers.reshape(trans, shape=compile_shape)
def _is_list_or_tuple_(data):
return (isinstance(data, list) or isinstance(data, tuple))
locs, confs = [], []
boxes, vars = [], []
b_attr = ParamAttr(learning_rate=2., regularizer=L2Decay(0.))
for i, input in enumerate(inputs):
min_size = self.min_sizes[i]
if use_density_prior_box:
densities = self.densities[i]
box, var = fluid.layers.density_prior_box(
input,
image,
densities=densities,
fixed_sizes=min_size,
fixed_ratios=[1.],
clip=False,
offset=0.5,
steps=[self.steps[i]] * 2)
else:
box, var = fluid.layers.prior_box(
input,
image,
min_sizes=min_size,
max_sizes=None,
steps=[self.steps[i]] * 2,
aspect_ratios=[1.],
clip=False,
flip=False,
offset=0.5)
num_boxes = box.shape[2]
box = fluid.layers.reshape(box, shape=[-1, 4])
var = fluid.layers.reshape(var, shape=[-1, 4])
num_loc_output = num_boxes * 4
num_conf_output = num_boxes * num_classes
# get loc
mbox_loc = fluid.layers.conv2d(
input, num_loc_output, 3, 1, 1, bias_attr=b_attr)
loc = permute_and_reshape(mbox_loc, 4)
# get conf
mbox_conf = fluid.layers.conv2d(
input, num_conf_output, 3, 1, 1, bias_attr=b_attr)
conf = permute_and_reshape(mbox_conf, 2)
locs.append(loc)
confs.append(conf)
boxes.append(box)
vars.append(var)
face_mbox_loc = fluid.layers.concat(locs, axis=1)
face_mbox_conf = fluid.layers.concat(confs, axis=1)
prior_boxes = fluid.layers.concat(boxes)
box_vars = fluid.layers.concat(vars)
return face_mbox_loc, face_mbox_conf, prior_boxes, box_vars
def generate_inputs(self):
inputs = OrderedDict()
if self.fixed_input_shape is not None:
input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
]
inputs['image'] = fluid.data(
dtype='float32', shape=input_shape, name='image')
else:
inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image')
if self.mode == 'train':
inputs['gt_box'] = fluid.data(
dtype='float32', shape=[None, None, 4], lod_level=1, name='gt_box')
inputs['gt_label'] = fluid.data(
dtype='int32', shape=[None, None], lod_level=1, name='gt_label')
inputs['im_size'] = fluid.data(
dtype='int32', shape=[None, 2], name='im_size')
elif self.mode == 'eval':
inputs['gt_box'] = fluid.data(
dtype='float32', shape=[None, None, 4], lod_level=1, name='gt_box')
inputs['gt_label'] = fluid.data(
dtype='int32', shape=[None, None], lod_level=1, name='gt_label')
inputs['is_difficult'] = fluid.data(
dtype='int32', shape=[None, 1], lod_level=1, name='is_difficult')
inputs['im_id'] = fluid.data(
dtype='int32', shape=[None, 1], name='im_id')
elif self.mode == 'test':
inputs['im_size'] = fluid.data(
dtype='int32', shape=[None, 2], name='im_size')
return inputs
def build_net(self, inputs):
image = inputs['image']
if self.mode == 'train':
gt_bbox = inputs['gt_bbox']
gt_label = inputs['gt_label']
im_size = inputs['im_size']
num_boxes = fluid.layers.shape(gt_box)[1]
im_size_wh = fluid.layers.reverse(im_size, axis=1)
whwh = fluid.layers.concat([im_size_wh, im_size_wh], axis=1)
whwh = fluid.layers.unsqueeze(whwh, axes=[1])
whwh = fluid.layers.expand(whwh, expand_times=[1, num_boxes, 1])
whwh = fluid.layers.cast(whwh, dtype='float32')
whwh.stop_gradient = True
normalized_box = fluid.layers.elementwise_div(gt_box, whwh)
body_feats = self.backbone(image)
locs, confs, box, box_var = self._multi_box_head(
inputs=body_feats,
image=image,
num_classes=self.num_classes,
use_density_prior_box=self.use_density_prior_box)
if mode == 'train':
loss = fluid.layers.ssd_loss(
locs,
confs,
gt_bbox,
gt_label,
box,
box_var,
overlap_threshold=0.35,
neg_overlap=0.35)
loss = fluid.layers.reduce_sum(loss)
loss.persistable = True
return loss
else:
pred = fluid.layers.detection_output(
locs,
confs,
box,
box_var,
background_label=self.background_label,
nms_threshold=self.nms_threshold,
nms_top_k=self.nms_keep_topk,
keep_top_k=self.nms_keep_topk,
score_threshold=self.score_threshold,
nms_eta=self.nms_eta)
return pred
\ No newline at end of file
......@@ -221,3 +221,242 @@ def segms_horizontal_flip(segms, height, width):
import pycocotools.mask as mask_util
flipped_segms.append(_flip_rle(segm, height, width))
return flipped_segms
def data_anchor_sampling(bbox_labels, image_width, image_height, scale_array,
resize_width):
num_gt = len(bbox_labels)
# np.random.randint range: [low, high)
rand_idx = np.random.randint(0, num_gt) if num_gt != 0 else 0
if num_gt != 0:
norm_xmin = bbox_labels[rand_idx][0]
norm_ymin = bbox_labels[rand_idx][1]
norm_xmax = bbox_labels[rand_idx][2]
norm_ymax = bbox_labels[rand_idx][3]
xmin = norm_xmin * image_width
ymin = norm_ymin * image_height
wid = image_width * (norm_xmax - norm_xmin)
hei = image_height * (norm_ymax - norm_ymin)
range_size = 0
area = wid * hei
for scale_ind in range(0, len(scale_array) - 1):
if area > scale_array[scale_ind] ** 2 and area < \
scale_array[scale_ind + 1] ** 2:
range_size = scale_ind + 1
break
if area > scale_array[len(scale_array) - 2]**2:
range_size = len(scale_array) - 2
scale_choose = 0.0
if range_size == 0:
rand_idx_size = 0
else:
# np.random.randint range: [low, high)
rng_rand_size = np.random.randint(0, range_size + 1)
rand_idx_size = rng_rand_size % (range_size + 1)
if rand_idx_size == range_size:
min_resize_val = scale_array[rand_idx_size] / 2.0
max_resize_val = min(2.0 * scale_array[rand_idx_size],
2 * math.sqrt(wid * hei))
scale_choose = random.uniform(min_resize_val, max_resize_val)
else:
min_resize_val = scale_array[rand_idx_size] / 2.0
max_resize_val = 2.0 * scale_array[rand_idx_size]
scale_choose = random.uniform(min_resize_val, max_resize_val)
sample_bbox_size = wid * resize_width / scale_choose
w_off_orig = 0.0
h_off_orig = 0.0
if sample_bbox_size < max(image_height, image_width):
if wid <= sample_bbox_size:
w_off_orig = np.random.uniform(xmin + wid - sample_bbox_size,
xmin)
else:
w_off_orig = np.random.uniform(xmin,
xmin + wid - sample_bbox_size)
if hei <= sample_bbox_size:
h_off_orig = np.random.uniform(ymin + hei - sample_bbox_size,
ymin)
else:
h_off_orig = np.random.uniform(ymin,
ymin + hei - sample_bbox_size)
else:
w_off_orig = np.random.uniform(image_width - sample_bbox_size, 0.0)
h_off_orig = np.random.uniform(image_height - sample_bbox_size, 0.0)
w_off_orig = math.floor(w_off_orig)
h_off_orig = math.floor(h_off_orig)
# Figure out top left coordinates.
w_off = float(w_off_orig / image_width)
h_off = float(h_off_orig / image_height)
sampled_bbox = [
w_off, h_off, w_off + float(sample_bbox_size / image_width),
h_off + float(sample_bbox_size / image_height)
]
return sampled_bbox
else:
return 0
def bbox_area_sampling(bboxes, labels, scores, target_size, min_size):
new_bboxes = []
new_labels = []
new_scores = []
for i, bbox in enumerate(bboxes):
w = float((bbox[2] - bbox[0]) * target_size)
h = float((bbox[3] - bbox[1]) * target_size)
if w * h < float(min_size * min_size):
continue
else:
new_bboxes.append(bbox)
new_labels.append(labels[i])
if scores is not None and scores.size != 0:
new_scores.append(scores[i])
bboxes = np.array(new_bboxes)
labels = np.array(new_labels)
scores = np.array(new_scores)
return bboxes, labels, scores
def satisfy_sample_constraint_coverage(sampler, sample_bbox, gt_bboxes):
if sampler[6] == 0 and sampler[7] == 0:
has_jaccard_overlap = False
else:
has_jaccard_overlap = True
if sampler[8] == 0 and sampler[9] == 0:
has_object_coverage = False
else:
has_object_coverage = True
if not has_jaccard_overlap and not has_object_coverage:
return True
found = False
for i in range(len(gt_bboxes)):
object_bbox = [
gt_bboxes[i][0], gt_bboxes[i][1], gt_bboxes[i][2], gt_bboxes[i][3]
]
if has_jaccard_overlap:
overlap = jaccard_overlap(sample_bbox, object_bbox)
if sampler[6] != 0 and \
overlap < sampler[6]:
continue
if sampler[7] != 0 and \
overlap > sampler[7]:
continue
found = True
if has_object_coverage:
object_coverage = bbox_coverage(object_bbox, sample_bbox)
if sampler[8] != 0 and \
object_coverage < sampler[8]:
continue
if sampler[9] != 0 and \
object_coverage > sampler[9]:
continue
found = True
if found:
return True
return found
def filter_and_process(sample_bbox, bboxes, labels, scores=None):
new_bboxes = []
new_labels = []
new_scores = []
for i in range(len(bboxes)):
new_bbox = [0, 0, 0, 0]
obj_bbox = [bboxes[i][0], bboxes[i][1], bboxes[i][2], bboxes[i][3]]
if not meet_emit_constraint(obj_bbox, sample_bbox):
continue
if not is_overlap(obj_bbox, sample_bbox):
continue
sample_width = sample_bbox[2] - sample_bbox[0]
sample_height = sample_bbox[3] - sample_bbox[1]
new_bbox[0] = (obj_bbox[0] - sample_bbox[0]) / sample_width
new_bbox[1] = (obj_bbox[1] - sample_bbox[1]) / sample_height
new_bbox[2] = (obj_bbox[2] - sample_bbox[0]) / sample_width
new_bbox[3] = (obj_bbox[3] - sample_bbox[1]) / sample_height
new_bbox = clip_bbox(new_bbox)
if bbox_area(new_bbox) > 0:
new_bboxes.append(new_bbox)
new_labels.append([labels[i][0]])
if scores is not None:
new_scores.append([scores[i][0]])
bboxes = np.array(new_bboxes)
labels = np.array(new_labels)
scores = np.array(new_scores)
return bboxes, labels, scores
def crop_image_sampling(img, sample_bbox, image_width, image_height,
target_size):
# no clipping here
xmin = int(sample_bbox[0] * image_width)
xmax = int(sample_bbox[2] * image_width)
ymin = int(sample_bbox[1] * image_height)
ymax = int(sample_bbox[3] * image_height)
w_off = xmin
h_off = ymin
width = xmax - xmin
height = ymax - ymin
cross_xmin = max(0.0, float(w_off))
cross_ymin = max(0.0, float(h_off))
cross_xmax = min(float(w_off + width - 1.0), float(image_width))
cross_ymax = min(float(h_off + height - 1.0), float(image_height))
cross_width = cross_xmax - cross_xmin
cross_height = cross_ymax - cross_ymin
roi_xmin = 0 if w_off >= 0 else abs(w_off)
roi_ymin = 0 if h_off >= 0 else abs(h_off)
roi_width = cross_width
roi_height = cross_height
roi_y1 = int(roi_ymin)
roi_y2 = int(roi_ymin + roi_height)
roi_x1 = int(roi_xmin)
roi_x2 = int(roi_xmin + roi_width)
cross_y1 = int(cross_ymin)
cross_y2 = int(cross_ymin + cross_height)
cross_x1 = int(cross_xmin)
cross_x2 = int(cross_xmin + cross_width)
sample_img = np.zeros((height, width, 3))
sample_img[roi_y1: roi_y2, roi_x1: roi_x2] = \
img[cross_y1: cross_y2, cross_x1: cross_x2]
sample_img = cv2.resize(
sample_img, (target_size, target_size), interpolation=cv2.INTER_AREA)
return sample_img
def generate_sample_bbox_square(sampler, image_width, image_height):
scale = np.random.uniform(sampler[2], sampler[3])
aspect_ratio = np.random.uniform(sampler[4], sampler[5])
aspect_ratio = max(aspect_ratio, (scale**2.0))
aspect_ratio = min(aspect_ratio, 1 / (scale**2.0))
bbox_width = scale * (aspect_ratio**0.5)
bbox_height = scale / (aspect_ratio**0.5)
if image_height < image_width:
bbox_width = bbox_height * image_height / image_width
else:
bbox_height = bbox_width * image_width / image_height
xmin_bound = 1 - bbox_width
ymin_bound = 1 - bbox_height
xmin = np.random.uniform(0, xmin_bound)
ymin = np.random.uniform(0, ymin_bound)
xmax = xmin + bbox_width
ymax = ymin + bbox_height
sampled_bbox = [xmin, ymin, xmax, ymax]
return sampled_bbox
\ No newline at end of file
......@@ -503,7 +503,7 @@ class Normalize(DetTransform):
TypeError: 形参数据类型不满足需求。
"""
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], is_scale=True):
self.mean = mean
self.std = std
if not (isinstance(self.mean, list) and isinstance(self.std, list)):
......@@ -511,6 +511,7 @@ class Normalize(DetTransform):
from functools import reduce
if reduce(lambda x, y: x * y, self.std) == 0:
raise TypeError('NormalizeImage: std is invalid!')
self.is_scale = is_scale
def __call__(self, im, im_info=None, label_info=None):
"""
......@@ -526,7 +527,7 @@ class Normalize(DetTransform):
"""
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im = normalize(im, mean, std)
im = normalize(im, mean, std, self.is_scale)
if label_info is None:
return (im, im_info)
else:
......@@ -558,7 +559,8 @@ class RandomDistort(DetTransform):
saturation_range=0.5,
saturation_prob=0.5,
hue_range=18,
hue_prob=0.5):
hue_prob=0.5,
is_order=False):
self.brightness_range = brightness_range
self.brightness_prob = brightness_prob
self.contrast_range = contrast_range
......@@ -567,6 +569,7 @@ class RandomDistort(DetTransform):
self.saturation_prob = saturation_prob
self.hue_range = hue_range
self.hue_prob = hue_prob
self.is_order = is_order
def __call__(self, im, im_info=None, label_info=None):
"""
......@@ -589,7 +592,8 @@ class RandomDistort(DetTransform):
hue_lower = -self.hue_range
hue_upper = self.hue_range
ops = [brightness, contrast, saturation, hue]
random.shuffle(ops)
if not self.is_order:
random.shuffle(ops)
params_dict = {
'brightness': {
'brightness_lower': brightness_lower,
......@@ -767,12 +771,14 @@ class RandomExpand(DetTransform):
ratio (float): 图像扩张的最大比例。默认为4.0。
prob (float): 随机扩张的概率。默认为0.5。
fill_value (list): 扩张图像的初始填充值(0-255)。默认为[123.675, 116.28, 103.53]。
filter_bbox (bool): 是否对新的框进行过滤。默认为False。
"""
def __init__(self,
ratio=4.,
prob=0.5,
fill_value=[123.675, 116.28, 103.53]):
fill_value=[123.675, 116.28, 103.53],
filter_bbox=False):
super(RandomExpand, self).__init__()
assert ratio > 1.01, "expand ratio must be larger than 1.01"
self.ratio = ratio
......@@ -782,6 +788,7 @@ class RandomExpand(DetTransform):
if not isinstance(fill_value, tuple):
fill_value = tuple(fill_value)
self.fill_value = fill_value
self.filter_bbox = filter_bbox
def __call__(self, im, im_info=None, label_info=None):
"""
......@@ -831,7 +838,35 @@ class RandomExpand(DetTransform):
im_info['image_shape'] = np.array([h, w]).astype('int32')
if 'gt_bbox' in label_info and len(label_info['gt_bbox']) > 0:
label_info['gt_bbox'] += np.array([x, y] * 2, dtype=np.float32)
if self.filter_bbox:
expand_bbox = [
-x / width, -y / height,
(w - x) / width, (h - y) / height
]
gt_bbox = label_info['gt_bbox']
gt_class = label_info['gt_class']
for i in range(gt_bbox.shape[0]):
gt_bbox[i][0] = gt_bbox[i][0] / width
gt_bbox[i][1] = gt_bbox[i][1] / height
gt_bbox[i][2] = gt_bbox[i][2] / width
gt_bbox[i][3] = gt_bbox[i][3] / height
if 'gt_score' in label_info:
gt_score = label_info['gt_score']
gt_bbox, gt_class, gt_score = filter_and_process(
expand_bbox, gt_bbox, gt_class, gt_score)
label_info['gt_score'] = gt_score
else:
gt_bbox, gt_class, _ = filter_and_process(
expand_bbox, gt_bbox, gt_class)
for i in range(gt_bbox.shape[0]):
gt_bbox[i][0] = gt_bbox[i][0] * w
gt_bbox[i][1] = gt_bbox[i][1] * h
gt_bbox[i][2] = gt_bbox[i][2] * w
gt_bbox[i][3] = gt_bbox[i][3] * h
label_info['gt_bbox'] = gt_bbox
label_info['gt_class'] = gt_class
else:
label_info['gt_bbox'] += np.array([x, y] * 2, dtype=np.float32)
if 'gt_poly' in label_info and len(label_info['gt_poly']) > 0:
label_info['gt_poly'] = expand_segms(label_info['gt_poly'], x, y,
height, width, expand_ratio)
......@@ -990,6 +1025,195 @@ class RandomCrop(DetTransform):
return (im, im_info, label_info)
return (im, im_info, label_info)
class CropImageWithDataAchorSampling(DetTransform):
def __init__(self,
anchor_sampler=[[1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0]],
batch_sampler=[[1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]],
target_size=None,
das_anchor_scales=[16, 32, 64, 128],
sampling_prob=0.5,
min_size=8.,
avoid_no_bbox=True):
"""裁剪图像并修改对应标注框。
1. 缩放图像的高和宽。
2. 根据随机采样裁剪图像。
3. 缩放标注框。
4. 确认新的标注框是否在新的图像内。
Args:
anchor_sampler (list): 根据anchor采样的裁剪参数列表所组成的集合。
batch_sampler (list): 裁剪参数列表所组成的集合。
- max sample (int):满足当前组合的裁剪区域的个数上限。
- max trial (int): 查找满足当前组合的次数。
- min scale (float): 裁剪面积相对原面积,每条边缩短比例的最小限制。
- max scale (float): 裁剪面积相对原面积,每条边缩短比例的最大限制。
- min aspect ratio (float): 裁剪后短边缩放比例的最小限制。
- max aspect ratio (float): 裁剪后短边缩放比例的最大限制。
- min overlap (float): 真实标注框与裁剪图像重叠面积的最小限制。
- max overlap (float): 真实标注框与裁剪图像重叠面积的最大限制。
e.g.[[1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0]]
或者
[[1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]]
[max sample, max trial, min scale, max scale,
min aspect ratio, max aspect ratio,
min overlap, max overlap, min coverage, max coverage]
target_size (bool): target image size.
das_anchor_scales (list[float]): anchor采样的尺度列表。默认为[16, 32, 64, 128]。
min_size (float): 采样的标注框的最小面积为min_size*min_size。默认为8.。
avoid_no_bbox (bool): 裁剪后的图如果无标注框是否抛弃。默认为True。
"""
self.anchor_sampler = anchor_sampler
self.batch_sampler = batch_sampler
self.target_size = target_size
self.sampling_prob = sampling_prob
self.min_size = min_size
self.avoid_no_bbox = avoid_no_bbox
self.das_anchor_scales = np.array(das_anchor_scales)
def __call__(self, im, im_info=None, label_info=None):
"""
Args:
im (np.ndarray): 图像np.ndarray数据。
im_info (dict, 可选): 存储与图像相关的信息。
label_info (dict, 可选): 存储与标注框相关的信息。
Returns:
tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
存储与标注框相关信息的字典。
其中,im_info更新字段为:
- image_shape (np.ndarray): 扩裁剪的图像高、宽二者组成的np.ndarray,形状为(2,)。
label_info更新字段为:
- gt_bbox (np.ndarray): 随机裁剪后真实标注框坐标,形状为(n, 4),
其中n代表真实标注框的个数。
- gt_class (np.ndarray): 随机裁剪后每个真实标注框对应的类别序号,形状为(n, 1),
其中n代表真实标注框的个数。
- gt_score (np.ndarray): 随机裁剪后每个真实标注框对应的混合得分,形状为(n, 1),
其中n代表真实标注框的个数。
Raises:
TypeError: 形参数据类型不满足需求。
"""
image_shape = im_info['image_shape']
image_width = image_shape[1]
image_height = image_shape[0]
gt_bbox = label_info['gt_bbox']
gt_bbox_tmp = gt_bbox.copy()
for i in range(gt_bbox_tmp.shape[0]):
gt_bbox_tmp[i][0] = gt_bbox[i][0] / im_width
gt_bbox_tmp[i][1] = gt_bbox[i][1] / im_height
gt_bbox_tmp[i][2] = gt_bbox[i][2] / im_width
gt_bbox_tmp[i][3] = gt_bbox[i][3] / im_height
gt_class = label_info['gt_class']
gt_score = None
if 'gt_score' in sample:
gt_score = label_info['gt_score']
sampled_bbox = []
gt_bbox_tmp = gt_bbox_tmp.tolist()
prob = np.random.uniform(0., 1.)
if prob > self.sampling_prob: # anchor sampling
assert self.anchor_sampler
for sampler in self.anchor_sampler:
found = 0
for i in range(sampler[1]):
if found >= sampler[0]:
break
sample_bbox = data_anchor_sampling(
gt_bbox_tmp, image_width, image_height,
self.das_anchor_scales, self.target_size)
if sample_bbox == 0:
break
if satisfy_sample_constraint_coverage(sampler, sample_bbox,
gt_bbox_tmp):
sampled_bbox.append(sample_bbox)
found = found + 1
im = np.array(im)
while sampled_bbox:
idx = int(np.random.uniform(0, len(sampled_bbox)))
sample_bbox = sampled_bbox.pop(idx)
crop_bbox, crop_class, crop_score = filter_and_process(
sample_bbox, gt_bbox_tmp, gt_class, gt_score)
crop_bbox, crop_class, crop_score = bbox_area_sampling(
crop_bbox, crop_class, crop_score, self.target_size,
self.min_size)
if self.avoid_no_bbox:
if len(crop_bbox) < 1:
continue
im = crop_image_sampling(im, sample_bbox, image_width,
image_height, self.target_size)
for i in range(crop_bbox.shape[0]):
crop_bbox[i][0] = crop_bbox[i][0] * im.shape[1]
crop_bbox[i][1] = crop_bbox[i][1] * im.shape[0]
crop_bbox[i][2] = crop_bbox[i][2] * im.shape[1]
crop_bbox[i][3] = crop_bbox[i][3] * im.shape[0]
label_info['gt_bbox'] = crop_bbox
label_info['gt_class'] = crop_class
label_info['gt_score'] = crop_score
im_info['image_shape'] = np.array(
[im.shape[0],
im.shape[1]]).astype('int32')
return (im, im_info, label_info)
return (im, im_info, label_info)
else:
for sampler in self.batch_sampler:
found = 0
for i in range(sampler[1]):
if found >= sampler[0]:
break
sample_bbox = generate_sample_bbox_square(
sampler, image_width, image_height)
if satisfy_sample_constraint_coverage(sampler, sample_bbox,
gt_bbox_tmp):
sampled_bbox.append(sample_bbox)
found = found + 1
im = np.array(im)
while sampled_bbox:
idx = int(np.random.uniform(0, len(sampled_bbox)))
sample_bbox = sampled_bbox.pop(idx)
sample_bbox = clip_bbox(sample_bbox)
crop_bbox, crop_class, crop_score = filter_and_process(
sample_bbox, gt_bbox_tmp, gt_class, gt_score)
# sampling bbox according the bbox area
crop_bbox, crop_class, crop_score = bbox_area_sampling(
crop_bbox, crop_class, crop_score, self.target_size,
self.min_size)
if self.avoid_no_bbox:
if len(crop_bbox) < 1:
continue
xmin = int(sample_bbox[0] * image_width)
xmax = int(sample_bbox[2] * image_width)
ymin = int(sample_bbox[1] * image_height)
ymax = int(sample_bbox[3] * image_height)
im = im[ymin:ymax, xmin:xmax]
for i in range(crop_bbox.shape[0]):
crop_bbox[i][0] = crop_bbox[i][0] * (xmax - xmin)
crop_bbox[i][1] = crop_bbox[i][1] * (ymax - ymin)
crop_bbox[i][2] = crop_bbox[i][2] * (xmax - xmin)
crop_bbox[i][3] = crop_bbox[i][3] * (ymax - ymin)
label_info['gt_bbox'] = crop_bbox
label_info['gt_class'] = crop_class
label_info['gt_score'] = crop_score
im_info['image_shape'] = np.array(
[im.shape[0],
im.shape[1]]).astype('int32')
return (im, im_info, label_info)
return (im, im_info, label_info)
class ArrangeFasterRCNN(DetTransform):
......@@ -1238,6 +1462,72 @@ class ArrangeYOLOv3(DetTransform):
im_shape = im_info['image_shape']
outputs = (im, im_shape)
return outputs
class ArrangeBlazeFace(DetTransform):
"""获取ArrangeBlazeFace模型训练/验证/预测所需信息。
Args:
mode (str): 指定数据用于何种用途,取值范围为['train', 'eval', 'test', 'quant']。
Raises:
ValueError: mode的取值不在['train', 'eval', 'test', 'quant']之内。
"""
def __init__(self, mode=None):
if mode not in ['train', 'eval', 'test', 'quant']:
raise ValueError(
"mode must be in ['train', 'eval', 'test', 'quant']!")
self.mode = mode
def __call__(self, im, im_info=None, label_info=None):
"""
Args:
im (np.ndarray): 图像np.ndarray数据。
im_info (dict, 可选): 存储与图像相关的信息。
label_info (dict, 可选): 存储与标注框相关的信息。
Returns:
tuple: 当mode为'train'时,返回(im, gt_bbox, gt_class, im_shape),分别对应
图像np.ndarray数据、真实标注框、真实标注框对应的类别、图像大小信息;
当mode为'eval'时,返回(im, im_id),分别对应图像np.ndarray数据、图像id;
当mode为'test'或'quant'时,返回(im, im_shape),分别对应图像np.ndarray数据、图像大小信息。
Raises:
TypeError: 形参数据类型不满足需求。
ValueError: 数据长度不匹配。
"""
im = permute(im, True)
if self.mode == 'train':
if im_info is None or label_info is None:
raise TypeError(
'Cannot do ArrangeBlazeFace! ' +
'Becasuse the im_info and label_info can not be None!')
if len(label_info['gt_bbox']) != len(label_info['gt_class']):
raise ValueError("gt num mismatch: bbox and class.")
outputs = (im, label_info['gt_bbox'], label_info['gt_class'], im_info['image_shape'])
elif self.mode == 'eval':
if im_info is None :
raise TypeError(
'Cannot do ArrangeBlazeFace! ' +
'Becasuse the im_info can not be None!')
gt_bbox = im_info['gt_bbox']
im_shape = im_info['image_shape']
im_height = im_shape[0]
im_width = im_shape[1]
for i in range(gt_bbox.shape[0]):
gt_bbox[i][0] = gt_bbox[i][0] / im_width
gt_bbox[i][1] = gt_bbox[i][1] / im_height
gt_bbox[i][2] = gt_bbox[i][2] / im_width
gt_bbox[i][3] = gt_bbox[i][3] / im_height
outputs = (im, gt_bbox, im_info['gt_class'],
im_info['difficult'], im_info['im_id'])
else:
if im_info is None:
raise TypeError('Cannot do ArrangeBlazeFace! ' +
'Becasuse the im_info can not be None!')
outputs = (im, im_info['image_shape'])
return outputs
class ComposedRCNNTransforms(Compose):
......
......@@ -18,6 +18,7 @@ from . import cv
FasterRCNN = cv.models.FasterRCNN
YOLOv3 = cv.models.YOLOv3
MaskRCNN = cv.models.MaskRCNN
BlazeFace = cv.models.BlazeFace
transforms = cv.transforms.det_transforms
visualize = cv.models.utils.visualize.visualize_detection
draw_pr_curve = cv.models.utils.visualize.draw_pr_curve
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册