Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
bae72f97
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bae72f97
编写于
8月 10, 2020
作者:
J
Jason
提交者:
GitHub
8月 10, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #267 from FlyingQianMM/develop_qh
add ppyolo in docs
上级
5c4646d1
333b1b48
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
150 addition
and
8 deletion
+150
-8
docs/apis/models/detection.md
docs/apis/models/detection.md
+124
-0
docs/appendix/model_zoo.md
docs/appendix/model_zoo.md
+1
-0
docs/examples/solutions.md
docs/examples/solutions.md
+1
-0
docs/train/object_detection.md
docs/train/object_detection.md
+1
-0
paddlex/cv/models/base.py
paddlex/cv/models/base.py
+2
-2
paddlex/cv/models/ppyolo.py
paddlex/cv/models/ppyolo.py
+13
-3
paddlex/cv/models/utils/pretrain_weights.py
paddlex/cv/models/utils/pretrain_weights.py
+6
-2
paddlex/cv/models/yolo_v3.py
paddlex/cv/models/yolo_v3.py
+1
-1
tutorials/train/README.md
tutorials/train/README.md
+1
-0
未找到文件。
docs/apis/models/detection.md
浏览文件 @
bae72f97
# Object Detection
## paddlex.det.PPYOLO
```
python
paddlex
.
det
.
PPYOLO
(
num_classes
=
80
,
backbone
=
'ResNet50_vd_ssld'
,
with_dcn_v2
=
True
,
anchors
=
None
,
anchor_masks
=
None
,
use_coord_conv
=
True
,
use_iou_aware
=
True
,
use_spp
=
True
,
use_drop_block
=
True
,
scale_x_y
=
1.05
,
ignore_threshold
=
0.7
,
label_smooth
=
False
,
use_iou_loss
=
True
,
use_matrix_nms
=
True
,
nms_score_threshold
=
0.01
,
nms_topk
=
1000
,
nms_keep_topk
=
100
,
nms_iou_threshold
=
0.45
,
train_random_shapes
=
[
320
,
352
,
384
,
416
,
448
,
480
,
512
,
544
,
576
,
608
])
```
> 构建PPYOLO检测器。**注意在PPYOLO,num_classes不需要包含背景类,如目标包括human、dog两种,则num_classes设为2即可,这里与FasterRCNN/MaskRCNN有差别**
> **参数**
>
> > - **num_classes** (int): 类别数。默认为80。
> > - **backbone** (str): PPYOLO的backbone网络,取值范围为['ResNet50_vd_ssld']。默认为'ResNet50_vd_ssld'。
> > - **with_dcn_v2** (bool): Backbone是否使用DCNv2结构。默认为True。
> > - **anchors** (list|tuple): anchor框的宽度和高度,为None时表示使用默认值
> > [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
> [59, 119], [116, 90], [156, 198], [373, 326]]。
> > - **anchor_masks** (list|tuple): 在计算PPYOLO损失时,使用anchor的mask索引,为None时表示使用默认值
> > [[6, 7, 8], [3, 4, 5], [0, 1, 2]]。
> > - **use_coord_conv** (bool): 是否使用CoordConv。默认值为True。
> > - **use_iou_aware** (bool): 是否使用IoU Aware分支。默认值为True。
> > - **use_spp** (bool): 是否使用Spatial Pyramid Pooling结构。默认值为True。
> > - **use_drop_block** (bool): 是否使用Drop Block。默认值为True。
> > - **scale_x_y** (float): 调整中心点位置时的系数因子。默认值为1.05。
> > - **use_iou_loss** (bool): 是否使用IoU loss。默认值为True。
> > - **use_matrix_nms** (bool): 是否使用Matrix NMS。默认值为True。
> > - **ignore_threshold** (float): 在计算PPYOLO损失时,IoU大于`ignore_threshold`的预测框的置信度被忽略。默认为0.7。
> > - **nms_score_threshold** (float): 检测框的置信度得分阈值,置信度得分低于阈值的框应该被忽略。默认为0.01。
> > - **nms_topk** (int): 进行NMS时,根据置信度保留的最大检测框数。默认为1000。
> > - **nms_keep_topk** (int): 进行NMS后,每个图像要保留的总检测框数。默认为100。
> > - **nms_iou_threshold** (float): 进行NMS时,用于剔除检测框IOU的阈值。默认为0.45。
> > - **label_smooth** (bool): 是否使用label smooth。默认值为False。
> > - **train_random_shapes** (list|tuple): 训练时从列表中随机选择图像大小。默认值为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
### train
```
python
train
(
self
,
num_epochs
,
train_dataset
,
train_batch_size
=
8
,
eval_dataset
=
None
,
save_interval_epochs
=
20
,
log_interval_steps
=
2
,
save_dir
=
'output'
,
pretrain_weights
=
'IMAGENET'
,
optimizer
=
None
,
learning_rate
=
1.0
/
8000
,
warmup_steps
=
1000
,
warmup_start_lr
=
0.0
,
lr_decay_epochs
=
[
213
,
240
],
lr_decay_gamma
=
0.1
,
metric
=
None
,
use_vdl
=
False
,
sensitivities_file
=
None
,
eval_metric_loss
=
0.05
,
early_stop
=
False
,
early_stop_patience
=
5
,
resume_checkpoint
=
None
,
use_ema
=
True
,
ema_decay
=
0.9998
)
```
> PPYOLO模型的训练接口,函数内置了`piecewise`学习率衰减策略和`momentum`优化器。
> **参数**
>
> > - **num_epochs** (int): 训练迭代轮数。
> > - **train_dataset** (paddlex.datasets): 训练数据读取器。
> > - **train_batch_size** (int): 训练数据batch大小。目前检测仅支持单卡评估,训练数据batch大小与显卡数量之商为验证数据batch大小。默认值为8。
> > - **eval_dataset** (paddlex.datasets): 验证数据读取器。
> > - **save_interval_epochs** (int): 模型保存间隔(单位:迭代轮数)。默认为20。
> > - **log_interval_steps** (int): 训练日志输出间隔(单位:迭代次数)。默认为2。
> > - **save_dir** (str): 模型保存路径。默认值为'output'。
> > - **pretrain_weights** (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',则自动下载在ImageNet图片数据上预训练的模型权重;若为字符串'COCO',则自动下载在COCO数据集上预训练的模型权重;若为None,则不使用预训练模型。默认为None。
> > - **optimizer** (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认优化器:fluid.layers.piecewise_decay衰减策略,fluid.optimizer.Momentum优化方法。
> > - **learning_rate** (float): 默认优化器的学习率。默认为1.0/8000。
> > - **warmup_steps** (int): 默认优化器进行warmup过程的步数。默认为1000。
> > - **warmup_start_lr** (int): 默认优化器warmup的起始学习率。默认为0.0。
> > - **lr_decay_epochs** (list): 默认优化器的学习率衰减轮数。默认为[213, 240]。
> > - **lr_decay_gamma** (float): 默认优化器的学习率衰减率。默认为0.1。
> > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在PascalVOC数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (bool): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
> > - **use_ema** (bool): 是否使用指数衰减计算参数的滑动平均值。默认值为True。
> > - **ema_decay** (float): 指数衰减率。默认值为0.9998。
### evaluate
```
python
evaluate
(
self
,
eval_dataset
,
batch_size
=
1
,
epoch_id
=
None
,
metric
=
None
,
return_details
=
False
)
```
> PPYOLO模型的评估接口,模型评估后会返回在验证集上的指标`box_map`(metric指定为'VOC'时)或`box_mmap`(metric指定为`COCO`时)。
> **参数**
>
> > - **eval_dataset** (paddlex.datasets): 验证数据读取器。
> > - **batch_size** (int): 验证数据批大小。默认为1。
> > - **epoch_id** (int): 当前评估模型所在的训练轮数。
> > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,根据用户传入的Dataset自动选择,如为VOCDetection,则`metric`为'VOC';如为COCODetection,则`metric`为'COCO'默认为None, 如为EasyData类型数据集,同时也会使用'VOC'。
> > - **return_details** (bool): 是否返回详细信息。默认值为False。
> >
> **返回值**
>
> > - **tuple** (metrics, eval_details) | **dict** (metrics): 当`return_details`为True时,返回(metrics, eval_details),当`return_details`为False时,返回metrics。metrics为dict,包含关键字:'bbox_mmap'或者’bbox_map‘,分别表示平均准确率平均值在各个阈值下的结果取平均值的结果(mmAP)、平均准确率平均值(mAP)。eval_details为dict,包含关键字:'bbox',对应元素预测结果列表,每个预测结果由图像id、预测框类别id、预测框坐标、预测框得分;’gt‘:真实标注框相关信息。
### predict
```
python
predict
(
self
,
img_file
,
transforms
=
None
)
```
> PPYOLO模型预测接口。需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`YOLOv3.test_transforms`和`YOLOv3.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`predict`接口时,用户需要再重新定义`test_transforms`传入给`predict`接口
> **参数**
>
> > - **img_file** (str|np.ndarray): 预测图像路径或numpy数组(HWC排列,BGR格式)。
> > - **transforms** (paddlex.det.transforms): 数据预处理操作。
>
> **返回值**
>
> > - **list**: 预测结果列表,列表中每个元素均为一个dict,key包括'bbox', 'category', 'category_id', 'score',分别表示每个预测目标的框坐标信息、类别、类别id、置信度,其中框坐标信息为[xmin, ymin, w, h],即左上角x, y坐标和框的宽和高。
### batch_predict
```
python
batch_predict
(
self
,
img_file_list
,
transforms
=
None
,
thread_num
=
2
)
```
> PPYOLO模型批量预测接口。需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`YOLOv3.test_transforms`和`YOLOv3.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`batch_predict`接口时,用户需要再重新定义`test_transforms`传入给`batch_predict`接口
> **参数**
>
> > - **img_file_list** (str|np.ndarray): 对列表(或元组)中的图像同时进行预测,列表中的元素是预测图像路径或numpy数组(HWC排列,BGR格式)。
> > - **transforms** (paddlex.det.transforms): 数据预处理操作。
> > - **thread_num** (int): 并发执行各图像预处理时的线程数。
>
> **返回值**
>
> > - **list**: 每个元素都为列表,表示各图像的预测结果。在各图像的预测结果列表中,每个元素均为一个dict,key包括'bbox', 'category', 'category_id', 'score',分别表示每个预测目标的框坐标信息、类别、类别id、置信度,其中框坐标信息为[xmin, ymin, w, h],即左上角x, y坐标和框的宽和高。
## paddlex.det.YOLOv3
```
python
...
...
docs/appendix/model_zoo.md
浏览文件 @
bae72f97
...
...
@@ -45,6 +45,7 @@
|
[
FasterRCNN-ResNet101-FPN
](
https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_1x.tar
)
| 244.2MB | 119.788 | 38.7 |
|
[
FasterRCNN-ResNet101_vd-FPN
](
https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_2x.tar
)
|244.3MB | 156.097 | 40.5 |
|
[
FasterRCNN-HRNet_W18-FPN
](
https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_hrnetv2p_w18_1x.tar
)
|115.5MB | 81.592 | 36 |
|
[
PPYOLO
](
https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_2x.pdparams
)
| 329.1MB | - |45.9 |
|
[
YOLOv3-DarkNet53
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_darknet.tar
)
|249.2MB | 42.672 | 38.9 |
|
[
YOLOv3-MobileNetV1
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar
)
|99.2MB | 15.442 | 29.3 |
|
[
YOLOv3-MobileNetV3_large
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v3.pdparams
)
|100.7MB | 143.322 | 31.6 |
...
...
docs/examples/solutions.md
浏览文件 @
bae72f97
...
...
@@ -42,6 +42,7 @@ PaddleX针对图像分类、目标检测、实例分割和语义分割4种视觉
| YOLOv3-MobileNetV3_larget | 适用于追求高速预测的移动端场景 | 100.7MB | 143.322 | - | - | 31.6 |
| YOLOv3-MobileNetV1 | 精度相对偏低,适用于追求高速预测的服务器端场景 | 99.2MB| 15.422 | - | - | 29.3 |
| YOLOv3-DarkNet53 | 在预测速度和模型精度上都有较好的表现,适用于大多数的服务器端场景| 249.2MB | 42.672 | - | - | 38.9 |
| PPYOLO | 预测速度和模型精度都比YOLOv3-DarkNet53优异,适用于大多数的服务器端场景 | 329.1MB | - | - | - | 45.9 |
| FasterRCNN-ResNet50-FPN | 经典的二阶段检测器,预测速度相对较慢,适用于重视模型精度的服务器端场景 | 167.MB | 83.189 | - | -| 37.2 |
| FasterRCNN-HRNet_W18-FPN | 适用于对图像分辨率较为敏感、对目标细节预测要求更高的服务器端场景 | 115.5MB | 81.592 | - | - | 36 |
| FasterRCNN-ResNet101_vd-FPN | 超高精度模型,预测时间更长,在处理较大数据量时有较高的精度,适用于服务器端场景 | 244.3MB | 156.097 | - | - | 40.5 |
...
...
docs/train/object_detection.md
浏览文件 @
bae72f97
...
...
@@ -13,6 +13,7 @@ PaddleX目前提供了FasterRCNN和YOLOv3两种检测结构,多种backbone模
|
[
YOLOv3-MobileNetV1
](
https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/yolov3_mobilenetv1.py
)
| 29.3% | 99.2MB | 15.442ms | - | 模型小,预测速度快,适用于低性能或移动端设备 |
|
[
YOLOv3-MobileNetV3
](
https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/yolov3_mobilenetv3.py
)
| 31.6% | 100.7MB | 143.322ms | - | 模型小,移动端上预测速度有优势 |
|
[
YOLOv3-DarkNet53
](
https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/yolov3_darknet53.py
)
| 38.9% | 249.2MB | 42.672ms | - | 模型较大,预测速度快,适用于服务端 |
|
[
PPYOLO
](
https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/ppyolo.py
)
| 45.9% | 329.1MB | - | - | 模型较大,预测速度比YOLOv3-DarkNet53更快,适用于服务端 |
|
[
FasterRCNN-ResNet50-FPN
](
https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/faster_rcnn_r50_fpn.py
)
| 37.2% | 167.7MB | 197.715ms | - | 模型精度高,适用于服务端部署 |
|
[
FasterRCNN-ResNet18-FPN
](
https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/faster_rcnn_r18_fpn.py
)
| 32.6% | 173.2MB | - | - | 模型精度高,适用于服务端部署 |
|
[
FasterRCNN-HRNet-FPN
](
https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/faster_rcnn_hrnet_fpn.py
)
| 36.0% | 115.MB | 81.592ms | - | 模型精度高,预测速度快,适用于服务端部署 |
...
...
paddlex/cv/models/base.py
浏览文件 @
bae72f97
...
...
@@ -548,7 +548,7 @@ class BaseAPI:
current_save_dir
=
osp
.
join
(
save_dir
,
"epoch_{}"
.
format
(
i
+
1
))
if
not
osp
.
isdir
(
current_save_dir
):
os
.
makedirs
(
current_save_dir
)
if
hasattr
(
self
,
'use_ema'
):
if
getattr
(
self
,
'use_ema'
,
False
):
self
.
exe
.
run
(
self
.
ema
.
apply_program
)
if
eval_dataset
is
not
None
and
eval_dataset
.
num_samples
>
0
:
self
.
eval_metrics
,
self
.
eval_details
=
self
.
evaluate
(
...
...
@@ -576,7 +576,7 @@ class BaseAPI:
log_writer
.
add_scalar
(
"Metrics/Eval(Epoch): {}"
.
format
(
k
),
v
,
i
+
1
)
self
.
save_model
(
save_dir
=
current_save_dir
)
if
hasattr
(
self
,
'use_ema'
):
if
getattr
(
self
,
'use_ema'
,
False
):
self
.
exe
.
run
(
self
.
ema
.
restore_program
)
time_eval_one_epoch
=
time
.
time
()
-
eval_epoch_start_time
eval_epoch_start_time
=
time
.
time
()
...
...
paddlex/cv/models/ppyolo.py
浏览文件 @
bae72f97
...
...
@@ -37,11 +37,19 @@ class PPYOLO(BaseAPI):
Args:
num_classes (int): 类别数。默认为80。
backbone (str): PPYOLO的backbone网络,取值范围为['ResNet50_vd']。默认为'ResNet50_vd'。
with_dcn_v2 (bool): Backbone是否使用DCNv2结构。默认为True。
anchors (list|tuple): anchor框的宽度和高度,为None时表示使用默认值
[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]]。
anchor_masks (list|tuple): 在计算PPYOLO损失时,使用anchor的mask索引,为None时表示使用默认值
[[6, 7, 8], [3, 4, 5], [0, 1, 2]]。
use_coord_conv (bool): 是否使用CoordConv。默认值为True。
use_iou_aware (bool): 是否使用IoU Aware分支。默认值为True。
use_spp (bool): 是否使用Spatial Pyramid Pooling结构。默认值为True。
use_drop_block (bool): 是否使用Drop Block。默认值为True。
scale_x_y (float): 调整中心点位置时的系数因子。默认值为1.05。
use_iou_loss (bool): 是否使用IoU loss。默认值为True。
use_matrix_nms (bool): 是否使用Matrix NMS。默认值为True。
ignore_threshold (float): 在计算PPYOLO损失时,IoU大于`ignore_threshold`的预测框的置信度被忽略。默认为0.7。
nms_score_threshold (float): 检测框的置信度得分阈值,置信度得分低于阈值的框应该被忽略。默认为0.01。
nms_topk (int): 进行NMS时,根据置信度保留的最大检测框数。默认为1000。
...
...
@@ -54,7 +62,7 @@ class PPYOLO(BaseAPI):
def
__init__
(
self
,
num_classes
=
80
,
backbone
=
'ResNet50_vd'
,
backbone
=
'ResNet50_vd
_ssld
'
,
with_dcn_v2
=
True
,
# YOLO Head
anchors
=
None
,
...
...
@@ -79,7 +87,7 @@ class PPYOLO(BaseAPI):
]):
self
.
init_params
=
locals
()
super
(
PPYOLO
,
self
).
__init__
(
'detector'
)
backbones
=
[
'ResNet50_vd'
]
backbones
=
[
'ResNet50_vd
_ssld
'
]
assert
backbone
in
backbones
,
"backbone should be one of {}"
.
format
(
backbones
)
self
.
backbone
=
backbone
...
...
@@ -116,7 +124,7 @@ class PPYOLO(BaseAPI):
self
.
with_dcn_v2
=
with_dcn_v2
def
_get_backbone
(
self
,
backbone_name
):
if
backbone_name
==
'ResNet50_vd'
:
if
backbone_name
.
startswith
(
'ResNet50_vd'
)
:
backbone
=
paddlex
.
cv
.
nets
.
ResNet
(
norm_type
=
'sync_bn'
,
layers
=
50
,
...
...
@@ -252,6 +260,8 @@ class PPYOLO(BaseAPI):
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
use_ema (bool): 是否使用指数衰减计算参数的滑动平均值。默认值为True。
ema_decay (float): 指数衰减率。默认值为0.9998。
Raises:
ValueError: 评估类型不在指定列表中。
...
...
paddlex/cv/models/utils/pretrain_weights.py
浏览文件 @
bae72f97
...
...
@@ -116,7 +116,9 @@ coco_pretrain = {
'DeepLabv3p_MobileNetV2_x1.0_COCO'
:
'https://bj.bcebos.com/v1/paddleseg/deeplab_mobilenet_x1_0_coco.tgz'
,
'DeepLabv3p_Xception65_COCO'
:
'https://paddleseg.bj.bcebos.com/models/xception65_coco.tgz'
'https://paddleseg.bj.bcebos.com/models/xception65_coco.tgz'
,
'PPYOLO_ResNet50_vd_ssld_COCO'
:
'https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_2x.pdparams'
}
cityscapes_pretrain
=
{
...
...
@@ -226,7 +228,9 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
new_save_dir
=
save_dir
if
hasattr
(
paddlex
,
'pretrain_dir'
):
new_save_dir
=
paddlex
.
pretrain_dir
if
class_name
in
[
'YOLOv3'
,
'FasterRCNN'
,
'MaskRCNN'
,
'DeepLabv3p'
]:
if
class_name
in
[
'YOLOv3'
,
'FasterRCNN'
,
'MaskRCNN'
,
'DeepLabv3p'
,
'PPYOLO'
]:
backbone
=
'{}_{}'
.
format
(
class_name
,
backbone
)
backbone
=
"{}_{}"
.
format
(
backbone
,
flag
)
if
flag
==
'COCO'
:
...
...
paddlex/cv/models/yolo_v3.py
浏览文件 @
bae72f97
...
...
@@ -60,7 +60,7 @@ class YOLOv3(PPYOLO):
]
assert
backbone
in
backbones
,
"backbone should be one of {}"
.
format
(
backbones
)
super
(
YOLOv3
,
self
).
__init__
(
'detector'
)
super
(
PPYOLO
,
self
).
__init__
(
'detector'
)
self
.
backbone
=
backbone
self
.
num_classes
=
num_classes
self
.
anchors
=
anchors
...
...
tutorials/train/README.md
浏览文件 @
bae72f97
...
...
@@ -12,6 +12,7 @@
|object_detection/faster_rcnn_hrnet_fpn.py | 目标检测FasterRCNN | 昆虫检测 |
|object_detection/faster_rcnn_r18_fpn.py | 目标检测FasterRCNN | 昆虫检测 |
|object_detection/faster_rcnn_r50_fpn.py | 目标检测FasterRCNN | 昆虫检测 |
|object_detection/ppyolo.py | 目标检测PPYOLO | 昆虫检测 |
|object_detection/yolov3_darknet53.py | 目标检测YOLOv3 | 昆虫检测 |
|object_detection/yolov3_mobilenetv1.py | 目标检测YOLOv3 | 昆虫检测 |
|object_detection/yolov3_mobilenetv3.py | 目标检测YOLOv3 | 昆虫检测 |
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录