未验证 提交 db4b84cf 编写于 作者: J Jason 提交者: GitHub

Merge pull request #35 from PaddlePaddle/develop_slim

modify prune notice and docs
...@@ -182,7 +182,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec ...@@ -182,7 +182,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
**参数:** **参数:**
> - **num_classes** (int): 包含了背景类的类别数。默认为81。 > - **num_classes** (int): 包含了背景类的类别数。默认为81。
> - **backbone** (str): FasterRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50vd', 'ResNet101', 'ResNet101vd']。默认为'ResNet50'。 > - **backbone** (str): FasterRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd']。默认为'ResNet50'。
> - **with_fpn** (bool): 是否使用FPN结构。默认为True。 > - **with_fpn** (bool): 是否使用FPN结构。默认为True。
> - **aspect_ratios** (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。 > - **aspect_ratios** (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
> - **anchor_sizes** (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。 > - **anchor_sizes** (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
...@@ -262,7 +262,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_ ...@@ -262,7 +262,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
**参数:** **参数:**
> - **num_classes** (int): 包含了背景类的类别数。默认为81。 > - **num_classes** (int): 包含了背景类的类别数。默认为81。
> - **backbone** (str): MaskRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50vd', 'ResNet101', 'ResNet101vd']。默认为'ResNet50'。 > - **backbone** (str): MaskRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd']。默认为'ResNet50'。
> - **with_fpn** (bool): 是否使用FPN结构。默认为True。 > - **with_fpn** (bool): 是否使用FPN结构。默认为True。
> - **aspect_ratios** (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。 > - **aspect_ratios** (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
> - **anchor_sizes** (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。 > - **anchor_sizes** (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
......
...@@ -38,10 +38,9 @@ except: ...@@ -38,10 +38,9 @@ except:
"[WARNING] pycocotools install: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/install.md" "[WARNING] pycocotools install: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/install.md"
) )
import paddlehub as hub #import paddlehub as hub
if hub.version.hub_version < '1.6.2': #if hub.version.hub_version < '1.6.2':
raise Exception("[ERROR] paddlehub >= 1.6.2 is required") # raise Exception("[ERROR] paddlehub >= 1.6.2 is required")
env_info = get_environ_info() env_info = get_environ_info()
load_model = cv.models.load_model load_model = cv.models.load_model
......
...@@ -204,13 +204,23 @@ class BaseAPI: ...@@ -204,13 +204,23 @@ class BaseAPI:
self.exe, self.train_prog, pretrain_weights, fuse_bn) self.exe, self.train_prog, pretrain_weights, fuse_bn)
# 进行裁剪 # 进行裁剪
if sensitivities_file is not None: if sensitivities_file is not None:
import paddleslim
from .slim.prune_config import get_sensitivities from .slim.prune_config import get_sensitivities
sensitivities_file = get_sensitivities(sensitivities_file, self, sensitivities_file = get_sensitivities(sensitivities_file, self,
save_dir) save_dir)
from .slim.prune import get_params_ratios, prune_program from .slim.prune import get_params_ratios, prune_program
logging.info(
"Start to prune program with eval_metric_loss = {}".format(
eval_metric_loss))
origin_flops = paddleslim.analysis.flops(self.test_prog)
prune_params_ratios = get_params_ratios( prune_params_ratios = get_params_ratios(
sensitivities_file, eval_metric_loss=eval_metric_loss) sensitivities_file, eval_metric_loss=eval_metric_loss)
prune_program(self, prune_params_ratios) prune_program(self, prune_params_ratios)
current_flops = paddleslim.analysis.flops(self.test_prog)
remaining_ratio = current_flops / origin_flops
logging.info(
"Finish prune program, before FLOPs:{}, after prune FLOPs:{}, remaining ratio:{}"
.format(origin_flops, current_flops, remaining_ratio))
self.status = 'Prune' self.status = 'Prune'
def get_model_info(self): def get_model_info(self):
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import numpy as np import numpy as np
import os.path as osp import os.path as osp
import paddle.fluid as fluid import paddle.fluid as fluid
import paddlehub as hub #import paddlehub as hub
import paddlex import paddlex
sensitivities_data = { sensitivities_data = {
...@@ -105,22 +105,26 @@ def get_sensitivities(flag, model, save_dir): ...@@ -105,22 +105,26 @@ def get_sensitivities(flag, model, save_dir):
model_type) model_type)
url = sensitivities_data[model_type] url = sensitivities_data[model_type]
fname = osp.split(url)[-1] fname = osp.split(url)[-1]
try: paddlex.utils.download(url, path=save_dir)
hub.download(fname, save_path=save_dir)
except Exception as e:
if isinstance(e, hub.ResourceNotFoundError):
raise Exception(
"Resource for model {}(key='{}') not found".format(
model_type, fname))
elif isinstance(e, hub.ServerConnectionError):
raise Exception(
"Cannot get reource for model {}(key='{}'), please check your internet connecgtion"
.format(model_type, fname))
else:
raise Exception(
"Unexpected error, please make sure paddlehub >= 1.6.2 {}".
format(str(e)))
return osp.join(save_dir, fname) return osp.join(save_dir, fname)
# try:
# hub.download(fname, save_path=save_dir)
# except Exception as e:
# if isinstance(e, hub.ResourceNotFoundError):
# raise Exception(
# "Resource for model {}(key='{}') not found".format(
# model_type, fname))
# elif isinstance(e, hub.ServerConnectionError):
# raise Exception(
# "Cannot get reource for model {}(key='{}'), please check your internet connecgtion"
# .format(model_type, fname))
# else:
# raise Exception(
# "Unexpected error, please make sure paddlehub >= 1.6.2 {}".
# format(str(e)))
# return osp.join(save_dir, fname)
else: else:
raise Exception( raise Exception(
"sensitivities need to be defined as directory path or `DEFAULT`(download sensitivities automatically)." "sensitivities need to be defined as directory path or `DEFAULT`(download sensitivities automatically)."
......
import paddlex import paddlex
import paddlehub as hub #import paddlehub as hub
import os import os
import os.path as osp import os.path as osp
...@@ -85,40 +85,53 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir): ...@@ -85,40 +85,53 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
backbone = 'DetResNet50' backbone = 'DetResNet50'
assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format( assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
backbone) backbone)
try: url = image_pretrain[backbone]
hub.download(backbone, save_path=new_save_dir) fname = osp.split(url)[-1].split('.')[0]
except Exception as e: paddlex.utils.download_and_decompress(url, path=new_save_dir)
if isinstance(e, hub.ResourceNotFoundError): return osp.join(new_save_dir, fname)
raise Exception( # try:
"Resource for backbone {} not found".format(backbone)) # hub.download(backbone, save_path=new_save_dir)
elif isinstance(e, hub.ServerConnectionError): # except Exception as e:
raise Exception( # if isinstance(e, hub.ResourceNotFoundError):
"Cannot get reource for backbone {}, please check your internet connecgtion" # raise Exception(
.format(backbone)) # "Resource for backbone {} not found".format(backbone))
else: # elif isinstance(e, hub.ServerConnectionError):
raise Exception( # raise Exception(
"Unexpected error, please make sure paddlehub >= 1.6.2") # "Cannot get reource for backbone {}, please check your internet connecgtion"
return osp.join(new_save_dir, backbone) # .format(backbone))
# else:
# raise Exception(
# "Unexpected error, please make sure paddlehub >= 1.6.2")
# return osp.join(new_save_dir, backbone)
elif flag == 'COCO': elif flag == 'COCO':
new_save_dir = save_dir new_save_dir = save_dir
if hasattr(paddlex, 'pretrain_dir'): if hasattr(paddlex, 'pretrain_dir'):
new_save_dir = paddlex.pretrain_dir new_save_dir = paddlex.pretrain_dir
assert backbone in coco_pretrain, "There is not COCO pretrain weights for {}, you may try ImageNet.".format( url = coco_pretrain[backbone]
backbone) fname = osp.split(url)[-1].split('.')[0]
try: paddlex.utils.download_and_decompress(url, path=new_save_dir)
hub.download(backbone, save_path=new_save_dir) return osp.join(new_save_dir, fname)
except Exception as e:
if isinstance(hub.ResourceNotFoundError):
raise Exception( # new_save_dir = save_dir
"Resource for backbone {} not found".format(backbone)) # if hasattr(paddlex, 'pretrain_dir'):
elif isinstance(hub.ServerConnectionError): # new_save_dir = paddlex.pretrain_dir
raise Exception( # assert backbone in coco_pretrain, "There is not COCO pretrain weights for {}, you may try ImageNet.".format(
"Cannot get reource for backbone {}, please check your internet connecgtion" # backbone)
.format(backbone)) # try:
else: # hub.download(backbone, save_path=new_save_dir)
raise Exception( # except Exception as e:
"Unexpected error, please make sure paddlehub >= 1.6.2") # if isinstance(hub.ResourceNotFoundError):
return osp.join(new_save_dir, backbone) # raise Exception(
# "Resource for backbone {} not found".format(backbone))
# elif isinstance(hub.ServerConnectionError):
# raise Exception(
# "Cannot get reource for backbone {}, please check your internet connecgtion"
# .format(backbone))
# else:
# raise Exception(
# "Unexpected error, please make sure paddlehub >= 1.6.2")
# return osp.join(new_save_dir, backbone)
else: else:
raise Exception( raise Exception(
"pretrain_weights need to be defined as directory path or `IMAGENET` or 'COCO' (download pretrain weights automatically)." "pretrain_weights need to be defined as directory path or `IMAGENET` or 'COCO' (download pretrain weights automatically)."
......
...@@ -29,9 +29,8 @@ setuptools.setup( ...@@ -29,9 +29,8 @@ setuptools.setup(
packages=setuptools.find_packages(), packages=setuptools.find_packages(),
setup_requires=['cython', 'numpy', 'sklearn'], setup_requires=['cython', 'numpy', 'sklearn'],
install_requires=[ install_requires=[
"pycocotools;platform_system!='Windows'", "pycocotools;platform_system!='Windows'", 'pyyaml', 'colorama', 'tqdm',
'pyyaml', 'colorama', 'tqdm', 'visualdl==1.3.0', 'visualdl==1.3.0', 'paddleslim==1.0.1'
'paddleslim==1.0.1', 'paddlehub>=1.6.2'
], ],
classifiers=[ classifiers=[
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册