提交 19ec2a55 编写于 作者: W wanghaoshuang

Merge branch 'code_doc_review' into 'develop'

Small update

See merge request !81
...@@ -24,7 +24,7 @@ pruner = Pruner() ...@@ -24,7 +24,7 @@ pruner = Pruner()
--- ---
>prune(program, scope, params, ratios, place=None, lazy=False, only_graph=False, param_backup=None, param_shape_backup=None) >prune(program, scope, params, ratios, place=None, lazy=False, only_graph=False, param_backup=False, param_shape_backup=False)
对目标网络的一组卷积层的权重进行裁剪。 对目标网络的一组卷积层的权重进行裁剪。
...@@ -43,7 +43,7 @@ for block in program.blocks: ...@@ -43,7 +43,7 @@ for block in program.blocks:
- **ratios(list<float>):** 用于裁剪`params`的剪切率,类型为列表。该列表长度必须与`params`的长度一致。 - **ratios(list<float>):** 用于裁剪`params`的剪切率,类型为列表。该列表长度必须与`params`的长度一致。
- **place(paddle.fluid.Place):** 待裁剪参数所在的设备位置,可以是`CUDAPlace``CPUPLace`[Place概念介绍]() - **place(paddle.fluid.Place):** 待裁剪参数所在的设备位置,可以是`CUDAPlace``CPUPlace`[Place概念介绍]()
- **lazy(bool):** `lazy`为True时,通过将指定通道的参数置零达到裁剪的目的,参数的`shape保持不变``lazy`为False时,直接将要裁的通道的参数删除,参数的`shape`会发生变化。 - **lazy(bool):** `lazy`为True时,通过将指定通道的参数置零达到裁剪的目的,参数的`shape保持不变``lazy`为False时,直接将要裁的通道的参数删除,参数的`shape`会发生变化。
...@@ -51,7 +51,7 @@ for block in program.blocks: ...@@ -51,7 +51,7 @@ for block in program.blocks:
- **param_backup(bool):** 是否返回对参数值的备份。默认为False。 - **param_backup(bool):** 是否返回对参数值的备份。默认为False。
- **param_shape_backup(bool):** 是否返回对参数`shape`的备份。 - **param_shape_backup(bool):** 是否返回对参数`shape`的备份。默认为False。
**返回:** **返回:**
...@@ -131,8 +131,8 @@ main_program, _, _ = pruner.prune( ...@@ -131,8 +131,8 @@ main_program, _, _ = pruner.prune(
place=place, place=place,
lazy=False, lazy=False,
only_graph=False, only_graph=False,
param_backup=None, param_backup=False,
param_shape_backup=None) param_shape_backup=False)
for param in main_program.global_block().all_parameters(): for param in main_program.global_block().all_parameters():
if "weights" in param.name: if "weights" in param.name:
...@@ -153,7 +153,7 @@ for param in main_program.global_block().all_parameters(): ...@@ -153,7 +153,7 @@ for param in main_program.global_block().all_parameters():
- **program(paddle.fluid.Program):** 待评估的目标网络。更多关于Program的介绍请参考:[Program概念介绍](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/Program_cn.html#program) - **program(paddle.fluid.Program):** 待评估的目标网络。更多关于Program的介绍请参考:[Program概念介绍](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/Program_cn.html#program)
- **place(paddle.fluid.Place):** 待分析的参数所在的设备位置,可以是`CUDAPlace``CPUPLace`[Place概念介绍]() - **place(paddle.fluid.Place):** 待分析的参数所在的设备位置,可以是`CUDAPlace``CPUPlace`[Place概念介绍]()
- **param_names(list<str>):** 待分析的卷积层的参数的名称列表。可以通过以下方式查看模型中所有参数的名称: - **param_names(list<str>):** 待分析的卷积层的参数的名称列表。可以通过以下方式查看模型中所有参数的名称:
......
...@@ -57,10 +57,10 @@ class AutoPruner(object): ...@@ -57,10 +57,10 @@ class AutoPruner(object):
place(fluid.Place): The device place of parameters. place(fluid.Place): The device place of parameters.
params(list<str>): The names of parameters to be pruned. params(list<str>): The names of parameters to be pruned.
init_ratios(list<float>|float): Init ratios used to pruned parameters in `params`. init_ratios(list<float>|float): Init ratios used to pruned parameters in `params`.
List means ratios used for pruning each parameter in `params`. List means ratios used for pruning each parameter in `params`.
The length of `init_ratios` should be equal to length of params when `init_ratios` is a list. The length of `init_ratios` should be equal to length of params when `init_ratios` is a list.
If it is a scalar, all the parameters in `params` will be pruned by uniform ratio. If it is a scalar, all the parameters in `params` will be pruned by uniform ratio.
None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None. None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None.
pruned_flops(float): The percent of FLOPS to be pruned. Default: None. pruned_flops(float): The percent of FLOPS to be pruned. Default: None.
pruned_latency(float): The percent of latency to be pruned. Default: None. pruned_latency(float): The percent of latency to be pruned. Default: None.
server_addr(tuple): A tuple of server ip and server port for controller server. server_addr(tuple): A tuple of server ip and server port for controller server.
...@@ -69,12 +69,14 @@ class AutoPruner(object): ...@@ -69,12 +69,14 @@ class AutoPruner(object):
max_try_times(int): The max number of trying to generate legal tokens. max_try_times(int): The max number of trying to generate legal tokens.
max_client_num(int): The max number of connections of controller server. max_client_num(int): The max number of connections of controller server.
search_steps(int): The steps of searching. search_steps(int): The steps of searching.
max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`. max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`.
The length of `max_ratios` should be equal to length of params when `max_ratios` is a list. List means max ratios for each parameter in `params`.
If it is a scalar, it will used for all the parameters in `params`. The length of `max_ratios` should be equal to length of params when `max_ratios` is a list.
min_ratios(float|list<float>): Min ratios used to pruned parameters in `params`. List means min ratios for each parameter in `params`. If it is a scalar, it will used for all the parameters in `params`.
The length of `min_ratios` should be equal to length of params when `min_ratios` is a list. min_ratios(float|list<float>): Min ratios used to pruned parameters in `params`.
If it is a scalar, it will used for all the parameters in `params`. List means min ratios for each parameter in `params`.
The length of `min_ratios` should be equal to length of params when `min_ratios` is a list.
If it is a scalar, it will used for all the parameters in `params`.
key(str): Identity used in communication between controller server and clients. key(str): Identity used in communication between controller server and clients.
is_server(bool): Whether current host is controller server. Default: True. is_server(bool): Whether current host is controller server. Default: True.
""" """
......
...@@ -32,10 +32,14 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -32,10 +32,14 @@ _logger = get_logger(__name__, level=logging.INFO)
class SensitivePruner(object): class SensitivePruner(object):
def __init__(self, place, eval_func, scope=None, checkpoints=None): def __init__(self, place, eval_func, scope=None, checkpoints=None):
""" """
Pruner used to prune parameters iteratively according to sensitivities of parameters in each step. Pruner used to prune parameters iteratively according to sensitivities
of parameters in each step.
Args: Args:
place(fluid.CUDAPlace | fluid.CPUPlace): The device place where program execute. place(fluid.CUDAPlace | fluid.CPUPlace): The device place where
eval_func(function): A callback function used to evaluate pruned program. The argument of this function is pruned program. And it return a score of given program. program execute.
eval_func(function): A callback function used to evaluate pruned
program. The argument of this function is pruned program.
And it return a score of given program.
scope(fluid.scope): The scope used to execute program. scope(fluid.scope): The scope used to execute program.
""" """
self._eval_func = eval_func self._eval_func = eval_func
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册