Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
99582ade
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
99582ade
编写于
7月 18, 2019
作者:
J
jiangjiajun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
common code generate and weight dump
上级
1f79d43d
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
96 addition
and
30 deletion
+96
-30
x2paddle/convert.py
x2paddle/convert.py
+64
-8
x2paddle/core/emitter.py
x2paddle/core/emitter.py
+15
-2
x2paddle/core/fluid_code.py
x2paddle/core/fluid_code.py
+1
-0
x2paddle/core/util.py
x2paddle/core/util.py
+0
-1
x2paddle/emitter/tf_emitter.py
x2paddle/emitter/tf_emitter.py
+4
-11
x2paddle/parser/tf_parser.py
x2paddle/parser/tf_parser.py
+12
-8
未找到文件。
x2paddle/convert.py
浏览文件 @
99582ade
...
@@ -14,14 +14,70 @@
...
@@ -14,14 +14,70 @@
from
x2paddle.parser.tf_parser
import
TFParser
from
x2paddle.parser.tf_parser
import
TFParser
from
x2paddle.optimizer.tf_optimizer
import
TFGraphOptimizer
from
x2paddle.optimizer.tf_optimizer
import
TFGraphOptimizer
from
x2paddle.emitter.tf_emitter
import
TFEmitter
from
x2paddle.emitter.tf_emitter
import
TFEmitter
from
six
import
text_type
as
_text_type
import
argparse
parser
=
TFParser
(
'/ssd2/Jason/github/X2Paddle/tool/vgg16_None.pb'
,
in_nodes
=
[
'inputs'
],
out_nodes
=
[
'output_boxes'
],
in_shapes
=
[[
-
1
,
416
,
416
,
3
]])
optimizer
=
TFGraphOptimizer
()
def
arg_parser
():
#parser.tf_graph.print()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model"
,
"-m"
,
type
=
_text_type
,
default
=
None
,
help
=
"model file path"
)
parser
.
add_argument
(
"--proto"
,
"-p"
,
type
=
_text_type
,
default
=
None
,
help
=
"proto file of caffe model"
)
parser
.
add_argument
(
"--weight"
,
"-w"
,
type
=
_text_type
,
default
=
None
,
help
=
"weight file of caffe model"
)
parser
.
add_argument
(
"--save_dir"
,
"-s"
,
type
=
_text_type
,
default
=
None
,
help
=
"path to save translated model"
)
parser
.
add_argument
(
"--framework"
,
"-f"
,
type
=
_text_type
,
default
=
None
,
help
=
"define which deeplearning framework"
)
return
parser
emitter
=
TFEmitter
(
parser
)
emitter
.
run
()
def
tf2paddle
(
model
,
save_dir
):
print
(
"Now translating model from tensorflow to paddle."
)
parser
=
TFParser
(
model
)
emitter
=
TFEmitter
(
parser
)
emitter
.
run
()
emitter
.
save_python_model
(
save_dir
)
def
caffe2paddle
(
proto
,
weight
,
save_dir
):
print
(
"Not implement yet."
)
def
main
():
parser
=
arg_parser
()
args
=
parser
.
parse_args
()
assert
args
.
framework
is
not
None
,
"--from is not defined(tensorflow/caffe)"
assert
args
.
save_dir
is
not
None
,
"--save_dir is not defined"
if
args
.
framework
==
"tensorflow"
:
assert
args
.
model
is
not
None
,
"--model should be defined while translate tensorflow model"
tf2paddle
(
args
.
model
,
args
.
save_dir
)
elif
args
.
framework
==
"caffe"
:
assert
args
.
proto
is
not
None
,
"--proto and --weight should be defined while translate caffe model"
caffe2paddle
(
args
.
proto
,
args
.
weight
,
args
.
save_dir
)
else
:
raise
Exception
(
"--framework only support tensorflow/caffe now"
)
if
__name__
==
"__main__"
:
main
()
x2paddle/core/emitter.py
浏览文件 @
99582ade
...
@@ -12,11 +12,16 @@
...
@@ -12,11 +12,16 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
x2paddle.core.util
import
*
import
os
class
Emitter
(
object
):
class
Emitter
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
paddle_codes
=
""
self
.
paddle_codes
=
""
self
.
tab
=
" "
self
.
tab
=
" "
self
.
net_code
=
list
()
self
.
weights
=
dict
()
def
add_codes
(
self
,
codes
,
indent
=
0
):
def
add_codes
(
self
,
codes
,
indent
=
0
):
if
isinstance
(
codes
,
list
):
if
isinstance
(
codes
,
list
):
...
@@ -28,11 +33,19 @@ class Emitter(object):
...
@@ -28,11 +33,19 @@ class Emitter(object):
raise
Exception
(
"Unknown type of codes"
)
raise
Exception
(
"Unknown type of codes"
)
def
add_heads
(
self
):
def
add_heads
(
self
):
self
.
add_codes
(
"from paddle.fluid.initializer import Constant"
)
self
.
add_codes
(
"from paddle.fluid.param_attr import ParamAttr"
)
self
.
add_codes
(
"import paddle.fluid as fluid"
)
self
.
add_codes
(
"import paddle.fluid as fluid"
)
self
.
add_codes
(
""
)
self
.
add_codes
(
""
)
def
save_inference_model
(
self
):
def
save_inference_model
(
self
):
print
(
"Not Implement"
)
print
(
"Not Implement"
)
def
save_python_code
(
self
):
def
save_python_model
(
self
,
save_dir
):
print
(
"Not Implement"
)
for
name
,
param
in
self
.
weights
.
items
():
export_paddle_param
(
param
,
name
,
save_dir
)
self
.
add_heads
()
self
.
add_codes
(
self
.
net_code
)
fp
=
open
(
os
.
path
.
join
(
save_dir
,
"model.py"
),
'w'
)
fp
.
write
(
self
.
paddle_codes
)
fp
.
close
()
x2paddle/core/fluid_code.py
浏览文件 @
99582ade
...
@@ -100,3 +100,4 @@ class FluidCode(object):
...
@@ -100,3 +100,4 @@ class FluidCode(object):
codes
.
append
(
layer
.
get_code
())
codes
.
append
(
layer
.
get_code
())
elif
isinstance
(
layer
,
str
):
elif
isinstance
(
layer
,
str
):
codes
.
append
(
layer
)
codes
.
append
(
layer
)
return
codes
x2paddle/core/util.py
浏览文件 @
99582ade
...
@@ -44,7 +44,6 @@ def export_paddle_param(param, param_name, dir):
...
@@ -44,7 +44,6 @@ def export_paddle_param(param, param_name, dir):
if
len
(
shape
)
==
0
:
if
len
(
shape
)
==
0
:
assert
param
.
size
==
1
,
"Unexpected situation happend!"
assert
param
.
size
==
1
,
"Unexpected situation happend!"
shape
=
[
1
]
shape
=
[
1
]
print
(
"param dtype:"
,
param
.
dtype
)
assert
str
(
param
.
dtype
)
in
dtype_map
,
"Unknown dtype of params."
assert
str
(
param
.
dtype
)
in
dtype_map
,
"Unknown dtype of params."
fp
=
open
(
os
.
path
.
join
(
dir
,
param_name
),
'wb'
)
fp
=
open
(
os
.
path
.
join
(
dir
,
param_name
),
'wb'
)
...
...
x2paddle/emitter/tf_emitter.py
浏览文件 @
99582ade
...
@@ -28,7 +28,6 @@ class TFEmitter(Emitter):
...
@@ -28,7 +28,6 @@ class TFEmitter(Emitter):
# only for define attribute of op
# only for define attribute of op
self
.
attr_node
=
list
()
self
.
attr_node
=
list
()
self
.
omit_nodes
=
list
()
self
.
omit_nodes
=
list
()
self
.
weights
=
dict
()
def
run
(
self
):
def
run
(
self
):
print
(
"Total nodes: {}"
.
format
(
len
(
self
.
graph
.
topo_sort
)))
print
(
"Total nodes: {}"
.
format
(
len
(
self
.
graph
.
topo_sort
)))
...
@@ -44,13 +43,7 @@ class TFEmitter(Emitter):
...
@@ -44,13 +43,7 @@ class TFEmitter(Emitter):
if
node_name
in
self
.
omit_nodes
:
if
node_name
in
self
.
omit_nodes
:
continue
continue
node
=
self
.
graph
.
get_node
(
node_name
)
node
=
self
.
graph
.
get_node
(
node_name
)
for
layer
in
node
.
fluid_code
.
layers
:
self
.
net_code
+=
node
.
fluid_code
.
gen_codes
()
print
(
layer
.
get_code
())
for
name
,
param
in
self
.
weights
.
items
():
node
=
self
.
graph
.
get_node
(
name
)
export_paddle_param
(
param
,
node
.
layer_name
.
replace
(
'/'
,
'_'
),
"params1"
)
def
Placeholder
(
self
,
node
):
def
Placeholder
(
self
,
node
):
shape
=
node
.
out_shapes
[
0
]
shape
=
node
.
out_shapes
[
0
]
...
@@ -85,13 +78,13 @@ class TFEmitter(Emitter):
...
@@ -85,13 +78,13 @@ class TFEmitter(Emitter):
inputs
=
None
,
inputs
=
None
,
output
=
node
,
output
=
node
,
param_attr
=
attr
)
param_attr
=
attr
)
self
.
weights
[
node
.
layer_name
]
=
node
.
value
self
.
weights
[
node
.
layer_name
.
replace
(
'/'
,
'_'
)
]
=
node
.
value
def
Transpose
(
self
,
node
):
def
Transpose
(
self
,
node
):
input
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
input
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
perm
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
perm
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
assert
perm
.
layer_type
==
"Const"
,
"Perm of transpose OP should be Const"
assert
perm
.
layer_type
==
"Const"
,
"Perm of transpose OP should be Const"
del
self
.
weights
[
perm
.
layer_name
]
del
self
.
weights
[
perm
.
layer_name
.
replace
(
'/'
,
'_'
)
]
perm
.
fluid_code
.
clear
()
perm
.
fluid_code
.
clear
()
perm
=
perm
.
value
.
tolist
()
perm
=
perm
.
value
.
tolist
()
...
@@ -204,7 +197,7 @@ class TFEmitter(Emitter):
...
@@ -204,7 +197,7 @@ class TFEmitter(Emitter):
channel_first
=
data_format
==
"NCHW"
channel_first
=
data_format
==
"NCHW"
if
not
channel_first
:
if
not
channel_first
:
self
.
weights
[
kernel
.
layer_name
]
=
numpy
.
transpose
(
self
.
weights
[
kernel
.
layer_name
.
replace
(
'/'
,
'_'
)
]
=
numpy
.
transpose
(
kernel
.
value
,
(
3
,
2
,
0
,
1
))
kernel
.
value
,
(
3
,
2
,
0
,
1
))
attr
=
{
"perm"
:
[
0
,
3
,
1
,
2
]}
attr
=
{
"perm"
:
[
0
,
3
,
1
,
2
]}
node
.
fluid_code
.
add_layer
(
"transpose"
,
node
.
fluid_code
.
add_layer
(
"transpose"
,
...
...
x2paddle/parser/tf_parser.py
浏览文件 @
99582ade
...
@@ -24,9 +24,11 @@ import copy
...
@@ -24,9 +24,11 @@ import copy
class
TFGraphNode
(
GraphNode
):
class
TFGraphNode
(
GraphNode
):
def
__init__
(
self
,
layer
,
layer_name
=
None
):
def
__init__
(
self
,
layer
,
layer_name
=
None
):
if
layer_name
is
None
:
if
layer_name
is
None
:
super
(
TFGraphNode
,
self
).
__init__
(
layer
,
layer
.
name
)
super
(
TFGraphNode
,
self
).
__init__
(
layer
,
layer
.
name
.
replace
(
'/'
,
'_'
))
else
:
else
:
super
(
TFGraphNode
,
self
).
__init__
(
layer
,
layer_name
)
super
(
TFGraphNode
,
self
).
__init__
(
layer
,
layer_name
.
replace
(
'/'
,
'_'
))
self
.
layer_type
=
layer
.
op
self
.
layer_type
=
layer
.
op
self
.
fluid_code
=
FluidCode
()
self
.
fluid_code
=
FluidCode
()
...
@@ -86,10 +88,11 @@ class TFGraph(Graph):
...
@@ -86,10 +88,11 @@ class TFGraph(Graph):
def
build
(
self
):
def
build
(
self
):
for
layer
in
self
.
model
.
node
:
for
layer
in
self
.
model
.
node
:
self
.
node_map
[
layer
.
name
]
=
TFGraphNode
(
layer
)
self
.
node_map
[
layer
.
name
.
replace
(
'/'
,
'_'
)
]
=
TFGraphNode
(
layer
)
for
layer_name
,
node
in
self
.
node_map
.
items
():
for
layer_name
,
node
in
self
.
node_map
.
items
():
for
in_node
in
node
.
layer
.
input
:
for
in_node
in
node
.
layer
.
input
:
in_node
=
in_node
.
replace
(
'/'
,
'_'
)
if
in_node
not
in
self
.
node_map
:
if
in_node
not
in
self
.
node_map
:
if
in_node
.
strip
().
split
(
':'
)[
0
]
in
self
.
node_map
:
if
in_node
.
strip
().
split
(
':'
)[
0
]
in
self
.
node_map
:
self
.
connect
(
in_node
.
strip
().
split
(
':'
)[
0
],
layer_name
)
self
.
connect
(
in_node
.
strip
().
split
(
':'
)[
0
],
layer_name
)
...
@@ -108,6 +111,7 @@ class TFGraph(Graph):
...
@@ -108,6 +111,7 @@ class TFGraph(Graph):
def
get_node
(
self
,
node_name
,
copy
=
False
):
def
get_node
(
self
,
node_name
,
copy
=
False
):
items
=
node_name
.
strip
().
split
(
':'
)
items
=
node_name
.
strip
().
split
(
':'
)
items
[
0
]
=
items
[
0
].
replace
(
'/'
,
'_'
)
if
items
[
0
]
in
self
.
identity_map
:
if
items
[
0
]
in
self
.
identity_map
:
items
[
0
]
=
self
.
identity_map
[
items
[
0
]]
items
[
0
]
=
self
.
identity_map
[
items
[
0
]]
new_node_name
=
":"
.
join
(
items
)
new_node_name
=
":"
.
join
(
items
)
...
@@ -151,11 +155,11 @@ class TFGraph(Graph):
...
@@ -151,11 +155,11 @@ class TFGraph(Graph):
class
TFParser
(
object
):
class
TFParser
(
object
):
def
__init__
(
self
,
pb_model
,
in_nodes
=
None
,
out_nodes
=
None
,
in_shapes
=
None
):
def
__init__
(
self
,
pb_model
,
in_nodes
=
None
,
out_nodes
=
None
,
in_shapes
=
None
):
assert
in_nodes
is
not
None
,
"in_nodes should not be None"
#
assert in_nodes is not None, "in_nodes should not be None"
assert
out_nodes
is
not
None
,
"out_nodes should not be None"
#
assert out_nodes is not None, "out_nodes should not be None"
assert
in_shapes
is
not
None
,
"in_shapes should not be None"
#
assert in_shapes is not None, "in_shapes should not be None"
assert
len
(
in_shapes
)
==
len
(
#
assert len(in_shapes) == len(
in_nodes
),
"length of in_shapes and in_nodes should be equal"
#
in_nodes), "length of in_shapes and in_nodes should be equal"
sess
=
tf
.
Session
()
sess
=
tf
.
Session
()
with
gfile
.
FastGFile
(
pb_model
,
'rb'
)
as
f
:
with
gfile
.
FastGFile
(
pb_model
,
'rb'
)
as
f
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录