Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
ec632e66
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ec632e66
编写于
7月 16, 2019
作者:
J
jiangjiajun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test code
上级
b444f36d
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
44 addition
and
17 deletion
+44
-17
x2paddle/convert.py
x2paddle/convert.py
+5
-0
x2paddle/core/graph.py
x2paddle/core/graph.py
+5
-8
x2paddle/parser/tf_parser.py
x2paddle/parser/tf_parser.py
+34
-9
未找到文件。
x2paddle/convert.py
浏览文件 @
ec632e66
...
@@ -11,3 +11,8 @@
...
@@ -11,3 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
x2paddle.parser.tf_parser
import
TFParser
parser
=
TFParser
(
'/ssd2/Jason/github/X2Paddle/x2paddle/tests/frozen_darknet_yolov3_model.pb'
,
in_nodes
=
[
'inputs'
],
out_nodes
=
[
'output_boxes'
],
in_shapes
=
[[
-
1
,
416
,
416
,
3
]])
x2paddle/core/graph.py
浏览文件 @
ec632e66
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
utils
import
*
import
collections
import
collections
...
@@ -44,7 +43,7 @@ class Graph(object):
...
@@ -44,7 +43,7 @@ class Graph(object):
self
.
topo_sort
=
list
()
self
.
topo_sort
=
list
()
self
.
model
=
model
self
.
model
=
model
def
build
(
self
,
input_format
):
def
build
(
self
):
self
.
_make_input_nodes
()
self
.
_make_input_nodes
()
self
.
_make_output_nodes
()
self
.
_make_output_nodes
()
self
.
_get_topo_sort
()
self
.
_get_topo_sort
()
...
@@ -65,7 +64,7 @@ class Graph(object):
...
@@ -65,7 +64,7 @@ class Graph(object):
num_inputs
[
name
]
=
len
(
node
.
inputs
)
num_inputs
[
name
]
=
len
(
node
.
inputs
)
self
.
topo_sort
=
self
.
input_nodes
[:]
self
.
topo_sort
=
self
.
input_nodes
[:]
while
idx
in
range
(
len
(
self
.
topo_sort
)):
for
idx
in
range
(
len
(
self
.
topo_sort
)):
current_node
=
self
.
node_map
[
self
.
topo_sort
[
idx
]]
current_node
=
self
.
node_map
[
self
.
topo_sort
[
idx
]]
for
node
in
current_node
.
outputs
:
for
node
in
current_node
.
outputs
:
num_inputs
[
node
.
layer_name
]
-=
1
num_inputs
[
node
.
layer_name
]
-=
1
...
@@ -79,8 +78,6 @@ class Graph(object):
...
@@ -79,8 +78,6 @@ class Graph(object):
return
self
.
node_map
[
name
]
return
self
.
node_map
[
name
]
def
connect
(
self
,
src
,
dst
):
def
connect
(
self
,
src
,
dst
):
if
src
.
layer_name
==
dst
.
layer_name
or
src
.
layer_name
not
in
\
if
dst
not
in
self
.
node_map
:
self
.
node_map
or
dst
.
layer_name
not
in
self
.
node_map
:
raise
Exception
(
"node[{}] not in graph"
.
format
(
dst
))
raise
Exception
(
'Warning: Node not exist or there is a self-loop'
)
self
.
node_map
[
dst
].
inputs
.
append
(
src
)
self
.
node_map
[
dst
.
layer_name
].
inputs
.
append
(
src
)
self
.
node_map
[
src
.
layer_name
].
outputs
.
append
(
dst
)
x2paddle/parser/tf_parser.py
浏览文件 @
ec632e66
...
@@ -13,18 +13,40 @@
...
@@ -13,18 +13,40 @@
# limitations under the License.
# limitations under the License.
from
x2paddle.core.graph
import
GraphNode
,
Graph
from
x2paddle.core.graph
import
GraphNode
,
Graph
from
tensorflow.python.platform
import
gfile
import
tensorflow
as
tf
import
copy
class
TFGraphNode
(
GraphNode
):
class
TFGraphNode
(
GraphNode
):
def
__init__
(
self
,
layer
,
layer_name
=
None
):
def
__init__
(
self
,
layer
,
layer_name
=
None
):
super
(
TFGraphNode
,
self
).
__init__
(
layer
,
layer_name
)
super
(
TFGraphNode
,
self
).
__init__
(
layer
,
layer_name
)
self
.
layer_type
=
layer
.
op
.
lower
()
self
.
layer_type
=
layer
.
op
class
TFGraph
(
Graph
):
class
TFGraph
(
Graph
):
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
super
(
TFGraph
,
self
).
__init__
(
model
)
super
(
TFGraph
,
self
).
__init__
(
model
)
self
.
multi_output_ops
=
[
'Split'
,
'Unpack'
]
def
build
(
self
):
for
layer
in
self
.
model
.
node
:
self
.
node_map
[
layer
.
name
]
=
TFGraphNode
(
layer
)
for
layer_name
,
node
in
self
.
node_map
.
items
():
for
in_node
in
node
.
layer
.
input
:
if
in_node
not
in
self
.
node_map
:
if
in_node
.
strip
().
split
(
':'
)[
0
]
in
self
.
node_map
:
self
.
connect
(
in_node
,
layer_name
)
else
:
raise
Exception
(
'input[{}] of node[{}] does not exist in node_map'
.
format
(
in_node
,
layer_name
))
else
:
if
self
.
node_map
[
in_node
].
layer_type
in
self
.
multi_output_ops
:
in_node
+=
":0"
self
.
connect
(
in_node
,
layer_name
)
super
(
TFGraph
,
self
).
build
()
class
TFParser
(
object
):
class
TFParser
(
object
):
def
__init__
(
self
,
pb_model
,
in_nodes
=
None
,
out_nodes
=
None
,
in_shapes
=
None
):
def
__init__
(
self
,
pb_model
,
in_nodes
=
None
,
out_nodes
=
None
,
in_shapes
=
None
):
...
@@ -33,11 +55,14 @@ class TFParser(object):
...
@@ -33,11 +55,14 @@ class TFParser(object):
assert
in_shapes
is
not
None
,
"in_shapes should not be None"
assert
in_shapes
is
not
None
,
"in_shapes should not be None"
assert
len
(
in_shapes
)
==
len
(
in_nodes
),
"length of in_shapes and in_nodes should be equal"
assert
len
(
in_shapes
)
==
len
(
in_nodes
),
"length of in_shapes and in_nodes should be equal"
serialized_str
=
open
(
pb_model
,
'rb'
).
read
()
sess
=
tf
.
Session
()
tf
.
reset_default_graph
()
with
gfile
.
FastGFile
(
pb_model
,
'rb'
)
as
f
:
graph_def
=
tf
.
GraphDef
()
graph_def
=
tf
.
GraphDef
()
graph_def
.
ParseFromString
(
serialized_str
)
graph_def
.
ParseFromString
(
f
.
read
())
sess
.
graph
.
as_default
()
sess
=
tf
.
Session
(
graph
=
tf
.
get_default_graph
())
tf
.
import_graph_def
(
graph_def
,
name
=
''
)
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
global_variables_initializer
())
self
.
tf_graph
=
TFGraph
(
sess
.
graph
.
_as_graph_def
(
add_shapes
=
True
)[
0
])
self
.
tf_graph
.
build
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录