Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
毕竟曾有刹那
Mace
提交
a422aa26
Mace
项目概览
毕竟曾有刹那
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
a422aa26
编写于
5月 17, 2018
作者:
李
李寅
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add tensorflow fc support
上级
cc3ea692
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
185 addition
and
63 deletion
+185
-63
mace/core/arg_helper.cc
mace/core/arg_helper.cc
+2
-1
mace/python/tools/converter.py
mace/python/tools/converter.py
+1
-1
mace/python/tools/converter_tool/base_converter.py
mace/python/tools/converter_tool/base_converter.py
+22
-18
mace/python/tools/converter_tool/tensorflow_converter.py
mace/python/tools/converter_tool/tensorflow_converter.py
+22
-1
mace/python/tools/converter_tool/transformer.py
mace/python/tools/converter_tool/transformer.py
+138
-42
未找到文件。
mace/core/arg_helper.cc
浏览文件 @
a422aa26
...
...
@@ -23,7 +23,8 @@ namespace mace {
ArgumentHelper
::
ArgumentHelper
(
const
OperatorDef
&
def
)
{
for
(
auto
&
arg
:
def
.
arg
())
{
if
(
arg_map_
.
find
(
arg
.
name
())
!=
arg_map_
.
end
())
{
LOG
(
WARNING
)
<<
"Duplicated argument name found in operator def."
;
LOG
(
WARNING
)
<<
"Duplicated argument name found in operator def: "
<<
def
.
name
()
<<
" "
<<
arg
.
name
();
}
arg_map_
[
arg
.
name
()]
=
arg
;
...
...
mace/python/tools/converter.py
浏览文件 @
a422aa26
...
...
@@ -128,7 +128,7 @@ def main(unused_args):
FLAGS
.
weight_file
)
output_graph_def
=
converter
.
run
()
print
(
"Transform model to one that can better run on device
.
"
)
print
(
"Transform model to one that can better run on device"
)
if
not
FLAGS
.
runtime
:
cpu_graph_def
=
copy
.
deepcopy
(
output_graph_def
)
option
.
device
=
mace_pb2
.
CPU
...
...
mace/python/tools/converter_tool/base_converter.py
浏览文件 @
a422aa26
...
...
@@ -136,23 +136,25 @@ class MaceKeyword(object):
class
TransformerRule
(
Enum
):
REMOVE_IDENTITY_OP
=
0
TRANSFORM_GLOBAL_POOLING
=
1
FOLD_SOFTMAX
=
2
FOLD_BATCHNORM
=
3
,
FOLD_CONV_AND_BN
=
4
,
FOLD_DEPTHWISE_CONV_AND_BN
=
5
,
TRANSFORM_GPU_WINOGRAD
=
6
,
TRANSFORM_ADD_TO_BIASADD
=
7
,
FOLD_BIASADD
=
8
,
FOLD_ACTIVATION
=
9
,
TRANSPOSE_FILTERS
=
10
,
RESHAPE_FC_WEIGHT
=
11
,
TRANSPOSE_DATA_FORMAT
=
12
,
TRANSFORM_GLOBAL_CONV_TO_FC
=
13
,
TRANSFORM_BUFFER_IMAGE
=
14
,
ADD_DEVICE_AND_DATA_TYPE
=
15
,
SORT_BY_EXECUTION
=
16
REMOVE_USELESS_RESHAPE_OP
=
0
REMOVE_IDENTITY_OP
=
1
TRANSFORM_GLOBAL_POOLING
=
2
FOLD_RESHAPE
=
3
TRANSFORM_MATMUL_TO_FC
=
4
FOLD_BATCHNORM
=
5
FOLD_CONV_AND_BN
=
6
FOLD_DEPTHWISE_CONV_AND_BN
=
7
TRANSFORM_GPU_WINOGRAD
=
8
TRANSFORM_ADD_TO_BIASADD
=
9
FOLD_BIASADD
=
10
FOLD_ACTIVATION
=
11
TRANSPOSE_FILTERS
=
12
RESHAPE_FC_WEIGHT
=
13
TRANSPOSE_DATA_FORMAT
=
14
TRANSFORM_GLOBAL_CONV_TO_FC
=
15
TRANSFORM_BUFFER_IMAGE
=
16
ADD_DEVICE_AND_DATA_TYPE
=
17
SORT_BY_EXECUTION
=
18
class
ConverterInterface
(
object
):
...
...
@@ -199,9 +201,11 @@ class ConverterOption(object):
self
.
_device
=
mace_pb2
.
CPU
self
.
_winograd_enabled
=
False
self
.
_transformer_option
=
[
TransformerRule
.
REMOVE_USELESS_RESHAPE_OP
,
TransformerRule
.
REMOVE_IDENTITY_OP
,
TransformerRule
.
TRANSFORM_GLOBAL_POOLING
,
TransformerRule
.
FOLD_SOFTMAX
,
TransformerRule
.
FOLD_RESHAPE
,
TransformerRule
.
TRANSFORM_MATMUL_TO_FC
,
TransformerRule
.
FOLD_BATCHNORM
,
TransformerRule
.
FOLD_CONV_AND_BN
,
TransformerRule
.
FOLD_DEPTHWISE_CONV_AND_BN
,
...
...
mace/python/tools/converter_tool/tensorflow_converter.py
浏览文件 @
a422aa26
...
...
@@ -101,9 +101,11 @@ class TensorflowConverter(base_converter.ConverterInterface):
'AvgPool'
:
self
.
convert_pooling
,
'MaxPool'
:
self
.
convert_pooling
,
'Squeeze'
:
self
.
convert_identity
,
'MatMul'
:
self
.
convert_matmul
,
'Identity'
:
self
.
convert_identity
,
'Reshape'
:
self
.
convert_reshape
,
'Shape'
:
self
.
convert_nop
,
'Transpose'
:
self
.
convert_transpose
,
'Softmax'
:
self
.
convert_softmax
,
'ResizeBilinear'
:
self
.
convert_resize_bilinear
,
'Placeholder'
:
self
.
convert_nop
,
...
...
@@ -144,7 +146,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
for
i
in
xrange
(
len
(
op
.
input
)):
if
op
.
input
[
i
][
-
2
:]
==
':0'
:
op_name
=
op
.
input
[
i
][:
-
2
]
if
op_name
in
self
.
_option
.
input_nodes
:
if
op_name
in
self
.
_option
.
input_nodes
\
or
op_name
in
self
.
_option
.
output_nodes
:
op
.
input
[
i
]
=
op_name
for
i
in
xrange
(
len
(
op
.
output
)):
if
op
.
output
[
i
][
-
2
:]
==
':0'
:
...
...
@@ -411,6 +414,10 @@ class TensorflowConverter(base_converter.ConverterInterface):
self
.
_skip_tensor
.
update
(
tf_op
.
inputs
[
-
1
].
name
)
def
convert_matmul
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
MatMul
.
name
def
convert_reshape
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
Reshape
.
name
...
...
@@ -430,6 +437,20 @@ class TensorflowConverter(base_converter.ConverterInterface):
shape_arg
.
ints
.
extend
(
shape_value
)
def
convert_transpose
(
self
,
tf_op
):
perm
=
tf_op
.
inputs
[
1
].
eval
().
astype
(
np
.
int32
)
ordered_perm
=
np
.
sort
(
perm
)
mace_check
(
np
.
array_equal
(
perm
,
ordered_perm
),
"Transpose not supported yet, only internal transpose"
" in composed ops might be supported"
)
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
'Identity'
del
op
.
input
[
1
:]
self
.
_skip_tensor
.
add
(
tf_op
.
inputs
[
1
].
name
)
def
convert_mean
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
del
op
.
input
[
1
:]
...
...
mace/python/tools/converter_tool/transformer.py
浏览文件 @
a422aa26
...
...
@@ -53,9 +53,11 @@ class Transformer(base_converter.ConverterInterface):
def
__init__
(
self
,
option
,
model
):
# DO NOT reorder the following transformers
self
.
_registered_transformers_order
=
[
TransformerRule
.
REMOVE_USELESS_RESHAPE_OP
,
TransformerRule
.
REMOVE_IDENTITY_OP
,
TransformerRule
.
TRANSFORM_GLOBAL_POOLING
,
TransformerRule
.
FOLD_SOFTMAX
,
TransformerRule
.
FOLD_RESHAPE
,
TransformerRule
.
TRANSFORM_MATMUL_TO_FC
,
TransformerRule
.
FOLD_BATCHNORM
,
TransformerRule
.
FOLD_CONV_AND_BN
,
TransformerRule
.
FOLD_DEPTHWISE_CONV_AND_BN
,
...
...
@@ -72,10 +74,14 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule
.
SORT_BY_EXECUTION
,
]
self
.
_registered_transformers
=
{
TransformerRule
.
REMOVE_USELESS_RESHAPE_OP
:
self
.
remove_useless_reshape_op
,
TransformerRule
.
REMOVE_IDENTITY_OP
:
self
.
remove_identity_op
,
TransformerRule
.
TRANSFORM_GLOBAL_POOLING
:
self
.
transform_global_pooling
,
TransformerRule
.
FOLD_SOFTMAX
:
self
.
fold_softmax
,
TransformerRule
.
FOLD_RESHAPE
:
self
.
fold_reshape
,
TransformerRule
.
TRANSFORM_MATMUL_TO_FC
:
self
.
transform_matmul_to_fc
,
TransformerRule
.
FOLD_BATCHNORM
:
self
.
fold_batchnorm
,
TransformerRule
.
FOLD_CONV_AND_BN
:
self
.
fold_conv_and_bn
,
# data_format related
...
...
@@ -161,18 +167,26 @@ class Transformer(base_converter.ConverterInterface):
for
output_tensor
in
op
.
output
:
self
.
_producer
[
output_tensor
]
=
op
for
input_node
in
self
.
_option
.
input_nodes
.
values
():
op
=
mace_pb2
.
OperatorDef
()
op
.
name
=
self
.
normalize_op_name
(
input_node
.
name
)
op
.
type
=
'Input'
op
.
output
.
extend
(
input_node
.
name
)
output_shape
=
op
.
output_shape
.
add
()
output_shape
.
dims
.
extend
(
input_node
.
shape
)
if
self
.
_option
.
device
==
mace_pb2
.
CPU
:
self
.
transpose_shape
(
output_shape
.
dims
,
[
0
,
3
,
1
,
2
])
ConverterUtil
.
add_data_format_arg
(
op
,
DataFormat
.
NCHW
)
else
:
ConverterUtil
.
add_data_format_arg
(
op
,
DataFormat
.
NHWC
)
self
.
_producer
[
op
.
output
[
0
]]
=
op
input_node_existed
=
False
for
op
in
self
.
_model
.
op
:
if
input_node
.
name
in
op
.
output
:
input_node_existed
=
True
break
if
not
input_node_existed
:
op
=
mace_pb2
.
OperatorDef
()
op
.
name
=
self
.
normalize_op_name
(
input_node
.
name
)
op
.
type
=
'Input'
op
.
output
.
extend
([
input_node
.
name
])
output_shape
=
op
.
output_shape
.
add
()
output_shape
.
dims
.
extend
(
input_node
.
shape
)
if
ConverterUtil
.
data_format
(
self
.
_consumers
[
input_node
.
name
][
0
])
\
==
DataFormat
.
NCHW
:
self
.
transpose_shape
(
output_shape
.
dims
,
[
0
,
3
,
1
,
2
])
ConverterUtil
.
add_data_format_arg
(
op
,
DataFormat
.
NCHW
)
else
:
ConverterUtil
.
add_data_format_arg
(
op
,
DataFormat
.
NHWC
)
self
.
_producer
[
op
.
output
[
0
]]
=
op
@
staticmethod
def
replace
(
obj_list
,
source
,
target
):
...
...
@@ -191,6 +205,12 @@ class Transformer(base_converter.ConverterInterface):
def
normalize_op_name
(
name
):
return
name
.
replace
(
':'
,
'_'
)
def
get_tensor_shape
(
self
,
tensor
):
producer
=
self
.
_producer
[
tensor
]
for
i
in
xrange
(
len
(
producer
.
output
)):
if
producer
.
output
[
i
]
==
tensor
:
return
list
(
producer
.
output_shape
[
i
].
dims
)
def
consumer_count
(
self
,
tensor_name
):
return
len
(
self
.
_consumers
.
get
(
tensor_name
,
[]))
...
...
@@ -203,23 +223,68 @@ class Transformer(base_converter.ConverterInterface):
return
False
def
replace_output_node
(
self
,
op
):
"""if it is an output node, change output node to the op before it"""
if
self
.
is_op_output_node
(
op
):
real_output_node
=
self
.
_producer
[
op
.
input
[
0
]]
self
.
replace
(
real_output_node
.
output
,
op
.
input
[
0
],
op
.
output
[
0
])
print
(
"change %s to %s"
%
(
real_output_node
.
name
,
op
.
name
))
def
safe_remove_node
(
self
,
op
,
replace_op
):
"""remove op.
1. change the inputs of its consumers to the outputs of replace_op
2. if the op is output node, change output node to replace op"""
if
replace_op
is
None
:
# When no replace op specified, we change the inputs of
# its consumers to the input of the op. This handles the case
# that the op is identity op and its input is a tensor.
mace_check
(
len
(
op
.
output
)
==
1
and
len
(
op
.
input
)
==
1
,
"cannot remove op that w/o replace op specified"
" and input/output length > 1"
+
str
(
op
))
for
consumer_op
in
self
.
_consumers
.
get
(
op
.
output
[
0
],
[]):
self
.
replace
(
consumer_op
.
input
,
op
.
output
[
0
],
op
.
input
[
0
])
mace_check
(
op
.
output
[
0
]
not
in
self
.
_option
.
output_nodes
,
"cannot remove op that is output node"
)
else
:
mace_check
(
len
(
op
.
output
)
==
len
(
replace_op
.
output
),
"cannot remove op since len(op.output) "
"!= len(replace_op.output)"
)
for
i
in
xrange
(
len
(
op
.
output
)):
for
consumer_op
in
self
.
_consumers
.
get
(
op
.
output
[
i
],
[]):
self
.
replace
(
consumer_op
.
input
,
op
.
output
[
i
],
replace_op
.
output
[
i
])
# if the op is output node, change replace_op output name to the op
# output name
for
i
in
xrange
(
len
(
op
.
output
)):
if
op
.
output
[
i
]
in
self
.
_option
.
output_nodes
:
for
consumer
in
self
.
_consumers
.
get
(
replace_op
.
output
[
i
],
[]):
self
.
replace
(
consumer
.
input
,
replace_op
.
output
[
i
],
op
.
output
[
i
])
replace_op
.
output
[
i
]
=
op
.
output
[
i
]
self
.
_model
.
op
.
remove
(
op
)
def
remove_useless_reshape_op
(
self
):
net
=
self
.
_model
for
op
in
net
.
op
:
if
op
.
type
==
MaceOp
.
Reshape
.
name
:
shape
=
list
(
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_shape_str
).
ints
)
if
shape
==
self
.
get_tensor_shape
(
op
.
input
[
0
]):
print
(
"Remove useless reshape: %s(%s)"
%
(
op
.
name
,
op
.
type
))
op
.
type
=
'Identity'
return
False
def
remove_identity_op
(
self
):
net
=
self
.
_model
for
op
in
net
.
op
:
if
op
.
type
==
'Identity'
:
print
(
"Remove identity: %s(%s)"
%
(
op
.
name
,
op
.
type
))
for
consumer_op
in
self
.
_consumers
.
get
(
op
.
output
[
0
],
[]):
Transformer
.
replace
(
consumer_op
.
input
,
op
.
output
[
0
],
op
.
input
[
0
])
self
.
replace_output_node
(
op
)
net
.
op
.
remove
(
op
)
self
.
safe_remove_node
(
op
,
self
.
_producer
.
get
(
op
.
input
[
0
],
None
))
return
True
return
False
...
...
@@ -264,10 +329,10 @@ class Transformer(base_converter.ConverterInterface):
and
len
(
self
.
_consts
[
consumer_op
.
input
[
1
]].
dims
)
==
1
:
print
(
"Fold batchnorm: %s(%s)"
%
(
op
.
name
,
op
.
type
))
consumer_op
.
type
=
MaceOp
.
FoldedBatchNorm
.
name
inputs
=
[
op
.
input
[
0
],
op
.
input
[
1
],
consumer_op
.
input
[
1
]]
consumer_op
.
input
[:]
=
inputs
[:
]
consumer_op
.
input
[:]
=
[
op
.
input
[
0
],
op
.
input
[
1
],
consumer_op
.
input
[
1
]
]
net
.
op
.
remove
(
op
)
self
.
safe_remove_node
(
op
,
None
)
return
True
return
False
...
...
@@ -514,7 +579,7 @@ class Transformer(base_converter.ConverterInterface):
filter
.
float_data
[:]
=
weight_tensor_value
.
flat
[:]
filter
.
dims
[:]
=
weight_tensor_value
.
shape
[:]
net
.
op
.
remove
(
op
)
self
.
safe_remove_node
(
op
,
iwt_
op
)
return
False
...
...
@@ -544,10 +609,8 @@ class Transformer(base_converter.ConverterInterface):
consumer_op
=
self
.
_consumers
[
op
.
output
[
0
]][
0
]
if
consumer_op
.
type
==
MaceOp
.
BiasAdd
.
name
:
print
(
"Fold biasadd: %s(%s)"
%
(
op
.
name
,
op
.
type
))
op
.
name
=
consumer_op
.
name
op
.
input
.
append
(
consumer_op
.
input
[
1
])
op
.
output
[
0
]
=
consumer_op
.
output
[
0
]
net
.
op
.
remove
(
consumer_op
)
self
.
safe_remove_node
(
consumer_op
,
op
)
return
True
return
False
...
...
@@ -575,7 +638,7 @@ class Transformer(base_converter.ConverterInterface):
or
arg
.
name
==
MaceKeyword
.
mace_activation_max_limit_str
:
# noqa
op
.
arg
.
extend
([
arg
])
net
.
op
.
remove
(
consumer_
op
)
self
.
safe_remove_node
(
consumer_op
,
op
)
return
True
return
False
...
...
@@ -651,11 +714,14 @@ class Transformer(base_converter.ConverterInterface):
op
.
output
.
extend
([
input_node
.
name
])
output_shape
=
op
.
output_shape
.
add
()
output_shape
.
dims
.
extend
(
input_node
.
shape
)
self
.
transpose_shape
(
output_shape
.
dims
,
[
0
,
3
,
1
,
2
])
dims_arg
=
op
.
arg
.
add
()
dims_arg
.
name
=
MaceKeyword
.
mace_dims_str
dims_arg
.
ints
.
extend
([
0
,
3
,
1
,
2
])
ConverterUtil
.
add_data_format_arg
(
op
,
DataFormat
.
NCHW
)
for
output_node
in
self
.
_option
.
output_nodes
.
values
():
output_name
=
MaceKeyword
.
mace_output_node_name
\
+
'_'
+
output_node
.
name
...
...
@@ -673,6 +739,8 @@ class Transformer(base_converter.ConverterInterface):
dims_arg
.
name
=
MaceKeyword
.
mace_dims_str
dims_arg
.
ints
.
extend
([
0
,
2
,
3
,
1
])
ConverterUtil
.
add_data_format_arg
(
op
,
DataFormat
.
NHWC
)
return
False
def
transpose_filters
(
self
):
...
...
@@ -695,12 +763,19 @@ class Transformer(base_converter.ConverterInterface):
filter_data
=
filter_data
.
transpose
(
3
,
2
,
0
,
1
)
filter
.
float_data
[:]
=
filter_data
.
flat
filter
.
dims
[:]
=
filter_data
.
shape
if
op
.
type
==
MaceOp
.
FullyConnected
.
name
:
weight
=
self
.
_consts
[
op
.
input
[
1
]]
weight_data
=
np
.
array
(
weight
.
float_data
).
reshape
(
weight
.
dims
)
weight_data
=
weight_data
.
transpose
(
1
,
0
)
weight
.
float_data
[:]
=
weight_data
.
flat
weight
.
dims
[:]
=
weight_data
.
shape
self
.
set_filter_format
(
FilterFormat
.
OIHW
)
return
False
def
reshape_fc_weight
(
self
):
print
(
"Reshape fully connec
r
ted weight shape"
)
print
(
"Reshape fully connected weight shape"
)
net
=
self
.
_model
for
op
in
net
.
op
:
if
op
.
type
==
MaceOp
.
FullyConnected
.
name
:
...
...
@@ -789,6 +864,8 @@ class Transformer(base_converter.ConverterInterface):
arg
.
name
=
MaceKeyword
.
mace_buffer_type
arg
.
i
=
OpenCLBufferType
.
IN_OUT_CHANNEL
.
value
ConverterUtil
.
add_data_format_arg
(
op_def
,
DataFormat
.
NHWC
)
for
output_node
in
self
.
_option
.
output_nodes
.
values
():
output_name
=
MaceKeyword
.
mace_output_node_name
\
+
'_'
+
output_node
.
name
...
...
@@ -804,14 +881,16 @@ class Transformer(base_converter.ConverterInterface):
arg
.
name
=
MaceKeyword
.
mace_buffer_type
arg
.
i
=
OpenCLBufferType
.
IN_OUT_CHANNEL
.
value
ConverterUtil
.
add_data_format_arg
(
op_def
,
DataFormat
.
NHWC
)
return
False
def
fold_
softmax
(
self
):
def
fold_
reshape
(
self
):
changed
=
False
net
=
self
.
_model
for
op
in
net
.
op
:
if
op
.
type
==
MaceOp
.
Softmax
.
name
:
print
(
"Fold
softmax
: %s(%s)"
%
(
op
.
name
,
op
.
type
))
if
op
.
type
==
MaceOp
.
Softmax
.
name
or
op
.
type
==
MaceOp
.
MatMul
.
name
:
print
(
"Fold
reshape
: %s(%s)"
%
(
op
.
name
,
op
.
type
))
if
self
.
consumer_count
(
op
.
output
[
0
])
==
1
:
consumer
=
self
.
_consumers
[
op
.
output
[
0
]][
0
]
if
consumer
.
type
==
MaceOp
.
Reshape
.
name
:
...
...
@@ -819,15 +898,14 @@ class Transformer(base_converter.ConverterInterface):
MaceKeyword
.
mace_shape_str
).
ints
# noqa
del
op
.
output_shape
[
0
].
dims
[:]
op
.
output_shape
[
0
].
dims
.
extend
(
shape
)
self
.
replace_output_node
(
consumer
)
net
.
op
.
remove
(
consumer
)
self
.
safe_remove_node
(
consumer
,
op
)
changed
=
True
producer
=
self
.
_producer
[
op
.
input
[
0
]]
if
producer
.
type
==
MaceOp
.
Reshape
.
name
:
op
.
input
[
0
]
=
producer
.
input
[
0
]
self
.
replace_output_node
(
producer
)
net
.
op
.
remove
(
producer
)
self
.
safe_remove_node
(
producer
,
self
.
_producer
[
producer
.
input
[
0
]]
)
changed
=
True
if
len
(
op
.
output_shape
[
0
].
dims
)
<
4
:
...
...
@@ -840,6 +918,20 @@ class Transformer(base_converter.ConverterInterface):
return
False
def
transform_matmul_to_fc
(
self
):
net
=
self
.
_model
for
op
in
net
.
op
:
if
op
.
type
==
MaceOp
.
MatMul
.
name
:
input_shape
=
self
.
get_tensor_shape
(
op
.
input
[
0
])
_
,
h
,
w
,
_
=
self
.
sort_feature_map_shape
(
input_shape
,
ConverterUtil
.
data_format
(
self
.
_producer
[
op
.
input
[
0
]]))
# noqa
if
h
==
1
and
w
==
1
and
op
.
input
[
1
]
in
self
.
_consts
:
weight
=
self
.
_consts
[
op
.
input
[
1
]]
if
len
(
weight
.
dims
)
==
2
:
op
.
type
=
MaceOp
.
FullyConnected
.
name
return
False
def
transform_global_conv_to_fc
(
self
):
"""Transform global conv to fc should be placed after transposing
input/output and filter"""
...
...
@@ -918,4 +1010,8 @@ class Transformer(base_converter.ConverterInterface):
del
net
.
op
[:]
net
.
op
.
extend
(
sorted_nodes
)
print
(
"Final ops:"
)
for
op
in
net
.
op
:
print
(
"%s (%s)"
%
(
op
.
name
,
op
.
type
))
return
False
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录