Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
db4b84cf
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
db4b84cf
编写于
5月 12, 2020
作者:
J
Jason
提交者:
GitHub
5月 12, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #35 from PaddlePaddle/develop_slim
modify prune notice and docs
上级
57418d2b
aebd5798
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
81 addition
and
56 deletion
+81
-56
docs/apis/models.md
docs/apis/models.md
+2
-2
paddlex/__init__.py
paddlex/__init__.py
+3
-4
paddlex/cv/models/base.py
paddlex/cv/models/base.py
+10
-0
paddlex/cv/models/slim/prune_config.py
paddlex/cv/models/slim/prune_config.py
+20
-16
paddlex/cv/models/utils/pretrain_weights.py
paddlex/cv/models/utils/pretrain_weights.py
+44
-31
setup.py
setup.py
+2
-3
未找到文件。
docs/apis/models.md
浏览文件 @
db4b84cf
...
@@ -182,7 +182,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
...
@@ -182,7 +182,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
**参数:**
**参数:**
> - **num_classes** (int): 包含了背景类的类别数。默认为81。
> - **num_classes** (int): 包含了背景类的类别数。默认为81。
> - **backbone** (str): FasterRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50
vd', 'ResNet101', 'ResNet101
vd']。默认为'ResNet50'。
> - **backbone** (str): FasterRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50
_vd', 'ResNet101', 'ResNet101_
vd']。默认为'ResNet50'。
> - **with_fpn** (bool): 是否使用FPN结构。默认为True。
> - **with_fpn** (bool): 是否使用FPN结构。默认为True。
> - **aspect_ratios** (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
> - **aspect_ratios** (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
> - **anchor_sizes** (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
> - **anchor_sizes** (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
...
@@ -262,7 +262,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
...
@@ -262,7 +262,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
**参数:**
**参数:**
> - **num_classes** (int): 包含了背景类的类别数。默认为81。
> - **num_classes** (int): 包含了背景类的类别数。默认为81。
> - **backbone** (str): MaskRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50
vd', 'ResNet101', 'ResNet101
vd']。默认为'ResNet50'。
> - **backbone** (str): MaskRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50
_vd', 'ResNet101', 'ResNet101_
vd']。默认为'ResNet50'。
> - **with_fpn** (bool): 是否使用FPN结构。默认为True。
> - **with_fpn** (bool): 是否使用FPN结构。默认为True。
> - **aspect_ratios** (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
> - **aspect_ratios** (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
> - **anchor_sizes** (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
> - **anchor_sizes** (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
...
...
paddlex/__init__.py
浏览文件 @
db4b84cf
...
@@ -38,10 +38,9 @@ except:
...
@@ -38,10 +38,9 @@ except:
"[WARNING] pycocotools install: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/install.md"
"[WARNING] pycocotools install: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/install.md"
)
)
import
paddlehub
as
hub
#import paddlehub as hub
if
hub
.
version
.
hub_version
<
'1.6.2'
:
#if hub.version.hub_version < '1.6.2':
raise
Exception
(
"[ERROR] paddlehub >= 1.6.2 is required"
)
# raise Exception("[ERROR] paddlehub >= 1.6.2 is required")
env_info
=
get_environ_info
()
env_info
=
get_environ_info
()
load_model
=
cv
.
models
.
load_model
load_model
=
cv
.
models
.
load_model
...
...
paddlex/cv/models/base.py
浏览文件 @
db4b84cf
...
@@ -204,13 +204,23 @@ class BaseAPI:
...
@@ -204,13 +204,23 @@ class BaseAPI:
self
.
exe
,
self
.
train_prog
,
pretrain_weights
,
fuse_bn
)
self
.
exe
,
self
.
train_prog
,
pretrain_weights
,
fuse_bn
)
# 进行裁剪
# 进行裁剪
if
sensitivities_file
is
not
None
:
if
sensitivities_file
is
not
None
:
import
paddleslim
from
.slim.prune_config
import
get_sensitivities
from
.slim.prune_config
import
get_sensitivities
sensitivities_file
=
get_sensitivities
(
sensitivities_file
,
self
,
sensitivities_file
=
get_sensitivities
(
sensitivities_file
,
self
,
save_dir
)
save_dir
)
from
.slim.prune
import
get_params_ratios
,
prune_program
from
.slim.prune
import
get_params_ratios
,
prune_program
logging
.
info
(
"Start to prune program with eval_metric_loss = {}"
.
format
(
eval_metric_loss
))
origin_flops
=
paddleslim
.
analysis
.
flops
(
self
.
test_prog
)
prune_params_ratios
=
get_params_ratios
(
prune_params_ratios
=
get_params_ratios
(
sensitivities_file
,
eval_metric_loss
=
eval_metric_loss
)
sensitivities_file
,
eval_metric_loss
=
eval_metric_loss
)
prune_program
(
self
,
prune_params_ratios
)
prune_program
(
self
,
prune_params_ratios
)
current_flops
=
paddleslim
.
analysis
.
flops
(
self
.
test_prog
)
remaining_ratio
=
current_flops
/
origin_flops
logging
.
info
(
"Finish prune program, before FLOPs:{}, after prune FLOPs:{}, remaining ratio:{}"
.
format
(
origin_flops
,
current_flops
,
remaining_ratio
))
self
.
status
=
'Prune'
self
.
status
=
'Prune'
def
get_model_info
(
self
):
def
get_model_info
(
self
):
...
...
paddlex/cv/models/slim/prune_config.py
浏览文件 @
db4b84cf
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
import
numpy
as
np
import
numpy
as
np
import
os.path
as
osp
import
os.path
as
osp
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
paddlehub
as
hub
#
import paddlehub as hub
import
paddlex
import
paddlex
sensitivities_data
=
{
sensitivities_data
=
{
...
@@ -105,22 +105,26 @@ def get_sensitivities(flag, model, save_dir):
...
@@ -105,22 +105,26 @@ def get_sensitivities(flag, model, save_dir):
model_type
)
model_type
)
url
=
sensitivities_data
[
model_type
]
url
=
sensitivities_data
[
model_type
]
fname
=
osp
.
split
(
url
)[
-
1
]
fname
=
osp
.
split
(
url
)[
-
1
]
try
:
paddlex
.
utils
.
download
(
url
,
path
=
save_dir
)
hub
.
download
(
fname
,
save_path
=
save_dir
)
except
Exception
as
e
:
if
isinstance
(
e
,
hub
.
ResourceNotFoundError
):
raise
Exception
(
"Resource for model {}(key='{}') not found"
.
format
(
model_type
,
fname
))
elif
isinstance
(
e
,
hub
.
ServerConnectionError
):
raise
Exception
(
"Cannot get reource for model {}(key='{}'), please check your internet connecgtion"
.
format
(
model_type
,
fname
))
else
:
raise
Exception
(
"Unexpected error, please make sure paddlehub >= 1.6.2 {}"
.
format
(
str
(
e
)))
return
osp
.
join
(
save_dir
,
fname
)
return
osp
.
join
(
save_dir
,
fname
)
# try:
# hub.download(fname, save_path=save_dir)
# except Exception as e:
# if isinstance(e, hub.ResourceNotFoundError):
# raise Exception(
# "Resource for model {}(key='{}') not found".format(
# model_type, fname))
# elif isinstance(e, hub.ServerConnectionError):
# raise Exception(
# "Cannot get reource for model {}(key='{}'), please check your internet connecgtion"
# .format(model_type, fname))
# else:
# raise Exception(
# "Unexpected error, please make sure paddlehub >= 1.6.2 {}".
# format(str(e)))
# return osp.join(save_dir, fname)
else
:
else
:
raise
Exception
(
raise
Exception
(
"sensitivities need to be defined as directory path or `DEFAULT`(download sensitivities automatically)."
"sensitivities need to be defined as directory path or `DEFAULT`(download sensitivities automatically)."
...
...
paddlex/cv/models/utils/pretrain_weights.py
浏览文件 @
db4b84cf
import
paddlex
import
paddlex
import
paddlehub
as
hub
#
import paddlehub as hub
import
os
import
os
import
os.path
as
osp
import
os.path
as
osp
...
@@ -85,40 +85,53 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
...
@@ -85,40 +85,53 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
backbone
=
'DetResNet50'
backbone
=
'DetResNet50'
assert
backbone
in
image_pretrain
,
"There is not ImageNet pretrain weights for {}, you may try COCO."
.
format
(
assert
backbone
in
image_pretrain
,
"There is not ImageNet pretrain weights for {}, you may try COCO."
.
format
(
backbone
)
backbone
)
try
:
url
=
image_pretrain
[
backbone
]
hub
.
download
(
backbone
,
save_path
=
new_save_dir
)
fname
=
osp
.
split
(
url
)[
-
1
].
split
(
'.'
)[
0
]
except
Exception
as
e
:
paddlex
.
utils
.
download_and_decompress
(
url
,
path
=
new_save_dir
)
if
isinstance
(
e
,
hub
.
ResourceNotFoundError
):
return
osp
.
join
(
new_save_dir
,
fname
)
raise
Exception
(
# try:
"Resource for backbone {} not found"
.
format
(
backbone
))
# hub.download(backbone, save_path=new_save_dir)
elif
isinstance
(
e
,
hub
.
ServerConnectionError
):
# except Exception as e:
raise
Exception
(
# if isinstance(e, hub.ResourceNotFoundError):
"Cannot get reource for backbone {}, please check your internet connecgtion"
# raise Exception(
.
format
(
backbone
))
# "Resource for backbone {} not found".format(backbone))
else
:
# elif isinstance(e, hub.ServerConnectionError):
raise
Exception
(
# raise Exception(
"Unexpected error, please make sure paddlehub >= 1.6.2"
)
# "Cannot get reource for backbone {}, please check your internet connecgtion"
return
osp
.
join
(
new_save_dir
,
backbone
)
# .format(backbone))
# else:
# raise Exception(
# "Unexpected error, please make sure paddlehub >= 1.6.2")
# return osp.join(new_save_dir, backbone)
elif
flag
==
'COCO'
:
elif
flag
==
'COCO'
:
new_save_dir
=
save_dir
new_save_dir
=
save_dir
if
hasattr
(
paddlex
,
'pretrain_dir'
):
if
hasattr
(
paddlex
,
'pretrain_dir'
):
new_save_dir
=
paddlex
.
pretrain_dir
new_save_dir
=
paddlex
.
pretrain_dir
assert
backbone
in
coco_pretrain
,
"There is not COCO pretrain weights for {}, you may try ImageNet."
.
format
(
url
=
coco_pretrain
[
backbone
]
backbone
)
fname
=
osp
.
split
(
url
)[
-
1
].
split
(
'.'
)[
0
]
try
:
paddlex
.
utils
.
download_and_decompress
(
url
,
path
=
new_save_dir
)
hub
.
download
(
backbone
,
save_path
=
new_save_dir
)
return
osp
.
join
(
new_save_dir
,
fname
)
except
Exception
as
e
:
if
isinstance
(
hub
.
ResourceNotFoundError
):
raise
Exception
(
# new_save_dir = save_dir
"Resource for backbone {} not found"
.
format
(
backbone
))
# if hasattr(paddlex, 'pretrain_dir'):
elif
isinstance
(
hub
.
ServerConnectionError
):
# new_save_dir = paddlex.pretrain_dir
raise
Exception
(
# assert backbone in coco_pretrain, "There is not COCO pretrain weights for {}, you may try ImageNet.".format(
"Cannot get reource for backbone {}, please check your internet connecgtion"
# backbone)
.
format
(
backbone
))
# try:
else
:
# hub.download(backbone, save_path=new_save_dir)
raise
Exception
(
# except Exception as e:
"Unexpected error, please make sure paddlehub >= 1.6.2"
)
# if isinstance(hub.ResourceNotFoundError):
return
osp
.
join
(
new_save_dir
,
backbone
)
# raise Exception(
# "Resource for backbone {} not found".format(backbone))
# elif isinstance(hub.ServerConnectionError):
# raise Exception(
# "Cannot get reource for backbone {}, please check your internet connecgtion"
# .format(backbone))
# else:
# raise Exception(
# "Unexpected error, please make sure paddlehub >= 1.6.2")
# return osp.join(new_save_dir, backbone)
else
:
else
:
raise
Exception
(
raise
Exception
(
"pretrain_weights need to be defined as directory path or `IMAGENET` or 'COCO' (download pretrain weights automatically)."
"pretrain_weights need to be defined as directory path or `IMAGENET` or 'COCO' (download pretrain weights automatically)."
...
...
setup.py
浏览文件 @
db4b84cf
...
@@ -29,9 +29,8 @@ setuptools.setup(
...
@@ -29,9 +29,8 @@ setuptools.setup(
packages
=
setuptools
.
find_packages
(),
packages
=
setuptools
.
find_packages
(),
setup_requires
=
[
'cython'
,
'numpy'
,
'sklearn'
],
setup_requires
=
[
'cython'
,
'numpy'
,
'sklearn'
],
install_requires
=
[
install_requires
=
[
"pycocotools;platform_system!='Windows'"
,
"pycocotools;platform_system!='Windows'"
,
'pyyaml'
,
'colorama'
,
'tqdm'
,
'pyyaml'
,
'colorama'
,
'tqdm'
,
'visualdl==1.3.0'
,
'visualdl==1.3.0'
,
'paddleslim==1.0.1'
'paddleslim==1.0.1'
,
'paddlehub>=1.6.2'
],
],
classifiers
=
[
classifiers
=
[
"Programming Language :: Python :: 3"
,
"Programming Language :: Python :: 3"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录