Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
ee2d40d4
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ee2d40d4
编写于
6月 07, 2020
作者:
F
FlyingQianMM
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add coco pretrained weights for detection
上级
dede0136
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
49 addition
and
4 deletion
+49
-4
paddlex/cv/models/base.py
paddlex/cv/models/base.py
+1
-1
paddlex/cv/models/utils/pretrain_weights.py
paddlex/cv/models/utils/pretrain_weights.py
+48
-3
未找到文件。
paddlex/cv/models/base.py
浏览文件 @
ee2d40d4
...
@@ -201,7 +201,7 @@ class BaseAPI:
...
@@ -201,7 +201,7 @@ class BaseAPI:
if
backbone
==
"HRNet"
:
if
backbone
==
"HRNet"
:
backbone
=
backbone
+
"_W{}"
.
format
(
self
.
width
)
backbone
=
backbone
+
"_W{}"
.
format
(
self
.
width
)
pretrain_weights
=
get_pretrain_weights
(
pretrain_weights
=
get_pretrain_weights
(
pretrain_weights
,
self
.
model_typ
e
,
backbone
,
pretrain_dir
)
pretrain_weights
,
class_nam
e
,
backbone
,
pretrain_dir
)
if
startup_prog
is
None
:
if
startup_prog
is
None
:
startup_prog
=
fluid
.
default_startup_program
()
startup_prog
=
fluid
.
default_startup_program
()
self
.
exe
.
run
(
startup_prog
)
self
.
exe
.
run
(
startup_prog
)
...
...
paddlex/cv/models/utils/pretrain_weights.py
浏览文件 @
ee2d40d4
import
paddlex
import
paddlex
import
paddlex.utils.logging
as
logging
import
paddlehub
as
hub
import
paddlehub
as
hub
import
os
import
os
import
os.path
as
osp
import
os.path
as
osp
...
@@ -73,16 +74,58 @@ image_pretrain = {
...
@@ -73,16 +74,58 @@ image_pretrain = {
}
}
coco_pretrain
=
{
coco_pretrain
=
{
'YOLOv3_DarkNet53'
:
'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_darknet.tar'
,
'YOLOv3_MobileNetV1'
:
'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar'
,
'YOLOv3_MobileNetV3_large'
:
'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v3.pdparams'
,
'YOLOv3_ResNet34'
:
'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar'
,
'YOLOv3_ResNet50_vd'
:
'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r50vd_dcn.tar'
,
'FasterRCNN_ResNet50'
:
'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_fpn_2x.tar'
,
'FasterRCNN_ResNet50_vd'
:
'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_2x.tar'
,
'FasterRCNN_ResNet101'
:
'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar'
,
'FasterRCNN_ResNet101_vd'
:
'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_2x.tar'
,
'FasterRCNN_HRNet_W18'
:
'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_hrnetv2p_w18_2x.tar'
,
'MaskRCNN_ResNet50'
:
'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_fpn_2x.tar'
,
'MaskRCNN_ResNet50_vd'
:
'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_vd_fpn_2x.tar'
,
'MaskRCNN_ResNet101'
:
'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar'
,
'MaskRCNN_ResNet101_vd'
:
'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r101_vd_fpn_1x.tar'
,
'UNet'
:
'https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz'
'UNet'
:
'https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz'
}
}
def
get_pretrain_weights
(
flag
,
model_typ
e
,
backbone
,
save_dir
):
def
get_pretrain_weights
(
flag
,
class_nam
e
,
backbone
,
save_dir
):
if
flag
is
None
:
if
flag
is
None
:
return
None
return
None
elif
osp
.
isdir
(
flag
):
elif
osp
.
isdir
(
flag
):
return
flag
return
flag
elif
flag
==
'IMAGENET'
:
warning_info
=
"{} supports to be finetuned with weights pretrained on the IMAGENET dataset only, so pretrain_weights is forced to be set to IMAGENET"
if
flag
==
'COCO'
:
if
class_name
==
"FasterRCNN"
and
backbone
in
[
'ResNet18'
]
or
\
class_name
==
"MaskRCNN"
and
backbone
in
[
'ResNet18'
,
'HRNet_W18'
]
or
\
class_name
==
'DeepLabv3p'
and
backbone
in
[
'Xception41'
,
'MobileNetV2_x0.25'
,
'MobileNetV2_x0.5'
,
'MobileNetV2_x1.5'
,
'MobileNetV2_x2.0'
]:
model_name
=
'{}_{}'
.
format
(
class_name
,
backbone
)
logging
.
warning
(
warning_info
.
format
(
model_name
))
flag
=
'IMAGENET'
elif
class_name
==
'HRNet'
:
logging
.
warning
(
warning_info
.
format
(
class_name
))
flag
=
'IMAGENET'
if
flag
==
'CITYSCAPES'
:
model_name
=
'{}_{}'
.
format
(
class_name
,
backbone
)
if
flag
==
'IMAGENET'
:
new_save_dir
=
save_dir
new_save_dir
=
save_dir
if
hasattr
(
paddlex
,
'pretrain_dir'
):
if
hasattr
(
paddlex
,
'pretrain_dir'
):
new_save_dir
=
paddlex
.
pretrain_dir
new_save_dir
=
paddlex
.
pretrain_dir
...
@@ -94,7 +137,7 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
...
@@ -94,7 +137,7 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
backbone
=
'MobileNetV3_small_x1_0_ssld'
backbone
=
'MobileNetV3_small_x1_0_ssld'
elif
backbone
==
'MobileNetV3_large_ssld'
:
elif
backbone
==
'MobileNetV3_large_ssld'
:
backbone
=
'MobileNetV3_large_x1_0_ssld'
backbone
=
'MobileNetV3_large_x1_0_ssld'
if
model_type
==
'detector'
:
if
class_name
in
[
'YOLOv3'
,
'FasterRCNN'
,
'MaskRCNN'
]
:
if
backbone
==
'ResNet50'
:
if
backbone
==
'ResNet50'
:
backbone
=
'DetResNet50'
backbone
=
'DetResNet50'
assert
backbone
in
image_pretrain
,
"There is not ImageNet pretrain weights for {}, you may try COCO."
.
format
(
assert
backbone
in
image_pretrain
,
"There is not ImageNet pretrain weights for {}, you may try COCO."
.
format
(
...
@@ -121,6 +164,8 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
...
@@ -121,6 +164,8 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
new_save_dir
=
save_dir
new_save_dir
=
save_dir
if
hasattr
(
paddlex
,
'pretrain_dir'
):
if
hasattr
(
paddlex
,
'pretrain_dir'
):
new_save_dir
=
paddlex
.
pretrain_dir
new_save_dir
=
paddlex
.
pretrain_dir
if
class_name
in
[
'YOLOv3'
,
'FasterRCNN'
,
'MaskRCNN'
]:
backbone
=
'{}_{}'
.
format
(
class_name
,
backbone
)
url
=
coco_pretrain
[
backbone
]
url
=
coco_pretrain
[
backbone
]
fname
=
osp
.
split
(
url
)[
-
1
].
split
(
'.'
)[
0
]
fname
=
osp
.
split
(
url
)[
-
1
].
split
(
'.'
)[
0
]
# paddlex.utils.download_and_decompress(url, path=new_save_dir)
# paddlex.utils.download_and_decompress(url, path=new_save_dir)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录