Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
edd42662
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
1 年多 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
edd42662
编写于
8月 20, 2020
作者:
S
SunAhong1993
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify the prim2code
上级
f76d0121
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
196 addition
and
95 deletion
+196
-95
x2paddle/core/program.py
x2paddle/core/program.py
+5
-4
x2paddle/op_mapper/pytorch2paddle/aten.py
x2paddle/op_mapper/pytorch2paddle/aten.py
+63
-0
x2paddle/op_mapper/pytorch2paddle/prim2code.py
x2paddle/op_mapper/pytorch2paddle/prim2code.py
+65
-45
x2paddle/optimizer/fusion/__init__.py
x2paddle/optimizer/fusion/__init__.py
+2
-0
x2paddle/optimizer/fusion/fc_fuse_pass.py
x2paddle/optimizer/fusion/fc_fuse_pass.py
+4
-4
x2paddle/optimizer/optimizer.py
x2paddle/optimizer/optimizer.py
+1
-1
x2paddle/optimizer/pass_.py
x2paddle/optimizer/pass_.py
+2
-21
x2paddle/optimizer/pattern_matcher.py
x2paddle/optimizer/pattern_matcher.py
+54
-20
未找到文件。
x2paddle/core/program.py
浏览文件 @
edd42662
...
...
@@ -357,7 +357,8 @@ class PaddleGraph(object):
for
layer_id
,
layer
in
self
.
layers
.
items
():
if
self
.
edges_in
.
get
(
layer_id
,
0
)
==
0
and
self
.
edges_out
.
get
(
layer_id
,
0
)
==
0
and
layer
.
kernel
!=
"prim.assert"
:
layer_id
,
0
)
==
0
and
layer
.
kernel
!=
"prim.assert"
\
and
layer
.
kernel
!=
"prim.exception"
:
continue
if
"dygraph"
in
layer
.
kernel
:
line
=
"{}"
.
format
(
...
...
@@ -396,9 +397,9 @@ class PaddleGraph(object):
self
.
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
elif
"prim"
in
layer
.
kernel
:
func_name
=
layer
.
kernel
.
replace
(
"."
,
"_"
)
from
.
import
convert_prim
if
hasattr
(
convert_prim
,
func_name
):
func
=
getattr
(
convert_prim
,
func_name
)
from
x2paddle.op_mapper.pytorch2paddle
import
prim2code
if
hasattr
(
prim2code
,
func_name
):
func
=
getattr
(
prim2code
,
func_name
)
func
(
layer
,
indent
=
indent
,
...
...
x2paddle/op_mapper/pytorch2paddle/aten.py
浏览文件 @
edd42662
...
...
@@ -532,6 +532,41 @@ def aten_dropout(mapper, graph, node):
return
current_inputs
,
current_outputs
def
aten_dropout_
(
mapper
,
graph
,
node
):
""" 构造Dropout的PaddleLayer。
TorchScript示例:
%119 : Tensor = aten::dropout_(%result.3, %117, %118)
参数含义:
%119 (Tensor): Dropout后的Tensor。
%result.3 (Tensor): 输入Tensor。
%118 (bool): 是否是训练阶段。
"""
if
"dropout"
in
mapper
.
dygraph_name_id
:
mapper
.
dygraph_name_id
[
"dropout"
]
+=
1
else
:
mapper
.
dygraph_name_id
[
"dropout"
]
=
0
dropout_name
=
"dropout"
+
str
(
mapper
.
dygraph_name_id
[
"dropout"
])
output_name
=
mapper
.
_get_outputs_name
(
node
)[
0
]
layer_outputs
=
[
dropout_name
,
output_name
]
layer_inputs
=
{}
inputs_name
,
inputs_node
=
mapper
.
_get_inputs_name
(
node
)
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%119
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
)
layer_inputs
[
"input"
]
=
inputs_name
[
0
]
# 获取当前节点输入、输出的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"fluid.dygraph.Dropout"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
,
p
=
0.0
)
return
current_inputs
,
current_outputs
def
aten_eq
(
mapper
,
graph
,
node
):
""" 构造判断数值是否相等的PaddleLayer。
...
...
@@ -994,6 +1029,34 @@ def aten___not__(mapper, graph, node):
return
current_inputs
,
current_outputs
def
aten_relu
(
mapper
,
graph
,
node
):
""" 构造ReLU激活的PaddleLayer。
TorchScript示例:
%result.3 : Tensor = aten::relu(%input.5)
参数含义:
%result.3 (Tensor): 输出,ReLU后的结果。
%result.5 (Tensor): 需要ReLU的Tensor。
注意: inplace这个参数在paddle中未实现
"""
output_name
=
mapper
.
_get_outputs_name
(
node
)[
0
]
layer_outputs
=
[
output_name
]
layer_inputs
=
{}
inputs_name
,
inputs_node
=
mapper
.
_get_inputs_name
(
node
)
# 获取当前节点输出的list
current_outputs
=
[
output_name
]
# 处理输入0,即%result.5
mapper
.
_check_input
(
graph
,
inputs_node
[
0
],
inputs_name
[
0
],
current_outputs
)
layer_inputs
[
"x"
]
=
inputs_name
[
0
]
# 获取当前节点输入的list
current_inputs
=
list
(
layer_inputs
.
values
())
graph
.
add_layer
(
"fluid.layers.relu"
,
inputs
=
layer_inputs
,
outputs
=
layer_outputs
)
return
current_inputs
,
current_outputs
def
aten_relu_
(
mapper
,
graph
,
node
):
""" 构造ReLU激活的PaddleLayer。
...
...
x2paddle/
core/convert_prim
.py
→
x2paddle/
op_mapper/pytorch2paddle/prim2code
.py
浏览文件 @
edd42662
...
...
@@ -24,26 +24,39 @@ def gen_codes(code_list, indent=0):
return
codes
def
get_value
(
layer
,
key
):
""" 进行optimizer后可能把inputs的value直接用数值代替(ConstantFuser),
会把input换成attr,所以需要此处的操作。
"""
if
key
in
layer
.
inputs
:
return
layer
.
inputs
[
key
]
else
:
return
str
(
layer
.
attrs
[
key
])
def
prim_add
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} + {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
inputs
[
"y"
]
)
line
=
"{} = {} + {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
),
get_value
(
layer
,
"y"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_add_
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} + {} * {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
attrs
[
"alpha"
],
layer
.
inputs
[
"y"
])
line
=
"{} = {} + {} * {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
),
layer
.
attrs
[
"alpha"
],
get_value
(
layer
,
"y"
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_and
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} and {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
inputs
[
"y"
]
)
line
=
"{} = {} and {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
),
get_value
(
layer
,
"y"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_append
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{}.append({})"
.
format
(
layer
.
inputs
[
"list"
],
layer
.
inputs
[
"element"
])
line
=
"{}.append({})"
.
format
(
get_value
(
layer
,
"list"
),
get_value
(
layer
,
"element"
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
...
...
@@ -72,23 +85,23 @@ def prim_constant(layer, indent=1, init_func=[], forward_func=[]):
def
prim_eq
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} == {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
inputs
[
"y"
]
)
line
=
"{} = {} == {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
),
get_value
(
layer
,
"y"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_equal
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"input"
]
)
line
=
"{} = {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_exception
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"raise RaiseException({})"
.
format
(
layer
.
inputs
[
"input"
]
)
line
=
"raise RaiseException({})"
.
format
(
get_value
(
layer
,
"input"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_if
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"if {} :"
.
format
(
list
(
layer
.
inputs
.
values
())[
0
]
)
line
=
"if {} :"
.
format
(
get_value
(
layer
,
"input"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
block
=
layer
.
blocks
[
0
]
b_init_lines
,
b_forward_lines
=
block
.
gen_dygraph_code
(
indent
=
indent
+
1
)
...
...
@@ -105,45 +118,47 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[]):
def
prim_getitem
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {}[{}]"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"list"
],
layer
.
inputs
[
"index"
])
line
=
"{} = {}[{}]"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"list"
),
get_value
(
layer
,
"index"
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_gt
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} > {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
inputs
[
"y"
]
)
line
=
"{} = {} > {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
),
get_value
(
layer
,
"y"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_le
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} <= {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
inputs
[
"y"
]
)
line
=
"{} = {} <= {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
),
get_value
(
layer
,
"y"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_len
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = len({})"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"input"
]
)
line
=
"{} = len({})"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_lt
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} < {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
inputs
[
"y"
]
)
line
=
"{} = {} < {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
),
get_value
(
layer
,
"y"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_list
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
inputs_list
=
list
(
layer
.
inputs
.
values
())
input_len
=
len
(
layer
.
inputs
)
+
len
(
layer
.
attrs
)
inputs_list
=
list
()
for
i
in
range
(
input_len
):
inputs_list
.
append
(
get_value
(
layer
,
"input{}"
.
format
(
i
)))
inputs_str
=
', '
.
join
(
inputs_list
)
line
=
"{} = [{}]"
.
format
(
layer
.
outputs
[
0
],
inputs_str
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_loop
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
loop_range
=
list
(
layer
.
inputs
.
values
())[
0
]
if
list
(
layer
.
inputs
.
values
())[
0
]
is
None
:
loop_range
=
str
(
layer
.
attrs
[
list
(
layer
.
inputs
.
keys
())[
0
]])
loop_range
=
get_value
(
layer
,
"input"
)
line
=
"for {} in range({}):"
.
format
(
layer
.
outputs
[
1
],
loop_range
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
block
=
layer
.
blocks
[
0
]
...
...
@@ -153,66 +168,71 @@ def prim_loop(layer, indent=1, init_func=[], forward_func=[]):
def
prim_min
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = min({})"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"input"
]
)
line
=
"{} = min({})"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_mul
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} * {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
inputs
[
"y"
]
)
line
=
"{} = {} * {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
),
get_value
(
layer
,
"y"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_ne
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} < {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
inputs
[
"y"
]
)
line
=
"{} = {} < {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
),
get_value
(
layer
,
"y"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_neg
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = -{}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"input"
]
)
line
=
"{} = -{}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_not
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = not {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"input"
]
)
line
=
"{} = not {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_requires_grad
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = not {}.stop_gradient"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"input"
]
)
get_value
(
layer
,
"input"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_select
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {}["
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"input"
]
)
line
=
"{} = {}["
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
)
)
for
dim
in
range
(
layer
.
attrs
[
"dim"
]):
line
+=
":, "
line
+=
(
layer
.
inputs
[
"index"
]
+
"]"
)
line
+=
(
get_value
(
layer
,
"index"
)
+
"]"
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_shape
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {}.shape"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"input"
]
)
line
=
"{} = {}.shape"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_slice
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {}[{}: {}: {}]"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"input"
],
layer
.
inputs
[
"start"
],
layer
.
inputs
[
"end"
],
layer
.
inputs
[
"step"
])
line
=
"{} = {}[{}: {}: {}]"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"input"
),
get_value
(
layer
,
"start"
),
get_value
(
layer
,
"end"
),
get_value
(
layer
,
"step"
))
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_sub
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
line
=
"{} = {} - {}"
.
format
(
layer
.
outputs
[
0
],
layer
.
inputs
[
"x"
],
layer
.
inputs
[
"y"
]
)
line
=
"{} = {} - {}"
.
format
(
layer
.
outputs
[
0
],
get_value
(
layer
,
"x"
),
get_value
(
layer
,
"y"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_tuple
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
inputs_list
=
list
(
layer
.
inputs
.
values
())
input_len
=
len
(
layer
.
inputs
)
+
len
(
layer
.
attrs
)
inputs_list
=
list
()
for
i
in
range
(
input_len
):
inputs_list
.
append
(
get_value
(
layer
,
"input{}"
.
format
(
i
)))
inputs_str
=
', '
.
join
(
inputs_list
)
line
=
"{} = ({})"
.
format
(
layer
.
outputs
[
0
],
inputs_str
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
...
...
@@ -220,13 +240,13 @@ def prim_tuple(layer, indent=1, init_func=[], forward_func=[]):
def
prim_tuple_unpack
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
outputs_str
=
', '
.
join
(
layer
.
outputs
)
line
=
"{} = {}"
.
format
(
outputs_str
,
layer
.
inputs
[
"input"
]
)
line
=
"{} = {}"
.
format
(
outputs_str
,
get_value
(
layer
,
"input"
)
)
forward_func
.
extend
(
gen_codes
([
line
],
indent
=
indent
))
def
prim_warnings
(
layer
,
indent
=
1
,
init_func
=
[],
forward_func
=
[]):
lines
=
[
"import warnings"
]
line
=
"warnings.warn({}, stacklevel={})"
.
format
(
layer
.
inputs
[
"input"
],
layer
.
attrs
[
"stacklevel"
])
line
=
"warnings.warn({}, stacklevel={})"
.
format
(
get_value
(
layer
,
"input"
),
layer
.
attrs
[
"stacklevel"
])
lines
.
append
(
line
)
forward_func
.
extend
(
gen_codes
(
lines
,
indent
=
indent
))
x2paddle/optimizer/fusion/__init__.py
浏览文件 @
edd42662
...
...
@@ -14,3 +14,5 @@
from
.fc_fuser
import
FcFuser
from
.fc_fuse_pass
import
FcFusePass
from
.constant_fuser
import
ConstantFuser
from
.constant_fuse_pass
import
ConstantFusePass
x2paddle/optimizer/fusion/fc_fuse_pass.py
浏览文件 @
edd42662
...
...
@@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
x2paddle.optimizer.pass_
import
P
rogramP
ass
from
x2paddle.optimizer.pass_
import
Pass
from
x2paddle.optimizer.fusion
import
FcFuser
from
x2paddle.optimizer.pass_manager
import
pass_register
@
pass_register
class
FcFusePass
(
P
rogramP
ass
):
class
FcFusePass
(
Pass
):
name
=
"fc_fuse_pass"
def
__init__
(
self
):
P
rogramP
ass
.
__init__
(
self
)
Pass
.
__init__
(
self
)
def
apply
(
self
,
graph
):
fuser
=
FcFuser
()
fuser
.
operate
(
graph
)
fuser
.
operate
(
graph
,
match_kind
=
"topo"
)
# 用于注册
...
...
x2paddle/optimizer/optimizer.py
浏览文件 @
edd42662
...
...
@@ -18,7 +18,7 @@ from x2paddle.optimizer.pass_manager import PassManager
class
GraphOptimizer
(
object
):
def
__init__
(
self
):
self
.
passes
=
[
"fc_fuse_pass"
]
self
.
passes
=
[
"fc_fuse_pass"
,
"constant_fuse_pass"
]
def
optimize
(
self
,
graph
):
for
pass_name
in
self
.
passes
:
...
...
x2paddle/optimizer/pass_.py
浏览文件 @
edd42662
...
...
@@ -12,19 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
enum
import
Enum
class
Kind
(
Enum
):
Program
=
1
Code
=
2
class
Pass
(
object
):
name
=
"pass"
def
__init__
(
self
,
kind
):
self
.
kind
=
kind
def
__init__
(
self
):
pass
def
apply
(
self
,
graph
):
raise
NotImplementedError
(
"The apply function must be implemented!"
)
...
...
@@ -32,13 +23,3 @@ class Pass(object):
@
classmethod
def
get_name
(
cls
):
return
cls
.
name
class
ProgramPass
(
Pass
):
def
__init__
(
self
):
super
(
ProgramPass
,
self
).
__init__
(
Kind
.
Program
)
class
CodePass
(
Pass
):
def
__init__
(
self
):
super
(
CodePass
,
self
).
__init__
(
Kind
.
Code
)
x2paddle/optimizer/pattern_matcher.py
浏览文件 @
edd42662
...
...
@@ -18,14 +18,18 @@ from x2paddle.core.program import PaddleGraph
class
PatternMatcher
(
object
):
def
__init__
(
self
,
pattern
):
self
.
pattern
=
pattern
self
.
subgraphs
=
list
()
# matches的每个match是按照拓扑排序组成layer的dict
self
.
matches
=
list
()
def
operate
(
self
,
graph
):
self
.
detect_patterns
(
graph
)
def
operate
(
self
,
graph
,
match_kind
=
"topo"
):
if
match_kind
==
"topo"
:
self
.
detect_patterns_by_topo
(
graph
)
elif
match_kind
==
"edge"
:
self
.
detect_patterns_by_edge
(
graph
)
self
.
remove_overlapped_match
()
return
self
.
subgraph
s
return
self
.
matche
s
def
detect_patterns
(
self
,
graph
):
def
detect_patterns
_by_topo
(
self
,
graph
):
""" 找到与模式匹配的子图,
并将子图的id以拓扑排序存放到subgraph_id2layers。
"""
...
...
@@ -101,49 +105,79 @@ class PatternMatcher(object):
for
i
,
(
layer_id
,
layer
)
in
enumerate
(
graph
.
layers
.
items
()):
match_info
=
get_subgraph
(
self
.
pattern
,
graph
,
i
)
if
match_info
:
self
.
subgraph
s
.
append
(
match_info
)
self
.
matche
s
.
append
(
match_info
)
for
j
,
block
in
enumerate
(
layer
.
blocks
):
if
len
(
block
.
layers
)
>
0
:
self
.
detect_patterns
(
layer
.
blocks
[
j
])
self
.
detect_patterns_by_topo
(
layer
.
blocks
[
j
])
def
detect_patterns_by_edge
(
self
,
graph
):
"""当遇见顺序没有强制规定的pattern时使用该方式
"""
pass
def
remove_overlapped_match
(
self
):
""" 如果2个子图有重叠,只取前一个子图。
"""
match_ids
=
[]
for
i
,
subgraph
in
enumerate
(
self
.
subgraph
s
):
for
i
,
match
in
enumerate
(
self
.
matche
s
):
is_overlapped
=
False
for
id
in
subgrap
h
.
keys
():
for
id
in
matc
h
.
keys
():
if
id
in
match_ids
:
self
.
subgraph
s
.
pop
(
i
)
self
.
matche
s
.
pop
(
i
)
is_overlapped
=
True
break
if
not
is_overlapped
:
match_ids
.
extend
(
list
(
subgraph
.
keys
()))
match_ids
.
extend
(
list
(
match
.
keys
()))
def
get_subgraph
(
prefix_layer_id
,
suffix_layer_id
,
graph
):
""" 根据prefix_layer_id和suffix_layer_id获取需要子图。
Args:
prefix_layer_id (str): 起初为一个空字符串,之后为suffix_layer_id分割出来的前缀。
suffix_layer_id (str): 起初为以一个layer的id,之后将分割部分给prefix_layer_id;例如”57.0.1“;
graph (x2paddle.core.program.PaddleGraph): 需要进行pass的子图。
"""
id_part
=
suffix_layer_id
.
split
(
"."
)
if
len
(
id_part
)
==
1
:
return
graph
if
prefix_layer_id
==
""
:
layer_id
=
id_part
[
0
]
prefix_layer_id
+=
"."
.
join
(
id_part
[:
2
])
else
:
layer_id
=
prefix_layer_id
+
"."
+
id_part
[
0
]
prefix_layer_id
+=
(
"."
+
"."
.
join
(
id_part
[:
2
]))
subgraph
=
graph
.
layers
[
layer_id
].
blocks
[
int
(
id_part
[
1
])]
suffix_layer_id
=
"."
.
join
(
id_part
[
2
:])
return
get_subgraph
(
prefix_layer_id
,
suffix_layer_id
,
subgraph
)
class
FuseBase
(
object
):
def
__init__
(
self
):
self
.
pattern
=
PaddleGraph
()
def
operate
(
self
,
graph
):
def
operate
(
self
,
graph
,
match_kind
=
"topo"
):
self
.
build_pattern
()
self
.
perform_pattern_matcher
(
graph
)
for
subgraph
in
self
.
subgraphs
:
self
.
insert_new_layer
(
graph
,
subgraph
)
self
.
perform_pattern_matcher
(
graph
,
match_kind
)
for
match
in
self
.
matches
:
first_layer_id
=
list
(
match
.
keys
())[
0
]
subgraph
=
get_subgraph
(
""
,
first_layer_id
,
graph
)
self
.
insert_new_layer
(
subgraph
,
match
)
self
.
delete_inter_layer
(
graph
)
graph
.
build
()
def
perform_pattern_matcher
(
self
,
graph
):
def
perform_pattern_matcher
(
self
,
graph
,
match_kind
=
"topo"
):
""" 执行模式匹配,找到匹配的子图。
"""
pattern_matcher
=
PatternMatcher
(
self
.
pattern
)
self
.
subgraphs
=
pattern_matcher
.
operate
(
graph
)
self
.
matches
=
pattern_matcher
.
operate
(
graph
,
match_kind
)
def
delete_inter_layer
(
self
,
graph
):
""" 删除不需要的中间layer及其对应参数。
"""
for
subgraph
in
self
.
subgraphs
:
for
layer_id
,
layer
in
subgraph
.
items
():
for
match
in
self
.
matches
:
first_layer_id
=
list
(
match
.
keys
())[
0
]
subgraph
=
get_subgraph
(
""
,
first_layer_id
,
graph
)
for
layer_id
,
layer
in
match
.
items
():
if
layer
.
kernel
==
"fluid.dygraph.base.to_variable"
and
\
layer
.
attrs
[
"value"
].
startswith
(
"params["
):
param_name
=
layer
.
attrs
[
"value"
][
8
:
-
2
]
...
...
@@ -151,4 +185,4 @@ class FuseBase(object):
graph
.
parameters
.
pop
(
param_name
)
if
layer_id
in
graph
.
layers
:
# layer_id可能是属于子图的,此时删除父layer,即删除整个子图
graph
.
layers
.
pop
(
layer_id
)
sub
graph
.
layers
.
pop
(
layer_id
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录