未验证 提交 d3dd3ba3 编写于 作者: Y yukavio 提交者: GitHub

Add unittest for prune walker (#436)

上级 a412b6fa
...@@ -22,7 +22,8 @@ def conv_bn_layer(input, ...@@ -22,7 +22,8 @@ def conv_bn_layer(input,
stride=1, stride=1,
groups=1, groups=1,
act=None, act=None,
bias=False): bias=False,
use_cudnn=True):
conv = fluid.layers.conv2d( conv = fluid.layers.conv2d(
input=input, input=input,
num_filters=num_filters, num_filters=num_filters,
...@@ -33,7 +34,8 @@ def conv_bn_layer(input, ...@@ -33,7 +34,8 @@ def conv_bn_layer(input,
act=None, act=None,
param_attr=ParamAttr(name=name + "_weights"), param_attr=ParamAttr(name=name + "_weights"),
bias_attr=bias, bias_attr=bias,
name=name + "_out") name=name + "_out",
use_cudnn=use_cudnn)
bn_name = name + "_bn" bn_name = name + "_bn"
return fluid.layers.batch_norm( return fluid.layers.batch_norm(
input=conv, input=conv,
......
...@@ -14,10 +14,9 @@ ...@@ -14,10 +14,9 @@
import sys import sys
sys.path.append("../") sys.path.append("../")
import unittest import unittest
import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddleslim.prune import Pruner from paddleslim.prune import Pruner
from paddleslim.core import GraphWrapper
from paddleslim.prune import conv2d as conv2d_walker
from layers import conv_bn_layer from layers import conv_bn_layer
...@@ -34,30 +33,56 @@ class TestPrune(unittest.TestCase): ...@@ -34,30 +33,56 @@ class TestPrune(unittest.TestCase):
# O: prune input channels # O: prune input channels
with fluid.program_guard(main_program, startup_program): with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16]) input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1") label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv2 = conv_bn_layer(conv1, 8, 3, "conv2") conv1 = conv_bn_layer(input, 8, 3, "conv1", act='relu')
conv2 = conv_bn_layer(conv1, 8, 3, "conv2", act='leaky_relu')
sum1 = conv1 + conv2 sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3") conv3 = conv_bn_layer(sum1, 8, 3, "conv3", act='relu6')
conv4 = conv_bn_layer(conv3, 8, 3, "conv4") conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1 sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5") conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6") sum3 = fluid.layers.sum([sum2, conv5])
conv6 = conv_bn_layer(sum3, 8, 3, "conv6")
sub1 = conv6 - sum3
mult = sub1 * sub1
conv7 = conv_bn_layer(
mult, 8, 3, "Depthwise_Conv7", groups=8, use_cudnn=False)
floored = fluid.layers.floor(conv7)
scaled = fluid.layers.scale(floored)
concated = fluid.layers.concat([scaled, mult], axis=1)
conv8 = conv_bn_layer(concated, 8, 3, "conv8")
feature = fluid.layers.reshape(conv8, [-1, 128, 16])
predict = fluid.layers.fc(input=feature, size=10, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
adam_optimizer = fluid.optimizer.AdamOptimizer(0.01)
avg_cost = fluid.layers.mean(cost)
adam_optimizer.minimize(avg_cost)
shapes = {} params = []
for param in main_program.global_block().all_parameters(): for param in main_program.all_parameters():
shapes[param.name] = param.shape if 'conv' in param.name:
params.append(param.name)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
scope = fluid.Scope() exe.run(startup_program)
exe.run(startup_program, scope=scope) x = np.random.random(size=(10, 3, 16, 16)).astype('float32')
label = np.random.random(size=(10, 1)).astype('int64')
graph = GraphWrapper(main_program) loss_data, = exe.run(main_program,
feed={"image": x,
conv_op = graph.var("conv4_weights").outputs()[0] "label": label},
walker = conv2d_walker(conv_op, []) fetch_list=[cost.name])
walker.prune(graph.var("conv4_weights"), pruned_axis=0, pruned_idx=[]) pruner = Pruner()
print(walker.pruned_params) main_program, _, _ = pruner.prune(
main_program,
fluid.global_scope(),
params=params,
ratios=[0.5] * len(params),
place=place,
lazy=False,
only_graph=False,
param_backup=None,
param_shape_backup=None)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册