diff --git a/dygraph/configs/datasets/roadsign_voc.yml b/dygraph/configs/datasets/roadsign_voc.yml new file mode 100644 index 0000000000000000000000000000000000000000..10ce3090ed067a0387c32d384c13e4fd987cb1c4 --- /dev/null +++ b/dygraph/configs/datasets/roadsign_voc.yml @@ -0,0 +1,21 @@ +metric: VOC +map_type: 11point +num_classes: 4 + +TrainDataset: + !VOCDataSet + dataset_dir: dataset/roadsign_voc + anno_path: train.txt + label_list: label_list.txt + data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] + +EvalDataset: + !VOCDataSet + dataset_dir: dataset/roadsign_voc + anno_path: valid.txt + label_list: label_list.txt + data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] + +TestDataset: + !ImageFolder + anno_path: dataset/roadsign_voc/label_list.txt diff --git a/dygraph/configs/yolov3/yolov3_mobilenet_v1_roadsign.yml b/dygraph/configs/yolov3/yolov3_mobilenet_v1_roadsign.yml new file mode 100644 index 0000000000000000000000000000000000000000..efbd05d5647d6a08e4b12fa467b25c1893967f14 --- /dev/null +++ b/dygraph/configs/yolov3/yolov3_mobilenet_v1_roadsign.yml @@ -0,0 +1,66 @@ +_BASE_: [ + '../datasets/roadsign_voc.yml', + '../runtime.yml', + '_base_/yolov3_mobilenet_v1.yml', + '_base_/yolov3_reader.yml', +] +pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v1_270e_coco.pdparams +use_fine_grained_loss: false +load_static_weights: false +norm_type: sync_bn +weights: output/yolov3_mobilenet_v1_roadsign/model_final + +YOLOv3Loss: + ignore_thresh: 0.7 + label_smooth: true + +TrainReader: + inputs_def: + num_max_boxes: 50 + sample_transforms: + - DecodeOp: {} + - MixupOp: {alpha: 1.5, beta: 1.5} + - RandomDistortOp: {} + - RandomExpandOp: {fill_value: [123.675, 116.28, 103.53]} + - RandomCropOp: {} + - RandomFlipOp: {} + batch_transforms: + - BatchRandomResizeOp: + target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608] + random_size: True + random_interp: True + keep_ratio: False + - NormalizeBoxOp: {} + - PadBoxOp: {num_max_boxes: 50} + - BboxXYXY2XYWHOp: {} + - NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} + - PermuteOp: {} + - Gt2YoloTargetOp: + anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + anchors: [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]] + downsample_ratios: [32, 16, 8] + num_classes: 4 + batch_size: 8 + shuffle: true + drop_last: true + +snapshot_epoch: 5 +epoch: 40 + +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [32, 36] + - !LinearWarmup + start_factor: 0.3333333333333333 + steps: 100 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 diff --git a/dygraph/dataset/roadsign_voc/download_roadsign_voc.py b/dygraph/dataset/roadsign_voc/download_roadsign_voc.py new file mode 100644 index 0000000000000000000000000000000000000000..3cb517d3cf362e3ad2ec7b4ebf3bff54acb244d4 --- /dev/null +++ b/dygraph/dataset/roadsign_voc/download_roadsign_voc.py @@ -0,0 +1,28 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os.path as osp +import logging +# add python path of PadleDetection to sys.path +parent_path = osp.abspath(osp.join(__file__, *(['..'] * 3))) +if parent_path not in sys.path: + sys.path.append(parent_path) + +from ppdet.utils.download import download_dataset + +logging.basicConfig(level=logging.INFO) + +download_path = osp.split(osp.realpath(sys.argv[0]))[0] +download_dataset(download_path, 'roadsign_voc') diff --git a/dygraph/dataset/roadsign_voc/label_list.txt b/dygraph/dataset/roadsign_voc/label_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..1be460f457a2fdbec91d3a69377c232ae4a6beb0 --- /dev/null +++ b/dygraph/dataset/roadsign_voc/label_list.txt @@ -0,0 +1,4 @@ +speedlimit +crosswalk +trafficlight +stop \ No newline at end of file diff --git a/dygraph/docs/images/road554.png b/dygraph/docs/images/road554.png new file mode 100644 index 0000000000000000000000000000000000000000..1ecd45d9403897aa048417a9b69ad06e7ce41016 Binary files /dev/null and b/dygraph/docs/images/road554.png differ diff --git a/dygraph/docs/tutorials/QUICK_STARTED_cn.md b/dygraph/docs/tutorials/QUICK_STARTED_cn.md new file mode 100644 index 0000000000000000000000000000000000000000..3c6c84264057e16ca1df740bd32291943682b20e --- /dev/null +++ b/dygraph/docs/tutorials/QUICK_STARTED_cn.md @@ -0,0 +1,82 @@ +# 快速开始 +为了使得用户能够在很短时间内快速产出模型,掌握PaddleDetection的使用方式,这篇教程通过一个预训练检测模型对小数据集进行finetune。在较短时间内即可产出一个效果不错的模型。实际业务中,建议用户根据需要选择合适模型配置文件进行适配。 + +- **设置显卡** +```bash +export CUDA_VISIBLE_DEVICES=0 +``` + +## 一、快速体验 +``` +# 用PP-YOLO算法在COCO数据集上预训练模型预测一张图片 +python tools/infer.py -c configs/ppyolo/ppyolo_r50vd_dcn_1x_coco.yml -o use_gpu=true weights=https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ppyolo_r50vd_dcn_1x_coco.pdparams --infer_img=demo/000000014439.jpg +``` +结果如下图: + +![](../images/000000014439.jpg) + + +## 二、准备数据 +数据集参考[Kaggle数据集](https://www.kaggle.com/andrewmvd/road-sign-detection) ,包含877张图像,数据类别4类:crosswalk,speedlimit,stop,trafficlight。 +将数据划分为训练集701张图和测试集176张图,[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/roadsign_voc.tar). + +``` +# 注意:可跳过这步下载,后面训练会自动下载 +python dataset/roadsign_voc/download_roadsign_voc.py +``` + +## 三、训练、评估、预测 +### 1、训练 +``` +# 边训练边测试 CPU需要约1小时(use_gpu=false),1080Ti GPU需要约10分钟。 +# -c 参数表示指定使用哪个配置文件 +# -o 参数表示指定配置文件中的全局变量(覆盖配置文件中的设置),这里设置使用gpu, +# --eval 参数表示边训练边评估,会自动保存一个评估结果最的名为model_final.pdmodel的模型 + + +python tools/train.py -c configs/yolov3/yolov3_mobilenet_v1_roadsign.yml --eval -o use_gpu=true --weight_type finetune +``` + +如果想通过VisualDL实时观察loss变化曲线,在训练命令中添加--use_vdl=true,以及通过--vdl_log_dir设置日志保存路径。 + +**但注意VisualDL需Python>=3.5** + +首先安装[VisualDL](https://github.com/PaddlePaddle/VisualDL) +``` +python -m pip install visualdl -i https://mirror.baidu.com/pypi/simple +``` + +``` +python -u tools/train.py -c configs/yolov3/yolov3_mobilenet_v1_roadsign.yml \ + --use_vdl=true \ + --vdl_log_dir=vdl_dir/scalar \ + --eval +``` +通过visualdl命令实时查看变化曲线: +``` +visualdl --logdir vdl_dir/scalar/ --host --port +``` + +### 2、评估 +``` +# 评估 默认使用训练过程中保存的model_final +# -c 参数表示指定使用哪个配置文件 +# -o 参数表示指定配置文件中的全局变量(覆盖配置文件中的设置),需使用单卡评估 + +python tools/eval.py -c configs/yolov3/yolov3_mobilenet_v1_roadsign.yml -o use_gpu=true +``` + + +### 3、预测 +``` +# -c 参数表示指定使用哪个配置文件 +# -o 参数表示指定配置文件中的全局变量(覆盖配置文件中的设置) +# --infer_img 参数指定预测图像路径 +# 预测结束后会在output文件夹中生成一张画有预测结果的同名图像 + +python tools/infer.py -c configs/yolov3/yolov3_mobilenet_v1_roadsign.yml -o use_gpu=true --infer_img=demo/road554.png +``` + +结果如下图: + +![](../images/road554.png) diff --git a/dygraph/ppdet/utils/checkpoint.py b/dygraph/ppdet/utils/checkpoint.py index 38a7dc01ae92c542b3f9f335d259cc475a2029d9..fd25bd1ad6ee1d2521472782706578348e358241 100644 --- a/dygraph/ppdet/utils/checkpoint.py +++ b/dygraph/ppdet/utils/checkpoint.py @@ -164,9 +164,16 @@ def load_pretrain_weight(model, else: ignore_set = set() for name, weight in model_dict.items(): - if name in param_state_dict: - if weight.shape != param_state_dict[name].shape: + if name in param_state_dict.keys(): + if weight.shape != list(param_state_dict[name].shape): + logger.info( + '{} not used, shape {} unmatched with {} in model.'. + format(name, + list(param_state_dict[name].shape), + weight.shape)) param_state_dict.pop(name, None) + else: + logger.info('Lack weight: {}'.format(name)) model.set_dict(param_state_dict) return diff --git a/dygraph/ppdet/utils/download.py b/dygraph/ppdet/utils/download.py index 7d234fe39d8c2f04b180acc86fc96cbfdc0944c4..858ce035ddff8479ad06c1e900ec42de3956827a 100644 --- a/dygraph/ppdet/utils/download.py +++ b/dygraph/ppdet/utils/download.py @@ -81,6 +81,12 @@ DATASETS = { 'https://dataset.bj.bcebos.com/PaddleDetection_demo/fruit.tar', 'baa8806617a54ccf3685fa7153388ae6', ), ], ['Annotations', 'JPEGImages']), + 'roadsign_voc': ([( + 'https://paddlemodels.bj.bcebos.com/object_detection/roadsign_voc.tar', + '8d629c0f880dd8b48de9aeff44bf1f3e', ), ], ['annotations', 'images']), + 'roadsign_coco': ([( + 'https://paddlemodels.bj.bcebos.com/object_detection/roadsign_coco.tar', + '49ce5a9b5ad0d6266163cd01de4b018e', ), ], ['annotations', 'images']), 'objects365': (), } @@ -173,7 +179,7 @@ def get_dataset_path(path, annotation, image_dir): "https://www.objects365.org/download.html".format(name)) data_dir = osp.join(DATASET_HOME, name) # For voc, only check dir VOCdevkit/VOC2012, VOCdevkit/VOC2007 - if name == 'voc' or name == 'fruit': + if name in ['voc', 'fruit', 'roadsign_voc']: exists = True for sub_dir in dataset[1]: check_dir = osp.join(data_dir, sub_dir) @@ -185,7 +191,7 @@ def get_dataset_path(path, annotation, image_dir): return data_dir # voc exist is checked above, voc is not exist here - check_exist = name != 'voc' and name != 'fruit' + check_exist = name != 'voc' and name != 'fruit' and name != 'roadsign_voc' for url, md5sum in dataset[0]: get_path(url, data_dir, md5sum, check_exist) @@ -195,10 +201,11 @@ def get_dataset_path(path, annotation, image_dir): return data_dir # not match any dataset in DATASETS - raise ValueError("Dataset {} is not valid and cannot parse dataset type " - "'{}' for automaticly downloading, which only supports " - "'voc' , 'coco', 'wider_face' and 'fruit' currently". - format(path, osp.split(path)[-1])) + raise ValueError( + "Dataset {} is not valid and cannot parse dataset type " + "'{}' for automaticly downloading, which only supports " + "'voc' , 'coco', 'wider_face', 'fruit' and 'roadsign_voc' currently". + format(path, osp.split(path)[-1])) def create_voc_list(data_dir, devkit_subdir='VOCdevkit'):