提交 d3bc72ea 编写于 作者: W whs 提交者: GitHub

Revert "add support of conditional block for pruning (#450)"

This reverts commit 07f7bffb.
上级 3d2a5924
...@@ -357,38 +357,9 @@ class GraphWrapper(object): ...@@ -357,38 +357,9 @@ class GraphWrapper(object):
Update the groups of convolution layer according to current filters. Update the groups of convolution layer according to current filters.
It is used after loading pruned parameters from file. It is used after loading pruned parameters from file.
""" """
head_op = []
visited = []
for op in self.ops(): for op in self.ops():
if op.type() != 'conditional_block': if op.type() != 'conditional_block':
if len(self.pre_ops(op)) == 0:
head_op.append(op)
candidate_op = self.ops()
def recursive_infer(op, infer=False):
if op in candidate_op:
if op.type() != 'conditional_block':
if infer:
op._op.desc.infer_shape(op._op.block.desc)
else:
visited.append(op)
candidate_op.remove(op)
for next_op in self.next_ops(op):
recursive_infer(next_op)
# Find ops which not in the DAG, some ops, such as optimizer op,
# should be infered before normal cumputation ops.
for op in head_op:
recursive_infer(op, infer=False)
# Infer ops which not in the DAG firstly.
candidate_op = self.ops()
for op in candidate_op:
if op not in visited and op.type() != 'conditional_block':
op._op.desc.infer_shape(op._op.block.desc) op._op.desc.infer_shape(op._op.block.desc)
# Infer the remain ops in topological order.
for op in head_op:
recursive_infer(op, infer=True)
def update_groups_of_conv(self): def update_groups_of_conv(self):
for op in self.ops(): for op in self.ops():
......
...@@ -54,22 +54,10 @@ def collect_convs(params, graph, visited={}): ...@@ -54,22 +54,10 @@ def collect_convs(params, graph, visited={}):
for param in params: for param in params:
pruned_params = [] pruned_params = []
param = graph.var(param) param = graph.var(param)
conv_op = param.outputs()[0]
target_op = param.outputs()[0] cls = PRUNE_WORKER.get(conv_op.type())
if target_op.type() == 'conditional_block': walker = cls(conv_op, pruned_params=pruned_params, visited=visited)
for op in param.outputs():
if op.type() in PRUNE_WORKER._module_dict.keys():
cls = PRUNE_WORKER.get(op.type())
walker = cls(op,
pruned_params=pruned_params,
visited=visited)
break
else:
cls = PRUNE_WORKER.get(target_op.type())
walker = cls(target_op,
pruned_params=pruned_params,
visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=[0]) walker.prune(param, pruned_axis=0, pruned_idx=[0])
groups.append(pruned_params) groups.append(pruned_params)
visited = set() visited = set()
......
...@@ -56,7 +56,6 @@ class PruneWorker(object): ...@@ -56,7 +56,6 @@ class PruneWorker(object):
def _visit(self, var, pruned_axis): def _visit(self, var, pruned_axis):
key = "_".join([str(self.op.idx()), var.name()]) key = "_".join([str(self.op.idx()), var.name()])
key = "_".join([key, self.op.all_inputs()[0].name()])
if pruned_axis not in self.visited: if pruned_axis not in self.visited:
self.visited[pruned_axis] = {} self.visited[pruned_axis] = {}
if key in self.visited[pruned_axis]: if key in self.visited[pruned_axis]:
......
...@@ -15,13 +15,10 @@ import sys ...@@ -15,13 +15,10 @@ import sys
sys.path.append("../") sys.path.append("../")
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddleslim.prune import Pruner from paddleslim.prune import Pruner
from static_case import StaticCase from static_case import StaticCase
from layers import conv_bn_layer from layers import conv_bn_layer
import random
from paddleslim.core import GraphWrapper
class TestPrune(StaticCase): class TestPrune(StaticCase):
...@@ -45,29 +42,7 @@ class TestPrune(StaticCase): ...@@ -45,29 +42,7 @@ class TestPrune(StaticCase):
conv4 = conv_bn_layer(conv3, 8, 3, "conv4") conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1 sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5") conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
sum3 = fluid.layers.sum([sum2, conv5])
flag = fluid.layers.fill_constant([1], value=1, dtype='int32')
rand_flag = paddle.randint(2, dtype='int32')
cond = fluid.layers.less_than(x=flag, y=rand_flag)
cond_output = fluid.layers.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=False,
name='cond_output')
def cond_block1():
cond_conv = conv_bn_layer(conv5, 8, 3, "conv_cond1_1")
fluid.layers.assign(input=cond_conv, output=cond_output)
def cond_block2():
cond_conv1 = conv_bn_layer(conv5, 8, 3, "conv_cond2_1")
cond_conv2 = conv_bn_layer(cond_conv1, 8, 3, "conv_cond2_2")
fluid.layers.assign(input=cond_conv2, output=cond_output)
fluid.layers.cond(cond, cond_block1, cond_block2)
sum3 = fluid.layers.sum([sum2, cond_output])
conv6 = conv_bn_layer(sum3, 8, 3, "conv6") conv6 = conv_bn_layer(sum3, 8, 3, "conv6")
sub1 = conv6 - sum3 sub1 = conv6 - sum3
mult = sub1 * sub1 mult = sub1 * sub1
...@@ -77,7 +52,8 @@ class TestPrune(StaticCase): ...@@ -77,7 +52,8 @@ class TestPrune(StaticCase):
scaled = fluid.layers.scale(floored) scaled = fluid.layers.scale(floored)
concated = fluid.layers.concat([scaled, mult], axis=1) concated = fluid.layers.concat([scaled, mult], axis=1)
conv8 = conv_bn_layer(concated, 8, 3, "conv8") conv8 = conv_bn_layer(concated, 8, 3, "conv8")
predict = fluid.layers.fc(input=conv8, size=10, act='softmax') feature = fluid.layers.reshape(conv8, [-1, 128, 16])
predict = fluid.layers.fc(input=feature, size=10, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label) cost = fluid.layers.cross_entropy(input=predict, label=label)
adam_optimizer = fluid.optimizer.AdamOptimizer(0.01) adam_optimizer = fluid.optimizer.AdamOptimizer(0.01)
avg_cost = fluid.layers.mean(cost) avg_cost = fluid.layers.mean(cost)
...@@ -87,10 +63,8 @@ class TestPrune(StaticCase): ...@@ -87,10 +63,8 @@ class TestPrune(StaticCase):
for param in main_program.all_parameters(): for param in main_program.all_parameters():
if 'conv' in param.name: if 'conv' in param.name:
params.append(param.name) params.append(param.name)
#TODO: To support pruning convolution before fc layer.
params.remove('conv8_weights')
place = fluid.CUDAPlace(0) place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_program) exe.run(startup_program)
x = np.random.random(size=(10, 3, 16, 16)).astype('float32') x = np.random.random(size=(10, 3, 16, 16)).astype('float32')
...@@ -111,11 +85,6 @@ class TestPrune(StaticCase): ...@@ -111,11 +85,6 @@ class TestPrune(StaticCase):
param_backup=None, param_backup=None,
param_shape_backup=None) param_shape_backup=None)
loss_data, = exe.run(main_program,
feed={"image": x,
"label": label},
fetch_list=[cost.name])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册