提交 91a601e2 编写于 作者: S sunyanfang01

add yolo with iou aware loss

上级 948032a7
......@@ -114,7 +114,7 @@ def multithread_reader(mapper,
while not isinstance(sample, EndSignal):
batch_data.append(sample)
if len(batch_data) == batch_size:
batch_data = GenerateMiniBatch(batch_data)
batch_data = GenerateMiniBatch(batch_data, mapper)
yield batch_data
batch_data = []
sample = out_queue.get()
......@@ -126,11 +126,11 @@ def multithread_reader(mapper,
else:
batch_data.append(sample)
if len(batch_data) == batch_size:
batch_data = GenerateMiniBatch(batch_data)
batch_data = GenerateMiniBatch(batch_data, mapper)
yield batch_data
batch_data = []
if not drop_last and len(batch_data) != 0:
batch_data = GenerateMiniBatch(batch_data)
batch_data = GenerateMiniBatch(batch_data, mapper)
yield batch_data
batch_data = []
......@@ -187,18 +187,21 @@ def multiprocess_reader(mapper,
else:
batch_data.append(sample)
if len(batch_data) == batch_size:
batch_data = GenerateMiniBatch(batch_data)
batch_data = GenerateMiniBatch(batch_data, mapper)
yield batch_data
batch_data = []
if len(batch_data) != 0 and not drop_last:
batch_data = GenerateMiniBatch(batch_data)
batch_data = GenerateMiniBatch(batch_data, mapper)
yield batch_data
batch_data = []
return queue_reader
def GenerateMiniBatch(batch_data):
def GenerateMiniBatch(batch_data, mapper):
if mapper.batch_transforms is not None:
for op in mapper.batch_transforms:
batch_data = op(batch_data)
if len(batch_data) == 1:
return batch_data
width = [data[0].shape[2] for data in batch_data]
......@@ -209,8 +212,8 @@ def GenerateMiniBatch(batch_data):
padding_batch = []
for data in batch_data:
im_c, im_h, im_w = data[0].shape[:]
padding_im = np.zeros(
(im_c, max_shape[1], max_shape[2]), dtype=np.float32)
padding_im = np.zeros((im_c, max_shape[1], max_shape[2]),
dtype=np.float32)
padding_im[:, :im_h, :im_w] = data[0]
padding_batch.append((padding_im, ) + data[1:])
return padding_batch
......@@ -226,8 +229,8 @@ class Dataset:
if num_workers == 'auto':
import multiprocessing as mp
num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 8 else 8
if platform.platform().startswith("Darwin") or platform.platform(
).startswith("Windows"):
if platform.platform().startswith(
"Darwin") or platform.platform().startswith("Windows"):
parallel_method = 'thread'
if transforms is None:
raise Exception("transform should be defined.")
......
......@@ -56,20 +56,11 @@ image_pretrain = {
'https://paddle-imagenet-models-name.bj.bcebos.com/Xception65_deeplab_pretrained.tar',
'ShuffleNetV2':
'https://paddle-imagenet-models-name.bj.bcebos.com/ShuffleNetV2_pretrained.tar',
'HRNet_W18':
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W18_C_pretrained.tar',
'HRNet_W30':
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W30_C_pretrained.tar',
'HRNet_W32':
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W32_C_pretrained.tar',
'HRNet_W40':
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W40_C_pretrained.tar',
'HRNet_W48':
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W48_C_pretrained.tar',
'HRNet_W60':
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W60_C_pretrained.tar',
'HRNet_W64':
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W64_C_pretrained.tar',
}
obj365_pretrain = {
'ResNet50_vd_dcn_db_obj365':
'https://paddlemodels.bj.bcebos.com/object_detection/ResNet50_vd_dcn_db_obj365_pretrained.tar',
}
coco_pretrain = {
......@@ -117,6 +108,18 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
raise Exception(
"Unexpected error, please make sure paddlehub >= 1.6.2")
return osp.join(new_save_dir, backbone)
elif flag == 'Object365':
new_save_dir = save_dir
if hasattr(paddlex, 'pretrain_dir'):
new_save_dir = paddlex.pretrain_dir
if backbone == 'ResNet50_vd':
backbone = 'ResNet50_vd_dcn_db_obj365'
assert backbone in obj365_pretrain, "There is not Object365 pretrain weights for {}, you may try ImageNet.".format(
backbone)
url = obj365_pretrain[backbone]
fname = osp.split(url)[-1].split('.')[0]
paddlex.utils.download_and_decompress(url, path=new_save_dir)
return osp.join(new_save_dir, fname)
elif flag == 'COCO':
new_save_dir = save_dir
if hasattr(paddlex, 'pretrain_dir'):
......@@ -144,5 +147,5 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
return osp.join(new_save_dir, backbone)
else:
raise Exception(
"pretrain_weights need to be defined as directory path or `IMAGENET` or 'COCO' (download pretrain weights automatically)."
"pretrain_weights need to be defined as directory path or `IMAGENET` or `Object365` or 'COCO' (download pretrain weights automatically)."
)
......@@ -58,13 +58,18 @@ class YOLOv3(BaseAPI):
nms_keep_topk=100,
nms_iou_threshold=0.45,
label_smooth=False,
use_iou_loss=False,
use_iou_aware_loss=False,
iou_aware_factor=0.4,
use_drop_block=False,
use_dcn_v2=False,
train_random_shapes=[
320, 352, 384, 416, 448, 480, 512, 544, 576, 608
]):
self.init_params = locals()
super(YOLOv3, self).__init__('detector')
backbones = [
'DarkNet53', 'ResNet34', 'MobileNetV1', 'MobileNetV3_large'
'DarkNet53', 'ResNet34', 'MobileNetV1', 'MobileNetV3_large', 'ResNet50_vd'
]
assert backbone in backbones, "backbone should be one of {}".format(
backbones)
......@@ -79,8 +84,18 @@ class YOLOv3(BaseAPI):
self.nms_iou_threshold = nms_iou_threshold
self.label_smooth = label_smooth
self.sync_bn = True
self.use_iou_loss = use_iou_loss
self.use_iou_aware_loss = use_iou_aware_loss
self.iou_aware_factor = iou_aware_factor
self.use_drop_block = use_drop_block
self.use_dcn_v2 = use_dcn_v2
self.train_random_shapes = train_random_shapes
self.fixed_input_shape = None
if self.anchors is None:
self.anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]]
if self.anchor_masks is None:
self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
def _get_backbone(self, backbone_name):
if backbone_name == 'DarkNet53':
......@@ -93,6 +108,16 @@ class YOLOv3(BaseAPI):
norm_decay=0.,
feature_maps=[3, 4, 5],
freeze_at=0)
elif backbone_name == 'ResNet50_vd':
backbone = paddlex.cv.nets.ResNet(
norm_type='sync_bn',
layers=50,
variant='d',
freeze_norm=False,
norm_decay=0.,
feature_maps=[3, 4, 5],
freeze_at=0,
dcn_v2_stages=[5] if self.use_dcn_v2 else [])
elif backbone_name == 'MobileNetV1':
backbone = paddlex.cv.nets.MobileNetV1(norm_type='sync_bn')
elif backbone_name.startswith('MobileNetV3'):
......@@ -115,7 +140,12 @@ class YOLOv3(BaseAPI):
nms_keep_topk=self.nms_keep_topk,
nms_iou_threshold=self.nms_iou_threshold,
train_random_shapes=self.train_random_shapes,
fixed_input_shape=self.fixed_input_shape)
fixed_input_shape=self.fixed_input_shape,
use_iou_loss=self.use_iou_loss,
use_iou_aware_loss=self.use_iou_aware_loss,
iou_aware_factor=self.iou_aware_factor,
use_drop_block=self.use_drop_block,
batch_size=self.train_batch_size if hasattr(self, 'train_batch_size') else 8)
inputs = model.generate_inputs()
model_out = model.build_net(inputs)
outputs = OrderedDict([('bbox', model_out)])
......@@ -217,7 +247,22 @@ class YOLOv3(BaseAPI):
assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
self.metric = metric
self.train_batch_size = train_batch_size
self.labels = train_dataset.labels
if self.use_iou_loss or self.use_iou_aware_loss:
if self.train_random_shapes is None or len(self.train_random_shapes) == 0:
for transform in train_dataset.transforms.transforms:
if isinstance(transform, paddlex.det.transforms.Resize):
self.train_random_shapes = [transform.target_size]
break
train_dataset.transforms.batch_transforms = []
reshape_bt = paddlex.det.transforms.RandomShape
train_dataset.transforms.batch_transforms.append(reshape_bt(
random_shapes=self.train_random_shapes))
iou_bt = paddlex.det.transforms.GenerateYoloTarget
train_dataset.transforms.batch_transforms.append(iou_bt(anchors=self.anchors,
anchor_masks=self.anchor_masks,
num_classes=self.num_classes))
# 构建训练网络
if optimizer is None:
# 构建默认的优化策略
......@@ -306,10 +351,11 @@ class YOLOv3(BaseAPI):
images = np.array([d[0] for d in data])
im_sizes = np.array([d[1] for d in data])
feed_data = {'image': images, 'im_size': im_sizes}
outputs = self.exe.run(self.test_prog,
feed=[feed_data],
fetch_list=list(self.test_outputs.values()),
return_numpy=False)
outputs = self.exe.run(
self.test_prog,
feed=[feed_data],
fetch_list=list(self.test_outputs.values()),
return_numpy=False)
res = {
'bbox': (np.array(outputs[0]),
outputs[0].recursive_sequence_lengths())
......@@ -325,13 +371,13 @@ class YOLOv3(BaseAPI):
res['gt_label'] = (res_gt_label, [])
res['is_difficult'] = (res_is_difficult, [])
results.append(res)
logging.debug("[EVAL] Epoch={}, Step={}/{}".format(epoch_id, step +
1, total_steps))
logging.debug("[EVAL] Epoch={}, Step={}/{}".format(
epoch_id, step + 1, total_steps))
box_ap_stats, eval_details = eval_results(
results, metric, eval_dataset.coco_gt, with_background=False)
evaluate_metrics = OrderedDict(
zip(['bbox_mmap'
if metric == 'COCO' else 'bbox_map'], box_ap_stats))
zip(['bbox_mmap' if metric == 'COCO' else 'bbox_map'],
box_ap_stats))
if return_details:
return evaluate_metrics, eval_details
return evaluate_metrics
......@@ -345,8 +391,7 @@ class YOLOv3(BaseAPI):
Returns:
list: 预测结果列表,每个预测结果由预测框类别标签、
预测框类别名称、预测框坐标(坐标格式为[xmin, ymin, w, h])、
预测框得分组成。
预测框类别名称、预测框坐标、预测框得分组成。
"""
if transforms is None and not hasattr(self, 'test_transforms'):
raise Exception("transforms need to be defined, now is None.")
......@@ -359,11 +404,14 @@ class YOLOv3(BaseAPI):
im, im_size = self.test_transforms(img_file)
im = np.expand_dims(im, axis=0)
im_size = np.expand_dims(im_size, axis=0)
outputs = self.exe.run(self.test_prog,
feed={'image': im,
'im_size': im_size},
fetch_list=list(self.test_outputs.values()),
return_numpy=False)
outputs = self.exe.run(
self.test_prog,
feed={
'image': im,
'im_size': im_size
},
fetch_list=list(self.test_outputs.values()),
return_numpy=False)
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(self.test_outputs.keys()), outputs)
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
def _split_ioup(output, an_num, num_classes):
"""
Split new output feature map to output, predicted iou
along channel dimension
"""
ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num])
ioup = fluid.layers.sigmoid(ioup)
oriout = fluid.layers.slice(
output, axes=[1], starts=[an_num], ends=[an_num * (num_classes + 6)])
return (ioup, oriout)
def _de_sigmoid(x, eps=1e-7):
x = fluid.layers.clip(x, eps, 1 / eps)
one = fluid.layers.fill_constant(
shape=[1, 1, 1, 1], dtype=x.dtype, value=1.)
x = fluid.layers.clip((one / x - 1.0), eps, 1 / eps)
x = -fluid.layers.log(x)
return x
def _postprocess_output(ioup, output, an_num, num_classes, iou_aware_factor):
"""
post process output objectness score
"""
tensors = []
stride = output.shape[1] // an_num
for m in range(an_num):
tensors.append(
fluid.layers.slice(
output,
axes=[1],
starts=[stride * m + 0],
ends=[stride * m + 4]))
obj = fluid.layers.slice(
output, axes=[1], starts=[stride * m + 4], ends=[stride * m + 5])
obj = fluid.layers.sigmoid(obj)
ip = fluid.layers.slice(ioup, axes=[1], starts=[m], ends=[m + 1])
new_obj = fluid.layers.pow(obj, (
1 - iou_aware_factor)) * fluid.layers.pow(ip, iou_aware_factor)
new_obj = _de_sigmoid(new_obj)
tensors.append(new_obj)
tensors.append(
fluid.layers.slice(
output,
axes=[1],
starts=[stride * m + 5],
ends=[stride * m + 5 + num_classes]))
output = fluid.layers.concat(tensors, axis=1)
return output
def get_iou_aware_score(output, an_num, num_classes, iou_aware_factor):
ioup, output = _split_ioup(output, an_num, num_classes)
output = _postprocess_output(ioup, output, an_num, num_classes,
iou_aware_factor)
return output
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import NumpyArrayInitializer
from paddle import fluid
from .iou_loss import IouLoss
class IouAwareLoss(IouLoss):
"""
iou aware loss, see https://arxiv.org/abs/1912.05992
Args:
loss_weight (float): iou aware loss weight, default is 1.0
max_height (int): max height of input to support random shape input
max_width (int): max width of input to support random shape input
"""
def __init__(self, loss_weight=1.0, max_height=608, max_width=608):
super(IouAwareLoss, self).__init__(
loss_weight=loss_weight, max_height=max_height, max_width=max_width)
def __call__(self,
ioup,
x,
y,
w,
h,
tx,
ty,
tw,
th,
anchors,
downsample_ratio,
batch_size,
eps=1.e-10):
'''
Args:
ioup ([Variables]): the predicted iou
x | y | w | h ([Variables]): the output of yolov3 for encoded x|y|w|h
tx |ty |tw |th ([Variables]): the target of yolov3 for encoded x|y|w|h
anchors ([float]): list of anchors for current output layer
downsample_ratio (float): the downsample ratio for current output layer
batch_size (int): training batch size
eps (float): the decimal to prevent the denominator eqaul zero
'''
pred = self._bbox_transform(x, y, w, h, anchors, downsample_ratio,
batch_size, False)
gt = self._bbox_transform(tx, ty, tw, th, anchors, downsample_ratio,
batch_size, True)
iouk = self._iou(pred, gt, ioup, eps)
iouk.stop_gradient = True
loss_iou_aware = fluid.layers.cross_entropy(ioup, iouk, soft_label=True)
loss_iou_aware = loss_iou_aware * self._loss_weight
return loss_iou_aware
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import NumpyArrayInitializer
from paddle import fluid
class IouLoss(object):
"""
iou loss, see https://arxiv.org/abs/1908.03851
loss = 1.0 - iou * iou
Args:
loss_weight (float): iou loss weight, default is 2.5
max_height (int): max height of input to support random shape input
max_width (int): max width of input to support random shape input
ciou_term (bool): whether to add ciou_term
loss_square (bool): whether to square the iou term
"""
def __init__(self,
loss_weight=2.5,
max_height=608,
max_width=608,
ciou_term=False,
loss_square=True):
self._loss_weight = loss_weight
self._MAX_HI = max_height
self._MAX_WI = max_width
self.ciou_term = ciou_term
self.loss_square = loss_square
def __call__(self,
x,
y,
w,
h,
tx,
ty,
tw,
th,
anchors,
downsample_ratio,
batch_size,
ioup=None,
eps=1.e-10):
'''
Args:
x | y | w | h ([Variables]): the output of yolov3 for encoded x|y|w|h
tx |ty |tw |th ([Variables]): the target of yolov3 for encoded x|y|w|h
anchors ([float]): list of anchors for current output layer
downsample_ratio (float): the downsample ratio for current output layer
batch_size (int): training batch size
eps (float): the decimal to prevent the denominator eqaul zero
'''
pred = self._bbox_transform(x, y, w, h, anchors, downsample_ratio,
batch_size, False)
gt = self._bbox_transform(tx, ty, tw, th, anchors, downsample_ratio,
batch_size, True)
iouk = self._iou(pred, gt, ioup, eps)
if self.loss_square:
loss_iou = 1. - iouk * iouk
else:
loss_iou = 1. - iouk
loss_iou = loss_iou * self._loss_weight
return loss_iou
def _iou(self, pred, gt, ioup=None, eps=1.e-10):
x1, y1, x2, y2 = pred
x1g, y1g, x2g, y2g = gt
x2 = fluid.layers.elementwise_max(x1, x2)
y2 = fluid.layers.elementwise_max(y1, y2)
xkis1 = fluid.layers.elementwise_max(x1, x1g)
ykis1 = fluid.layers.elementwise_max(y1, y1g)
xkis2 = fluid.layers.elementwise_min(x2, x2g)
ykis2 = fluid.layers.elementwise_min(y2, y2g)
intsctk = (xkis2 - xkis1) * (ykis2 - ykis1)
intsctk = intsctk * fluid.layers.greater_than(
xkis2, xkis1) * fluid.layers.greater_than(ykis2, ykis1)
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g
) - intsctk + eps
iouk = intsctk / unionk
if self.ciou_term:
ciou = self.get_ciou_term(pred, gt, iouk, eps)
iouk = iouk - ciou
return iouk
def get_ciou_term(self, pred, gt, iouk, eps):
x1, y1, x2, y2 = pred
x1g, y1g, x2g, y2g = gt
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = (x2 - x1) + fluid.layers.cast((x2 - x1) == 0, 'float32')
h = (y2 - y1) + fluid.layers.cast((y2 - y1) == 0, 'float32')
cxg = (x1g + x2g) / 2
cyg = (y1g + y2g) / 2
wg = x2g - x1g
hg = y2g - y1g
# A or B
xc1 = fluid.layers.elementwise_min(x1, x1g)
yc1 = fluid.layers.elementwise_min(y1, y1g)
xc2 = fluid.layers.elementwise_max(x2, x2g)
yc2 = fluid.layers.elementwise_max(y2, y2g)
# DIOU term
dist_intersection = (cx - cxg) * (cx - cxg) + (cy - cyg) * (cy - cyg)
dist_union = (xc2 - xc1) * (xc2 - xc1) + (yc2 - yc1) * (yc2 - yc1)
diou_term = (dist_intersection + eps) / (dist_union + eps)
# CIOU term
ciou_term = 0
ar_gt = wg / hg
ar_pred = w / h
arctan = fluid.layers.atan(ar_gt) - fluid.layers.atan(ar_pred)
ar_loss = 4. / np.pi / np.pi * arctan * arctan
alpha = ar_loss / (1 - iouk + ar_loss + eps)
alpha.stop_gradient = True
ciou_term = alpha * ar_loss
return diou_term + ciou_term
def _bbox_transform(self, dcx, dcy, dw, dh, anchors, downsample_ratio,
batch_size, is_gt):
grid_x = int(self._MAX_WI / downsample_ratio)
grid_y = int(self._MAX_HI / downsample_ratio)
an_num = len(anchors) // 2
shape_fmp = fluid.layers.shape(dcx)
shape_fmp.stop_gradient = True
# generate the grid_w x grid_h center of feature map
idx_i = np.array([[i for i in range(grid_x)]])
idx_j = np.array([[j for j in range(grid_y)]]).transpose()
gi_np = np.repeat(idx_i, grid_y, axis=0)
gi_np = np.reshape(gi_np, newshape=[1, 1, grid_y, grid_x])
gi_np = np.tile(gi_np, reps=[batch_size, an_num, 1, 1])
gj_np = np.repeat(idx_j, grid_x, axis=1)
gj_np = np.reshape(gj_np, newshape=[1, 1, grid_y, grid_x])
gj_np = np.tile(gj_np, reps=[batch_size, an_num, 1, 1])
gi_max = self._create_tensor_from_numpy(gi_np.astype(np.float32))
gi = fluid.layers.crop(x=gi_max, shape=dcx)
gi.stop_gradient = True
gj_max = self._create_tensor_from_numpy(gj_np.astype(np.float32))
gj = fluid.layers.crop(x=gj_max, shape=dcx)
gj.stop_gradient = True
grid_x_act = fluid.layers.cast(shape_fmp[3], dtype="float32")
grid_x_act.stop_gradient = True
grid_y_act = fluid.layers.cast(shape_fmp[2], dtype="float32")
grid_y_act.stop_gradient = True
if is_gt:
cx = fluid.layers.elementwise_add(dcx, gi) / grid_x_act
cx.gradient = True
cy = fluid.layers.elementwise_add(dcy, gj) / grid_y_act
cy.gradient = True
else:
dcx_sig = fluid.layers.sigmoid(dcx)
cx = fluid.layers.elementwise_add(dcx_sig, gi) / grid_x_act
dcy_sig = fluid.layers.sigmoid(dcy)
cy = fluid.layers.elementwise_add(dcy_sig, gj) / grid_y_act
anchor_w_ = [anchors[i] for i in range(0, len(anchors)) if i % 2 == 0]
anchor_w_np = np.array(anchor_w_)
anchor_w_np = np.reshape(anchor_w_np, newshape=[1, an_num, 1, 1])
anchor_w_np = np.tile(anchor_w_np, reps=[batch_size, 1, grid_y, grid_x])
anchor_w_max = self._create_tensor_from_numpy(
anchor_w_np.astype(np.float32))
anchor_w = fluid.layers.crop(x=anchor_w_max, shape=dcx)
anchor_w.stop_gradient = True
anchor_h_ = [anchors[i] for i in range(0, len(anchors)) if i % 2 == 1]
anchor_h_np = np.array(anchor_h_)
anchor_h_np = np.reshape(anchor_h_np, newshape=[1, an_num, 1, 1])
anchor_h_np = np.tile(anchor_h_np, reps=[batch_size, 1, grid_y, grid_x])
anchor_h_max = self._create_tensor_from_numpy(
anchor_h_np.astype(np.float32))
anchor_h = fluid.layers.crop(x=anchor_h_max, shape=dcx)
anchor_h.stop_gradient = True
# e^tw e^th
exp_dw = fluid.layers.exp(dw)
exp_dh = fluid.layers.exp(dh)
pw = fluid.layers.elementwise_mul(exp_dw, anchor_w) / \
(grid_x_act * downsample_ratio)
ph = fluid.layers.elementwise_mul(exp_dh, anchor_h) / \
(grid_y_act * downsample_ratio)
if is_gt:
exp_dw.stop_gradient = True
exp_dh.stop_gradient = True
pw.stop_gradient = True
ph.stop_gradient = True
x1 = cx - 0.5 * pw
y1 = cy - 0.5 * ph
x2 = cx + 0.5 * pw
y2 = cy + 0.5 * ph
if is_gt:
x1.stop_gradient = True
y1.stop_gradient = True
x2.stop_gradient = True
y2.stop_gradient = True
return x1, y1, x2, y2
def _create_tensor_from_numpy(self, numpy_array):
paddle_array = fluid.layers.create_parameter(
attr=ParamAttr(),
shape=numpy_array.shape,
dtype=numpy_array.dtype,
default_initializer=NumpyArrayInitializer(numpy_array))
paddle_array.stop_gradient = True
return paddle_array
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from .iou_loss import IouLoss
from .iou_aware_loss import IouAwareLoss
class YOLOv3Loss(object):
"""
Combined loss for YOLOv3 network
Args:
batch_size (int): training batch size
ignore_thresh (float): threshold to ignore confidence loss
label_smooth (bool): whether to use label smoothing
"""
def __init__(self,
batch_size=8,
ignore_thresh=0.7,
label_smooth=True,
iou_loss_weight=None,
iou_aware_loss_weight=None,
scale_x_y=1.,
match_score=False):
self._batch_size = batch_size
self._ignore_thresh = ignore_thresh
self._label_smooth = label_smooth
self._iou_loss_weight = iou_loss_weight
self._iou_aware_loss_weight = iou_aware_loss_weight
self.match_score = match_score
def __call__(self, outputs, gt_box, gt_label, gt_score, targets, anchors,
anchor_masks, mask_anchors, num_classes, prefix_name, max_size):
if len(targets) != 0:
losses_all = self._get_fine_grained_loss(
outputs, targets, gt_box, self._batch_size, num_classes,
mask_anchors, self._ignore_thresh, max_size)
else:
losses = []
downsample = 32
for i, output in enumerate(outputs):
anchor_mask = anchor_masks[i]
loss = fluid.layers.yolov3_loss(
x=output,
gt_box=gt_box,
gt_label=gt_label,
gt_score=gt_score,
anchors=anchors,
anchor_mask=anchor_mask,
class_num=num_classes,
ignore_thresh=self._ignore_thresh,
downsample_ratio=downsample,
use_label_smooth=self._label_smooth,
name=prefix_name + "yolo_loss" + str(i))
losses.append(fluid.layers.reduce_mean(loss))
downsample //= 2
losses_all = {'loss': sum(losses)}
total_loss = fluid.layers.sum(list(losses_all.values()))
return total_loss
def _get_fine_grained_loss(self, outputs, targets, gt_box, batch_size,
num_classes, mask_anchors, ignore_thresh, max_size):
"""
Calculate fine grained YOLOv3 loss
Args:
outputs ([Variables]): List of Variables, output of backbone stages
targets ([Variables]): List of Variables, The targets for yolo
loss calculatation.
gt_box (Variable): The ground-truth boudding boxes.
batch_size (int): The training batch size
num_classes (int): class num of dataset
mask_anchors ([[float]]): list of anchors in each output layer
ignore_thresh (float): prediction bbox overlap any gt_box greater
than ignore_thresh, objectness loss will
be ignored.
Returns:
Type: dict
xy_loss (Variable): YOLOv3 (x, y) coordinates loss
wh_loss (Variable): YOLOv3 (w, h) coordinates loss
obj_loss (Variable): YOLOv3 objectness score loss
cls_loss (Variable): YOLOv3 classification loss
"""
assert len(outputs) == len(targets), \
"YOLOv3 output layer number not equal target number"
loss_xys, loss_whs, loss_objs, loss_clss = [], [], [], []
if self._iou_loss_weight is not None:
loss_ious = []
if self._iou_aware_loss_weight is not None:
loss_iou_awares = []
downsample = 32
for i, (output, target,
anchors) in enumerate(zip(outputs, targets, mask_anchors)):
an_num = len(anchors) // 2
if self._iou_aware_loss_weight is not None:
ioup, output = self._split_ioup(output, an_num, num_classes)
x, y, w, h, obj, cls = self._split_output(output, an_num,
num_classes)
tx, ty, tw, th, tscale, tobj, tcls = self._split_target(target)
tscale_tobj = tscale * tobj
loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(
x, tx) * tscale_tobj
loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
loss_y = fluid.layers.sigmoid_cross_entropy_with_logits(
y, ty) * tscale_tobj
loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
# NOTE: we refined loss function of (w, h) as L1Loss
loss_w = fluid.layers.abs(w - tw) * tscale_tobj
loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3])
loss_h = fluid.layers.abs(h - th) * tscale_tobj
loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3])
if self._iou_loss_weight is not None:
iou_loss_obj = IouLoss(self._iou_loss_weight,
max_size, max_size)
loss_iou = iou_loss_obj(x, y, w, h, tx, ty, tw, th, anchors,
downsample, self._batch_size)
loss_iou = loss_iou * tscale_tobj
loss_iou = fluid.layers.reduce_sum(loss_iou, dim=[1, 2, 3])
loss_ious.append(fluid.layers.reduce_mean(loss_iou))
if self._iou_aware_loss_weight is not None:
iou_aware_loss_obj = IouAwareLoss(self._iou_aware_loss_weight,
max_size, max_size)
loss_iou_aware = iou_aware_loss_obj(
ioup, x, y, w, h, tx, ty, tw, th, anchors, downsample,
self._batch_size)
loss_iou_aware = loss_iou_aware * tobj
loss_iou_aware = fluid.layers.reduce_sum(
loss_iou_aware, dim=[1, 2, 3])
loss_iou_awares.append(fluid.layers.reduce_mean(loss_iou_aware))
loss_obj_pos, loss_obj_neg = self._calc_obj_loss(
output, obj, tobj, gt_box, self._batch_size, anchors,
num_classes, downsample, self._ignore_thresh)
loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls, tcls)
loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0)
loss_cls = fluid.layers.reduce_sum(loss_cls)
loss_xys.append(fluid.layers.reduce_mean(loss_x + loss_y))
loss_whs.append(fluid.layers.reduce_mean(loss_w + loss_h))
loss_objs.append(
fluid.layers.reduce_mean(loss_obj_pos + loss_obj_neg))
loss_clss.append(fluid.layers.reduce_mean(loss_cls))
downsample //= 2
losses_all = {
"loss_xy": fluid.layers.sum(loss_xys),
"loss_wh": fluid.layers.sum(loss_whs),
"loss_obj": fluid.layers.sum(loss_objs),
"loss_cls": fluid.layers.sum(loss_clss),
}
if self._iou_loss_weight is not None:
losses_all["loss_iou"] = fluid.layers.sum(loss_ious)
if self._iou_aware_loss_weight is not None:
losses_all["loss_iou_aware"] = fluid.layers.sum(loss_iou_awares)
return losses_all
def _split_ioup(self, output, an_num, num_classes):
"""
Split output feature map to output, predicted iou
along channel dimension
"""
ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num])
ioup = fluid.layers.sigmoid(ioup)
oriout = fluid.layers.slice(
output,
axes=[1],
starts=[an_num],
ends=[an_num * (num_classes + 6)])
return (ioup, oriout)
def _split_output(self, output, an_num, num_classes):
"""
Split output feature map to x, y, w, h, objectness, classification
along channel dimension
"""
x = fluid.layers.strided_slice(
output,
axes=[1],
starts=[0],
ends=[output.shape[1]],
strides=[5 + num_classes])
y = fluid.layers.strided_slice(
output,
axes=[1],
starts=[1],
ends=[output.shape[1]],
strides=[5 + num_classes])
w = fluid.layers.strided_slice(
output,
axes=[1],
starts=[2],
ends=[output.shape[1]],
strides=[5 + num_classes])
h = fluid.layers.strided_slice(
output,
axes=[1],
starts=[3],
ends=[output.shape[1]],
strides=[5 + num_classes])
obj = fluid.layers.strided_slice(
output,
axes=[1],
starts=[4],
ends=[output.shape[1]],
strides=[5 + num_classes])
clss = []
stride = output.shape[1] // an_num
for m in range(an_num):
clss.append(
fluid.layers.slice(
output,
axes=[1],
starts=[stride * m + 5],
ends=[stride * m + 5 + num_classes]))
cls = fluid.layers.transpose(
fluid.layers.stack(
clss, axis=1), perm=[0, 1, 3, 4, 2])
return (x, y, w, h, obj, cls)
def _split_target(self, target):
"""
split target to x, y, w, h, objectness, classification
along dimension 2
target is in shape [N, an_num, 6 + class_num, H, W]
"""
tx = target[:, :, 0, :, :]
ty = target[:, :, 1, :, :]
tw = target[:, :, 2, :, :]
th = target[:, :, 3, :, :]
tscale = target[:, :, 4, :, :]
tobj = target[:, :, 5, :, :]
tcls = fluid.layers.transpose(
target[:, :, 6:, :, :], perm=[0, 1, 3, 4, 2])
tcls.stop_gradient = True
return (tx, ty, tw, th, tscale, tobj, tcls)
def _calc_obj_loss(self, output, obj, tobj, gt_box, batch_size, anchors,
num_classes, downsample, ignore_thresh):
# A prediction bbox overlap any gt_bbox over ignore_thresh,
# objectness loss will be ignored, process as follows:
# 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here
# NOTE: img_size is set as 1.0 to get noramlized pred bbox
bbox, prob = fluid.layers.yolo_box(
x=output,
img_size=fluid.layers.ones(
shape=[batch_size, 2], dtype="int32"),
anchors=anchors,
class_num=num_classes,
conf_thresh=0.,
downsample_ratio=downsample,
clip_bbox=False)
# 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox
# and gt bbox in each sample
if batch_size > 1:
preds = fluid.layers.split(bbox, batch_size, dim=0)
gts = fluid.layers.split(gt_box, batch_size, dim=0)
else:
preds = [bbox]
gts = [gt_box]
probs = [prob]
ious = []
for pred, gt in zip(preds, gts):
def box_xywh2xyxy(box):
x = box[:, 0]
y = box[:, 1]
w = box[:, 2]
h = box[:, 3]
return fluid.layers.stack(
[
x - w / 2.,
y - h / 2.,
x + w / 2.,
y + h / 2.,
], axis=1)
pred = fluid.layers.squeeze(pred, axes=[0])
gt = box_xywh2xyxy(fluid.layers.squeeze(gt, axes=[0]))
ious.append(fluid.layers.iou_similarity(pred, gt))
iou = fluid.layers.stack(ious, axis=0)
# 3. Get iou_mask by IoU between gt bbox and prediction bbox,
# Get obj_mask by tobj(holds gt_score), calculate objectness loss
max_iou = fluid.layers.reduce_max(iou, dim=-1)
iou_mask = fluid.layers.cast(max_iou <= ignore_thresh, dtype="float32")
if self.match_score:
max_prob = fluid.layers.reduce_max(prob, dim=-1)
iou_mask = iou_mask * fluid.layers.cast(
max_prob <= 0.25, dtype="float32")
output_shape = fluid.layers.shape(output)
an_num = len(anchors) // 2
iou_mask = fluid.layers.reshape(iou_mask, (-1, an_num, output_shape[2],
output_shape[3]))
iou_mask.stop_gradient = True
# NOTE: tobj holds gt_score, obj_mask holds object existence mask
obj_mask = fluid.layers.cast(tobj > 0., dtype="float32")
obj_mask.stop_gradient = True
# For positive objectness grids, objectness loss should be calculated
# For negative objectness grids, objectness loss is calculated only iou_mask == 1.0
loss_obj = fluid.layers.sigmoid_cross_entropy_with_logits(obj, obj_mask)
loss_obj_pos = fluid.layers.reduce_sum(loss_obj * tobj, dim=[1, 2, 3])
loss_obj_neg = fluid.layers.reduce_sum(
loss_obj * (1.0 - obj_mask) * iou_mask, dim=[1, 2, 3])
return loss_obj_pos, loss_obj_neg
......@@ -13,9 +13,12 @@
# limitations under the License.
from paddle import fluid
from paddle.fluid.initializer import NumpyArrayInitializer
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from collections import OrderedDict
from .loss import yolo_loss
from .iou_aware import get_iou_aware_score
class YOLOv3:
......@@ -34,7 +37,12 @@ class YOLOv3:
train_random_shapes=[
320, 352, 384, 416, 448, 480, 512, 544, 576, 608
],
fixed_input_shape=None):
fixed_input_shape=None,
use_iou_loss=False,
use_iou_aware_loss=False,
iou_aware_factor=0.4,
use_drop_block=False,
batch_size=8):
if anchors is None:
anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]]
......@@ -56,6 +64,13 @@ class YOLOv3:
self.prefix_name = ''
self.train_random_shapes = train_random_shapes
self.fixed_input_shape = fixed_input_shape
self.use_iou_loss = use_iou_loss
self.use_iou_aware_loss = use_iou_aware_loss
self.iou_aware_factor = iou_aware_factor
self.use_drop_block = use_drop_block
self.block_size = 3
self.keep_prob = 0.9
self.batch_size = batch_size
def _head(self, feats):
outputs = []
......@@ -71,7 +86,10 @@ class YOLOv3:
channel=512 // (2**i),
name=self.prefix_name + 'yolo_block.{}'.format(i))
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5)
if self.use_iou_aware_loss:
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 6)
else:
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5)
block_out = fluid.layers.conv2d(
input=tip,
num_filters=num_filters,
......@@ -155,6 +173,55 @@ class YOLOv3:
out = fluid.layers.resize_nearest(
input=input, scale=float(scale), name=name)
return out
def _dropblock(self, input, block_size=3, keep_prob=0.9):
is_test = False if self.mode == 'train' else True
if is_test:
return input
def calculate_gamma(input, block_size, keep_prob):
input_shape = fluid.layers.shape(input)
feat_shape_tmp = fluid.layers.slice(input_shape, [0], [3], [4])
feat_shape_tmp = fluid.layers.cast(feat_shape_tmp, dtype="float32")
feat_shape_t = fluid.layers.reshape(feat_shape_tmp, [1, 1, 1, 1])
feat_area = fluid.layers.pow(feat_shape_t, factor=2)
block_shape_t = fluid.layers.fill_constant(
shape=[1, 1, 1, 1], value=block_size, dtype='float32')
block_area = fluid.layers.pow(block_shape_t, factor=2)
useful_shape_t = feat_shape_t - block_shape_t + 1
useful_area = fluid.layers.pow(useful_shape_t, factor=2)
upper_t = feat_area * (1 - keep_prob)
bottom_t = block_area * useful_area
output = upper_t / bottom_t
return output
gamma = calculate_gamma(input, block_size=block_size, keep_prob=keep_prob)
input_shape = fluid.layers.shape(input)
p = fluid.layers.expand_as(gamma, input)
input_shape_tmp = fluid.layers.cast(input_shape, dtype="int64")
random_matrix = fluid.layers.uniform_random(
input_shape_tmp, dtype='float32', min=0.0, max=1.0, seed=1)
one_zero_m = fluid.layers.less_than(random_matrix, p)
one_zero_m.stop_gradient = True
one_zero_m = fluid.layers.cast(one_zero_m, dtype="float32")
mask_flag = fluid.layers.pool2d(
one_zero_m,
pool_size=block_size,
pool_type='max',
pool_stride=1,
pool_padding=block_size // 2)
mask = 1.0 - mask_flag
elem_numel = fluid.layers.reduce_prod(input_shape)
elem_numel_m = fluid.layers.cast(elem_numel, dtype="float32")
elem_numel_m.stop_gradient = True
elem_sum = fluid.layers.reduce_sum(mask)
elem_sum_m = fluid.layers.cast(elem_sum, dtype="float32")
elem_sum_m.stop_gradient = True
output = input * mask * elem_numel_m / elem_sum_m
return output
def _detection_block(self, input, channel, name=None):
assert channel % 2 == 0, "channel({}) cannot be divided by 2 in detection block({})".format(
......@@ -179,6 +246,16 @@ class YOLOv3:
padding=1,
is_test=is_test,
name='{}.{}.1'.format(name, i))
if self.use_drop_block and i == 0 and channel != 512:
conv = self._dropblock(
conv,
block_size=self.block_size,
keep_prob=self.keep_prob)
if self.use_drop_block and channel == 512:
conv = self._dropblock(
conv,
block_size=self.block_size,
keep_prob=self.keep_prob)
route = self._conv_bn(
conv,
channel,
......@@ -197,31 +274,28 @@ class YOLOv3:
name='{}.tip'.format(name))
return route, tip
def _get_loss(self, inputs, gt_box, gt_label, gt_score):
def _get_loss(self, inputs, gt_box, gt_label, gt_score, targets):
losses = []
downsample = 32
for i, input in enumerate(inputs):
loss = fluid.layers.yolov3_loss(
x=input,
gt_box=gt_box,
gt_label=gt_label,
gt_score=gt_score,
anchors=self.anchors,
anchor_mask=self.anchor_masks[i],
class_num=self.num_classes,
ignore_thresh=self.ignore_thresh,
downsample_ratio=downsample,
use_label_smooth=self.label_smooth,
name=self.prefix_name + 'yolo_loss' + str(i))
losses.append(fluid.layers.reduce_mean(loss))
downsample //= 2
return sum(losses)
yolo_loss_obj = yolo_loss.YOLOv3Loss(batch_size=self.batch_size,
ignore_thresh=self.ignore_thresh,
label_smooth=self.label_smooth,
iou_loss_weight=2.5 if self.use_iou_loss else None,
iou_aware_loss_weight=1.0 if self.use_iou_aware_loss else None)
return yolo_loss_obj(inputs, gt_box, gt_label, gt_score, targets,
self.anchors, self.anchor_masks,
self.mask_anchors, self.num_classes,
self.prefix_name, max(self.train_random_shapes))
def _get_prediction(self, inputs, im_size):
boxes = []
scores = []
downsample = 32
for i, input in enumerate(inputs):
if self.use_iou_aware_loss:
input = get_iou_aware_score(input,
len(self.anchor_masks[i]),
self.num_classes,
self.iou_aware_factor)
box, score = fluid.layers.yolo_box(
x=input,
img_size=im_size,
......@@ -267,6 +341,12 @@ class YOLOv3:
dtype='float32', shape=[None, None], name='gt_score')
inputs['im_size'] = fluid.data(
dtype='int32', shape=[None, 2], name='im_size')
if self.use_iou_loss or self.use_iou_aware_loss:
for i, mask in enumerate(self.anchor_masks):
inputs['target{}'.format(i)] = fluid.data(
dtype='float32',
shape=[None, len(mask), 6 + self.num_classes, None, None],
name='target{}'.format(i))
elif self.mode == 'eval':
inputs['im_size'] = fluid.data(
dtype='int32', shape=[None, 2], name='im_size')
......@@ -285,22 +365,6 @@ class YOLOv3:
def build_net(self, inputs):
image = inputs['image']
if self.mode == 'train':
if isinstance(self.train_random_shapes,
(list, tuple)) and len(self.train_random_shapes) > 0:
import numpy as np
shapes = np.array(self.train_random_shapes)
shapes = np.stack([shapes, shapes], axis=1).astype('float32')
shapes_tensor = fluid.layers.assign(shapes)
index = fluid.layers.uniform_random(
shape=[1], dtype='float32', min=0.0, max=1)
index = fluid.layers.cast(
index * len(self.train_random_shapes), dtype='int32')
shape = fluid.layers.gather(shapes_tensor, index)
shape = fluid.layers.reshape(shape, [-1])
shape = fluid.layers.cast(shape, dtype='int32')
image = fluid.layers.resize_nearest(
image, out_shape=shape, align_corners=False)
feats = self.backbone(image)
if isinstance(feats, OrderedDict):
feat_names = list(feats.keys())
......@@ -320,8 +384,14 @@ class YOLOv3:
whwh = fluid.layers.cast(whwh, dtype='float32')
whwh.stop_gradient = True
normalized_box = fluid.layers.elementwise_div(gt_box, whwh)
targets = []
if self.use_iou_loss or self.use_iou_aware_loss:
for i, mask in enumerate(self.anchor_masks):
k = 'target{}'.format(i)
if k in inputs:
targets.append(inputs[k])
return self._get_loss(head_outputs, normalized_box, gt_label,
gt_score)
gt_score, targets)
else:
im_size = inputs['im_size']
return self._get_prediction(head_outputs, im_size)
......@@ -49,13 +49,14 @@ class Compose(DetTransform):
ValueError: 数据长度不匹配。
"""
def __init__(self, transforms):
def __init__(self, transforms, batch_transforms=None):
if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!')
if len(transforms) < 1:
raise ValueError('The length of transforms ' + \
'must be equal or larger than 1!')
self.transforms = transforms
self.batch_transforms = batch_transforms
self.use_mixup = False
for t in self.transforms:
if type(t).__name__ == 'MixupImage':
......@@ -498,9 +499,10 @@ class Normalize(DetTransform):
TypeError: 形参数据类型不满足需求。
"""
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], is_scale=True):
self.mean = mean
self.std = std
self.is_scale = is_scale
if not (isinstance(self.mean, list) and isinstance(self.std, list)):
raise TypeError("NormalizeImage: input type is invalid.")
from functools import reduce
......@@ -521,7 +523,7 @@ class Normalize(DetTransform):
"""
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im = normalize(im, mean, std)
im = normalize(im, mean, std, self.is_scale)
if label_info is None:
return (im, im_info)
else:
......@@ -1233,108 +1235,190 @@ class ArrangeYOLOv3(DetTransform):
im_shape = im_info['image_shape']
outputs = (im, im_shape)
return outputs
class RandomShape(DetTransform):
"""调整图像大小(resize)。
对batch数据中的每张图像全部resize到random_shapes中任意一个大小。
注意:当插值方式为“RANDOM”时,则随机选取一种插值方式进行resize。
Args:
random_shapes (list): resize大小选择列表。
默认为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
interp (str): resize的插值方式,与opencv的插值方式对应,取值范围为
['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM']。默认为"RANDOM"。
Raises:
ValueError: 插值方式不在['NEAREST', 'LINEAR', 'CUBIC',
'AREA', 'LANCZOS4', 'RANDOM']中。
"""
# The interpolation mode
interp_dict = {
'NEAREST': cv2.INTER_NEAREST,
'LINEAR': cv2.INTER_LINEAR,
'CUBIC': cv2.INTER_CUBIC,
'AREA': cv2.INTER_AREA,
'LANCZOS4': cv2.INTER_LANCZOS4
}
def __init__(self,
random_shapes=[
320, 352, 384, 416, 448, 480, 512, 544, 576, 608
],
interp='RANDOM'):
if not (interp == "RANDOM" or interp in self.interp_dict):
raise ValueError("interp should be one of {}".format(
self.interp_dict.keys()))
self.random_shapes = random_shapes
self.interp = interp
class ComposedRCNNTransforms(Compose):
""" RCNN模型(faster-rcnn/mask-rcnn)图像处理流程,具体如下,
训练阶段:
1. 随机以0.5的概率将图像水平翻转
2. 图像归一化
3. 图像按比例Resize,scale计算方式如下
scale = min_max_size[0] / short_size_of_image
if max_size_of_image * scale > min_max_size[1]:
scale = min_max_size[1] / max_size_of_image
4. 将3步骤的长宽进行padding,使得长宽为32的倍数
验证阶段:
1. 图像归一化
2. 图像按比例Resize,scale计算方式同上训练阶段
3. 将2步骤的长宽进行padding,使得长宽为32的倍数
def __call__(self, batch_data):
"""
Args:
mode(str): 图像处理流程所处阶段,训练/验证/预测,分别对应'train', 'eval', 'test'
min_max_size(list): 图像在缩放时,最小边和最大边的约束条件
mean(list): 图像均值
std(list): 图像方差
"""
batch_data (list): 由与图像相关的各种信息组成的batch数据。
def __init__(self,
mode,
min_max_size=[800, 1333],
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]):
if mode == 'train':
# 训练时的transforms,包含数据增强
transforms = [
RandomHorizontalFlip(prob=0.5), Normalize(
mean=mean, std=std), ResizeByShort(
short_size=min_max_size[0], max_size=min_max_size[1]),
Padding(coarsest_stride=32)
]
Returns:
list: 由与图像相关的各种信息组成的batch数据。
"""
shape = np.random.choice(self.random_shapes)
if self.interp == "RANDOM":
interp = random.choice(list(self.interp_dict.keys()))
else:
# 验证/预测时的transforms
transforms = [
Normalize(
mean=mean, std=std), ResizeByShort(
short_size=min_max_size[0], max_size=min_max_size[1]),
Padding(coarsest_stride=32)
]
super(ComposedRCNNTransforms, self).__init__(transforms)
class ComposedYOLOTransforms(Compose):
"""YOLOv3模型的图像预处理流程,具体如下,
训练阶段:
1. 在前mixup_epoch轮迭代中,使用MixupImage策略,见https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/det_transforms.html#mixupimage
2. 对图像进行随机扰动,包括亮度,对比度,饱和度和色调
3. 随机扩充图像,见https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/det_transforms.html#randomexpand
4. 随机裁剪图像
5. 将4步骤的输出图像Resize成shape参数的大小
6. 随机0.5的概率水平翻转图像
7. 图像归一化
验证/预测阶段:
1. 将图像Resize成shape参数大小
2. 图像归一化
Args:
mode(str): 图像处理流程所处阶段,训练/验证/预测,分别对应'train', 'eval', 'test'
shape(list): 输入模型中图像的大小,输入模型的图像会被Resize成此大小
mixup_epoch(int): 模型训练过程中,前mixup_epoch会使用mixup策略
mean(list): 图像均值
std(list): 图像方差
interp = self.interp
for data_id, data in enumerate(batch_data):
data_list = list(data)
im = data_list[0]
im = np.swapaxes(im, 1, 0)
im = np.swapaxes(im, 1, 2)
im = resize(im, shape, self.interp_dict[interp])
im = np.swapaxes(im, 1, 2)
im = np.swapaxes(im, 1, 0)
data_list[0] = im
batch_data[data_id] = tuple(data_list)
np.save('im.npy', im)
return batch_data
class GenerateYoloTarget(DetTransform):
"""生成YOLOv3的ground truth(真实标注框)在不同特征层的位置转换信息。
该transform只在YOLOv3计算细粒度loss时使用。
Args:
anchors (list|tuple): anchor框的宽度和高度。
anchor_masks (list|tuple): 在计算损失时,使用anchor的mask索引。
num_classes (int): 类别数。默认为80。
iou_thresh (float): iou阈值,当anchor和真实标注框的iou大于该阈值时,计入target。默认为1.0。
"""
def __init__(self,
anchors,
anchor_masks,
num_classes=80,
iou_thresh=1.):
super(GenerateYoloTarget, self).__init__()
self.anchors = anchors
self.anchor_masks = anchor_masks
self.num_classes = num_classes
self.iou_thresh = iou_thresh
def __call__(self, batch_data):
"""
Args:
batch_data (list): 由与图像相关的各种信息组成的batch数据。
def __init__(self,
mode,
shape=[608, 608],
mixup_epoch=250,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]):
width = shape
if isinstance(shape, list):
if shape[0] != shape[1]:
raise Exception(
"In YOLOv3 model, width and height should be equal")
width = shape[0]
if width % 32 != 0:
raise Exception(
"In YOLOv3 model, width and height should be multiple of 32, e.g 224、256、320...."
)
Returns:
list: 由与图像相关的各种信息组成的batch数据。
其中,每个数据新添加的字段为:
- target0 (np.ndarray): YOLOv3的ground truth在特征层0的位置转换信息,
形状为(特征层0的anchor数量, 6+类别数, 特征层0的h, 特征层0的w)。
- target1 (np.ndarray): YOLOv3的ground truth在特征层1的位置转换信息,
形状为(特征层1的anchor数量, 6+类别数, 特征层1的h, 特征层1的w)。
- ...
-targetn (np.ndarray): YOLOv3的ground truth在特征层n的位置转换信息,
形状为(特征层n的anchor数量, 6+类别数, 特征层n的h, 特征层n的w)。
n的是大小由anchor_masks的长度决定。
"""
im = batch_data[0][0]
h = im.shape[1]
w = im.shape[2]
an_hw = np.array(self.anchors) / np.array([[w, h]])
for data_id, data in enumerate(batch_data):
gt_bbox = data[1]
gt_class = data[2]
gt_score = data[3]
im_shape = data[4]
origin_h = float(im_shape[0])
origin_w = float(im_shape[1])
data_list = list(data)
for i, mask in enumerate(self.anchor_masks):
downsample_ratio = 32 // pow(2, i)
grid_h = int(h / downsample_ratio)
grid_w = int(w / downsample_ratio)
target = np.zeros(
(len(mask), 6 + self.num_classes, grid_h, grid_w),
dtype=np.float32)
for b in range(gt_bbox.shape[0]):
gx = gt_bbox[b, 0] / float(origin_w)
gy = gt_bbox[b, 1] / float(origin_h)
gw = gt_bbox[b, 2] / float(origin_w)
gh = gt_bbox[b, 3] / float(origin_h)
cls = gt_class[b]
score = gt_score[b]
if gw <= 0. or gh <= 0. or score <= 0.:
continue
# find best match anchor index
best_iou = 0.
best_idx = -1
for an_idx in range(an_hw.shape[0]):
iou = jaccard_overlap(
[0., 0., gw, gh],
[0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
if iou > best_iou:
best_iou = iou
best_idx = an_idx
gi = int(gx * grid_w)
gj = int(gy * grid_h)
# gtbox should be regresed in this layes if best match
# anchor index in anchor mask of this layer
if best_idx in mask:
best_n = mask.index(best_idx)
# x, y, w, h, scale
target[best_n, 0, gj, gi] = gx * grid_w - gi
target[best_n, 1, gj, gi] = gy * grid_h - gj
target[best_n, 2, gj, gi] = np.log(
gw * w / self.anchors[best_idx][0])
target[best_n, 3, gj, gi] = np.log(
gh * h / self.anchors[best_idx][1])
target[best_n, 4, gj, gi] = 2.0 - gw * gh
# objectness record gt_score
target[best_n, 5, gj, gi] = score
# classification
target[best_n, 6 + cls, gj, gi] = 1.
# For non-matched anchors, calculate the target if the iou
# between anchor and gt is larger than iou_thresh
if self.iou_thresh < 1:
for idx, mask_i in enumerate(mask):
if mask_i == best_idx: continue
iou = jaccard_overlap(
[0., 0., gw, gh],
[0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
if iou > self.iou_thresh:
# x, y, w, h, scale
target[idx, 0, gj, gi] = gx * grid_w - gi
target[idx, 1, gj, gi] = gy * grid_h - gj
target[idx, 2, gj, gi] = np.log(
gw * w / self.anchors[mask_i][0])
target[idx, 3, gj, gi] = np.log(
gh * h / self.anchors[mask_i][1])
target[idx, 4, gj, gi] = 2.0 - gw * gh
# objectness record gt_score
target[idx, 5, gj, gi] = score
# classification
target[idx, 6 + cls, gj, gi] = 1.
data_list.append(target)
batch_data[data_id] = tuple(data_list)
return batch_data
if mode == 'train':
# 训练时的transforms,包含数据增强
transforms = [
MixupImage(mixup_epoch=mixup_epoch), RandomDistort(),
RandomExpand(), RandomCrop(), Resize(
target_size=width,
interp='RANDOM'), RandomHorizontalFlip(), Normalize(
mean=mean, std=std)
]
else:
# 验证/预测时的transforms
transforms = [
Resize(
target_size=width, interp='CUBIC'), Normalize(
mean=mean, std=std)
]
super(ComposedYOLOTransforms, self).__init__(transforms)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册