Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
e3cd2f1b
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e3cd2f1b
编写于
3月 19, 2020
作者:
W
wanghaoshuang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support dynamic graph
上级
48744e8b
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
135 addition
and
104 deletion
+135
-104
paddleslim/analysis/latency.py
paddleslim/analysis/latency.py
+1
-2
paddleslim/core/__init__.py
paddleslim/core/__init__.py
+6
-5
paddleslim/core/dy_graph.py
paddleslim/core/dy_graph.py
+95
-79
paddleslim/prune/auto_pruner.py
paddleslim/prune/auto_pruner.py
+1
-1
paddleslim/prune/dy_prune_walker.py
paddleslim/prune/dy_prune_walker.py
+2
-2
paddleslim/prune/group_param.py
paddleslim/prune/group_param.py
+6
-1
paddleslim/prune/importance_sort.py
paddleslim/prune/importance_sort.py
+0
-1
paddleslim/prune/pruner.py
paddleslim/prune/pruner.py
+24
-13
未找到文件。
paddleslim/analysis/latency.py
浏览文件 @
e3cd2f1b
...
...
@@ -15,7 +15,7 @@
# limitations under the License.
from
paddle.fluid
import
Program
from
..core
import
GraphWrapper
,
OpWrapper
from
..core
import
GraphWrapper
__all__
=
[
"LatencyEvaluator"
,
"TableLatencyEvaluator"
]
...
...
@@ -65,7 +65,6 @@ class LatencyEvaluator(object):
return
ops
def
_conv_op_args
(
self
,
op
):
assert
isinstance
(
op
,
OpWrapper
)
tmp
,
res
=
[],
[]
# op_name
tmp
.
append
(
'conv'
)
...
...
paddleslim/core/__init__.py
浏览文件 @
e3cd2f1b
...
...
@@ -17,8 +17,9 @@ from .registry import Registry
__all__
=
[
'GraphWrapper'
,
'Registry'
]
try
:
from
.dy_graph
import
DyGraph
__all__
+=
[
'DyGraph'
]
except
Exception
as
e
:
pass
#try:
from
.dy_graph
import
DyGraph
__all__
+=
[
'DyGraph'
]
#except Exception as e:
# print e
# pass
paddleslim/core/dy_graph.py
浏览文件 @
e3cd2f1b
...
...
@@ -31,6 +31,13 @@ class VarWrapper(object):
self
.
_is_parameter
=
is_parameter
self
.
_tensor
=
tensor
def
data
(
self
):
return
np
.
array
(
self
.
_tensor
.
data
)
def
set_data
(
self
,
data
,
place
=
None
):
assert
self
.
_tensor
is
not
None
self
.
_tensor
.
data
=
self
.
_tensor
.
new_tensor
(
data
)
def
__eq__
(
self
,
v
):
"""
Overwrite this function for ...in... syntax in python.
...
...
@@ -198,78 +205,92 @@ class DyGraph(object):
"""
super
(
DyGraph
,
self
).
__init__
()
self
.
module
=
module
self
.
_graph
=
torch
.
jit
.
trace
(
self
.
module
,
torch
.
rand
(
input_shape
)).
graph
print
self
.
_graph
self
.
children
=
{}
for
name
,
child
in
self
.
module
.
named_children
():
self
.
children
[
name
]
=
child
self
.
id2child
=
{}
for
node
in
self
.
_graph
.
nodes
():
if
"prim::GetAttr"
==
node
.
kind
()
and
"self.1"
==
node
.
inputsAt
(
0
).
debugName
():
# print dir(node)
self
.
id2child
[
node
.
output
().
debugName
()]
=
node
[
"name"
]
print
self
.
id2child
self
.
vars
=
{}
self
.
nodes
=
{}
for
node
in
self
.
_graph
.
nodes
():
if
"prim::CallMethod"
==
node
.
kind
()
and
"forward"
==
node
[
"name"
]:
module_id
=
node
.
inputsAt
(
0
).
debugName
()
node_id
=
node
.
output
().
debugName
()
+
"-"
+
module_id
in_var_id
=
node
.
inputsAt
(
1
).
debugName
()
out_var_id
=
node
.
output
().
debugName
()
if
node_id
not
in
self
.
nodes
:
self
.
nodes
[
node_id
]
=
OpWrapper
(
node_id
,
self
.
id2child
[
module_id
])
self
.
nodes
[
node_id
].
module
=
self
.
children
[
self
.
id2child
[
module_id
]]
for
param_id
,
param
in
self
.
nodes
[
node_id
].
module
.
named_parameters
():
param_id
=
"."
.
join
([
self
.
id2child
[
module_id
],
param_id
])
if
param_id
not
in
self
.
vars
:
self
.
vars
[
param_id
]
=
VarWrapper
(
param_id
,
is_parameter
=
True
,
tensor
=
param
)
self
.
nodes
[
node_id
].
all_inputs
().
append
(
self
.
vars
[
param_id
])
self
.
vars
[
param_id
].
outputs
().
append
(
self
.
nodes
[
node_id
])
if
in_var_id
not
in
self
.
vars
:
self
.
vars
[
in_var_id
]
=
VarWrapper
(
in_var_id
)
if
out_var_id
not
in
self
.
vars
:
self
.
vars
[
out_var_id
]
=
VarWrapper
(
out_var_id
)
self
.
nodes
[
node_id
].
all_inputs
().
append
(
self
.
vars
[
in_var_id
])
self
.
nodes
[
node_id
].
all_outputs
().
append
(
self
.
vars
[
out_var_id
])
self
.
vars
[
in_var_id
].
outputs
().
append
(
self
.
nodes
[
node_id
])
self
.
vars
[
out_var_id
].
inputs
().
append
(
self
.
nodes
[
node_id
])
elif
node
.
kind
().
startswith
(
"aten::"
):
# print dir(node)
node_id
=
node
.
output
().
debugName
()
+
"-"
+
node
.
kind
()
# node_id = node.debugName()
if
node_id
not
in
self
.
nodes
:
self
.
nodes
[
node_id
]
=
OpWrapper
(
node_id
,
node
.
kind
())
# self.nodes[node_id].type = node.kind()
for
input
in
node
.
inputs
():
in_var_id
=
input
.
debugName
()
if
in_var_id
not
in
self
.
vars
:
self
.
vars
[
in_var_id
]
=
VarWrapper
(
in_var_id
)
self
.
vars
[
in_var_id
].
outputs
().
append
(
self
.
nodes
[
node_id
])
self
.
nodes
[
node_id
].
all_inputs
().
append
(
self
.
vars
[
in_var_id
])
for
output
in
node
.
outputs
():
out_var_id
=
output
.
debugName
()
if
out_var_id
not
in
self
.
vars
:
self
.
vars
[
out_var_id
]
=
VarWrapper
(
out_var_id
)
self
.
vars
[
out_var_id
].
inputs
().
append
(
self
.
nodes
[
node_id
])
self
.
nodes
[
node_id
].
all_outputs
().
append
(
self
.
vars
[
out_var_id
])
traced
=
torch
.
jit
.
trace
(
self
.
module
,
torch
.
rand
(
input_shape
))
self
.
_trace_graph
(
traced
,
input
=
None
,
nodes
=
{},
vars
=
{})
# self._graph = traced.graph
# for name,child in traced.named_modules():
# print name, child.graph
# print dir(traced)
# print self._graph
# self.children = {}
# for name, child in self.module.named_modules():
# self.children[name] = child
## print "child: {}".format(name)
#
# self.id2child = {}
# for node in self._graph.nodes():
# if "prim::GetAttr" == node.kind() and "self.1" == node.inputsAt(
# 0).debugName():
# print node.output().graph
# self.id2child[node.output().debugName()] = node["name"]
#
# print self.id2child
#
# self.vars = {}
# self.nodes = {}
# for node in self._graph.nodes():
# if "prim::CallMethod" == node.kind() and "forward" == node["name"]:
# module_id = node.inputsAt(0).debugName()
# node_id = node.output().debugName() + "-" + module_id
# in_var_id = node.inputsAt(1).debugName()
# out_var_id = node.output().debugName()
# if node_id not in self.nodes:
# self.nodes[node_id] = OpWrapper(node_id,
# self.id2child[module_id])
# self.nodes[node_id].module = self.children[self.id2child[
# module_id]]
#
# for param_id, param in self.nodes[
# node_id].module.named_parameters():
# param_id = ".".join([self.id2child[module_id], param_id])
# if param_id not in self.vars:
# self.vars[param_id] = VarWrapper(
# param_id, is_parameter=True, tensor=param)
# self.nodes[node_id].all_inputs().append(self.vars[
# param_id])
# self.vars[param_id].outputs().append(self.nodes[
# node_id])
#
# if in_var_id not in self.vars:
# self.vars[in_var_id] = VarWrapper(in_var_id)
# if out_var_id not in self.vars:
# self.vars[out_var_id] = VarWrapper(out_var_id)
# self.nodes[node_id].all_inputs().append(self.vars[in_var_id])
# self.nodes[node_id].all_outputs().append(self.vars[out_var_id])
# self.vars[in_var_id].outputs().append(self.nodes[node_id])
# self.vars[out_var_id].inputs().append(self.nodes[node_id])
# elif node.kind().startswith("aten::"):
# # print dir(node)
# node_id = node.output().debugName() + "-" + node.kind()
# # node_id = node.debugName()
# if node_id not in self.nodes:
# self.nodes[node_id] = OpWrapper(node_id, node.kind())
#
## self.nodes[node_id].type = node.kind()
# for input in node.inputs():
# in_var_id = input.debugName()
# if in_var_id not in self.vars:
# self.vars[in_var_id] = VarWrapper(in_var_id)
# self.vars[in_var_id].outputs().append(self.nodes[node_id])
# self.nodes[node_id].all_inputs().append(self.vars[
# in_var_id])
#
# for output in node.outputs():
# out_var_id = output.debugName()
# if out_var_id not in self.vars:
# self.vars[out_var_id] = VarWrapper(out_var_id)
# self.vars[out_var_id].inputs().append(self.nodes[node_id])
# self.nodes[node_id].all_outputs().append(self.vars[
# out_var_id])
def
_trace_graph
(
self
,
traced
,
input
=
None
,
nodes
=
{},
vars
=
{}):
inputs
=
[
i
for
i
in
traced
.
graph
.
inputs
()]
print
inputs
[
1
]
input_id
=
inputs
[
1
].
debugName
()
if
input
is
None
and
input_id
not
in
vars
:
vars
[
input_id
]
=
VarWrapper
(
input_id
)
def
all_parameters
(
self
):
"""
...
...
@@ -388,19 +409,14 @@ class DyGraph(object):
Update the shape of parameters in the graph according to tensors in scope.
It is used after loading pruned parameters from file.
"""
for
param
in
self
.
all_parameters
():
tensor_shape
=
np
.
array
(
scope
.
find_var
(
param
.
name
()).
get_tensor
()).
shape
param
.
set_shape
(
tensor_shape
)
pass
def
infer_shape
(
self
):
"""
Update the groups of convolution layer according to current filters.
It is used after loading pruned parameters from file.
"""
for
op
in
self
.
ops
():
if
op
.
type
()
!=
'conditional_block'
:
op
.
_op
.
desc
.
infer_shape
(
op
.
_op
.
block
.
desc
)
pass
def
update_groups_of_conv
(
self
):
for
op
in
self
.
ops
():
...
...
paddleslim/prune/auto_pruner.py
浏览文件 @
e3cd2f1b
...
...
@@ -17,7 +17,7 @@ import logging
import
numpy
as
np
import
paddle.fluid
as
fluid
from
.pruner
import
Pruner
from
..core
import
VarWrapper
,
OpWrapper
,
GraphWrapper
from
..core
import
GraphWrapper
from
..common
import
SAController
from
..common
import
get_logger
from
..analysis
import
flops
...
...
paddleslim/prune/dy_prune_walker.py
浏览文件 @
e3cd2f1b
...
...
@@ -108,7 +108,7 @@ class Conv2d(PruneWorker):
if
pruned_axis
==
0
:
if
len
(
self
.
op
.
all_inputs
())
>
2
:
# has bias
self
.
pruned_params
.
append
(
(
self
.
op
.
all_inputs
()[
1
],
channel_axis
,
pruned_idx
))
(
self
.
op
.
all_inputs
()[
1
],
0
,
pruned_idx
))
output_var
=
self
.
op
.
all_outputs
()[
0
]
self
.
_visit
(
output_var
,
channel_axis
)
next_ops
=
output_var
.
outputs
()
...
...
@@ -135,7 +135,7 @@ class Conv2d(PruneWorker):
if
len
(
self
.
op
.
all_inputs
())
>
2
:
self
.
pruned_params
.
append
(
(
self
.
op
.
all_inputs
()[
1
],
channel_axis
,
pruned_idx
))
(
self
.
op
.
all_inputs
()[
1
],
0
,
pruned_idx
))
output_var
=
self
.
op
.
all_outputs
()[
0
]
next_ops
=
output_var
.
outputs
()
...
...
paddleslim/prune/group_param.py
浏览文件 @
e3cd2f1b
...
...
@@ -14,7 +14,10 @@
# limitations under the License.
from
..core
import
GraphWrapper
from
..core
import
DyGraph
import
paddle.fluid
as
fluid
from
.prune_walker
import
conv2d
as
conv2d_walker
from
.dy_prune_walker
import
Conv2d
as
dy_conv2d_walker
__all__
=
[
"collect_convs"
]
...
...
@@ -48,8 +51,10 @@ def collect_convs(params, graph):
list<list<tuple>>: The groups.
"""
if
not
isinstance
(
graph
,
GraphWrapper
):
if
isinstance
(
graph
,
fluid
.
Program
):
graph
=
GraphWrapper
(
graph
)
elif
isinstance
(
graph
,
DyGraph
):
conv2d_walker
=
dy_conv2d_walker
groups
=
[]
for
param
in
params
:
visited
=
{}
...
...
paddleslim/prune/importance_sort.py
浏览文件 @
e3cd2f1b
...
...
@@ -58,7 +58,6 @@ def channel_score_sort(group, graph):
list: sorted indexes
"""
assert
(
isinstance
(
graph
,
GraphWrapper
))
name
,
axis
,
score
=
group
[
0
]
# sort channels by the first convolution's score
sorted_idx
=
score
.
argsort
()
...
...
paddleslim/prune/pruner.py
浏览文件 @
e3cd2f1b
...
...
@@ -17,11 +17,13 @@ import sys
import
numpy
as
np
import
paddle.fluid
as
fluid
import
copy
from
..core
import
VarWrapper
,
OpWrapper
,
GraphWrapper
from
..core
import
GraphWrapper
from
..core
import
DyGraph
from
.group_param
import
collect_convs
from
.criterion
import
l1_norm
from
.importance_sort
import
channel_score_sort
,
batch_norm_scale
from
.importance_sort
import
channel_score_sort
,
batch_norm_scale
_sort
from
..common
import
get_logger
import
torch
__all__
=
[
"Pruner"
]
...
...
@@ -57,7 +59,8 @@ class Pruner():
lazy
=
False
,
only_graph
=
False
,
param_backup
=
False
,
param_shape_backup
=
False
):
param_shape_backup
=
False
,
input_shape
=
None
):
"""Pruning the given parameters.
Args:
...
...
@@ -82,8 +85,11 @@ class Pruner():
if
isinstance
(
graph
,
fluid
.
Program
):
graph
=
GraphWrapper
(
program
.
clone
())
elif
isinstance
(
graph
,
torch
.
nn
.
Module
):
graph
=
DyGraph
(
graph
)
conv2d_walker
=
dy_conv2d_walker
assert
(
input_shape
is
not
None
,
"input_shape can not be None while graph is instance of torch.nn.Module"
)
graph
=
DyGraph
(
graph
,
input_shape
)
else
:
raise
NotImplementedError
(
'The type of graph is not supported.'
)
param_backup
=
{}
if
param_backup
else
None
...
...
@@ -93,6 +99,7 @@ class Pruner():
pruned_params
=
[]
for
param
,
ratio
in
zip
(
params
,
ratios
):
group
=
collect_convs
([
param
],
graph
)[
0
]
# [(name, axis)]
print
"group: {}"
.
format
(
group
)
if
only_graph
:
param_v
=
graph
.
var
(
param
)
...
...
@@ -105,16 +112,17 @@ class Pruner():
group_values
=
[]
for
name
,
axis
in
group
:
values
=
np
.
array
(
scope
.
find_var
(
name
).
get_tensor
()
)
values
=
graph
.
var
(
name
).
data
(
)
group_values
.
append
((
name
,
values
,
axis
))
scores
=
self
.
criterion
(
group_with_values
)
# [(name, axis, score)]
scores
=
self
.
criterion
(
group_values
)
# [(name, axis, score)]
print
"scores: {}"
.
format
(
scores
)
group_idx
=
self
.
channel_sortor
(
scores
,
graph
=
graph
)
# [(name, axis, soted_idx)]
print
"group_idx: {}"
.
format
(
group_idx
)
for
param
,
pruned_axis
,
pruned_idx
in
group_idx
:
pruned_num
=
len
(
pruned_idx
)
*
ratio
pruned_num
=
int
(
round
(
len
(
pruned_idx
)
*
ratio
))
print
pruned_num
pruned_params
.
append
((
param
,
pruned_axis
,
pruned_idx
[:
pruned_num
]))
# [(name, axis, pruned_idx)]
...
...
@@ -142,7 +150,7 @@ class Pruner():
new_shape
[
pruned_axis
]
-=
len
(
pruned_idx
)
param
.
set_shape
(
new_shape
)
if
not
only_graph
:
param_t
=
scope
.
find_var
(
param
.
name
()).
get_tensor
()
param_t
=
graph
.
var
(
param_name
).
data
()
if
param_backup
is
not
None
and
(
param
.
name
()
not
in
param_backup
):
param_backup
[
param
.
name
()]
=
copy
.
deepcopy
(
...
...
@@ -157,9 +165,12 @@ class Pruner():
_logger
.
error
(
"Pruning {}, but get [{}]"
.
format
(
param
.
name
(),
e
))
param_t
.
set
(
pruned_param
,
place
)
graph
.
var
(
param_name
).
set_data
(
pruned_param
,
place
=
place
)
graph
.
update_groups_of_conv
()
graph
.
infer_shape
()
if
isinstance
(
graph
,
DyGraph
):
return
graph
.
module
,
param_backup
,
param_shape_backup
else
:
return
graph
.
program
,
param_backup
,
param_shape_backup
def
_cal_pruned_idx
(
self
,
graph
,
scope
,
param
,
ratio
,
axis
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录