Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Mr.Vain
Mace
提交
f13461ef
Mace
项目概览
Mr.Vain
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
f13461ef
编写于
12月 20, 2018
作者:
Y
yejianwu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support multi fc
上级
8b9021f7
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
54 addition
and
3 deletion
+54
-3
mace/python/tools/converter_tool/base_converter.py
mace/python/tools/converter_tool/base_converter.py
+2
-0
mace/python/tools/converter_tool/transformer.py
mace/python/tools/converter_tool/transformer.py
+50
-1
tools/converter.py
tools/converter.py
+2
-2
未找到文件。
mace/python/tools/converter_tool/base_converter.py
浏览文件 @
f13461ef
...
...
@@ -256,6 +256,7 @@ class TransformerRule(Enum):
TRANSPOSE_MATMUL_WEIGHT
=
34
FOLD_EMBEDDING_LOOKUP
=
35
TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN
=
36
FOLD_FC_RESHAPE
=
37
class
ConverterInterface
(
object
):
...
...
@@ -461,6 +462,7 @@ class ConverterOption(object):
TransformerRule
.
FOLD_SQRDIFF_MEAN
,
TransformerRule
.
TRANSFORM_GLOBAL_CONV_TO_FC
,
TransformerRule
.
RESHAPE_FC_WEIGHT
,
TransformerRule
.
FOLD_FC_RESHAPE
,
# Model data format related transformation
TransformerRule
.
TRANSPOSE_FILTERS
,
TransformerRule
.
TRANSPOSE_DATA_FORMAT
,
...
...
mace/python/tools/converter_tool/transformer.py
浏览文件 @
f13461ef
...
...
@@ -75,6 +75,8 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule
.
TRANSPOSE_FILTERS
:
self
.
transpose_filters
,
TransformerRule
.
TRANSPOSE_MATMUL_WEIGHT
:
self
.
transpose_matmul_weight
,
TransformerRule
.
FOLD_FC_RESHAPE
:
self
.
fold_fc_reshape
,
TransformerRule
.
TRANSPOSE_DATA_FORMAT
:
self
.
transpose_data_format
,
TransformerRule
.
ADD_WINOGRAD_ARG
:
self
.
add_winograd_arg
,
TransformerRule
.
ADD_IN_OUT_TENSOR_INFO
:
...
...
@@ -1227,11 +1229,24 @@ class Transformer(base_converter.ConverterInterface):
return
True
return
False
def
is_after_fc
(
self
,
op
):
while
op
.
input
[
0
]
in
self
.
_producer
:
producer
=
self
.
_producer
[
op
.
input
[
0
]]
if
producer
.
type
in
[
MaceOp
.
Activation
.
name
,
MaceOp
.
BiasAdd
.
name
]:
op
=
producer
continue
elif
producer
.
type
==
MaceOp
.
FullyConnected
.
name
:
return
True
else
:
return
False
return
False
def
transform_matmul_to_fc
(
self
):
net
=
self
.
_model
filter_format
=
self
.
filter_format
()
for
op
in
net
.
op
:
# transform input(4D) -> reshape(2D) -> matmul to fc
# transform `input(4D) -> reshape(2D) -> matmul` to `fc(2D)`
# fc output is 2D in transformer, using as 4D in op kernel
# work for TensorFlow
if
op
.
type
==
MaceOp
.
Reshape
.
name
and
\
len
(
op
.
input
)
==
2
and
\
...
...
@@ -1268,6 +1283,21 @@ class Transformer(base_converter.ConverterInterface):
[
weight_data
.
shape
[
1
]]
return
True
# transform `fc1(2D) -> matmul` to `fc1(2D) -> fc1(2D)`
if
op
.
type
==
MaceOp
.
MatMul
.
name
and
\
filter_format
==
FilterFormat
.
HWIO
:
producer
=
self
.
_producer
[
op
.
input
[
0
]]
weight
=
self
.
_consts
[
op
.
input
[
1
]]
if
len
(
weight
.
dims
)
==
2
and
self
.
is_after_fc
(
op
)
and
\
len
(
producer
.
output_shape
[
0
].
dims
)
==
2
and
\
weight
.
dims
[
0
]
==
producer
.
output_shape
[
0
].
dims
[
1
]:
six
.
print_
(
'convert matmul to fc'
)
op
.
type
=
MaceOp
.
FullyConnected
.
name
weight_data
=
np
.
array
(
weight
.
float_data
).
reshape
(
weight
.
dims
)
weight
.
dims
[:]
=
[
1
,
1
]
+
list
(
weight_data
.
shape
)
return
True
return
False
def
update_float_op_data_type
(
self
):
...
...
@@ -1750,3 +1780,22 @@ class Transformer(base_converter.ConverterInterface):
shape_tensor
.
data_type
=
mace_pb2
.
DT_INT32
shape_tensor
.
int32_data
.
extend
(
dims
)
op
.
input
.
append
(
shape_tensor
.
name
)
def
fold_fc_reshape
(
self
):
net
=
self
.
_model
for
op
in
net
.
op
:
# whether to reshape fc output(default 4D)
if
op
.
type
==
MaceOp
.
FullyConnected
.
name
:
consumers
=
self
.
_consumers
[
op
.
output
[
0
]]
op_output_shape
=
op
.
output_shape
[
0
].
dims
[:]
for
consumer
in
consumers
:
if
consumer
.
type
==
MaceOp
.
Reshape
.
name
and
\
consumer
.
input
[
1
]
in
self
.
_consts
and
\
self
.
_consts
[
consumer
.
input
[
1
]].
int32_data
[:]
==
\
[
op_output_shape
[
0
],
1
,
1
,
op_output_shape
[
1
]]:
# work for tensorflow
net
.
tensors
.
remove
(
self
.
_consts
[
consumer
.
input
[
1
]])
del
consumer
.
input
[
1
]
self
.
safe_remove_node
(
consumer
,
None
)
return
True
return
False
tools/converter.py
浏览文件 @
f13461ef
...
...
@@ -411,7 +411,7 @@ def format_model_config(flags):
ModuleName
.
YAML_CONFIG
,
"'input_data_formats' must be in "
+
str
(
DataFormatStrs
)
+
", but got "
+
input_data_format
s
)
+
input_data_format
)
else
:
subgraph
[
YAMLKeyword
.
input_data_formats
]
=
[
DataFormat
.
NHWC
]
...
...
@@ -431,7 +431,7 @@ def format_model_config(flags):
subgraph
[
YAMLKeyword
.
output_data_formats
]:
mace_check
(
output_data_format
in
DataFormatStrs
,
ModuleName
.
YAML_CONFIG
,
"'
in
put_data_formats' must be in "
"'
out
put_data_formats' must be in "
+
str
(
DataFormatStrs
))
else
:
subgraph
[
YAMLKeyword
.
output_data_formats
]
=
[
DataFormat
.
NHWC
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录