Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
b4beabb7
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b4beabb7
编写于
8月 27, 2019
作者:
J
Jason
提交者:
GitHub
8月 27, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1 from PaddlePaddle/develop
develop
上级
405a2f18
753d19eb
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
87 addition
and
2 deletion
+87
-2
FAQ.md
FAQ.md
+29
-0
x2paddle/op_mapper/tf_op_mapper.py
x2paddle/op_mapper/tf_op_mapper.py
+58
-2
未找到文件。
FAQ.md
浏览文件 @
b4beabb7
...
@@ -2,3 +2,32 @@
...
@@ -2,3 +2,32 @@
**Q1. TensorFlow模型转换过程中,提示『Unknown shape for input tensor[tensor name: "input"], Please define shape of input here』?**
**Q1. TensorFlow模型转换过程中,提示『Unknown shape for input tensor[tensor name: "input"], Please define shape of input here』?**
A:该提示信息表示无法从TensorFlow的pb模型中获取到输入tensor(tensor名为"input:)的shape信息,所以需要用户手动在提示后输入详细的shape信息,如None,224,224,3 其中None表示Batch
A:该提示信息表示无法从TensorFlow的pb模型中获取到输入tensor(tensor名为"input:)的shape信息,所以需要用户手动在提示后输入详细的shape信息,如None,224,224,3 其中None表示Batch
**Q2. TensorFlow模型转换失败怎么解决?**
A: 目前TensorFlow模型转换失败存在几个问题。1) 存在暂未支持的OP,此信息会在转换时输出; 2) NHWC优化导致部分参数出错;3)Batch维度带来的出错 4)其它
对于(1)问题,建议自行添加或发起Issue;
其中(2)、(3)、(4)问题目前没有明确的报错信息,当您遇到模型转换失败时,请尝试如下的步骤后,再进行转换测试
```
x2paddle -f tensorflow -m tf.pb -s pd-model --without_data_format_optimization --define_input_shape
```
#### without_data_format_optimization : 关闭NHWC优化
TensorFlow的CV模型,大多的输入格式为
`NHWC`
,而Paddle目前仅支持
`NCHW`
,如若直接转换,需要在conv2d、pool2d等操作前后添加transpose解决,这样会带来性能的损耗。X2Paddle在模型转换过程中,对此问题进行了优化,避免transpose操作带来的性能问题,但目前仅在部分模型上进行了测试,不一定适用于其它模型,因此,如若模型转换存在问题时,我们建议你关闭NHWC的优化。
在模型转换时添加参数 --without_data_format_optimization
```
x2paddle -f tensorflow -m tf.pb -s pd-model --without_data_format_optimization
```
### define_input_shape : 固定Batch大小
受限于不同框架的运行机制,在转换过程中,Batch维度也有一定可能会带来模型转换失败的问题。可以尝试固定Batch维度后再转换
在模型转换时添加参数 --define_input_shape
```
x2paddle -f tensorflow -m tf.pb -s pd-model --define_input_shape
```
如原tensorflow模型的输入shape为
`[None, 224, 224, 3]`
,可添加参数后,根据提示,把输入的shape修改为
`[2, 224, 224, 3]`
x2paddle/op_mapper/tf_op_mapper.py
浏览文件 @
b4beabb7
...
@@ -17,6 +17,7 @@ from x2paddle.core.op_mapper import OpMapper
...
@@ -17,6 +17,7 @@ from x2paddle.core.op_mapper import OpMapper
from
x2paddle.core.util
import
*
from
x2paddle.core.util
import
*
import
inspect
import
inspect
import
numpy
import
numpy
import
sys
# compute padding size for SAME mode
# compute padding size for SAME mode
...
@@ -83,18 +84,31 @@ class TFOpMapper(OpMapper):
...
@@ -83,18 +84,31 @@ class TFOpMapper(OpMapper):
del
self
.
graph
.
input_nodes
[
idx
]
del
self
.
graph
.
input_nodes
[
idx
]
print
(
"Total nodes: {}"
.
format
(
len
(
self
.
graph
.
topo_sort
)))
print
(
"Total nodes: {}"
.
format
(
len
(
self
.
graph
.
topo_sort
)))
unsupported_ops
=
set
()
for
node_name
in
self
.
graph
.
topo_sort
:
for
node_name
in
self
.
graph
.
topo_sort
:
node
=
self
.
graph
.
get_node
(
node_name
)
node
=
self
.
graph
.
get_node
(
node_name
)
op
=
node
.
layer_type
op
=
node
.
layer_type
if
op
in
self
.
directly_map_ops
:
if
op
in
self
.
directly_map_ops
:
if
len
(
unsupported_ops
)
>
0
:
continue
self
.
directly_map
(
node
)
self
.
directly_map
(
node
)
elif
op
in
self
.
elementwise_ops
:
elif
op
in
self
.
elementwise_ops
:
if
len
(
unsupported_ops
)
>
0
:
continue
self
.
elementwise_map
(
node
)
self
.
elementwise_map
(
node
)
elif
hasattr
(
self
,
op
):
elif
hasattr
(
self
,
op
):
if
len
(
unsupported_ops
)
>
0
:
continue
func
=
getattr
(
self
,
op
)
func
=
getattr
(
self
,
op
)
func
(
node
)
func
(
node
)
else
:
else
:
raise
Exception
(
"OP: [{}] not support yet"
.
format
(
op
))
unsupported_ops
.
add
(
op
)
if
len
(
unsupported_ops
)
>
0
:
print
(
"=========={} Ops are not supported yet======"
.
format
(
len
(
unsupported_ops
)))
for
op
in
unsupported_ops
:
print
(
"========== {} =========="
.
format
(
op
))
sys
.
exit
(
-
1
)
def
directly_map
(
self
,
node
):
def
directly_map
(
self
,
node
):
assert
node
.
layer_type
in
self
.
directly_map_ops
assert
node
.
layer_type
in
self
.
directly_map_ops
...
@@ -773,7 +787,15 @@ class TFOpMapper(OpMapper):
...
@@ -773,7 +787,15 @@ class TFOpMapper(OpMapper):
begin
=
[
begin
[
i
]
for
i
in
[
0
,
3
,
1
,
2
]]
begin
=
[
begin
[
i
]
for
i
in
[
0
,
3
,
1
,
2
]]
end
=
[
end
[
i
]
for
i
in
[
0
,
3
,
1
,
2
]]
end
=
[
end
[
i
]
for
i
in
[
0
,
3
,
1
,
2
]]
attr
=
{
"axes"
:
range
(
len
(
strides
)),
"starts"
:
begin
,
"ends"
:
end
}
for
i
in
range
(
len
(
end
)):
if
end
[
i
]
==
0
:
end
[
i
]
=
999999
attr
=
{
"axes"
:
[
i
for
i
in
range
(
len
(
strides
))],
"starts"
:
begin
,
"ends"
:
end
}
node
.
fluid_code
.
add_layer
(
"slice"
,
node
.
fluid_code
.
add_layer
(
"slice"
,
inputs
=
input
,
inputs
=
input
,
output
=
node
,
output
=
node
,
...
@@ -955,3 +977,37 @@ class TFOpMapper(OpMapper):
...
@@ -955,3 +977,37 @@ class TFOpMapper(OpMapper):
inputs
=
input
,
inputs
=
input
,
output
=
node
,
output
=
node
,
param_attr
=
attr
)
param_attr
=
attr
)
def
ResizeNearestNeighbor
(
self
,
node
):
input
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
resize_shape
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
self
.
omit_nodes
.
append
(
resize_shape
.
layer_name
)
if
resize_shape
.
layer_type
==
"Const"
:
resize_shape
=
resize_shape
.
value
.
tolist
()
else
:
resize_shape
=
self
.
decoder
.
infer_shape_tensor
(
resize_shape
)
align_corners
=
node
.
get_attr
(
"align_corners"
)
attr
=
{
"align_corners"
:
align_corners
,
"out_shape"
:
resize_shape
}
node
.
fluid_code
.
add_layer
(
"resize_nearest"
,
inputs
=
input
,
output
=
node
,
param_attr
=
attr
)
def
ResizeBilinear
(
self
,
node
):
input
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
resize_shape
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
self
.
omit_nodes
.
append
(
resize_shape
.
layer_name
)
if
resize_shape
.
layer_type
==
"Const"
:
resize_shape
=
resize_shape
.
value
.
tolist
()
else
:
resize_shape
=
self
.
decoder
.
infer_shape_tensor
(
resize_shape
)
align_corners
=
node
.
get_attr
(
"align_corners"
)
attr
=
{
"align_corners"
:
align_corners
,
"out_shape"
:
resize_shape
,
"align_mode"
:
1
}
node
.
fluid_code
.
add_layer
(
"resize_bilinear"
,
inputs
=
input
,
output
=
node
,
param_attr
=
attr
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录