未验证 提交 5438acae 编写于 作者: W whs 提交者: GitHub

Fix docs (#827)

上级 df7c9f2e
......@@ -79,7 +79,7 @@ L1NormFilterPruner
from paddleslim import L1NormFilterPruner
net = mobilenet_v1(pretrained=False)
pruner = L1NormFilterPruner(net, [1, 3, 224, 224])
plan = pruner.prun_var("conv2d_26.w_0", [0])
plan = pruner.prune_var("conv2d_26.w_0", [0])
print(f"plan: {plan}")
paddle.summary(net, (1, 3, 224, 224))
..
......@@ -111,7 +111,7 @@ L1NormFilterPruner
from paddleslim import L1NormFilterPruner
net = mobilenet_v1(pretrained=False)
pruner = L1NormFilterPruner(net, [1, 3, 224, 224])
plan = pruner.prun_vars({"conv2d_26.w_0": 0.5}, [0])
plan = pruner.prune_vars({"conv2d_26.w_0": 0.5}, [0])
print(f"plan: {plan}")
paddle.summary(net, (1, 3, 224, 224))
..
......
......@@ -69,27 +69,41 @@ import numpy as np
from paddleslim.dygraph import FilterPruner
class L2NormFilterPruner(FilterPruner):
def __init__(self, model, input_shape, sen_file=None, opt=None):
def __init__(self, model, inputs, sen_file=None, opt=None):
super(L2NormFilterPruner, self).__init__(
model, input_shape, sen_file=sen_file, opt=opt)
def cal_mask(self, var_name, pruned_ratio, group):
value = group[var_name]['value']
pruned_dims = group[var_name]['pruned_dims']
reduce_dims = [
i for i in range(len(value.shape)) if i not in pruned_dims
]
# scores = np.mean(np.abs(value), axis=tuple(reduce_dims))
model, inputs, sen_file=sen_file, opt=opt)
def cal_mask(self, pruned_ratio, collection):
var_name = collection.master_name
pruned_axis = collection.master_axis
value = collection.values[var_name]
groups = 1
for _detail in collection.all_pruning_details():
assert (isinstance(_detail.axis, int))
if _detail.axis == 1:
_groups = _detail.op.attr('groups')
if _groups is not None and _groups > 1:
groups = _groups
break
reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis]
scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims)))
if groups > 1:
scores = scores.reshape([groups, -1])
scores = np.mean(scores, axis=1)
sorted_idx = scores.argsort()
pruned_num = int(round(len(sorted_idx) * pruned_ratio))
pruned_idx = sorted_idx[:pruned_num]
mask_shape = [value.shape[i] for i in pruned_dims]
mask_shape = [value.shape[pruned_axis]]
mask = np.ones(mask_shape, dtype="int32")
if groups > 1:
mask = mask.reshape([groups, -1])
mask[pruned_idx] = 0
return mask
return mask.reshape(mask_shape)
```
如上述代码所示,我们重载了`FilterPruner`基类的`cal_mask`方法,并在`L1NormFilterPruner`代码基础上,修改了计算通道重要性的语句,将其修改为了计算L2Norm的逻辑:
......@@ -147,15 +161,22 @@ import numpy as np
from paddleslim.dygraph import FilterPruner
class FPGMFilterPruner(FilterPruner):
def __init__(self, model, input_shape, sen_file=None, opt=None):
def __init__(self, model, inputs, sen_file=None, opt=None):
super(FPGMFilterPruner, self).__init__(
model, input_shape, sen_file=sen_file, opt=opt)
def cal_mask(self, var_name, pruned_ratio, group):
value = group[var_name]['value']
pruned_dims = group[var_name]['pruned_dims']
assert(pruned_dims == [0])
model, inputs, sen_file=sen_file, opt=opt)
def cal_mask(self, pruned_ratio, collection):
var_name = collection.master_name
pruned_axis = collection.master_axis
value = collection.values[var_name]
groups = 1
for _detail in collection.all_pruning_details():
assert (isinstance(_detail.axis, int))
if _detail.axis == 1:
_groups = _detail.op.attr('groups')
if _groups is not None and _groups > 1:
groups = _groups
break
dist_sum_list = []
for out_i in range(value.shape[0]):
......@@ -163,13 +184,19 @@ class FPGMFilterPruner(FilterPruner):
dist_sum_list.append(dist_sum)
scores = np.array(dist_sum_list)
if groups > 1:
scores = scores.reshape([groups, -1])
scores = np.mean(scores, axis=1)
sorted_idx = scores.argsort()
pruned_num = int(round(len(sorted_idx) * pruned_ratio))
pruned_idx = sorted_idx[:pruned_num]
mask_shape = [value.shape[i] for i in pruned_dims]
mask_shape = [value.shape[pruned_axis]]
mask = np.ones(mask_shape, dtype="int32")
if groups > 1:
mask = mask.reshape([groups, -1])
mask[pruned_idx] = 0
return mask
return mask.reshape(mask_shape)
def get_distance_sum(self, value, out_idx):
w = value.view()
......@@ -210,6 +237,7 @@ optimizer = paddle.optimizer.Momentum(
inputs = [Input([None, 3, 32, 32], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
net = mobilenet_v1(pretrained=False)
model = paddle.Model(net, inputs, labels)
model.prepare(
optimizer,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册