提交 ed8de1ae 编写于 作者: F FlyingQianMM

add ppyolo

上级 f1465e6f
...@@ -115,7 +115,7 @@ def multithread_reader(mapper, ...@@ -115,7 +115,7 @@ def multithread_reader(mapper,
while not isinstance(sample, EndSignal): while not isinstance(sample, EndSignal):
batch_data.append(sample) batch_data.append(sample)
if len(batch_data) == batch_size: if len(batch_data) == batch_size:
batch_data = generate_minibatch(batch_data) batch_data = generate_minibatch(batch_data, mapper=mapper)
yield batch_data yield batch_data
batch_data = [] batch_data = []
sample = out_queue.get() sample = out_queue.get()
...@@ -127,11 +127,11 @@ def multithread_reader(mapper, ...@@ -127,11 +127,11 @@ def multithread_reader(mapper,
else: else:
batch_data.append(sample) batch_data.append(sample)
if len(batch_data) == batch_size: if len(batch_data) == batch_size:
batch_data = generate_minibatch(batch_data) batch_data = generate_minibatch(batch_data, mapper=mapper)
yield batch_data yield batch_data
batch_data = [] batch_data = []
if not drop_last and len(batch_data) != 0: if not drop_last and len(batch_data) != 0:
batch_data = generate_minibatch(batch_data) batch_data = generate_minibatch(batch_data, mapper=mapper)
yield batch_data yield batch_data
batch_data = [] batch_data = []
...@@ -188,18 +188,21 @@ def multiprocess_reader(mapper, ...@@ -188,18 +188,21 @@ def multiprocess_reader(mapper,
else: else:
batch_data.append(sample) batch_data.append(sample)
if len(batch_data) == batch_size: if len(batch_data) == batch_size:
batch_data = generate_minibatch(batch_data) batch_data = generate_minibatch(batch_data, mapper=mapper)
yield batch_data yield batch_data
batch_data = [] batch_data = []
if len(batch_data) != 0 and not drop_last: if len(batch_data) != 0 and not drop_last:
batch_data = generate_minibatch(batch_data) batch_data = generate_minibatch(batch_data, mapper=mapper)
yield batch_data yield batch_data
batch_data = [] batch_data = []
return queue_reader return queue_reader
def generate_minibatch(batch_data, label_padding_value=255): def generate_minibatch(batch_data, label_padding_value=255, mapper=None):
if mapper is not None and mapper.batch_transforms is not None:
for op in mapper.batch_transforms:
batch_data = op(batch_data)
# if batch_size is 1, do not pad the image # if batch_size is 1, do not pad the image
if len(batch_data) == 1: if len(batch_data) == 1:
return batch_data return batch_data
...@@ -218,14 +221,13 @@ def generate_minibatch(batch_data, label_padding_value=255): ...@@ -218,14 +221,13 @@ def generate_minibatch(batch_data, label_padding_value=255):
(im_c, max_shape[1], max_shape[2]), dtype=np.float32) (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
padding_im[:, :im_h, :im_w] = data[0] padding_im[:, :im_h, :im_w] = data[0]
if len(data) > 2: if len(data) > 2:
# padding the image, label and insert 'padding' into `im_info` of segmentation during evaluating phase. # padding the image, label and insert 'padding' into `im_info` of segmentation during evaluating phase.
if len(data[1]) == 0 or 'padding' not in [ if len(data[1]) == 0 or 'padding' not in [
data[1][i][0] for i in range(len(data[1])) data[1][i][0] for i in range(len(data[1]))
]: ]:
data[1].append(('padding', [im_h, im_w])) data[1].append(('padding', [im_h, im_w]))
padding_batch.append((padding_im, data[1], data[2])) padding_batch.append((padding_im, data[1], data[2]))
elif len(data) > 1: elif len(data) > 1:
if isinstance(data[1], np.ndarray) and len(data[1].shape) > 1: if isinstance(data[1], np.ndarray) and len(data[1].shape) > 1:
# padding the image and label of segmentation during the training # padding the image and label of segmentation during the training
......
...@@ -94,6 +94,8 @@ class BaseAPI: ...@@ -94,6 +94,8 @@ class BaseAPI:
self.train_inputs, self.train_outputs = self.build_net(mode='train') self.train_inputs, self.train_outputs = self.build_net(mode='train')
self.train_prog = fluid.default_main_program() self.train_prog = fluid.default_main_program()
startup_prog = fluid.default_startup_program() startup_prog = fluid.default_startup_program()
self.train_prog.random_seed = 1000
startup_prog.random_seed = 1000
# 构建预测网络 # 构建预测网络
self.test_prog = fluid.Program() self.test_prog = fluid.Program()
...@@ -246,8 +248,8 @@ class BaseAPI: ...@@ -246,8 +248,8 @@ class BaseAPI:
logging.info( logging.info(
"Load pretrain weights from {}.".format(pretrain_weights), "Load pretrain weights from {}.".format(pretrain_weights),
use_color=True) use_color=True)
paddlex.utils.utils.load_pretrain_weights(self.exe, self.train_prog, paddlex.utils.utils.load_pretrain_weights(
pretrain_weights, fuse_bn) self.exe, self.train_prog, pretrain_weights, fuse_bn)
# 进行裁剪 # 进行裁剪
if sensitivities_file is not None: if sensitivities_file is not None:
import paddleslim import paddleslim
...@@ -351,7 +353,9 @@ class BaseAPI: ...@@ -351,7 +353,9 @@ class BaseAPI:
logging.info("Model saved in {}.".format(save_dir)) logging.info("Model saved in {}.".format(save_dir))
def export_inference_model(self, save_dir): def export_inference_model(self, save_dir):
test_input_names = [var.name for var in list(self.test_inputs.values())] test_input_names = [
var.name for var in list(self.test_inputs.values())
]
test_outputs = list(self.test_outputs.values()) test_outputs = list(self.test_outputs.values())
with fluid.scope_guard(self.scope): with fluid.scope_guard(self.scope):
if self.__class__.__name__ == 'MaskRCNN': if self.__class__.__name__ == 'MaskRCNN':
...@@ -389,7 +393,8 @@ class BaseAPI: ...@@ -389,7 +393,8 @@ class BaseAPI:
# 模型保存成功的标志 # 模型保存成功的标志
open(osp.join(save_dir, '.success'), 'w').close() open(osp.join(save_dir, '.success'), 'w').close()
logging.info("Model for inference deploy saved in {}.".format(save_dir)) logging.info("Model for inference deploy saved in {}.".format(
save_dir))
def train_loop(self, def train_loop(self,
num_epochs, num_epochs,
...@@ -516,11 +521,13 @@ class BaseAPI: ...@@ -516,11 +521,13 @@ class BaseAPI:
eta = ((num_epochs - i) * total_num_steps - step - 1 eta = ((num_epochs - i) * total_num_steps - step - 1
) * avg_step_time ) * avg_step_time
if time_eval_one_epoch is not None: if time_eval_one_epoch is not None:
eval_eta = (total_eval_times - i // save_interval_epochs eval_eta = (
) * time_eval_one_epoch total_eval_times - i // save_interval_epochs
) * time_eval_one_epoch
else: else:
eval_eta = (total_eval_times - i // save_interval_epochs eval_eta = (
) * total_num_steps_eval * avg_step_time total_eval_times - i // save_interval_epochs
) * total_num_steps_eval * avg_step_time
eta_str = seconds_to_hms(eta + eval_eta) eta_str = seconds_to_hms(eta + eval_eta)
logging.info( logging.info(
...@@ -543,6 +550,8 @@ class BaseAPI: ...@@ -543,6 +550,8 @@ class BaseAPI:
current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1)) current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
if not osp.isdir(current_save_dir): if not osp.isdir(current_save_dir):
os.makedirs(current_save_dir) os.makedirs(current_save_dir)
if hasattr(self, 'use_ema'):
self.exe.run(self.ema.apply_program)
if eval_dataset is not None and eval_dataset.num_samples > 0: if eval_dataset is not None and eval_dataset.num_samples > 0:
self.eval_metrics, self.eval_details = self.evaluate( self.eval_metrics, self.eval_details = self.evaluate(
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
...@@ -569,6 +578,8 @@ class BaseAPI: ...@@ -569,6 +578,8 @@ class BaseAPI:
log_writer.add_scalar( log_writer.add_scalar(
"Metrics/Eval(Epoch): {}".format(k), v, i + 1) "Metrics/Eval(Epoch): {}".format(k), v, i + 1)
self.save_model(save_dir=current_save_dir) self.save_model(save_dir=current_save_dir)
if hasattr(self, 'use_ema'):
self.exe.run(self.ema.restore_program)
time_eval_one_epoch = time.time() - eval_epoch_start_time time_eval_one_epoch = time.time() - eval_epoch_start_time
eval_epoch_start_time = time.time() eval_epoch_start_time = time.time()
if best_model_epoch > 0: if best_model_epoch > 0:
......
...@@ -19,6 +19,8 @@ import os.path as osp ...@@ -19,6 +19,8 @@ import os.path as osp
import numpy as np import numpy as np
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
from paddle.fluid.optimizer import ExponentialMovingAverage
import paddlex.utils.logging as logging import paddlex.utils.logging as logging
import paddlex import paddlex
import copy import copy
...@@ -28,6 +30,10 @@ from .base import BaseAPI ...@@ -28,6 +30,10 @@ from .base import BaseAPI
from collections import OrderedDict from collections import OrderedDict
from .utils.detection_eval import eval_results, bbox2out from .utils.detection_eval import eval_results, bbox2out
import random
random.seed(0)
np.random.seed(0)
class YOLOv3(BaseAPI): class YOLOv3(BaseAPI):
"""构建YOLOv3,并实现其训练、评估、预测和模型导出。 """构建YOLOv3,并实现其训练、评估、预测和模型导出。
...@@ -50,24 +56,37 @@ class YOLOv3(BaseAPI): ...@@ -50,24 +56,37 @@ class YOLOv3(BaseAPI):
train_random_shapes (list|tuple): 训练时从列表中随机选择图像大小。默认值为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。 train_random_shapes (list|tuple): 训练时从列表中随机选择图像大小。默认值为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
""" """
def __init__(self, def __init__(
num_classes=80, self,
backbone='MobileNetV1', num_classes=80,
anchors=None, backbone='MobileNetV1',
anchor_masks=None, with_dcn_v2=False,
ignore_threshold=0.7, # YOLO Head
nms_score_threshold=0.01, anchors=None,
nms_topk=1000, anchor_masks=None,
nms_keep_topk=100, use_coord_conv=False,
nms_iou_threshold=0.45, use_iou_aware=False,
label_smooth=False, use_spp=False,
train_random_shapes=[ use_drop_block=False,
320, 352, 384, 416, 448, 480, 512, 544, 576, 608 scale_x_y=1.0,
]): # YOLOv3 Loss
ignore_threshold=0.7,
label_smooth=False,
use_iou_loss=False,
# NMS
use_matrix_nms=False,
nms_score_threshold=0.01,
nms_topk=1000,
nms_keep_topk=100,
nms_iou_threshold=0.45,
train_random_shapes=[
320, 352, 384, 416, 448, 480, 512, 544, 576, 608
]):
self.init_params = locals() self.init_params = locals()
super(YOLOv3, self).__init__('detector') super(YOLOv3, self).__init__('detector')
backbones = [ backbones = [
'DarkNet53', 'ResNet34', 'MobileNetV1', 'MobileNetV3_large' 'DarkNet53', 'ResNet34', 'MobileNetV1', 'MobileNetV3_large',
'ResNet50_vd'
] ]
assert backbone in backbones, "backbone should be one of {}".format( assert backbone in backbones, "backbone should be one of {}".format(
backbones) backbones)
...@@ -75,6 +94,11 @@ class YOLOv3(BaseAPI): ...@@ -75,6 +94,11 @@ class YOLOv3(BaseAPI):
self.num_classes = num_classes self.num_classes = num_classes
self.anchors = anchors self.anchors = anchors
self.anchor_masks = anchor_masks self.anchor_masks = anchor_masks
if anchors is None:
self.anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]]
if anchor_masks is None:
self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
self.ignore_threshold = ignore_threshold self.ignore_threshold = ignore_threshold
self.nms_score_threshold = nms_score_threshold self.nms_score_threshold = nms_score_threshold
self.nms_topk = nms_topk self.nms_topk = nms_topk
...@@ -84,6 +108,20 @@ class YOLOv3(BaseAPI): ...@@ -84,6 +108,20 @@ class YOLOv3(BaseAPI):
self.sync_bn = True self.sync_bn = True
self.train_random_shapes = train_random_shapes self.train_random_shapes = train_random_shapes
self.fixed_input_shape = None self.fixed_input_shape = None
self.use_fine_grained_loss = False
if use_coord_conv or use_iou_aware or use_spp or use_drop_block or use_iou_loss:
self.use_fine_grained_loss = True
self.use_coord_conv = use_coord_conv
self.use_iou_aware = use_iou_aware
self.use_spp = use_spp
self.use_drop_block = use_drop_block
self.use_iou_loss = use_iou_loss
self.scale_x_y = scale_x_y
self.max_height = 608
self.max_width = 608
self.use_matrix_nms = use_matrix_nms
self.use_ema = False
self.with_dcn_v2 = with_dcn_v2
def _get_backbone(self, backbone_name): def _get_backbone(self, backbone_name):
if backbone_name == 'DarkNet53': if backbone_name == 'DarkNet53':
...@@ -102,6 +140,16 @@ class YOLOv3(BaseAPI): ...@@ -102,6 +140,16 @@ class YOLOv3(BaseAPI):
model_name = backbone_name.split('_')[1] model_name = backbone_name.split('_')[1]
backbone = paddlex.cv.nets.MobileNetV3( backbone = paddlex.cv.nets.MobileNetV3(
norm_type='sync_bn', model_name=model_name) norm_type='sync_bn', model_name=model_name)
elif backbone_name == 'ResNet50_vd':
backbone = paddlex.cv.nets.ResNet(
norm_type='sync_bn',
layers=50,
freeze_norm=False,
norm_decay=0.,
feature_maps=[3, 4, 5],
freeze_at=0,
variant='d',
dcn_v2_stages=[5] if self.with_dcn_v2 else [])
return backbone return backbone
def build_net(self, mode='train'): def build_net(self, mode='train'):
...@@ -117,14 +165,31 @@ class YOLOv3(BaseAPI): ...@@ -117,14 +165,31 @@ class YOLOv3(BaseAPI):
nms_topk=self.nms_topk, nms_topk=self.nms_topk,
nms_keep_topk=self.nms_keep_topk, nms_keep_topk=self.nms_keep_topk,
nms_iou_threshold=self.nms_iou_threshold, nms_iou_threshold=self.nms_iou_threshold,
train_random_shapes=self.train_random_shapes, fixed_input_shape=self.fixed_input_shape,
fixed_input_shape=self.fixed_input_shape) coord_conv=self.use_coord_conv,
iou_aware=self.use_iou_aware,
scale_x_y=self.scale_x_y,
spp=self.use_spp,
drop_block=self.use_drop_block,
use_matrix_nms=self.use_matrix_nms,
use_fine_grained_loss=self.use_fine_grained_loss,
use_iou_loss=self.use_iou_loss,
batch_size=self.batch_size_per_gpu
if hasattr(self, 'batch_size_per_gpu') else 8)
if mode == 'train' and self.use_iou_loss or self.use_iou_aware:
model.max_height = self.max_height
model.max_width = self.max_width
inputs = model.generate_inputs() inputs = model.generate_inputs()
model_out = model.build_net(inputs) model_out = model.build_net(inputs)
outputs = OrderedDict([('bbox', model_out)]) outputs = OrderedDict([('bbox', model_out[0])])
if mode == 'train': if mode == 'train':
self.optimizer.minimize(model_out) self.optimizer.minimize(model_out)
outputs = OrderedDict([('loss', model_out)]) outputs = OrderedDict([('loss', model_out)])
if self.use_ema:
global_steps = _decay_step_counter()
self.ema = ExponentialMovingAverage(
self.ema_decay, thres_steps=global_steps)
self.ema.update()
return inputs, outputs return inputs, outputs
def default_optimizer(self, learning_rate, warmup_steps, warmup_start_lr, def default_optimizer(self, learning_rate, warmup_steps, warmup_start_lr,
...@@ -172,6 +237,8 @@ class YOLOv3(BaseAPI): ...@@ -172,6 +237,8 @@ class YOLOv3(BaseAPI):
warmup_start_lr=0.0, warmup_start_lr=0.0,
lr_decay_epochs=[213, 240], lr_decay_epochs=[213, 240],
lr_decay_gamma=0.1, lr_decay_gamma=0.1,
use_ema=False,
ema_decay=0.9998,
metric=None, metric=None,
use_vdl=False, use_vdl=False,
sensitivities_file=None, sensitivities_file=None,
...@@ -242,6 +309,46 @@ class YOLOv3(BaseAPI): ...@@ -242,6 +309,46 @@ class YOLOv3(BaseAPI):
lr_decay_gamma=lr_decay_gamma, lr_decay_gamma=lr_decay_gamma,
num_steps_each_epoch=num_steps_each_epoch) num_steps_each_epoch=num_steps_each_epoch)
self.optimizer = optimizer self.optimizer = optimizer
self.use_ema = use_ema
self.ema_decay = ema_decay
self.batch_size_per_gpu = int(train_batch_size /
paddlex.env_info['num'])
if self.use_fine_grained_loss:
for transform in train_dataset.transforms.transforms:
if isinstance(transform, paddlex.det.transforms.Resize):
self.max_height = transform.target_size
self.max_width = transform.target_size
break
if train_dataset.transforms.batch_transforms is None:
train_dataset.transforms.batch_transforms = list()
define_random_shape = False
for bt in train_dataset.transforms.batch_transforms:
if isinstance(bt, paddlex.det.transforms.BatchRandomShape):
define_random_shape = True
if not define_random_shape:
if isinstance(self.train_random_shapes,
(list, tuple)) and len(self.train_random_shapes) > 0:
train_dataset.transforms.batch_transforms.append(
paddlex.det.transforms.BatchRandomShape(
random_shapes=self.train_random_shapes))
if self.use_fine_grained_loss:
self.max_height = max(self.max_height,
max(self.train_random_shapes))
self.max_width = max(self.max_width,
max(self.train_random_shapes))
if self.use_fine_grained_loss:
define_generate_target = False
for bt in train_dataset.transforms.batch_transforms:
if isinstance(bt, paddlex.det.transforms.GenerateYoloTarget):
define_generate_target = True
if not define_generate_target:
train_dataset.transforms.batch_transforms.append(
paddlex.det.transforms.GenerateYoloTarget(
anchors=self.anchors,
anchor_masks=self.anchor_masks,
num_classes=self.num_classes,
downsample_ratios=[32, 16, 8]))
# 构建训练、验证、预测网络 # 构建训练、验证、预测网络
self.build_program() self.build_program()
# 初始化网络权重 # 初始化网络权重
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
def _split_ioup(output, an_num, num_classes):
"""
Split new output feature map to output, predicted iou
along channel dimension
"""
ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num])
ioup = fluid.layers.sigmoid(ioup)
oriout = fluid.layers.slice(
output, axes=[1], starts=[an_num], ends=[an_num * (num_classes + 6)])
return (ioup, oriout)
def _de_sigmoid(x, eps=1e-7):
x = fluid.layers.clip(x, eps, 1 / eps)
one = fluid.layers.fill_constant(
shape=[1, 1, 1, 1], dtype=x.dtype, value=1.)
x = fluid.layers.clip((one / x - 1.0), eps, 1 / eps)
x = -fluid.layers.log(x)
return x
def _postprocess_output(ioup, output, an_num, num_classes, iou_aware_factor):
"""
post process output objectness score
"""
tensors = []
stride = output.shape[1] // an_num
for m in range(an_num):
tensors.append(
fluid.layers.slice(
output,
axes=[1],
starts=[stride * m + 0],
ends=[stride * m + 4]))
obj = fluid.layers.slice(
output, axes=[1], starts=[stride * m + 4], ends=[stride * m + 5])
obj = fluid.layers.sigmoid(obj)
ip = fluid.layers.slice(ioup, axes=[1], starts=[m], ends=[m + 1])
new_obj = fluid.layers.pow(obj, (
1 - iou_aware_factor)) * fluid.layers.pow(ip, iou_aware_factor)
new_obj = _de_sigmoid(new_obj)
tensors.append(new_obj)
tensors.append(
fluid.layers.slice(
output,
axes=[1],
starts=[stride * m + 5],
ends=[stride * m + 5 + num_classes]))
output = fluid.layers.concat(tensors, axis=1)
return output
def get_iou_aware_score(output, an_num, num_classes, iou_aware_factor):
ioup, output = _split_ioup(output, an_num, num_classes)
output = _postprocess_output(ioup, output, an_num, num_classes,
iou_aware_factor)
return output
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import NumpyArrayInitializer
from paddle import fluid
from .iou_loss import IouLoss
class IouAwareLoss(IouLoss):
"""
iou aware loss, see https://arxiv.org/abs/1912.05992
Args:
loss_weight (float): iou aware loss weight, default is 1.0
max_height (int): max height of input to support random shape input
max_width (int): max width of input to support random shape input
"""
def __init__(self, loss_weight=1.0, max_height=608, max_width=608):
super(IouAwareLoss, self).__init__(
loss_weight=loss_weight,
max_height=max_height,
max_width=max_width)
def __call__(self,
ioup,
x,
y,
w,
h,
tx,
ty,
tw,
th,
anchors,
downsample_ratio,
batch_size,
scale_x_y,
eps=1.e-10):
'''
Args:
ioup ([Variables]): the predicted iou
x | y | w | h ([Variables]): the output of yolov3 for encoded x|y|w|h
tx |ty |tw |th ([Variables]): the target of yolov3 for encoded x|y|w|h
anchors ([float]): list of anchors for current output layer
downsample_ratio (float): the downsample ratio for current output layer
batch_size (int): training batch size
eps (float): the decimal to prevent the denominator eqaul zero
'''
pred = self._bbox_transform(x, y, w, h, anchors, downsample_ratio,
batch_size, False, scale_x_y, eps)
gt = self._bbox_transform(tx, ty, tw, th, anchors, downsample_ratio,
batch_size, True, scale_x_y, eps)
iouk = self._iou(pred, gt, ioup, eps)
iouk.stop_gradient = True
loss_iou_aware = fluid.layers.cross_entropy(
ioup, iouk, soft_label=True)
loss_iou_aware = loss_iou_aware * self._loss_weight
return loss_iou_aware
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import NumpyArrayInitializer
from paddle import fluid
class IouLoss(object):
"""
iou loss, see https://arxiv.org/abs/1908.03851
loss = 1.0 - iou * iou
Args:
loss_weight (float): iou loss weight, default is 2.5
max_height (int): max height of input to support random shape input
max_width (int): max width of input to support random shape input
ciou_term (bool): whether to add ciou_term
loss_square (bool): whether to square the iou term
"""
def __init__(self,
loss_weight=2.5,
max_height=608,
max_width=608,
ciou_term=False,
loss_square=True):
self._loss_weight = loss_weight
self._MAX_HI = max_height
self._MAX_WI = max_width
self.ciou_term = ciou_term
self.loss_square = loss_square
def __call__(self,
x,
y,
w,
h,
tx,
ty,
tw,
th,
anchors,
downsample_ratio,
batch_size,
scale_x_y=1.,
ioup=None,
eps=1.e-10):
'''
Args:
x | y | w | h ([Variables]): the output of yolov3 for encoded x|y|w|h
tx |ty |tw |th ([Variables]): the target of yolov3 for encoded x|y|w|h
anchors ([float]): list of anchors for current output layer
downsample_ratio (float): the downsample ratio for current output layer
batch_size (int): training batch size
eps (float): the decimal to prevent the denominator eqaul zero
'''
pred = self._bbox_transform(x, y, w, h, anchors, downsample_ratio,
batch_size, False, scale_x_y, eps)
gt = self._bbox_transform(tx, ty, tw, th, anchors, downsample_ratio,
batch_size, True, scale_x_y, eps)
iouk = self._iou(pred, gt, ioup, eps)
if self.loss_square:
loss_iou = 1. - iouk * iouk
else:
loss_iou = 1. - iouk
loss_iou = loss_iou * self._loss_weight
return loss_iou
def _iou(self, pred, gt, ioup=None, eps=1.e-10):
x1, y1, x2, y2 = pred
x1g, y1g, x2g, y2g = gt
x2 = fluid.layers.elementwise_max(x1, x2)
y2 = fluid.layers.elementwise_max(y1, y2)
xkis1 = fluid.layers.elementwise_max(x1, x1g)
ykis1 = fluid.layers.elementwise_max(y1, y1g)
xkis2 = fluid.layers.elementwise_min(x2, x2g)
ykis2 = fluid.layers.elementwise_min(y2, y2g)
intsctk = (xkis2 - xkis1) * (ykis2 - ykis1)
intsctk = intsctk * fluid.layers.greater_than(
xkis2, xkis1) * fluid.layers.greater_than(ykis2, ykis1)
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g
) - intsctk + eps
iouk = intsctk / unionk
if self.ciou_term:
ciou = self.get_ciou_term(pred, gt, iouk, eps)
iouk = iouk - ciou
return iouk
def get_ciou_term(self, pred, gt, iouk, eps):
x1, y1, x2, y2 = pred
x1g, y1g, x2g, y2g = gt
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = (x2 - x1) + fluid.layers.cast((x2 - x1) == 0, 'float32')
h = (y2 - y1) + fluid.layers.cast((y2 - y1) == 0, 'float32')
cxg = (x1g + x2g) / 2
cyg = (y1g + y2g) / 2
wg = x2g - x1g
hg = y2g - y1g
# A or B
xc1 = fluid.layers.elementwise_min(x1, x1g)
yc1 = fluid.layers.elementwise_min(y1, y1g)
xc2 = fluid.layers.elementwise_max(x2, x2g)
yc2 = fluid.layers.elementwise_max(y2, y2g)
# DIOU term
dist_intersection = (cx - cxg) * (cx - cxg) + (cy - cyg) * (cy - cyg)
dist_union = (xc2 - xc1) * (xc2 - xc1) + (yc2 - yc1) * (yc2 - yc1)
diou_term = (dist_intersection + eps) / (dist_union + eps)
# CIOU term
ciou_term = 0
ar_gt = wg / hg
ar_pred = w / h
arctan = fluid.layers.atan(ar_gt) - fluid.layers.atan(ar_pred)
ar_loss = 4. / np.pi / np.pi * arctan * arctan
alpha = ar_loss / (1 - iouk + ar_loss + eps)
alpha.stop_gradient = True
ciou_term = alpha * ar_loss
return diou_term + ciou_term
def _bbox_transform(self, dcx, dcy, dw, dh, anchors, downsample_ratio,
batch_size, is_gt, scale_x_y, eps):
grid_x = int(self._MAX_WI / downsample_ratio)
grid_y = int(self._MAX_HI / downsample_ratio)
an_num = len(anchors) // 2
shape_fmp = fluid.layers.shape(dcx)
shape_fmp.stop_gradient = True
# generate the grid_w x grid_h center of feature map
idx_i = np.array([[i for i in range(grid_x)]])
idx_j = np.array([[j for j in range(grid_y)]]).transpose()
gi_np = np.repeat(idx_i, grid_y, axis=0)
gi_np = np.reshape(gi_np, newshape=[1, 1, grid_y, grid_x])
gi_np = np.tile(gi_np, reps=[batch_size, an_num, 1, 1])
gj_np = np.repeat(idx_j, grid_x, axis=1)
gj_np = np.reshape(gj_np, newshape=[1, 1, grid_y, grid_x])
gj_np = np.tile(gj_np, reps=[batch_size, an_num, 1, 1])
gi_max = self._create_tensor_from_numpy(gi_np.astype(np.float32))
gi = fluid.layers.crop(x=gi_max, shape=dcx)
gi.stop_gradient = True
gj_max = self._create_tensor_from_numpy(gj_np.astype(np.float32))
gj = fluid.layers.crop(x=gj_max, shape=dcx)
gj.stop_gradient = True
grid_x_act = fluid.layers.cast(shape_fmp[3], dtype="float32")
grid_x_act.stop_gradient = True
grid_y_act = fluid.layers.cast(shape_fmp[2], dtype="float32")
grid_y_act.stop_gradient = True
if is_gt:
cx = fluid.layers.elementwise_add(dcx, gi) / grid_x_act
cx.gradient = True
cy = fluid.layers.elementwise_add(dcy, gj) / grid_y_act
cy.gradient = True
else:
dcx_sig = fluid.layers.sigmoid(dcx)
dcy_sig = fluid.layers.sigmoid(dcy)
if (abs(scale_x_y - 1.0) > eps):
dcx_sig = scale_x_y * dcx_sig - 0.5 * (scale_x_y - 1)
dcy_sig = scale_x_y * dcy_sig - 0.5 * (scale_x_y - 1)
cx = fluid.layers.elementwise_add(dcx_sig, gi) / grid_x_act
cy = fluid.layers.elementwise_add(dcy_sig, gj) / grid_y_act
anchor_w_ = [anchors[i] for i in range(0, len(anchors)) if i % 2 == 0]
anchor_w_np = np.array(anchor_w_)
anchor_w_np = np.reshape(anchor_w_np, newshape=[1, an_num, 1, 1])
anchor_w_np = np.tile(
anchor_w_np, reps=[batch_size, 1, grid_y, grid_x])
anchor_w_max = self._create_tensor_from_numpy(
anchor_w_np.astype(np.float32))
anchor_w = fluid.layers.crop(x=anchor_w_max, shape=dcx)
anchor_w.stop_gradient = True
anchor_h_ = [anchors[i] for i in range(0, len(anchors)) if i % 2 == 1]
anchor_h_np = np.array(anchor_h_)
anchor_h_np = np.reshape(anchor_h_np, newshape=[1, an_num, 1, 1])
anchor_h_np = np.tile(
anchor_h_np, reps=[batch_size, 1, grid_y, grid_x])
anchor_h_max = self._create_tensor_from_numpy(
anchor_h_np.astype(np.float32))
anchor_h = fluid.layers.crop(x=anchor_h_max, shape=dcx)
anchor_h.stop_gradient = True
# e^tw e^th
exp_dw = fluid.layers.exp(dw)
exp_dh = fluid.layers.exp(dh)
pw = fluid.layers.elementwise_mul(exp_dw, anchor_w) / \
(grid_x_act * downsample_ratio)
ph = fluid.layers.elementwise_mul(exp_dh, anchor_h) / \
(grid_y_act * downsample_ratio)
if is_gt:
exp_dw.stop_gradient = True
exp_dh.stop_gradient = True
pw.stop_gradient = True
ph.stop_gradient = True
x1 = cx - 0.5 * pw
y1 = cy - 0.5 * ph
x2 = cx + 0.5 * pw
y2 = cy + 0.5 * ph
if is_gt:
x1.stop_gradient = True
y1.stop_gradient = True
x2.stop_gradient = True
y2.stop_gradient = True
return x1, y1, x2, y2
def _create_tensor_from_numpy(self, numpy_array):
paddle_array = fluid.layers.create_parameter(
attr=ParamAttr(),
shape=numpy_array.shape,
dtype=numpy_array.dtype,
default_initializer=NumpyArrayInitializer(numpy_array))
paddle_array.stop_gradient = True
return paddle_array
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
try:
from collections.abc import Sequence
except Exception:
from collections import Sequence
class YOLOv3Loss(object):
"""
Combined loss for YOLOv3 network
Args:
batch_size (int): training batch size
ignore_thresh (float): threshold to ignore confidence loss
label_smooth (bool): whether to use label smoothing
use_fine_grained_loss (bool): whether use fine grained YOLOv3 loss
instead of fluid.layers.yolov3_loss
"""
def __init__(self,
batch_size=8,
ignore_thresh=0.7,
label_smooth=True,
use_fine_grained_loss=False,
iou_loss=None,
iou_aware_loss=None,
downsample=[32, 16, 8],
scale_x_y=1.,
match_score=False):
self._batch_size = batch_size
self._ignore_thresh = ignore_thresh
self._label_smooth = label_smooth
self._use_fine_grained_loss = use_fine_grained_loss
self._iou_loss = iou_loss
self._iou_aware_loss = iou_aware_loss
self.downsample = downsample
self.scale_x_y = scale_x_y
self.match_score = match_score
def __call__(self, outputs, gt_box, gt_label, gt_score, targets, anchors,
anchor_masks, mask_anchors, num_classes, prefix_name):
if self._use_fine_grained_loss:
return self._get_fine_grained_loss(
outputs, targets, gt_box, self._batch_size, num_classes,
mask_anchors, self._ignore_thresh)
else:
losses = []
for i, output in enumerate(outputs):
scale_x_y = self.scale_x_y if not isinstance(
self.scale_x_y, Sequence) else self.scale_x_y[i]
anchor_mask = anchor_masks[i]
loss = fluid.layers.yolov3_loss(
x=output,
gt_box=gt_box,
gt_label=gt_label,
gt_score=gt_score,
anchors=anchors,
anchor_mask=anchor_mask,
class_num=num_classes,
ignore_thresh=self._ignore_thresh,
downsample_ratio=self.downsample[i],
use_label_smooth=self._label_smooth,
scale_x_y=scale_x_y,
name=prefix_name + "yolo_loss" + str(i))
losses.append(fluid.layers.reduce_mean(loss))
return {'loss': sum(losses)}
def _get_fine_grained_loss(self,
outputs,
targets,
gt_box,
batch_size,
num_classes,
mask_anchors,
ignore_thresh,
eps=1.e-10):
"""
Calculate fine grained YOLOv3 loss
Args:
outputs ([Variables]): List of Variables, output of backbone stages
targets ([Variables]): List of Variables, The targets for yolo
loss calculatation.
gt_box (Variable): The ground-truth boudding boxes.
batch_size (int): The training batch size
num_classes (int): class num of dataset
mask_anchors ([[float]]): list of anchors in each output layer
ignore_thresh (float): prediction bbox overlap any gt_box greater
than ignore_thresh, objectness loss will
be ignored.
Returns:
Type: dict
xy_loss (Variable): YOLOv3 (x, y) coordinates loss
wh_loss (Variable): YOLOv3 (w, h) coordinates loss
obj_loss (Variable): YOLOv3 objectness score loss
cls_loss (Variable): YOLOv3 classification loss
"""
assert len(outputs) == len(targets), \
"YOLOv3 output layer number not equal target number"
loss_xys, loss_whs, loss_objs, loss_clss = [], [], [], []
if self._iou_loss is not None:
loss_ious = []
if self._iou_aware_loss is not None:
loss_iou_awares = []
for i, (output, target,
anchors) in enumerate(zip(outputs, targets, mask_anchors)):
downsample = self.downsample[i]
an_num = len(anchors) // 2
if self._iou_aware_loss is not None:
ioup, output = self._split_ioup(output, an_num, num_classes)
x, y, w, h, obj, cls = self._split_output(output, an_num,
num_classes)
tx, ty, tw, th, tscale, tobj, tcls = self._split_target(target)
tscale_tobj = tscale * tobj
scale_x_y = self.scale_x_y if not isinstance(
self.scale_x_y, Sequence) else self.scale_x_y[i]
if (abs(scale_x_y - 1.0) < eps):
loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(
x, tx) * tscale_tobj
loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
loss_y = fluid.layers.sigmoid_cross_entropy_with_logits(
y, ty) * tscale_tobj
loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
else:
dx = scale_x_y * fluid.layers.sigmoid(x) - 0.5 * (scale_x_y -
1.0)
dy = scale_x_y * fluid.layers.sigmoid(y) - 0.5 * (scale_x_y -
1.0)
loss_x = fluid.layers.abs(dx - tx) * tscale_tobj
loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
loss_y = fluid.layers.abs(dy - ty) * tscale_tobj
loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
# NOTE: we refined loss function of (w, h) as L1Loss
loss_w = fluid.layers.abs(w - tw) * tscale_tobj
loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3])
loss_h = fluid.layers.abs(h - th) * tscale_tobj
loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3])
if self._iou_loss is not None:
loss_iou = self._iou_loss(x, y, w, h, tx, ty, tw, th, anchors,
downsample, self._batch_size,
scale_x_y)
loss_iou = loss_iou * tscale_tobj
loss_iou = fluid.layers.reduce_sum(loss_iou, dim=[1, 2, 3])
loss_ious.append(fluid.layers.reduce_mean(loss_iou))
if self._iou_aware_loss is not None:
loss_iou_aware = self._iou_aware_loss(
ioup, x, y, w, h, tx, ty, tw, th, anchors, downsample,
self._batch_size, scale_x_y)
loss_iou_aware = loss_iou_aware * tobj
loss_iou_aware = fluid.layers.reduce_sum(
loss_iou_aware, dim=[1, 2, 3])
loss_iou_awares.append(
fluid.layers.reduce_mean(loss_iou_aware))
loss_obj_pos, loss_obj_neg = self._calc_obj_loss(
output, obj, tobj, gt_box, self._batch_size, anchors,
num_classes, downsample, self._ignore_thresh, scale_x_y)
loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls,
tcls)
loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0)
loss_cls = fluid.layers.reduce_sum(loss_cls, dim=[1, 2, 3, 4])
loss_xys.append(fluid.layers.reduce_mean(loss_x + loss_y))
loss_whs.append(fluid.layers.reduce_mean(loss_w + loss_h))
loss_objs.append(
fluid.layers.reduce_mean(loss_obj_pos + loss_obj_neg))
loss_clss.append(fluid.layers.reduce_mean(loss_cls))
losses_all = {
"loss_xy": fluid.layers.sum(loss_xys),
"loss_wh": fluid.layers.sum(loss_whs),
"loss_obj": fluid.layers.sum(loss_objs),
"loss_cls": fluid.layers.sum(loss_clss),
}
if self._iou_loss is not None:
losses_all["loss_iou"] = fluid.layers.sum(loss_ious)
if self._iou_aware_loss is not None:
losses_all["loss_iou_aware"] = fluid.layers.sum(loss_iou_awares)
return losses_all
def _split_ioup(self, output, an_num, num_classes):
"""
Split output feature map to output, predicted iou
along channel dimension
"""
ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num])
ioup = fluid.layers.sigmoid(ioup)
oriout = fluid.layers.slice(
output,
axes=[1],
starts=[an_num],
ends=[an_num * (num_classes + 6)])
return (ioup, oriout)
def _split_output(self, output, an_num, num_classes):
"""
Split output feature map to x, y, w, h, objectness, classification
along channel dimension
"""
x = fluid.layers.strided_slice(
output,
axes=[1],
starts=[0],
ends=[output.shape[1]],
strides=[5 + num_classes])
y = fluid.layers.strided_slice(
output,
axes=[1],
starts=[1],
ends=[output.shape[1]],
strides=[5 + num_classes])
w = fluid.layers.strided_slice(
output,
axes=[1],
starts=[2],
ends=[output.shape[1]],
strides=[5 + num_classes])
h = fluid.layers.strided_slice(
output,
axes=[1],
starts=[3],
ends=[output.shape[1]],
strides=[5 + num_classes])
obj = fluid.layers.strided_slice(
output,
axes=[1],
starts=[4],
ends=[output.shape[1]],
strides=[5 + num_classes])
clss = []
stride = output.shape[1] // an_num
for m in range(an_num):
clss.append(
fluid.layers.slice(
output,
axes=[1],
starts=[stride * m + 5],
ends=[stride * m + 5 + num_classes]))
cls = fluid.layers.transpose(
fluid.layers.stack(
clss, axis=1), perm=[0, 1, 3, 4, 2])
return (x, y, w, h, obj, cls)
def _split_target(self, target):
"""
split target to x, y, w, h, objectness, classification
along dimension 2
target is in shape [N, an_num, 6 + class_num, H, W]
"""
tx = target[:, :, 0, :, :]
ty = target[:, :, 1, :, :]
tw = target[:, :, 2, :, :]
th = target[:, :, 3, :, :]
tscale = target[:, :, 4, :, :]
tobj = target[:, :, 5, :, :]
tcls = fluid.layers.transpose(
target[:, :, 6:, :, :], perm=[0, 1, 3, 4, 2])
tcls.stop_gradient = True
return (tx, ty, tw, th, tscale, tobj, tcls)
def _calc_obj_loss(self, output, obj, tobj, gt_box, batch_size, anchors,
num_classes, downsample, ignore_thresh, scale_x_y):
# A prediction bbox overlap any gt_bbox over ignore_thresh,
# objectness loss will be ignored, process as follows:
# 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here
# NOTE: img_size is set as 1.0 to get noramlized pred bbox
bbox, prob = fluid.layers.yolo_box(
x=output,
img_size=fluid.layers.ones(
shape=[batch_size, 2], dtype="int32"),
anchors=anchors,
class_num=num_classes,
conf_thresh=0.,
downsample_ratio=downsample,
clip_bbox=False,
scale_x_y=scale_x_y)
# 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox
# and gt bbox in each sample
if batch_size > 1:
preds = fluid.layers.split(bbox, batch_size, dim=0)
gts = fluid.layers.split(gt_box, batch_size, dim=0)
else:
preds = [bbox]
gts = [gt_box]
probs = [prob]
ious = []
for pred, gt in zip(preds, gts):
def box_xywh2xyxy(box):
x = box[:, 0]
y = box[:, 1]
w = box[:, 2]
h = box[:, 3]
return fluid.layers.stack(
[
x - w / 2.,
y - h / 2.,
x + w / 2.,
y + h / 2.,
], axis=1)
pred = fluid.layers.squeeze(pred, axes=[0])
gt = box_xywh2xyxy(fluid.layers.squeeze(gt, axes=[0]))
ious.append(fluid.layers.iou_similarity(pred, gt))
iou = fluid.layers.stack(ious, axis=0)
# 3. Get iou_mask by IoU between gt bbox and prediction bbox,
# Get obj_mask by tobj(holds gt_score), calculate objectness loss
max_iou = fluid.layers.reduce_max(iou, dim=-1)
iou_mask = fluid.layers.cast(max_iou <= ignore_thresh, dtype="float32")
if self.match_score:
max_prob = fluid.layers.reduce_max(prob, dim=-1)
iou_mask = iou_mask * fluid.layers.cast(
max_prob <= 0.25, dtype="float32")
output_shape = fluid.layers.shape(output)
an_num = len(anchors) // 2
iou_mask = fluid.layers.reshape(iou_mask, (-1, an_num, output_shape[2],
output_shape[3]))
iou_mask.stop_gradient = True
# NOTE: tobj holds gt_score, obj_mask holds object existence mask
obj_mask = fluid.layers.cast(tobj > 0., dtype="float32")
obj_mask.stop_gradient = True
# For positive objectness grids, objectness loss should be calculated
# For negative objectness grids, objectness loss is calculated only iou_mask == 1.0
loss_obj = fluid.layers.sigmoid_cross_entropy_with_logits(obj,
obj_mask)
loss_obj_pos = fluid.layers.reduce_sum(loss_obj * tobj, dim=[1, 2, 3])
loss_obj_neg = fluid.layers.reduce_sum(
loss_obj * (1.0 - obj_mask) * iou_mask, dim=[1, 2, 3])
return loss_obj_pos, loss_obj_neg
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from numbers import Integral
import math
import six
import paddle
from paddle import fluid
def DropBlock(input, block_size, keep_prob, is_test):
if is_test:
return input
def CalculateGamma(input, block_size, keep_prob):
input_shape = fluid.layers.shape(input)
feat_shape_tmp = fluid.layers.slice(input_shape, [0], [3], [4])
feat_shape_tmp = fluid.layers.cast(feat_shape_tmp, dtype="float32")
feat_shape_t = fluid.layers.reshape(feat_shape_tmp, [1, 1, 1, 1])
feat_area = fluid.layers.pow(feat_shape_t, factor=2)
block_shape_t = fluid.layers.fill_constant(
shape=[1, 1, 1, 1], value=block_size, dtype='float32')
block_area = fluid.layers.pow(block_shape_t, factor=2)
useful_shape_t = feat_shape_t - block_shape_t + 1
useful_area = fluid.layers.pow(useful_shape_t, factor=2)
upper_t = feat_area * (1 - keep_prob)
bottom_t = block_area * useful_area
output = upper_t / bottom_t
return output
gamma = CalculateGamma(input, block_size=block_size, keep_prob=keep_prob)
input_shape = fluid.layers.shape(input)
p = fluid.layers.expand_as(gamma, input)
input_shape_tmp = fluid.layers.cast(input_shape, dtype="int64")
random_matrix = fluid.layers.uniform_random(
input_shape_tmp, dtype='float32', min=0.0, max=1.0, seed=1000)
one_zero_m = fluid.layers.less_than(random_matrix, p)
one_zero_m.stop_gradient = True
one_zero_m = fluid.layers.cast(one_zero_m, dtype="float32")
mask_flag = fluid.layers.pool2d(
one_zero_m,
pool_size=block_size,
pool_type='max',
pool_stride=1,
pool_padding=block_size // 2)
mask = 1.0 - mask_flag
elem_numel = fluid.layers.reduce_prod(input_shape)
elem_numel_m = fluid.layers.cast(elem_numel, dtype="float32")
elem_numel_m.stop_gradient = True
elem_sum = fluid.layers.reduce_sum(mask)
elem_sum_m = fluid.layers.cast(elem_sum, dtype="float32")
elem_sum_m.stop_gradient = True
output = input * mask * elem_numel_m / elem_sum_m
return output
class MultiClassNMS(object):
def __init__(self,
score_threshold=.05,
nms_top_k=-1,
keep_top_k=100,
nms_threshold=.5,
normalized=False,
nms_eta=1.0,
background_label=0):
super(MultiClassNMS, self).__init__()
self.score_threshold = score_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
self.nms_threshold = nms_threshold
self.normalized = normalized
self.nms_eta = nms_eta
self.background_label = background_label
def __call__(self, bboxes, scores):
return fluid.layers.multiclass_nms(
bboxes=bboxes,
scores=scores,
score_threshold=self.score_threshold,
nms_top_k=self.nms_top_k,
keep_top_k=self.keep_top_k,
normalized=self.normalized,
nms_threshold=self.nms_threshold,
nms_eta=self.nms_eta,
background_label=self.background_label)
class MatrixNMS(object):
def __init__(self,
score_threshold=.05,
post_threshold=.05,
nms_top_k=-1,
keep_top_k=100,
use_gaussian=False,
gaussian_sigma=2.,
normalized=False,
background_label=0):
super(MatrixNMS, self).__init__()
self.score_threshold = score_threshold
self.post_threshold = post_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
self.normalized = normalized
self.use_gaussian = use_gaussian
self.gaussian_sigma = gaussian_sigma
self.background_label = background_label
def __call__(self, bboxes, scores):
return paddle.fluid.layers.matrix_nms(
bboxes=bboxes,
scores=scores,
score_threshold=self.score_threshold,
post_threshold=self.post_threshold,
nms_top_k=self.nms_top_k,
keep_top_k=self.keep_top_k,
normalized=self.normalized,
use_gaussian=self.use_gaussian,
gaussian_sigma=self.gaussian_sigma,
background_label=self.background_label)
class MultiClassSoftNMS(object):
def __init__(
self,
score_threshold=0.01,
keep_top_k=300,
softnms_sigma=0.5,
normalized=False,
background_label=0, ):
super(MultiClassSoftNMS, self).__init__()
self.score_threshold = score_threshold
self.keep_top_k = keep_top_k
self.softnms_sigma = softnms_sigma
self.normalized = normalized
self.background_label = background_label
def __call__(self, bboxes, scores):
def create_tmp_var(program, name, dtype, shape, lod_level):
return program.current_block().create_var(
name=name, dtype=dtype, shape=shape, lod_level=lod_level)
def _soft_nms_for_cls(dets, sigma, thres):
"""soft_nms_for_cls"""
dets_final = []
while len(dets) > 0:
maxpos = np.argmax(dets[:, 0])
dets_final.append(dets[maxpos].copy())
ts, tx1, ty1, tx2, ty2 = dets[maxpos]
scores = dets[:, 0]
# force remove bbox at maxpos
scores[maxpos] = -1
x1 = dets[:, 1]
y1 = dets[:, 2]
x2 = dets[:, 3]
y2 = dets[:, 4]
eta = 0 if self.normalized else 1
areas = (x2 - x1 + eta) * (y2 - y1 + eta)
xx1 = np.maximum(tx1, x1)
yy1 = np.maximum(ty1, y1)
xx2 = np.minimum(tx2, x2)
yy2 = np.minimum(ty2, y2)
w = np.maximum(0.0, xx2 - xx1 + eta)
h = np.maximum(0.0, yy2 - yy1 + eta)
inter = w * h
ovr = inter / (areas + areas[maxpos] - inter)
weight = np.exp(-(ovr * ovr) / sigma)
scores = scores * weight
idx_keep = np.where(scores >= thres)
dets[:, 0] = scores
dets = dets[idx_keep]
dets_final = np.array(dets_final).reshape(-1, 5)
return dets_final
def _soft_nms(bboxes, scores):
class_nums = scores.shape[-1]
softnms_thres = self.score_threshold
softnms_sigma = self.softnms_sigma
keep_top_k = self.keep_top_k
cls_boxes = [[] for _ in range(class_nums)]
cls_ids = [[] for _ in range(class_nums)]
start_idx = 1 if self.background_label == 0 else 0
for j in range(start_idx, class_nums):
inds = np.where(scores[:, j] >= softnms_thres)[0]
scores_j = scores[inds, j]
rois_j = bboxes[inds, j, :] if len(
bboxes.shape) > 2 else bboxes[inds, :]
dets_j = np.hstack((scores_j[:, np.newaxis], rois_j)).astype(
np.float32, copy=False)
cls_rank = np.argsort(-dets_j[:, 0])
dets_j = dets_j[cls_rank]
cls_boxes[j] = _soft_nms_for_cls(
dets_j, sigma=softnms_sigma, thres=softnms_thres)
cls_ids[j] = np.array([j] * cls_boxes[j].shape[0]).reshape(-1,
1)
cls_boxes = np.vstack(cls_boxes[start_idx:])
cls_ids = np.vstack(cls_ids[start_idx:])
pred_result = np.hstack([cls_ids, cls_boxes])
# Limit to max_per_image detections **over all classes**
image_scores = cls_boxes[:, 0]
if len(image_scores) > keep_top_k:
image_thresh = np.sort(image_scores)[-keep_top_k]
keep = np.where(cls_boxes[:, 0] >= image_thresh)[0]
pred_result = pred_result[keep, :]
return pred_result
def _batch_softnms(bboxes, scores):
batch_offsets = bboxes.lod()
bboxes = np.array(bboxes)
scores = np.array(scores)
out_offsets = [0]
pred_res = []
if len(batch_offsets) > 0:
batch_offset = batch_offsets[0]
for i in range(len(batch_offset) - 1):
s, e = batch_offset[i], batch_offset[i + 1]
pred = _soft_nms(bboxes[s:e], scores[s:e])
out_offsets.append(pred.shape[0] + out_offsets[-1])
pred_res.append(pred)
else:
assert len(bboxes.shape) == 3
assert len(scores.shape) == 3
for i in range(bboxes.shape[0]):
pred = _soft_nms(bboxes[i], scores[i])
out_offsets.append(pred.shape[0] + out_offsets[-1])
pred_res.append(pred)
res = fluid.LoDTensor()
res.set_lod([out_offsets])
if len(pred_res) == 0:
pred_res = np.array([[1]], dtype=np.float32)
res.set(np.vstack(pred_res).astype(np.float32), fluid.CPUPlace())
return res
pred_result = create_tmp_var(
fluid.default_main_program(),
name='softnms_pred_result',
dtype='float32',
shape=[-1, 6],
lod_level=1)
fluid.layers.py_func(
func=_batch_softnms, x=[bboxes, scores], out=pred_result)
return pred_result
...@@ -16,25 +16,50 @@ from paddle import fluid ...@@ -16,25 +16,50 @@ from paddle import fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay from paddle.fluid.regularizer import L2Decay
from collections import OrderedDict from collections import OrderedDict
from .ops import MultiClassNMS, MultiClassSoftNMS, MatrixNMS
from .ops import DropBlock
from .loss.yolo_loss import YOLOv3Loss
from .loss.iou_loss import IouLoss
from .loss.iou_aware_loss import IouAwareLoss
from .iou_aware import get_iou_aware_score
try:
from collections.abc import Sequence
except Exception:
from collections import Sequence
class YOLOv3: class YOLOv3:
def __init__(self, def __init__(
backbone, self,
num_classes, backbone,
mode='train', mode='train',
anchors=None, # YOLOv3Head
anchor_masks=None, num_classes=80,
ignore_threshold=0.7, anchors=None,
label_smooth=False, anchor_masks=None,
nms_score_threshold=0.01, coord_conv=False,
nms_topk=1000, iou_aware=False,
nms_keep_topk=100, iou_aware_factor=0.4,
nms_iou_threshold=0.45, scale_x_y=1.,
train_random_shapes=[ spp=False,
320, 352, 384, 416, 448, 480, 512, 544, 576, 608 drop_block=False,
], use_matrix_nms=False,
fixed_input_shape=None): # YOLOv3Loss
batch_size=8,
ignore_threshold=0.7,
label_smooth=False,
use_fine_grained_loss=False,
use_iou_loss=False,
iou_loss_weight=2.5,
iou_aware_loss_weight=1.0,
max_height=608,
max_width=608,
# NMS
nms_score_threshold=0.01,
nms_topk=1000,
nms_keep_topk=100,
nms_iou_threshold=0.45,
fixed_input_shape=None):
if anchors is None: if anchors is None:
anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]] [59, 119], [116, 90], [156, 198], [373, 326]]
...@@ -46,56 +71,114 @@ class YOLOv3: ...@@ -46,56 +71,114 @@ class YOLOv3:
self.mode = mode self.mode = mode
self.num_classes = num_classes self.num_classes = num_classes
self.backbone = backbone self.backbone = backbone
self.ignore_thresh = ignore_threshold
self.label_smooth = label_smooth
self.nms_score_threshold = nms_score_threshold
self.nms_topk = nms_topk
self.nms_keep_topk = nms_keep_topk
self.nms_iou_threshold = nms_iou_threshold
self.norm_decay = 0.0 self.norm_decay = 0.0
self.prefix_name = '' self.prefix_name = ''
self.train_random_shapes = train_random_shapes self.use_fine_grained_loss = use_fine_grained_loss
self.fixed_input_shape = fixed_input_shape self.fixed_input_shape = fixed_input_shape
self.coord_conv = coord_conv
self.iou_aware = iou_aware
self.iou_aware_factor = iou_aware_factor
self.scale_x_y = scale_x_y
self.use_spp = spp
self.drop_block = drop_block
def _head(self, feats): if use_matrix_nms:
self.nms = MatrixNMS(
background_label=-1,
keep_top_k=nms_keep_topk,
normalized=False,
score_threshold=nms_score_threshold,
post_threshold=0.01)
else:
self.nms = MultiClassNMS(
background_label=-1,
keep_top_k=nms_keep_topk,
nms_threshold=nms_iou_threshold,
nms_top_k=nms_topk,
normalized=False,
score_threshold=nms_score_threshold)
self.iou_loss = None
self.iou_aware_loss = None
if use_iou_loss:
self.iou_loss = IouLoss(
loss_weight=iou_loss_weight,
max_height=max_height,
max_width=max_width)
if iou_aware:
self.iou_aware_loss = IouAwareLoss(
loss_weight=iou_aware_loss_weight,
max_height=max_height,
max_width=max_width)
self.yolo_loss = YOLOv3Loss(
batch_size=batch_size,
ignore_thresh=ignore_threshold,
scale_x_y=scale_x_y,
label_smooth=label_smooth,
use_fine_grained_loss=self.use_fine_grained_loss,
iou_loss=self.iou_loss,
iou_aware_loss=self.iou_aware_loss)
self.conv_block_num = 2
self.block_size = 3
self.keep_prob = 0.9
self.downsample = [32, 16, 8]
self.clip_bbox = True
def _head(self, input, is_train=True):
outputs = [] outputs = []
# get last out_layer_num blocks in reverse order
out_layer_num = len(self.anchor_masks) out_layer_num = len(self.anchor_masks)
blocks = feats[-1:-out_layer_num - 1:-1] blocks = input[-1:-out_layer_num - 1:-1]
route = None
route = None
for i, block in enumerate(blocks): for i, block in enumerate(blocks):
if i > 0: if i > 0: # perform concat in first 2 detection_block
block = fluid.layers.concat(input=[route, block], axis=1) block = fluid.layers.concat(input=[route, block], axis=1)
route, tip = self._detection_block( route, tip = self._detection_block(
block, block,
channel=512 // (2**i), channel=64 * (2**out_layer_num) // (2**i),
name=self.prefix_name + 'yolo_block.{}'.format(i)) is_first=i == 0,
is_test=(not is_train),
conv_block_num=self.conv_block_num,
name=self.prefix_name + "yolo_block.{}".format(i))
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5) # out channel number = mask_num * (5 + class_num)
block_out = fluid.layers.conv2d( if self.iou_aware:
input=tip, num_filters = len(self.anchor_masks[i]) * (
num_filters=num_filters, self.num_classes + 6)
filter_size=1, else:
stride=1, num_filters = len(self.anchor_masks[i]) * (
padding=0, self.num_classes + 5)
act=None, with fluid.name_scope('yolo_output'):
param_attr=ParamAttr(name=self.prefix_name + block_out = fluid.layers.conv2d(
'yolo_output.{}.conv.weights'.format(i)), input=tip,
bias_attr=ParamAttr( num_filters=num_filters,
regularizer=L2Decay(0.0), filter_size=1,
name=self.prefix_name + stride=1,
'yolo_output.{}.conv.bias'.format(i))) padding=0,
outputs.append(block_out) act=None,
param_attr=ParamAttr(
name=self.prefix_name +
"yolo_output.{}.conv.weights".format(i)),
bias_attr=ParamAttr(
regularizer=L2Decay(0.),
name=self.prefix_name +
"yolo_output.{}.conv.bias".format(i)))
outputs.append(block_out)
if i < len(blocks) - 1: if i < len(blocks) - 1:
# do not perform upsample in the last detection_block
route = self._conv_bn( route = self._conv_bn(
input=route, input=route,
ch_out=256 // (2**i), ch_out=256 // (2**i),
filter_size=1, filter_size=1,
stride=1, stride=1,
padding=0, padding=0,
name=self.prefix_name + 'yolo_transition.{}'.format(i)) is_test=(not is_train),
name=self.prefix_name + "yolo_transition.{}".format(i))
# upsample
route = self._upsample(route) route = self._upsample(route)
return outputs return outputs
def _parse_anchors(self, anchors): def _parse_anchors(self, anchors):
...@@ -116,6 +199,54 @@ class YOLOv3: ...@@ -116,6 +199,54 @@ class YOLOv3:
assert mask < anchor_num, "anchor mask index overflow" assert mask < anchor_num, "anchor mask index overflow"
self.mask_anchors[-1].extend(anchors[mask]) self.mask_anchors[-1].extend(anchors[mask])
def _create_tensor_from_numpy(self, numpy_array):
paddle_array = fluid.layers.create_global_var(
shape=numpy_array.shape, value=0., dtype=numpy_array.dtype)
fluid.layers.assign(numpy_array, paddle_array)
return paddle_array
def _add_coord(self, input, is_test=True):
if not self.coord_conv:
return input
# NOTE: here is used for exporting model for TensorRT inference,
# only support batch_size=1 for input shape should be fixed,
# and we create tensor with fixed shape from numpy array
if is_test and input.shape[2] > 0 and input.shape[3] > 0:
batch_size = 1
grid_x = int(input.shape[3])
grid_y = int(input.shape[2])
idx_i = np.array(
[[i / (grid_x - 1) * 2.0 - 1 for i in range(grid_x)]],
dtype='float32')
gi_np = np.repeat(idx_i, grid_y, axis=0)
gi_np = np.reshape(gi_np, newshape=[1, 1, grid_y, grid_x])
gi_np = np.tile(gi_np, reps=[batch_size, 1, 1, 1])
x_range = self._create_tensor_from_numpy(gi_np.astype(np.float32))
x_range.stop_gradient = True
y_range = self._create_tensor_from_numpy(
gi_np.transpose([0, 1, 3, 2]).astype(np.float32))
y_range.stop_gradient = True
# NOTE: in training mode, H and W is variable for random shape,
# implement add_coord with shape as Variable
else:
input_shape = fluid.layers.shape(input)
b = input_shape[0]
h = input_shape[2]
w = input_shape[3]
x_range = fluid.layers.range(0, w, 1, 'float32') / ((w - 1.) / 2.)
x_range = x_range - 1.
x_range = fluid.layers.unsqueeze(x_range, [0, 1, 2])
x_range = fluid.layers.expand(x_range, [b, 1, h, 1])
x_range.stop_gradient = True
y_range = fluid.layers.transpose(x_range, [0, 1, 3, 2])
y_range.stop_gradient = True
return fluid.layers.concat([input, x_range, y_range], axis=1)
def _conv_bn(self, def _conv_bn(self,
input, input,
ch_out, ch_out,
...@@ -151,18 +282,52 @@ class YOLOv3: ...@@ -151,18 +282,52 @@ class YOLOv3:
out = fluid.layers.leaky_relu(x=out, alpha=0.1) out = fluid.layers.leaky_relu(x=out, alpha=0.1)
return out return out
def _spp_module(self, input, is_test=True, name=""):
output1 = input
output2 = fluid.layers.pool2d(
input=output1,
pool_size=5,
pool_stride=1,
pool_padding=2,
ceil_mode=False,
pool_type='max')
output3 = fluid.layers.pool2d(
input=output1,
pool_size=9,
pool_stride=1,
pool_padding=4,
ceil_mode=False,
pool_type='max')
output4 = fluid.layers.pool2d(
input=output1,
pool_size=13,
pool_stride=1,
pool_padding=6,
ceil_mode=False,
pool_type='max')
output = fluid.layers.concat(
input=[output1, output2, output3, output4], axis=1)
return output
def _upsample(self, input, scale=2, name=None): def _upsample(self, input, scale=2, name=None):
out = fluid.layers.resize_nearest( out = fluid.layers.resize_nearest(
input=input, scale=float(scale), name=name) input=input, scale=float(scale), name=name)
return out return out
def _detection_block(self, input, channel, name=None): def _detection_block(self,
assert channel % 2 == 0, "channel({}) cannot be divided by 2 in detection block({})".format( input,
channel, name) channel,
conv_block_num=2,
is_first=False,
is_test=True,
name=None):
assert channel % 2 == 0, \
"channel {} cannot be divided by 2 in detection block {}" \
.format(channel, name)
is_test = False if self.mode == 'train' else True
conv = input conv = input
for i in range(2): for j in range(conv_block_num):
conv = self._add_coord(conv, is_test=is_test)
conv = self._conv_bn( conv = self._conv_bn(
conv, conv,
channel, channel,
...@@ -170,7 +335,17 @@ class YOLOv3: ...@@ -170,7 +335,17 @@ class YOLOv3:
stride=1, stride=1,
padding=0, padding=0,
is_test=is_test, is_test=is_test,
name='{}.{}.0'.format(name, i)) name='{}.{}.0'.format(name, j))
if self.use_spp and is_first and j == 1:
conv = self._spp_module(conv, is_test=is_test, name="spp")
conv = self._conv_bn(
conv,
512,
filter_size=1,
stride=1,
padding=0,
is_test=is_test,
name='{}.{}.spp.conv'.format(name, j))
conv = self._conv_bn( conv = self._conv_bn(
conv, conv,
channel * 2, channel * 2,
...@@ -178,7 +353,21 @@ class YOLOv3: ...@@ -178,7 +353,21 @@ class YOLOv3:
stride=1, stride=1,
padding=1, padding=1,
is_test=is_test, is_test=is_test,
name='{}.{}.1'.format(name, i)) name='{}.{}.1'.format(name, j))
if self.drop_block and j == 0 and not is_first:
conv = DropBlock(
conv,
block_size=self.block_size,
keep_prob=self.keep_prob,
is_test=is_test)
if self.drop_block and is_first:
conv = DropBlock(
conv,
block_size=self.block_size,
keep_prob=self.keep_prob,
is_test=is_test)
conv = self._add_coord(conv, is_test=is_test)
route = self._conv_bn( route = self._conv_bn(
conv, conv,
channel, channel,
...@@ -187,8 +376,9 @@ class YOLOv3: ...@@ -187,8 +376,9 @@ class YOLOv3:
padding=0, padding=0,
is_test=is_test, is_test=is_test,
name='{}.2'.format(name)) name='{}.2'.format(name))
new_route = self._add_coord(route, is_test=is_test)
tip = self._conv_bn( tip = self._conv_bn(
route, new_route,
channel * 2, channel * 2,
filter_size=3, filter_size=3,
stride=1, stride=1,
...@@ -197,54 +387,44 @@ class YOLOv3: ...@@ -197,54 +387,44 @@ class YOLOv3:
name='{}.tip'.format(name)) name='{}.tip'.format(name))
return route, tip return route, tip
def _get_loss(self, inputs, gt_box, gt_label, gt_score): def _get_loss(self, inputs, gt_box, gt_label, gt_score, targets):
losses = [] loss = self.yolo_loss(inputs, gt_box, gt_label, gt_score, targets,
downsample = 32 self.anchors, self.anchor_masks,
for i, input in enumerate(inputs): self.mask_anchors, self.num_classes,
loss = fluid.layers.yolov3_loss( self.prefix_name)
x=input, total_loss = fluid.layers.sum(list(loss.values()))
gt_box=gt_box, return total_loss
gt_label=gt_label,
gt_score=gt_score,
anchors=self.anchors,
anchor_mask=self.anchor_masks[i],
class_num=self.num_classes,
ignore_thresh=self.ignore_thresh,
downsample_ratio=downsample,
use_label_smooth=self.label_smooth,
name=self.prefix_name + 'yolo_loss' + str(i))
losses.append(fluid.layers.reduce_mean(loss))
downsample //= 2
return sum(losses)
def _get_prediction(self, inputs, im_size): def _get_prediction(self, inputs, im_size):
boxes = [] boxes = []
scores = [] scores = []
downsample = 32
for i, input in enumerate(inputs): for i, input in enumerate(inputs):
if self.iou_aware:
input = get_iou_aware_score(input,
len(self.anchor_masks[i]),
self.num_classes,
self.iou_aware_factor)
scale_x_y = self.scale_x_y if not isinstance(
self.scale_x_y, Sequence) else self.scale_x_y[i]
box, score = fluid.layers.yolo_box( box, score = fluid.layers.yolo_box(
x=input, x=input,
img_size=im_size, img_size=im_size,
anchors=self.mask_anchors[i], anchors=self.mask_anchors[i],
class_num=self.num_classes, class_num=self.num_classes,
conf_thresh=self.nms_score_threshold, conf_thresh=self.nms.score_threshold,
downsample_ratio=downsample, downsample_ratio=self.downsample[i],
name=self.prefix_name + 'yolo_box' + str(i)) name=self.prefix_name + 'yolo_box' + str(i),
clip_bbox=self.clip_bbox,
scale_x_y=self.scale_x_y)
boxes.append(box) boxes.append(box)
scores.append(fluid.layers.transpose(score, perm=[0, 2, 1])) scores.append(fluid.layers.transpose(score, perm=[0, 2, 1]))
downsample //= 2
yolo_boxes = fluid.layers.concat(boxes, axis=1) yolo_boxes = fluid.layers.concat(boxes, axis=1)
yolo_scores = fluid.layers.concat(scores, axis=2) yolo_scores = fluid.layers.concat(scores, axis=2)
pred = fluid.layers.multiclass_nms( if type(self.nms) is MultiClassSoftNMS:
bboxes=yolo_boxes, yolo_scores = fluid.layers.transpose(yolo_scores, perm=[0, 2, 1])
scores=yolo_scores, pred = self.nms(bboxes=yolo_boxes, scores=yolo_scores)
score_threshold=self.nms_score_threshold,
nms_top_k=self.nms_topk,
keep_top_k=self.nms_keep_topk,
nms_threshold=self.nms_iou_threshold,
normalized=False,
nms_eta=1.0,
background_label=-1)
return pred return pred
def generate_inputs(self): def generate_inputs(self):
...@@ -267,6 +447,25 @@ class YOLOv3: ...@@ -267,6 +447,25 @@ class YOLOv3:
dtype='float32', shape=[None, None], name='gt_score') dtype='float32', shape=[None, None], name='gt_score')
inputs['im_size'] = fluid.data( inputs['im_size'] = fluid.data(
dtype='int32', shape=[None, 2], name='im_size') dtype='int32', shape=[None, 2], name='im_size')
if self.use_fine_grained_loss:
downsample = 32
for i, mask in enumerate(self.anchor_masks):
if self.fixed_input_shape is not None:
target_shape = [
self.fixed_input_shape[1] // downsample,
self.fixed_input_shape[0] // downsample
]
else:
target_shape = [None, None]
inputs['target{}'.format(i)] = fluid.data(
dtype='float32',
lod_level=0,
shape=[
None, len(mask), 6 + self.num_classes,
target_shape[0], target_shape[1]
],
name='target{}'.format(i))
downsample //= 2
elif self.mode == 'eval': elif self.mode == 'eval':
inputs['im_size'] = fluid.data( inputs['im_size'] = fluid.data(
dtype='int32', shape=[None, 2], name='im_size') dtype='int32', shape=[None, 2], name='im_size')
...@@ -284,44 +483,37 @@ class YOLOv3: ...@@ -284,44 +483,37 @@ class YOLOv3:
return inputs return inputs
def build_net(self, inputs): def build_net(self, inputs):
import numpy as np
image = inputs['image'] image = inputs['image']
if self.mode == 'train':
if isinstance(self.train_random_shapes,
(list, tuple)) and len(self.train_random_shapes) > 0:
import numpy as np
shapes = np.array(self.train_random_shapes)
shapes = np.stack([shapes, shapes], axis=1).astype('float32')
shapes_tensor = fluid.layers.assign(shapes)
index = fluid.layers.uniform_random(
shape=[1], dtype='float32', min=0.0, max=1)
index = fluid.layers.cast(
index * len(self.train_random_shapes), dtype='int32')
shape = fluid.layers.gather(shapes_tensor, index)
shape = fluid.layers.reshape(shape, [-1])
shape = fluid.layers.cast(shape, dtype='int32')
image = fluid.layers.resize_nearest(
image, out_shape=shape, align_corners=False)
feats = self.backbone(image) feats = self.backbone(image)
if isinstance(feats, OrderedDict): if isinstance(feats, OrderedDict):
feat_names = list(feats.keys()) feat_names = list(feats.keys())
feats = [feats[name] for name in feat_names] feats = [feats[name] for name in feat_names]
head_outputs = self._head(feats) head_outputs = self._head(feats, self.mode == 'train')
if self.mode == 'train': if self.mode == 'train':
gt_box = inputs['gt_box'] gt_box = inputs['gt_box']
gt_label = inputs['gt_label'] gt_label = inputs['gt_label']
gt_score = inputs['gt_score'] gt_score = inputs['gt_score']
im_size = inputs['im_size'] im_size = inputs['im_size']
num_boxes = fluid.layers.shape(gt_box)[1] #num_boxes = fluid.layers.shape(gt_box)[1]
im_size_wh = fluid.layers.reverse(im_size, axis=1) #im_size_wh = fluid.layers.reverse(im_size, axis=1)
whwh = fluid.layers.concat([im_size_wh, im_size_wh], axis=1) #whwh = fluid.layers.concat([im_size_wh, im_size_wh], axis=1)
whwh = fluid.layers.unsqueeze(whwh, axes=[1]) #whwh = fluid.layers.unsqueeze(whwh, axes=[1])
whwh = fluid.layers.expand(whwh, expand_times=[1, num_boxes, 1]) #whwh = fluid.layers.expand(whwh, expand_times=[1, num_boxes, 1])
whwh = fluid.layers.cast(whwh, dtype='float32') #whwh = fluid.layers.cast(whwh, dtype='float32')
whwh.stop_gradient = True #whwh.stop_gradient = True
normalized_box = fluid.layers.elementwise_div(gt_box, whwh) #normalized_box = fluid.layers.elementwise_div(gt_box, whwh)
normalized_box = gt_box
targets = []
if self.use_fine_grained_loss:
for i, mask in enumerate(self.anchor_masks):
k = 'target{}'.format(i)
if k in inputs:
targets.append(inputs[k])
return self._get_loss(head_outputs, normalized_box, gt_label, return self._get_loss(head_outputs, normalized_box, gt_label,
gt_score) gt_score, targets)
else: else:
im_size = inputs['im_size'] im_size = inputs['im_size']
return self._get_prediction(head_outputs, im_size) return self._get_prediction(head_outputs, im_size)
...@@ -55,6 +55,7 @@ class Compose(DetTransform): ...@@ -55,6 +55,7 @@ class Compose(DetTransform):
raise ValueError('The length of transforms ' + \ raise ValueError('The length of transforms ' + \
'must be equal or larger than 1!') 'must be equal or larger than 1!')
self.transforms = transforms self.transforms = transforms
self.batch_transforms = None
self.use_mixup = False self.use_mixup = False
for t in self.transforms: for t in self.transforms:
if type(t).__name__ == 'MixupImage': if type(t).__name__ == 'MixupImage':
...@@ -1385,3 +1386,187 @@ class ComposedYOLOv3Transforms(Compose): ...@@ -1385,3 +1386,187 @@ class ComposedYOLOv3Transforms(Compose):
mean=mean, std=std) mean=mean, std=std)
] ]
super(ComposedYOLOv3Transforms, self).__init__(transforms) super(ComposedYOLOv3Transforms, self).__init__(transforms)
class BatchRandomShape(DetTransform):
"""调整图像大小(resize)。
对batch数据中的每张图像全部resize到random_shapes中任意一个大小。
注意:当插值方式为“RANDOM”时,则随机选取一种插值方式进行resize。
Args:
random_shapes (list): resize大小选择列表。
默认为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
interp (str): resize的插值方式,与opencv的插值方式对应,取值范围为
['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM']。默认为"RANDOM"。
Raises:
ValueError: 插值方式不在['NEAREST', 'LINEAR', 'CUBIC',
'AREA', 'LANCZOS4', 'RANDOM']中。
"""
# The interpolation mode
interp_dict = {
'NEAREST': cv2.INTER_NEAREST,
'LINEAR': cv2.INTER_LINEAR,
'CUBIC': cv2.INTER_CUBIC,
'AREA': cv2.INTER_AREA,
'LANCZOS4': cv2.INTER_LANCZOS4
}
def __init__(
self,
random_shapes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608],
interp='RANDOM'):
if not (interp == "RANDOM" or interp in self.interp_dict):
raise ValueError("interp should be one of {}".format(
self.interp_dict.keys()))
self.random_shapes = random_shapes
self.interp = interp
def __call__(self, batch_data):
"""
Args:
batch_data (list): 由与图像相关的各种信息组成的batch数据。
Returns:
list: 由与图像相关的各种信息组成的batch数据。
"""
shape = np.random.choice(self.random_shapes)
if self.interp == "RANDOM":
interp = random.choice(list(self.interp_dict.keys()))
else:
interp = self.interp
for data_id, data in enumerate(batch_data):
data_list = list(data)
im = data_list[0]
im = np.swapaxes(im, 1, 0)
im = np.swapaxes(im, 1, 2)
im = resize(im, shape, self.interp_dict[interp])
im = np.swapaxes(im, 1, 2)
im = np.swapaxes(im, 1, 0)
data_list[0] = im
batch_data[data_id] = tuple(data_list)
return batch_data
class GenerateYoloTarget(object):
"""生成YOLOv3的ground truth(真实标注框)在不同特征层的位置转换信息。
该transform只在YOLOv3计算细粒度loss时使用。
Args:
anchors (list|tuple): anchor框的宽度和高度。
anchor_masks (list|tuple): 在计算损失时,使用anchor的mask索引。
num_classes (int): 类别数。默认为80。
iou_thresh (float): iou阈值,当anchor和真实标注框的iou大于该阈值时,计入target。默认为1.0。
"""
def __init__(self,
anchors,
anchor_masks,
downsample_ratios,
num_classes=80,
iou_thresh=1.):
super(GenerateYoloTarget, self).__init__()
self.anchors = anchors
self.anchor_masks = anchor_masks
self.downsample_ratios = downsample_ratios
self.num_classes = num_classes
self.iou_thresh = iou_thresh
def __call__(self, batch_data):
"""
Args:
batch_data (list): 由与图像相关的各种信息组成的batch数据。
Returns:
list: 由与图像相关的各种信息组成的batch数据。
其中,每个数据新添加的字段为:
- target0 (np.ndarray): YOLOv3的ground truth在特征层0的位置转换信息,
形状为(特征层0的anchor数量, 6+类别数, 特征层0的h, 特征层0的w)。
- target1 (np.ndarray): YOLOv3的ground truth在特征层1的位置转换信息,
形状为(特征层1的anchor数量, 6+类别数, 特征层1的h, 特征层1的w)。
- ...
-targetn (np.ndarray): YOLOv3的ground truth在特征层n的位置转换信息,
形状为(特征层n的anchor数量, 6+类别数, 特征层n的h, 特征层n的w)。
n的是大小由anchor_masks的长度决定。
"""
im = batch_data[0][0]
h = im.shape[1]
w = im.shape[2]
an_hw = np.array(self.anchors) / np.array([[w, h]])
for data_id, data in enumerate(batch_data):
gt_bbox = data[1]
gt_class = data[2]
gt_score = data[3]
im_shape = data[4]
origin_h = float(im_shape[0])
origin_w = float(im_shape[1])
data_list = list(data)
for i, (
mask, downsample_ratio
) in enumerate(zip(self.anchor_masks, self.downsample_ratios)):
grid_h = int(h / downsample_ratio)
grid_w = int(w / downsample_ratio)
target = np.zeros(
(len(mask), 6 + self.num_classes, grid_h, grid_w),
dtype=np.float32)
for b in range(gt_bbox.shape[0]):
gx = gt_bbox[b, 0] / float(origin_w)
gy = gt_bbox[b, 1] / float(origin_h)
gw = gt_bbox[b, 2] / float(origin_w)
gh = gt_bbox[b, 3] / float(origin_h)
cls = gt_class[b]
score = gt_score[b]
if gw <= 0. or gh <= 0. or score <= 0.:
continue
# find best match anchor index
best_iou = 0.
best_idx = -1
for an_idx in range(an_hw.shape[0]):
iou = jaccard_overlap(
[0., 0., gw, gh],
[0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
if iou > best_iou:
best_iou = iou
best_idx = an_idx
gi = int(gx * grid_w)
gj = int(gy * grid_h)
# gtbox should be regresed in this layes if best match
# anchor index in anchor mask of this layer
if best_idx in mask:
best_n = mask.index(best_idx)
# x, y, w, h, scale
target[best_n, 0, gj, gi] = gx * grid_w - gi
target[best_n, 1, gj, gi] = gy * grid_h - gj
target[best_n, 2, gj, gi] = np.log(
gw * w / self.anchors[best_idx][0])
target[best_n, 3, gj, gi] = np.log(
gh * h / self.anchors[best_idx][1])
target[best_n, 4, gj, gi] = 2.0 - gw * gh
# objectness record gt_score
target[best_n, 5, gj, gi] = score
# classification
target[best_n, 6 + cls, gj, gi] = 1.
# For non-matched anchors, calculate the target if the iou
# between anchor and gt is larger than iou_thresh
if self.iou_thresh < 1:
for idx, mask_i in enumerate(mask):
if mask_i == best_idx: continue
iou = jaccard_overlap(
[0., 0., gw, gh],
[0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
if iou > self.iou_thresh:
# x, y, w, h, scale
target[idx, 0, gj, gi] = gx * grid_w - gi
target[idx, 1, gj, gi] = gy * grid_h - gj
target[idx, 2, gj, gi] = np.log(
gw * w / self.anchors[mask_i][0])
target[idx, 3, gj, gi] = np.log(
gh * h / self.anchors[mask_i][1])
target[idx, 4, gj, gi] = 2.0 - gw * gh
# objectness record gt_score
target[idx, 5, gj, gi] = score
# classification
target[idx, 6 + cls, gj, gi] = 1.
data_list.append(target)
batch_data[data_id] = tuple(data_list)
return batch_data
# 环境变量配置,用于控制是否使用GPU
# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from paddlex.det import transforms
import paddlex as pdx
# 下载和解压昆虫检测数据集
insect_dataset = 'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
pdx.utils.download_and_decompress(insect_dataset, path='./')
# 定义训练和验证时的transforms
# API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/det_transforms.html
train_transforms = transforms.Compose([
transforms.MixupImage(mixup_epoch=250), transforms.RandomDistort(),
transforms.RandomExpand(), transforms.RandomCrop(), transforms.Resize(
target_size=608, interp='RANDOM'), transforms.RandomHorizontalFlip(),
transforms.Normalize()
])
eval_transforms = transforms.Compose([
transforms.Resize(
target_size=608, interp='CUBIC'), transforms.Normalize()
])
# 定义训练和验证所用的数据集
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-vocdetection
train_dataset = pdx.datasets.VOCDetection(
data_dir='insect_det',
file_list='insect_det/train_list.txt',
label_list='insect_det/labels.txt',
transforms=train_transforms,
shuffle=True)
eval_dataset = pdx.datasets.VOCDetection(
data_dir='insect_det',
file_list='insect_det/val_list.txt',
label_list='insect_det/labels.txt',
transforms=eval_transforms)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
num_classes = len(train_dataset.labels)
# API说明: https://paddlex.readthedocs.io/zh_CN/develop/apis/models/detection.html#paddlex-det-yolov3
model = pdx.det.YOLOv3(
num_classes=num_classes,
backbone='ResNet50_vd',
with_dcn_v2=True,
use_coord_conv=True,
use_iou_aware=True,
use_spp=True,
use_drop_block=True,
scale_x_y=1.05,
use_iou_loss=True,
use_matrix_nms=True)
# API说明: https://paddlex.readthedocs.io/zh_CN/develop/apis/models/detection.html#train
# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
model.train(
num_epochs=270,
train_dataset=train_dataset,
train_batch_size=8,
eval_dataset=eval_dataset,
learning_rate=0.000125,
lr_decay_epochs=[210, 240],
use_ema=True,
save_dir='output/ppyolo',
use_vdl=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册