提交 e3cd2f1b 编写于 作者: W wanghaoshuang

Support dynamic graph

上级 48744e8b
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
from paddle.fluid import Program from paddle.fluid import Program
from ..core import GraphWrapper, OpWrapper from ..core import GraphWrapper
__all__ = ["LatencyEvaluator", "TableLatencyEvaluator"] __all__ = ["LatencyEvaluator", "TableLatencyEvaluator"]
...@@ -65,7 +65,6 @@ class LatencyEvaluator(object): ...@@ -65,7 +65,6 @@ class LatencyEvaluator(object):
return ops return ops
def _conv_op_args(self, op): def _conv_op_args(self, op):
assert isinstance(op, OpWrapper)
tmp, res = [], [] tmp, res = [], []
# op_name # op_name
tmp.append('conv') tmp.append('conv')
......
...@@ -17,8 +17,9 @@ from .registry import Registry ...@@ -17,8 +17,9 @@ from .registry import Registry
__all__ = ['GraphWrapper', 'Registry'] __all__ = ['GraphWrapper', 'Registry']
try: #try:
from .dy_graph import DyGraph from .dy_graph import DyGraph
__all__ += ['DyGraph'] __all__ += ['DyGraph']
except Exception as e: #except Exception as e:
pass # print e
# pass
...@@ -31,6 +31,13 @@ class VarWrapper(object): ...@@ -31,6 +31,13 @@ class VarWrapper(object):
self._is_parameter = is_parameter self._is_parameter = is_parameter
self._tensor = tensor self._tensor = tensor
def data(self):
return np.array(self._tensor.data)
def set_data(self, data, place=None):
assert self._tensor is not None
self._tensor.data = self._tensor.new_tensor(data)
def __eq__(self, v): def __eq__(self, v):
""" """
Overwrite this function for ...in... syntax in python. Overwrite this function for ...in... syntax in python.
...@@ -198,78 +205,92 @@ class DyGraph(object): ...@@ -198,78 +205,92 @@ class DyGraph(object):
""" """
super(DyGraph, self).__init__() super(DyGraph, self).__init__()
self.module = module self.module = module
self._graph = torch.jit.trace(self.module, traced = torch.jit.trace(self.module, torch.rand(input_shape))
torch.rand(input_shape)).graph
print self._graph self._trace_graph(traced, input=None, nodes={}, vars={})
self.children = {}
for name, child in self.module.named_children(): # self._graph = traced.graph
self.children[name] = child # for name,child in traced.named_modules():
# print name, child.graph
self.id2child = {} # print dir(traced)
for node in self._graph.nodes(): # print self._graph
if "prim::GetAttr" == node.kind() and "self.1" == node.inputsAt( # self.children = {}
0).debugName(): # for name, child in self.module.named_modules():
# print dir(node) # self.children[name] = child
self.id2child[node.output().debugName()] = node["name"] ## print "child: {}".format(name)
#
print self.id2child # self.id2child = {}
# for node in self._graph.nodes():
self.vars = {} # if "prim::GetAttr" == node.kind() and "self.1" == node.inputsAt(
self.nodes = {} # 0).debugName():
for node in self._graph.nodes(): # print node.output().graph
if "prim::CallMethod" == node.kind() and "forward" == node["name"]: # self.id2child[node.output().debugName()] = node["name"]
module_id = node.inputsAt(0).debugName() #
node_id = node.output().debugName() + "-" + module_id # print self.id2child
in_var_id = node.inputsAt(1).debugName() #
out_var_id = node.output().debugName() # self.vars = {}
if node_id not in self.nodes: # self.nodes = {}
self.nodes[node_id] = OpWrapper(node_id, # for node in self._graph.nodes():
self.id2child[module_id]) # if "prim::CallMethod" == node.kind() and "forward" == node["name"]:
self.nodes[node_id].module = self.children[self.id2child[ # module_id = node.inputsAt(0).debugName()
module_id]] # node_id = node.output().debugName() + "-" + module_id
# in_var_id = node.inputsAt(1).debugName()
for param_id, param in self.nodes[ # out_var_id = node.output().debugName()
node_id].module.named_parameters(): # if node_id not in self.nodes:
param_id = ".".join([self.id2child[module_id], param_id]) # self.nodes[node_id] = OpWrapper(node_id,
if param_id not in self.vars: # self.id2child[module_id])
self.vars[param_id] = VarWrapper( # self.nodes[node_id].module = self.children[self.id2child[
param_id, is_parameter=True, tensor=param) # module_id]]
self.nodes[node_id].all_inputs().append(self.vars[ #
param_id]) # for param_id, param in self.nodes[
self.vars[param_id].outputs().append(self.nodes[ # node_id].module.named_parameters():
node_id]) # param_id = ".".join([self.id2child[module_id], param_id])
# if param_id not in self.vars:
if in_var_id not in self.vars: # self.vars[param_id] = VarWrapper(
self.vars[in_var_id] = VarWrapper(in_var_id) # param_id, is_parameter=True, tensor=param)
if out_var_id not in self.vars: # self.nodes[node_id].all_inputs().append(self.vars[
self.vars[out_var_id] = VarWrapper(out_var_id) # param_id])
self.nodes[node_id].all_inputs().append(self.vars[in_var_id]) # self.vars[param_id].outputs().append(self.nodes[
self.nodes[node_id].all_outputs().append(self.vars[out_var_id]) # node_id])
self.vars[in_var_id].outputs().append(self.nodes[node_id]) #
self.vars[out_var_id].inputs().append(self.nodes[node_id]) # if in_var_id not in self.vars:
elif node.kind().startswith("aten::"): # self.vars[in_var_id] = VarWrapper(in_var_id)
# print dir(node) # if out_var_id not in self.vars:
node_id = node.output().debugName() + "-" + node.kind() # self.vars[out_var_id] = VarWrapper(out_var_id)
# node_id = node.debugName() # self.nodes[node_id].all_inputs().append(self.vars[in_var_id])
if node_id not in self.nodes: # self.nodes[node_id].all_outputs().append(self.vars[out_var_id])
self.nodes[node_id] = OpWrapper(node_id, node.kind()) # self.vars[in_var_id].outputs().append(self.nodes[node_id])
# self.vars[out_var_id].inputs().append(self.nodes[node_id])
# self.nodes[node_id].type = node.kind() # elif node.kind().startswith("aten::"):
for input in node.inputs(): # # print dir(node)
in_var_id = input.debugName() # node_id = node.output().debugName() + "-" + node.kind()
if in_var_id not in self.vars: # # node_id = node.debugName()
self.vars[in_var_id] = VarWrapper(in_var_id) # if node_id not in self.nodes:
self.vars[in_var_id].outputs().append(self.nodes[node_id]) # self.nodes[node_id] = OpWrapper(node_id, node.kind())
self.nodes[node_id].all_inputs().append(self.vars[ #
in_var_id]) ## self.nodes[node_id].type = node.kind()
# for input in node.inputs():
for output in node.outputs(): # in_var_id = input.debugName()
out_var_id = output.debugName() # if in_var_id not in self.vars:
if out_var_id not in self.vars: # self.vars[in_var_id] = VarWrapper(in_var_id)
self.vars[out_var_id] = VarWrapper(out_var_id) # self.vars[in_var_id].outputs().append(self.nodes[node_id])
self.vars[out_var_id].inputs().append(self.nodes[node_id]) # self.nodes[node_id].all_inputs().append(self.vars[
self.nodes[node_id].all_outputs().append(self.vars[ # in_var_id])
out_var_id]) #
# for output in node.outputs():
# out_var_id = output.debugName()
# if out_var_id not in self.vars:
# self.vars[out_var_id] = VarWrapper(out_var_id)
# self.vars[out_var_id].inputs().append(self.nodes[node_id])
# self.nodes[node_id].all_outputs().append(self.vars[
# out_var_id])
def _trace_graph(self, traced, input=None, nodes={}, vars={}):
inputs = [i for i in traced.graph.inputs()]
print inputs[1]
input_id = inputs[1].debugName()
if input is None and input_id not in vars:
vars[input_id] = VarWrapper(input_id)
def all_parameters(self): def all_parameters(self):
""" """
...@@ -388,19 +409,14 @@ class DyGraph(object): ...@@ -388,19 +409,14 @@ class DyGraph(object):
Update the shape of parameters in the graph according to tensors in scope. Update the shape of parameters in the graph according to tensors in scope.
It is used after loading pruned parameters from file. It is used after loading pruned parameters from file.
""" """
for param in self.all_parameters(): pass
tensor_shape = np.array(
scope.find_var(param.name()).get_tensor()).shape
param.set_shape(tensor_shape)
def infer_shape(self): def infer_shape(self):
""" """
Update the groups of convolution layer according to current filters. Update the groups of convolution layer according to current filters.
It is used after loading pruned parameters from file. It is used after loading pruned parameters from file.
""" """
for op in self.ops(): pass
if op.type() != 'conditional_block':
op._op.desc.infer_shape(op._op.block.desc)
def update_groups_of_conv(self): def update_groups_of_conv(self):
for op in self.ops(): for op in self.ops():
......
...@@ -17,7 +17,7 @@ import logging ...@@ -17,7 +17,7 @@ import logging
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from .pruner import Pruner from .pruner import Pruner
from ..core import VarWrapper, OpWrapper, GraphWrapper from ..core import GraphWrapper
from ..common import SAController from ..common import SAController
from ..common import get_logger from ..common import get_logger
from ..analysis import flops from ..analysis import flops
......
...@@ -108,7 +108,7 @@ class Conv2d(PruneWorker): ...@@ -108,7 +108,7 @@ class Conv2d(PruneWorker):
if pruned_axis == 0: if pruned_axis == 0:
if len(self.op.all_inputs()) > 2: # has bias if len(self.op.all_inputs()) > 2: # has bias
self.pruned_params.append( self.pruned_params.append(
(self.op.all_inputs()[1], channel_axis, pruned_idx)) (self.op.all_inputs()[1], 0, pruned_idx))
output_var = self.op.all_outputs()[0] output_var = self.op.all_outputs()[0]
self._visit(output_var, channel_axis) self._visit(output_var, channel_axis)
next_ops = output_var.outputs() next_ops = output_var.outputs()
...@@ -135,7 +135,7 @@ class Conv2d(PruneWorker): ...@@ -135,7 +135,7 @@ class Conv2d(PruneWorker):
if len(self.op.all_inputs()) > 2: if len(self.op.all_inputs()) > 2:
self.pruned_params.append( self.pruned_params.append(
(self.op.all_inputs()[1], channel_axis, pruned_idx)) (self.op.all_inputs()[1], 0, pruned_idx))
output_var = self.op.all_outputs()[0] output_var = self.op.all_outputs()[0]
next_ops = output_var.outputs() next_ops = output_var.outputs()
......
...@@ -14,7 +14,10 @@ ...@@ -14,7 +14,10 @@
# limitations under the License. # limitations under the License.
from ..core import GraphWrapper from ..core import GraphWrapper
from ..core import DyGraph
import paddle.fluid as fluid
from .prune_walker import conv2d as conv2d_walker from .prune_walker import conv2d as conv2d_walker
from .dy_prune_walker import Conv2d as dy_conv2d_walker
__all__ = ["collect_convs"] __all__ = ["collect_convs"]
...@@ -48,8 +51,10 @@ def collect_convs(params, graph): ...@@ -48,8 +51,10 @@ def collect_convs(params, graph):
list<list<tuple>>: The groups. list<list<tuple>>: The groups.
""" """
if not isinstance(graph, GraphWrapper): if isinstance(graph, fluid.Program):
graph = GraphWrapper(graph) graph = GraphWrapper(graph)
elif isinstance(graph, DyGraph):
conv2d_walker = dy_conv2d_walker
groups = [] groups = []
for param in params: for param in params:
visited = {} visited = {}
......
...@@ -58,7 +58,6 @@ def channel_score_sort(group, graph): ...@@ -58,7 +58,6 @@ def channel_score_sort(group, graph):
list: sorted indexes list: sorted indexes
""" """
assert (isinstance(graph, GraphWrapper))
name, axis, score = group[ name, axis, score = group[
0] # sort channels by the first convolution's score 0] # sort channels by the first convolution's score
sorted_idx = score.argsort() sorted_idx = score.argsort()
......
...@@ -17,11 +17,13 @@ import sys ...@@ -17,11 +17,13 @@ import sys
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import copy import copy
from ..core import VarWrapper, OpWrapper, GraphWrapper from ..core import GraphWrapper
from ..core import DyGraph
from .group_param import collect_convs from .group_param import collect_convs
from .criterion import l1_norm from .criterion import l1_norm
from .importance_sort import channel_score_sort, batch_norm_scale from .importance_sort import channel_score_sort, batch_norm_scale_sort
from ..common import get_logger from ..common import get_logger
import torch
__all__ = ["Pruner"] __all__ = ["Pruner"]
...@@ -57,7 +59,8 @@ class Pruner(): ...@@ -57,7 +59,8 @@ class Pruner():
lazy=False, lazy=False,
only_graph=False, only_graph=False,
param_backup=False, param_backup=False,
param_shape_backup=False): param_shape_backup=False,
input_shape=None):
"""Pruning the given parameters. """Pruning the given parameters.
Args: Args:
...@@ -82,8 +85,11 @@ class Pruner(): ...@@ -82,8 +85,11 @@ class Pruner():
if isinstance(graph, fluid.Program): if isinstance(graph, fluid.Program):
graph = GraphWrapper(program.clone()) graph = GraphWrapper(program.clone())
elif isinstance(graph, torch.nn.Module): elif isinstance(graph, torch.nn.Module):
graph = DyGraph(graph) assert (
conv2d_walker = dy_conv2d_walker input_shape is not None,
"input_shape can not be None while graph is instance of torch.nn.Module"
)
graph = DyGraph(graph, input_shape)
else: else:
raise NotImplementedError('The type of graph is not supported.') raise NotImplementedError('The type of graph is not supported.')
param_backup = {} if param_backup else None param_backup = {} if param_backup else None
...@@ -93,6 +99,7 @@ class Pruner(): ...@@ -93,6 +99,7 @@ class Pruner():
pruned_params = [] pruned_params = []
for param, ratio in zip(params, ratios): for param, ratio in zip(params, ratios):
group = collect_convs([param], graph)[0] # [(name, axis)] group = collect_convs([param], graph)[0] # [(name, axis)]
print "group: {}".format(group)
if only_graph: if only_graph:
param_v = graph.var(param) param_v = graph.var(param)
...@@ -105,16 +112,17 @@ class Pruner(): ...@@ -105,16 +112,17 @@ class Pruner():
group_values = [] group_values = []
for name, axis in group: for name, axis in group:
values = np.array(scope.find_var(name).get_tensor()) values = graph.var(name).data()
group_values.append((name, values, axis)) group_values.append((name, values, axis))
scores = self.criterion( scores = self.criterion(group_values) # [(name, axis, score)]
group_with_values) # [(name, axis, score)] print "scores: {}".format(scores)
group_idx = self.channel_sortor( group_idx = self.channel_sortor(
scores, graph=graph) # [(name, axis, soted_idx)] scores, graph=graph) # [(name, axis, soted_idx)]
print "group_idx: {}".format(group_idx)
for param, pruned_axis, pruned_idx in group_idx: for param, pruned_axis, pruned_idx in group_idx:
pruned_num = len(pruned_idx) * ratio pruned_num = int(round(len(pruned_idx) * ratio))
print pruned_num
pruned_params.append(( pruned_params.append((
param, pruned_axis, param, pruned_axis,
pruned_idx[:pruned_num])) # [(name, axis, pruned_idx)] pruned_idx[:pruned_num])) # [(name, axis, pruned_idx)]
...@@ -142,7 +150,7 @@ class Pruner(): ...@@ -142,7 +150,7 @@ class Pruner():
new_shape[pruned_axis] -= len(pruned_idx) new_shape[pruned_axis] -= len(pruned_idx)
param.set_shape(new_shape) param.set_shape(new_shape)
if not only_graph: if not only_graph:
param_t = scope.find_var(param.name()).get_tensor() param_t = graph.var(param_name).data()
if param_backup is not None and ( if param_backup is not None and (
param.name() not in param_backup): param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy( param_backup[param.name()] = copy.deepcopy(
...@@ -157,10 +165,13 @@ class Pruner(): ...@@ -157,10 +165,13 @@ class Pruner():
_logger.error("Pruning {}, but get [{}]".format( _logger.error("Pruning {}, but get [{}]".format(
param.name(), e)) param.name(), e))
param_t.set(pruned_param, place) graph.var(param_name).set_data(pruned_param, place=place)
graph.update_groups_of_conv() graph.update_groups_of_conv()
graph.infer_shape() graph.infer_shape()
return graph.program, param_backup, param_shape_backup if isinstance(graph, DyGraph):
return graph.module, param_backup, param_shape_backup
else:
return graph.program, param_backup, param_shape_backup
def _cal_pruned_idx(self, graph, scope, param, ratio, axis): def _cal_pruned_idx(self, graph, scope, param, ratio, axis):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册