From 5438acae066cc04aeee68fa0652a38c1473fd050 Mon Sep 17 00:00:00 2001 From: whs Date: Tue, 29 Jun 2021 18:37:35 +0800 Subject: [PATCH] Fix docs (#827) --- .../dygraph/pruners/l1norm_filter_pruner.rst | 4 +- .../dygraph/self_defined_filter_pruning.md | 76 +++++++++++++------ 2 files changed, 54 insertions(+), 26 deletions(-) diff --git a/docs/zh_cn/api_cn/dygraph/pruners/l1norm_filter_pruner.rst b/docs/zh_cn/api_cn/dygraph/pruners/l1norm_filter_pruner.rst index b03122c8..d80097c5 100644 --- a/docs/zh_cn/api_cn/dygraph/pruners/l1norm_filter_pruner.rst +++ b/docs/zh_cn/api_cn/dygraph/pruners/l1norm_filter_pruner.rst @@ -79,7 +79,7 @@ L1NormFilterPruner from paddleslim import L1NormFilterPruner net = mobilenet_v1(pretrained=False) pruner = L1NormFilterPruner(net, [1, 3, 224, 224]) - plan = pruner.prun_var("conv2d_26.w_0", [0]) + plan = pruner.prune_var("conv2d_26.w_0", [0]) print(f"plan: {plan}") paddle.summary(net, (1, 3, 224, 224)) .. @@ -111,7 +111,7 @@ L1NormFilterPruner from paddleslim import L1NormFilterPruner net = mobilenet_v1(pretrained=False) pruner = L1NormFilterPruner(net, [1, 3, 224, 224]) - plan = pruner.prun_vars({"conv2d_26.w_0": 0.5}, [0]) + plan = pruner.prune_vars({"conv2d_26.w_0": 0.5}, [0]) print(f"plan: {plan}") paddle.summary(net, (1, 3, 224, 224)) .. diff --git a/docs/zh_cn/tutorials/pruning/dygraph/self_defined_filter_pruning.md b/docs/zh_cn/tutorials/pruning/dygraph/self_defined_filter_pruning.md index 1fdf522b..67701066 100644 --- a/docs/zh_cn/tutorials/pruning/dygraph/self_defined_filter_pruning.md +++ b/docs/zh_cn/tutorials/pruning/dygraph/self_defined_filter_pruning.md @@ -69,27 +69,41 @@ import numpy as np from paddleslim.dygraph import FilterPruner class L2NormFilterPruner(FilterPruner): - - def __init__(self, model, input_shape, sen_file=None, opt=None): + def __init__(self, model, inputs, sen_file=None, opt=None): super(L2NormFilterPruner, self).__init__( - model, input_shape, sen_file=sen_file, opt=opt) - - def cal_mask(self, var_name, pruned_ratio, group): - value = group[var_name]['value'] - pruned_dims = group[var_name]['pruned_dims'] - reduce_dims = [ - i for i in range(len(value.shape)) if i not in pruned_dims - ] - - # scores = np.mean(np.abs(value), axis=tuple(reduce_dims)) + model, inputs, sen_file=sen_file, opt=opt) + + def cal_mask(self, pruned_ratio, collection): + var_name = collection.master_name + pruned_axis = collection.master_axis + value = collection.values[var_name] + groups = 1 + for _detail in collection.all_pruning_details(): + assert (isinstance(_detail.axis, int)) + if _detail.axis == 1: + _groups = _detail.op.attr('groups') + if _groups is not None and _groups > 1: + groups = _groups + break + + reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis] scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims))) + if groups > 1: + scores = scores.reshape([groups, -1]) + scores = np.mean(scores, axis=1) + sorted_idx = scores.argsort() pruned_num = int(round(len(sorted_idx) * pruned_ratio)) pruned_idx = sorted_idx[:pruned_num] - mask_shape = [value.shape[i] for i in pruned_dims] + + mask_shape = [value.shape[pruned_axis]] mask = np.ones(mask_shape, dtype="int32") + if groups > 1: + mask = mask.reshape([groups, -1]) mask[pruned_idx] = 0 - return mask + return mask.reshape(mask_shape) + + ``` 如上述代码所示,我们重载了`FilterPruner`基类的`cal_mask`方法,并在`L1NormFilterPruner`代码基础上,修改了计算通道重要性的语句,将其修改为了计算L2Norm的逻辑: @@ -147,15 +161,22 @@ import numpy as np from paddleslim.dygraph import FilterPruner class FPGMFilterPruner(FilterPruner): - - def __init__(self, model, input_shape, sen_file=None, opt=None): + def __init__(self, model, inputs, sen_file=None, opt=None): super(FPGMFilterPruner, self).__init__( - model, input_shape, sen_file=sen_file, opt=opt) - - def cal_mask(self, var_name, pruned_ratio, group): - value = group[var_name]['value'] - pruned_dims = group[var_name]['pruned_dims'] - assert(pruned_dims == [0]) + model, inputs, sen_file=sen_file, opt=opt) + + def cal_mask(self, pruned_ratio, collection): + var_name = collection.master_name + pruned_axis = collection.master_axis + value = collection.values[var_name] + groups = 1 + for _detail in collection.all_pruning_details(): + assert (isinstance(_detail.axis, int)) + if _detail.axis == 1: + _groups = _detail.op.attr('groups') + if _groups is not None and _groups > 1: + groups = _groups + break dist_sum_list = [] for out_i in range(value.shape[0]): @@ -163,13 +184,19 @@ class FPGMFilterPruner(FilterPruner): dist_sum_list.append(dist_sum) scores = np.array(dist_sum_list) + if groups > 1: + scores = scores.reshape([groups, -1]) + scores = np.mean(scores, axis=1) + sorted_idx = scores.argsort() pruned_num = int(round(len(sorted_idx) * pruned_ratio)) pruned_idx = sorted_idx[:pruned_num] - mask_shape = [value.shape[i] for i in pruned_dims] + mask_shape = [value.shape[pruned_axis]] mask = np.ones(mask_shape, dtype="int32") + if groups > 1: + mask = mask.reshape([groups, -1]) mask[pruned_idx] = 0 - return mask + return mask.reshape(mask_shape) def get_distance_sum(self, value, out_idx): w = value.view() @@ -210,6 +237,7 @@ optimizer = paddle.optimizer.Momentum( inputs = [Input([None, 3, 32, 32], 'float32', name='image')] labels = [Input([None, 1], 'int64', name='label')] +net = mobilenet_v1(pretrained=False) model = paddle.Model(net, inputs, labels) model.prepare( optimizer, -- GitLab