未验证 提交 bae72f97 编写于 作者: J Jason 提交者: GitHub

Merge pull request #267 from FlyingQianMM/develop_qh

add ppyolo in docs
# Object Detection
## paddlex.det.PPYOLO
```python
paddlex.det.PPYOLO(num_classes=80, backbone='ResNet50_vd_ssld', with_dcn_v2=True, anchors=None, anchor_masks=None, use_coord_conv=True, use_iou_aware=True, use_spp=True, use_drop_block=True, scale_x_y=1.05, ignore_threshold=0.7, label_smooth=False, use_iou_loss=True, use_matrix_nms=True, nms_score_threshold=0.01, nms_topk=1000, nms_keep_topk=100, nms_iou_threshold=0.45, train_random_shapes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608])
```
> 构建PPYOLO检测器。**注意在PPYOLO,num_classes不需要包含背景类,如目标包括human、dog两种,则num_classes设为2即可,这里与FasterRCNN/MaskRCNN有差别**
> **参数**
>
> > - **num_classes** (int): 类别数。默认为80。
> > - **backbone** (str): PPYOLO的backbone网络,取值范围为['ResNet50_vd_ssld']。默认为'ResNet50_vd_ssld'。
> > - **with_dcn_v2** (bool): Backbone是否使用DCNv2结构。默认为True。
> > - **anchors** (list|tuple): anchor框的宽度和高度,为None时表示使用默认值
> > [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
> [59, 119], [116, 90], [156, 198], [373, 326]]。
> > - **anchor_masks** (list|tuple): 在计算PPYOLO损失时,使用anchor的mask索引,为None时表示使用默认值
> > [[6, 7, 8], [3, 4, 5], [0, 1, 2]]。
> > - **use_coord_conv** (bool): 是否使用CoordConv。默认值为True。
> > - **use_iou_aware** (bool): 是否使用IoU Aware分支。默认值为True。
> > - **use_spp** (bool): 是否使用Spatial Pyramid Pooling结构。默认值为True。
> > - **use_drop_block** (bool): 是否使用Drop Block。默认值为True。
> > - **scale_x_y** (float): 调整中心点位置时的系数因子。默认值为1.05。
> > - **use_iou_loss** (bool): 是否使用IoU loss。默认值为True。
> > - **use_matrix_nms** (bool): 是否使用Matrix NMS。默认值为True。
> > - **ignore_threshold** (float): 在计算PPYOLO损失时,IoU大于`ignore_threshold`的预测框的置信度被忽略。默认为0.7。
> > - **nms_score_threshold** (float): 检测框的置信度得分阈值,置信度得分低于阈值的框应该被忽略。默认为0.01。
> > - **nms_topk** (int): 进行NMS时,根据置信度保留的最大检测框数。默认为1000。
> > - **nms_keep_topk** (int): 进行NMS后,每个图像要保留的总检测框数。默认为100。
> > - **nms_iou_threshold** (float): 进行NMS时,用于剔除检测框IOU的阈值。默认为0.45。
> > - **label_smooth** (bool): 是否使用label smooth。默认值为False。
> > - **train_random_shapes** (list|tuple): 训练时从列表中随机选择图像大小。默认值为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
### train
```python
train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None, use_ema=True, ema_decay=0.9998)
```
> PPYOLO模型的训练接口,函数内置了`piecewise`学习率衰减策略和`momentum`优化器。
> **参数**
>
> > - **num_epochs** (int): 训练迭代轮数。
> > - **train_dataset** (paddlex.datasets): 训练数据读取器。
> > - **train_batch_size** (int): 训练数据batch大小。目前检测仅支持单卡评估,训练数据batch大小与显卡数量之商为验证数据batch大小。默认值为8。
> > - **eval_dataset** (paddlex.datasets): 验证数据读取器。
> > - **save_interval_epochs** (int): 模型保存间隔(单位:迭代轮数)。默认为20。
> > - **log_interval_steps** (int): 训练日志输出间隔(单位:迭代次数)。默认为2。
> > - **save_dir** (str): 模型保存路径。默认值为'output'。
> > - **pretrain_weights** (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',则自动下载在ImageNet图片数据上预训练的模型权重;若为字符串'COCO',则自动下载在COCO数据集上预训练的模型权重;若为None,则不使用预训练模型。默认为None。
> > - **optimizer** (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认优化器:fluid.layers.piecewise_decay衰减策略,fluid.optimizer.Momentum优化方法。
> > - **learning_rate** (float): 默认优化器的学习率。默认为1.0/8000。
> > - **warmup_steps** (int): 默认优化器进行warmup过程的步数。默认为1000。
> > - **warmup_start_lr** (int): 默认优化器warmup的起始学习率。默认为0.0。
> > - **lr_decay_epochs** (list): 默认优化器的学习率衰减轮数。默认为[213, 240]。
> > - **lr_decay_gamma** (float): 默认优化器的学习率衰减率。默认为0.1。
> > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在PascalVOC数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (bool): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
> > - **use_ema** (bool): 是否使用指数衰减计算参数的滑动平均值。默认值为True。
> > - **ema_decay** (float): 指数衰减率。默认值为0.9998。
### evaluate
```python
evaluate(self, eval_dataset, batch_size=1, epoch_id=None, metric=None, return_details=False)
```
> PPYOLO模型的评估接口,模型评估后会返回在验证集上的指标`box_map`(metric指定为'VOC'时)或`box_mmap`(metric指定为`COCO`时)。
> **参数**
>
> > - **eval_dataset** (paddlex.datasets): 验证数据读取器。
> > - **batch_size** (int): 验证数据批大小。默认为1。
> > - **epoch_id** (int): 当前评估模型所在的训练轮数。
> > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,根据用户传入的Dataset自动选择,如为VOCDetection,则`metric`为'VOC';如为COCODetection,则`metric`为'COCO'默认为None, 如为EasyData类型数据集,同时也会使用'VOC'。
> > - **return_details** (bool): 是否返回详细信息。默认值为False。
> >
> **返回值**
>
> > - **tuple** (metrics, eval_details) | **dict** (metrics): 当`return_details`为True时,返回(metrics, eval_details),当`return_details`为False时,返回metrics。metrics为dict,包含关键字:'bbox_mmap'或者’bbox_map‘,分别表示平均准确率平均值在各个阈值下的结果取平均值的结果(mmAP)、平均准确率平均值(mAP)。eval_details为dict,包含关键字:'bbox',对应元素预测结果列表,每个预测结果由图像id、预测框类别id、预测框坐标、预测框得分;’gt‘:真实标注框相关信息。
### predict
```python
predict(self, img_file, transforms=None)
```
> PPYOLO模型预测接口。需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`YOLOv3.test_transforms`和`YOLOv3.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`predict`接口时,用户需要再重新定义`test_transforms`传入给`predict`接口
> **参数**
>
> > - **img_file** (str|np.ndarray): 预测图像路径或numpy数组(HWC排列,BGR格式)。
> > - **transforms** (paddlex.det.transforms): 数据预处理操作。
>
> **返回值**
>
> > - **list**: 预测结果列表,列表中每个元素均为一个dict,key包括'bbox', 'category', 'category_id', 'score',分别表示每个预测目标的框坐标信息、类别、类别id、置信度,其中框坐标信息为[xmin, ymin, w, h],即左上角x, y坐标和框的宽和高。
### batch_predict
```python
batch_predict(self, img_file_list, transforms=None, thread_num=2)
```
> PPYOLO模型批量预测接口。需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`YOLOv3.test_transforms`和`YOLOv3.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`batch_predict`接口时,用户需要再重新定义`test_transforms`传入给`batch_predict`接口
> **参数**
>
> > - **img_file_list** (str|np.ndarray): 对列表(或元组)中的图像同时进行预测,列表中的元素是预测图像路径或numpy数组(HWC排列,BGR格式)。
> > - **transforms** (paddlex.det.transforms): 数据预处理操作。
> > - **thread_num** (int): 并发执行各图像预处理时的线程数。
>
> **返回值**
>
> > - **list**: 每个元素都为列表,表示各图像的预测结果。在各图像的预测结果列表中,每个元素均为一个dict,key包括'bbox', 'category', 'category_id', 'score',分别表示每个预测目标的框坐标信息、类别、类别id、置信度,其中框坐标信息为[xmin, ymin, w, h],即左上角x, y坐标和框的宽和高。
## paddlex.det.YOLOv3
```python
......
......@@ -45,6 +45,7 @@
|[FasterRCNN-ResNet101-FPN](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_1x.tar)| 244.2MB | 119.788 | 38.7 |
|[FasterRCNN-ResNet101_vd-FPN](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_2x.tar) |244.3MB | 156.097 | 40.5 |
|[FasterRCNN-HRNet_W18-FPN](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_hrnetv2p_w18_1x.tar) |115.5MB | 81.592 | 36 |
|[PPYOLO](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_2x.pdparams) | 329.1MB | - |45.9 |
|[YOLOv3-DarkNet53](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_darknet.tar)|249.2MB | 42.672 | 38.9 |
|[YOLOv3-MobileNetV1](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar) |99.2MB | 15.442 | 29.3 |
|[YOLOv3-MobileNetV3_large](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v3.pdparams)|100.7MB | 143.322 | 31.6 |
......
......@@ -42,6 +42,7 @@ PaddleX针对图像分类、目标检测、实例分割和语义分割4种视觉
| YOLOv3-MobileNetV3_larget | 适用于追求高速预测的移动端场景 | 100.7MB | 143.322 | - | - | 31.6 |
| YOLOv3-MobileNetV1 | 精度相对偏低,适用于追求高速预测的服务器端场景 | 99.2MB| 15.422 | - | - | 29.3 |
| YOLOv3-DarkNet53 | 在预测速度和模型精度上都有较好的表现,适用于大多数的服务器端场景| 249.2MB | 42.672 | - | - | 38.9 |
| PPYOLO | 预测速度和模型精度都比YOLOv3-DarkNet53优异,适用于大多数的服务器端场景 | 329.1MB | - | - | - | 45.9 |
| FasterRCNN-ResNet50-FPN | 经典的二阶段检测器,预测速度相对较慢,适用于重视模型精度的服务器端场景 | 167.MB | 83.189 | - | -| 37.2 |
| FasterRCNN-HRNet_W18-FPN | 适用于对图像分辨率较为敏感、对目标细节预测要求更高的服务器端场景 | 115.5MB | 81.592 | - | - | 36 |
| FasterRCNN-ResNet101_vd-FPN | 超高精度模型,预测时间更长,在处理较大数据量时有较高的精度,适用于服务器端场景 | 244.3MB | 156.097 | - | - | 40.5 |
......
......@@ -13,6 +13,7 @@ PaddleX目前提供了FasterRCNN和YOLOv3两种检测结构,多种backbone模
| [YOLOv3-MobileNetV1](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/yolov3_mobilenetv1.py) | 29.3% | 99.2MB | 15.442ms | - | 模型小,预测速度快,适用于低性能或移动端设备 |
| [YOLOv3-MobileNetV3](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/yolov3_mobilenetv3.py) | 31.6% | 100.7MB | 143.322ms | - | 模型小,移动端上预测速度有优势 |
| [YOLOv3-DarkNet53](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/yolov3_darknet53.py) | 38.9% | 249.2MB | 42.672ms | - | 模型较大,预测速度快,适用于服务端 |
| [PPYOLO](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/ppyolo.py) | 45.9% | 329.1MB | - | - | 模型较大,预测速度比YOLOv3-DarkNet53更快,适用于服务端 |
| [FasterRCNN-ResNet50-FPN](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/faster_rcnn_r50_fpn.py) | 37.2% | 167.7MB | 197.715ms | - | 模型精度高,适用于服务端部署 |
| [FasterRCNN-ResNet18-FPN](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/faster_rcnn_r18_fpn.py) | 32.6% | 173.2MB | - | - | 模型精度高,适用于服务端部署 |
| [FasterRCNN-HRNet-FPN](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/faster_rcnn_hrnet_fpn.py) | 36.0% | 115.MB | 81.592ms | - | 模型精度高,预测速度快,适用于服务端部署 |
......
......@@ -548,7 +548,7 @@ class BaseAPI:
current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
if not osp.isdir(current_save_dir):
os.makedirs(current_save_dir)
if hasattr(self, 'use_ema'):
if getattr(self, 'use_ema', False):
self.exe.run(self.ema.apply_program)
if eval_dataset is not None and eval_dataset.num_samples > 0:
self.eval_metrics, self.eval_details = self.evaluate(
......@@ -576,7 +576,7 @@ class BaseAPI:
log_writer.add_scalar(
"Metrics/Eval(Epoch): {}".format(k), v, i + 1)
self.save_model(save_dir=current_save_dir)
if hasattr(self, 'use_ema'):
if getattr(self, 'use_ema', False):
self.exe.run(self.ema.restore_program)
time_eval_one_epoch = time.time() - eval_epoch_start_time
eval_epoch_start_time = time.time()
......
......@@ -37,11 +37,19 @@ class PPYOLO(BaseAPI):
Args:
num_classes (int): 类别数。默认为80。
backbone (str): PPYOLO的backbone网络,取值范围为['ResNet50_vd']。默认为'ResNet50_vd'。
with_dcn_v2 (bool): Backbone是否使用DCNv2结构。默认为True。
anchors (list|tuple): anchor框的宽度和高度,为None时表示使用默认值
[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]]。
anchor_masks (list|tuple): 在计算PPYOLO损失时,使用anchor的mask索引,为None时表示使用默认值
[[6, 7, 8], [3, 4, 5], [0, 1, 2]]。
use_coord_conv (bool): 是否使用CoordConv。默认值为True。
use_iou_aware (bool): 是否使用IoU Aware分支。默认值为True。
use_spp (bool): 是否使用Spatial Pyramid Pooling结构。默认值为True。
use_drop_block (bool): 是否使用Drop Block。默认值为True。
scale_x_y (float): 调整中心点位置时的系数因子。默认值为1.05。
use_iou_loss (bool): 是否使用IoU loss。默认值为True。
use_matrix_nms (bool): 是否使用Matrix NMS。默认值为True。
ignore_threshold (float): 在计算PPYOLO损失时,IoU大于`ignore_threshold`的预测框的置信度被忽略。默认为0.7。
nms_score_threshold (float): 检测框的置信度得分阈值,置信度得分低于阈值的框应该被忽略。默认为0.01。
nms_topk (int): 进行NMS时,根据置信度保留的最大检测框数。默认为1000。
......@@ -54,7 +62,7 @@ class PPYOLO(BaseAPI):
def __init__(
self,
num_classes=80,
backbone='ResNet50_vd',
backbone='ResNet50_vd_ssld',
with_dcn_v2=True,
# YOLO Head
anchors=None,
......@@ -79,7 +87,7 @@ class PPYOLO(BaseAPI):
]):
self.init_params = locals()
super(PPYOLO, self).__init__('detector')
backbones = ['ResNet50_vd']
backbones = ['ResNet50_vd_ssld']
assert backbone in backbones, "backbone should be one of {}".format(
backbones)
self.backbone = backbone
......@@ -116,7 +124,7 @@ class PPYOLO(BaseAPI):
self.with_dcn_v2 = with_dcn_v2
def _get_backbone(self, backbone_name):
if backbone_name == 'ResNet50_vd':
if backbone_name.startswith('ResNet50_vd'):
backbone = paddlex.cv.nets.ResNet(
norm_type='sync_bn',
layers=50,
......@@ -252,6 +260,8 @@ class PPYOLO(BaseAPI):
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
use_ema (bool): 是否使用指数衰减计算参数的滑动平均值。默认值为True。
ema_decay (float): 指数衰减率。默认值为0.9998。
Raises:
ValueError: 评估类型不在指定列表中。
......
......@@ -116,7 +116,9 @@ coco_pretrain = {
'DeepLabv3p_MobileNetV2_x1.0_COCO':
'https://bj.bcebos.com/v1/paddleseg/deeplab_mobilenet_x1_0_coco.tgz',
'DeepLabv3p_Xception65_COCO':
'https://paddleseg.bj.bcebos.com/models/xception65_coco.tgz'
'https://paddleseg.bj.bcebos.com/models/xception65_coco.tgz',
'PPYOLO_ResNet50_vd_ssld_COCO':
'https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_2x.pdparams'
}
cityscapes_pretrain = {
......@@ -226,7 +228,9 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
new_save_dir = save_dir
if hasattr(paddlex, 'pretrain_dir'):
new_save_dir = paddlex.pretrain_dir
if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN', 'DeepLabv3p']:
if class_name in [
'YOLOv3', 'FasterRCNN', 'MaskRCNN', 'DeepLabv3p', 'PPYOLO'
]:
backbone = '{}_{}'.format(class_name, backbone)
backbone = "{}_{}".format(backbone, flag)
if flag == 'COCO':
......
......@@ -60,7 +60,7 @@ class YOLOv3(PPYOLO):
]
assert backbone in backbones, "backbone should be one of {}".format(
backbones)
super(YOLOv3, self).__init__('detector')
super(PPYOLO, self).__init__('detector')
self.backbone = backbone
self.num_classes = num_classes
self.anchors = anchors
......
......@@ -12,6 +12,7 @@
|object_detection/faster_rcnn_hrnet_fpn.py | 目标检测FasterRCNN | 昆虫检测 |
|object_detection/faster_rcnn_r18_fpn.py | 目标检测FasterRCNN | 昆虫检测 |
|object_detection/faster_rcnn_r50_fpn.py | 目标检测FasterRCNN | 昆虫检测 |
|object_detection/ppyolo.py | 目标检测PPYOLO | 昆虫检测 |
|object_detection/yolov3_darknet53.py | 目标检测YOLOv3 | 昆虫检测 |
|object_detection/yolov3_mobilenetv1.py | 目标检测YOLOv3 | 昆虫检测 |
|object_detection/yolov3_mobilenetv3.py | 目标检测YOLOv3 | 昆虫检测 |
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册