diff --git a/paddleslim/analysis/latency.py b/paddleslim/analysis/latency.py index 3c9211e4e189e2d56d8cf6b2cd0d8501039b946c..3975856b1d4232e78ede0145f693075b16472191 100644 --- a/paddleslim/analysis/latency.py +++ b/paddleslim/analysis/latency.py @@ -15,7 +15,7 @@ # limitations under the License. from paddle.fluid import Program -from ..core import GraphWrapper, OpWrapper +from ..core import GraphWrapper __all__ = ["LatencyEvaluator", "TableLatencyEvaluator"] @@ -65,7 +65,6 @@ class LatencyEvaluator(object): return ops def _conv_op_args(self, op): - assert isinstance(op, OpWrapper) tmp, res = [], [] # op_name tmp.append('conv') diff --git a/paddleslim/core/__init__.py b/paddleslim/core/__init__.py index 969768de83691d227ac7dafcca5c8e34a2451469..ff2c287eaec4dd9c5613c2a5b09e8d4bd5aed91a 100644 --- a/paddleslim/core/__init__.py +++ b/paddleslim/core/__init__.py @@ -17,8 +17,9 @@ from .registry import Registry __all__ = ['GraphWrapper', 'Registry'] -try: - from .dy_graph import DyGraph - __all__ += ['DyGraph'] -except Exception as e: - pass +#try: +from .dy_graph import DyGraph +__all__ += ['DyGraph'] +#except Exception as e: +# print e +# pass diff --git a/paddleslim/core/dy_graph.py b/paddleslim/core/dy_graph.py index a48d013bae73acb4dc1f82716cec1536df6d37b6..5e92d540da3bc05ed7301992d410bb544d61979c 100644 --- a/paddleslim/core/dy_graph.py +++ b/paddleslim/core/dy_graph.py @@ -31,6 +31,13 @@ class VarWrapper(object): self._is_parameter = is_parameter self._tensor = tensor + def data(self): + return np.array(self._tensor.data) + + def set_data(self, data, place=None): + assert self._tensor is not None + self._tensor.data = self._tensor.new_tensor(data) + def __eq__(self, v): """ Overwrite this function for ...in... syntax in python. @@ -198,78 +205,92 @@ class DyGraph(object): """ super(DyGraph, self).__init__() self.module = module - self._graph = torch.jit.trace(self.module, - torch.rand(input_shape)).graph - print self._graph - self.children = {} - for name, child in self.module.named_children(): - self.children[name] = child - - self.id2child = {} - for node in self._graph.nodes(): - if "prim::GetAttr" == node.kind() and "self.1" == node.inputsAt( - 0).debugName(): - # print dir(node) - self.id2child[node.output().debugName()] = node["name"] - - print self.id2child - - self.vars = {} - self.nodes = {} - for node in self._graph.nodes(): - if "prim::CallMethod" == node.kind() and "forward" == node["name"]: - module_id = node.inputsAt(0).debugName() - node_id = node.output().debugName() + "-" + module_id - in_var_id = node.inputsAt(1).debugName() - out_var_id = node.output().debugName() - if node_id not in self.nodes: - self.nodes[node_id] = OpWrapper(node_id, - self.id2child[module_id]) - self.nodes[node_id].module = self.children[self.id2child[ - module_id]] - - for param_id, param in self.nodes[ - node_id].module.named_parameters(): - param_id = ".".join([self.id2child[module_id], param_id]) - if param_id not in self.vars: - self.vars[param_id] = VarWrapper( - param_id, is_parameter=True, tensor=param) - self.nodes[node_id].all_inputs().append(self.vars[ - param_id]) - self.vars[param_id].outputs().append(self.nodes[ - node_id]) - - if in_var_id not in self.vars: - self.vars[in_var_id] = VarWrapper(in_var_id) - if out_var_id not in self.vars: - self.vars[out_var_id] = VarWrapper(out_var_id) - self.nodes[node_id].all_inputs().append(self.vars[in_var_id]) - self.nodes[node_id].all_outputs().append(self.vars[out_var_id]) - self.vars[in_var_id].outputs().append(self.nodes[node_id]) - self.vars[out_var_id].inputs().append(self.nodes[node_id]) - elif node.kind().startswith("aten::"): - # print dir(node) - node_id = node.output().debugName() + "-" + node.kind() - # node_id = node.debugName() - if node_id not in self.nodes: - self.nodes[node_id] = OpWrapper(node_id, node.kind()) - -# self.nodes[node_id].type = node.kind() - for input in node.inputs(): - in_var_id = input.debugName() - if in_var_id not in self.vars: - self.vars[in_var_id] = VarWrapper(in_var_id) - self.vars[in_var_id].outputs().append(self.nodes[node_id]) - self.nodes[node_id].all_inputs().append(self.vars[ - in_var_id]) - - for output in node.outputs(): - out_var_id = output.debugName() - if out_var_id not in self.vars: - self.vars[out_var_id] = VarWrapper(out_var_id) - self.vars[out_var_id].inputs().append(self.nodes[node_id]) - self.nodes[node_id].all_outputs().append(self.vars[ - out_var_id]) + traced = torch.jit.trace(self.module, torch.rand(input_shape)) + + self._trace_graph(traced, input=None, nodes={}, vars={}) + +# self._graph = traced.graph +# for name,child in traced.named_modules(): +# print name, child.graph +# print dir(traced) +# print self._graph +# self.children = {} +# for name, child in self.module.named_modules(): +# self.children[name] = child +## print "child: {}".format(name) +# +# self.id2child = {} +# for node in self._graph.nodes(): +# if "prim::GetAttr" == node.kind() and "self.1" == node.inputsAt( +# 0).debugName(): +# print node.output().graph +# self.id2child[node.output().debugName()] = node["name"] +# +# print self.id2child +# +# self.vars = {} +# self.nodes = {} +# for node in self._graph.nodes(): +# if "prim::CallMethod" == node.kind() and "forward" == node["name"]: +# module_id = node.inputsAt(0).debugName() +# node_id = node.output().debugName() + "-" + module_id +# in_var_id = node.inputsAt(1).debugName() +# out_var_id = node.output().debugName() +# if node_id not in self.nodes: +# self.nodes[node_id] = OpWrapper(node_id, +# self.id2child[module_id]) +# self.nodes[node_id].module = self.children[self.id2child[ +# module_id]] +# +# for param_id, param in self.nodes[ +# node_id].module.named_parameters(): +# param_id = ".".join([self.id2child[module_id], param_id]) +# if param_id not in self.vars: +# self.vars[param_id] = VarWrapper( +# param_id, is_parameter=True, tensor=param) +# self.nodes[node_id].all_inputs().append(self.vars[ +# param_id]) +# self.vars[param_id].outputs().append(self.nodes[ +# node_id]) +# +# if in_var_id not in self.vars: +# self.vars[in_var_id] = VarWrapper(in_var_id) +# if out_var_id not in self.vars: +# self.vars[out_var_id] = VarWrapper(out_var_id) +# self.nodes[node_id].all_inputs().append(self.vars[in_var_id]) +# self.nodes[node_id].all_outputs().append(self.vars[out_var_id]) +# self.vars[in_var_id].outputs().append(self.nodes[node_id]) +# self.vars[out_var_id].inputs().append(self.nodes[node_id]) +# elif node.kind().startswith("aten::"): +# # print dir(node) +# node_id = node.output().debugName() + "-" + node.kind() +# # node_id = node.debugName() +# if node_id not in self.nodes: +# self.nodes[node_id] = OpWrapper(node_id, node.kind()) +# +## self.nodes[node_id].type = node.kind() +# for input in node.inputs(): +# in_var_id = input.debugName() +# if in_var_id not in self.vars: +# self.vars[in_var_id] = VarWrapper(in_var_id) +# self.vars[in_var_id].outputs().append(self.nodes[node_id]) +# self.nodes[node_id].all_inputs().append(self.vars[ +# in_var_id]) +# +# for output in node.outputs(): +# out_var_id = output.debugName() +# if out_var_id not in self.vars: +# self.vars[out_var_id] = VarWrapper(out_var_id) +# self.vars[out_var_id].inputs().append(self.nodes[node_id]) +# self.nodes[node_id].all_outputs().append(self.vars[ +# out_var_id]) + + def _trace_graph(self, traced, input=None, nodes={}, vars={}): + inputs = [i for i in traced.graph.inputs()] + print inputs[1] + input_id = inputs[1].debugName() + if input is None and input_id not in vars: + vars[input_id] = VarWrapper(input_id) def all_parameters(self): """ @@ -388,19 +409,14 @@ class DyGraph(object): Update the shape of parameters in the graph according to tensors in scope. It is used after loading pruned parameters from file. """ - for param in self.all_parameters(): - tensor_shape = np.array( - scope.find_var(param.name()).get_tensor()).shape - param.set_shape(tensor_shape) + pass def infer_shape(self): """ Update the groups of convolution layer according to current filters. It is used after loading pruned parameters from file. """ - for op in self.ops(): - if op.type() != 'conditional_block': - op._op.desc.infer_shape(op._op.block.desc) + pass def update_groups_of_conv(self): for op in self.ops(): diff --git a/paddleslim/prune/auto_pruner.py b/paddleslim/prune/auto_pruner.py index 55ec0f55e387c5d8729ea85c233e8f2d54f82352..f8421f9039d7ae1bb71d5dcc0839f11f66fee109 100644 --- a/paddleslim/prune/auto_pruner.py +++ b/paddleslim/prune/auto_pruner.py @@ -17,7 +17,7 @@ import logging import numpy as np import paddle.fluid as fluid from .pruner import Pruner -from ..core import VarWrapper, OpWrapper, GraphWrapper +from ..core import GraphWrapper from ..common import SAController from ..common import get_logger from ..analysis import flops diff --git a/paddleslim/prune/dy_prune_walker.py b/paddleslim/prune/dy_prune_walker.py index dfbcc4e0fc4588f97e58cb73a99a2eebecef161b..daf858a10036ba67d6c343ab24608a0631b7a937 100644 --- a/paddleslim/prune/dy_prune_walker.py +++ b/paddleslim/prune/dy_prune_walker.py @@ -108,7 +108,7 @@ class Conv2d(PruneWorker): if pruned_axis == 0: if len(self.op.all_inputs()) > 2: # has bias self.pruned_params.append( - (self.op.all_inputs()[1], channel_axis, pruned_idx)) + (self.op.all_inputs()[1], 0, pruned_idx)) output_var = self.op.all_outputs()[0] self._visit(output_var, channel_axis) next_ops = output_var.outputs() @@ -135,7 +135,7 @@ class Conv2d(PruneWorker): if len(self.op.all_inputs()) > 2: self.pruned_params.append( - (self.op.all_inputs()[1], channel_axis, pruned_idx)) + (self.op.all_inputs()[1], 0, pruned_idx)) output_var = self.op.all_outputs()[0] next_ops = output_var.outputs() diff --git a/paddleslim/prune/group_param.py b/paddleslim/prune/group_param.py index 52075c9a47d34723d0f90b8c69b982a610aeb2f7..c5505e33aac1ab1028f1f5d5fc86a1bc8a9fca7b 100644 --- a/paddleslim/prune/group_param.py +++ b/paddleslim/prune/group_param.py @@ -14,7 +14,10 @@ # limitations under the License. from ..core import GraphWrapper +from ..core import DyGraph +import paddle.fluid as fluid from .prune_walker import conv2d as conv2d_walker +from .dy_prune_walker import Conv2d as dy_conv2d_walker __all__ = ["collect_convs"] @@ -48,8 +51,10 @@ def collect_convs(params, graph): list>: The groups. """ - if not isinstance(graph, GraphWrapper): + if isinstance(graph, fluid.Program): graph = GraphWrapper(graph) + elif isinstance(graph, DyGraph): + conv2d_walker = dy_conv2d_walker groups = [] for param in params: visited = {} diff --git a/paddleslim/prune/importance_sort.py b/paddleslim/prune/importance_sort.py index c8a8e3fd2236056fddff76762feb9dd975e0390a..e8939e896d2e01ef7a21206f913f9f1cbbee6333 100644 --- a/paddleslim/prune/importance_sort.py +++ b/paddleslim/prune/importance_sort.py @@ -58,7 +58,6 @@ def channel_score_sort(group, graph): list: sorted indexes """ - assert (isinstance(graph, GraphWrapper)) name, axis, score = group[ 0] # sort channels by the first convolution's score sorted_idx = score.argsort() diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index d18ad223dc8a322efa3568ee2588329eae163c98..d3a96b8d365b0dadc7b7e83026c9313486deac4a 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -17,11 +17,13 @@ import sys import numpy as np import paddle.fluid as fluid import copy -from ..core import VarWrapper, OpWrapper, GraphWrapper +from ..core import GraphWrapper +from ..core import DyGraph from .group_param import collect_convs from .criterion import l1_norm -from .importance_sort import channel_score_sort, batch_norm_scale +from .importance_sort import channel_score_sort, batch_norm_scale_sort from ..common import get_logger +import torch __all__ = ["Pruner"] @@ -57,7 +59,8 @@ class Pruner(): lazy=False, only_graph=False, param_backup=False, - param_shape_backup=False): + param_shape_backup=False, + input_shape=None): """Pruning the given parameters. Args: @@ -82,8 +85,11 @@ class Pruner(): if isinstance(graph, fluid.Program): graph = GraphWrapper(program.clone()) elif isinstance(graph, torch.nn.Module): - graph = DyGraph(graph) - conv2d_walker = dy_conv2d_walker + assert ( + input_shape is not None, + "input_shape can not be None while graph is instance of torch.nn.Module" + ) + graph = DyGraph(graph, input_shape) else: raise NotImplementedError('The type of graph is not supported.') param_backup = {} if param_backup else None @@ -93,6 +99,7 @@ class Pruner(): pruned_params = [] for param, ratio in zip(params, ratios): group = collect_convs([param], graph)[0] # [(name, axis)] + print "group: {}".format(group) if only_graph: param_v = graph.var(param) @@ -105,16 +112,17 @@ class Pruner(): group_values = [] for name, axis in group: - values = np.array(scope.find_var(name).get_tensor()) + values = graph.var(name).data() group_values.append((name, values, axis)) - scores = self.criterion( - group_with_values) # [(name, axis, score)] - + scores = self.criterion(group_values) # [(name, axis, score)] + print "scores: {}".format(scores) group_idx = self.channel_sortor( scores, graph=graph) # [(name, axis, soted_idx)] + print "group_idx: {}".format(group_idx) for param, pruned_axis, pruned_idx in group_idx: - pruned_num = len(pruned_idx) * ratio + pruned_num = int(round(len(pruned_idx) * ratio)) + print pruned_num pruned_params.append(( param, pruned_axis, pruned_idx[:pruned_num])) # [(name, axis, pruned_idx)] @@ -142,7 +150,7 @@ class Pruner(): new_shape[pruned_axis] -= len(pruned_idx) param.set_shape(new_shape) if not only_graph: - param_t = scope.find_var(param.name()).get_tensor() + param_t = graph.var(param_name).data() if param_backup is not None and ( param.name() not in param_backup): param_backup[param.name()] = copy.deepcopy( @@ -157,10 +165,13 @@ class Pruner(): _logger.error("Pruning {}, but get [{}]".format( param.name(), e)) - param_t.set(pruned_param, place) + graph.var(param_name).set_data(pruned_param, place=place) graph.update_groups_of_conv() graph.infer_shape() - return graph.program, param_backup, param_shape_backup + if isinstance(graph, DyGraph): + return graph.module, param_backup, param_shape_backup + else: + return graph.program, param_backup, param_shape_backup def _cal_pruned_idx(self, graph, scope, param, ratio, axis): """