Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
fef5149c
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fef5149c
编写于
8月 02, 2019
作者:
J
jiangjiajun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add NCHW change
上级
0a1078ca
变更
3
展开全部
显示空白变更内容
内联
并排
Showing
3 changed file
with
295 addition
and
153 deletion
+295
-153
x2paddle/core/op_mapper.py
x2paddle/core/op_mapper.py
+1
-1
x2paddle/decoder/tf_decoder.py
x2paddle/decoder/tf_decoder.py
+19
-5
x2paddle/op_mapper/tf_op_mapper.py
x2paddle/op_mapper/tf_op_mapper.py
+275
-147
未找到文件。
x2paddle/core/op_mapper.py
浏览文件 @
fef5149c
...
...
@@ -116,7 +116,7 @@ class OpMapper(object):
feeded_var_names
=
input_names
,
target_vars
=
outputs
,
executor
=
exe
,
params_filename
=
"__params__"
)
params_filename
=
None
)
except
:
raise
Exception
(
"Paddle code was saved in {}/model.py, but seems there's wrong exist, please check model.py manually."
...
...
x2paddle/decoder/tf_decoder.py
浏览文件 @
fef5149c
...
...
@@ -24,7 +24,7 @@ import sys
class
TFGraphNode
(
GraphNode
):
def
__init__
(
self
,
layer
,
layer_name
=
None
):
def
__init__
(
self
,
layer
,
layer_name
=
None
,
data_format
=
"NHWC"
):
if
layer_name
is
None
:
super
(
TFGraphNode
,
self
).
__init__
(
layer
,
...
...
@@ -35,6 +35,8 @@ class TFGraphNode(GraphNode):
layer_name
.
replace
(
'/'
,
'_'
).
replace
(
'-'
,
'_'
))
self
.
layer_type
=
layer
.
op
self
.
tf_data_format
=
data_format
self
.
pd_data_format
=
"NCHW"
self
.
fluid_code
=
FluidCode
()
self
.
dtype_map
=
{
1
:
"float32"
,
3
:
"int32"
,
4
:
"int8"
,
9
:
"int64"
}
...
...
@@ -86,15 +88,16 @@ class TFGraphNode(GraphNode):
class
TFGraph
(
Graph
):
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
,
data_format
=
"NHWC"
):
super
(
TFGraph
,
self
).
__init__
(
model
)
self
.
identity_map
=
dict
()
self
.
multi_out_ops
=
[
'Split'
,
'SplitV'
]
self
.
tf_data_format
=
data_format
def
build
(
self
):
for
layer
in
self
.
model
.
node
:
self
.
node_map
[
layer
.
name
.
replace
(
'/'
,
'_'
).
replace
(
'-'
,
'_'
)]
=
TFGraphNode
(
layer
)
'-'
,
'_'
)]
=
TFGraphNode
(
layer
,
data_format
=
self
.
tf_data_format
)
for
layer_name
,
node
in
self
.
node_map
.
items
():
for
in_node
in
node
.
layer
.
input
:
...
...
@@ -166,9 +169,20 @@ class TFGraph(Graph):
idx
=
self
.
output_nodes
.
index
(
node_name
)
self
.
output_nodes
[
idx
]
=
input_node
.
layer_name
def
data_format_propagation
(
self
,
node
):
current_node
=
self
.
node_map
[
node
.
layer_name
]
current_node
=
node
.
tf_data_format
outputs
=
current_node
.
outputs
if
len
(
outputs
)
==
0
:
return
for
out
in
outputs
:
next_node
=
self
.
node_map
[
out
]
next_node
.
tf_data_format
=
node
.
tf_data_format
self
.
data_format_propagation
(
next_node
)
class
TFDecoder
(
object
):
def
__init__
(
self
,
pb_model
):
def
__init__
(
self
,
pb_model
,
data_format
=
"NHWC"
):
self
.
sess
=
tf
.
Session
()
self
.
input_info
=
dict
()
with
gfile
.
FastGFile
(
pb_model
,
'rb'
)
as
f
:
...
...
@@ -186,7 +200,7 @@ class TFDecoder(object):
self
.
sess
.
run
(
tf
.
global_variables_initializer
())
self
.
tf_graph
=
TFGraph
(
self
.
sess
.
graph
.
_as_graph_def
(
add_shapes
=
True
)[
0
])
self
.
sess
.
graph
.
_as_graph_def
(
add_shapes
=
True
)[
0
]
,
data_format
)
self
.
tf_graph
.
build
()
def
_fix_output_shape
(
self
,
graph
):
...
...
x2paddle/op_mapper/tf_op_mapper.py
浏览文件 @
fef5149c
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录