diff --git a/doc/prune_api.md b/doc/prune_api.md index bb88eb468c2725fd88a12fb8b63fa6575cb8ab5c..f75520a39f9c83652c8fce44c60943533d14bcbf 100644 --- a/doc/prune_api.md +++ b/doc/prune_api.md @@ -24,7 +24,7 @@ pruner = Pruner() --- ->prune(program, scope, params, ratios, place=None, lazy=False, only_graph=False, param_backup=None, param_shape_backup=None) +>prune(program, scope, params, ratios, place=None, lazy=False, only_graph=False, param_backup=False, param_shape_backup=False) 对目标网络的一组卷积层的权重进行裁剪。 @@ -43,7 +43,7 @@ for block in program.blocks: - **ratios(list):** 用于裁剪`params`的剪切率,类型为列表。该列表长度必须与`params`的长度一致。 -- **place(paddle.fluid.Place):** 待裁剪参数所在的设备位置,可以是`CUDAPlace`或`CPUPLace`。[Place概念介绍]() +- **place(paddle.fluid.Place):** 待裁剪参数所在的设备位置,可以是`CUDAPlace`或`CPUPlace`。[Place概念介绍]() - **lazy(bool):** `lazy`为True时,通过将指定通道的参数置零达到裁剪的目的,参数的`shape保持不变`;`lazy`为False时,直接将要裁的通道的参数删除,参数的`shape`会发生变化。 @@ -51,7 +51,7 @@ for block in program.blocks: - **param_backup(bool):** 是否返回对参数值的备份。默认为False。 -- **param_shape_backup(bool):** 是否返回对参数`shape`的备份。 +- **param_shape_backup(bool):** 是否返回对参数`shape`的备份。默认为False。 **返回:** @@ -131,8 +131,8 @@ main_program, _, _ = pruner.prune( place=place, lazy=False, only_graph=False, - param_backup=None, - param_shape_backup=None) + param_backup=False, + param_shape_backup=False) for param in main_program.global_block().all_parameters(): if "weights" in param.name: @@ -153,7 +153,7 @@ for param in main_program.global_block().all_parameters(): - **program(paddle.fluid.Program):** 待评估的目标网络。更多关于Program的介绍请参考:[Program概念介绍](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/Program_cn.html#program)。 -- **place(paddle.fluid.Place):** 待分析的参数所在的设备位置,可以是`CUDAPlace`或`CPUPLace`。[Place概念介绍]() +- **place(paddle.fluid.Place):** 待分析的参数所在的设备位置,可以是`CUDAPlace`或`CPUPlace`。[Place概念介绍]() - **param_names(list):** 待分析的卷积层的参数的名称列表。可以通过以下方式查看模型中所有参数的名称: diff --git a/paddleslim/prune/auto_pruner.py b/paddleslim/prune/auto_pruner.py index 8420d0c1b5d6ca1d0401ba249ebfa980037907d0..575d93c546e10717dd294004a8be80e55550ae4b 100644 --- a/paddleslim/prune/auto_pruner.py +++ b/paddleslim/prune/auto_pruner.py @@ -57,10 +57,10 @@ class AutoPruner(object): place(fluid.Place): The device place of parameters. params(list): The names of parameters to be pruned. init_ratios(list|float): Init ratios used to pruned parameters in `params`. - List means ratios used for pruning each parameter in `params`. - The length of `init_ratios` should be equal to length of params when `init_ratios` is a list. - If it is a scalar, all the parameters in `params` will be pruned by uniform ratio. - None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None. + List means ratios used for pruning each parameter in `params`. + The length of `init_ratios` should be equal to length of params when `init_ratios` is a list. + If it is a scalar, all the parameters in `params` will be pruned by uniform ratio. + None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None. pruned_flops(float): The percent of FLOPS to be pruned. Default: None. pruned_latency(float): The percent of latency to be pruned. Default: None. server_addr(tuple): A tuple of server ip and server port for controller server. @@ -69,12 +69,14 @@ class AutoPruner(object): max_try_times(int): The max number of trying to generate legal tokens. max_client_num(int): The max number of connections of controller server. search_steps(int): The steps of searching. - max_ratios(float|list): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`. - The length of `max_ratios` should be equal to length of params when `max_ratios` is a list. - If it is a scalar, it will used for all the parameters in `params`. - min_ratios(float|list): Min ratios used to pruned parameters in `params`. List means min ratios for each parameter in `params`. - The length of `min_ratios` should be equal to length of params when `min_ratios` is a list. - If it is a scalar, it will used for all the parameters in `params`. + max_ratios(float|list): Max ratios used to pruned parameters in `params`. + List means max ratios for each parameter in `params`. + The length of `max_ratios` should be equal to length of params when `max_ratios` is a list. + If it is a scalar, it will used for all the parameters in `params`. + min_ratios(float|list): Min ratios used to pruned parameters in `params`. + List means min ratios for each parameter in `params`. + The length of `min_ratios` should be equal to length of params when `min_ratios` is a list. + If it is a scalar, it will used for all the parameters in `params`. key(str): Identity used in communication between controller server and clients. is_server(bool): Whether current host is controller server. Default: True. """ diff --git a/paddleslim/prune/sensitive_pruner.py b/paddleslim/prune/sensitive_pruner.py index 823b9264108055cac8604d8a351497b94591fcb4..197f54c8c5fbd009c57ae009bc876949572be574 100644 --- a/paddleslim/prune/sensitive_pruner.py +++ b/paddleslim/prune/sensitive_pruner.py @@ -32,10 +32,14 @@ _logger = get_logger(__name__, level=logging.INFO) class SensitivePruner(object): def __init__(self, place, eval_func, scope=None, checkpoints=None): """ - Pruner used to prune parameters iteratively according to sensitivities of parameters in each step. + Pruner used to prune parameters iteratively according to sensitivities + of parameters in each step. Args: - place(fluid.CUDAPlace | fluid.CPUPlace): The device place where program execute. - eval_func(function): A callback function used to evaluate pruned program. The argument of this function is pruned program. And it return a score of given program. + place(fluid.CUDAPlace | fluid.CPUPlace): The device place where + program execute. + eval_func(function): A callback function used to evaluate pruned + program. The argument of this function is pruned program. + And it return a score of given program. scope(fluid.scope): The scope used to execute program. """ self._eval_func = eval_func