Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
7a9ee4ca
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
7a9ee4ca
编写于
7月 25, 2018
作者:
李
李寅
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'fix-tf-tf' into 'master'
Fix bug: transform fc of tensorflow. See merge request !683
上级
fdcd657b
e4c0d531
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
51 addition
and
22 deletion
+51
-22
mace/python/tools/converter_tool/tensorflow_converter.py
mace/python/tools/converter_tool/tensorflow_converter.py
+1
-1
mace/python/tools/converter_tool/transformer.py
mace/python/tools/converter_tool/transformer.py
+50
-21
未找到文件。
mace/python/tools/converter_tool/tensorflow_converter.py
浏览文件 @
7a9ee4ca
...
...
@@ -578,7 +578,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
axis_arg
=
op
.
arg
.
add
()
axis_arg
.
name
=
MaceKeyword
.
mace_axis_str
axis
=
tf_op
.
inputs
[
-
1
].
eval
().
astype
(
np
.
int32
)
axis
=
4
+
axis
if
axis
<
0
else
axis
axis
=
len
(
op
.
output_shape
[
0
].
dims
)
+
axis
if
axis
<
0
else
axis
axis_arg
.
i
=
axis
self
.
_skip_tensor
.
add
(
tf_op
.
inputs
[
-
1
].
name
)
...
...
mace/python/tools/converter_tool/transformer.py
浏览文件 @
7a9ee4ca
...
...
@@ -751,6 +751,15 @@ class Transformer(base_converter.ConverterInterface):
"only support concat at "
"channel dimension"
)
arg
.
i
=
3
producer
=
self
.
_producer
[
op
.
input
[
0
]]
input_shape
=
producer
.
output_shape
[
0
].
dims
if
producer
.
type
==
MaceOp
.
FullyConnected
.
name
and
\
len
(
input_shape
)
==
2
:
axis_arg
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_axis_str
)
if
axis_arg
.
i
==
1
\
and
self
.
_target_data_format
==
DataFormat
.
NHWC
:
# noqa
axis_arg
.
i
=
3
elif
op
.
type
==
MaceOp
.
Squeeze
.
name
:
for
arg
in
op
.
arg
:
...
...
@@ -938,7 +947,10 @@ class Transformer(base_converter.ConverterInterface):
input_shape
=
list
(
input_op
.
output_shape
[
0
].
dims
)
input_data_format
=
ConverterUtil
.
data_format
(
input_op
)
weight
.
dims
[:]
=
[
weight
.
dims
[
0
]]
+
input_shape
[
1
:]
if
input_data_format
==
DataFormat
.
NHWC
:
if
len
(
input_shape
)
==
2
:
weight
.
dims
[:]
=
weight
.
dims
[:]
+
[
1
,
1
]
if
input_data_format
==
DataFormat
.
NHWC
and
\
len
(
input_shape
)
==
4
:
self
.
transpose_shape
(
weight
.
dims
,
[
0
,
3
,
1
,
2
])
return
False
...
...
@@ -1113,31 +1125,48 @@ class Transformer(base_converter.ConverterInterface):
net
=
self
.
_model
filter_format
=
self
.
filter_format
()
for
op
in
net
.
op
:
# transform
reshape + matmul ->
fc
# transform
input(4D) -> reshape(2D) -> matmul to
fc
# work for TensorFlow
if
op
.
type
==
MaceOp
.
Reshape
.
name
and
\
op
.
input
[
1
]
in
self
.
_consts
and
\
len
(
op
.
output_shape
[
0
].
dims
)
==
2
and
\
filter_format
==
FilterFormat
.
HWIO
:
input_op
=
self
.
_producer
[
op
.
input
[
0
]]
input_shape
=
input_op
.
output_shape
[
0
].
dims
# check input op
if
len
(
input_shape
)
==
4
and
\
np
.
prod
(
input_shape
[
1
:])
==
op
.
output_shape
[
0
].
dims
[
1
]:
is_fc
=
True
consumers
=
self
.
_consumers
[
op
.
output
[
0
]]
# check matmul op
for
matmul_op
in
consumers
:
if
matmul_op
.
type
!=
MaceOp
.
MatMul
.
name
:
is_fc
=
False
else
:
weight
=
self
.
_consts
[
matmul_op
.
input
[
1
]]
if
len
(
weight
.
dims
)
!=
2
or
\
weight
.
dims
[
0
]
!=
op
.
output_shape
[
0
].
dims
[
1
]:
is_fc
=
False
if
is_fc
:
print
'convert reshape and matmul to fc'
self
.
safe_remove_node
(
op
,
input_op
,
remove_input_tensor
=
True
)
for
matmul_op
in
consumers
:
weight
=
self
.
_consts
[
matmul_op
.
input
[
1
]]
matmul_op
.
type
=
MaceOp
.
FullyConnected
.
name
weight_data
=
np
.
array
(
weight
.
float_data
).
reshape
(
weight
.
dims
)
weight
.
dims
[:]
=
input_shape
[
1
:]
+
\
[
weight_data
.
shape
[
1
]]
return
True
# transform input(2D) -> matmul to fc
if
op
.
type
==
MaceOp
.
MatMul
.
name
and
\
filter_format
==
FilterFormat
.
HWIO
:
producer
=
self
.
_producer
[
op
.
input
[
0
]]
weight
=
self
.
_consts
[
op
.
input
[
1
]]
if
len
(
weight
.
dims
)
==
2
\
and
producer
.
type
==
MaceOp
.
Reshape
.
name
\
and
len
(
producer
.
output
)
==
1
\
and
producer
.
input
[
1
]
in
self
.
_consts
\
and
len
(
producer
.
output_shape
[
0
].
dims
)
==
2
:
input_op
=
self
.
_producer
[
producer
.
input
[
0
]]
input_shape
=
input_op
.
output_shape
[
0
].
dims
feature_size
=
np
.
prod
(
input_shape
[
1
:])
self
.
safe_remove_node
(
producer
,
input_op
,
remove_input_tensor
=
True
)
if
feature_size
==
producer
.
output_shape
[
0
].
dims
[
1
]:
print
'convert reshape and matmul to fc'
op
.
type
=
MaceOp
.
FullyConnected
.
name
weight_data
=
np
.
array
(
weight
.
float_data
).
reshape
(
weight
.
dims
)
weight
.
dims
[:]
=
input_shape
[
1
:]
+
\
[
weight_data
.
shape
[
1
]]
return
True
elif
len
(
weight
.
dims
)
==
2
and
\
if
len
(
weight
.
dims
)
==
2
and
\
producer
.
type
!=
MaceOp
.
Reshape
.
name
and
\
len
(
producer
.
output_shape
[
0
].
dims
)
==
2
and
\
weight
.
dims
[
0
]
==
producer
.
output_shape
[
0
].
dims
[
1
]:
print
'convert matmul to fc'
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录