# 自定义剪裁 ## 1. 概述 该教程介绍如果在PaddleSlim提供的接口基础上快速自定义`Filters`剪裁策略。 在PaddleSlim中,所有剪裁`Filters`的`Pruner`继承自基类`FilterPruner`。`FilterPruner`中自定义了一系列通用方法,用户只需要重载实现`FilterPruner`的`cal_mask`接口,`cal_mask`接口定义如下: ```python def cal_mask(self, var_name, pruned_ratio, group): raise NotImplemented() ``` `cal_mask`接口接受的参数说明如下: - **var_name:** 要剪裁的目标变量,一般为卷积层的权重参数的名称。在Paddle中,卷积层的权重参数格式为`[output_channel, input_channel, kernel_size, kernel_size]`,其中,`output_channel`为当前卷积层的输出通道数,`input_channel`为当前卷积层的输入通道数,`kernel_size`为卷积核大小。 - **pruned_ratio:** 对名称为`var_name`的变量的剪裁率。 - **group:** 与待裁目标变量相关的所有变量的信息。 ### 1.1 Group概念介绍 ![](./self_define_filter_pruning/1-1.png)
图1-1 卷积层关联关系示意图
如图1-1所示,在给定模型中有两个卷积层,第一个卷积层有3个`filters`,第二个卷积层有2个`filters`。如果删除第一个卷积绿色的`filter`,第一个卷积的输出特征图的通道数也会减1,同时需要删掉第二个卷积层绿色的`kernels`。如上所述的两个卷积共同组成一个group,表示如下: ``` group = { "conv_1.weight":{ "pruned_dims": [0], "layer": conv_layer_1, "var": var_instance_1, "value": var_value_1, }, "conv_2.weight":{ "pruned_dims": [1], "layer": conv_layer_2, "var": var_instance_2, "value": var_value_2, } } ``` 在上述表示`group`的数据结构示例中,`conv_1.weight`为第一个卷积权重参数的名称,其对应的value也是一个dict实例,存放了当前参数的一些信息,包括: - **pruned_dims:** 类型为`list`,表示当前参数在哪些维度上被裁。 - **layer:** 类型为[paddle.nn.Layer](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/dygraph_cn/Layer_cn.html#layer), 表示当前参数所在`Layer`。 - **var:** 类型为[paddle.Tensor](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/Variable_cn.html#variable), 表示当前参数对应的实例。 - **value:** 类型为numpy.array类型,待裁参数所存的具体数值,方便开发者使用。 图1-2为更复杂的情况,其中,`Add`操作的所有输入的通道数需要保持一致,`Concat`操作的输出通道数的调整可能会影响到所有输入的通道数,因此`group`中可能包含多个卷积的参数或变量,可以是:卷积权重、卷积bias、`batch norm`相关参数等。 ![](./self_define_filter_pruning/1-2.png)
图1-2 复杂网络示例
## 2. 定义模型 ```python import paddle from paddle.vision.models import mobilenet_v1 net = mobilenet_v1(pretrained=False) paddle.summary(net, (1, 3, 32, 32)) ``` ## 3. L2NormFilterPruner 该小节参考`L1NormFilterPruner`实现`L2NormFilterPruner`,方式为集成`FIlterPruner`并重载`cal_mask`接口。代码如下所示: ```python import numpy as np from paddleslim.dygraph import FilterPruner class L2NormFilterPruner(FilterPruner): def __init__(self, model, inputs, sen_file=None, opt=None): super(L2NormFilterPruner, self).__init__( model, inputs, sen_file=sen_file, opt=opt) def cal_mask(self, pruned_ratio, collection): var_name = collection.master_name pruned_axis = collection.master_axis value = collection.values[var_name] groups = 1 for _detail in collection.all_pruning_details(): assert (isinstance(_detail.axis, int)) if _detail.axis == 1: _groups = _detail.op.attr('groups') if _groups is not None and _groups > 1: groups = _groups break reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis] scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims))) if groups > 1: scores = scores.reshape([groups, -1]) scores = np.mean(scores, axis=1) sorted_idx = scores.argsort() pruned_num = int(round(len(sorted_idx) * pruned_ratio)) pruned_idx = sorted_idx[:pruned_num] mask_shape = [value.shape[pruned_axis]] mask = np.ones(mask_shape, dtype="int32") if groups > 1: mask = mask.reshape([groups, -1]) mask[pruned_idx] = 0 return mask.reshape(mask_shape) ``` 如上述代码所示,我们重载了`FilterPruner`基类的`cal_mask`方法,并在`L1NormFilterPruner`代码基础上,修改了计算通道重要性的语句,将其修改为了计算L2Norm的逻辑: ``` scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims))) ``` 接下来定义一个`L2NormFilterPruner`对象,并调用`prune_var`方法对单个卷积层进行剪裁,`prune_var`方法继承自`FilterPruner`,开发者不用再重载实现。 按以下代码调用`prune_var`方法后,参数名称为`conv2d_0.w_0`的卷积层会被裁掉50%的`filters`,与之相关关联的后续卷积和`BatchNorm`相关的参数也会被剪裁。`prune_var`不仅会对待裁模型进行`inplace`的裁剪,还会返回保存裁剪详细信息的`PruningPlan`对象,用户可以直接打印`PruningPlan`对象内容。 最后,可以通过调用`Pruner`的`restore`方法,将已被裁剪的模型恢复到初始状态。 ```python pruner = L2NormFilterPruner(net, [1, 3, 32, 32]) plan = pruner.prune_var("conv2d_0.w_0", 0, 0.5) print(plan) pruner.restore() ``` ## 4. FPGMFilterPruner 参考:[Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration](https://arxiv.org/abs/1811.00250) ### 4.1 原理介绍 如图4-1所示,传统基于Norm统计方法的filter重要性评估方式的有效性取决于卷积层权重数值的分布,比较理想的分布式要满足两个条件: - 偏差(deviation)要大 - 最小值要小(图4-1中v1) 满足上述条件后,我们才能裁掉更多Norm统计值较小的参数,如图4-1中红色部分所示。 ![](./self_define_filter_pruning/4-1.png)
图 4-1
而现实中的模型的权重分布如图4-2中绿色分布所示,总是有较小的偏差或较大的最小值。 ![](./self_define_filter_pruning/4-2.png)
图 4-2
考虑到上述传统方法的缺点,FPGM则用filter之间的几何距离来表示重要性,其遵循的原则就是:几何距离比较近的filters,作用也相近。 如图4-3所示,有3个filters,将各个filter展开为向量,并两两计算几何距离。其中,绿色filter的重要性得分就是它到其它两个filter的距离和,即0.7071+0.5831=1.2902。同理算出另外两个filters的得分,绿色filter得分最高,其重要性最高。 ![](./self_define_filter_pruning/4-3.png)
图 4-3
### 4.2 实现 以下代码通过继承`FilterPruner`并重载`cal_mask`实现了`FPGMFilterPruner`,其中,`get_distance_sum`用于计算第`out_idx`个filter的重要性。 ```python import numpy as np from paddleslim.dygraph import FilterPruner class FPGMFilterPruner(FilterPruner): def __init__(self, model, inputs, sen_file=None, opt=None): super(FPGMFilterPruner, self).__init__( model, inputs, sen_file=sen_file, opt=opt) def cal_mask(self, pruned_ratio, collection): var_name = collection.master_name pruned_axis = collection.master_axis value = collection.values[var_name] groups = 1 for _detail in collection.all_pruning_details(): assert (isinstance(_detail.axis, int)) if _detail.axis == 1: _groups = _detail.op.attr('groups') if _groups is not None and _groups > 1: groups = _groups break dist_sum_list = [] for out_i in range(value.shape[0]): dist_sum = self.get_distance_sum(value, out_i) dist_sum_list.append(dist_sum) scores = np.array(dist_sum_list) if groups > 1: scores = scores.reshape([groups, -1]) scores = np.mean(scores, axis=1) sorted_idx = scores.argsort() pruned_num = int(round(len(sorted_idx) * pruned_ratio)) pruned_idx = sorted_idx[:pruned_num] mask_shape = [value.shape[pruned_axis]] mask = np.ones(mask_shape, dtype="int32") if groups > 1: mask = mask.reshape([groups, -1]) mask[pruned_idx] = 0 return mask.reshape(mask_shape) def get_distance_sum(self, value, out_idx): w = value.view() w.shape = value.shape[0], np.product(value.shape[1:]) selected_filter = np.tile(w[out_idx], (w.shape[0], 1)) x = w - selected_filter x = np.sqrt(np.sum(x * x, -1)) return x.sum() ``` 接下来声明一个FPGMFilterPruner对象进行验证: ```python pruner = FPGMFilterPruner(net, [1, 3, 32, 32]) plan = pruner.prune_var("conv2d_0.w_0", 0, 0.5) print(plan) pruner.restore() ``` ## 5. 敏感度剪裁 在第3节和第4节,开发者自定义实现的`L2NormFilterPruner`和`FPGMFilterPruner`也继承了`FilterPruner`的敏感度计算方法`sensitive`和剪裁方法`sensitive_prune`。 ### 5.1 预训练 ```python import paddle.vision.transforms as T transform = T.Compose([ T.Transpose(), T.Normalize([127.5], [127.5]) ]) train_dataset = paddle.vision.datasets.Cifar10(mode="train", backend="cv2",transform=transform) val_dataset = paddle.vision.datasets.Cifar10(mode="test", backend="cv2",transform=transform) from paddle.static import InputSpec as Input optimizer = paddle.optimizer.Momentum( learning_rate=0.1, parameters=net.parameters()) inputs = [Input([None, 3, 32, 32], 'float32', name='image')] labels = [Input([None, 1], 'int64', name='label')] net = mobilenet_v1(pretrained=False) model = paddle.Model(net, inputs, labels) model.prepare( optimizer, paddle.nn.CrossEntropyLoss(), paddle.metric.Accuracy(topk=(1, 5))) model.fit(train_dataset, epochs=2, batch_size=128, verbose=1) result = model.evaluate(val_dataset,batch_size=128, log_freq=10) print(result) ``` ### 5.2 计算敏感度 ```python pruner = FPGMFilterPruner(net, [1, 3, 32, 32], opt=optimizer) def eval_fn(): result = model.evaluate( val_dataset, batch_size=128) return result['acc_top1'] sen = pruner.sensitive(eval_func=eval_fn, sen_file="./fpgm_sen.pickle") print(sen) ``` ### 5.3 剪裁 ```python from paddleslim.analysis import dygraph_flops flops = dygraph_flops(net, [1, 3, 32, 32]) print(f"FLOPs before pruning: {flops}") plan = pruner.sensitive_prune(0.4, skip_vars=["conv2d_26.w_0"]) flops = dygraph_flops(net, [1, 3, 32, 32]) print(f"FLOPs after pruning: {flops}") print(f"Pruned FLOPs: {round(plan.pruned_flops*100, 2)}%") result = model.evaluate(val_dataset,batch_size=128, log_freq=10) print(f"before fine-tuning: {result}") ``` ### 5.4 重训练 ```python model.fit(train_dataset, epochs=2, batch_size=128, verbose=1) result = model.evaluate(val_dataset,batch_size=128, log_freq=10) print(f"after fine-tuning: {result}") ```