diff --git a/configs/dcn/yolov3_r50vd_dcn.yml b/configs/dcn/yolov3_r50vd_dcn.yml index 6a1c9471a85c523cb7a234d5e68f83b0b3436bbc..e16aac604e51c9a2f14ab8b6f92018ff3c1dfaff 100755 --- a/configs/dcn/yolov3_r50vd_dcn.yml +++ b/configs/dcn/yolov3_r50vd_dcn.yml @@ -8,6 +8,7 @@ metric: COCO pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar weights: output/yolov3_r50vd_dcn/model_final num_classes: 80 +use_fine_grained_loss: false YOLOv3: backbone: ResNet @@ -29,8 +30,7 @@ YOLOv3Head: [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]] norm_decay: 0. - ignore_thresh: 0.7 - label_smooth: true + yolo_loss: YOLOv3Loss nms: background_label: -1 keep_top_k: 100 @@ -39,6 +39,11 @@ YOLOv3Head: normalized: false score_threshold: 0.01 +YOLOv3Loss: + batch_size: 8 + ignore_thresh: 0.7 + label_smooth: false + LearningRate: base_lr: 0.001 schedulers: diff --git a/configs/dcn/yolov3_r50vd_dcn_obj365_pretrained_coco.yml b/configs/dcn/yolov3_r50vd_dcn_obj365_pretrained_coco.yml index 32929c51f302217a3832b749f61f229b744d3c2c..707e928df2b7a2e60b3b7899e80abb620fe6e5bd 100755 --- a/configs/dcn/yolov3_r50vd_dcn_obj365_pretrained_coco.yml +++ b/configs/dcn/yolov3_r50vd_dcn_obj365_pretrained_coco.yml @@ -8,6 +8,7 @@ metric: COCO pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/ResNet50_vd_obj365_pretrained.tar weights: output/yolov3_r50vd_dcn_obj365_pretrained_coco/model_final num_classes: 80 +use_fine_grained_loss: false YOLOv3: backbone: ResNet @@ -29,8 +30,7 @@ YOLOv3Head: [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]] norm_decay: 0. - ignore_thresh: 0.7 - label_smooth: true + yolo_loss: YOLOv3Loss nms: background_label: -1 keep_top_k: 100 @@ -39,6 +39,11 @@ YOLOv3Head: normalized: false score_threshold: 0.01 +YOLOv3Loss: + batch_size: 8 + ignore_thresh: 0.7 + label_smooth: false + LearningRate: base_lr: 0.001 schedulers: diff --git a/configs/yolov3_darknet.yml b/configs/yolov3_darknet.yml index bbb4df712bc658bff8277c66ab405fc39386c9a1..22ae6d91429dc8101abcc60736f0d065cc04efee 100644 --- a/configs/yolov3_darknet.yml +++ b/configs/yolov3_darknet.yml @@ -8,6 +8,7 @@ metric: COCO pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_pretrained.tar weights: output/yolov3_darknet/model_final num_classes: 80 +use_fine_grained_loss: false YOLOv3: backbone: DarkNet @@ -24,8 +25,7 @@ YOLOv3Head: [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]] norm_decay: 0. - ignore_thresh: 0.7 - label_smooth: true + yolo_loss: YOLOv3Loss nms: background_label: -1 keep_top_k: 100 @@ -34,6 +34,11 @@ YOLOv3Head: normalized: false score_threshold: 0.01 +YOLOv3Loss: + batch_size: 8 + ignore_thresh: 0.7 + label_smooth: false + LearningRate: base_lr: 0.001 schedulers: diff --git a/configs/yolov3_darknet_voc.yml b/configs/yolov3_darknet_voc.yml index df74e03605a1cb4a43cfd49ba575bccc520ee4ff..f8ff843bb28a0afdf5188616233515cb0e8cfdc7 100644 --- a/configs/yolov3_darknet_voc.yml +++ b/configs/yolov3_darknet_voc.yml @@ -9,6 +9,7 @@ map_type: 11point pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_pretrained.tar weights: output/yolov3_darknet_voc/model_final num_classes: 20 +use_fine_grained_loss: false YOLOv3: backbone: DarkNet @@ -25,8 +26,7 @@ YOLOv3Head: [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]] norm_decay: 0. - ignore_thresh: 0.7 - label_smooth: false + yolo_loss: YOLOv3Loss nms: background_label: -1 keep_top_k: 100 @@ -35,6 +35,11 @@ YOLOv3Head: normalized: false score_threshold: 0.01 +YOLOv3Loss: + batch_size: 8 + ignore_thresh: 0.7 + label_smooth: true + LearningRate: base_lr: 0.001 schedulers: diff --git a/configs/yolov3_mobilenet_v1.yml b/configs/yolov3_mobilenet_v1.yml index 1e9bcddd0e0aa592e06813c6cf79124ad50707ac..e27900b8bae728954ad0be35f6eefba082f1e194 100644 --- a/configs/yolov3_mobilenet_v1.yml +++ b/configs/yolov3_mobilenet_v1.yml @@ -8,6 +8,7 @@ metric: COCO pretrain_weights: http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar weights: output/yolov3_mobilenet_v1/model_final num_classes: 80 +use_fine_grained_loss: false YOLOv3: backbone: MobileNet @@ -25,8 +26,7 @@ YOLOv3Head: [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]] norm_decay: 0. - ignore_thresh: 0.7 - label_smooth: true + yolo_loss: YOLOv3Loss nms: background_label: -1 keep_top_k: 100 @@ -35,6 +35,11 @@ YOLOv3Head: normalized: false score_threshold: 0.01 +YOLOv3Loss: + batch_size: 8 + ignore_thresh: 0.7 + label_smooth: false + LearningRate: base_lr: 0.001 schedulers: diff --git a/configs/yolov3_mobilenet_v1_fruit.yml b/configs/yolov3_mobilenet_v1_fruit.yml index 84bc71b9fddcd6e64f4093026f4781093e8bab97..e89a600bbe7c149846a2c6695a62119ef24a0a9a 100644 --- a/configs/yolov3_mobilenet_v1_fruit.yml +++ b/configs/yolov3_mobilenet_v1_fruit.yml @@ -10,6 +10,7 @@ pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mob weights: output/yolov3_mobilenet_v1_fruit/best_model num_classes: 3 finetune_exclude_pretrained_params: ['yolo_output'] +use_fine_grained_loss: false YOLOv3: backbone: MobileNet @@ -27,8 +28,7 @@ YOLOv3Head: [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]] norm_decay: 0. - ignore_thresh: 0.7 - label_smooth: true + yolo_loss: YOLOv3Loss nms: background_label: -1 keep_top_k: 100 @@ -37,6 +37,11 @@ YOLOv3Head: normalized: false score_threshold: 0.01 +YOLOv3Loss: + batch_size: 8 + ignore_thresh: 0.7 + label_smooth: true + LearningRate: base_lr: 0.00001 schedulers: diff --git a/configs/yolov3_mobilenet_v1_voc.yml b/configs/yolov3_mobilenet_v1_voc.yml index d8734f0b7feabf02e9ab903293ab0af392609b90..44cc9ed6fc941b046ed2c765772ec303984289d5 100644 --- a/configs/yolov3_mobilenet_v1_voc.yml +++ b/configs/yolov3_mobilenet_v1_voc.yml @@ -9,6 +9,7 @@ map_type: 11point pretrain_weights: http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar weights: output/yolov3_mobilenet_v1_voc/model_final num_classes: 20 +use_fine_grained_loss: false YOLOv3: backbone: MobileNet @@ -26,8 +27,7 @@ YOLOv3Head: [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]] norm_decay: 0. - ignore_thresh: 0.7 - label_smooth: false + yolo_loss: YOLOv3Loss nms: background_label: -1 keep_top_k: 100 @@ -36,6 +36,11 @@ YOLOv3Head: normalized: false score_threshold: 0.01 +YOLOv3Loss: + batch_size: 8 + ignore_thresh: 0.7 + label_smooth: false + LearningRate: base_lr: 0.001 schedulers: diff --git a/configs/yolov3_r34.yml b/configs/yolov3_r34.yml index db7fa056c0bc70e60b41abf7d752c628ecdb729f..765190802f3f85ac172ab6a8e3ba54ab54da347c 100644 --- a/configs/yolov3_r34.yml +++ b/configs/yolov3_r34.yml @@ -8,6 +8,7 @@ metric: COCO pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_pretrained.tar weights: output/yolov3_r34/model_final num_classes: 80 +use_fine_grained_loss: false YOLOv3: backbone: ResNet @@ -27,8 +28,7 @@ YOLOv3Head: [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]] norm_decay: 0. - ignore_thresh: 0.7 - label_smooth: true + yolo_loss: YOLOv3Loss nms: background_label: -1 keep_top_k: 100 @@ -37,6 +37,11 @@ YOLOv3Head: normalized: false score_threshold: 0.01 +YOLOv3Loss: + batch_size: 8 + ignore_thresh: 0.7 + label_smooth: false + LearningRate: base_lr: 0.001 schedulers: diff --git a/configs/yolov3_r34_voc.yml b/configs/yolov3_r34_voc.yml index b8db29bb7ffd52ab352186c4bd54c110ab33a2eb..8022ffd1605e593996d6ff3cb8e2d28c71c749cf 100644 --- a/configs/yolov3_r34_voc.yml +++ b/configs/yolov3_r34_voc.yml @@ -9,6 +9,7 @@ map_type: 11point pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_pretrained.tar weights: output/yolov3_r34_voc/model_final num_classes: 20 +use_fine_grained_loss: false YOLOv3: backbone: ResNet @@ -28,8 +29,7 @@ YOLOv3Head: [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]] norm_decay: 0. - ignore_thresh: 0.7 - label_smooth: false + yolo_loss: YOLOv3Loss nms: background_label: -1 keep_top_k: 100 @@ -38,6 +38,11 @@ YOLOv3Head: normalized: false score_threshold: 0.01 +YOLOv3Loss: + batch_size: 8 + ignore_thresh: 0.7 + label_smooth: false + LearningRate: base_lr: 0.001 schedulers: diff --git a/configs/yolov3_reader.yml b/configs/yolov3_reader.yml index 70941aaf27d7f2992f6ea4c4b6e4ba4a6ff9fbf0..e539408d13c199c3bd92bce56167f832c7a77d8f 100644 --- a/configs/yolov3_reader.yml +++ b/configs/yolov3_reader.yml @@ -37,6 +37,15 @@ TrainReader: - !Permute to_bgr: false channel_first: True + # Gt2YoloTarget is only used when use_fine_grained_loss set as true, + # this operator will be deleted automatically if use_fine_grained_loss + # is set as false + - !Gt2YoloTarget + anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + anchors: [[10, 13], [16, 30], [33, 23], + [30, 61], [62, 45], [59, 119], + [116, 90], [156, 198], [373, 326]] + downsample_ratios: [32, 16, 8] batch_size: 8 shuffle: true mixup_epoch: 250 diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md index 15b2a9a0f3bc51ac636ae2b51206f4f08f4739c0..953dc490384d153513ab5ccb435f882d3cbf5fad 100644 --- a/docs/GETTING_STARTED.md +++ b/docs/GETTING_STARTED.md @@ -74,6 +74,19 @@ list below can be viewed by `--help` finetune_exclude_pretrained_params = ['cls_score','bbox_pred'] ``` +- Training YOLOv3 with fine grained YOLOv3 loss built by Paddle OPs in python + + In order to facilitate the redesign of YOLOv3 loss function, we also provide fine grained YOLOv3 loss function building in python code by common Paddle OPs instead of using `fluid.layers.yolov3_loss`, + training YOLOv3 with python loss function as follows: + + ```bash + export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + python -u tools/train.py -c configs/yolov3_darknet.yml \ + -o use_fine_grained_loss=true + ``` + + Fine grained YOLOv3 loss code is defined in `ppdet/modeling/losses/yolo_loss.py`. + ##### NOTES - `CUDA_VISIBLE_DEVICES` can specify different gpu numbers. Such as: `export CUDA_VISIBLE_DEVICES=0,1,2,3`. GPU calculation rules can refer [FAQ](#faq) diff --git a/docs/GETTING_STARTED_cn.md b/docs/GETTING_STARTED_cn.md index 916eccc7986c97a1800e7791c1bc55760451a498..914019376d3f05be31a96047c3e34b30c34e5275 100644 --- a/docs/GETTING_STARTED_cn.md +++ b/docs/GETTING_STARTED_cn.md @@ -74,6 +74,19 @@ python tools/infer.py -c configs/faster_rcnn_r50_1x.yml --infer_img=demo/0000005 详细说明请参考[Transfer Learning](TRANSFER_LEARNING_cn.md) +- 使用Paddle OP组建的YOLOv3损失函数训练YOLOv3 + + 为了便于用户重新设计修改YOLOv3的损失函数,我们也提供了不使用`fluid.layer.yolov3_loss`接口而是在python代码中使用Paddle OP的方式组建YOLOv3损失函数, + 可通过如下命令用Paddle OP组建YOLOv3损失函数版本的YOLOv3模型: + + ```bash + export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + python -u tools/train.py -c configs/yolov3_darknet.yml \ + -o use_fine_grained_loss=true + ``` + + Paddle OP组建YOLOv3损失函数代码位于`ppdet/modeling/losses/yolo_loss.py` + #### 提示 - `CUDA_VISIBLE_DEVICES` 参数可以指定不同的GPU。例如: `export CUDA_VISIBLE_DEVICES=0,1,2,3`. GPU计算规则可以参考 [FAQ](#faq) diff --git a/docs/config_example/yolov3_darknet.yml b/docs/config_example/yolov3_darknet.yml index 65f479b72bb2504f84bbed5fed20afd78bad0ea9..0bd11214fd2092054334fe8a5644ded848d2e09d 100644 --- a/docs/config_example/yolov3_darknet.yml +++ b/docs/config_example/yolov3_darknet.yml @@ -36,6 +36,11 @@ weights: output/yolov3_darknet/model_final # Number of classes, 80 for COCO and 20 for VOC. num_classes: 80 +# Whether use fine grained YOLOv3 loss, if true, build YOLOv3 loss by python code with common OPs, +# if false, use fluid.layer.yolov3_loss OP to calculate YOLOv3 loss, the former one is better +# for redesign YOLOv3 loss, the latter one is better for training by original YOLOv3 loss +use_fine_grained_loss: false + # YOLOv3 architecture, see https://arxiv.org/abs/1804.02767 YOLOv3: @@ -63,12 +68,8 @@ YOLOv3Head: [116, 90], [156, 198], [373, 326]] # L2 weight decay factor of batch normalization layer norm_decay: 0. - # Ignore threshold for yolo_loss layer, 0.7 by default. - # Objectness loss will be ignored if a predcition bbox overlap a gtbox over ignore_thresh. - ignore_thresh: 0.7 - # Whether use label smooth in yolo_loss layer - # It is recommended to set as true when only num_classes is very big - label_smooth: true + # use YOLOv3Loss, which will be defined in following YOLOv3Loss segmentation. + yolo_loss: YOLOv3Loss # fluid.layers.multiclass_nms # Non-max suppress for output prediction boxes, see multiclass_nms for following parameters. # 1. Select detection bounding boxes with high scores larger than score_threshold. @@ -89,6 +90,18 @@ YOLOv3Head: # Threshold to filter out bounding boxes with low confidence score. score_threshold: 0.01 +YOLOv3Loss: + # training batch size, this will be used when use_fine_grained_loss is set as True. + # ATTENTION: this should be same as batch size defined in YoloTrainFeed in fine + # grained loss mode. + batch_size: 8 + # Ignore threshold for yolo_loss layer, 0.7 by default. + # Objectness loss will be ignored if a predcition bbox overlap a gtbox over ignore_thresh. + ignore_thresh: 0.7 + # Whether use label smooth in yolo_loss layer + # It is recommended to set as true when only num_classes is very big + label_smooth: false + # Learning rate configuration LearningRate: # Base learning rate for training, 1e-3 by default. diff --git a/ppdet/data/reader.py b/ppdet/data/reader.py index f905bbdbab963af27cc7a4022d0518aa6444a009..5f0c97a60d759082cd2298c11dce0905e047d6dc 100644 --- a/ppdet/data/reader.py +++ b/ppdet/data/reader.py @@ -26,6 +26,7 @@ import logging from ppdet.core.workspace import register, serializable from .parallel_map import ParallelMap +from .transform.batch_operators import Gt2YoloTarget __all__ = ['Reader', 'create_reader'] @@ -192,6 +193,8 @@ class Reader(object): class_aware_sampling=False, worker_num=-1, use_process=False, + use_fine_grained_loss=False, + num_classes=80, bufsize=100, memsize='3G', inputs_def=None): @@ -204,6 +207,17 @@ class Reader(object): self._sample_transforms = Compose(sample_transforms, {'fields': self._fields}) self._batch_transforms = None + + if use_fine_grained_loss: + for bt in batch_transforms: + if isinstance(bt, Gt2YoloTarget): + bt.num_classes = num_classes + elif batch_transforms: + batch_transforms = [ + bt for bt in batch_transforms + if not isinstance(bt, Gt2YoloTarget) + ] + if batch_transforms: self._batch_transforms = Compose(batch_transforms, {'fields': self._fields}) @@ -376,7 +390,7 @@ class Reader(object): self._parallel.stop() -def create_reader(cfg, max_iter=0): +def create_reader(cfg, max_iter=0, global_cfg=None): """ Return iterable data reader. @@ -386,6 +400,11 @@ def create_reader(cfg, max_iter=0): if not isinstance(cfg, dict): raise TypeError("The config should be a dict when creating reader.") + # synchornize use_fine_grained_loss/num_classes from global_cfg to reader cfg + if global_cfg: + cfg['use_fine_grained_loss'] = getattr(global_cfg, + 'use_fine_grained_loss', False) + cfg['num_classes'] = getattr(global_cfg, 'num_classes', 80) reader = Reader(**cfg)() def _reader(): diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index 9cc88d5207601e6777d03a5086e1d813f39fe4e7..ff1035f0f7c4f529fed4a09592664ddb87f244f2 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -26,9 +26,12 @@ import cv2 import numpy as np from .operators import register_op, BaseOperator +from .op_helper import jaccard_overlap logger = logging.getLogger(__name__) +__all__ = ['PadBatch', 'RandomShape', 'PadMultiScaleTest', 'Gt2YoloTarget'] + @register_op class PadBatch(BaseOperator): @@ -164,3 +167,81 @@ class PadMultiScaleTest(BaseOperator): if not batch_input: samples = samples[0] return samples + + +@register_op +class Gt2YoloTarget(BaseOperator): + """ + Generate YOLOv3 targets by groud truth data, this operator is only used in + fine grained YOLOv3 loss mode + """ + + def __init__(self, anchors, anchor_masks, downsample_ratios, + num_classes=80): + super(Gt2YoloTarget, self).__init__() + self.anchors = anchors + self.anchor_masks = anchor_masks + self.downsample_ratios = downsample_ratios + self.num_classes = num_classes + + def __call__(self, samples, context=None): + assert len(self.anchor_masks) == len(self.downsample_ratios), \ + "anchor_masks', and 'downsample_ratios' should have same length." + + h, w = samples[0]['image'].shape[1:3] + an_hw = np.array(self.anchors) / np.array([[w, h]]) + for sample in samples: + # im, gt_bbox, gt_class, gt_score = sample + im = sample['image'] + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + gt_score = sample['gt_score'] + for i, ( + mask, downsample_ratio + ) in enumerate(zip(self.anchor_masks, self.downsample_ratios)): + grid_h = int(h / downsample_ratio) + grid_w = int(w / downsample_ratio) + target = np.zeros( + (len(mask), 6 + self.num_classes, grid_h, grid_w), + dtype=np.float32) + for b in range(gt_bbox.shape[0]): + gx, gy, gw, gh = gt_bbox[b, :] + cls = gt_class[b] + score = gt_score[b] + if gw <= 0. or gh <= 0. or score <= 0.: + continue + + # find best match anchor index + best_iou = 0. + best_idx = -1 + for an_idx in range(an_hw.shape[0]): + iou = jaccard_overlap( + [0., 0., gw, gh], + [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]]) + if iou > best_iou: + best_iou = iou + best_idx = an_idx + + # gtbox should be regresed in this layes if best match + # anchor index in anchor mask of this layer + if best_idx in mask: + best_n = mask.index(best_idx) + gi = int(gx * grid_w) + gj = int(gy * grid_h) + + # x, y, w, h, scale + target[best_n, 0, gj, gi] = gx * grid_w - gi + target[best_n, 1, gj, gi] = gy * grid_h - gj + target[best_n, 2, gj, gi] = np.log( + gw * w / self.anchors[best_idx][0]) + target[best_n, 3, gj, gi] = np.log( + gh * h / self.anchors[best_idx][1]) + target[best_n, 4, gj, gi] = 2.0 - gw * gh + + # objectness record gt_score + target[best_n, 5, gj, gi] = score + + # classification + target[best_n, 6 + cls, gj, gi] = 1. + sample['target{}'.format(i)] = target + return samples diff --git a/ppdet/modeling/anchor_heads/yolo_head.py b/ppdet/modeling/anchor_heads/yolo_head.py index 7e756f267762827b3666e8143dce9a695fc526e2..a140f45c4e0def3eea317ec92de2c4715622c788 100644 --- a/ppdet/modeling/anchor_heads/yolo_head.py +++ b/ppdet/modeling/anchor_heads/yolo_head.py @@ -21,6 +21,7 @@ from paddle.fluid.param_attr import ParamAttr from paddle.fluid.regularizer import L2Decay from ppdet.modeling.ops import MultiClassNMS +from ppdet.modeling.losses.yolo_loss import YOLOv3Loss from ppdet.core.workspace import register __all__ = ['YOLOv3Head'] @@ -34,23 +35,20 @@ class YOLOv3Head(object): Args: norm_decay (float): weight decay for normalization layer weights num_classes (int): number of output classes - ignore_thresh (float): threshold to ignore confidence loss - label_smooth (bool): whether to use label smoothing anchors (list): anchors anchor_masks (list): anchor masks nms (object): an instance of `MultiClassNMS` """ - __inject__ = ['nms'] + __inject__ = ['yolo_loss', 'nms'] __shared__ = ['num_classes', 'weight_prefix_name'] def __init__(self, norm_decay=0., num_classes=80, - ignore_thresh=0.7, - label_smooth=True, anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]], anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]], + yolo_loss="YOLOv3Loss", nms=MultiClassNMS( score_threshold=0.01, nms_top_k=1000, @@ -60,10 +58,9 @@ class YOLOv3Head(object): weight_prefix_name=''): self.norm_decay = norm_decay self.num_classes = num_classes - self.ignore_thresh = ignore_thresh - self.label_smooth = label_smooth self.anchor_masks = anchor_masks self._parse_anchors(anchors) + self.yolo_loss = yolo_loss self.nms = nms self.prefix_name = weight_prefix_name if isinstance(nms, dict): @@ -234,7 +231,7 @@ class YOLOv3Head(object): return outputs - def get_loss(self, input, gt_box, gt_label, gt_score): + def get_loss(self, input, gt_box, gt_label, gt_score, targets): """ Get final loss of network of YOLOv3. @@ -243,6 +240,8 @@ class YOLOv3Head(object): gt_box (Variable): The ground-truth boudding boxes. gt_label (Variable): The ground-truth class labels. gt_score (Variable): The ground-truth boudding boxes mixup scores. + targets ([Variables]): List of Variables, the targets for yolo + loss calculatation. Returns: loss (Variable): The loss Variable of YOLOv3 network. @@ -250,26 +249,10 @@ class YOLOv3Head(object): """ outputs = self._get_outputs(input, is_train=True) - losses = [] - downsample = 32 - for i, output in enumerate(outputs): - anchor_mask = self.anchor_masks[i] - loss = fluid.layers.yolov3_loss( - x=output, - gt_box=gt_box, - gt_label=gt_label, - gt_score=gt_score, - anchors=self.anchors, - anchor_mask=anchor_mask, - class_num=self.num_classes, - ignore_thresh=self.ignore_thresh, - downsample_ratio=downsample, - use_label_smooth=self.label_smooth, - name=self.prefix_name + "yolo_loss" + str(i)) - losses.append(fluid.layers.reduce_mean(loss)) - downsample //= 2 - - return sum(losses) + return self.yolo_loss(outputs, gt_box, gt_label, gt_score, targets, + self.anchors, self.anchor_masks, + self.mask_anchors, self.num_classes, + self.prefix_name) def get_prediction(self, input, im_size): """ diff --git a/ppdet/modeling/architectures/yolov3.py b/ppdet/modeling/architectures/yolov3.py index d7e3948fe3f0eddcebbfc8d5a7b7f569e6cb267a..e82f3b1c04999d4d8ac324e06cd95cca46e44e2f 100644 --- a/ppdet/modeling/architectures/yolov3.py +++ b/ppdet/modeling/architectures/yolov3.py @@ -38,11 +38,16 @@ class YOLOv3(object): __category__ = 'architecture' __inject__ = ['backbone', 'yolo_head'] + __shared__ = ['use_fine_grained_loss'] - def __init__(self, backbone, yolo_head='YOLOv3Head'): + def __init__(self, + backbone, + yolo_head='YOLOv3Head', + use_fine_grained_loss=False): super(YOLOv3, self).__init__() self.backbone = backbone self.yolo_head = yolo_head + self.use_fine_grained_loss = use_fine_grained_loss def build(self, feed_vars, mode='train'): im = feed_vars['image'] @@ -68,10 +73,19 @@ class YOLOv3(object): gt_class = feed_vars['gt_class'] gt_score = feed_vars['gt_score'] - return { - 'loss': self.yolo_head.get_loss(body_feats, gt_bbox, gt_class, - gt_score) - } + # Get targets for splited yolo loss calculation + # YOLOv3 supports up to 3 output layers currently + targets = [] + for i in range(3): + k = 'target{}'.format(i) + if k in feed_vars: + targets.append(feed_vars[k]) + + loss = self.yolo_head.get_loss(body_feats, gt_bbox, gt_class, + gt_score, targets) + total_loss = fluid.layers.sum(list(loss.values())) + loss.update({'loss': total_loss}) + return loss else: im_size = feed_vars['im_size'] return self.yolo_head.get_prediction(body_feats, im_size) @@ -89,6 +103,28 @@ class YOLOv3(object): 'is_difficult': {'shape': [None, num_max_boxes],'dtype': 'int32', 'lod_level': 0}, } # yapf: enable + + if self.use_fine_grained_loss: + # yapf: disable + targets_def = { + 'target0': {'shape': [None, 3, 86, 19, 19], 'dtype': 'float32', 'lod_level': 0}, + 'target1': {'shape': [None, 3, 86, 38, 38], 'dtype': 'float32', 'lod_level': 0}, + 'target2': {'shape': [None, 3, 86, 76, 76], 'dtype': 'float32', 'lod_level': 0}, + } + # yapf: enable + + downsample = 32 + for k, mask in zip(targets_def.keys(), self.yolo_head.anchor_masks): + targets_def[k]['shape'][1] = len(mask) + targets_def[k]['shape'][2] = 6 + self.yolo_head.num_classes + targets_def[k]['shape'][3] = image_shape[ + -2] // downsample if image_shape[-2] else None + targets_def[k]['shape'][4] = image_shape[ + -1] // downsample if image_shape[-1] else None + downsample // 2 + + inputs_def.update(targets_def) + return inputs_def def build_inputs( @@ -99,6 +135,8 @@ class YOLOv3(object): use_dataloader=True, iterable=False): inputs_def = self._inputs_def(image_shape, num_max_boxes) + if self.use_fine_grained_loss: + fields.extend(['target0', 'target1', 'target2']) feed_vars = OrderedDict([(key, fluid.data( name=key, shape=inputs_def[key]['shape'], diff --git a/ppdet/modeling/losses/__init__.py b/ppdet/modeling/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05a45eeb24fb09ce49b713997d67d3756e9d3e5f --- /dev/null +++ b/ppdet/modeling/losses/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +from . import yolo_loss + +from .yolo_loss import * diff --git a/ppdet/modeling/losses/yolo_loss.py b/ppdet/modeling/losses/yolo_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ad7e766a2fbc503934a395be0c0dec02eb14eed6 --- /dev/null +++ b/ppdet/modeling/losses/yolo_loss.py @@ -0,0 +1,289 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from ppdet.core.workspace import register + +__all__ = ['YOLOv3Loss'] + + +@register +class YOLOv3Loss(object): + """ + Combined loss for YOLOv3 network + + Args: + batch_size (int): training batch size + ignore_thresh (float): threshold to ignore confidence loss + label_smooth (bool): whether to use label smoothing + use_fine_grained_loss (bool): whether use fine grained YOLOv3 loss + instead of fluid.layers.yolov3_loss + """ + __shared__ = ['use_fine_grained_loss'] + + def __init__(self, + batch_size=8, + ignore_thresh=0.7, + label_smooth=True, + use_fine_grained_loss=False): + self._batch_size = batch_size + self._ignore_thresh = ignore_thresh + self._label_smooth = label_smooth + self._use_fine_grained_loss = use_fine_grained_loss + + def __call__(self, outputs, gt_box, gt_label, gt_score, targets, anchors, + anchor_masks, mask_anchors, num_classes, prefix_name): + if self._use_fine_grained_loss: + return self._get_fine_grained_loss( + outputs, targets, gt_box, self._batch_size, num_classes, + mask_anchors, self._ignore_thresh) + else: + losses = [] + downsample = 32 + for i, output in enumerate(outputs): + anchor_mask = anchor_masks[i] + loss = fluid.layers.yolov3_loss( + x=output, + gt_box=gt_box, + gt_label=gt_label, + gt_score=gt_score, + anchors=anchors, + anchor_mask=anchor_mask, + class_num=num_classes, + ignore_thresh=self._ignore_thresh, + downsample_ratio=downsample, + use_label_smooth=self._label_smooth, + name=prefix_name + "yolo_loss" + str(i)) + losses.append(fluid.layers.reduce_mean(loss)) + downsample //= 2 + + return {'loss': sum(losses)} + + def _get_fine_grained_loss(self, outputs, targets, gt_box, batch_size, + num_classes, mask_anchors, ignore_thresh): + """ + Calculate fine grained YOLOv3 loss + + Args: + outputs ([Variables]): List of Variables, output of backbone stages + targets ([Variables]): List of Variables, The targets for yolo + loss calculatation. + gt_box (Variable): The ground-truth boudding boxes. + batch_size (int): The training batch size + num_classes (int): class num of dataset + mask_anchors ([[float]]): list of anchors in each output layer + ignore_thresh (float): prediction bbox overlap any gt_box greater + than ignore_thresh, objectness loss will + be ignored. + + Returns: + Type: dict + xy_loss (Variable): YOLOv3 (x, y) coordinates loss + wh_loss (Variable): YOLOv3 (w, h) coordinates loss + obj_loss (Variable): YOLOv3 objectness score loss + cls_loss (Variable): YOLOv3 classification loss + + """ + + assert len(outputs) == len(targets), \ + "YOLOv3 output layer number not equal target number" + + downsample = 32 + loss_xys, loss_whs, loss_objs, loss_clss = [], [], [], [] + for i, (output, target, + anchors) in enumerate(zip(outputs, targets, mask_anchors)): + an_num = len(anchors) // 2 + x, y, w, h, obj, cls = self._split_output(output, an_num, + num_classes) + tx, ty, tw, th, tscale, tobj, tcls = self._split_target(target) + + tscale_tobj = tscale * tobj + loss_x = fluid.layers.sigmoid_cross_entropy_with_logits( + x, tx) * tscale_tobj + loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3]) + loss_y = fluid.layers.sigmoid_cross_entropy_with_logits( + y, ty) * tscale_tobj + loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3]) + # NOTE: we refined loss function of (w, h) as L1Loss + loss_w = fluid.layers.abs(w - tw) * tscale_tobj + loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3]) + loss_h = fluid.layers.abs(h - th) * tscale_tobj + loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3]) + + loss_obj_pos, loss_obj_neg = self._calc_obj_loss( + output, obj, tobj, gt_box, self._batch_size, anchors, + num_classes, downsample, self._ignore_thresh) + + loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls, tcls) + loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0) + loss_cls = fluid.layers.reduce_sum(loss_cls, dim=[1, 2, 3, 4]) + + loss_xys.append(fluid.layers.reduce_mean(loss_x + loss_y)) + loss_whs.append(fluid.layers.reduce_mean(loss_w + loss_h)) + loss_objs.append( + fluid.layers.reduce_mean(loss_obj_pos + loss_obj_neg)) + loss_clss.append(fluid.layers.reduce_mean(loss_cls)) + + downsample //= 2 + + return { + "loss_xy": fluid.layers.sum(loss_xys), + "loss_wh": fluid.layers.sum(loss_whs), + "loss_obj": fluid.layers.sum(loss_objs), + "loss_cls": fluid.layers.sum(loss_clss), + } + + def _split_output(self, output, an_num, num_classes): + """ + Split output feature map to x, y, w, h, objectness, classification + along channel dimension + """ + x = fluid.layers.strided_slice( + output, + axes=[1], + starts=[0], + ends=[output.shape[1]], + strides=[5 + num_classes]) + y = fluid.layers.strided_slice( + output, + axes=[1], + starts=[1], + ends=[output.shape[1]], + strides=[5 + num_classes]) + w = fluid.layers.strided_slice( + output, + axes=[1], + starts=[2], + ends=[output.shape[1]], + strides=[5 + num_classes]) + h = fluid.layers.strided_slice( + output, + axes=[1], + starts=[3], + ends=[output.shape[1]], + strides=[5 + num_classes]) + obj = fluid.layers.strided_slice( + output, + axes=[1], + starts=[4], + ends=[output.shape[1]], + strides=[5 + num_classes]) + clss = [] + stride = output.shape[1] // an_num + for m in range(an_num): + clss.append( + fluid.layers.slice( + output, + axes=[1], + starts=[stride * m + 5], + ends=[stride * m + 5 + num_classes])) + cls = fluid.layers.transpose( + fluid.layers.stack( + clss, axis=1), perm=[0, 1, 3, 4, 2]) + + return (x, y, w, h, obj, cls) + + def _split_target(self, target): + """ + split target to x, y, w, h, objectness, classification + along dimension 2 + + target is in shape [N, an_num, 6 + class_num, H, W] + """ + tx = target[:, :, 0, :, :] + ty = target[:, :, 1, :, :] + tw = target[:, :, 2, :, :] + th = target[:, :, 3, :, :] + + tscale = target[:, :, 4, :, :] + tobj = target[:, :, 5, :, :] + + tcls = fluid.layers.transpose( + target[:, :, 6:, :, :], perm=[0, 1, 3, 4, 2]) + tcls.stop_gradient = True + + return (tx, ty, tw, th, tscale, tobj, tcls) + + def _calc_obj_loss(self, output, obj, tobj, gt_box, batch_size, anchors, + num_classes, downsample, ignore_thresh): + # A prediction bbox overlap any gt_bbox over ignore_thresh, + # objectness loss will be ignored, process as follows: + + # 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here + # NOTE: img_size is set as 1.0 to get noramlized pred bbox + bbox, _ = fluid.layers.yolo_box( + x=output, + img_size=fluid.layers.ones( + shape=[batch_size, 2], dtype="int32"), + anchors=anchors, + class_num=num_classes, + conf_thresh=0., + downsample_ratio=downsample, + clip_bbox=False) + + # 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox + # and gt bbox in each sample + if batch_size > 1: + preds = fluid.layers.split(bbox, batch_size, dim=0) + gts = fluid.layers.split(gt_box, batch_size, dim=0) + else: + preds = [bbox] + gts = [gt_box] + ious = [] + for pred, gt in zip(preds, gts): + + def box_xywh2xyxy(box): + x = box[:, 0] + y = box[:, 1] + w = box[:, 2] + h = box[:, 3] + return fluid.layers.stack( + [ + x - w / 2., + y - h / 2., + x + w / 2., + y + h / 2., + ], axis=1) + + pred = fluid.layers.squeeze(pred, axes=[0]) + gt = box_xywh2xyxy(fluid.layers.squeeze(gt, axes=[0])) + ious.append(fluid.layers.iou_similarity(pred, gt)) + iou = fluid.layers.stack(ious, axis=0) + + # 3. Get iou_mask by IoU between gt bbox and prediction bbox, + # Get obj_mask by tobj(holds gt_score), calculate objectness loss + max_iou = fluid.layers.reduce_max(iou, dim=-1) + iou_mask = fluid.layers.cast(max_iou <= ignore_thresh, dtype="float32") + output_shape = fluid.layers.shape(output) + an_num = len(anchors) // 2 + iou_mask = fluid.layers.reshape(iou_mask, (-1, an_num, output_shape[2], + output_shape[3])) + iou_mask.stop_gradient = True + + # NOTE: tobj holds gt_score, obj_mask holds object existence mask + obj_mask = fluid.layers.cast(tobj > 0., dtype="float32") + obj_mask.stop_gradient = True + + # For positive objectness grids, objectness loss should be calculated + # For negative objectness grids, objectness loss is calculated only iou_mask == 1.0 + loss_obj = fluid.layers.sigmoid_cross_entropy_with_logits(obj, obj_mask) + loss_obj_pos = fluid.layers.reduce_sum(loss_obj * tobj, dim=[1, 2, 3]) + loss_obj_neg = fluid.layers.reduce_sum( + loss_obj * (1.0 - obj_mask) * iou_mask, dim=[1, 2, 3]) + + return loss_obj_pos, loss_obj_neg diff --git a/tools/train.py b/tools/train.py index 1e464e2e944b77c1e659895c214efcdd4480db30..ebdaf03ab75c43099a8320c75ff244687224f767 100644 --- a/tools/train.py +++ b/tools/train.py @@ -194,8 +194,8 @@ def main(): checkpoint.load_params( exe, train_prog, cfg.pretrain_weights, ignore_params=ignore_params) - train_reader = create_reader(cfg.TrainReader, - (cfg.max_iters - start_iter) * devices_num) + train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) * + devices_num, cfg) train_loader.set_sample_list_generator(train_reader, place) # whether output bbox is normalized in model output layer