From 59613880577485be49f42ecd6638c2ab25f3bb4e Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Fri, 24 Jul 2020 21:16:48 +0800 Subject: [PATCH] add COCO test-dev eval config and doc (#1099) * add COCO test-dev eval config and doc --- configs/ppyolo/README.md | 20 ++++- configs/ppyolo/ppyolo_test.yml | 140 +++++++++++++++++++++++++++++ ppdet/modeling/losses/iou_loss.py | 4 +- ppdet/modeling/losses/yolo_loss.py | 28 ++++-- 4 files changed, 178 insertions(+), 14 deletions(-) create mode 100644 configs/ppyolo/ppyolo_test.yml diff --git a/configs/ppyolo/README.md b/configs/ppyolo/README.md index aa809d3e3..c2aa81d99 100644 --- a/configs/ppyolo/README.md +++ b/configs/ppyolo/README.md @@ -11,7 +11,7 @@ [PP-YOLO](https://arxiv.org/abs/2007.12099)的PaddleDetection优化和改进的YOLOv3的模型,其精度(COCO数据集mAP)和推理速度均优于[YOLOv4](https://arxiv.org/abs/2004.10934)模型,要求使用PaddlePaddle 1.8.4(2020年8月中旬发布)或适当的[develop版本](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/Tables.html#whl-dev)。 -PP-YOLO在[COCO](http://cocodataset.org) test2019数据集上精度达到45.2%,在单卡V100上FP32推理速度为72.9 FPS, V100上开启TensorRT下FP16推理速度为155.6 FPS。 +PP-YOLO在[COCO](http://cocodataset.org) test-dev2019数据集上精度达到45.2%,在单卡V100上FP32推理速度为72.9 FPS, V100上开启TensorRT下FP16推理速度为155.6 FPS。
@@ -45,7 +45,7 @@ PP-YOLO从如下方面优化和提升YOLOv3模型的精度和速度: **注意:** -- PP-YOLO模型使用COCO数据集中train2017作为训练集,使用test2019左右测试集。 +- PP-YOLO模型使用COCO数据集中train2017作为训练集,使用test-dev2019左右测试集。 - PP-YOLO模型训练过程中使用8GPU,每GPU batch size为24进行训练,如训练GPU数和batch size不使用上述配置,须参考[FAQ](../../docs/FAQ.md)调整学习率和迭代次数。 - PP-YOLO模型推理速度测试采用单卡V100,batch size=1进行测试,使用CUDA 10.2, CUDNN 7.5.1,TensorRT推理速度测试使用TensorRT 5.1.2.2。 - PP-YOLO模型推理速度测试数据为使用`tools/export_model.py`脚本导出模型后,使用`deploy/python/infer.py`脚本中的`--run_benchnark`参数使用Paddle预测库进行推理速度benchmark测试结果, 且测试的均为不包含数据预处理和模型输出后处理(NMS)的数据(与[YOLOv4(AlexyAB)](https://github.com/AlexeyAB/darknet)测试方法一致)。 @@ -66,7 +66,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python tools/train.py -c configs/ppyolo/ppy ### 2. 评估 -使用单GPU通过如下命令一键式评估模型效果 +使用单GPU通过如下命令一键式评估模型在COCO val2017数据集效果 ```bash # 使用PaddleDetection发布的权重 @@ -76,6 +76,20 @@ CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/ppyolo/ppyolo.yml -o weig CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/ppyolo/ppyolo.yml -o weights=output/ppyolo/best_model ``` +我们提供了`configs/ppyolo/ppyolo_test.yml`用于评估COCO test-dev2019数据集的效果,评估COCO test-dev2019数据集的效果须先从[COCO数据集下载页](https://cocodataset.org/#download)下载test-dev2019数据集,解压到`configs/ppyolo/ppyolo_test.yml`中`EvalReader.dataset`中配置的路径,并使用如下命令进行评估 + +```bash +# 使用PaddleDetection发布的权重 +CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/ppyolo/ppyolo_test.yml -o weights=https://paddlemodels.bj.bcebos.com/object_detection/ppyolo.pdparams + +# 使用训练保存的checkpoint +CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/ppyolo/ppyolo_test.yml -o weights=output/ppyolo/best_model +``` + +评估结果保存于`bbox.json`中,将其压缩为zip包后通过[COCO数据集评估页](https://competitions.codalab.org/competitions/20794#participate)提交评估。 + +**注意:** `configs/ppyolo/ppyolo_test.yml`仅用于评估COCO test-dev数据集,不用于训练和评估COCO val2017数据集。 + ### 3. 推理 使用单GPU通过如下命令一键式推理图像,通过`--infer_img`指定图像路径,或通过`--infer_dir`指定目录并推理目录下所有图像 diff --git a/configs/ppyolo/ppyolo_test.yml b/configs/ppyolo/ppyolo_test.yml new file mode 100644 index 000000000..840865a0b --- /dev/null +++ b/configs/ppyolo/ppyolo_test.yml @@ -0,0 +1,140 @@ +# NOTE: this config file is only used for evaluation on COCO test2019 set, +# for training or evaluationg on COCO val2017, please use ppyolo.yml +architecture: YOLOv3 +use_gpu: true +max_iters: 500000 +log_smooth_window: 100 +log_iter: 100 +save_dir: output +snapshot_iter: 10000 +metric: COCO +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar +weights: output/ppyolo/model_final +num_classes: 80 +use_fine_grained_loss: true +use_ema: true +ema_decay: 0.9998 +save_prediction_only: True + +YOLOv3: + backbone: ResNet + yolo_head: YOLOv3Head + use_fine_grained_loss: true + +ResNet: + norm_type: sync_bn + freeze_at: 0 + freeze_norm: false + norm_decay: 0. + depth: 50 + feature_maps: [3, 4, 5] + variant: d + dcn_v2_stages: [5] + +YOLOv3Head: + anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + anchors: [[10, 13], [16, 30], [33, 23], + [30, 61], [62, 45], [59, 119], + [116, 90], [156, 198], [373, 326]] + norm_decay: 0. + coord_conv: true + iou_aware: true + iou_aware_factor: 0.4 + scale_x_y: 1.05 + spp: true + yolo_loss: YOLOv3Loss + nms: MatrixNMS + drop_block: true + +YOLOv3Loss: + batch_size: 24 + ignore_thresh: 0.7 + scale_x_y: 1.05 + label_smooth: false + use_fine_grained_loss: true + iou_loss: IouLoss + iou_aware_loss: IouAwareLoss + +IouLoss: + loss_weight: 2.5 + max_height: 608 + max_width: 608 + +IouAwareLoss: + loss_weight: 1.0 + max_height: 608 + max_width: 608 + +MatrixNMS: + background_label: -1 + keep_top_k: 100 + normalized: false + score_threshold: 0.01 + post_threshold: 0.01 + +LearningRate: + base_lr: 0.00333 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 400000 + - 450000 + - !LinearWarmup + start_factor: 0. + steps: 4000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 + +_READER_: 'ppyolo_reader.yml' +EvalReader: + inputs_def: + fields: ['image', 'im_size', 'im_id'] + num_max_boxes: 90 + dataset: + !COCODataSet + image_dir: test2017 + anno_path: annotations/image_info_test-dev2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + - !ResizeImage + target_size: 608 + interp: 1 + - !NormalizeImage + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + is_scale: True + is_channel_first: false + - !Permute + to_bgr: false + channel_first: True + batch_size: 1 + +TestReader: + dataset: + !ImageFolder + use_default_label: true + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + - !ResizeImage + target_size: 608 + interp: 1 + - !NormalizeImage + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + is_scale: True + is_channel_first: false + - !Permute + to_bgr: false + channel_first: True diff --git a/ppdet/modeling/losses/iou_loss.py b/ppdet/modeling/losses/iou_loss.py index 15a5f229c..624cd47c0 100644 --- a/ppdet/modeling/losses/iou_loss.py +++ b/ppdet/modeling/losses/iou_loss.py @@ -182,8 +182,8 @@ class IouLoss(object): dcx_sig = fluid.layers.sigmoid(dcx) dcy_sig = fluid.layers.sigmoid(dcy) if (abs(scale_x_y - 1.0) > eps): - dcx_sig = scale_x_y * dcx_sig - 0.5 * (scale_x_y - 1) - dcy_sig = scale_x_y * dcy_sig - 0.5 * (scale_x_y - 1) + dcx_sig = scale_x_y * dcx_sig - 0.5 * (scale_x_y - 1) + dcy_sig = scale_x_y * dcy_sig - 0.5 * (scale_x_y - 1) cx = fluid.layers.elementwise_add(dcx_sig, gi) / grid_x_act cy = fluid.layers.elementwise_add(dcy_sig, gj) / grid_y_act diff --git a/ppdet/modeling/losses/yolo_loss.py b/ppdet/modeling/losses/yolo_loss.py index e97198fbc..6823c024b 100644 --- a/ppdet/modeling/losses/yolo_loss.py +++ b/ppdet/modeling/losses/yolo_loss.py @@ -91,8 +91,15 @@ class YOLOv3Loss(object): return {'loss': sum(losses)} - def _get_fine_grained_loss(self, outputs, targets, gt_box, batch_size, - num_classes, mask_anchors, ignore_thresh, eps=1.e-10): + def _get_fine_grained_loss(self, + outputs, + targets, + gt_box, + batch_size, + num_classes, + mask_anchors, + ignore_thresh, + eps=1.e-10): """ Calculate fine grained YOLOv3 loss @@ -148,8 +155,10 @@ class YOLOv3Loss(object): y, ty) * tscale_tobj loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3]) else: - dx = scale_x_y * fluid.layers.sigmoid(x) - 0.5 * (scale_x_y - 1.0) - dy = scale_x_y * fluid.layers.sigmoid(y) - 0.5 * (scale_x_y - 1.0) + dx = scale_x_y * fluid.layers.sigmoid(x) - 0.5 * (scale_x_y - + 1.0) + dy = scale_x_y * fluid.layers.sigmoid(y) - 0.5 * (scale_x_y - + 1.0) loss_x = fluid.layers.abs(dx - tx) * tscale_tobj loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3]) loss_y = fluid.layers.abs(dy - ty) * tscale_tobj @@ -162,7 +171,8 @@ class YOLOv3Loss(object): loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3]) if self._iou_loss is not None: loss_iou = self._iou_loss(x, y, w, h, tx, ty, tw, th, anchors, - downsample, self._batch_size, scale_x_y) + downsample, self._batch_size, + scale_x_y) loss_iou = loss_iou * tscale_tobj loss_iou = fluid.layers.reduce_sum(loss_iou, dim=[1, 2, 3]) loss_ious.append(fluid.layers.reduce_mean(loss_iou)) @@ -304,7 +314,7 @@ class YOLOv3Loss(object): downsample_ratio=downsample, clip_bbox=False, scale_x_y=scale_x_y) - + # 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox # and gt bbox in each sample if batch_size > 1: @@ -333,17 +343,17 @@ class YOLOv3Loss(object): pred = fluid.layers.squeeze(pred, axes=[0]) gt = box_xywh2xyxy(fluid.layers.squeeze(gt, axes=[0])) ious.append(fluid.layers.iou_similarity(pred, gt)) - + iou = fluid.layers.stack(ious, axis=0) # 3. Get iou_mask by IoU between gt bbox and prediction bbox, # Get obj_mask by tobj(holds gt_score), calculate objectness loss - + max_iou = fluid.layers.reduce_max(iou, dim=-1) iou_mask = fluid.layers.cast(max_iou <= ignore_thresh, dtype="float32") if self.match_score: max_prob = fluid.layers.reduce_max(prob, dim=-1) iou_mask = iou_mask * fluid.layers.cast( - max_prob <= 0.25, dtype="float32") + max_prob <= 0.25, dtype="float32") output_shape = fluid.layers.shape(output) an_num = len(anchors) // 2 iou_mask = fluid.layers.reshape(iou_mask, (-1, an_num, output_shape[2], -- GitLab