提交 19ec2a55 编写于 作者: W wanghaoshuang

Merge branch 'code_doc_review' into 'develop'

Small update

See merge request !81
......@@ -24,7 +24,7 @@ pruner = Pruner()
---
>prune(program, scope, params, ratios, place=None, lazy=False, only_graph=False, param_backup=None, param_shape_backup=None)
>prune(program, scope, params, ratios, place=None, lazy=False, only_graph=False, param_backup=False, param_shape_backup=False)
对目标网络的一组卷积层的权重进行裁剪。
......@@ -43,7 +43,7 @@ for block in program.blocks:
- **ratios(list<float>):** 用于裁剪`params`的剪切率,类型为列表。该列表长度必须与`params`的长度一致。
- **place(paddle.fluid.Place):** 待裁剪参数所在的设备位置,可以是`CUDAPlace``CPUPLace`[Place概念介绍]()
- **place(paddle.fluid.Place):** 待裁剪参数所在的设备位置,可以是`CUDAPlace``CPUPlace`[Place概念介绍]()
- **lazy(bool):** `lazy`为True时,通过将指定通道的参数置零达到裁剪的目的,参数的`shape保持不变``lazy`为False时,直接将要裁的通道的参数删除,参数的`shape`会发生变化。
......@@ -51,7 +51,7 @@ for block in program.blocks:
- **param_backup(bool):** 是否返回对参数值的备份。默认为False。
- **param_shape_backup(bool):** 是否返回对参数`shape`的备份。
- **param_shape_backup(bool):** 是否返回对参数`shape`的备份。默认为False。
**返回:**
......@@ -131,8 +131,8 @@ main_program, _, _ = pruner.prune(
place=place,
lazy=False,
only_graph=False,
param_backup=None,
param_shape_backup=None)
param_backup=False,
param_shape_backup=False)
for param in main_program.global_block().all_parameters():
if "weights" in param.name:
......@@ -153,7 +153,7 @@ for param in main_program.global_block().all_parameters():
- **program(paddle.fluid.Program):** 待评估的目标网络。更多关于Program的介绍请参考:[Program概念介绍](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/Program_cn.html#program)
- **place(paddle.fluid.Place):** 待分析的参数所在的设备位置,可以是`CUDAPlace``CPUPLace`[Place概念介绍]()
- **place(paddle.fluid.Place):** 待分析的参数所在的设备位置,可以是`CUDAPlace``CPUPlace`[Place概念介绍]()
- **param_names(list<str>):** 待分析的卷积层的参数的名称列表。可以通过以下方式查看模型中所有参数的名称:
......
......@@ -69,10 +69,12 @@ class AutoPruner(object):
max_try_times(int): The max number of trying to generate legal tokens.
max_client_num(int): The max number of connections of controller server.
search_steps(int): The steps of searching.
max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`.
max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`.
List means max ratios for each parameter in `params`.
The length of `max_ratios` should be equal to length of params when `max_ratios` is a list.
If it is a scalar, it will used for all the parameters in `params`.
min_ratios(float|list<float>): Min ratios used to pruned parameters in `params`. List means min ratios for each parameter in `params`.
min_ratios(float|list<float>): Min ratios used to pruned parameters in `params`.
List means min ratios for each parameter in `params`.
The length of `min_ratios` should be equal to length of params when `min_ratios` is a list.
If it is a scalar, it will used for all the parameters in `params`.
key(str): Identity used in communication between controller server and clients.
......
......@@ -32,10 +32,14 @@ _logger = get_logger(__name__, level=logging.INFO)
class SensitivePruner(object):
def __init__(self, place, eval_func, scope=None, checkpoints=None):
"""
Pruner used to prune parameters iteratively according to sensitivities of parameters in each step.
Pruner used to prune parameters iteratively according to sensitivities
of parameters in each step.
Args:
place(fluid.CUDAPlace | fluid.CPUPlace): The device place where program execute.
eval_func(function): A callback function used to evaluate pruned program. The argument of this function is pruned program. And it return a score of given program.
place(fluid.CUDAPlace | fluid.CPUPlace): The device place where
program execute.
eval_func(function): A callback function used to evaluate pruned
program. The argument of this function is pruned program.
And it return a score of given program.
scope(fluid.scope): The scope used to execute program.
"""
self._eval_func = eval_func
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册