提交 ed8de1ae 编写于 作者: F FlyingQianMM

add ppyolo

上级 f1465e6f
......@@ -115,7 +115,7 @@ def multithread_reader(mapper,
while not isinstance(sample, EndSignal):
batch_data.append(sample)
if len(batch_data) == batch_size:
batch_data = generate_minibatch(batch_data)
batch_data = generate_minibatch(batch_data, mapper=mapper)
yield batch_data
batch_data = []
sample = out_queue.get()
......@@ -127,11 +127,11 @@ def multithread_reader(mapper,
else:
batch_data.append(sample)
if len(batch_data) == batch_size:
batch_data = generate_minibatch(batch_data)
batch_data = generate_minibatch(batch_data, mapper=mapper)
yield batch_data
batch_data = []
if not drop_last and len(batch_data) != 0:
batch_data = generate_minibatch(batch_data)
batch_data = generate_minibatch(batch_data, mapper=mapper)
yield batch_data
batch_data = []
......@@ -188,18 +188,21 @@ def multiprocess_reader(mapper,
else:
batch_data.append(sample)
if len(batch_data) == batch_size:
batch_data = generate_minibatch(batch_data)
batch_data = generate_minibatch(batch_data, mapper=mapper)
yield batch_data
batch_data = []
if len(batch_data) != 0 and not drop_last:
batch_data = generate_minibatch(batch_data)
batch_data = generate_minibatch(batch_data, mapper=mapper)
yield batch_data
batch_data = []
return queue_reader
def generate_minibatch(batch_data, label_padding_value=255):
def generate_minibatch(batch_data, label_padding_value=255, mapper=None):
if mapper is not None and mapper.batch_transforms is not None:
for op in mapper.batch_transforms:
batch_data = op(batch_data)
# if batch_size is 1, do not pad the image
if len(batch_data) == 1:
return batch_data
......@@ -218,14 +221,13 @@ def generate_minibatch(batch_data, label_padding_value=255):
(im_c, max_shape[1], max_shape[2]), dtype=np.float32)
padding_im[:, :im_h, :im_w] = data[0]
if len(data) > 2:
# padding the image, label and insert 'padding' into `im_info` of segmentation during evaluating phase.
# padding the image, label and insert 'padding' into `im_info` of segmentation during evaluating phase.
if len(data[1]) == 0 or 'padding' not in [
data[1][i][0] for i in range(len(data[1]))
]:
data[1].append(('padding', [im_h, im_w]))
padding_batch.append((padding_im, data[1], data[2]))
elif len(data) > 1:
if isinstance(data[1], np.ndarray) and len(data[1].shape) > 1:
# padding the image and label of segmentation during the training
......
......@@ -94,6 +94,8 @@ class BaseAPI:
self.train_inputs, self.train_outputs = self.build_net(mode='train')
self.train_prog = fluid.default_main_program()
startup_prog = fluid.default_startup_program()
self.train_prog.random_seed = 1000
startup_prog.random_seed = 1000
# 构建预测网络
self.test_prog = fluid.Program()
......@@ -246,8 +248,8 @@ class BaseAPI:
logging.info(
"Load pretrain weights from {}.".format(pretrain_weights),
use_color=True)
paddlex.utils.utils.load_pretrain_weights(self.exe, self.train_prog,
pretrain_weights, fuse_bn)
paddlex.utils.utils.load_pretrain_weights(
self.exe, self.train_prog, pretrain_weights, fuse_bn)
# 进行裁剪
if sensitivities_file is not None:
import paddleslim
......@@ -351,7 +353,9 @@ class BaseAPI:
logging.info("Model saved in {}.".format(save_dir))
def export_inference_model(self, save_dir):
test_input_names = [var.name for var in list(self.test_inputs.values())]
test_input_names = [
var.name for var in list(self.test_inputs.values())
]
test_outputs = list(self.test_outputs.values())
with fluid.scope_guard(self.scope):
if self.__class__.__name__ == 'MaskRCNN':
......@@ -389,7 +393,8 @@ class BaseAPI:
# 模型保存成功的标志
open(osp.join(save_dir, '.success'), 'w').close()
logging.info("Model for inference deploy saved in {}.".format(save_dir))
logging.info("Model for inference deploy saved in {}.".format(
save_dir))
def train_loop(self,
num_epochs,
......@@ -516,11 +521,13 @@ class BaseAPI:
eta = ((num_epochs - i) * total_num_steps - step - 1
) * avg_step_time
if time_eval_one_epoch is not None:
eval_eta = (total_eval_times - i // save_interval_epochs
) * time_eval_one_epoch
eval_eta = (
total_eval_times - i // save_interval_epochs
) * time_eval_one_epoch
else:
eval_eta = (total_eval_times - i // save_interval_epochs
) * total_num_steps_eval * avg_step_time
eval_eta = (
total_eval_times - i // save_interval_epochs
) * total_num_steps_eval * avg_step_time
eta_str = seconds_to_hms(eta + eval_eta)
logging.info(
......@@ -543,6 +550,8 @@ class BaseAPI:
current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
if not osp.isdir(current_save_dir):
os.makedirs(current_save_dir)
if hasattr(self, 'use_ema'):
self.exe.run(self.ema.apply_program)
if eval_dataset is not None and eval_dataset.num_samples > 0:
self.eval_metrics, self.eval_details = self.evaluate(
eval_dataset=eval_dataset,
......@@ -569,6 +578,8 @@ class BaseAPI:
log_writer.add_scalar(
"Metrics/Eval(Epoch): {}".format(k), v, i + 1)
self.save_model(save_dir=current_save_dir)
if hasattr(self, 'use_ema'):
self.exe.run(self.ema.restore_program)
time_eval_one_epoch = time.time() - eval_epoch_start_time
eval_epoch_start_time = time.time()
if best_model_epoch > 0:
......
......@@ -19,6 +19,8 @@ import os.path as osp
import numpy as np
from multiprocessing.pool import ThreadPool
import paddle.fluid as fluid
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
from paddle.fluid.optimizer import ExponentialMovingAverage
import paddlex.utils.logging as logging
import paddlex
import copy
......@@ -28,6 +30,10 @@ from .base import BaseAPI
from collections import OrderedDict
from .utils.detection_eval import eval_results, bbox2out
import random
random.seed(0)
np.random.seed(0)
class YOLOv3(BaseAPI):
"""构建YOLOv3,并实现其训练、评估、预测和模型导出。
......@@ -50,24 +56,37 @@ class YOLOv3(BaseAPI):
train_random_shapes (list|tuple): 训练时从列表中随机选择图像大小。默认值为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
"""
def __init__(self,
num_classes=80,
backbone='MobileNetV1',
anchors=None,
anchor_masks=None,
ignore_threshold=0.7,
nms_score_threshold=0.01,
nms_topk=1000,
nms_keep_topk=100,
nms_iou_threshold=0.45,
label_smooth=False,
train_random_shapes=[
320, 352, 384, 416, 448, 480, 512, 544, 576, 608
]):
def __init__(
self,
num_classes=80,
backbone='MobileNetV1',
with_dcn_v2=False,
# YOLO Head
anchors=None,
anchor_masks=None,
use_coord_conv=False,
use_iou_aware=False,
use_spp=False,
use_drop_block=False,
scale_x_y=1.0,
# YOLOv3 Loss
ignore_threshold=0.7,
label_smooth=False,
use_iou_loss=False,
# NMS
use_matrix_nms=False,
nms_score_threshold=0.01,
nms_topk=1000,
nms_keep_topk=100,
nms_iou_threshold=0.45,
train_random_shapes=[
320, 352, 384, 416, 448, 480, 512, 544, 576, 608
]):
self.init_params = locals()
super(YOLOv3, self).__init__('detector')
backbones = [
'DarkNet53', 'ResNet34', 'MobileNetV1', 'MobileNetV3_large'
'DarkNet53', 'ResNet34', 'MobileNetV1', 'MobileNetV3_large',
'ResNet50_vd'
]
assert backbone in backbones, "backbone should be one of {}".format(
backbones)
......@@ -75,6 +94,11 @@ class YOLOv3(BaseAPI):
self.num_classes = num_classes
self.anchors = anchors
self.anchor_masks = anchor_masks
if anchors is None:
self.anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]]
if anchor_masks is None:
self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
self.ignore_threshold = ignore_threshold
self.nms_score_threshold = nms_score_threshold
self.nms_topk = nms_topk
......@@ -84,6 +108,20 @@ class YOLOv3(BaseAPI):
self.sync_bn = True
self.train_random_shapes = train_random_shapes
self.fixed_input_shape = None
self.use_fine_grained_loss = False
if use_coord_conv or use_iou_aware or use_spp or use_drop_block or use_iou_loss:
self.use_fine_grained_loss = True
self.use_coord_conv = use_coord_conv
self.use_iou_aware = use_iou_aware
self.use_spp = use_spp
self.use_drop_block = use_drop_block
self.use_iou_loss = use_iou_loss
self.scale_x_y = scale_x_y
self.max_height = 608
self.max_width = 608
self.use_matrix_nms = use_matrix_nms
self.use_ema = False
self.with_dcn_v2 = with_dcn_v2
def _get_backbone(self, backbone_name):
if backbone_name == 'DarkNet53':
......@@ -102,6 +140,16 @@ class YOLOv3(BaseAPI):
model_name = backbone_name.split('_')[1]
backbone = paddlex.cv.nets.MobileNetV3(
norm_type='sync_bn', model_name=model_name)
elif backbone_name == 'ResNet50_vd':
backbone = paddlex.cv.nets.ResNet(
norm_type='sync_bn',
layers=50,
freeze_norm=False,
norm_decay=0.,
feature_maps=[3, 4, 5],
freeze_at=0,
variant='d',
dcn_v2_stages=[5] if self.with_dcn_v2 else [])
return backbone
def build_net(self, mode='train'):
......@@ -117,14 +165,31 @@ class YOLOv3(BaseAPI):
nms_topk=self.nms_topk,
nms_keep_topk=self.nms_keep_topk,
nms_iou_threshold=self.nms_iou_threshold,
train_random_shapes=self.train_random_shapes,
fixed_input_shape=self.fixed_input_shape)
fixed_input_shape=self.fixed_input_shape,
coord_conv=self.use_coord_conv,
iou_aware=self.use_iou_aware,
scale_x_y=self.scale_x_y,
spp=self.use_spp,
drop_block=self.use_drop_block,
use_matrix_nms=self.use_matrix_nms,
use_fine_grained_loss=self.use_fine_grained_loss,
use_iou_loss=self.use_iou_loss,
batch_size=self.batch_size_per_gpu
if hasattr(self, 'batch_size_per_gpu') else 8)
if mode == 'train' and self.use_iou_loss or self.use_iou_aware:
model.max_height = self.max_height
model.max_width = self.max_width
inputs = model.generate_inputs()
model_out = model.build_net(inputs)
outputs = OrderedDict([('bbox', model_out)])
outputs = OrderedDict([('bbox', model_out[0])])
if mode == 'train':
self.optimizer.minimize(model_out)
outputs = OrderedDict([('loss', model_out)])
if self.use_ema:
global_steps = _decay_step_counter()
self.ema = ExponentialMovingAverage(
self.ema_decay, thres_steps=global_steps)
self.ema.update()
return inputs, outputs
def default_optimizer(self, learning_rate, warmup_steps, warmup_start_lr,
......@@ -172,6 +237,8 @@ class YOLOv3(BaseAPI):
warmup_start_lr=0.0,
lr_decay_epochs=[213, 240],
lr_decay_gamma=0.1,
use_ema=False,
ema_decay=0.9998,
metric=None,
use_vdl=False,
sensitivities_file=None,
......@@ -242,6 +309,46 @@ class YOLOv3(BaseAPI):
lr_decay_gamma=lr_decay_gamma,
num_steps_each_epoch=num_steps_each_epoch)
self.optimizer = optimizer
self.use_ema = use_ema
self.ema_decay = ema_decay
self.batch_size_per_gpu = int(train_batch_size /
paddlex.env_info['num'])
if self.use_fine_grained_loss:
for transform in train_dataset.transforms.transforms:
if isinstance(transform, paddlex.det.transforms.Resize):
self.max_height = transform.target_size
self.max_width = transform.target_size
break
if train_dataset.transforms.batch_transforms is None:
train_dataset.transforms.batch_transforms = list()
define_random_shape = False
for bt in train_dataset.transforms.batch_transforms:
if isinstance(bt, paddlex.det.transforms.BatchRandomShape):
define_random_shape = True
if not define_random_shape:
if isinstance(self.train_random_shapes,
(list, tuple)) and len(self.train_random_shapes) > 0:
train_dataset.transforms.batch_transforms.append(
paddlex.det.transforms.BatchRandomShape(
random_shapes=self.train_random_shapes))
if self.use_fine_grained_loss:
self.max_height = max(self.max_height,
max(self.train_random_shapes))
self.max_width = max(self.max_width,
max(self.train_random_shapes))
if self.use_fine_grained_loss:
define_generate_target = False
for bt in train_dataset.transforms.batch_transforms:
if isinstance(bt, paddlex.det.transforms.GenerateYoloTarget):
define_generate_target = True
if not define_generate_target:
train_dataset.transforms.batch_transforms.append(
paddlex.det.transforms.GenerateYoloTarget(
anchors=self.anchors,
anchor_masks=self.anchor_masks,
num_classes=self.num_classes,
downsample_ratios=[32, 16, 8]))
# 构建训练、验证、预测网络
self.build_program()
# 初始化网络权重
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
def _split_ioup(output, an_num, num_classes):
"""
Split new output feature map to output, predicted iou
along channel dimension
"""
ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num])
ioup = fluid.layers.sigmoid(ioup)
oriout = fluid.layers.slice(
output, axes=[1], starts=[an_num], ends=[an_num * (num_classes + 6)])
return (ioup, oriout)
def _de_sigmoid(x, eps=1e-7):
x = fluid.layers.clip(x, eps, 1 / eps)
one = fluid.layers.fill_constant(
shape=[1, 1, 1, 1], dtype=x.dtype, value=1.)
x = fluid.layers.clip((one / x - 1.0), eps, 1 / eps)
x = -fluid.layers.log(x)
return x
def _postprocess_output(ioup, output, an_num, num_classes, iou_aware_factor):
"""
post process output objectness score
"""
tensors = []
stride = output.shape[1] // an_num
for m in range(an_num):
tensors.append(
fluid.layers.slice(
output,
axes=[1],
starts=[stride * m + 0],
ends=[stride * m + 4]))
obj = fluid.layers.slice(
output, axes=[1], starts=[stride * m + 4], ends=[stride * m + 5])
obj = fluid.layers.sigmoid(obj)
ip = fluid.layers.slice(ioup, axes=[1], starts=[m], ends=[m + 1])
new_obj = fluid.layers.pow(obj, (
1 - iou_aware_factor)) * fluid.layers.pow(ip, iou_aware_factor)
new_obj = _de_sigmoid(new_obj)
tensors.append(new_obj)
tensors.append(
fluid.layers.slice(
output,
axes=[1],
starts=[stride * m + 5],
ends=[stride * m + 5 + num_classes]))
output = fluid.layers.concat(tensors, axis=1)
return output
def get_iou_aware_score(output, an_num, num_classes, iou_aware_factor):
ioup, output = _split_ioup(output, an_num, num_classes)
output = _postprocess_output(ioup, output, an_num, num_classes,
iou_aware_factor)
return output
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import NumpyArrayInitializer
from paddle import fluid
from .iou_loss import IouLoss
class IouAwareLoss(IouLoss):
"""
iou aware loss, see https://arxiv.org/abs/1912.05992
Args:
loss_weight (float): iou aware loss weight, default is 1.0
max_height (int): max height of input to support random shape input
max_width (int): max width of input to support random shape input
"""
def __init__(self, loss_weight=1.0, max_height=608, max_width=608):
super(IouAwareLoss, self).__init__(
loss_weight=loss_weight,
max_height=max_height,
max_width=max_width)
def __call__(self,
ioup,
x,
y,
w,
h,
tx,
ty,
tw,
th,
anchors,
downsample_ratio,
batch_size,
scale_x_y,
eps=1.e-10):
'''
Args:
ioup ([Variables]): the predicted iou
x | y | w | h ([Variables]): the output of yolov3 for encoded x|y|w|h
tx |ty |tw |th ([Variables]): the target of yolov3 for encoded x|y|w|h
anchors ([float]): list of anchors for current output layer
downsample_ratio (float): the downsample ratio for current output layer
batch_size (int): training batch size
eps (float): the decimal to prevent the denominator eqaul zero
'''
pred = self._bbox_transform(x, y, w, h, anchors, downsample_ratio,
batch_size, False, scale_x_y, eps)
gt = self._bbox_transform(tx, ty, tw, th, anchors, downsample_ratio,
batch_size, True, scale_x_y, eps)
iouk = self._iou(pred, gt, ioup, eps)
iouk.stop_gradient = True
loss_iou_aware = fluid.layers.cross_entropy(
ioup, iouk, soft_label=True)
loss_iou_aware = loss_iou_aware * self._loss_weight
return loss_iou_aware
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import NumpyArrayInitializer
from paddle import fluid
class IouLoss(object):
"""
iou loss, see https://arxiv.org/abs/1908.03851
loss = 1.0 - iou * iou
Args:
loss_weight (float): iou loss weight, default is 2.5
max_height (int): max height of input to support random shape input
max_width (int): max width of input to support random shape input
ciou_term (bool): whether to add ciou_term
loss_square (bool): whether to square the iou term
"""
def __init__(self,
loss_weight=2.5,
max_height=608,
max_width=608,
ciou_term=False,
loss_square=True):
self._loss_weight = loss_weight
self._MAX_HI = max_height
self._MAX_WI = max_width
self.ciou_term = ciou_term
self.loss_square = loss_square
def __call__(self,
x,
y,
w,
h,
tx,
ty,
tw,
th,
anchors,
downsample_ratio,
batch_size,
scale_x_y=1.,
ioup=None,
eps=1.e-10):
'''
Args:
x | y | w | h ([Variables]): the output of yolov3 for encoded x|y|w|h
tx |ty |tw |th ([Variables]): the target of yolov3 for encoded x|y|w|h
anchors ([float]): list of anchors for current output layer
downsample_ratio (float): the downsample ratio for current output layer
batch_size (int): training batch size
eps (float): the decimal to prevent the denominator eqaul zero
'''
pred = self._bbox_transform(x, y, w, h, anchors, downsample_ratio,
batch_size, False, scale_x_y, eps)
gt = self._bbox_transform(tx, ty, tw, th, anchors, downsample_ratio,
batch_size, True, scale_x_y, eps)
iouk = self._iou(pred, gt, ioup, eps)
if self.loss_square:
loss_iou = 1. - iouk * iouk
else:
loss_iou = 1. - iouk
loss_iou = loss_iou * self._loss_weight
return loss_iou
def _iou(self, pred, gt, ioup=None, eps=1.e-10):
x1, y1, x2, y2 = pred
x1g, y1g, x2g, y2g = gt
x2 = fluid.layers.elementwise_max(x1, x2)
y2 = fluid.layers.elementwise_max(y1, y2)
xkis1 = fluid.layers.elementwise_max(x1, x1g)
ykis1 = fluid.layers.elementwise_max(y1, y1g)
xkis2 = fluid.layers.elementwise_min(x2, x2g)
ykis2 = fluid.layers.elementwise_min(y2, y2g)
intsctk = (xkis2 - xkis1) * (ykis2 - ykis1)
intsctk = intsctk * fluid.layers.greater_than(
xkis2, xkis1) * fluid.layers.greater_than(ykis2, ykis1)
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g
) - intsctk + eps
iouk = intsctk / unionk
if self.ciou_term:
ciou = self.get_ciou_term(pred, gt, iouk, eps)
iouk = iouk - ciou
return iouk
def get_ciou_term(self, pred, gt, iouk, eps):
x1, y1, x2, y2 = pred
x1g, y1g, x2g, y2g = gt
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = (x2 - x1) + fluid.layers.cast((x2 - x1) == 0, 'float32')
h = (y2 - y1) + fluid.layers.cast((y2 - y1) == 0, 'float32')
cxg = (x1g + x2g) / 2
cyg = (y1g + y2g) / 2
wg = x2g - x1g
hg = y2g - y1g
# A or B
xc1 = fluid.layers.elementwise_min(x1, x1g)
yc1 = fluid.layers.elementwise_min(y1, y1g)
xc2 = fluid.layers.elementwise_max(x2, x2g)
yc2 = fluid.layers.elementwise_max(y2, y2g)
# DIOU term
dist_intersection = (cx - cxg) * (cx - cxg) + (cy - cyg) * (cy - cyg)
dist_union = (xc2 - xc1) * (xc2 - xc1) + (yc2 - yc1) * (yc2 - yc1)
diou_term = (dist_intersection + eps) / (dist_union + eps)
# CIOU term
ciou_term = 0
ar_gt = wg / hg
ar_pred = w / h
arctan = fluid.layers.atan(ar_gt) - fluid.layers.atan(ar_pred)
ar_loss = 4. / np.pi / np.pi * arctan * arctan
alpha = ar_loss / (1 - iouk + ar_loss + eps)
alpha.stop_gradient = True
ciou_term = alpha * ar_loss
return diou_term + ciou_term
def _bbox_transform(self, dcx, dcy, dw, dh, anchors, downsample_ratio,
batch_size, is_gt, scale_x_y, eps):
grid_x = int(self._MAX_WI / downsample_ratio)
grid_y = int(self._MAX_HI / downsample_ratio)
an_num = len(anchors) // 2
shape_fmp = fluid.layers.shape(dcx)
shape_fmp.stop_gradient = True
# generate the grid_w x grid_h center of feature map
idx_i = np.array([[i for i in range(grid_x)]])
idx_j = np.array([[j for j in range(grid_y)]]).transpose()
gi_np = np.repeat(idx_i, grid_y, axis=0)
gi_np = np.reshape(gi_np, newshape=[1, 1, grid_y, grid_x])
gi_np = np.tile(gi_np, reps=[batch_size, an_num, 1, 1])
gj_np = np.repeat(idx_j, grid_x, axis=1)
gj_np = np.reshape(gj_np, newshape=[1, 1, grid_y, grid_x])
gj_np = np.tile(gj_np, reps=[batch_size, an_num, 1, 1])
gi_max = self._create_tensor_from_numpy(gi_np.astype(np.float32))
gi = fluid.layers.crop(x=gi_max, shape=dcx)
gi.stop_gradient = True
gj_max = self._create_tensor_from_numpy(gj_np.astype(np.float32))
gj = fluid.layers.crop(x=gj_max, shape=dcx)
gj.stop_gradient = True
grid_x_act = fluid.layers.cast(shape_fmp[3], dtype="float32")
grid_x_act.stop_gradient = True
grid_y_act = fluid.layers.cast(shape_fmp[2], dtype="float32")
grid_y_act.stop_gradient = True
if is_gt:
cx = fluid.layers.elementwise_add(dcx, gi) / grid_x_act
cx.gradient = True
cy = fluid.layers.elementwise_add(dcy, gj) / grid_y_act
cy.gradient = True
else:
dcx_sig = fluid.layers.sigmoid(dcx)
dcy_sig = fluid.layers.sigmoid(dcy)
if (abs(scale_x_y - 1.0) > eps):
dcx_sig = scale_x_y * dcx_sig - 0.5 * (scale_x_y - 1)
dcy_sig = scale_x_y * dcy_sig - 0.5 * (scale_x_y - 1)
cx = fluid.layers.elementwise_add(dcx_sig, gi) / grid_x_act
cy = fluid.layers.elementwise_add(dcy_sig, gj) / grid_y_act
anchor_w_ = [anchors[i] for i in range(0, len(anchors)) if i % 2 == 0]
anchor_w_np = np.array(anchor_w_)
anchor_w_np = np.reshape(anchor_w_np, newshape=[1, an_num, 1, 1])
anchor_w_np = np.tile(
anchor_w_np, reps=[batch_size, 1, grid_y, grid_x])
anchor_w_max = self._create_tensor_from_numpy(
anchor_w_np.astype(np.float32))
anchor_w = fluid.layers.crop(x=anchor_w_max, shape=dcx)
anchor_w.stop_gradient = True
anchor_h_ = [anchors[i] for i in range(0, len(anchors)) if i % 2 == 1]
anchor_h_np = np.array(anchor_h_)
anchor_h_np = np.reshape(anchor_h_np, newshape=[1, an_num, 1, 1])
anchor_h_np = np.tile(
anchor_h_np, reps=[batch_size, 1, grid_y, grid_x])
anchor_h_max = self._create_tensor_from_numpy(
anchor_h_np.astype(np.float32))
anchor_h = fluid.layers.crop(x=anchor_h_max, shape=dcx)
anchor_h.stop_gradient = True
# e^tw e^th
exp_dw = fluid.layers.exp(dw)
exp_dh = fluid.layers.exp(dh)
pw = fluid.layers.elementwise_mul(exp_dw, anchor_w) / \
(grid_x_act * downsample_ratio)
ph = fluid.layers.elementwise_mul(exp_dh, anchor_h) / \
(grid_y_act * downsample_ratio)
if is_gt:
exp_dw.stop_gradient = True
exp_dh.stop_gradient = True
pw.stop_gradient = True
ph.stop_gradient = True
x1 = cx - 0.5 * pw
y1 = cy - 0.5 * ph
x2 = cx + 0.5 * pw
y2 = cy + 0.5 * ph
if is_gt:
x1.stop_gradient = True
y1.stop_gradient = True
x2.stop_gradient = True
y2.stop_gradient = True
return x1, y1, x2, y2
def _create_tensor_from_numpy(self, numpy_array):
paddle_array = fluid.layers.create_parameter(
attr=ParamAttr(),
shape=numpy_array.shape,
dtype=numpy_array.dtype,
default_initializer=NumpyArrayInitializer(numpy_array))
paddle_array.stop_gradient = True
return paddle_array
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
try:
from collections.abc import Sequence
except Exception:
from collections import Sequence
class YOLOv3Loss(object):
"""
Combined loss for YOLOv3 network
Args:
batch_size (int): training batch size
ignore_thresh (float): threshold to ignore confidence loss
label_smooth (bool): whether to use label smoothing
use_fine_grained_loss (bool): whether use fine grained YOLOv3 loss
instead of fluid.layers.yolov3_loss
"""
def __init__(self,
batch_size=8,
ignore_thresh=0.7,
label_smooth=True,
use_fine_grained_loss=False,
iou_loss=None,
iou_aware_loss=None,
downsample=[32, 16, 8],
scale_x_y=1.,
match_score=False):
self._batch_size = batch_size
self._ignore_thresh = ignore_thresh
self._label_smooth = label_smooth
self._use_fine_grained_loss = use_fine_grained_loss
self._iou_loss = iou_loss
self._iou_aware_loss = iou_aware_loss
self.downsample = downsample
self.scale_x_y = scale_x_y
self.match_score = match_score
def __call__(self, outputs, gt_box, gt_label, gt_score, targets, anchors,
anchor_masks, mask_anchors, num_classes, prefix_name):
if self._use_fine_grained_loss:
return self._get_fine_grained_loss(
outputs, targets, gt_box, self._batch_size, num_classes,
mask_anchors, self._ignore_thresh)
else:
losses = []
for i, output in enumerate(outputs):
scale_x_y = self.scale_x_y if not isinstance(
self.scale_x_y, Sequence) else self.scale_x_y[i]
anchor_mask = anchor_masks[i]
loss = fluid.layers.yolov3_loss(
x=output,
gt_box=gt_box,
gt_label=gt_label,
gt_score=gt_score,
anchors=anchors,
anchor_mask=anchor_mask,
class_num=num_classes,
ignore_thresh=self._ignore_thresh,
downsample_ratio=self.downsample[i],
use_label_smooth=self._label_smooth,
scale_x_y=scale_x_y,
name=prefix_name + "yolo_loss" + str(i))
losses.append(fluid.layers.reduce_mean(loss))
return {'loss': sum(losses)}
def _get_fine_grained_loss(self,
outputs,
targets,
gt_box,
batch_size,
num_classes,
mask_anchors,
ignore_thresh,
eps=1.e-10):
"""
Calculate fine grained YOLOv3 loss
Args:
outputs ([Variables]): List of Variables, output of backbone stages
targets ([Variables]): List of Variables, The targets for yolo
loss calculatation.
gt_box (Variable): The ground-truth boudding boxes.
batch_size (int): The training batch size
num_classes (int): class num of dataset
mask_anchors ([[float]]): list of anchors in each output layer
ignore_thresh (float): prediction bbox overlap any gt_box greater
than ignore_thresh, objectness loss will
be ignored.
Returns:
Type: dict
xy_loss (Variable): YOLOv3 (x, y) coordinates loss
wh_loss (Variable): YOLOv3 (w, h) coordinates loss
obj_loss (Variable): YOLOv3 objectness score loss
cls_loss (Variable): YOLOv3 classification loss
"""
assert len(outputs) == len(targets), \
"YOLOv3 output layer number not equal target number"
loss_xys, loss_whs, loss_objs, loss_clss = [], [], [], []
if self._iou_loss is not None:
loss_ious = []
if self._iou_aware_loss is not None:
loss_iou_awares = []
for i, (output, target,
anchors) in enumerate(zip(outputs, targets, mask_anchors)):
downsample = self.downsample[i]
an_num = len(anchors) // 2
if self._iou_aware_loss is not None:
ioup, output = self._split_ioup(output, an_num, num_classes)
x, y, w, h, obj, cls = self._split_output(output, an_num,
num_classes)
tx, ty, tw, th, tscale, tobj, tcls = self._split_target(target)
tscale_tobj = tscale * tobj
scale_x_y = self.scale_x_y if not isinstance(
self.scale_x_y, Sequence) else self.scale_x_y[i]
if (abs(scale_x_y - 1.0) < eps):
loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(
x, tx) * tscale_tobj
loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
loss_y = fluid.layers.sigmoid_cross_entropy_with_logits(
y, ty) * tscale_tobj
loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
else:
dx = scale_x_y * fluid.layers.sigmoid(x) - 0.5 * (scale_x_y -
1.0)
dy = scale_x_y * fluid.layers.sigmoid(y) - 0.5 * (scale_x_y -
1.0)
loss_x = fluid.layers.abs(dx - tx) * tscale_tobj
loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
loss_y = fluid.layers.abs(dy - ty) * tscale_tobj
loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
# NOTE: we refined loss function of (w, h) as L1Loss
loss_w = fluid.layers.abs(w - tw) * tscale_tobj
loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3])
loss_h = fluid.layers.abs(h - th) * tscale_tobj
loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3])
if self._iou_loss is not None:
loss_iou = self._iou_loss(x, y, w, h, tx, ty, tw, th, anchors,
downsample, self._batch_size,
scale_x_y)
loss_iou = loss_iou * tscale_tobj
loss_iou = fluid.layers.reduce_sum(loss_iou, dim=[1, 2, 3])
loss_ious.append(fluid.layers.reduce_mean(loss_iou))
if self._iou_aware_loss is not None:
loss_iou_aware = self._iou_aware_loss(
ioup, x, y, w, h, tx, ty, tw, th, anchors, downsample,
self._batch_size, scale_x_y)
loss_iou_aware = loss_iou_aware * tobj
loss_iou_aware = fluid.layers.reduce_sum(
loss_iou_aware, dim=[1, 2, 3])
loss_iou_awares.append(
fluid.layers.reduce_mean(loss_iou_aware))
loss_obj_pos, loss_obj_neg = self._calc_obj_loss(
output, obj, tobj, gt_box, self._batch_size, anchors,
num_classes, downsample, self._ignore_thresh, scale_x_y)
loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls,
tcls)
loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0)
loss_cls = fluid.layers.reduce_sum(loss_cls, dim=[1, 2, 3, 4])
loss_xys.append(fluid.layers.reduce_mean(loss_x + loss_y))
loss_whs.append(fluid.layers.reduce_mean(loss_w + loss_h))
loss_objs.append(
fluid.layers.reduce_mean(loss_obj_pos + loss_obj_neg))
loss_clss.append(fluid.layers.reduce_mean(loss_cls))
losses_all = {
"loss_xy": fluid.layers.sum(loss_xys),
"loss_wh": fluid.layers.sum(loss_whs),
"loss_obj": fluid.layers.sum(loss_objs),
"loss_cls": fluid.layers.sum(loss_clss),
}
if self._iou_loss is not None:
losses_all["loss_iou"] = fluid.layers.sum(loss_ious)
if self._iou_aware_loss is not None:
losses_all["loss_iou_aware"] = fluid.layers.sum(loss_iou_awares)
return losses_all
def _split_ioup(self, output, an_num, num_classes):
"""
Split output feature map to output, predicted iou
along channel dimension
"""
ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num])
ioup = fluid.layers.sigmoid(ioup)
oriout = fluid.layers.slice(
output,
axes=[1],
starts=[an_num],
ends=[an_num * (num_classes + 6)])
return (ioup, oriout)
def _split_output(self, output, an_num, num_classes):
"""
Split output feature map to x, y, w, h, objectness, classification
along channel dimension
"""
x = fluid.layers.strided_slice(
output,
axes=[1],
starts=[0],
ends=[output.shape[1]],
strides=[5 + num_classes])
y = fluid.layers.strided_slice(
output,
axes=[1],
starts=[1],
ends=[output.shape[1]],
strides=[5 + num_classes])
w = fluid.layers.strided_slice(
output,
axes=[1],
starts=[2],
ends=[output.shape[1]],
strides=[5 + num_classes])
h = fluid.layers.strided_slice(
output,
axes=[1],
starts=[3],
ends=[output.shape[1]],
strides=[5 + num_classes])
obj = fluid.layers.strided_slice(
output,
axes=[1],
starts=[4],
ends=[output.shape[1]],
strides=[5 + num_classes])
clss = []
stride = output.shape[1] // an_num
for m in range(an_num):
clss.append(
fluid.layers.slice(
output,
axes=[1],
starts=[stride * m + 5],
ends=[stride * m + 5 + num_classes]))
cls = fluid.layers.transpose(
fluid.layers.stack(
clss, axis=1), perm=[0, 1, 3, 4, 2])
return (x, y, w, h, obj, cls)
def _split_target(self, target):
"""
split target to x, y, w, h, objectness, classification
along dimension 2
target is in shape [N, an_num, 6 + class_num, H, W]
"""
tx = target[:, :, 0, :, :]
ty = target[:, :, 1, :, :]
tw = target[:, :, 2, :, :]
th = target[:, :, 3, :, :]
tscale = target[:, :, 4, :, :]
tobj = target[:, :, 5, :, :]
tcls = fluid.layers.transpose(
target[:, :, 6:, :, :], perm=[0, 1, 3, 4, 2])
tcls.stop_gradient = True
return (tx, ty, tw, th, tscale, tobj, tcls)
def _calc_obj_loss(self, output, obj, tobj, gt_box, batch_size, anchors,
num_classes, downsample, ignore_thresh, scale_x_y):
# A prediction bbox overlap any gt_bbox over ignore_thresh,
# objectness loss will be ignored, process as follows:
# 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here
# NOTE: img_size is set as 1.0 to get noramlized pred bbox
bbox, prob = fluid.layers.yolo_box(
x=output,
img_size=fluid.layers.ones(
shape=[batch_size, 2], dtype="int32"),
anchors=anchors,
class_num=num_classes,
conf_thresh=0.,
downsample_ratio=downsample,
clip_bbox=False,
scale_x_y=scale_x_y)
# 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox
# and gt bbox in each sample
if batch_size > 1:
preds = fluid.layers.split(bbox, batch_size, dim=0)
gts = fluid.layers.split(gt_box, batch_size, dim=0)
else:
preds = [bbox]
gts = [gt_box]
probs = [prob]
ious = []
for pred, gt in zip(preds, gts):
def box_xywh2xyxy(box):
x = box[:, 0]
y = box[:, 1]
w = box[:, 2]
h = box[:, 3]
return fluid.layers.stack(
[
x - w / 2.,
y - h / 2.,
x + w / 2.,
y + h / 2.,
], axis=1)
pred = fluid.layers.squeeze(pred, axes=[0])
gt = box_xywh2xyxy(fluid.layers.squeeze(gt, axes=[0]))
ious.append(fluid.layers.iou_similarity(pred, gt))
iou = fluid.layers.stack(ious, axis=0)
# 3. Get iou_mask by IoU between gt bbox and prediction bbox,
# Get obj_mask by tobj(holds gt_score), calculate objectness loss
max_iou = fluid.layers.reduce_max(iou, dim=-1)
iou_mask = fluid.layers.cast(max_iou <= ignore_thresh, dtype="float32")
if self.match_score:
max_prob = fluid.layers.reduce_max(prob, dim=-1)
iou_mask = iou_mask * fluid.layers.cast(
max_prob <= 0.25, dtype="float32")
output_shape = fluid.layers.shape(output)
an_num = len(anchors) // 2
iou_mask = fluid.layers.reshape(iou_mask, (-1, an_num, output_shape[2],
output_shape[3]))
iou_mask.stop_gradient = True
# NOTE: tobj holds gt_score, obj_mask holds object existence mask
obj_mask = fluid.layers.cast(tobj > 0., dtype="float32")
obj_mask.stop_gradient = True
# For positive objectness grids, objectness loss should be calculated
# For negative objectness grids, objectness loss is calculated only iou_mask == 1.0
loss_obj = fluid.layers.sigmoid_cross_entropy_with_logits(obj,
obj_mask)
loss_obj_pos = fluid.layers.reduce_sum(loss_obj * tobj, dim=[1, 2, 3])
loss_obj_neg = fluid.layers.reduce_sum(
loss_obj * (1.0 - obj_mask) * iou_mask, dim=[1, 2, 3])
return loss_obj_pos, loss_obj_neg
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from numbers import Integral
import math
import six
import paddle
from paddle import fluid
def DropBlock(input, block_size, keep_prob, is_test):
if is_test:
return input
def CalculateGamma(input, block_size, keep_prob):
input_shape = fluid.layers.shape(input)
feat_shape_tmp = fluid.layers.slice(input_shape, [0], [3], [4])
feat_shape_tmp = fluid.layers.cast(feat_shape_tmp, dtype="float32")
feat_shape_t = fluid.layers.reshape(feat_shape_tmp, [1, 1, 1, 1])
feat_area = fluid.layers.pow(feat_shape_t, factor=2)
block_shape_t = fluid.layers.fill_constant(
shape=[1, 1, 1, 1], value=block_size, dtype='float32')
block_area = fluid.layers.pow(block_shape_t, factor=2)
useful_shape_t = feat_shape_t - block_shape_t + 1
useful_area = fluid.layers.pow(useful_shape_t, factor=2)
upper_t = feat_area * (1 - keep_prob)
bottom_t = block_area * useful_area
output = upper_t / bottom_t
return output
gamma = CalculateGamma(input, block_size=block_size, keep_prob=keep_prob)
input_shape = fluid.layers.shape(input)
p = fluid.layers.expand_as(gamma, input)
input_shape_tmp = fluid.layers.cast(input_shape, dtype="int64")
random_matrix = fluid.layers.uniform_random(
input_shape_tmp, dtype='float32', min=0.0, max=1.0, seed=1000)
one_zero_m = fluid.layers.less_than(random_matrix, p)
one_zero_m.stop_gradient = True
one_zero_m = fluid.layers.cast(one_zero_m, dtype="float32")
mask_flag = fluid.layers.pool2d(
one_zero_m,
pool_size=block_size,
pool_type='max',
pool_stride=1,
pool_padding=block_size // 2)
mask = 1.0 - mask_flag
elem_numel = fluid.layers.reduce_prod(input_shape)
elem_numel_m = fluid.layers.cast(elem_numel, dtype="float32")
elem_numel_m.stop_gradient = True
elem_sum = fluid.layers.reduce_sum(mask)
elem_sum_m = fluid.layers.cast(elem_sum, dtype="float32")
elem_sum_m.stop_gradient = True
output = input * mask * elem_numel_m / elem_sum_m
return output
class MultiClassNMS(object):
def __init__(self,
score_threshold=.05,
nms_top_k=-1,
keep_top_k=100,
nms_threshold=.5,
normalized=False,
nms_eta=1.0,
background_label=0):
super(MultiClassNMS, self).__init__()
self.score_threshold = score_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
self.nms_threshold = nms_threshold
self.normalized = normalized
self.nms_eta = nms_eta
self.background_label = background_label
def __call__(self, bboxes, scores):
return fluid.layers.multiclass_nms(
bboxes=bboxes,
scores=scores,
score_threshold=self.score_threshold,
nms_top_k=self.nms_top_k,
keep_top_k=self.keep_top_k,
normalized=self.normalized,
nms_threshold=self.nms_threshold,
nms_eta=self.nms_eta,
background_label=self.background_label)
class MatrixNMS(object):
def __init__(self,
score_threshold=.05,
post_threshold=.05,
nms_top_k=-1,
keep_top_k=100,
use_gaussian=False,
gaussian_sigma=2.,
normalized=False,
background_label=0):
super(MatrixNMS, self).__init__()
self.score_threshold = score_threshold
self.post_threshold = post_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
self.normalized = normalized
self.use_gaussian = use_gaussian
self.gaussian_sigma = gaussian_sigma
self.background_label = background_label
def __call__(self, bboxes, scores):
return paddle.fluid.layers.matrix_nms(
bboxes=bboxes,
scores=scores,
score_threshold=self.score_threshold,
post_threshold=self.post_threshold,
nms_top_k=self.nms_top_k,
keep_top_k=self.keep_top_k,
normalized=self.normalized,
use_gaussian=self.use_gaussian,
gaussian_sigma=self.gaussian_sigma,
background_label=self.background_label)
class MultiClassSoftNMS(object):
def __init__(
self,
score_threshold=0.01,
keep_top_k=300,
softnms_sigma=0.5,
normalized=False,
background_label=0, ):
super(MultiClassSoftNMS, self).__init__()
self.score_threshold = score_threshold
self.keep_top_k = keep_top_k
self.softnms_sigma = softnms_sigma
self.normalized = normalized
self.background_label = background_label
def __call__(self, bboxes, scores):
def create_tmp_var(program, name, dtype, shape, lod_level):
return program.current_block().create_var(
name=name, dtype=dtype, shape=shape, lod_level=lod_level)
def _soft_nms_for_cls(dets, sigma, thres):
"""soft_nms_for_cls"""
dets_final = []
while len(dets) > 0:
maxpos = np.argmax(dets[:, 0])
dets_final.append(dets[maxpos].copy())
ts, tx1, ty1, tx2, ty2 = dets[maxpos]
scores = dets[:, 0]
# force remove bbox at maxpos
scores[maxpos] = -1
x1 = dets[:, 1]
y1 = dets[:, 2]
x2 = dets[:, 3]
y2 = dets[:, 4]
eta = 0 if self.normalized else 1
areas = (x2 - x1 + eta) * (y2 - y1 + eta)
xx1 = np.maximum(tx1, x1)
yy1 = np.maximum(ty1, y1)
xx2 = np.minimum(tx2, x2)
yy2 = np.minimum(ty2, y2)
w = np.maximum(0.0, xx2 - xx1 + eta)
h = np.maximum(0.0, yy2 - yy1 + eta)
inter = w * h
ovr = inter / (areas + areas[maxpos] - inter)
weight = np.exp(-(ovr * ovr) / sigma)
scores = scores * weight
idx_keep = np.where(scores >= thres)
dets[:, 0] = scores
dets = dets[idx_keep]
dets_final = np.array(dets_final).reshape(-1, 5)
return dets_final
def _soft_nms(bboxes, scores):
class_nums = scores.shape[-1]
softnms_thres = self.score_threshold
softnms_sigma = self.softnms_sigma
keep_top_k = self.keep_top_k
cls_boxes = [[] for _ in range(class_nums)]
cls_ids = [[] for _ in range(class_nums)]
start_idx = 1 if self.background_label == 0 else 0
for j in range(start_idx, class_nums):
inds = np.where(scores[:, j] >= softnms_thres)[0]
scores_j = scores[inds, j]
rois_j = bboxes[inds, j, :] if len(
bboxes.shape) > 2 else bboxes[inds, :]
dets_j = np.hstack((scores_j[:, np.newaxis], rois_j)).astype(
np.float32, copy=False)
cls_rank = np.argsort(-dets_j[:, 0])
dets_j = dets_j[cls_rank]
cls_boxes[j] = _soft_nms_for_cls(
dets_j, sigma=softnms_sigma, thres=softnms_thres)
cls_ids[j] = np.array([j] * cls_boxes[j].shape[0]).reshape(-1,
1)
cls_boxes = np.vstack(cls_boxes[start_idx:])
cls_ids = np.vstack(cls_ids[start_idx:])
pred_result = np.hstack([cls_ids, cls_boxes])
# Limit to max_per_image detections **over all classes**
image_scores = cls_boxes[:, 0]
if len(image_scores) > keep_top_k:
image_thresh = np.sort(image_scores)[-keep_top_k]
keep = np.where(cls_boxes[:, 0] >= image_thresh)[0]
pred_result = pred_result[keep, :]
return pred_result
def _batch_softnms(bboxes, scores):
batch_offsets = bboxes.lod()
bboxes = np.array(bboxes)
scores = np.array(scores)
out_offsets = [0]
pred_res = []
if len(batch_offsets) > 0:
batch_offset = batch_offsets[0]
for i in range(len(batch_offset) - 1):
s, e = batch_offset[i], batch_offset[i + 1]
pred = _soft_nms(bboxes[s:e], scores[s:e])
out_offsets.append(pred.shape[0] + out_offsets[-1])
pred_res.append(pred)
else:
assert len(bboxes.shape) == 3
assert len(scores.shape) == 3
for i in range(bboxes.shape[0]):
pred = _soft_nms(bboxes[i], scores[i])
out_offsets.append(pred.shape[0] + out_offsets[-1])
pred_res.append(pred)
res = fluid.LoDTensor()
res.set_lod([out_offsets])
if len(pred_res) == 0:
pred_res = np.array([[1]], dtype=np.float32)
res.set(np.vstack(pred_res).astype(np.float32), fluid.CPUPlace())
return res
pred_result = create_tmp_var(
fluid.default_main_program(),
name='softnms_pred_result',
dtype='float32',
shape=[-1, 6],
lod_level=1)
fluid.layers.py_func(
func=_batch_softnms, x=[bboxes, scores], out=pred_result)
return pred_result
......@@ -55,6 +55,7 @@ class Compose(DetTransform):
raise ValueError('The length of transforms ' + \
'must be equal or larger than 1!')
self.transforms = transforms
self.batch_transforms = None
self.use_mixup = False
for t in self.transforms:
if type(t).__name__ == 'MixupImage':
......@@ -1385,3 +1386,187 @@ class ComposedYOLOv3Transforms(Compose):
mean=mean, std=std)
]
super(ComposedYOLOv3Transforms, self).__init__(transforms)
class BatchRandomShape(DetTransform):
"""调整图像大小(resize)。
对batch数据中的每张图像全部resize到random_shapes中任意一个大小。
注意:当插值方式为“RANDOM”时,则随机选取一种插值方式进行resize。
Args:
random_shapes (list): resize大小选择列表。
默认为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
interp (str): resize的插值方式,与opencv的插值方式对应,取值范围为
['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM']。默认为"RANDOM"。
Raises:
ValueError: 插值方式不在['NEAREST', 'LINEAR', 'CUBIC',
'AREA', 'LANCZOS4', 'RANDOM']中。
"""
# The interpolation mode
interp_dict = {
'NEAREST': cv2.INTER_NEAREST,
'LINEAR': cv2.INTER_LINEAR,
'CUBIC': cv2.INTER_CUBIC,
'AREA': cv2.INTER_AREA,
'LANCZOS4': cv2.INTER_LANCZOS4
}
def __init__(
self,
random_shapes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608],
interp='RANDOM'):
if not (interp == "RANDOM" or interp in self.interp_dict):
raise ValueError("interp should be one of {}".format(
self.interp_dict.keys()))
self.random_shapes = random_shapes
self.interp = interp
def __call__(self, batch_data):
"""
Args:
batch_data (list): 由与图像相关的各种信息组成的batch数据。
Returns:
list: 由与图像相关的各种信息组成的batch数据。
"""
shape = np.random.choice(self.random_shapes)
if self.interp == "RANDOM":
interp = random.choice(list(self.interp_dict.keys()))
else:
interp = self.interp
for data_id, data in enumerate(batch_data):
data_list = list(data)
im = data_list[0]
im = np.swapaxes(im, 1, 0)
im = np.swapaxes(im, 1, 2)
im = resize(im, shape, self.interp_dict[interp])
im = np.swapaxes(im, 1, 2)
im = np.swapaxes(im, 1, 0)
data_list[0] = im
batch_data[data_id] = tuple(data_list)
return batch_data
class GenerateYoloTarget(object):
"""生成YOLOv3的ground truth(真实标注框)在不同特征层的位置转换信息。
该transform只在YOLOv3计算细粒度loss时使用。
Args:
anchors (list|tuple): anchor框的宽度和高度。
anchor_masks (list|tuple): 在计算损失时,使用anchor的mask索引。
num_classes (int): 类别数。默认为80。
iou_thresh (float): iou阈值,当anchor和真实标注框的iou大于该阈值时,计入target。默认为1.0。
"""
def __init__(self,
anchors,
anchor_masks,
downsample_ratios,
num_classes=80,
iou_thresh=1.):
super(GenerateYoloTarget, self).__init__()
self.anchors = anchors
self.anchor_masks = anchor_masks
self.downsample_ratios = downsample_ratios
self.num_classes = num_classes
self.iou_thresh = iou_thresh
def __call__(self, batch_data):
"""
Args:
batch_data (list): 由与图像相关的各种信息组成的batch数据。
Returns:
list: 由与图像相关的各种信息组成的batch数据。
其中,每个数据新添加的字段为:
- target0 (np.ndarray): YOLOv3的ground truth在特征层0的位置转换信息,
形状为(特征层0的anchor数量, 6+类别数, 特征层0的h, 特征层0的w)。
- target1 (np.ndarray): YOLOv3的ground truth在特征层1的位置转换信息,
形状为(特征层1的anchor数量, 6+类别数, 特征层1的h, 特征层1的w)。
- ...
-targetn (np.ndarray): YOLOv3的ground truth在特征层n的位置转换信息,
形状为(特征层n的anchor数量, 6+类别数, 特征层n的h, 特征层n的w)。
n的是大小由anchor_masks的长度决定。
"""
im = batch_data[0][0]
h = im.shape[1]
w = im.shape[2]
an_hw = np.array(self.anchors) / np.array([[w, h]])
for data_id, data in enumerate(batch_data):
gt_bbox = data[1]
gt_class = data[2]
gt_score = data[3]
im_shape = data[4]
origin_h = float(im_shape[0])
origin_w = float(im_shape[1])
data_list = list(data)
for i, (
mask, downsample_ratio
) in enumerate(zip(self.anchor_masks, self.downsample_ratios)):
grid_h = int(h / downsample_ratio)
grid_w = int(w / downsample_ratio)
target = np.zeros(
(len(mask), 6 + self.num_classes, grid_h, grid_w),
dtype=np.float32)
for b in range(gt_bbox.shape[0]):
gx = gt_bbox[b, 0] / float(origin_w)
gy = gt_bbox[b, 1] / float(origin_h)
gw = gt_bbox[b, 2] / float(origin_w)
gh = gt_bbox[b, 3] / float(origin_h)
cls = gt_class[b]
score = gt_score[b]
if gw <= 0. or gh <= 0. or score <= 0.:
continue
# find best match anchor index
best_iou = 0.
best_idx = -1
for an_idx in range(an_hw.shape[0]):
iou = jaccard_overlap(
[0., 0., gw, gh],
[0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
if iou > best_iou:
best_iou = iou
best_idx = an_idx
gi = int(gx * grid_w)
gj = int(gy * grid_h)
# gtbox should be regresed in this layes if best match
# anchor index in anchor mask of this layer
if best_idx in mask:
best_n = mask.index(best_idx)
# x, y, w, h, scale
target[best_n, 0, gj, gi] = gx * grid_w - gi
target[best_n, 1, gj, gi] = gy * grid_h - gj
target[best_n, 2, gj, gi] = np.log(
gw * w / self.anchors[best_idx][0])
target[best_n, 3, gj, gi] = np.log(
gh * h / self.anchors[best_idx][1])
target[best_n, 4, gj, gi] = 2.0 - gw * gh
# objectness record gt_score
target[best_n, 5, gj, gi] = score
# classification
target[best_n, 6 + cls, gj, gi] = 1.
# For non-matched anchors, calculate the target if the iou
# between anchor and gt is larger than iou_thresh
if self.iou_thresh < 1:
for idx, mask_i in enumerate(mask):
if mask_i == best_idx: continue
iou = jaccard_overlap(
[0., 0., gw, gh],
[0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
if iou > self.iou_thresh:
# x, y, w, h, scale
target[idx, 0, gj, gi] = gx * grid_w - gi
target[idx, 1, gj, gi] = gy * grid_h - gj
target[idx, 2, gj, gi] = np.log(
gw * w / self.anchors[mask_i][0])
target[idx, 3, gj, gi] = np.log(
gh * h / self.anchors[mask_i][1])
target[idx, 4, gj, gi] = 2.0 - gw * gh
# objectness record gt_score
target[idx, 5, gj, gi] = score
# classification
target[idx, 6 + cls, gj, gi] = 1.
data_list.append(target)
batch_data[data_id] = tuple(data_list)
return batch_data
# 环境变量配置,用于控制是否使用GPU
# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from paddlex.det import transforms
import paddlex as pdx
# 下载和解压昆虫检测数据集
insect_dataset = 'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
pdx.utils.download_and_decompress(insect_dataset, path='./')
# 定义训练和验证时的transforms
# API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/det_transforms.html
train_transforms = transforms.Compose([
transforms.MixupImage(mixup_epoch=250), transforms.RandomDistort(),
transforms.RandomExpand(), transforms.RandomCrop(), transforms.Resize(
target_size=608, interp='RANDOM'), transforms.RandomHorizontalFlip(),
transforms.Normalize()
])
eval_transforms = transforms.Compose([
transforms.Resize(
target_size=608, interp='CUBIC'), transforms.Normalize()
])
# 定义训练和验证所用的数据集
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-vocdetection
train_dataset = pdx.datasets.VOCDetection(
data_dir='insect_det',
file_list='insect_det/train_list.txt',
label_list='insect_det/labels.txt',
transforms=train_transforms,
shuffle=True)
eval_dataset = pdx.datasets.VOCDetection(
data_dir='insect_det',
file_list='insect_det/val_list.txt',
label_list='insect_det/labels.txt',
transforms=eval_transforms)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
num_classes = len(train_dataset.labels)
# API说明: https://paddlex.readthedocs.io/zh_CN/develop/apis/models/detection.html#paddlex-det-yolov3
model = pdx.det.YOLOv3(
num_classes=num_classes,
backbone='ResNet50_vd',
with_dcn_v2=True,
use_coord_conv=True,
use_iou_aware=True,
use_spp=True,
use_drop_block=True,
scale_x_y=1.05,
use_iou_loss=True,
use_matrix_nms=True)
# API说明: https://paddlex.readthedocs.io/zh_CN/develop/apis/models/detection.html#train
# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
model.train(
num_epochs=270,
train_dataset=train_dataset,
train_batch_size=8,
eval_dataset=eval_dataset,
learning_rate=0.000125,
lr_decay_epochs=[210, 240],
use_ema=True,
save_dir='output/ppyolo',
use_vdl=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册