Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
33166c4e
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
33166c4e
编写于
1月 13, 2019
作者:
J
jiangjiajun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
first commit
上级
df8bfe33
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
2022 addition
and
0 deletion
+2022
-0
TensorFlow2Paddle/README.md
TensorFlow2Paddle/README.md
+1
-0
TensorFlow2Paddle/framework_pb2.py
TensorFlow2Paddle/framework_pb2.py
+1165
-0
TensorFlow2Paddle/graph.py
TensorFlow2Paddle/graph.py
+82
-0
TensorFlow2Paddle/name_generator.py
TensorFlow2Paddle/name_generator.py
+46
-0
TensorFlow2Paddle/paddle_emitter.py
TensorFlow2Paddle/paddle_emitter.py
+548
-0
TensorFlow2Paddle/tensorflow_graph.py
TensorFlow2Paddle/tensorflow_graph.py
+83
-0
TensorFlow2Paddle/tensorflow_parser.py
TensorFlow2Paddle/tensorflow_parser.py
+66
-0
TensorFlow2Paddle/transformer.py
TensorFlow2Paddle/transformer.py
+31
-0
未找到文件。
TensorFlow2Paddle/README.md
0 → 100644
浏览文件 @
33166c4e
Warning: TensorFlow2Paddle is not stable yet
TensorFlow2Paddle/framework_pb2.py
0 → 100644
浏览文件 @
33166c4e
此差异已折叠。
点击以展开。
TensorFlow2Paddle/graph.py
0 → 100644
浏览文件 @
33166c4e
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
name_generator
import
NameGenerator
class
GraphNode
(
object
):
def
__init__
(
self
,
layer
):
self
.
in_edges
=
list
()
self
.
out_edges
=
list
()
self
.
layer
=
layer
self
.
ref_name
=
None
class
Graph
(
object
):
def
__init__
(
self
,
model
):
self
.
node_map
=
dict
()
self
.
input_nodes
=
list
()
self
.
output_nodes
=
list
()
self
.
topological_sort
=
list
()
self
.
model
=
model
self
.
name_generator
=
NameGenerator
()
def
build
(
self
):
self
.
_make_input_nodes
()
self
.
_make_output_nodes
()
self
.
_get_topological_sort
()
self
.
_gen_newname_for_nodes
()
def
_make_input_nodes
(
self
):
for
name
,
node
in
self
.
node_map
.
items
():
node
.
left_in_edges
=
len
(
node
.
in_edges
)
if
len
(
node
.
in_edges
)
==
0
:
self
.
input_nodes
.
append
(
name
)
def
_make_output_nodes
(
self
):
for
name
,
node
in
self
.
node_map
.
items
():
if
len
(
node
.
out_edges
)
==
0
:
self
.
output_nodes
.
append
(
name
)
def
_get_topological_sort
(
self
):
self
.
topological_sort
=
self
.
input_nodes
[:]
idx
=
0
while
idx
<
len
(
self
.
topological_sort
):
current_node
=
self
.
node_map
[
self
.
topological_sort
[
idx
]]
for
next_node
in
current_node
.
out_edges
:
next_node_info
=
self
.
node_map
[
next_node
]
next_node_info
.
left_in_edges
-=
1
if
next_node_info
.
left_in_edges
==
0
:
self
.
topological_sort
.
append
(
next_node
)
idx
+=
1
def
_gen_newname_for_nodes
(
self
):
for
node_name
in
self
.
topological_sort
:
node
=
self
.
node_map
[
node_name
]
ref_name
=
self
.
name_generator
.
get_name
(
node
)
self
.
node_map
[
node
.
layer
.
name
].
ref_name
=
ref_name
def
get_node
(
self
,
name
):
if
name
not
in
self
.
node_map
:
raise
Exception
(
"Graph doesn't have node [%s]."
%
name
)
else
:
return
self
.
node_map
[
name
]
def
_make_connection
(
self
,
src
,
dst
):
if
src
==
dst
or
src
not
in
self
.
node_map
or
dst
not
in
self
.
node_map
:
raise
Exception
(
'Warning: Node not exist or there is a self-loop'
)
if
src
not
in
self
.
node_map
[
dst
].
in_edges
:
self
.
node_map
[
dst
].
in_edges
.
append
(
src
)
if
dst
not
in
self
.
node_map
[
src
].
out_edges
:
self
.
node_map
[
src
].
out_edges
.
append
(
dst
)
TensorFlow2Paddle/name_generator.py
0 → 100644
浏览文件 @
33166c4e
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
NameGenerator
(
object
):
def
__init__
(
self
):
self
.
param_index
=
0
self
.
input_index
=
0
self
.
net_index
=
0
self
.
const_index
=
0
self
.
names
=
dict
()
def
get_name
(
self
,
node
):
ref_name
=
None
op_name
=
node
.
type
if
node
.
layer
.
name
in
self
.
names
:
return
self
.
names
[
node
.
layer
.
name
]
if
op_name
==
"variablev2"
:
ref_name
=
"param_"
+
str
(
self
.
param_index
)
self
.
param_index
+=
1
elif
op_name
==
"placeholder"
:
ref_name
=
"input_"
+
str
(
self
.
input_index
)
self
.
input_index
+=
1
elif
op_name
==
"const"
:
ref_name
=
"const_"
+
str
(
self
.
const_index
)
self
.
const_index
+=
1
elif
op_name
.
lower
()
==
"identity"
:
ref_name
=
self
.
names
[
node
.
layer
.
input
[
0
]]
else
:
ref_name
=
"net_"
+
str
(
self
.
net_index
)
self
.
net_index
+=
1
self
.
names
[
node
.
layer
.
name
]
=
ref_name
return
ref_name
TensorFlow2Paddle/paddle_emitter.py
0 → 100644
浏览文件 @
33166c4e
此差异已折叠。
点击以展开。
TensorFlow2Paddle/tensorflow_graph.py
0 → 100644
浏览文件 @
33166c4e
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
graph
import
GraphNode
,
Graph
from
tensorflow.core.framework
import
attr_value_pb2
class
TensorflowGraphNode
(
GraphNode
):
def
__init__
(
self
,
layer
):
super
(
TensorflowGraphNode
,
self
).
__init__
(
layer
)
self
.
codes
=
list
()
self
.
dataformat
=
'NCHW'
@
property
def
type
(
self
):
return
self
.
layer
.
op
.
lower
()
@
property
def
name
(
self
):
return
self
.
layer
.
name
# TODO
def
get_attr
(
self
,
name
,
default_value
=
None
):
if
name
in
self
.
layer
.
attr
:
attr
=
self
.
layer
.
attr
[
name
]
field
=
attr
.
WhichOneof
(
'value'
)
val
=
getattr
(
attr
,
field
)
if
field
else
default_value
if
isinstance
(
val
,
attr_value_pb2
.
AttrValue
.
ListValue
):
return
list
(
val
.
ListFields
()[
0
][
1
])
else
:
return
val
.
decode
(
'utf-8'
)
if
isinstance
(
val
,
bytes
)
else
val
else
:
return
default_value
class
TensorflowGraph
(
Graph
):
def
__init__
(
self
,
model
):
super
(
TensorflowGraph
,
self
).
__init__
(
model
)
self
.
model
=
model
def
build
(
self
):
for
i
,
layer
in
enumerate
(
self
.
model
.
node
):
self
.
node_map
[
layer
.
name
]
=
TensorflowGraphNode
(
layer
)
for
pred
in
layer
.
input
:
if
pred
not
in
self
.
node_map
:
raise
Exception
(
'input: {} not in node_map'
.
format
(
pred
))
self
.
_make_connection
(
pred
,
layer
.
name
)
super
(
TensorflowGraph
,
self
).
build
()
self
.
_check_dataformat
()
# check the dataformat of network
def
_check_dataformat
(
self
):
ss
=
list
()
for
i
in
range
(
0
,
len
(
self
.
topological_sort
)):
current_node
=
self
.
node_map
[
self
.
topological_sort
[
i
]]
if
current_node
.
type
==
'conv2d'
:
s
=
current_node
.
layer
.
attr
[
'data_format'
].
s
if
s
!=
'NHWC'
and
s
!=
'NCHW'
:
raise
Exception
(
'Unkown dataformat {}'
.
format
(
s
))
ss
.
append
(
s
)
if
len
(
set
(
ss
))
>
1
:
raise
Exception
(
"Two type of dataformat exist in this model"
)
if
len
(
set
(
ss
))
==
0
:
return
for
i
in
range
(
0
,
len
(
self
.
topological_sort
)):
current_node
=
self
.
node_map
[
self
.
topological_sort
[
i
]]
current_node
.
dataformat
=
ss
[
0
]
TensorFlow2Paddle/tensorflow_parser.py
0 → 100644
浏览文件 @
33166c4e
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tensorflow
from
tensorflow_graph
import
TensorflowGraph
class
TensorflowParser
(
object
):
def
__init__
(
self
,
meta_file
,
checkpoint_file
,
dest_nodes
,
input_shape
=
None
,
in_nodes
=
None
):
graph_def
=
None
self
.
weights
=
dict
()
with
tensorflow
.
Session
()
as
sess
:
if
meta_file
is
None
:
raise
Exception
(
"meta_file must be provided"
)
new_saver
=
tensorflow
.
train
.
import_meta_graph
(
meta_file
)
if
checkpoint_file
is
not
None
:
new_saver
.
restore
(
sess
,
tensorflow
.
train
.
latest_checkpoint
(
checkpoint_file
))
for
var
in
tensorflow
.
global_variables
():
value
=
var
.
eval
(
sess
)
self
.
weights
[
var
.
name
.
split
(
':'
)[
0
]]
=
value
graph_def
,
ver
=
tensorflow
.
get_default_graph
().
_as_graph_def
(
add_shapes
=
True
)
if
in_nodes
is
not
None
and
input_shape
is
not
None
:
from
tensorflow.python.tools
import
strip_unused_lib
from
tensorflow.python.framework
import
dtypes
graph_def
=
strip_unused_lib
.
strip_unused
(
input_graph_def
=
graph_def
,
input_node_names
=
in_nodes
,
output_node_names
=
dest_nodes
,
placeholder_type_enum
=
dtypes
.
float32
.
as_datatype_enum
)
input_list
=
[
None
]
for
i
in
range
(
len
(
input_shape
)):
input_list
.
append
(
tensorflow
.
Dimension
(
input_shape
[
i
]))
tensor_input
=
tensorflow
.
TensorShape
(
input_list
)
self
.
tf_graph
=
TensorflowGraph
(
graph_def
)
for
node
in
self
.
tf_graph
.
model
.
node
:
if
node
.
name
in
in_nodes
:
node
.
attr
[
'shape'
].
list
.
shape
.
extend
(
[
tensor_input
.
as_proto
()])
node
.
attr
[
'_output_shapes'
].
list
.
shape
.
pop
()
node
.
attr
[
'_output_shapes'
].
list
.
shape
.
extend
(
[
tensor_input
.
as_proto
()])
else
:
raise
Exception
(
'in_nodes and output_nodes need be provided'
)
self
.
tf_graph
.
build
()
TensorFlow2Paddle/transformer.py
0 → 100644
浏览文件 @
33166c4e
from
paddle_emitter
import
PaddleEmitter
from
tensorflow_parser
import
TensorflowParser
class
Transformer
(
object
):
def
__init__
(
self
,
meta_file
,
ckpt_file
,
out_nodes
,
in_shape
,
in_nodes
):
self
.
parser
=
TensorflowParser
(
meta_file
,
ckpt_file
,
out_nodes
,
in_shape
,
in_nodes
)
self
.
emitter
=
PaddleEmitter
(
self
.
parser
.
tf_graph
)
def
transform_code
(
self
,
out_py_file
):
filew
=
open
(
out_py_file
,
'w'
)
codes
=
self
.
emitter
.
gen_code
()
filew
.
write
(
codes
)
filew
.
close
()
def
transform_weight
(
self
,
out_dir
):
self
.
emitter
.
gen_weight
(
self
.
parser
.
weights
,
out_dir
)
def
run
(
self
,
dst_dir
):
import
os
if
os
.
path
.
isdir
(
dst_dir
)
or
os
.
path
.
isfile
(
dst_dir
):
print
(
"{} already exists, set a new directory"
)
return
if
not
os
.
path
.
isdir
(
dst_dir
):
os
.
mkdir
(
dst_dir
)
self
.
transform_code
(
dst_dir
+
"/mymodel.py"
)
if
(
len
(
self
.
parser
.
weights
)
==
0
):
print
(
"There is no tensorflow model weight translate to paddle"
)
else
:
self
.
transform_weight
(
dst_dir
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录