Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
32b98f8e
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
32b98f8e
编写于
7月 23, 2019
作者:
J
jiangjiajun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
mobilenet support for tensorflow
上级
4e3cdf05
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
256 addition
and
19 deletion
+256
-19
x2paddle/core/graph.py
x2paddle/core/graph.py
+4
-3
x2paddle/core/op_mapper.py
x2paddle/core/op_mapper.py
+8
-3
x2paddle/core/util.py
x2paddle/core/util.py
+21
-5
x2paddle/decoder/tf_decoder.py
x2paddle/decoder/tf_decoder.py
+32
-4
x2paddle/op_mapper/tf_op_mapper.py
x2paddle/op_mapper/tf_op_mapper.py
+191
-4
未找到文件。
x2paddle/core/graph.py
浏览文件 @
32b98f8e
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
import
collections
from
copy
import
deepcopy
import
copy
as
cp
class
GraphNode
(
object
):
...
...
@@ -77,7 +77,7 @@ class Graph(object):
if
name
.
split
(
':'
)[
0
]
in
self
.
node_map
:
name_prefix
,
idx
=
name
.
split
(
':'
)
if
copy
:
node
=
deep
copy
(
self
.
node_map
[
name_prefix
])
node
=
cp
.
copy
(
self
.
node_map
[
name_prefix
])
else
:
node
=
self
.
node_map
[
name_prefix
]
node
.
index
=
int
(
idx
)
...
...
@@ -86,7 +86,7 @@ class Graph(object):
raise
Exception
(
"Graph doesn't have node [%s]."
%
name
)
else
:
if
copy
:
node
=
deep
copy
(
self
.
node_map
[
name
])
node
=
cp
.
copy
(
self
.
node_map
[
name
])
else
:
node
=
self
.
node_map
[
name
]
return
node
...
...
@@ -110,6 +110,7 @@ class Graph(object):
del
self
.
node_map
[
input
].
inputs
[
idx
]
del
self
.
node_map
[
node_name
]
print
(
"remove topo"
,
node_name
)
idx
=
self
.
topo_sort
.
index
(
node_name
)
del
self
.
topo_sort
[
idx
]
...
...
x2paddle/core/op_mapper.py
浏览文件 @
32b98f8e
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
from
x2paddle.core.util
import
*
import
inspect
import
os
...
...
@@ -33,7 +34,8 @@ class OpMapper(object):
if
len
(
unsupported_ops
)
==
0
:
return
True
else
:
print
(
"There are {} ops not supported yet, list as below"
)
print
(
"There are {} ops not supported yet, list as below"
.
format
(
len
(
unsupported_ops
)))
for
op
in
unsupported_ops
:
print
(
op
)
return
False
...
...
@@ -41,9 +43,10 @@ class OpMapper(object):
def
add_codes
(
self
,
codes
,
indent
=
0
):
if
isinstance
(
codes
,
list
):
for
code
in
codes
:
self
.
paddle_codes
+=
(
self
.
tab
*
indent
+
code
+
'
\n
'
)
self
.
paddle_codes
+=
(
self
.
tab
*
indent
+
code
.
strip
(
'
\n
'
)
+
'
\n
'
)
elif
isinstance
(
codes
,
str
):
self
.
paddle_codes
+=
(
self
.
tab
*
indent
+
codes
+
'
\n
'
)
self
.
paddle_codes
+=
(
self
.
tab
*
indent
+
codes
.
strip
(
'
\n
'
)
+
'
\n
'
)
else
:
raise
Exception
(
"Unknown type of codes"
)
...
...
@@ -61,6 +64,8 @@ class OpMapper(object):
export_paddle_param
(
param
,
name
,
save_dir
)
self
.
add_heads
()
self
.
add_codes
(
self
.
net_code
)
self
.
add_codes
(
""
)
self
.
add_codes
(
inspect
.
getsourcelines
(
init_net
)[
0
])
fp
=
open
(
os
.
path
.
join
(
save_dir
,
"model.py"
),
'w'
)
fp
.
write
(
self
.
paddle_codes
)
fp
.
close
()
x2paddle/core/util.py
浏览文件 @
32b98f8e
...
...
@@ -13,7 +13,8 @@
# limitations under the License.
from
paddle.fluid.proto
import
framework_pb2
import
struct
import
paddle.fluid
as
fluid
import
numpy
import
math
import
os
...
...
@@ -49,14 +50,29 @@ def export_paddle_param(param, param_name, dir):
os
.
makedirs
(
dir
)
fp
=
open
(
os
.
path
.
join
(
dir
,
param_name
),
'wb'
)
fp
.
write
(
struct
.
pack
(
'i'
,
0
)
)
fp
.
write
(
struct
.
pack
(
'L'
,
0
)
)
fp
.
write
(
struct
.
pack
(
'i'
,
0
)
)
numpy
.
array
([
0
],
dtype
=
'int32'
).
tofile
(
fp
)
numpy
.
array
([
0
],
dtype
=
'int64'
).
tofile
(
fp
)
numpy
.
array
([
0
],
dtype
=
'int32'
).
tofile
(
fp
)
tensor_desc
=
framework_pb2
.
VarType
.
TensorDesc
()
tensor_desc
.
data_type
=
dtype_map
[
str
(
param
.
dtype
)][
0
]
tensor_desc
.
dims
.
extend
(
shape
)
desc_size
=
tensor_desc
.
ByteSize
()
fp
.
write
(
struct
.
pack
(
'i'
,
desc_size
)
)
numpy
.
array
([
desc_size
],
dtype
=
'int32'
).
tofile
(
fp
)
fp
.
write
(
tensor_desc
.
SerializeToString
())
param
.
tofile
(
fp
)
fp
.
close
()
def
init_net
(
param_dir
=
"./"
):
import
os
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
exe
.
run
(
fluid
.
default_startup_program
())
def
if_exist
(
var
):
b
=
os
.
path
.
exists
(
os
.
path
.
join
(
param_dir
,
var
.
name
))
return
b
fluid
.
io
.
load_vars
(
exe
,
param_dir
,
fluid
.
default_main_program
(),
predicate
=
if_exist
)
x2paddle/decoder/tf_decoder.py
浏览文件 @
32b98f8e
...
...
@@ -18,7 +18,8 @@ from tensorflow.python.framework import tensor_util
from
tensorflow.python.platform
import
gfile
from
tensorflow.core.framework
import
attr_value_pb2
import
tensorflow
as
tf
import
copy
import
copy
as
cp
import
sys
class
TFGraphNode
(
GraphNode
):
...
...
@@ -121,11 +122,12 @@ class TFGraph(Graph):
# delete isolated nodes
isolated_nodes
=
list
()
for
node_name
in
self
.
node_map
.
keys
():
if
len
(
self
.
get_node
(
node_name
).
inputs
)
==
0
or
len
(
if
len
(
self
.
get_node
(
node_name
).
inputs
)
==
0
and
len
(
self
.
get_node
(
node_name
).
outputs
)
==
0
:
isolated_nodes
.
append
(
node_name
)
self
.
remove_node
(
node_name
)
for
node_name
in
isolated_nodes
:
self
.
remove_node
(
node_name
)
def
_remove_identity_node
(
self
):
identity_node
=
list
()
...
...
@@ -153,14 +155,40 @@ class TFGraph(Graph):
del
self
.
topo_sort
[
idx
]
def
check_input_shape
(
graph_def
):
graph_def
=
cp
.
deepcopy
(
graph_def
)
input_map
=
dict
()
for
layer
in
graph_def
.
node
:
if
layer
.
op
!=
"Placeholder"
:
continue
graph_node
=
TFGraphNode
(
layer
)
dtype
=
graph_node
.
dtype
# print("shape:", graph_node.out_shapes)
if
not
graph_node
.
get_attr
(
"shape"
):
sys
.
stderr
.
write
(
"Unknown shape for input tensor[{}]
\n
"
.
format
(
layer
.
name
))
shape
=
input
(
"Please define shape of input here: "
)
shape
=
[
None
if
dim
==
"None"
else
int
(
dim
)
for
dim
in
shape
.
strip
().
split
(
','
)
]
x2paddle_input
=
tf
.
placeholder
(
dtype
=
dtype
,
shape
=
shape
,
name
=
"x2paddle_{}"
.
format
(
layer
.
name
))
input_map
[
"{}:0"
.
format
(
layer
.
name
)]
=
x2paddle_input
return
input_map
class
TFDecoder
(
object
):
def
__init__
(
self
,
pb_model
):
sess
=
tf
.
Session
()
with
gfile
.
FastGFile
(
pb_model
,
'rb'
)
as
f
:
graph_def
=
tf
.
GraphDef
()
graph_def
.
ParseFromString
(
f
.
read
())
input_map
=
check_input_shape
(
graph_def
)
sess
.
graph
.
as_default
()
tf
.
import_graph_def
(
graph_def
,
name
=
''
)
tf
.
import_graph_def
(
graph_def
,
name
=
''
,
input_map
=
input_map
)
sess
.
run
(
tf
.
global_variables_initializer
())
...
...
x2paddle/op_mapper/tf_op_mapper.py
浏览文件 @
32b98f8e
...
...
@@ -48,6 +48,8 @@ class TFOpMapper(OpMapper):
def
Placeholder
(
self
,
node
):
shape
=
node
.
out_shapes
[
0
]
assert
len
(
shape
)
!=
0
,
"Unknown shape of input nodes[{}]."
.
format
(
node
.
layer_name
)
dtype
=
node
.
dtype
attr
=
{
'dtype'
:
string
(
dtype
),
...
...
@@ -171,10 +173,11 @@ class TFOpMapper(OpMapper):
"pool_type"
:
string
(
"max"
),
"pool_stride"
:
strides
[
2
:
4
]
}
node
.
fluid_code
.
add_layer
(
"pool2d"
,
inputs
=
input
if
channel_first
else
node
,
output
=
node
,
param_attr
=
attr
)
node
.
fluid_code
.
add_layer
(
"pool2d"
,
inputs
=
input
if
channel_first
and
pad_mode
!=
"SAME"
else
node
,
output
=
node
,
param_attr
=
attr
)
if
not
channel_first
:
attr
=
{
"perm"
:
[
0
,
2
,
3
,
1
]}
...
...
@@ -227,6 +230,102 @@ class TFOpMapper(OpMapper):
"stride"
:
strides
[
2
:
4
],
"dilation"
:
dilations
[
2
:
4
]
}
node
.
fluid_code
.
add_layer
(
"conv2d"
,
inputs
=
input
if
channel_first
and
pad_mode
!=
"SAME"
else
node
,
output
=
node
,
param_attr
=
attr
)
if
not
channel_first
:
attr
=
{
"perm"
:
[
0
,
2
,
3
,
1
]}
node
.
fluid_code
.
add_layer
(
"transpose"
,
inputs
=
node
,
output
=
node
,
param_attr
=
attr
)
def
Relu6
(
self
,
node
):
input
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
node
.
fluid_code
.
add_layer
(
"relu6"
,
inputs
=
input
,
output
=
node
,
param_attr
=
None
)
def
FusedBatchNorm
(
self
,
node
):
input
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
gamma
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
beta
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
2
],
copy
=
True
)
moving_mean
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
3
],
copy
=
True
)
moving_var
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
4
],
copy
=
True
)
assert
gamma
.
layer_type
==
"Const"
assert
beta
.
layer_type
==
"Const"
assert
moving_mean
.
layer_type
==
"Const"
assert
moving_var
.
layer_type
==
"Const"
self
.
omit_nodes
.
append
(
gamma
.
layer_name
)
self
.
omit_nodes
.
append
(
beta
.
layer_name
)
self
.
omit_nodes
.
append
(
moving_mean
.
layer_name
)
self
.
omit_nodes
.
append
(
moving_var
.
layer_name
)
attr
=
{
"epsilon"
:
node
.
get_attr
(
"epsilon"
),
"param_attr"
:
string
(
gamma
.
layer_name
),
"data_layout"
:
string
(
node
.
get_attr
(
"data_format"
).
decode
()),
"bias_attr"
:
string
(
beta
.
layer_name
),
"moving_mean_name"
:
string
(
moving_mean
.
layer_name
),
"moving_variance_name"
:
string
(
moving_var
.
layer_name
),
"is_test"
:
True
}
node
.
fluid_code
.
add_layer
(
"batch_norm"
,
inputs
=
input
,
output
=
node
,
param_attr
=
attr
)
def
DepthwiseConv2dNative
(
self
,
node
):
input
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
kernel
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
assert
kernel
.
layer_type
==
"Const"
,
"Kernel of DepthwiseConv2DNative should be Const"
self
.
omit_nodes
.
append
(
kernel
.
layer_name
)
in_shape
=
input
.
out_shapes
[
0
]
k_size
=
kernel
.
out_shapes
[
0
]
strides
=
node
.
get_attr
(
"strides"
)
dilations
=
node
.
get_attr
(
"dilations"
)
data_format
=
node
.
get_attr
(
"data_format"
).
decode
()
pad_mode
=
node
.
get_attr
(
"padding"
).
decode
()
channel_first
=
data_format
==
"NCHW"
if
not
channel_first
:
self
.
weights
[
kernel
.
layer_name
.
replace
(
'/'
,
'_'
)]
=
numpy
.
transpose
(
kernel
.
value
,
(
2
,
3
,
0
,
1
))
attr
=
{
"perm"
:
[
0
,
3
,
1
,
2
]}
node
.
fluid_code
.
add_layer
(
"transpose"
,
inputs
=
input
,
output
=
node
,
param_attr
=
attr
)
in_shape
=
[
in_shape
[
i
]
for
i
in
[
0
,
3
,
1
,
2
]]
strides
=
[
strides
[
i
]
for
i
in
[
0
,
3
,
1
,
2
]]
dilations
=
[
dilations
[
i
]
for
i
in
[
0
,
3
,
1
,
2
]]
if
pad_mode
==
"SAME"
:
pad_h
=
get_same_padding
(
in_shape
[
2
],
k_size
[
0
],
strides
[
2
])
pad_w
=
get_same_padding
(
in_shape
[
3
],
k_size
[
1
],
strides
[
3
])
attr
=
{
"paddings"
:
pad_h
+
pad_w
,
"pad_value"
:
0.0
}
if
pad_h
[
0
]
+
pad_h
[
1
]
+
pad_w
[
0
]
+
pad_w
[
1
]
!=
0
:
node
.
fluid_code
.
add_layer
(
"pad2d"
,
inputs
=
input
if
channel_first
and
pad_mode
!=
"SAME"
else
node
,
output
=
node
,
param_attr
=
attr
)
attr
=
{
"bias_attr"
:
False
,
"param_attr"
:
string
(
kernel
.
layer_name
),
"num_filters"
:
in_shape
[
1
],
"filter_size"
:
k_size
[
0
:
2
],
"stride"
:
strides
[
2
:
4
],
"dilation"
:
dilations
[
2
:
4
],
"groups"
:
k_size
[
3
]
*
in_shape
[
1
]
}
node
.
fluid_code
.
add_layer
(
"conv2d"
,
inputs
=
input
if
channel_first
else
node
,
output
=
node
,
...
...
@@ -238,3 +337,91 @@ class TFOpMapper(OpMapper):
inputs
=
node
,
output
=
node
,
param_attr
=
attr
)
def
Shape
(
self
,
node
):
input
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
node
.
fluid_code
.
add_layer
(
"shape"
,
inputs
=
input
,
output
=
node
,
param_attr
=
None
)
def
Reshape
(
self
,
node
):
input
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
param
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
if
param
.
layer_type
==
"Const"
:
attr
=
{
"shape"
:
param
.
value
.
tolist
()}
else
:
# Here is a trick method to solove tensor parameter in tensorflow
assert
len
(
param
.
out_shapes
[
0
]
)
==
1
,
"Unexpected situation of shape parameter"
attr
=
{
"num_or_sections"
:
param
.
out_shapes
[
0
][
0
],
"dim"
:
0
}
node
.
fluid_code
.
add_layer
(
"split"
,
inputs
=
param
,
output
=
node
,
param_attr
=
attr
)
new_param
=
"["
for
i
in
range
(
param
.
out_shapes
[
0
][
0
]):
new_param
+=
(
node
.
layer_name
+
"[{}]"
.
format
(
i
)
+
", "
)
new_param
=
new_param
.
strip
(
", "
)
+
"]"
attr
=
{
"shape"
:
new_param
}
node
.
fluid_code
.
add_layer
(
"reshape"
,
inputs
=
input
,
output
=
node
,
param_attr
=
attr
)
def
Add
(
self
,
node
):
x
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
y
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
inputs
=
{
"x"
:
x
,
"y"
:
y
}
node
.
fluid_code
.
add_layer
(
"elementwise_add"
,
inputs
=
inputs
,
output
=
node
,
param_attr
=
None
)
def
AvgPool
(
self
,
node
):
input
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
in_shape
=
input
.
out_shapes
[
0
]
k_size
=
node
.
get_attr
(
"ksize"
)
strides
=
node
.
get_attr
(
"strides"
)
data_format
=
node
.
get_attr
(
"data_format"
).
decode
()
pad_mode
=
node
.
get_attr
(
"padding"
).
decode
()
channel_first
=
data_format
==
"NCHW"
if
not
channel_first
:
attr
=
{
"perm"
:
[
0
,
3
,
1
,
2
]}
node
.
fluid_code
.
add_layer
(
"transpose"
,
inputs
=
input
,
output
=
node
,
param_attr
=
attr
)
in_shape
=
[
in_shape
[
i
]
for
i
in
[
0
,
3
,
1
,
2
]]
strides
=
[
strides
[
i
]
for
i
in
[
0
,
3
,
1
,
2
]]
attr
=
{
"pool_size"
:
k_size
[
1
:
3
],
"pool_type"
:
string
(
"avg"
),
"pool_stride"
:
strides
[
2
:
4
]
}
if
pad_mode
==
"SAME"
:
pad_h
=
get_same_padding
(
in_shape
[
2
],
k_size
[
0
],
strides
[
2
])
pad_w
=
get_same_padding
(
in_shape
[
3
],
k_size
[
1
],
strides
[
3
])
assert
pad_h
[
0
]
==
pad_h
[
1
]
and
pad_w
[
0
]
==
pad_w
[
1
],
"Cannot map AvgPool"
attr
[
"pool_padding"
]
=
[
pad_h
[
0
],
pad_w
[
0
]]
node
.
fluid_code
.
add_layer
(
"pool2d"
,
inputs
=
input
if
channel_first
else
node
,
output
=
node
,
param_attr
=
attr
)
if
not
channel_first
:
attr
=
{
"perm"
:
[
0
,
2
,
3
,
1
]}
node
.
fluid_code
.
add_layer
(
"transpose"
,
inputs
=
node
,
output
=
node
,
param_attr
=
attr
)
def
Softmax
(
self
,
node
):
input
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
node
.
fluid_code
.
add_layer
(
"softmax"
,
inputs
=
input
,
output
=
node
,
param_attr
=
None
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录