未验证 提交 4ce9ead4 编写于 作者: K Kaipeng Deng 提交者: GitHub

Add fine grained yolov3 loss (#109)

* split yolov3 loss
上级 46abc77d
......@@ -8,6 +8,7 @@ metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar
weights: output/yolov3_r50vd_dcn/model_final
num_classes: 80
use_fine_grained_loss: false
YOLOv3:
backbone: ResNet
......@@ -29,8 +30,7 @@ YOLOv3Head:
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
ignore_thresh: 0.7
label_smooth: true
yolo_loss: YOLOv3Loss
nms:
background_label: -1
keep_top_k: 100
......@@ -39,6 +39,11 @@ YOLOv3Head:
normalized: false
score_threshold: 0.01
YOLOv3Loss:
batch_size: 8
ignore_thresh: 0.7
label_smooth: false
LearningRate:
base_lr: 0.001
schedulers:
......
......@@ -8,6 +8,7 @@ metric: COCO
pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/ResNet50_vd_obj365_pretrained.tar
weights: output/yolov3_r50vd_dcn_obj365_pretrained_coco/model_final
num_classes: 80
use_fine_grained_loss: false
YOLOv3:
backbone: ResNet
......@@ -29,8 +30,7 @@ YOLOv3Head:
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
ignore_thresh: 0.7
label_smooth: true
yolo_loss: YOLOv3Loss
nms:
background_label: -1
keep_top_k: 100
......@@ -39,6 +39,11 @@ YOLOv3Head:
normalized: false
score_threshold: 0.01
YOLOv3Loss:
batch_size: 8
ignore_thresh: 0.7
label_smooth: false
LearningRate:
base_lr: 0.001
schedulers:
......
......@@ -8,6 +8,7 @@ metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_pretrained.tar
weights: output/yolov3_darknet/model_final
num_classes: 80
use_fine_grained_loss: false
YOLOv3:
backbone: DarkNet
......@@ -24,8 +25,7 @@ YOLOv3Head:
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
ignore_thresh: 0.7
label_smooth: true
yolo_loss: YOLOv3Loss
nms:
background_label: -1
keep_top_k: 100
......@@ -34,6 +34,11 @@ YOLOv3Head:
normalized: false
score_threshold: 0.01
YOLOv3Loss:
batch_size: 8
ignore_thresh: 0.7
label_smooth: false
LearningRate:
base_lr: 0.001
schedulers:
......
......@@ -9,6 +9,7 @@ map_type: 11point
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_pretrained.tar
weights: output/yolov3_darknet_voc/model_final
num_classes: 20
use_fine_grained_loss: false
YOLOv3:
backbone: DarkNet
......@@ -25,8 +26,7 @@ YOLOv3Head:
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
ignore_thresh: 0.7
label_smooth: false
yolo_loss: YOLOv3Loss
nms:
background_label: -1
keep_top_k: 100
......@@ -35,6 +35,11 @@ YOLOv3Head:
normalized: false
score_threshold: 0.01
YOLOv3Loss:
batch_size: 8
ignore_thresh: 0.7
label_smooth: true
LearningRate:
base_lr: 0.001
schedulers:
......
......@@ -8,6 +8,7 @@ metric: COCO
pretrain_weights: http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar
weights: output/yolov3_mobilenet_v1/model_final
num_classes: 80
use_fine_grained_loss: false
YOLOv3:
backbone: MobileNet
......@@ -25,8 +26,7 @@ YOLOv3Head:
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
ignore_thresh: 0.7
label_smooth: true
yolo_loss: YOLOv3Loss
nms:
background_label: -1
keep_top_k: 100
......@@ -35,6 +35,11 @@ YOLOv3Head:
normalized: false
score_threshold: 0.01
YOLOv3Loss:
batch_size: 8
ignore_thresh: 0.7
label_smooth: false
LearningRate:
base_lr: 0.001
schedulers:
......
......@@ -10,6 +10,7 @@ pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mob
weights: output/yolov3_mobilenet_v1_fruit/best_model
num_classes: 3
finetune_exclude_pretrained_params: ['yolo_output']
use_fine_grained_loss: false
YOLOv3:
backbone: MobileNet
......@@ -27,8 +28,7 @@ YOLOv3Head:
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
ignore_thresh: 0.7
label_smooth: true
yolo_loss: YOLOv3Loss
nms:
background_label: -1
keep_top_k: 100
......@@ -37,6 +37,11 @@ YOLOv3Head:
normalized: false
score_threshold: 0.01
YOLOv3Loss:
batch_size: 8
ignore_thresh: 0.7
label_smooth: true
LearningRate:
base_lr: 0.00001
schedulers:
......
......@@ -9,6 +9,7 @@ map_type: 11point
pretrain_weights: http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar
weights: output/yolov3_mobilenet_v1_voc/model_final
num_classes: 20
use_fine_grained_loss: false
YOLOv3:
backbone: MobileNet
......@@ -26,8 +27,7 @@ YOLOv3Head:
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
ignore_thresh: 0.7
label_smooth: false
yolo_loss: YOLOv3Loss
nms:
background_label: -1
keep_top_k: 100
......@@ -36,6 +36,11 @@ YOLOv3Head:
normalized: false
score_threshold: 0.01
YOLOv3Loss:
batch_size: 8
ignore_thresh: 0.7
label_smooth: false
LearningRate:
base_lr: 0.001
schedulers:
......
......@@ -8,6 +8,7 @@ metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_pretrained.tar
weights: output/yolov3_r34/model_final
num_classes: 80
use_fine_grained_loss: false
YOLOv3:
backbone: ResNet
......@@ -27,8 +28,7 @@ YOLOv3Head:
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
ignore_thresh: 0.7
label_smooth: true
yolo_loss: YOLOv3Loss
nms:
background_label: -1
keep_top_k: 100
......@@ -37,6 +37,11 @@ YOLOv3Head:
normalized: false
score_threshold: 0.01
YOLOv3Loss:
batch_size: 8
ignore_thresh: 0.7
label_smooth: false
LearningRate:
base_lr: 0.001
schedulers:
......
......@@ -9,6 +9,7 @@ map_type: 11point
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_pretrained.tar
weights: output/yolov3_r34_voc/model_final
num_classes: 20
use_fine_grained_loss: false
YOLOv3:
backbone: ResNet
......@@ -28,8 +29,7 @@ YOLOv3Head:
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
ignore_thresh: 0.7
label_smooth: false
yolo_loss: YOLOv3Loss
nms:
background_label: -1
keep_top_k: 100
......@@ -38,6 +38,11 @@ YOLOv3Head:
normalized: false
score_threshold: 0.01
YOLOv3Loss:
batch_size: 8
ignore_thresh: 0.7
label_smooth: false
LearningRate:
base_lr: 0.001
schedulers:
......
......@@ -37,6 +37,15 @@ TrainReader:
- !Permute
to_bgr: false
channel_first: True
# Gt2YoloTarget is only used when use_fine_grained_loss set as true,
# this operator will be deleted automatically if use_fine_grained_loss
# is set as false
- !Gt2YoloTarget
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
downsample_ratios: [32, 16, 8]
batch_size: 8
shuffle: true
mixup_epoch: 250
......
......@@ -74,6 +74,19 @@ list below can be viewed by `--help`
finetune_exclude_pretrained_params = ['cls_score','bbox_pred']
```
- Training YOLOv3 with fine grained YOLOv3 loss built by Paddle OPs in python
In order to facilitate the redesign of YOLOv3 loss function, we also provide fine grained YOLOv3 loss function building in python code by common Paddle OPs instead of using `fluid.layers.yolov3_loss`,
training YOLOv3 with python loss function as follows:
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u tools/train.py -c configs/yolov3_darknet.yml \
-o use_fine_grained_loss=true
```
Fine grained YOLOv3 loss code is defined in `ppdet/modeling/losses/yolo_loss.py`.
##### NOTES
- `CUDA_VISIBLE_DEVICES` can specify different gpu numbers. Such as: `export CUDA_VISIBLE_DEVICES=0,1,2,3`. GPU calculation rules can refer [FAQ](#faq)
......
......@@ -74,6 +74,19 @@ python tools/infer.py -c configs/faster_rcnn_r50_1x.yml --infer_img=demo/0000005
详细说明请参考[Transfer Learning](TRANSFER_LEARNING_cn.md)
- 使用Paddle OP组建的YOLOv3损失函数训练YOLOv3
为了便于用户重新设计修改YOLOv3的损失函数,我们也提供了不使用`fluid.layer.yolov3_loss`接口而是在python代码中使用Paddle OP的方式组建YOLOv3损失函数,
可通过如下命令用Paddle OP组建YOLOv3损失函数版本的YOLOv3模型:
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u tools/train.py -c configs/yolov3_darknet.yml \
-o use_fine_grained_loss=true
```
Paddle OP组建YOLOv3损失函数代码位于`ppdet/modeling/losses/yolo_loss.py`
#### 提示
- `CUDA_VISIBLE_DEVICES` 参数可以指定不同的GPU。例如: `export CUDA_VISIBLE_DEVICES=0,1,2,3`. GPU计算规则可以参考 [FAQ](#faq)
......
......@@ -36,6 +36,11 @@ weights: output/yolov3_darknet/model_final
# Number of classes, 80 for COCO and 20 for VOC.
num_classes: 80
# Whether use fine grained YOLOv3 loss, if true, build YOLOv3 loss by python code with common OPs,
# if false, use fluid.layer.yolov3_loss OP to calculate YOLOv3 loss, the former one is better
# for redesign YOLOv3 loss, the latter one is better for training by original YOLOv3 loss
use_fine_grained_loss: false
# YOLOv3 architecture, see https://arxiv.org/abs/1804.02767
YOLOv3:
......@@ -63,12 +68,8 @@ YOLOv3Head:
[116, 90], [156, 198], [373, 326]]
# L2 weight decay factor of batch normalization layer
norm_decay: 0.
# Ignore threshold for yolo_loss layer, 0.7 by default.
# Objectness loss will be ignored if a predcition bbox overlap a gtbox over ignore_thresh.
ignore_thresh: 0.7
# Whether use label smooth in yolo_loss layer
# It is recommended to set as true when only num_classes is very big
label_smooth: true
# use YOLOv3Loss, which will be defined in following YOLOv3Loss segmentation.
yolo_loss: YOLOv3Loss
# fluid.layers.multiclass_nms
# Non-max suppress for output prediction boxes, see multiclass_nms for following parameters.
# 1. Select detection bounding boxes with high scores larger than score_threshold.
......@@ -89,6 +90,18 @@ YOLOv3Head:
# Threshold to filter out bounding boxes with low confidence score.
score_threshold: 0.01
YOLOv3Loss:
# training batch size, this will be used when use_fine_grained_loss is set as True.
# ATTENTION: this should be same as batch size defined in YoloTrainFeed in fine
# grained loss mode.
batch_size: 8
# Ignore threshold for yolo_loss layer, 0.7 by default.
# Objectness loss will be ignored if a predcition bbox overlap a gtbox over ignore_thresh.
ignore_thresh: 0.7
# Whether use label smooth in yolo_loss layer
# It is recommended to set as true when only num_classes is very big
label_smooth: false
# Learning rate configuration
LearningRate:
# Base learning rate for training, 1e-3 by default.
......
......@@ -26,6 +26,7 @@ import logging
from ppdet.core.workspace import register, serializable
from .parallel_map import ParallelMap
from .transform.batch_operators import Gt2YoloTarget
__all__ = ['Reader', 'create_reader']
......@@ -192,6 +193,8 @@ class Reader(object):
class_aware_sampling=False,
worker_num=-1,
use_process=False,
use_fine_grained_loss=False,
num_classes=80,
bufsize=100,
memsize='3G',
inputs_def=None):
......@@ -204,6 +207,17 @@ class Reader(object):
self._sample_transforms = Compose(sample_transforms,
{'fields': self._fields})
self._batch_transforms = None
if use_fine_grained_loss:
for bt in batch_transforms:
if isinstance(bt, Gt2YoloTarget):
bt.num_classes = num_classes
elif batch_transforms:
batch_transforms = [
bt for bt in batch_transforms
if not isinstance(bt, Gt2YoloTarget)
]
if batch_transforms:
self._batch_transforms = Compose(batch_transforms,
{'fields': self._fields})
......@@ -376,7 +390,7 @@ class Reader(object):
self._parallel.stop()
def create_reader(cfg, max_iter=0):
def create_reader(cfg, max_iter=0, global_cfg=None):
"""
Return iterable data reader.
......@@ -386,6 +400,11 @@ def create_reader(cfg, max_iter=0):
if not isinstance(cfg, dict):
raise TypeError("The config should be a dict when creating reader.")
# synchornize use_fine_grained_loss/num_classes from global_cfg to reader cfg
if global_cfg:
cfg['use_fine_grained_loss'] = getattr(global_cfg,
'use_fine_grained_loss', False)
cfg['num_classes'] = getattr(global_cfg, 'num_classes', 80)
reader = Reader(**cfg)()
def _reader():
......
......@@ -26,9 +26,12 @@ import cv2
import numpy as np
from .operators import register_op, BaseOperator
from .op_helper import jaccard_overlap
logger = logging.getLogger(__name__)
__all__ = ['PadBatch', 'RandomShape', 'PadMultiScaleTest', 'Gt2YoloTarget']
@register_op
class PadBatch(BaseOperator):
......@@ -164,3 +167,81 @@ class PadMultiScaleTest(BaseOperator):
if not batch_input:
samples = samples[0]
return samples
@register_op
class Gt2YoloTarget(BaseOperator):
"""
Generate YOLOv3 targets by groud truth data, this operator is only used in
fine grained YOLOv3 loss mode
"""
def __init__(self, anchors, anchor_masks, downsample_ratios,
num_classes=80):
super(Gt2YoloTarget, self).__init__()
self.anchors = anchors
self.anchor_masks = anchor_masks
self.downsample_ratios = downsample_ratios
self.num_classes = num_classes
def __call__(self, samples, context=None):
assert len(self.anchor_masks) == len(self.downsample_ratios), \
"anchor_masks', and 'downsample_ratios' should have same length."
h, w = samples[0]['image'].shape[1:3]
an_hw = np.array(self.anchors) / np.array([[w, h]])
for sample in samples:
# im, gt_bbox, gt_class, gt_score = sample
im = sample['image']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
gt_score = sample['gt_score']
for i, (
mask, downsample_ratio
) in enumerate(zip(self.anchor_masks, self.downsample_ratios)):
grid_h = int(h / downsample_ratio)
grid_w = int(w / downsample_ratio)
target = np.zeros(
(len(mask), 6 + self.num_classes, grid_h, grid_w),
dtype=np.float32)
for b in range(gt_bbox.shape[0]):
gx, gy, gw, gh = gt_bbox[b, :]
cls = gt_class[b]
score = gt_score[b]
if gw <= 0. or gh <= 0. or score <= 0.:
continue
# find best match anchor index
best_iou = 0.
best_idx = -1
for an_idx in range(an_hw.shape[0]):
iou = jaccard_overlap(
[0., 0., gw, gh],
[0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
if iou > best_iou:
best_iou = iou
best_idx = an_idx
# gtbox should be regresed in this layes if best match
# anchor index in anchor mask of this layer
if best_idx in mask:
best_n = mask.index(best_idx)
gi = int(gx * grid_w)
gj = int(gy * grid_h)
# x, y, w, h, scale
target[best_n, 0, gj, gi] = gx * grid_w - gi
target[best_n, 1, gj, gi] = gy * grid_h - gj
target[best_n, 2, gj, gi] = np.log(
gw * w / self.anchors[best_idx][0])
target[best_n, 3, gj, gi] = np.log(
gh * h / self.anchors[best_idx][1])
target[best_n, 4, gj, gi] = 2.0 - gw * gh
# objectness record gt_score
target[best_n, 5, gj, gi] = score
# classification
target[best_n, 6 + cls, gj, gi] = 1.
sample['target{}'.format(i)] = target
return samples
......@@ -21,6 +21,7 @@ from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from ppdet.modeling.ops import MultiClassNMS
from ppdet.modeling.losses.yolo_loss import YOLOv3Loss
from ppdet.core.workspace import register
__all__ = ['YOLOv3Head']
......@@ -34,23 +35,20 @@ class YOLOv3Head(object):
Args:
norm_decay (float): weight decay for normalization layer weights
num_classes (int): number of output classes
ignore_thresh (float): threshold to ignore confidence loss
label_smooth (bool): whether to use label smoothing
anchors (list): anchors
anchor_masks (list): anchor masks
nms (object): an instance of `MultiClassNMS`
"""
__inject__ = ['nms']
__inject__ = ['yolo_loss', 'nms']
__shared__ = ['num_classes', 'weight_prefix_name']
def __init__(self,
norm_decay=0.,
num_classes=80,
ignore_thresh=0.7,
label_smooth=True,
anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]],
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
yolo_loss="YOLOv3Loss",
nms=MultiClassNMS(
score_threshold=0.01,
nms_top_k=1000,
......@@ -60,10 +58,9 @@ class YOLOv3Head(object):
weight_prefix_name=''):
self.norm_decay = norm_decay
self.num_classes = num_classes
self.ignore_thresh = ignore_thresh
self.label_smooth = label_smooth
self.anchor_masks = anchor_masks
self._parse_anchors(anchors)
self.yolo_loss = yolo_loss
self.nms = nms
self.prefix_name = weight_prefix_name
if isinstance(nms, dict):
......@@ -234,7 +231,7 @@ class YOLOv3Head(object):
return outputs
def get_loss(self, input, gt_box, gt_label, gt_score):
def get_loss(self, input, gt_box, gt_label, gt_score, targets):
"""
Get final loss of network of YOLOv3.
......@@ -243,6 +240,8 @@ class YOLOv3Head(object):
gt_box (Variable): The ground-truth boudding boxes.
gt_label (Variable): The ground-truth class labels.
gt_score (Variable): The ground-truth boudding boxes mixup scores.
targets ([Variables]): List of Variables, the targets for yolo
loss calculatation.
Returns:
loss (Variable): The loss Variable of YOLOv3 network.
......@@ -250,26 +249,10 @@ class YOLOv3Head(object):
"""
outputs = self._get_outputs(input, is_train=True)
losses = []
downsample = 32
for i, output in enumerate(outputs):
anchor_mask = self.anchor_masks[i]
loss = fluid.layers.yolov3_loss(
x=output,
gt_box=gt_box,
gt_label=gt_label,
gt_score=gt_score,
anchors=self.anchors,
anchor_mask=anchor_mask,
class_num=self.num_classes,
ignore_thresh=self.ignore_thresh,
downsample_ratio=downsample,
use_label_smooth=self.label_smooth,
name=self.prefix_name + "yolo_loss" + str(i))
losses.append(fluid.layers.reduce_mean(loss))
downsample //= 2
return sum(losses)
return self.yolo_loss(outputs, gt_box, gt_label, gt_score, targets,
self.anchors, self.anchor_masks,
self.mask_anchors, self.num_classes,
self.prefix_name)
def get_prediction(self, input, im_size):
"""
......
......@@ -38,11 +38,16 @@ class YOLOv3(object):
__category__ = 'architecture'
__inject__ = ['backbone', 'yolo_head']
__shared__ = ['use_fine_grained_loss']
def __init__(self, backbone, yolo_head='YOLOv3Head'):
def __init__(self,
backbone,
yolo_head='YOLOv3Head',
use_fine_grained_loss=False):
super(YOLOv3, self).__init__()
self.backbone = backbone
self.yolo_head = yolo_head
self.use_fine_grained_loss = use_fine_grained_loss
def build(self, feed_vars, mode='train'):
im = feed_vars['image']
......@@ -68,10 +73,19 @@ class YOLOv3(object):
gt_class = feed_vars['gt_class']
gt_score = feed_vars['gt_score']
return {
'loss': self.yolo_head.get_loss(body_feats, gt_bbox, gt_class,
gt_score)
}
# Get targets for splited yolo loss calculation
# YOLOv3 supports up to 3 output layers currently
targets = []
for i in range(3):
k = 'target{}'.format(i)
if k in feed_vars:
targets.append(feed_vars[k])
loss = self.yolo_head.get_loss(body_feats, gt_bbox, gt_class,
gt_score, targets)
total_loss = fluid.layers.sum(list(loss.values()))
loss.update({'loss': total_loss})
return loss
else:
im_size = feed_vars['im_size']
return self.yolo_head.get_prediction(body_feats, im_size)
......@@ -89,6 +103,28 @@ class YOLOv3(object):
'is_difficult': {'shape': [None, num_max_boxes],'dtype': 'int32', 'lod_level': 0},
}
# yapf: enable
if self.use_fine_grained_loss:
# yapf: disable
targets_def = {
'target0': {'shape': [None, 3, 86, 19, 19], 'dtype': 'float32', 'lod_level': 0},
'target1': {'shape': [None, 3, 86, 38, 38], 'dtype': 'float32', 'lod_level': 0},
'target2': {'shape': [None, 3, 86, 76, 76], 'dtype': 'float32', 'lod_level': 0},
}
# yapf: enable
downsample = 32
for k, mask in zip(targets_def.keys(), self.yolo_head.anchor_masks):
targets_def[k]['shape'][1] = len(mask)
targets_def[k]['shape'][2] = 6 + self.yolo_head.num_classes
targets_def[k]['shape'][3] = image_shape[
-2] // downsample if image_shape[-2] else None
targets_def[k]['shape'][4] = image_shape[
-1] // downsample if image_shape[-1] else None
downsample // 2
inputs_def.update(targets_def)
return inputs_def
def build_inputs(
......@@ -99,6 +135,8 @@ class YOLOv3(object):
use_dataloader=True,
iterable=False):
inputs_def = self._inputs_def(image_shape, num_max_boxes)
if self.use_fine_grained_loss:
fields.extend(['target0', 'target1', 'target2'])
feed_vars = OrderedDict([(key, fluid.data(
name=key,
shape=inputs_def[key]['shape'],
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from . import yolo_loss
from .yolo_loss import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from ppdet.core.workspace import register
__all__ = ['YOLOv3Loss']
@register
class YOLOv3Loss(object):
"""
Combined loss for YOLOv3 network
Args:
batch_size (int): training batch size
ignore_thresh (float): threshold to ignore confidence loss
label_smooth (bool): whether to use label smoothing
use_fine_grained_loss (bool): whether use fine grained YOLOv3 loss
instead of fluid.layers.yolov3_loss
"""
__shared__ = ['use_fine_grained_loss']
def __init__(self,
batch_size=8,
ignore_thresh=0.7,
label_smooth=True,
use_fine_grained_loss=False):
self._batch_size = batch_size
self._ignore_thresh = ignore_thresh
self._label_smooth = label_smooth
self._use_fine_grained_loss = use_fine_grained_loss
def __call__(self, outputs, gt_box, gt_label, gt_score, targets, anchors,
anchor_masks, mask_anchors, num_classes, prefix_name):
if self._use_fine_grained_loss:
return self._get_fine_grained_loss(
outputs, targets, gt_box, self._batch_size, num_classes,
mask_anchors, self._ignore_thresh)
else:
losses = []
downsample = 32
for i, output in enumerate(outputs):
anchor_mask = anchor_masks[i]
loss = fluid.layers.yolov3_loss(
x=output,
gt_box=gt_box,
gt_label=gt_label,
gt_score=gt_score,
anchors=anchors,
anchor_mask=anchor_mask,
class_num=num_classes,
ignore_thresh=self._ignore_thresh,
downsample_ratio=downsample,
use_label_smooth=self._label_smooth,
name=prefix_name + "yolo_loss" + str(i))
losses.append(fluid.layers.reduce_mean(loss))
downsample //= 2
return {'loss': sum(losses)}
def _get_fine_grained_loss(self, outputs, targets, gt_box, batch_size,
num_classes, mask_anchors, ignore_thresh):
"""
Calculate fine grained YOLOv3 loss
Args:
outputs ([Variables]): List of Variables, output of backbone stages
targets ([Variables]): List of Variables, The targets for yolo
loss calculatation.
gt_box (Variable): The ground-truth boudding boxes.
batch_size (int): The training batch size
num_classes (int): class num of dataset
mask_anchors ([[float]]): list of anchors in each output layer
ignore_thresh (float): prediction bbox overlap any gt_box greater
than ignore_thresh, objectness loss will
be ignored.
Returns:
Type: dict
xy_loss (Variable): YOLOv3 (x, y) coordinates loss
wh_loss (Variable): YOLOv3 (w, h) coordinates loss
obj_loss (Variable): YOLOv3 objectness score loss
cls_loss (Variable): YOLOv3 classification loss
"""
assert len(outputs) == len(targets), \
"YOLOv3 output layer number not equal target number"
downsample = 32
loss_xys, loss_whs, loss_objs, loss_clss = [], [], [], []
for i, (output, target,
anchors) in enumerate(zip(outputs, targets, mask_anchors)):
an_num = len(anchors) // 2
x, y, w, h, obj, cls = self._split_output(output, an_num,
num_classes)
tx, ty, tw, th, tscale, tobj, tcls = self._split_target(target)
tscale_tobj = tscale * tobj
loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(
x, tx) * tscale_tobj
loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
loss_y = fluid.layers.sigmoid_cross_entropy_with_logits(
y, ty) * tscale_tobj
loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
# NOTE: we refined loss function of (w, h) as L1Loss
loss_w = fluid.layers.abs(w - tw) * tscale_tobj
loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3])
loss_h = fluid.layers.abs(h - th) * tscale_tobj
loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3])
loss_obj_pos, loss_obj_neg = self._calc_obj_loss(
output, obj, tobj, gt_box, self._batch_size, anchors,
num_classes, downsample, self._ignore_thresh)
loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls, tcls)
loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0)
loss_cls = fluid.layers.reduce_sum(loss_cls, dim=[1, 2, 3, 4])
loss_xys.append(fluid.layers.reduce_mean(loss_x + loss_y))
loss_whs.append(fluid.layers.reduce_mean(loss_w + loss_h))
loss_objs.append(
fluid.layers.reduce_mean(loss_obj_pos + loss_obj_neg))
loss_clss.append(fluid.layers.reduce_mean(loss_cls))
downsample //= 2
return {
"loss_xy": fluid.layers.sum(loss_xys),
"loss_wh": fluid.layers.sum(loss_whs),
"loss_obj": fluid.layers.sum(loss_objs),
"loss_cls": fluid.layers.sum(loss_clss),
}
def _split_output(self, output, an_num, num_classes):
"""
Split output feature map to x, y, w, h, objectness, classification
along channel dimension
"""
x = fluid.layers.strided_slice(
output,
axes=[1],
starts=[0],
ends=[output.shape[1]],
strides=[5 + num_classes])
y = fluid.layers.strided_slice(
output,
axes=[1],
starts=[1],
ends=[output.shape[1]],
strides=[5 + num_classes])
w = fluid.layers.strided_slice(
output,
axes=[1],
starts=[2],
ends=[output.shape[1]],
strides=[5 + num_classes])
h = fluid.layers.strided_slice(
output,
axes=[1],
starts=[3],
ends=[output.shape[1]],
strides=[5 + num_classes])
obj = fluid.layers.strided_slice(
output,
axes=[1],
starts=[4],
ends=[output.shape[1]],
strides=[5 + num_classes])
clss = []
stride = output.shape[1] // an_num
for m in range(an_num):
clss.append(
fluid.layers.slice(
output,
axes=[1],
starts=[stride * m + 5],
ends=[stride * m + 5 + num_classes]))
cls = fluid.layers.transpose(
fluid.layers.stack(
clss, axis=1), perm=[0, 1, 3, 4, 2])
return (x, y, w, h, obj, cls)
def _split_target(self, target):
"""
split target to x, y, w, h, objectness, classification
along dimension 2
target is in shape [N, an_num, 6 + class_num, H, W]
"""
tx = target[:, :, 0, :, :]
ty = target[:, :, 1, :, :]
tw = target[:, :, 2, :, :]
th = target[:, :, 3, :, :]
tscale = target[:, :, 4, :, :]
tobj = target[:, :, 5, :, :]
tcls = fluid.layers.transpose(
target[:, :, 6:, :, :], perm=[0, 1, 3, 4, 2])
tcls.stop_gradient = True
return (tx, ty, tw, th, tscale, tobj, tcls)
def _calc_obj_loss(self, output, obj, tobj, gt_box, batch_size, anchors,
num_classes, downsample, ignore_thresh):
# A prediction bbox overlap any gt_bbox over ignore_thresh,
# objectness loss will be ignored, process as follows:
# 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here
# NOTE: img_size is set as 1.0 to get noramlized pred bbox
bbox, _ = fluid.layers.yolo_box(
x=output,
img_size=fluid.layers.ones(
shape=[batch_size, 2], dtype="int32"),
anchors=anchors,
class_num=num_classes,
conf_thresh=0.,
downsample_ratio=downsample,
clip_bbox=False)
# 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox
# and gt bbox in each sample
if batch_size > 1:
preds = fluid.layers.split(bbox, batch_size, dim=0)
gts = fluid.layers.split(gt_box, batch_size, dim=0)
else:
preds = [bbox]
gts = [gt_box]
ious = []
for pred, gt in zip(preds, gts):
def box_xywh2xyxy(box):
x = box[:, 0]
y = box[:, 1]
w = box[:, 2]
h = box[:, 3]
return fluid.layers.stack(
[
x - w / 2.,
y - h / 2.,
x + w / 2.,
y + h / 2.,
], axis=1)
pred = fluid.layers.squeeze(pred, axes=[0])
gt = box_xywh2xyxy(fluid.layers.squeeze(gt, axes=[0]))
ious.append(fluid.layers.iou_similarity(pred, gt))
iou = fluid.layers.stack(ious, axis=0)
# 3. Get iou_mask by IoU between gt bbox and prediction bbox,
# Get obj_mask by tobj(holds gt_score), calculate objectness loss
max_iou = fluid.layers.reduce_max(iou, dim=-1)
iou_mask = fluid.layers.cast(max_iou <= ignore_thresh, dtype="float32")
output_shape = fluid.layers.shape(output)
an_num = len(anchors) // 2
iou_mask = fluid.layers.reshape(iou_mask, (-1, an_num, output_shape[2],
output_shape[3]))
iou_mask.stop_gradient = True
# NOTE: tobj holds gt_score, obj_mask holds object existence mask
obj_mask = fluid.layers.cast(tobj > 0., dtype="float32")
obj_mask.stop_gradient = True
# For positive objectness grids, objectness loss should be calculated
# For negative objectness grids, objectness loss is calculated only iou_mask == 1.0
loss_obj = fluid.layers.sigmoid_cross_entropy_with_logits(obj, obj_mask)
loss_obj_pos = fluid.layers.reduce_sum(loss_obj * tobj, dim=[1, 2, 3])
loss_obj_neg = fluid.layers.reduce_sum(
loss_obj * (1.0 - obj_mask) * iou_mask, dim=[1, 2, 3])
return loss_obj_pos, loss_obj_neg
......@@ -194,8 +194,8 @@ def main():
checkpoint.load_params(
exe, train_prog, cfg.pretrain_weights, ignore_params=ignore_params)
train_reader = create_reader(cfg.TrainReader,
(cfg.max_iters - start_iter) * devices_num)
train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) *
devices_num, cfg)
train_loader.set_sample_list_generator(train_reader, place)
# whether output bbox is normalized in model output layer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册