提交 e3cd2f1b 编写于 作者: W wanghaoshuang

Support dynamic graph

上级 48744e8b
......@@ -15,7 +15,7 @@
# limitations under the License.
from paddle.fluid import Program
from ..core import GraphWrapper, OpWrapper
from ..core import GraphWrapper
__all__ = ["LatencyEvaluator", "TableLatencyEvaluator"]
......@@ -65,7 +65,6 @@ class LatencyEvaluator(object):
return ops
def _conv_op_args(self, op):
assert isinstance(op, OpWrapper)
tmp, res = [], []
# op_name
tmp.append('conv')
......
......@@ -17,8 +17,9 @@ from .registry import Registry
__all__ = ['GraphWrapper', 'Registry']
try:
from .dy_graph import DyGraph
__all__ += ['DyGraph']
except Exception as e:
pass
#try:
from .dy_graph import DyGraph
__all__ += ['DyGraph']
#except Exception as e:
# print e
# pass
......@@ -31,6 +31,13 @@ class VarWrapper(object):
self._is_parameter = is_parameter
self._tensor = tensor
def data(self):
return np.array(self._tensor.data)
def set_data(self, data, place=None):
assert self._tensor is not None
self._tensor.data = self._tensor.new_tensor(data)
def __eq__(self, v):
"""
Overwrite this function for ...in... syntax in python.
......@@ -198,78 +205,92 @@ class DyGraph(object):
"""
super(DyGraph, self).__init__()
self.module = module
self._graph = torch.jit.trace(self.module,
torch.rand(input_shape)).graph
print self._graph
self.children = {}
for name, child in self.module.named_children():
self.children[name] = child
self.id2child = {}
for node in self._graph.nodes():
if "prim::GetAttr" == node.kind() and "self.1" == node.inputsAt(
0).debugName():
# print dir(node)
self.id2child[node.output().debugName()] = node["name"]
print self.id2child
self.vars = {}
self.nodes = {}
for node in self._graph.nodes():
if "prim::CallMethod" == node.kind() and "forward" == node["name"]:
module_id = node.inputsAt(0).debugName()
node_id = node.output().debugName() + "-" + module_id
in_var_id = node.inputsAt(1).debugName()
out_var_id = node.output().debugName()
if node_id not in self.nodes:
self.nodes[node_id] = OpWrapper(node_id,
self.id2child[module_id])
self.nodes[node_id].module = self.children[self.id2child[
module_id]]
for param_id, param in self.nodes[
node_id].module.named_parameters():
param_id = ".".join([self.id2child[module_id], param_id])
if param_id not in self.vars:
self.vars[param_id] = VarWrapper(
param_id, is_parameter=True, tensor=param)
self.nodes[node_id].all_inputs().append(self.vars[
param_id])
self.vars[param_id].outputs().append(self.nodes[
node_id])
if in_var_id not in self.vars:
self.vars[in_var_id] = VarWrapper(in_var_id)
if out_var_id not in self.vars:
self.vars[out_var_id] = VarWrapper(out_var_id)
self.nodes[node_id].all_inputs().append(self.vars[in_var_id])
self.nodes[node_id].all_outputs().append(self.vars[out_var_id])
self.vars[in_var_id].outputs().append(self.nodes[node_id])
self.vars[out_var_id].inputs().append(self.nodes[node_id])
elif node.kind().startswith("aten::"):
# print dir(node)
node_id = node.output().debugName() + "-" + node.kind()
# node_id = node.debugName()
if node_id not in self.nodes:
self.nodes[node_id] = OpWrapper(node_id, node.kind())
# self.nodes[node_id].type = node.kind()
for input in node.inputs():
in_var_id = input.debugName()
if in_var_id not in self.vars:
self.vars[in_var_id] = VarWrapper(in_var_id)
self.vars[in_var_id].outputs().append(self.nodes[node_id])
self.nodes[node_id].all_inputs().append(self.vars[
in_var_id])
for output in node.outputs():
out_var_id = output.debugName()
if out_var_id not in self.vars:
self.vars[out_var_id] = VarWrapper(out_var_id)
self.vars[out_var_id].inputs().append(self.nodes[node_id])
self.nodes[node_id].all_outputs().append(self.vars[
out_var_id])
traced = torch.jit.trace(self.module, torch.rand(input_shape))
self._trace_graph(traced, input=None, nodes={}, vars={})
# self._graph = traced.graph
# for name,child in traced.named_modules():
# print name, child.graph
# print dir(traced)
# print self._graph
# self.children = {}
# for name, child in self.module.named_modules():
# self.children[name] = child
## print "child: {}".format(name)
#
# self.id2child = {}
# for node in self._graph.nodes():
# if "prim::GetAttr" == node.kind() and "self.1" == node.inputsAt(
# 0).debugName():
# print node.output().graph
# self.id2child[node.output().debugName()] = node["name"]
#
# print self.id2child
#
# self.vars = {}
# self.nodes = {}
# for node in self._graph.nodes():
# if "prim::CallMethod" == node.kind() and "forward" == node["name"]:
# module_id = node.inputsAt(0).debugName()
# node_id = node.output().debugName() + "-" + module_id
# in_var_id = node.inputsAt(1).debugName()
# out_var_id = node.output().debugName()
# if node_id not in self.nodes:
# self.nodes[node_id] = OpWrapper(node_id,
# self.id2child[module_id])
# self.nodes[node_id].module = self.children[self.id2child[
# module_id]]
#
# for param_id, param in self.nodes[
# node_id].module.named_parameters():
# param_id = ".".join([self.id2child[module_id], param_id])
# if param_id not in self.vars:
# self.vars[param_id] = VarWrapper(
# param_id, is_parameter=True, tensor=param)
# self.nodes[node_id].all_inputs().append(self.vars[
# param_id])
# self.vars[param_id].outputs().append(self.nodes[
# node_id])
#
# if in_var_id not in self.vars:
# self.vars[in_var_id] = VarWrapper(in_var_id)
# if out_var_id not in self.vars:
# self.vars[out_var_id] = VarWrapper(out_var_id)
# self.nodes[node_id].all_inputs().append(self.vars[in_var_id])
# self.nodes[node_id].all_outputs().append(self.vars[out_var_id])
# self.vars[in_var_id].outputs().append(self.nodes[node_id])
# self.vars[out_var_id].inputs().append(self.nodes[node_id])
# elif node.kind().startswith("aten::"):
# # print dir(node)
# node_id = node.output().debugName() + "-" + node.kind()
# # node_id = node.debugName()
# if node_id not in self.nodes:
# self.nodes[node_id] = OpWrapper(node_id, node.kind())
#
## self.nodes[node_id].type = node.kind()
# for input in node.inputs():
# in_var_id = input.debugName()
# if in_var_id not in self.vars:
# self.vars[in_var_id] = VarWrapper(in_var_id)
# self.vars[in_var_id].outputs().append(self.nodes[node_id])
# self.nodes[node_id].all_inputs().append(self.vars[
# in_var_id])
#
# for output in node.outputs():
# out_var_id = output.debugName()
# if out_var_id not in self.vars:
# self.vars[out_var_id] = VarWrapper(out_var_id)
# self.vars[out_var_id].inputs().append(self.nodes[node_id])
# self.nodes[node_id].all_outputs().append(self.vars[
# out_var_id])
def _trace_graph(self, traced, input=None, nodes={}, vars={}):
inputs = [i for i in traced.graph.inputs()]
print inputs[1]
input_id = inputs[1].debugName()
if input is None and input_id not in vars:
vars[input_id] = VarWrapper(input_id)
def all_parameters(self):
"""
......@@ -388,19 +409,14 @@ class DyGraph(object):
Update the shape of parameters in the graph according to tensors in scope.
It is used after loading pruned parameters from file.
"""
for param in self.all_parameters():
tensor_shape = np.array(
scope.find_var(param.name()).get_tensor()).shape
param.set_shape(tensor_shape)
pass
def infer_shape(self):
"""
Update the groups of convolution layer according to current filters.
It is used after loading pruned parameters from file.
"""
for op in self.ops():
if op.type() != 'conditional_block':
op._op.desc.infer_shape(op._op.block.desc)
pass
def update_groups_of_conv(self):
for op in self.ops():
......
......@@ -17,7 +17,7 @@ import logging
import numpy as np
import paddle.fluid as fluid
from .pruner import Pruner
from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..core import GraphWrapper
from ..common import SAController
from ..common import get_logger
from ..analysis import flops
......
......@@ -108,7 +108,7 @@ class Conv2d(PruneWorker):
if pruned_axis == 0:
if len(self.op.all_inputs()) > 2: # has bias
self.pruned_params.append(
(self.op.all_inputs()[1], channel_axis, pruned_idx))
(self.op.all_inputs()[1], 0, pruned_idx))
output_var = self.op.all_outputs()[0]
self._visit(output_var, channel_axis)
next_ops = output_var.outputs()
......@@ -135,7 +135,7 @@ class Conv2d(PruneWorker):
if len(self.op.all_inputs()) > 2:
self.pruned_params.append(
(self.op.all_inputs()[1], channel_axis, pruned_idx))
(self.op.all_inputs()[1], 0, pruned_idx))
output_var = self.op.all_outputs()[0]
next_ops = output_var.outputs()
......
......@@ -14,7 +14,10 @@
# limitations under the License.
from ..core import GraphWrapper
from ..core import DyGraph
import paddle.fluid as fluid
from .prune_walker import conv2d as conv2d_walker
from .dy_prune_walker import Conv2d as dy_conv2d_walker
__all__ = ["collect_convs"]
......@@ -48,8 +51,10 @@ def collect_convs(params, graph):
list<list<tuple>>: The groups.
"""
if not isinstance(graph, GraphWrapper):
if isinstance(graph, fluid.Program):
graph = GraphWrapper(graph)
elif isinstance(graph, DyGraph):
conv2d_walker = dy_conv2d_walker
groups = []
for param in params:
visited = {}
......
......@@ -58,7 +58,6 @@ def channel_score_sort(group, graph):
list: sorted indexes
"""
assert (isinstance(graph, GraphWrapper))
name, axis, score = group[
0] # sort channels by the first convolution's score
sorted_idx = score.argsort()
......
......@@ -17,11 +17,13 @@ import sys
import numpy as np
import paddle.fluid as fluid
import copy
from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..core import GraphWrapper
from ..core import DyGraph
from .group_param import collect_convs
from .criterion import l1_norm
from .importance_sort import channel_score_sort, batch_norm_scale
from .importance_sort import channel_score_sort, batch_norm_scale_sort
from ..common import get_logger
import torch
__all__ = ["Pruner"]
......@@ -57,7 +59,8 @@ class Pruner():
lazy=False,
only_graph=False,
param_backup=False,
param_shape_backup=False):
param_shape_backup=False,
input_shape=None):
"""Pruning the given parameters.
Args:
......@@ -82,8 +85,11 @@ class Pruner():
if isinstance(graph, fluid.Program):
graph = GraphWrapper(program.clone())
elif isinstance(graph, torch.nn.Module):
graph = DyGraph(graph)
conv2d_walker = dy_conv2d_walker
assert (
input_shape is not None,
"input_shape can not be None while graph is instance of torch.nn.Module"
)
graph = DyGraph(graph, input_shape)
else:
raise NotImplementedError('The type of graph is not supported.')
param_backup = {} if param_backup else None
......@@ -93,6 +99,7 @@ class Pruner():
pruned_params = []
for param, ratio in zip(params, ratios):
group = collect_convs([param], graph)[0] # [(name, axis)]
print "group: {}".format(group)
if only_graph:
param_v = graph.var(param)
......@@ -105,16 +112,17 @@ class Pruner():
group_values = []
for name, axis in group:
values = np.array(scope.find_var(name).get_tensor())
values = graph.var(name).data()
group_values.append((name, values, axis))
scores = self.criterion(
group_with_values) # [(name, axis, score)]
scores = self.criterion(group_values) # [(name, axis, score)]
print "scores: {}".format(scores)
group_idx = self.channel_sortor(
scores, graph=graph) # [(name, axis, soted_idx)]
print "group_idx: {}".format(group_idx)
for param, pruned_axis, pruned_idx in group_idx:
pruned_num = len(pruned_idx) * ratio
pruned_num = int(round(len(pruned_idx) * ratio))
print pruned_num
pruned_params.append((
param, pruned_axis,
pruned_idx[:pruned_num])) # [(name, axis, pruned_idx)]
......@@ -142,7 +150,7 @@ class Pruner():
new_shape[pruned_axis] -= len(pruned_idx)
param.set_shape(new_shape)
if not only_graph:
param_t = scope.find_var(param.name()).get_tensor()
param_t = graph.var(param_name).data()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(
......@@ -157,9 +165,12 @@ class Pruner():
_logger.error("Pruning {}, but get [{}]".format(
param.name(), e))
param_t.set(pruned_param, place)
graph.var(param_name).set_data(pruned_param, place=place)
graph.update_groups_of_conv()
graph.infer_shape()
if isinstance(graph, DyGraph):
return graph.module, param_backup, param_shape_backup
else:
return graph.program, param_backup, param_shape_backup
def _cal_pruned_idx(self, graph, scope, param, ratio, axis):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册