Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
8fee1a69
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8fee1a69
编写于
10月 11, 2019
作者:
J
Jason
提交者:
GitHub
10月 11, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #160 from mamingjie-China/develop
modify the codes in TensorFlow converting
上级
2bcb8fd3
bac4164d
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
233 addition
and
102 deletion
+233
-102
x2paddle/op_mapper/tf_op_mapper.py
x2paddle/op_mapper/tf_op_mapper.py
+4
-0
x2paddle/op_mapper/tf_op_mapper_nhwc.py
x2paddle/op_mapper/tf_op_mapper_nhwc.py
+51
-11
x2paddle/optimizer/tf_optimizer.py
x2paddle/optimizer/tf_optimizer.py
+178
-91
未找到文件。
x2paddle/op_mapper/tf_op_mapper.py
浏览文件 @
8fee1a69
...
@@ -168,7 +168,11 @@ class TFOpMapper(OpMapper):
...
@@ -168,7 +168,11 @@ class TFOpMapper(OpMapper):
x_input
=
y
x_input
=
y
y_input
=
x
y_input
=
x
x_shape
=
y
.
out_shapes
[
0
]
x_shape
=
y
.
out_shapes
[
0
]
if
len
(
x_shape
)
==
0
:
x_shape
=
[
1
]
y_shape
=
x
.
out_shapes
[
0
]
y_shape
=
x
.
out_shapes
[
0
]
if
len
(
y_shape
)
==
0
:
y_shape
=
[
1
]
else
:
else
:
if
len
(
x_shape
)
==
1
and
len
(
y_shape
)
==
4
and
x_shape
[
if
len
(
x_shape
)
==
1
and
len
(
y_shape
)
==
4
and
x_shape
[
0
]
==
y_shape
[
-
1
]
and
y_shape
.
count
(
-
1
)
<
1
:
0
]
==
y_shape
[
-
1
]
and
y_shape
.
count
(
-
1
)
<
1
:
...
...
x2paddle/op_mapper/tf_op_mapper_nhwc.py
浏览文件 @
8fee1a69
...
@@ -121,6 +121,25 @@ class TFOpMapperNHWC(OpMapper):
...
@@ -121,6 +121,25 @@ class TFOpMapperNHWC(OpMapper):
pd_param_name
=
list
(
param
.
values
())[
0
]
pd_param_name
=
list
(
param
.
values
())[
0
]
tf_param
=
node
.
get_attr
(
tf_param_name
)
tf_param
=
node
.
get_attr
(
tf_param_name
)
attr
[
pd_param_name
]
=
tf_param
attr
[
pd_param_name
]
=
tf_param
if
len
(
input
.
out_shapes
[
0
])
==
4
and
op_info
[
0
]
!=
'shape'
:
attr1
=
{
"perm"
:
[
0
,
3
,
1
,
2
]}
node
.
fluid_code
.
add_layer
(
'transpose'
,
inputs
=
input
,
output
=
node
,
param_attr
=
attr1
)
input
=
node
node
.
fluid_code
.
add_layer
(
op_info
[
0
],
inputs
=
input
,
output
=
node
,
param_attr
=
attr
)
input
=
node
attr2
=
{
"perm"
:
[
0
,
2
,
3
,
1
]}
node
.
fluid_code
.
add_layer
(
'transpose'
,
inputs
=
input
,
output
=
node
,
param_attr
=
attr2
)
else
:
node
.
fluid_code
.
add_layer
(
op_info
[
0
],
node
.
fluid_code
.
add_layer
(
op_info
[
0
],
inputs
=
input
,
inputs
=
input
,
output
=
node
,
output
=
node
,
...
@@ -149,7 +168,11 @@ class TFOpMapperNHWC(OpMapper):
...
@@ -149,7 +168,11 @@ class TFOpMapperNHWC(OpMapper):
x_input
=
y
x_input
=
y
y_input
=
x
y_input
=
x
x_shape
=
y
.
out_shapes
[
0
]
x_shape
=
y
.
out_shapes
[
0
]
if
len
(
x_shape
)
==
0
:
x_shape
=
[
1
]
y_shape
=
x
.
out_shapes
[
0
]
y_shape
=
x
.
out_shapes
[
0
]
if
len
(
y_shape
)
==
0
:
y_shape
=
[
1
]
else
:
else
:
raise
Exception
(
"Unexpected situation happend"
)
raise
Exception
(
"Unexpected situation happend"
)
...
@@ -193,6 +216,25 @@ class TFOpMapperNHWC(OpMapper):
...
@@ -193,6 +216,25 @@ class TFOpMapperNHWC(OpMapper):
output
=
"y_tmp"
,
output
=
"y_tmp"
,
param_attr
=
attr
)
param_attr
=
attr
)
y_input
=
"y_tmp"
y_input
=
"y_tmp"
if
len
(
x_shape
)
==
4
and
len
(
y_shape
)
==
4
:
node
.
fluid_code
.
add_layer
(
"transpose"
,
inputs
=
x_input
,
output
=
x_input
,
param_attr
=
{
'perm'
:
[
0
,
3
,
1
,
2
]})
node
.
fluid_code
.
add_layer
(
"transpose"
,
inputs
=
y_input
,
output
=
y_input
,
param_attr
=
{
'perm'
:
[
0
,
3
,
1
,
2
]})
inputs
=
{
"x"
:
x_input
,
"y"
:
y_input
}
node
.
fluid_code
.
add_layer
(
op_type
,
inputs
=
inputs
,
output
=
node
,
param_attr
=
None
)
node
.
fluid_code
.
add_layer
(
"transpose"
,
inputs
=
node
,
output
=
node
,
param_attr
=
{
'perm'
:
[
0
,
2
,
3
,
1
]})
else
:
inputs
=
{
"x"
:
x_input
,
"y"
:
y_input
}
inputs
=
{
"x"
:
x_input
,
"y"
:
y_input
}
node
.
fluid_code
.
add_layer
(
op_type
,
node
.
fluid_code
.
add_layer
(
op_type
,
inputs
=
inputs
,
inputs
=
inputs
,
...
@@ -978,9 +1020,7 @@ class TFOpMapperNHWC(OpMapper):
...
@@ -978,9 +1020,7 @@ class TFOpMapperNHWC(OpMapper):
if
pad_mode
==
"SAME"
:
if
pad_mode
==
"SAME"
:
if
node
.
tf_data_format
==
"NHWC"
:
if
node
.
tf_data_format
==
"NHWC"
:
print
(
out_shape
)
out_shape
=
[
out_shape
[
i
]
for
i
in
[
0
,
3
,
1
,
2
]]
out_shape
=
[
out_shape
[
i
]
for
i
in
[
0
,
3
,
1
,
2
]]
print
(
out_shape
)
for
i
in
range
(
4
):
for
i
in
range
(
4
):
if
out_shape
[
i
]
<
0
:
if
out_shape
[
i
]
<
0
:
out_shape
[
i
]
=
999999
out_shape
[
i
]
=
999999
...
...
x2paddle/optimizer/tf_optimizer.py
浏览文件 @
8fee1a69
...
@@ -232,84 +232,35 @@ class TFOptimizer(object):
...
@@ -232,84 +232,35 @@ class TFOptimizer(object):
'act'
]
'act'
]
node
.
fluid_code
.
clear
()
node
.
fluid_code
.
clear
()
self
.
graph
.
remove_node
(
node
.
layer_name
)
self
.
graph
.
remove_node
(
node
.
layer_name
)
self
.
graph
.
identity_map
[
node
.
layer_name
]
=
input
.
layer_name
def
remove_transpose
(
self
):
def
remove_transpose
(
self
):
graph_copy
=
cp
.
deepcopy
(
self
.
graph
)
graph_copy
=
cp
.
deepcopy
(
self
.
graph
)
nhwc_insensitive_ops
=
[
nhwc_insensitive_ops
=
[
'Relu'
,
'Relu6'
,
'Abs'
,
'Sigmoid'
,
'Exp'
,
'Rsqrt'
,
'swish_f32'
,
'Relu'
,
'Relu6'
,
'Abs'
,
'Sigmoid'
,
'Exp'
,
'Rsqrt'
,
'swish_f32'
,
'LeakyRelu'
,
'Cast'
'LeakyRelu'
,
'Cast'
,
'Tanh'
]
]
elementwise_ops
=
[
elementwise_ops
=
[
'Sub'
,
'Add'
,
'RealDiv'
,
'Maximum'
,
'Mul'
,
'FloorDiv'
,
'Sub'
,
'Add'
,
'RealDiv'
,
'Maximum'
,
'Mul'
,
'FloorDiv'
,
'GreaterEqual'
'GreaterEqual'
]
]
for
node_name
in
self
.
graph
.
topo_sort
:
node
=
graph_copy
.
get_node
(
node_name
)
if
node
is
None
:
continue
if
node
.
layer_type
in
nhwc_insensitive_ops
:
graph_copy
.
remove_node
(
node_name
)
optimize_ops
=
[
optimize_ops
=
[
'Conv2D'
,
'MaxPool'
,
'FusedBatchNorm'
,
'DepthwiseConv2dNative'
,
'Conv2D'
,
'MaxPool'
,
'FusedBatchNorm'
,
'DepthwiseConv2dNative'
,
'AvgPool'
,
'Pad'
,
'Conv2DBackpropInput'
,
'ResizeNearestNeighbor'
,
'AvgPool'
,
'Pad'
,
'Conv2DBackpropInput'
,
'ResizeNearestNeighbor'
,
'ResizeBilinear'
,
"Placeholder"
'ResizeBilinear'
,
"Placeholder"
]
]
can_be_optimized_ops
=
[
'Conv2D'
,
'MaxPool'
,
'FusedBatchNorm'
,
'DepthwiseConv2dNative'
,
'AvgPool'
,
'Pad'
,
'Conv2DBackpropInput'
,
'ResizeNearestNeighbor'
,
'ResizeBilinear'
,
"Placeholder"
,
'Relu'
,
'Relu6'
,
'Abs'
,
'Sigmoid'
,
'Exp'
,
'Rsqrt'
,
'swish_f32'
,
'LeakyRelu'
,
'Cast'
,
'Tanh'
]
for
node_name
in
self
.
graph
.
topo_sort
:
for
node_name
in
self
.
graph
.
topo_sort
:
node
=
graph_copy
.
get_node
(
node_name
)
node
=
graph_copy
.
get_node
(
node_name
)
if
node
is
None
:
if
node
is
None
:
continue
continue
if
node
.
layer_type
in
elementwise_ops
:
if
node
.
layer_type
in
can_be_optimized_ops
:
is_nhwc
=
True
for
in_name
in
node
.
inputs
:
in_node
=
graph_copy
.
get_node
(
in_name
)
if
hasattr
(
in_node
,
"is_nhwc"
):
if
not
in_node
.
is_nhwc
:
is_nhwc
=
False
else
:
if
len
(
in_node
.
fluid_code
.
layers
)
<
2
:
is_nhwc
=
False
continue
if
in_node
.
fluid_code
.
layers
[
-
1
].
op
!=
"transpose"
or
in_node
.
fluid_code
.
layers
[
-
1
].
param_attr
[
"perm"
]
!=
[
0
,
2
,
3
,
1
]:
is_nhwc
=
False
continue
node
.
is_nhwc
=
is_nhwc
for
i
in
range
(
len
(
self
.
graph
.
topo_sort
)):
node_name
=
self
.
graph
.
topo_sort
[
-
1
*
i
-
1
]
node
=
graph_copy
.
get_node
(
node_name
)
if
node
is
None
:
continue
if
node
.
layer_type
in
elementwise_ops
:
can_be_removed
=
True
if
len
(
node
.
fluid_code
.
layers
)
>
1
:
can_be_removed
=
False
if
not
node
.
is_nhwc
:
can_be_removed
=
False
for
out_name
in
node
.
outputs
:
out_node
=
graph_copy
.
get_node
(
out_name
)
if
hasattr
(
out_node
,
"is_nhwc"
):
if
not
out_node
.
is_nhwc
:
can_be_removed
=
False
else
:
if
len
(
out_node
.
fluid_code
.
layers
)
<
2
:
can_be_removed
=
False
break
if
out_node
.
fluid_code
.
layers
[
0
].
op
!=
"transpose"
or
out_node
.
fluid_code
.
layers
[
0
].
param_attr
[
"perm"
]
!=
[
0
,
3
,
1
,
2
]:
can_be_removed
=
False
break
node
.
can_be_removed
=
can_be_removed
for
node_name
in
self
.
graph
.
topo_sort
:
node
=
graph_copy
.
get_node
(
node_name
)
if
node
is
None
:
continue
if
node
.
layer_type
in
optimize_ops
:
if
node
.
fluid_code
.
layers
[
if
node
.
fluid_code
.
layers
[
-
1
].
op
!=
"transpose"
or
node
.
fluid_code
.
layers
[
-
1
].
op
!=
"transpose"
or
node
.
fluid_code
.
layers
[
-
1
].
param_attr
[
"perm"
]
!=
[
0
,
2
,
3
,
1
]:
-
1
].
param_attr
[
"perm"
]
!=
[
0
,
2
,
3
,
1
]:
...
@@ -327,6 +278,9 @@ class TFOptimizer(object):
...
@@ -327,6 +278,9 @@ class TFOptimizer(object):
0
].
param_attr
[
"perm"
]
!=
[
0
,
3
,
1
,
2
]:
0
].
param_attr
[
"perm"
]
!=
[
0
,
3
,
1
,
2
]:
can_be_removed
=
False
can_be_removed
=
False
break
break
elif
out_node
.
layer_type
in
elementwise_ops
:
can_be_removed
=
False
break
if
can_be_removed
and
len
(
node
.
fluid_code
.
layers
)
>
1
:
if
can_be_removed
and
len
(
node
.
fluid_code
.
layers
)
>
1
:
true_node
=
self
.
graph
.
get_node
(
node_name
)
true_node
=
self
.
graph
.
get_node
(
node_name
)
if
true_node
.
layer_type
==
"Placeholder"
:
if
true_node
.
layer_type
==
"Placeholder"
:
...
@@ -346,8 +300,6 @@ class TFOptimizer(object):
...
@@ -346,8 +300,6 @@ class TFOptimizer(object):
del
true_node
.
fluid_code
.
layers
[
-
1
]
del
true_node
.
fluid_code
.
layers
[
-
1
]
for
out_name
in
output_names
:
for
out_name
in
output_names
:
out_node
=
self
.
graph
.
get_node
(
out_name
)
out_node
=
self
.
graph
.
get_node
(
out_name
)
if
out_node
.
layer_type
in
elementwise_ops
:
continue
out_node
.
fluid_code
.
layers
[
out_node
.
fluid_code
.
layers
[
1
].
inputs
=
out_node
.
fluid_code
.
layers
[
0
].
inputs
1
].
inputs
=
out_node
.
fluid_code
.
layers
[
0
].
inputs
del
out_node
.
fluid_code
.
layers
[
0
]
del
out_node
.
fluid_code
.
layers
[
0
]
...
@@ -357,43 +309,178 @@ class TFOptimizer(object):
...
@@ -357,43 +309,178 @@ class TFOptimizer(object):
if
node
is
None
:
if
node
is
None
:
continue
continue
if
node
.
layer_type
in
elementwise_ops
:
if
node
.
layer_type
in
elementwise_ops
:
if
not
node
.
can_be_removed
:
can_be_removed
=
True
if
node
.
fluid_code
.
layers
[
-
1
].
op
!=
"transpose"
or
node
.
fluid_code
.
layers
[
-
1
].
param_attr
[
"perm"
]
!=
[
0
,
2
,
3
,
1
]:
continue
can_be_removed
=
True
output_names
=
node
.
outputs
for
out_name
in
output_names
:
out_node
=
graph_copy
.
get_node
(
out_name
)
if
len
(
out_node
.
fluid_code
.
layers
)
<
3
:
can_be_removed
=
False
break
if
hasattr
(
out_node
,
"can_be_removed"
):
if
not
out_node
.
can_be_removed
:
can_be_removed
=
False
break
if
out_node
.
layer_type
in
can_be_optimized_ops
:
if
out_node
.
fluid_code
.
layers
[
0
].
op
!=
"transpose"
or
out_node
.
fluid_code
.
layers
[
0
].
param_attr
[
"perm"
]
!=
[
0
,
3
,
1
,
2
]:
can_be_removed
=
False
break
elif
out_node
.
layer_type
in
elementwise_ops
:
if
out_node
.
fluid_code
.
layers
[
0
].
op
!=
"transpose"
and
out_node
.
fluid_code
.
layers
[
1
].
op
!=
"transpose"
:
can_be_removed
=
False
break
if
out_node
.
fluid_code
.
layers
[
0
].
op
==
"transpose"
:
if
out_node
.
fluid_code
.
layers
[
0
].
param_attr
[
"perm"
]
!=
[
0
,
3
,
1
,
2
]:
can_be_removed
=
False
break
if
out_node
.
fluid_code
.
layers
[
1
].
op
==
"transpose"
:
if
out_node
.
fluid_code
.
layers
[
1
].
param_attr
[
"perm"
]
!=
[
0
,
3
,
1
,
2
]:
can_be_removed
=
False
break
if
can_be_removed
and
len
(
node
.
fluid_code
.
layers
)
>
1
:
true_node
=
self
.
graph
.
get_node
(
node_name
)
true_node
=
self
.
graph
.
get_node
(
node_name
)
for
i
,
in_name
in
enumerate
(
node
.
inputs
):
true_node
.
fluid_code
.
layers
[
in_node
=
graph_copy
.
get_node
(
in_name
)
-
2
].
output
=
true_node
.
fluid_code
.
layers
[
-
1
].
output
if
hasattr
(
in_node
,
"is_nhwc"
)
and
in_node
.
is_nhwc
:
del
true_node
.
fluid_code
.
layers
[
-
1
]
if
i
==
0
:
for
out_name
in
output_names
:
l
=
Layer
()
out_node
=
self
.
graph
.
get_node
(
out_name
)
l
.
op
=
"transpose"
if
out_node
.
layer_type
in
can_be_optimized_ops
:
l
.
inputs
=
true_node
.
fluid_code
.
layers
[
out_node
.
fluid_code
.
layers
[
0
].
inputs
[
"x"
]
1
].
inputs
=
out_node
.
fluid_code
.
layers
[
0
].
inputs
l
.
param_attr
=
{
"perm"
:
[
0
,
2
,
3
,
1
]}
del
out_node
.
fluid_code
.
layers
[
0
]
l
.
output
=
"nhwc_"
+
l
.
inputs
.
layer_name
elif
out_node
.
layer_type
in
elementwise_ops
:
true_node
.
fluid_code
.
layers
[
0
].
inputs
[
if
out_node
.
inputs
[
0
]
in
node
.
layer_name
:
"x"
]
=
l
.
output
if
out_node
.
fluid_code
.
layers
[
true_node
.
fluid_code
.
layers
.
insert
(
0
,
l
)
1
].
op
==
'transpose'
:
elif
i
==
1
:
out_node
.
fluid_code
.
layers
[
2
].
inputs
[
l
=
Layer
()
'x'
]
=
out_node
.
fluid_code
.
layers
[
l
.
op
=
"transpose"
0
].
inputs
l
.
inputs
=
true_node
.
fluid_code
.
layers
[
del
out_node
.
fluid_code
.
layers
[
0
]
0
].
inputs
[
"y"
]
l
.
param_attr
=
{
"perm"
:
[
0
,
2
,
3
,
1
]}
l
.
output
=
"nhwc_"
+
l
.
inputs
.
layer_name
true_node
.
fluid_code
.
layers
[
0
].
inputs
[
"y"
]
=
l
.
output
true_node
.
fluid_code
.
layers
.
insert
(
0
,
l
)
else
:
else
:
raise
Exception
(
"Unexpected situation happend"
)
out_node
.
fluid_code
.
layers
[
1
].
inputs
[
continue
'x'
]
=
out_node
.
fluid_code
.
layers
[
0
].
inputs
del
out_node
.
fluid_code
.
layers
[
0
]
elif
out_node
.
inputs
[
1
]
in
node
.
layer_name
:
if
out_node
.
fluid_code
.
layers
[
1
].
op
==
'transpose'
:
out_node
.
fluid_code
.
layers
[
2
].
inputs
[
'y'
]
=
out_node
.
fluid_code
.
layers
[
1
].
inputs
del
out_node
.
fluid_code
.
layers
[
1
]
else
:
else
:
for
out_name
in
node
.
outputs
:
out_node
.
fluid_code
.
layers
[
1
].
inputs
[
'y'
]
=
out_node
.
fluid_code
.
layers
[
0
].
inputs
del
out_node
.
fluid_code
.
layers
[
0
]
graph_copy
=
cp
.
deepcopy
(
self
.
graph
)
for
node_name
in
self
.
graph
.
topo_sort
:
node
=
graph_copy
.
get_node
(
node_name
)
if
node
is
None
or
len
(
node
.
fluid_code
.
layers
)
<
2
:
continue
if
node
.
layer_type
in
can_be_optimized_ops
and
node
.
layer_type
!=
"Placeholder"
:
if
node
.
fluid_code
.
layers
[
-
1
].
op
!=
"transpose"
or
node
.
fluid_code
.
layers
[
-
1
].
param_attr
[
"perm"
]
!=
[
0
,
2
,
3
,
1
]:
continue
can_be_removed
=
True
output_names
=
node
.
outputs
for
out_name
in
output_names
:
out_node
=
graph_copy
.
get_node
(
out_name
)
if
hasattr
(
out_node
,
"can_be_removed"
):
if
not
out_node
.
can_be_removed
:
can_be_removed
=
False
break
if
len
(
out_node
.
fluid_code
.
layers
)
<
2
:
can_be_removed
=
False
break
if
out_node
.
layer_type
in
can_be_optimized_ops
:
if
out_node
.
fluid_code
.
layers
[
0
].
op
!=
"transpose"
or
out_node
.
fluid_code
.
layers
[
0
].
param_attr
[
"perm"
]
!=
[
0
,
3
,
1
,
2
]:
can_be_removed
=
False
break
elif
out_node
.
layer_type
in
elementwise_ops
:
if
out_node
.
fluid_code
.
layers
[
0
].
op
!=
"transpose"
and
out_node
.
fluid_code
.
layers
[
1
].
op
!=
"transpose"
:
can_be_removed
=
False
break
if
out_node
.
fluid_code
.
layers
[
0
].
op
==
"expand"
or
out_node
.
fluid_code
.
layers
[
1
].
op
==
"expand"
:
can_be_removed
=
False
break
if
out_node
.
fluid_code
.
layers
[
0
].
op
==
"transpose"
:
if
out_node
.
fluid_code
.
layers
[
0
].
param_attr
[
"perm"
]
!=
[
0
,
3
,
1
,
2
]:
can_be_removed
=
False
break
if
out_node
.
fluid_code
.
layers
[
1
].
op
==
"transpose"
:
if
out_node
.
fluid_code
.
layers
[
1
].
param_attr
[
"perm"
]
!=
[
0
,
3
,
1
,
2
]:
can_be_removed
=
False
break
elif
out_node
.
layer_type
not
in
elementwise_ops
and
out_node
.
layer_type
not
in
can_be_optimized_ops
:
can_be_removed
=
False
break
if
can_be_removed
:
true_node
=
self
.
graph
.
get_node
(
node_name
)
if
len
(
true_node
.
fluid_code
.
layers
)
<
2
:
continue
true_node
.
fluid_code
.
layers
[
-
2
].
output
=
true_node
.
fluid_code
.
layers
[
-
1
].
output
del
true_node
.
fluid_code
.
layers
[
-
1
]
for
out_name
in
output_names
:
out_node
=
self
.
graph
.
get_node
(
out_name
)
out_node
=
self
.
graph
.
get_node
(
out_name
)
if
out_node
.
layer_type
not
in
elementwise_ops
:
if
out_node
.
layer_type
in
can_be_optimized_ops
:
assert
out_node
.
fluid_code
.
layers
[
0
].
op
==
"transpose"
,
"unexpected situation happend"
out_node
.
fluid_code
.
layers
[
out_node
.
fluid_code
.
layers
[
1
].
inputs
=
out_node
.
fluid_code
.
layers
[
0
].
inputs
1
].
inputs
=
out_node
.
fluid_code
.
layers
[
0
].
inputs
del
out_node
.
fluid_code
.
layers
[
0
]
del
out_node
.
fluid_code
.
layers
[
0
]
elif
out_node
.
layer_type
in
elementwise_ops
:
if
out_node
.
inputs
[
0
]
in
node
.
layer_name
:
if
out_node
.
fluid_code
.
layers
[
1
].
op
==
'transpose'
:
if
out_node
.
fluid_code
.
layers
[
2
].
op
==
'transpose'
:
out_node
.
fluid_code
.
layers
[
3
].
inputs
[
'x'
]
=
out_node
.
fluid_code
.
layers
[
0
].
inputs
else
:
out_node
.
fluid_code
.
layers
[
2
].
inputs
[
'x'
]
=
out_node
.
fluid_code
.
layers
[
0
].
inputs
del
out_node
.
fluid_code
.
layers
[
0
]
else
:
out_node
.
fluid_code
.
layers
[
1
].
inputs
[
'x'
]
=
out_node
.
fluid_code
.
layers
[
0
].
inputs
del
out_node
.
fluid_code
.
layers
[
0
]
elif
out_node
.
inputs
[
1
]
in
node
.
layer_name
:
if
out_node
.
fluid_code
.
layers
[
1
].
op
==
'transpose'
:
out_node
.
fluid_code
.
layers
[
2
].
inputs
[
'y'
]
=
out_node
.
fluid_code
.
layers
[
1
].
inputs
del
out_node
.
fluid_code
.
layers
[
1
]
else
:
out_node
.
fluid_code
.
layers
[
1
].
inputs
[
'y'
]
=
out_node
.
fluid_code
.
layers
[
0
].
inputs
del
out_node
.
fluid_code
.
layers
[
0
]
def
make_nchw_input_output
(
self
):
def
make_nchw_input_output
(
self
):
for
i
,
name
in
enumerate
(
self
.
graph
.
input_nodes
):
for
i
,
name
in
enumerate
(
self
.
graph
.
input_nodes
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录