Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
7a96c492
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
1 年多 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7a96c492
编写于
1月 25, 2019
作者:
J
jiangjiajun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
push tensorflow2fluid
上级
c4136505
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
2313 addition
and
0 deletion
+2313
-0
tensorflow2fluid/src/__init__.py
tensorflow2fluid/src/__init__.py
+0
-0
tensorflow2fluid/src/framework_pb2.py
tensorflow2fluid/src/framework_pb2.py
+1165
-0
tensorflow2fluid/src/graph.py
tensorflow2fluid/src/graph.py
+96
-0
tensorflow2fluid/src/name_generator.py
tensorflow2fluid/src/name_generator.py
+46
-0
tensorflow2fluid/src/paddle_emitter.py
tensorflow2fluid/src/paddle_emitter.py
+742
-0
tensorflow2fluid/src/tensorflow_graph.py
tensorflow2fluid/src/tensorflow_graph.py
+122
-0
tensorflow2fluid/src/tensorflow_parser.py
tensorflow2fluid/src/tensorflow_parser.py
+113
-0
tensorflow2fluid/src/transformer.py
tensorflow2fluid/src/transformer.py
+29
-0
未找到文件。
tensorflow2fluid/src/__init__.py
0 → 100644
浏览文件 @
7a96c492
tensorflow2fluid/src/framework_pb2.py
0 → 100644
浏览文件 @
7a96c492
此差异已折叠。
点击以展开。
tensorflow2fluid/src/graph.py
0 → 100644
浏览文件 @
7a96c492
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
name_generator
import
NameGenerator
class
GraphNode
(
object
):
def
__init__
(
self
,
layer
):
self
.
inputs
=
list
()
self
.
outputs
=
list
()
self
.
layer
=
layer
self
.
ref_name
=
None
self
.
output_name
=
None
def
__hash__
(
self
):
return
hash
(
self
.
layer
.
name
)
def
__eq__
(
self
,
other
):
if
self
.
layer
.
name
==
other
.
layer
.
name
:
return
True
return
False
class
Graph
(
object
):
def
__init__
(
self
,
model
):
self
.
node_map
=
dict
()
self
.
input_nodes
=
list
()
self
.
output_nodes
=
list
()
self
.
topological_sort
=
list
()
self
.
model
=
model
self
.
name_generator
=
NameGenerator
()
def
build
(
self
):
self
.
_make_input_nodes
()
self
.
_make_output_nodes
()
self
.
_get_topological_sort
()
self
.
_gen_newname_for_nodes
()
def
_make_input_nodes
(
self
):
for
name
,
node
in
self
.
node_map
.
items
():
if
len
(
node
.
outputs
)
==
0
and
len
(
node
.
inputs
)
==
0
:
continue
node
.
left_inputs
=
len
(
node
.
inputs
)
if
len
(
node
.
inputs
)
==
0
:
self
.
input_nodes
.
append
(
name
)
def
_make_output_nodes
(
self
):
for
name
,
node
in
self
.
node_map
.
items
():
if
len
(
node
.
outputs
)
==
0
and
len
(
node
.
inputs
)
==
0
:
continue
if
len
(
node
.
outputs
)
==
0
:
self
.
output_nodes
.
append
(
name
)
def
_get_topological_sort
(
self
):
self
.
topological_sort
=
self
.
input_nodes
[:]
idx
=
0
while
idx
<
len
(
self
.
topological_sort
):
current_node
=
self
.
node_map
[
self
.
topological_sort
[
idx
]]
for
next_node
in
current_node
.
outputs
:
next_node_info
=
self
.
node_map
[
next_node
.
layer_name
]
next_node_info
.
left_inputs
-=
1
if
next_node_info
.
left_inputs
==
0
:
self
.
topological_sort
.
append
(
next_node
.
layer_name
)
idx
+=
1
def
_gen_newname_for_nodes
(
self
):
for
node_name
in
self
.
topological_sort
:
node
=
self
.
node_map
[
node_name
]
ref_name
=
self
.
name_generator
.
get_name
(
node
)
self
.
node_map
[
node
.
layer
.
name
].
ref_name
=
ref_name
self
.
node_map
[
node
.
layer
.
name
].
output_name
=
ref_name
.
split
(
'['
)[
0
]
def
get_node
(
self
,
name
):
if
name
not
in
self
.
node_map
:
raise
Exception
(
"Graph doesn't have node [%s]."
%
name
)
else
:
return
self
.
node_map
[
name
]
def
_make_connection
(
self
,
src
,
dst
):
if
src
.
layer_name
==
dst
.
layer_name
or
src
.
layer_name
not
in
self
.
node_map
or
dst
.
layer_name
not
in
self
.
node_map
:
raise
Exception
(
'Warning: Node not exist or there is a self-loop'
)
if
src
not
in
self
.
node_map
[
dst
.
layer_name
].
inputs
:
self
.
node_map
[
dst
.
layer_name
].
inputs
.
append
(
src
)
if
dst
not
in
self
.
node_map
[
src
.
layer_name
].
outputs
:
self
.
node_map
[
src
.
layer_name
].
outputs
.
append
(
dst
)
tensorflow2fluid/src/name_generator.py
0 → 100644
浏览文件 @
7a96c492
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
NameGenerator
(
object
):
def
__init__
(
self
):
self
.
param_index
=
0
self
.
input_index
=
0
self
.
net_index
=
0
self
.
const_index
=
0
self
.
names
=
dict
()
def
get_name
(
self
,
node
):
ref_name
=
None
op_name
=
node
.
layer_type
if
node
.
layer
.
name
in
self
.
names
:
return
self
.
names
[
node
.
layer
.
name
]
if
op_name
==
"variablev2"
:
ref_name
=
"param_"
+
str
(
self
.
param_index
)
self
.
param_index
+=
1
elif
op_name
==
"placeholder"
:
ref_name
=
"input_"
+
str
(
self
.
input_index
)
self
.
input_index
+=
1
elif
op_name
==
"const"
:
ref_name
=
"const_"
+
str
(
self
.
const_index
)
self
.
const_index
+=
1
elif
op_name
.
lower
()
==
"identity"
:
ref_name
=
self
.
names
[
node
.
layer
.
input
[
0
]]
else
:
ref_name
=
"net_"
+
str
(
self
.
net_index
)
self
.
net_index
+=
1
self
.
names
[
node
.
layer
.
name
]
=
ref_name
return
ref_name
tensorflow2fluid/src/paddle_emitter.py
0 → 100644
浏览文件 @
7a96c492
此差异已折叠。
点击以展开。
tensorflow2fluid/src/tensorflow_graph.py
0 → 100644
浏览文件 @
7a96c492
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
graph
import
GraphNode
,
Graph
from
tensorflow.core.framework
import
attr_value_pb2
class
TensorflowGraphNode
(
GraphNode
):
def
__init__
(
self
,
layer
):
super
(
TensorflowGraphNode
,
self
).
__init__
(
layer
)
self
.
codes
=
list
()
self
.
data_format
=
'NCHW'
@
property
def
layer_type
(
self
):
return
self
.
layer
.
op
.
lower
()
@
property
def
layer_name
(
self
):
return
self
.
layer
.
name
def
get_attr
(
self
,
name
,
default_value
=
None
):
if
name
in
self
.
layer
.
attr
:
attr
=
self
.
layer
.
attr
[
name
]
field
=
attr
.
WhichOneof
(
'value'
)
val
=
getattr
(
attr
,
field
)
if
field
else
default_value
if
isinstance
(
val
,
attr_value_pb2
.
AttrValue
.
ListValue
):
return
list
(
val
.
ListFields
()[
0
][
1
])
else
:
return
val
.
decode
(
'utf-8'
)
if
isinstance
(
val
,
bytes
)
else
val
else
:
return
default_value
class
TensorflowGraph
(
Graph
):
def
__init__
(
self
,
tf_graph
):
super
(
TensorflowGraph
,
self
).
__init__
(
tf_graph
)
self
.
tf_graph
=
tf_graph
def
build
(
self
):
skip_node
=
set
([
'const'
])
for
i
,
layer
in
enumerate
(
self
.
tf_graph
.
node
):
self
.
node_map
[
layer
.
name
]
=
TensorflowGraphNode
(
layer
)
for
i
,
layer
in
enumerate
(
self
.
tf_graph
.
node
):
if
layer
.
op
.
lower
()
in
skip_node
:
continue
for
pred
in
layer
.
input
:
if
pred
not
in
self
.
node_map
and
pred
.
split
(
':'
)[
0
]
in
self
.
node_map
:
node
=
self
.
node_map
[
pred
.
split
(
':'
)[
0
]]
if
node
.
layer_type
==
"switch"
:
self
.
_make_connection
(
node
,
self
.
node_map
[
layer
.
name
])
else
:
raise
Exception
(
"Need to fix here"
)
elif
pred
in
self
.
node_map
:
self
.
_make_connection
(
self
.
node_map
[
pred
],
self
.
node_map
[
layer
.
name
])
else
:
raise
Exception
(
"input: {} not in node_map"
.
format
(
pred
))
super
(
TensorflowGraph
,
self
).
build
()
self
.
_remove_useless_nodes
()
self
.
_check_dataformat
()
def
_check_dataformat
(
self
):
ss
=
list
()
for
i
in
range
(
0
,
len
(
self
.
topological_sort
)):
current_node
=
self
.
node_map
[
self
.
topological_sort
[
i
]]
if
'data_format'
in
current_node
.
layer
.
attr
:
s
=
current_node
.
layer
.
attr
[
'data_format'
].
s
if
s
!=
'NHWC'
and
s
!=
'NCHW'
:
raise
Exception
(
'Unkown dataformat {}'
.
format
(
s
))
ss
.
append
(
s
)
if
len
(
set
(
ss
))
>
1
:
raise
Exception
(
"Two type of dataformat exist in this model"
)
if
len
(
set
(
ss
))
==
0
:
return
for
k
,
v
in
self
.
node_map
.
items
():
self
.
node_map
[
k
].
data_format
=
ss
[
0
]
def
_remove_useless_nodes
(
self
):
useless_type
=
set
(
[
'identity'
,
'placeholderwithdefault'
,
'switch'
,
'merge'
])
remove_index
=
list
()
for
i
in
range
(
0
,
len
(
self
.
topological_sort
)):
name
=
self
.
topological_sort
[
i
]
current_node
=
self
.
node_map
[
name
]
if
current_node
.
layer_type
in
useless_type
:
input
=
current_node
.
inputs
[
0
]
for
node
in
current_node
.
outputs
:
for
k
in
range
(
0
,
len
(
node
.
inputs
)):
if
node
.
inputs
[
k
]
==
current_node
:
node
.
inputs
[
k
]
=
input
if
node
not
in
input
.
outputs
:
input
.
outputs
.
append
(
node
)
input
.
outputs
.
remove
(
current_node
)
del
self
.
node_map
[
name
]
if
name
in
self
.
output_nodes
:
self
.
output_nodes
.
remove
(
name
)
if
name
in
self
.
input_nodes
:
self
.
input_nodes
.
remove
(
name
)
remove_index
.
append
(
i
)
remove_index
.
sort
(
reverse
=
True
)
for
i
in
range
(
0
,
len
(
remove_index
)):
del
self
.
topological_sort
[
remove_index
[
i
]]
tensorflow2fluid/src/tensorflow_parser.py
0 → 100644
浏览文件 @
7a96c492
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tensorflow
from
tensorflow_graph
import
TensorflowGraph
from
tensorflow.python.framework
import
tensor_util
from
tensorflow.python.tools
import
strip_unused_lib
from
tensorflow.python.framework
import
dtypes
class
TensorflowCkptParser
(
object
):
def
__init__
(
self
,
meta_file
,
checkpoint_file
,
dest_nodes
,
input_shape
=
None
,
in_nodes
=
None
):
graph_def
=
None
self
.
weights
=
None
with
tensorflow
.
Session
()
as
sess
:
if
meta_file
is
None
:
raise
Exception
(
"meta_file must be provided"
)
new_saver
=
tensorflow
.
train
.
import_meta_graph
(
meta_file
)
if
checkpoint_file
is
not
None
:
self
.
weights
=
dict
()
new_saver
.
restore
(
sess
,
tensorflow
.
train
.
latest_checkpoint
(
checkpoint_file
))
for
var
in
tensorflow
.
global_variables
():
value
=
var
.
eval
(
sess
)
self
.
weights
[
var
.
name
.
split
(
':'
)[
0
]]
=
value
graph_def
,
ver
=
tensorflow
.
get_default_graph
().
_as_graph_def
(
add_shapes
=
True
)
if
in_nodes
is
not
None
and
input_shape
is
not
None
:
graph_def
=
strip_unused_lib
.
strip_unused
(
input_graph_def
=
graph_def
,
input_node_names
=
in_nodes
,
output_node_names
=
dest_nodes
,
placeholder_type_enum
=
dtypes
.
float32
.
as_datatype_enum
)
self
.
tf_graph
=
TensorflowGraph
(
graph_def
)
else
:
raise
Exception
(
'in_nodes and output_nodes need be provided'
)
self
.
tf_graph
.
build
()
class
TensorflowPbParser
(
object
):
def
__init__
(
self
,
pb_file
,
dest_nodes
,
input_shape
=
None
,
in_nodes
=
None
):
with
open
(
pb_file
)
as
f
:
serialized
=
f
.
read
()
tensorflow
.
reset_default_graph
()
original_graph_def
=
tensorflow
.
GraphDef
()
original_graph_def
.
ParseFromString
(
serialized
)
original_graph_def
=
strip_unused_lib
.
strip_unused
(
input_graph_def
=
original_graph_def
,
input_node_names
=
in_nodes
,
output_node_names
=
dest_nodes
,
placeholder_type_enum
=
dtypes
.
float32
.
as_datatype_enum
)
graph_def
=
tensorflow
.
GraphDef
()
graph_def
.
ParseFromString
(
original_graph_def
.
SerializeToString
())
in_type_list
=
dict
()
for
node
in
graph_def
.
node
:
if
node
.
name
in
in_nodes
:
in_type_list
[
node
.
name
]
=
node
.
attr
[
'dtype'
].
type
input_shape
=
list
(
input_shape
)
if
not
isinstance
(
input_shape
[
0
],
list
):
input_shape
=
[
input_shape
]
input_map
=
dict
()
for
i
in
range
(
len
(
input_shape
)):
if
in_type_list
[
in_nodes
[
i
]]
==
1
or
in_type_list
[
in_nodes
[
i
]]
==
0
:
dtype
=
tensorflow
.
float32
x
=
tensorflow
.
placeholder
(
dtype
,
shape
=
input_shape
[
i
])
elif
in_type_list
[
in_nodes
[
i
]]
==
3
:
dtype
=
tensorflow
.
int32
x
=
tensorflow
.
placehoder
(
dtype
,
shape
=
input_shape
[
i
])
else
:
raise
Exception
(
"Unexpected dtype for input, only support float32 and int32 now"
)
input_map
[
in_nodes
[
i
]
+
":0"
]
=
x
tensorflow
.
import_graph_def
(
graph_def
,
name
=
""
,
input_map
=
input_map
)
graph_def
=
tensorflow
.
get_default_graph
().
_as_graph_def
(
add_shapes
=
True
)[
0
]
node
=
graph_def
.
node
[
0
]
self
.
tf_graph
=
TensorflowGraph
(
graph_def
)
self
.
tf_graph
.
build
()
self
.
weights
=
dict
()
for
node
in
graph_def
.
node
:
if
node
.
op
.
lower
()
==
"const"
:
try
:
node
.
attr
[
'value'
].
tensor
.
tensor_content
weight
=
tensor_util
.
MakeNdarray
(
node
.
attr
[
'value'
].
tensor
)
self
.
weights
[
node
.
name
]
=
weight
except
:
continue
tensorflow2fluid/src/transformer.py
0 → 100644
浏览文件 @
7a96c492
from
paddle_emitter
import
PaddleEmitter
from
tensorflow_parser
import
TensorflowCkptParser
from
tensorflow_parser
import
TensorflowPbParser
class
Transformer
(
object
):
def
__init__
(
self
,
meta_file
,
ckpt_file
,
out_nodes
,
in_shape
,
in_nodes
,
save_dir
):
self
.
parser
=
TensorflowCkptParser
(
meta_file
,
ckpt_file
,
out_nodes
,
in_shape
,
in_nodes
)
self
.
emitter
=
PaddleEmitter
(
self
.
parser
,
save_dir
)
def
transform_code
(
self
):
codes
=
self
.
emitter
.
run
()
def
run
(
self
):
self
.
transform_code
()
class
PbTransformer
(
object
):
def
__init__
(
self
,
pb_file
,
out_nodes
,
in_shape
,
in_nodes
,
save_dir
):
self
.
parser
=
TensorflowPbParser
(
pb_file
,
out_nodes
,
in_shape
,
in_nodes
)
self
.
emitter
=
PaddleEmitter
(
self
.
parser
,
save_dir
)
node
=
self
.
parser
.
tf_graph
.
tf_graph
.
node
[
0
]
def
transform_code
(
self
):
codes
=
self
.
emitter
.
run
()
def
run
(
self
):
self
.
transform_code
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录