Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
3e82ad67
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
3e82ad67
编写于
5月 08, 2018
作者:
李
李寅
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor model converter and transformer
上级
04f7a34a
变更
20
展开全部
隐藏空白更改
内联
并排
Showing
20 changed file
with
2411 addition
and
2802 deletion
+2411
-2802
mace/core/mace.cc
mace/core/mace.cc
+5
-5
mace/ops/fully_connected.cc
mace/ops/fully_connected.cc
+3
-3
mace/ops/fully_connected_benchmark.cc
mace/ops/fully_connected_benchmark.cc
+2
-2
mace/ops/fully_connected_test.cc
mace/ops/fully_connected_test.cc
+6
-6
mace/proto/mace.proto
mace/proto/mace.proto
+1
-0
mace/python/tools/BUILD
mace/python/tools/BUILD
+14
-22
mace/python/tools/caffe_converter_lib.py
mace/python/tools/caffe_converter_lib.py
+0
-1213
mace/python/tools/convert_util.py
mace/python/tools/convert_util.py
+6
-0
mace/python/tools/converter.py
mace/python/tools/converter.py
+72
-16
mace/python/tools/converter_tool/__init__.py
mace/python/tools/converter_tool/__init__.py
+0
-0
mace/python/tools/converter_tool/base_converter.py
mace/python/tools/converter_tool/base_converter.py
+259
-0
mace/python/tools/converter_tool/caffe_converter.py
mace/python/tools/converter_tool/caffe_converter.py
+508
-0
mace/python/tools/converter_tool/shape_inference.py
mace/python/tools/converter_tool/shape_inference.py
+149
-0
mace/python/tools/converter_tool/tensorflow_converter.py
mace/python/tools/converter_tool/tensorflow_converter.py
+442
-0
mace/python/tools/converter_tool/transformer.py
mace/python/tools/converter_tool/transformer.py
+914
-0
mace/python/tools/memory_optimizer.py
mace/python/tools/memory_optimizer.py
+10
-4
mace/python/tools/source_converter_lib.py
mace/python/tools/source_converter_lib.py
+12
-1
mace/python/tools/tf_converter_lib.py
mace/python/tools/tf_converter_lib.py
+0
-1522
mace/test/mace_api_mt_test.cc
mace/test/mace_api_mt_test.cc
+4
-4
mace/test/mace_api_test.cc
mace/test/mace_api_test.cc
+4
-4
未找到文件。
mace/core/mace.cc
浏览文件 @
3e82ad67
...
...
@@ -119,11 +119,11 @@ MaceEngine::Impl::Impl(const NetDef *net_def,
LOG
(
INFO
)
<<
"MACE version: "
<<
MaceVersion
();
// Set storage path for internal usage
for
(
auto
input_name
:
input_nodes
)
{
ws_
->
CreateTensor
(
MakeString
(
"mace_input_node_"
,
input_name
,
":0"
),
ws_
->
CreateTensor
(
MakeString
(
"mace_input_node_"
,
input_name
),
GetDeviceAllocator
(
device_type_
),
DT_FLOAT
);
}
for
(
auto
output_name
:
output_nodes
)
{
ws_
->
CreateTensor
(
MakeString
(
"mace_output_node_"
,
output_name
,
":0"
),
ws_
->
CreateTensor
(
MakeString
(
"mace_output_node_"
,
output_name
),
GetDeviceAllocator
(
device_type_
),
DT_FLOAT
);
}
#ifdef MACE_ENABLE_HEXAGON
...
...
@@ -182,7 +182,7 @@ MaceStatus MaceEngine::Impl::Run(
"The Inputs' shape must be 4-dimension with NHWC format,"
" please use 1 to fill missing dimensions"
);
Tensor
*
input_tensor
=
ws_
->
GetTensor
(
MakeString
(
"mace_input_node_"
,
input
.
first
,
":0"
));
ws_
->
GetTensor
(
MakeString
(
"mace_input_node_"
,
input
.
first
));
input_tensor
->
Resize
(
input
.
second
.
shape
());
{
Tensor
::
MappingGuard
input_guard
(
input_tensor
);
...
...
@@ -199,7 +199,7 @@ MaceStatus MaceEngine::Impl::Run(
" please use 1 to fill missing dimensions"
);
}
Tensor
*
output_tensor
=
ws_
->
GetTensor
(
MakeString
(
"mace_output_node_"
,
output
.
first
+
":0"
));
ws_
->
GetTensor
(
MakeString
(
"mace_output_node_"
,
output
.
first
));
output_tensors
.
push_back
(
output_tensor
);
}
#ifdef MACE_ENABLE_HEXAGON
...
...
@@ -223,7 +223,7 @@ MaceStatus MaceEngine::Impl::Run(
#endif
for
(
auto
&
output
:
*
outputs
)
{
Tensor
*
output_tensor
=
ws_
->
GetTensor
(
MakeString
(
"mace_output_node_"
,
output
.
first
+
":0"
));
ws_
->
GetTensor
(
MakeString
(
"mace_output_node_"
,
output
.
first
));
// save output
if
(
output_tensor
!=
nullptr
&&
output
.
second
.
data
()
!=
nullptr
)
{
Tensor
::
MappingGuard
output_guard
(
output_tensor
);
...
...
mace/ops/fully_connected.cc
浏览文件 @
3e82ad67
...
...
@@ -18,20 +18,20 @@ namespace mace {
namespace
ops
{
void
Register_FullyConnected
(
OperatorRegistry
*
op_registry
)
{
REGISTER_OPERATOR
(
op_registry
,
OpKeyBuilder
(
"F
C
"
)
REGISTER_OPERATOR
(
op_registry
,
OpKeyBuilder
(
"F
ullyConnected
"
)
.
Device
(
DeviceType
::
CPU
)
.
TypeConstraint
<
float
>
(
"T"
)
.
Build
(),
FullyConnectedOp
<
DeviceType
::
CPU
,
float
>
);
#ifdef MACE_ENABLE_OPENCL
REGISTER_OPERATOR
(
op_registry
,
OpKeyBuilder
(
"F
C
"
)
REGISTER_OPERATOR
(
op_registry
,
OpKeyBuilder
(
"F
ullyConnected
"
)
.
Device
(
DeviceType
::
GPU
)
.
TypeConstraint
<
float
>
(
"T"
)
.
Build
(),
FullyConnectedOp
<
DeviceType
::
GPU
,
float
>
);
REGISTER_OPERATOR
(
op_registry
,
OpKeyBuilder
(
"F
C
"
)
REGISTER_OPERATOR
(
op_registry
,
OpKeyBuilder
(
"F
ullyConnected
"
)
.
Device
(
DeviceType
::
GPU
)
.
TypeConstraint
<
half
>
(
"T"
)
.
Build
(),
...
...
mace/ops/fully_connected_benchmark.cc
浏览文件 @
3e82ad67
...
...
@@ -37,7 +37,7 @@ void FCBenchmark(
net
.
AddRandomInput
<
D
,
float
>
(
"Bias"
,
{
out_channel
});
if
(
D
==
DeviceType
::
CPU
)
{
OpDefBuilder
(
"F
C
"
,
"FullyConnectedTest"
)
OpDefBuilder
(
"F
ullyConnected
"
,
"FullyConnectedTest"
)
.
Input
(
"Input"
)
.
Input
(
"Weight"
)
.
Input
(
"Bias"
)
...
...
@@ -52,7 +52,7 @@ void FCBenchmark(
BufferToImage
<
D
,
T
>
(
&
net
,
"Bias"
,
"BiasImage"
,
kernels
::
BufferType
::
ARGUMENT
);
OpDefBuilder
(
"F
C
"
,
"FullyConnectedTest"
)
OpDefBuilder
(
"F
ullyConnected
"
,
"FullyConnectedTest"
)
.
Input
(
"InputImage"
)
.
Input
(
"WeightImage"
)
.
Input
(
"BiasImage"
)
...
...
mace/ops/fully_connected_test.cc
浏览文件 @
3e82ad67
...
...
@@ -42,7 +42,7 @@ void Simple(const std::vector<index_t> &input_shape,
if
(
D
==
DeviceType
::
CPU
)
{
net
.
Transpose2D
<
D
,
float
>
(
"Weight"
,
"WeightTranspose"
);
OpDefBuilder
(
"F
C
"
,
"FullyConnectedTest"
)
OpDefBuilder
(
"F
ullyConnected
"
,
"FullyConnectedTest"
)
.
Input
(
"Input"
)
.
Input
(
"Weight"
)
.
Input
(
"Bias"
)
...
...
@@ -59,7 +59,7 @@ void Simple(const std::vector<index_t> &input_shape,
BufferToImage
<
D
,
float
>
(
&
net
,
"Bias"
,
"BiasImage"
,
kernels
::
BufferType
::
ARGUMENT
);
OpDefBuilder
(
"F
C
"
,
"FullyConnectedTest"
)
OpDefBuilder
(
"F
ullyConnected
"
,
"FullyConnectedTest"
)
.
Input
(
"InputImage"
)
.
Input
(
"WeightImage"
)
.
Input
(
"BiasImage"
)
...
...
@@ -142,7 +142,7 @@ void Complex(const index_t batch,
"Weight"
,
{
out_channel
,
height
*
width
*
channels
});
net
.
AddRandomInput
<
DeviceType
::
GPU
,
float
>
(
"Bias"
,
{
out_channel
});
OpDefBuilder
(
"F
C
"
,
"FullyConnectedTest"
)
OpDefBuilder
(
"F
ullyConnected
"
,
"FullyConnectedTest"
)
.
Input
(
"Input"
)
.
Input
(
"Weight"
)
.
Input
(
"Bias"
)
...
...
@@ -166,7 +166,7 @@ void Complex(const index_t batch,
BufferToImage
<
DeviceType
::
GPU
,
float
>
(
&
net
,
"Bias"
,
"BiasImage"
,
kernels
::
BufferType
::
ARGUMENT
);
OpDefBuilder
(
"F
C
"
,
"FullyConnectedTest"
)
OpDefBuilder
(
"F
ullyConnected
"
,
"FullyConnectedTest"
)
.
Input
(
"InputImage"
)
.
Input
(
"WeightImage"
)
.
Input
(
"BiasImage"
)
...
...
@@ -231,7 +231,7 @@ void TestWXFormat(const index_t batch,
"Weight"
,
{
out_channel
,
height
*
width
*
channels
});
net
.
AddRandomInput
<
DeviceType
::
GPU
,
float
>
(
"Bias"
,
{
out_channel
});
OpDefBuilder
(
"F
C
"
,
"FullyConnectedTest"
)
OpDefBuilder
(
"F
ullyConnected
"
,
"FullyConnectedTest"
)
.
Input
(
"Input"
)
.
Input
(
"Weight"
)
.
Input
(
"Bias"
)
...
...
@@ -255,7 +255,7 @@ void TestWXFormat(const index_t batch,
BufferToImage
<
DeviceType
::
GPU
,
T
>
(
&
net
,
"Bias"
,
"BiasImage"
,
kernels
::
BufferType
::
ARGUMENT
);
OpDefBuilder
(
"F
C
"
,
"FullyConnectedTest"
)
OpDefBuilder
(
"F
ullyConnected
"
,
"FullyConnectedTest"
)
.
Input
(
"InputImage"
)
.
Input
(
"WeightImage"
)
.
Input
(
"BiasImage"
)
...
...
mace/proto/mace.proto
浏览文件 @
3e82ad67
...
...
@@ -10,6 +10,7 @@ enum NetMode {
enum
DeviceType
{
CPU
=
0
;
// In default, we will use CPU.
GPU
=
2
;
HEXAGON
=
3
;
}
enum
DataType
{
...
...
mace/python/tools/BUILD
浏览文件 @
3e82ad67
py_library
(
name
=
"
tf_
converter_lib"
,
name
=
"converter_lib"
,
srcs
=
[
"convert_util.py"
,
"graph_util.py"
,
"tf_converter_lib.py"
,
"tf_dsp_converter_lib.py"
,
"converter_tool/base_converter.py"
,
"converter_tool/shape_inference.py"
,
"converter_tool/tensorflow_converter.py"
,
"converter_tool/caffe_converter.py"
,
"converter_tool/transformer.py"
,
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":memory_optimizer"
,
"//mace/proto:mace_py"
,
],
)
py_library
(
name
=
"caffe_converter_lib"
,
srcs
=
[
"caffe_converter_lib.py"
,
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":memory_optimizer"
,
"//mace/third_party/caffe:caffe_py"
,
],
)
...
...
@@ -37,22 +30,21 @@ py_library(
)
py_binary
(
name
=
"
convert
er"
,
srcs
=
[
"
convert
er.py"
],
name
=
"
memory_optimiz
er"
,
srcs
=
[
"
memory_optimiz
er.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":caffe_converter_lib"
,
":source_converter_lib"
,
":tf_converter_lib"
,
"@six_archive//:six"
,
"//mace/proto:mace_py"
,
],
)
py_binary
(
name
=
"
memory_optimiz
er"
,
srcs
=
[
"
memory_optimiz
er.py"
],
name
=
"
convert
er"
,
srcs
=
[
"
convert
er.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
"//mace/proto:mace_py"
,
":converter_lib"
,
":source_converter_lib"
,
"@six_archive//:six"
,
],
)
mace/python/tools/caffe_converter_lib.py
已删除
100644 → 0
浏览文件 @
04f7a34a
此差异已折叠。
点击以展开。
mace/python/tools/convert_util.py
浏览文件 @
3e82ad67
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
tensorflow
as
tf
from
mace.proto
import
mace_pb2
...
...
@@ -40,3 +41,8 @@ def tf_dtype_2_mace_dtype(tf_dtype):
if
not
mace_dtype
:
raise
Exception
(
"Not supported tensorflow dtype: "
+
tf_dtype
)
return
mace_dtype
def
mace_check
(
condition
,
msg
):
if
not
condition
:
raise
Exception
(
msg
)
mace/python/tools/converter.py
浏览文件 @
3e82ad67
...
...
@@ -16,7 +16,16 @@ import argparse
import
sys
import
hashlib
import
os.path
from
mace.proto
import
mace_pb2
from
mace.python.tools
import
tf_dsp_converter_lib
from
mace.python.tools
import
memory_optimizer
from
mace.python.tools
import
source_converter_lib
from
mace.python.tools.converter_tool
import
base_converter
as
cvt
from
mace.python.tools.converter_tool
import
tensorflow_converter
from
mace.python.tools.converter_tool
import
caffe_converter
from
mace.python.tools.converter_tool
import
transformer
# ./bazel-bin/mace/python/tools/tf_converter --model_file quantized_test.pb \
# --output quantized_test_dsp.pb \
...
...
@@ -25,6 +34,12 @@ from mace.python.tools import source_converter_lib
FLAGS
=
None
data_type_map
=
{
'DT_HALF'
:
mace_pb2
.
DT_HALF
,
'DT_FLOAT'
:
mace_pb2
.
DT_FLOAT
}
device_type_map
=
{
'cpu'
:
mace_pb2
.
CPU
,
'gpu'
:
mace_pb2
.
GPU
,
'dsp'
:
mace_pb2
.
HEXAGON
}
def
file_checksum
(
fname
):
hash_func
=
hashlib
.
sha256
()
...
...
@@ -34,6 +49,10 @@ def file_checksum(fname):
return
hash_func
.
hexdigest
()
def
parse_int_array_from_str
(
ints_str
):
return
[
int
(
int_str
)
for
int_str
in
ints_str
.
split
(
','
)]
def
main
(
unused_args
):
if
not
os
.
path
.
isfile
(
FLAGS
.
model_file
):
print
(
"Input graph file '"
+
FLAGS
.
model_file
+
"' does not exist!"
)
...
...
@@ -59,27 +78,64 @@ def main(unused_args):
(
weight_checksum
,
FLAGS
.
weight_checksum
))
sys
.
exit
(
-
1
)
if
FLAGS
.
runtime
==
'dsp'
:
print
(
"DSP not support caffe model yet."
)
sys
.
exit
(
-
1
)
if
FLAGS
.
platform
not
in
[
'tensorflow'
,
'caffe'
]:
print
(
"platform %s is not supported."
%
FLAGS
.
platform
)
sys
.
exit
(
-
1
)
if
FLAGS
.
runtime
not
in
[
'cpu'
,
'gpu'
,
'dsp'
]:
print
(
"runtime %s is not supported."
%
FLAGS
.
runtime
)
sys
.
exit
(
-
1
)
from
mace.python.tools
import
caffe_converter_lib
output_graph_def
=
caffe_converter_lib
.
convert_to_mace_pb
(
FLAGS
.
model_file
,
FLAGS
.
weight_file
,
FLAGS
.
input_node
,
FLAGS
.
input_shape
,
FLAGS
.
output_node
,
FLAGS
.
data_type
,
FLAGS
.
runtime
,
FLAGS
.
winograd
)
elif
FLAGS
.
platform
==
'tensorflow'
:
if
FLAGS
.
runtime
==
'dsp'
:
from
mace.python.tools
import
tf_dsp_converter_lib
if
FLAGS
.
runtime
==
'dsp'
:
if
FLAGS
.
platform
==
'tensorflow'
:
output_graph_def
=
tf_dsp_converter_lib
.
convert_to_mace_pb
(
FLAGS
.
model_file
,
FLAGS
.
input_node
,
FLAGS
.
output_node
,
FLAGS
.
dsp_mode
)
else
:
from
mace.python.tools
import
tf_converter_lib
output_graph_def
=
tf_converter_lib
.
convert_to_mace_pb
(
FLAGS
.
model_file
,
FLAGS
.
input_node
,
FLAGS
.
input_shape
,
FLAGS
.
output_node
,
FLAGS
.
data_type
,
FLAGS
.
runtime
,
FLAGS
.
winograd
)
print
(
"%s does not support dsp runtime yet."
%
FLAGS
.
platform
)
sys
.
exit
(
-
1
)
else
:
option
=
cvt
.
ConverterOption
()
option
.
data_type
=
data_type_map
[
FLAGS
.
data_type
]
option
.
device
=
device_type_map
[
FLAGS
.
runtime
]
option
.
winograd_enabled
=
bool
(
FLAGS
.
winograd
)
input_node_names
=
FLAGS
.
input_node
.
split
(
','
)
input_node_shapes
=
FLAGS
.
input_shape
.
split
(
':'
)
if
len
(
input_node_names
)
!=
len
(
input_node_shapes
):
raise
Exception
(
'input node count and shape count do not match.'
)
for
i
in
xrange
(
len
(
input_node_names
)):
input_node
=
cvt
.
NodeInfo
()
input_node
.
name
=
input_node_names
[
i
]
input_node
.
shape
=
parse_int_array_from_str
(
FLAGS
.
input_shape
)
option
.
add_input_node
(
input_node
)
output_node_names
=
FLAGS
.
output_node
.
split
(
','
)
for
i
in
xrange
(
len
(
output_node_names
)):
output_node
=
cvt
.
NodeInfo
()
output_node
.
name
=
output_node_names
[
i
]
option
.
add_output_node
(
output_node
)
print
(
"Convert model to mace model."
)
if
FLAGS
.
platform
==
'tensorflow'
:
converter
=
tensorflow_converter
.
TensorflowConverter
(
option
,
FLAGS
.
model_file
)
# noqa
elif
FLAGS
.
platform
==
'caffe'
:
converter
=
caffe_converter
.
CaffeConverter
(
option
,
FLAGS
.
model_file
,
FLAGS
.
weight_file
)
output_graph_def
=
converter
.
run
()
print
(
"Transform model to one that can better run on device."
)
# TODO(liuqi/liyin): transform gpu/cpu and merge their ops
mace_transformer
=
transformer
.
Transformer
(
option
,
output_graph_def
)
output_graph_def
=
mace_transformer
.
run
()
print
"start optimize memory."
if
FLAGS
.
runtime
==
'gpu'
:
memory_optimizer
.
optimize_gpu_memory
(
output_graph_def
)
elif
FLAGS
.
runtime
==
'cpu'
:
memory_optimizer
.
optimize_cpu_memory
(
output_graph_def
)
print
"Memory optimization done."
if
FLAGS
.
output_type
==
'source'
:
source_converter_lib
.
convert_to_source
(
...
...
mace/python/tools/converter_tool/__init__.py
0 → 100644
浏览文件 @
3e82ad67
mace/python/tools/converter_tool/base_converter.py
0 → 100644
浏览文件 @
3e82ad67
from
enum
import
Enum
from
mace.proto
import
mace_pb2
class
DataFormat
(
Enum
):
NHWC
=
0
NCHW
=
1
class
FilterFormat
(
Enum
):
HWIO
=
0
OIHW
=
1
HWOI
=
2
class
PaddingMode
(
Enum
):
VALID
=
0
SAME
=
1
FULL
=
2
class
PoolingType
(
Enum
):
AVG
=
1
MAX
=
2
class
ActivationType
(
Enum
):
NOOP
=
0
RELU
=
1
RELUX
=
2
PRELU
=
3
TANH
=
4
SIGMOID
=
5
class
EltwiseType
(
Enum
):
SUM
=
0
SUB
=
1
PROD
=
2
DIV
=
3
MIN
=
4
MAX
=
5
NEG
=
6
ABS
=
7
SQR_DIFF
=
8
POW
=
9
MaceSupportedOps
=
[
'Activation'
,
'AddN'
,
'BatchNorm'
,
'BatchToSpaceND'
,
'BiasAdd'
,
'ChannelShuffle'
,
'Concat'
,
'Conv2D'
,
'Deconv2D'
,
'DepthToSpace'
,
'DepthwiseConv2d'
,
'Dequantize'
,
'Eltwise'
,
'FoldedBatchNorm'
,
'FullyConnected'
,
'LocalResponseNorm'
,
'MatMul'
,
'Pad'
,
'Pooling'
,
'Proposal'
,
'PSROIAlign'
,
'Quantize'
,
'Requantize'
,
'Reshape'
,
'ResizeBilinear'
,
'Slice'
,
'Softmax'
,
'SpaceToBatchND'
,
'SpaceToDepth'
,
'Transpose'
,
'WinogradInverseTransform'
,
'WinogradTransform'
,
]
MaceOp
=
Enum
(
'MaceOp'
,
[(
op
,
op
)
for
op
in
MaceSupportedOps
],
type
=
str
)
class
MaceKeyword
(
object
):
# node related str
mace_input_node_name
=
'mace_input_node'
mace_output_node_name
=
'mace_output_node'
mace_buffer_type
=
'buffer_type'
mace_mode
=
'mode'
mace_buffer_to_image
=
'BufferToImage'
mace_image_to_buffer
=
'ImageToBuffer'
# arg related str
mace_padding_str
=
'padding'
mace_padding_values_str
=
'padding_values'
mace_strides_str
=
'strides'
mace_dilations_str
=
'dilations'
mace_pooling_type_str
=
'pooling_type'
mace_global_pooling_str
=
'global_pooling'
mace_kernel_str
=
'kernels'
mace_data_format_str
=
'data_format'
mace_filter_format_str
=
'filter_format'
mace_element_type_str
=
'type'
mace_activation_type_str
=
'activation'
mace_activation_max_limit_str
=
'max_limit'
mace_resize_size_str
=
'size'
mace_batch_to_space_crops_str
=
'crops'
mace_paddings_str
=
'paddings'
mace_align_corners_str
=
'align_corners'
mace_space_batch_block_shape_str
=
'block_shape'
mace_space_depth_block_size_str
=
'block_size'
mace_constant_value_str
=
'constant_value'
mace_dims_str
=
'dims'
mace_axis_str
=
'axis'
mace_shape_str
=
'shape'
mace_winograd_filter_transformed
=
'is_filter_transformed'
class
ConverterInterface
(
object
):
"""Base class for converting external models to mace models."""
def
run
(
self
):
raise
NotImplementedError
(
'run'
)
class
NodeInfo
(
object
):
"""A class for describing node information"""
def
__init__
(
self
):
self
.
_name
=
None
self
.
_shape
=
[]
@
property
def
name
(
self
):
return
self
.
_name
@
property
def
shape
(
self
):
return
self
.
_shape
@
name
.
setter
def
name
(
self
,
name
):
self
.
_name
=
name
@
shape
.
setter
def
shape
(
self
,
shape
):
self
.
_shape
=
shape
def
__str__
(
self
):
return
'%s %s'
%
(
self
.
_name
,
str
(
self
.
_shape
))
class
ConverterOption
(
object
):
"""A class for specifying options passed to converter tool"""
def
__init__
(
self
):
self
.
_input_nodes
=
{}
self
.
_output_nodes
=
{}
self
.
_data_type
=
mace_pb2
.
DT_FLOAT
self
.
_device
=
mace_pb2
.
CPU
self
.
_winograd_enabled
=
False
@
property
def
input_nodes
(
self
):
return
self
.
_input_nodes
@
property
def
output_nodes
(
self
):
return
self
.
_output_nodes
@
property
def
data_type
(
self
):
return
self
.
_data_type
@
property
def
device
(
self
):
return
self
.
_device
@
property
def
winograd_enabled
(
self
):
return
self
.
_winograd_enabled
@
input_nodes
.
setter
def
input_nodes
(
self
,
input_nodes
):
for
node
in
input_nodes
:
self
.
_input_nodes
[
node
.
name
]
=
node
def
add_input_node
(
self
,
input_node
):
self
.
_input_nodes
[
input_node
.
name
]
=
input_node
@
output_nodes
.
setter
def
output_nodes
(
self
,
output_nodes
):
for
node
in
output_nodes
:
self
.
output_nodes
[
node
.
name
]
=
node
def
add_output_node
(
self
,
output_node
):
self
.
_output_nodes
[
output_node
.
name
]
=
output_node
@
data_type
.
setter
def
data_type
(
self
,
data_type
):
self
.
_data_type
=
data_type
@
device
.
setter
def
device
(
self
,
device
):
self
.
_device
=
device
@
winograd_enabled
.
setter
def
winograd_enabled
(
self
,
winograd_enabled
):
self
.
_winograd_enabled
=
winograd_enabled
class
ConverterUtil
(
object
):
@
staticmethod
def
get_arg
(
op
,
arg_name
):
for
arg
in
op
.
arg
:
if
arg
.
name
==
arg_name
:
return
arg
return
None
@
staticmethod
def
add_data_format_arg
(
op
,
data_format
):
data_format_arg
=
op
.
arg
.
add
()
data_format_arg
.
name
=
MaceKeyword
.
mace_data_format_str
data_format_arg
.
i
=
data_format
.
value
@
staticmethod
def
data_format
(
op
):
arg
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_data_format_str
)
if
arg
is
None
:
return
None
elif
arg
.
i
==
DataFormat
.
NHWC
.
value
:
return
DataFormat
.
NHWC
elif
arg
.
i
==
DataFormat
.
NCHW
.
value
:
return
DataFormat
.
NCHW
else
:
return
None
@
staticmethod
def
set_filter_format
(
net
,
filter_format
):
arg
=
net
.
arg
.
add
()
arg
.
name
=
MaceKeyword
.
mace_filter_format_str
arg
.
i
=
filter_format
.
value
@
staticmethod
def
filter_format
(
net
):
arg
=
ConverterUtil
.
get_arg
(
net
,
MaceKeyword
.
mace_filter_format_str
)
if
arg
is
None
:
return
None
elif
arg
.
i
==
FilterFormat
.
HWIO
.
value
:
return
FilterFormat
.
HWIO
elif
arg
.
i
==
FilterFormat
.
HWOI
.
value
:
return
FilterFormat
.
HWOI
elif
arg
.
i
==
FilterFormat
.
OIHW
.
value
:
return
FilterFormat
.
OIHW
else
:
return
None
mace/python/tools/converter_tool/caffe_converter.py
0 → 100644
浏览文件 @
3e82ad67
import
math
import
numpy
as
np
import
google.protobuf.text_format
from
mace.proto
import
mace_pb2
from
mace.third_party.caffe
import
caffe_pb2
from
mace.python.tools.converter_tool
import
base_converter
from
mace.python.tools.converter_tool
import
shape_inference
from
mace.python.tools.converter_tool.base_converter
import
PoolingType
from
mace.python.tools.converter_tool.base_converter
import
ActivationType
from
mace.python.tools.converter_tool.base_converter
import
EltwiseType
from
mace.python.tools.converter_tool.base_converter
import
DataFormat
from
mace.python.tools.converter_tool.base_converter
import
FilterFormat
from
mace.python.tools.converter_tool.base_converter
import
MaceOp
from
mace.python.tools.converter_tool.base_converter
import
MaceKeyword
from
mace.python.tools.converter_tool.base_converter
import
ConverterUtil
from
mace.python.tools.convert_util
import
mace_check
caffe_group_str
=
'group'
caffe_kernel_h_str
=
'kernel_h'
caffe_kernel_w_str
=
'kernel_w'
caffe_stride_h_str
=
'stride_h'
caffe_stride_w_str
=
'stride_w'
caffe_pad_h_str
=
'pad_h'
caffe_pad_w_str
=
'pad_w'
class
CaffeOperator
(
object
):
"""CaffeOperator merges and provides both layer and weights information.
Layer records caffe layer proto, while blobs records the weight data in
format of numpy ndarray.
"""
def
__init__
(
self
):
self
.
_layer
=
None
self
.
_blobs
=
None
@
property
def
name
(
self
):
return
self
.
_layer
.
name
@
property
def
type
(
self
):
return
self
.
_layer
.
type
@
property
def
layer
(
self
):
return
self
.
_layer
@
property
def
blobs
(
self
):
return
self
.
_blobs
@
layer
.
setter
def
layer
(
self
,
layer
):
self
.
_layer
=
layer
@
blobs
.
setter
def
blobs
(
self
,
blobs
):
self
.
_blobs
=
[
self
.
blob_to_nparray
(
blob
)
for
blob
in
blobs
]
def
get_blob
(
self
,
index
):
mace_check
(
index
<
len
(
self
.
_blobs
),
"blob out of index"
)
return
self
.
_blobs
[
index
]
@
staticmethod
def
blob_to_nparray
(
blob
):
if
blob
.
num
!=
0
:
return
(
np
.
asarray
(
blob
.
data
,
dtype
=
np
.
float32
).
reshape
(
(
blob
.
num
,
blob
.
channels
,
blob
.
height
,
blob
.
width
)))
else
:
return
np
.
asarray
(
blob
.
data
,
dtype
=
np
.
float32
).
reshape
(
blob
.
shape
.
dim
)
class
CaffeNet
(
object
):
"""CaffeNet contains caffe operations. Output of each layer has unique
name as we replace duplicated output name with unique one, while keep
mace input/output name which user specifies unchanged."""
def
__init__
(
self
):
self
.
_ops
=
{}
self
.
_consumers
=
{}
# for in-place op, its input name is the same with output name,
# so we change the output name to an alias
self
.
_alias_op_output_name
=
{}
self
.
_used_op_output_name
=
set
()
@
property
def
ops
(
self
):
return
self
.
_ops
.
values
()
def
get_op
(
self
,
op_name
):
return
self
.
_ops
.
get
(
op_name
,
None
)
def
get_consumers
(
self
,
tensor_name
):
return
self
.
_consumers
.
get
(
tensor_name
,
[])
def
add_layer
(
self
,
layer
):
op
=
CaffeOperator
()
op
.
layer
=
layer
self
.
_ops
[
layer
.
name
]
=
op
# change op output name if it is an in-place op
layer
.
bottom
[:]
=
[
self
.
_alias_op_output_name
.
get
(
layer_input
,
layer_input
)
for
layer_input
in
layer
.
bottom
][:]
for
i
in
xrange
(
len
(
layer
.
top
)):
old_name
=
layer
.
top
[
i
]
if
layer
.
type
==
'Input'
:
new_name
=
old_name
else
:
idx
=
0
new_name
=
old_name
+
'#'
+
str
(
idx
)
while
new_name
in
self
.
_used_op_output_name
:
idx
+=
1
new_name
=
old_name
+
'#'
+
str
(
idx
)
layer
.
top
[
i
]
=
new_name
self
.
_alias_op_output_name
[
old_name
]
=
new_name
self
.
_used_op_output_name
.
update
([
new_name
])
for
input_tensor
in
layer
.
bottom
:
if
input_tensor
not
in
self
.
_consumers
:
self
.
_consumers
[
input_tensor
]
=
[]
self
.
_consumers
[
input_tensor
].
append
(
op
)
def
add_blob
(
self
,
weight
):
if
weight
.
name
in
self
.
_ops
:
op
=
self
.
_ops
[
weight
.
name
]
op
.
blobs
=
list
(
weight
.
blobs
)
class
CaffeConverter
(
base_converter
.
ConverterInterface
):
"""A class for convert caffe model to mace model."""
pooling_type_mode
=
{
caffe_pb2
.
PoolingParameter
.
AVE
:
PoolingType
.
AVG
,
caffe_pb2
.
PoolingParameter
.
MAX
:
PoolingType
.
MAX
}
eltwise_type
=
{
caffe_pb2
.
EltwiseParameter
.
PROD
:
EltwiseType
.
PROD
,
caffe_pb2
.
EltwiseParameter
.
SUM
:
EltwiseType
.
SUM
,
caffe_pb2
.
EltwiseParameter
.
MAX
:
EltwiseType
.
MAX
,
}
activation_type
=
{
'ReLU'
:
ActivationType
.
RELU
,
'PReLU'
:
ActivationType
.
PRELU
,
'TanH'
:
ActivationType
.
TANH
,
}
def
__init__
(
self
,
option
,
src_model_file
,
src_weight_file
):
self
.
_op_converters
=
{
'Input'
:
self
.
convert_nop
,
'Convolution'
:
self
.
convert_conv2d
,
'Eltwise'
:
self
.
convert_elementwise
,
'Add'
:
self
.
convert_add
,
'ReLU'
:
self
.
convert_activation
,
'TanH'
:
self
.
convert_activation
,
'Sigmoid'
:
self
.
convert_activation
,
'PReLU'
:
self
.
convert_activation
,
'Pooling'
:
self
.
convert_pooling
,
'Concat'
:
self
.
convert_concat
,
'Slice'
:
self
.
convert_slice
,
'Softmax'
:
self
.
convert_softmax
,
'InnerProduct'
:
self
.
convert_fully_connected
,
'BatchNorm'
:
self
.
convert_folded_batchnorm
,
}
self
.
_option
=
option
self
.
_mace_net_def
=
mace_pb2
.
NetDef
()
ConverterUtil
.
set_filter_format
(
self
.
_mace_net_def
,
FilterFormat
.
OIHW
)
self
.
_caffe_net
=
CaffeNet
()
self
.
_caffe_layers
=
caffe_pb2
.
NetParameter
()
caffe_weights
=
caffe_pb2
.
NetParameter
()
# parse prototxt
with
open
(
src_model_file
,
'rb'
)
as
f
:
google
.
protobuf
.
text_format
.
Merge
(
str
(
f
.
read
()),
self
.
_caffe_layers
)
self
.
filter_test_layers
(
self
.
_caffe_layers
)
for
layer
in
self
.
_caffe_layers
.
layer
:
self
.
_caffe_net
.
add_layer
(
layer
)
# parse model weight
with
open
(
src_weight_file
,
'rb'
)
as
f
:
caffe_weights
.
ParseFromString
(
f
.
read
())
self
.
filter_test_layers
(
caffe_weights
)
for
weight
in
caffe_weights
.
layer
:
self
.
_caffe_net
.
add_blob
(
weight
)
self
.
_skip_ops
=
[]
def
run
(
self
):
self
.
convert_ops
()
shape_inferer
=
shape_inference
.
ShapeInference
(
self
.
_mace_net_def
,
self
.
_option
.
input_nodes
.
values
())
shape_inferer
.
run
()
self
.
replace_output_tensor_name
()
return
self
.
_mace_net_def
@
staticmethod
def
replace_input_name
(
ops
,
src_name
,
dst_name
):
for
op
in
ops
:
for
i
in
xrange
(
len
(
op
.
input
)):
if
op
.
input
[
i
]
==
src_name
:
op
.
input
[
i
]
=
dst_name
def
replace_output_tensor_name
(
self
):
consumers
=
{}
for
op
in
self
.
_mace_net_def
.
op
:
for
input_name
in
op
.
input
:
if
input_name
not
in
consumers
:
consumers
[
input_name
]
=
[]
consumers
[
input_name
].
append
(
op
)
# replace the last op with same prefix name with the original top name
ops
=
[
op
for
op
in
self
.
_mace_net_def
.
op
]
ops
.
reverse
()
visited
=
set
()
for
op
in
ops
:
for
i
in
xrange
(
len
(
op
.
output
)):
original_output_name
=
op
.
output
[
i
].
split
(
'#'
)[
0
]
if
original_output_name
not
in
visited
:
self
.
replace_input_name
(
consumers
.
get
(
op
.
output
[
i
],
[]),
op
.
output
[
i
],
original_output_name
)
op
.
output
[
i
]
=
original_output_name
visited
.
update
([
original_output_name
])
# if user set op name as output node, replace it with op name
for
op
in
self
.
_mace_net_def
.
op
:
if
op
.
name
in
self
.
_option
.
output_nodes
:
if
len
(
op
.
output
)
>
0
:
self
.
replace_input_name
(
consumers
.
get
(
op
.
output
[
0
],
[]),
op
.
output
,
op
.
name
)
op
.
output
[
0
]
=
op
.
name
@
staticmethod
def
filter_test_layers
(
layers
):
phase_map
=
{
0
:
'train'
,
1
:
'test'
}
while
True
:
changed
=
False
for
layer
in
layers
.
layer
:
phase
=
'test'
if
len
(
layer
.
include
):
phase
=
phase_map
[
layer
.
include
[
0
].
phase
]
if
len
(
layer
.
exclude
):
phase
=
phase_map
[
layer
.
exclude
[
0
].
phase
]
if
phase
!=
'test'
or
layer
.
type
==
'Dropout'
:
print
(
"Remove layer %s (%s)"
%
(
layer
.
name
,
layer
.
type
))
layers
.
layer
.
remove
(
layer
)
changed
=
True
break
if
not
changed
:
break
@
staticmethod
def
add_stride_pad_kernel_arg
(
param
,
op_def
):
try
:
if
len
(
param
.
stride
)
>
1
or
len
(
param
.
kernel_size
)
>
1
or
len
(
param
.
pad
)
>
1
:
raise
Exception
(
'Mace does not support multiple stride/kernel_size/pad'
)
stride
=
[
param
.
stride
[
0
],
param
.
stride
[
0
]]
if
len
(
param
.
stride
)
else
[
1
,
1
]
pad
=
[
param
.
pad
[
0
]
*
2
,
param
.
pad
[
0
]
*
2
]
if
len
(
param
.
pad
)
else
[
0
,
0
]
kernel
=
[
param
.
kernel_size
[
0
],
param
.
kernel_size
[
0
]]
if
len
(
param
.
kernel_size
)
else
[
0
,
0
]
except
TypeError
:
stride
=
[
param
.
stride
,
param
.
stride
]
pad
=
[
param
.
pad
*
2
,
param
.
pad
*
2
]
kernel
=
[
param
.
kernel_size
,
param
.
kernel_size
]
if
param
.
HasField
(
caffe_stride_h_str
)
or
param
.
HasField
(
caffe_stride_w_str
):
stride
=
[
param
.
stride_h
,
param
.
stride_w
]
if
param
.
HasField
(
caffe_pad_h_str
)
or
param
.
HasField
(
caffe_pad_w_str
):
pad
=
[
param
.
pad_h
*
2
,
param
.
pad_w
*
2
]
strides_arg
=
op_def
.
arg
.
add
()
strides_arg
.
name
=
MaceKeyword
.
mace_strides_str
strides_arg
.
ints
.
extend
(
stride
)
padding_arg
=
op_def
.
arg
.
add
()
padding_arg
.
name
=
MaceKeyword
.
mace_padding_values_str
padding_arg
.
ints
.
extend
(
pad
)
if
op_def
.
type
==
MaceOp
.
Pooling
.
name
:
if
param
.
HasField
(
caffe_kernel_h_str
)
or
param
.
HasField
(
caffe_kernel_w_str
):
kernel
=
[
param
.
kernel_h
,
param
.
kernel_w
]
kernels_arg
=
op_def
.
arg
.
add
()
kernels_arg
.
name
=
MaceKeyword
.
mace_kernel_str
kernels_arg
.
ints
.
extend
(
kernel
)
if
param
.
HasField
(
'global_pooling'
):
global_pooling_arg
=
op_def
.
arg
.
add
()
global_pooling_arg
.
name
=
MaceKeyword
.
mace_global_pooling_str
global_pooling_arg
.
i
=
1
def
convert_ops
(
self
):
for
layer
in
self
.
_caffe_layers
.
layer
:
caffe_op
=
self
.
_caffe_net
.
get_op
(
layer
.
name
)
if
caffe_op
not
in
self
.
_skip_ops
:
mace_check
(
layer
.
type
in
self
.
_op_converters
,
"Mace does not support caffe op type %s yet"
%
layer
.
type
)
self
.
_op_converters
[
layer
.
type
](
caffe_op
)
def
add_tensor
(
self
,
name
,
shape
,
data_type
,
value
):
tensor
=
self
.
_mace_net_def
.
tensors
.
add
()
tensor
.
name
=
name
tensor
.
dims
.
extend
(
list
(
shape
))
tensor
.
data_type
=
data_type
tensor
.
float_data
.
extend
(
value
.
flat
)
def
convert_nop
(
self
,
layer
):
pass
def
convert_general_op
(
self
,
caffe_op
):
op
=
self
.
_mace_net_def
.
op
.
add
()
op
.
name
=
caffe_op
.
name
op
.
type
=
caffe_op
.
type
op
.
input
.
extend
(
caffe_op
.
layer
.
bottom
)
op
.
output
.
extend
(
caffe_op
.
layer
.
top
)
data_type_arg
=
op
.
arg
.
add
()
data_type_arg
.
name
=
'T'
data_type_arg
.
i
=
self
.
_option
.
data_type
ConverterUtil
.
add_data_format_arg
(
op
,
DataFormat
.
NCHW
)
return
op
def
convert_conv2d
(
self
,
caffe_op
):
op
=
self
.
convert_general_op
(
caffe_op
)
param
=
caffe_op
.
layer
.
convolution_param
is_depthwise
=
False
if
param
.
HasField
(
caffe_group_str
):
mace_check
(
param
.
group
==
caffe_op
.
blob
[
0
].
shape
[
1
]
and
caffe_op
.
blob
[
0
].
shape
[
0
]
==
1
,
"Mace do not support group convolution yet"
)
is_depthwise
=
True
if
is_depthwise
:
op
.
type
=
MaceOp
.
DepthwiseConv2d
.
name
else
:
op
.
type
=
MaceOp
.
Conv2D
.
name
self
.
add_stride_pad_kernel_arg
(
param
,
op
)
# dilation is specific for convolution in caffe
dilations
=
[
1
,
1
]
if
len
(
param
.
dilation
)
>
0
:
dilation_arg
=
op
.
arg
.
add
()
dilation_arg
.
name
=
MaceKeyword
.
mace_dilations_str
if
len
(
param
.
dilation
)
==
1
:
dilations
=
[
param
.
dilation
[
0
],
param
.
dilation
[
0
]]
elif
len
(
param
.
dilation
)
==
2
:
dilations
=
[
param
.
dilation
[
0
],
param
.
dilation
[
1
]]
dilation_arg
.
ints
.
extend
(
dilations
)
filter_tensor_name
=
op
.
name
+
'_filter'
filter_data
=
caffe_op
.
blobs
[
0
]
self
.
add_tensor
(
filter_tensor_name
,
filter_data
.
shape
,
mace_pb2
.
DT_FLOAT
,
filter_data
)
op
.
input
.
extend
([
filter_tensor_name
])
if
len
(
caffe_op
.
blobs
)
==
2
:
bias_tensor_name
=
op
.
name
+
'_bias'
bias_data
=
caffe_op
.
blobs
[
1
]
self
.
add_tensor
(
bias_tensor_name
,
bias_data
.
shape
,
mace_pb2
.
DT_FLOAT
,
bias_data
)
op
.
input
.
extend
([
bias_tensor_name
])
def
convert_elementwise
(
self
,
caffe_op
):
op
=
self
.
convert_general_op
(
caffe_op
)
param
=
caffe_op
.
layer
.
eltwise_param
op
.
type
=
MaceOp
.
Eltwise
.
name
type_arg
=
op
.
arg
.
add
()
type_arg
.
name
=
MaceKeyword
.
mace_element_type_str
type_arg
.
i
=
self
.
eltwise_type
[
param
.
operation
].
value
if
len
(
param
.
coeff
)
>
0
:
coeff_arg
=
op
.
arg
.
add
()
coeff_arg
.
name
=
'coeff'
coeff_arg
.
floats
.
extend
(
list
(
param
.
coeff
))
def
convert_add
(
self
,
caffe_op
):
op
=
self
.
convert_general_op
(
caffe_op
)
op
.
type
=
MaceOp
.
AddN
.
name
def
convert_activation
(
self
,
caffe_op
):
op
=
self
.
convert_general_op
(
caffe_op
)
op
.
type
=
MaceOp
.
Activation
.
name
type_arg
=
op
.
arg
.
add
()
type_arg
.
name
=
MaceKeyword
.
mace_activation_type_str
type_arg
.
s
=
self
.
activation_type
[
caffe_op
.
type
].
name
if
caffe_op
.
type
==
'PReLU'
:
alpha_tensor_name
=
caffe_op
.
name
+
'_alpha'
alpha_data
=
caffe_op
.
blobs
[
0
]
self
.
add_tensor
(
alpha_tensor_name
,
alpha_data
.
shape
,
mace_pb2
.
DT_FLOAT
,
alpha_data
)
op
.
input
.
extend
([
alpha_tensor_name
])
def
convert_folded_batchnorm
(
self
,
caffe_op
):
op
=
self
.
convert_general_op
(
caffe_op
)
op
.
type
=
MaceOp
.
FoldedBatchNorm
.
name
scale_op
=
None
for
consumer
in
self
.
_caffe_net
.
get_consumers
(
caffe_op
.
layer
.
top
[
0
]):
if
consumer
.
type
==
'Scale'
:
scale_op
=
consumer
mace_check
(
scale_op
is
not
None
,
"batchnorm is not followed by scale"
)
self
.
_skip_ops
.
append
(
scale_op
)
epsilon_value
=
caffe_op
.
layer
.
batch_norm_param
.
eps
mace_check
(
caffe_op
.
blobs
[
2
][
0
]
!=
0
,
"batchnorm scalar is zero"
)
mean_value
=
(
1.
/
caffe_op
.
blobs
[
2
][
0
])
*
caffe_op
.
blobs
[
0
]
var_value
=
(
1.
/
caffe_op
.
blobs
[
2
][
0
])
*
caffe_op
.
blobs
[
1
]
gamma_value
=
scale_op
.
blobs
[
0
]
beta_value
=
np
.
zeros_like
(
mean_value
)
if
len
(
scale_op
.
blobs
)
==
2
:
beta_value
=
scale_op
.
blobs
[
1
]
scale_value
=
(
(
1.0
/
np
.
vectorize
(
math
.
sqrt
)(
var_value
+
epsilon_value
))
*
gamma_value
).
reshape
(
-
1
)
offset_value
=
((
-
mean_value
*
scale_value
)
+
beta_value
).
reshape
(
-
1
)
input_names
=
[
op
.
name
+
'_scale'
,
op
.
name
+
'_offset'
]
self
.
add_tensor
(
input_names
[
0
],
scale_value
.
shape
,
mace_pb2
.
DT_FLOAT
,
scale_value
)
self
.
add_tensor
(
input_names
[
1
],
offset_value
.
shape
,
mace_pb2
.
DT_FLOAT
,
offset_value
)
op
.
input
.
extend
([
name
for
name
in
input_names
])
op
.
output
[:]
=
scale_op
.
layer
.
top
[:]
def
convert_pooling
(
self
,
caffe_op
):
op
=
self
.
convert_general_op
(
caffe_op
)
param
=
caffe_op
.
layer
.
pooling_param
op
.
type
=
MaceOp
.
Pooling
.
name
self
.
add_stride_pad_kernel_arg
(
param
,
op
)
pooling_type_arg
=
op
.
arg
.
add
()
pooling_type_arg
.
name
=
MaceKeyword
.
mace_pooling_type_str
pooling_type_arg
.
i
=
self
.
pooling_type_mode
[
param
.
pool
].
value
def
convert_softmax
(
self
,
caffe_op
):
self
.
convert_general_op
(
caffe_op
)
def
convert_concat
(
self
,
caffe_op
):
op
=
self
.
convert_general_op
(
caffe_op
)
param
=
caffe_op
.
layer
.
concat_param
op
.
type
=
MaceOp
.
Concat
.
name
axis_arg
=
op
.
arg
.
add
()
axis_arg
.
name
=
MaceKeyword
.
mace_axis_str
axis_arg
.
i
=
1
if
param
.
HasField
(
'axis'
):
axis_arg
.
i
=
param
.
axis
elif
param
.
HasField
(
'concat_dim'
):
axis_arg
.
i
=
param
.
concat_dim
mace_check
(
axis_arg
.
i
==
1
,
"only support concat at channel dimension"
)
def
convert_slice
(
self
,
caffe_op
):
op
=
self
.
convert_general_op
(
caffe_op
)
op
.
type
=
MaceOp
.
Slice
.
name
if
caffe_op
.
layer
.
HasField
(
'slice_param'
):
param
=
caffe_op
.
layer
.
slice_param
mace_check
(
not
param
.
HasField
(
'axis'
)
or
param
.
axis
==
1
,
"Mace do not support slice with axis %d"
%
param
.
axis
)
mace_check
(
len
(
param
.
slice_point
)
==
0
,
"Mace do not support slice with slice_point"
)
axis_arg
=
op
.
arg
.
add
()
axis_arg
.
name
=
MaceKeyword
.
mace_axis_str
axis_arg
.
i
=
1
def
convert_fully_connected
(
self
,
caffe_op
):
op
=
self
.
convert_general_op
(
caffe_op
)
param
=
caffe_op
.
layer
.
inner_product_param
op
.
type
=
MaceOp
.
FullyConnected
.
name
mace_check
(
param
.
axis
==
1
and
not
param
.
transpose
,
"Do not support non-default axis and transpose"
)
mace_check
(
caffe_op
.
blobs
[
0
].
ndim
in
[
2
,
4
],
"Unexpected fc weigth ndim."
)
if
caffe_op
.
blobs
[
0
].
ndim
==
4
:
mace_check
(
list
(
caffe_op
.
blobs
[
0
].
shape
[:
2
])
==
[
1
,
1
],
"Do not support 4D weight with shape [1, 1, *, *]"
)
weight_tensor_name
=
op
.
name
+
'_weight'
weight_data
=
caffe_op
.
blobs
[
0
].
reshape
(
param
.
num_output
,
-
1
)
self
.
add_tensor
(
weight_tensor_name
,
weight_data
.
shape
,
mace_pb2
.
DT_FLOAT
,
weight_data
)
op
.
input
.
extend
([
weight_tensor_name
])
if
len
(
caffe_op
.
blobs
)
==
2
:
bias_tensor_name
=
op
.
name
+
'_bias'
bias_data
=
caffe_op
.
blobs
[
1
]
self
.
add_tensor
(
bias_tensor_name
,
bias_data
.
shape
,
mace_pb2
.
DT_FLOAT
,
bias_data
)
op
.
input
.
extend
([
bias_tensor_name
])
mace/python/tools/converter_tool/shape_inference.py
0 → 100644
浏览文件 @
3e82ad67
import
math
import
numpy
as
np
from
mace.python.tools.converter_tool.transformer
import
Transformer
from
mace.python.tools.converter_tool.base_converter
import
DataFormat
from
mace.python.tools.converter_tool.base_converter
import
FilterFormat
from
mace.python.tools.converter_tool.base_converter
import
MaceOp
from
mace.python.tools.converter_tool.base_converter
import
MaceKeyword
from
mace.python.tools.converter_tool.base_converter
import
ConverterUtil
from
mace.python.tools.convert_util
import
mace_check
class
ShapeInference
(
object
):
"""Currently we only use it to infer caffe shape, we use tensorflow engine
to infer tensorflow op shapes, since tensorflow has too many ops."""
def
__init__
(
self
,
net
,
input_nodes
):
self
.
_op_shape_inference
=
{
MaceOp
.
Conv2D
.
name
:
self
.
infer_shape_conv_pool_shape
,
MaceOp
.
Eltwise
.
name
:
self
.
infer_shape_general
,
MaceOp
.
FoldedBatchNorm
.
name
:
self
.
infer_shape_general
,
MaceOp
.
AddN
.
name
:
self
.
infer_shape_general
,
MaceOp
.
Activation
.
name
:
self
.
infer_shape_general
,
MaceOp
.
Pooling
.
name
:
self
.
infer_shape_conv_pool_shape
,
MaceOp
.
Concat
.
name
:
self
.
infer_shape_concat
,
MaceOp
.
Slice
.
name
:
self
.
infer_shape_slice
,
MaceOp
.
Softmax
.
name
:
self
.
infer_shape_general
,
MaceOp
.
FullyConnected
.
name
:
self
.
infer_shape_fully_connected
,
}
self
.
_net
=
net
self
.
_output_shape_cache
=
{}
for
input_node
in
input_nodes
:
input_shape
=
input_node
.
shape
[:]
# transpose input from NCHW to NHWC
Transformer
.
transpose_shape
(
input_shape
,
[
0
,
3
,
1
,
2
])
self
.
_output_shape_cache
[
input_node
.
name
]
=
input_shape
for
tensor
in
net
.
tensors
:
self
.
_output_shape_cache
[
tensor
.
name
]
=
list
(
tensor
.
dims
)
def
run
(
self
):
for
op
in
self
.
_net
.
op
:
mace_check
(
op
.
type
in
self
.
_op_shape_inference
,
"Mace does not support caffe op type %s yet"
%
op
.
type
)
self
.
_op_shape_inference
[
op
.
type
](
op
)
def
add_output_shape
(
self
,
op
,
shapes
):
mace_check
(
len
(
op
.
output
)
==
len
(
shapes
),
"Op %s (%s) output count is different from "
"output shape count"
%
(
op
.
name
,
op
.
type
))
for
i
in
xrange
(
len
(
shapes
)):
output_name
=
op
.
output
[
i
]
output_shape
=
op
.
output_shape
.
add
()
output_shape
.
dims
.
extend
(
shapes
[
i
])
self
.
_output_shape_cache
[
output_name
]
=
shapes
[
i
]
def
infer_shape_general
(
self
,
op
):
if
len
(
op
.
input
)
>
0
:
mace_check
(
op
.
input
[
0
]
in
self
.
_output_shape_cache
,
"%s does not exist"
%
op
.
input
[
0
])
input_shape
=
self
.
_output_shape_cache
[
op
.
input
[
0
]]
self
.
add_output_shape
(
op
,
[
input_shape
])
def
infer_shape_conv_pool_shape
(
self
,
op
):
input_shape
=
self
.
_output_shape_cache
[
op
.
input
[
0
]]
output_shape
=
np
.
zeros_like
(
input_shape
)
if
op
.
type
==
MaceOp
.
Pooling
:
filter_shape
=
list
(
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_kernel_str
).
ints
)
if
ConverterUtil
.
data_format
(
op
)
==
DataFormat
.
NCHW
:
filter_shape
=
[
input_shape
[
1
],
input_shape
[
1
]]
+
filter_shape
if
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_global_pooling_str
)
\
is
not
None
:
filter_shape
[
2
]
=
input_shape
[
2
]
filter_shape
[
3
]
=
input_shape
[
3
]
else
:
# NHWC
filter_shape
=
filter_shape
+
[
input_shape
[
1
],
input_shape
[
1
]]
if
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_global_pooling_str
)
\
is
not
None
:
filter_shape
[
0
]
=
input_shape
[
1
]
filter_shape
[
1
]
=
input_shape
[
2
]
else
:
filter_shape
=
self
.
_output_shape_cache
[
op
.
input
[
1
]]
paddings
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_padding_values_str
).
ints
# noqa
strides
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_strides_str
).
ints
dilations_arg
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_dilations_str
)
if
dilations_arg
is
not
None
:
dilations
=
dilations_arg
.
ints
else
:
dilations
=
[
1
,
1
]
if
op
.
type
==
MaceOp
.
Pooling
:
round_func
=
math
.
ceil
else
:
round_func
=
math
.
floor
output_shape
[
0
]
=
input_shape
[
0
]
if
ConverterUtil
.
data_format
(
op
)
==
DataFormat
.
NCHW
\
and
ConverterUtil
.
filter_format
(
self
.
_net
)
==
FilterFormat
.
OIHW
:
# noqa
# filter format: OIHW
output_shape
[
1
]
=
filter_shape
[
0
]
output_shape
[
2
]
=
int
(
round_func
((
input_shape
[
2
]
+
paddings
[
0
]
-
filter_shape
[
2
]
-
(
filter_shape
[
2
]
-
1
)
*
(
dilations
[
0
]
-
1
))
/
float
(
strides
[
0
])))
+
1
output_shape
[
3
]
=
int
(
round_func
((
input_shape
[
3
]
+
paddings
[
1
]
-
filter_shape
[
3
]
-
(
filter_shape
[
3
]
-
1
)
*
(
dilations
[
1
]
-
1
))
/
float
(
strides
[
1
])))
+
1
else
:
mace_check
(
False
,
"Mace can only infer shape for"
" NCHW input and OIHW filter"
)
self
.
add_output_shape
(
op
,
[
output_shape
])
def
infer_shape_concat
(
self
,
op
):
output_shape
=
self
.
_output_shape_cache
[
op
.
input
[
0
]]
axis
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_axis_str
).
i
for
input_node
in
op
.
input
:
input_shape
=
self
.
_output_shape_cache
[
input_node
]
output_shape
[
axis
]
+=
input_shape
[
axis
]
self
.
add_output_shape
(
op
,
[
output_shape
])
def
infer_shape_slice
(
self
,
op
):
output_shape
=
self
.
_output_shape_cache
[
op
.
input
[
0
]]
axis
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_axis_str
).
i
output_shape
[
axis
]
/=
len
(
op
.
output
)
output_shapes
=
[]
for
_
in
op
.
output
:
output_shapes
.
append
(
output_shape
)
self
.
add_output_shape
(
op
,
output_shapes
)
def
infer_shape_fully_connected
(
self
,
op
):
input_shape
=
self
.
_output_shape_cache
[
op
.
input
[
0
]]
weight_shape
=
self
.
_output_shape_cache
[
op
.
input
[
1
]]
if
ConverterUtil
.
data_format
(
op
)
==
DataFormat
.
NCHW
:
output_shape
=
[
input_shape
[
0
],
weight_shape
[
0
],
1
,
1
]
else
:
mace_check
(
False
,
"format %s is not supported"
%
ConverterUtil
.
data_format
(
op
))
self
.
add_output_shape
(
op
,
[
output_shape
])
mace/python/tools/converter_tool/tensorflow_converter.py
0 → 100644
浏览文件 @
3e82ad67
import
math
import
numpy
as
np
import
tensorflow
as
tf
from
mace.proto
import
mace_pb2
from
mace.python.tools.converter_tool
import
base_converter
from
mace.python.tools.converter_tool.base_converter
import
PoolingType
from
mace.python.tools.converter_tool.base_converter
import
PaddingMode
from
mace.python.tools.converter_tool.base_converter
import
ActivationType
from
mace.python.tools.converter_tool.base_converter
import
EltwiseType
from
mace.python.tools.converter_tool.base_converter
import
DataFormat
from
mace.python.tools.converter_tool.base_converter
import
FilterFormat
from
mace.python.tools.converter_tool.base_converter
import
MaceOp
from
mace.python.tools.converter_tool.base_converter
import
MaceKeyword
from
mace.python.tools.converter_tool.base_converter
import
ConverterUtil
from
mace.python.tools.convert_util
import
mace_check
from
tensorflow.core.framework
import
tensor_shape_pb2
tf_padding_str
=
'padding'
tf_strides_str
=
'strides'
tf_dilations_str
=
'dilations'
tf_data_format_str
=
'data_format'
tf_kernel_str
=
'ksize'
tf_epsilon_str
=
'epsilon'
tf_align_corners
=
'align_corners'
tf_block_size
=
'block_size'
class
TensorflowConverter
(
base_converter
.
ConverterInterface
):
"""A class for convert tensorflow frozen model to mace model.
We use tensorflow engine to infer op output shapes, since they are of
too many types."""
padding_mode
=
{
'VALID'
:
PaddingMode
.
VALID
,
'SAME'
:
PaddingMode
.
SAME
,
'FULL'
:
PaddingMode
.
FULL
}
pooling_type_mode
=
{
'AvgPool'
:
PoolingType
.
AVG
,
'MaxPool'
:
PoolingType
.
MAX
}
eltwise_type
=
{
'Add'
:
EltwiseType
.
SUM
,
'Sub'
:
EltwiseType
.
SUB
,
'Mul'
:
EltwiseType
.
PROD
,
'Div'
:
EltwiseType
.
DIV
,
'Min'
:
EltwiseType
.
MIN
,
'Max'
:
EltwiseType
.
MAX
,
'Neg'
:
EltwiseType
.
NEG
,
'Abs'
:
EltwiseType
.
ABS
,
'RealDiv'
:
EltwiseType
.
DIV
,
'SquaredDifference'
:
EltwiseType
.
SQR_DIFF
,
'Pow'
:
EltwiseType
.
POW
}
activation_type
=
{
'Relu'
:
ActivationType
.
RELU
,
'Relu6'
:
ActivationType
.
RELUX
,
'Tanh'
:
ActivationType
.
TANH
,
'Sigmoid'
:
ActivationType
.
SIGMOID
}
def
__init__
(
self
,
option
,
src_model_file
):
self
.
_op_converters
=
{
'Conv2D'
:
self
.
convert_conv2d
,
'DepthwiseConv2dNative'
:
self
.
convert_conv2d
,
'Conv2DBackpropInput'
:
self
.
convert_conv2d
,
'BiasAdd'
:
self
.
convert_biasadd
,
'Add'
:
self
.
convert_add
,
'Sub'
:
self
.
convert_elementwise
,
'Mul'
:
self
.
convert_elementwise
,
'Div'
:
self
.
convert_elementwise
,
'Min'
:
self
.
convert_elementwise
,
'Max'
:
self
.
convert_elementwise
,
'Neg'
:
self
.
convert_elementwise
,
'Abs'
:
self
.
convert_elementwise
,
'RealDiv'
:
self
.
convert_elementwise
,
'SquaredDifference'
:
self
.
convert_elementwise
,
'Pow'
:
self
.
convert_elementwise
,
'Relu'
:
self
.
convert_activation
,
'Relu6'
:
self
.
convert_activation
,
'Tanh'
:
self
.
convert_activation
,
'Sigmoid'
:
self
.
convert_activation
,
'FusedBatchNorm'
:
self
.
convert_fused_batchnorm
,
'AvgPool'
:
self
.
convert_pooling
,
'MaxPool'
:
self
.
convert_pooling
,
'Squeeze'
:
self
.
convert_identity
,
'Reshape'
:
self
.
convert_reshape
,
'Shape'
:
self
.
convert_nop
,
'Softmax'
:
self
.
convert_softmax
,
'ResizeBilinear'
:
self
.
convert_resize_bilinear
,
'Placeholder'
:
self
.
convert_nop
,
'SpaceToBatchND'
:
self
.
convert_space_batch
,
'BatchToSpaceND'
:
self
.
convert_space_batch
,
'DepthToSpace'
:
self
.
convert_space_depth
,
'SpaceToDepth'
:
self
.
convert_space_depth
,
'Pad'
:
self
.
convert_pad
,
'ConcatV2'
:
self
.
convert_concat
,
'Mean'
:
self
.
convert_mean
,
# Const converter_tool should be placed at the end
'Const'
:
self
.
convert_tensor
,
}
self
.
_option
=
option
self
.
_mace_net_def
=
mace_pb2
.
NetDef
()
ConverterUtil
.
set_filter_format
(
self
.
_mace_net_def
,
FilterFormat
.
HWIO
)
tf_graph_def
=
tf
.
GraphDef
()
with
tf
.
gfile
.
Open
(
src_model_file
,
'rb'
)
as
f
:
tf_graph_def
.
ParseFromString
(
f
.
read
())
self
.
add_shape_info
(
tf_graph_def
)
with
tf
.
Session
()
as
session
:
with
session
.
graph
.
as_default
()
as
graph
:
tf
.
import_graph_def
(
tf_graph_def
,
name
=
''
)
self
.
_tf_graph
=
graph
self
.
_skip_tensor
=
set
()
def
run
(
self
):
with
tf
.
Session
()
as
session
:
self
.
convert_ops
()
self
.
replace_input_output_tensor_name
()
return
self
.
_mace_net_def
def
replace_input_output_tensor_name
(
self
):
for
op
in
self
.
_mace_net_def
.
op
:
for
i
in
xrange
(
len
(
op
.
input
)):
if
op
.
input
[
i
][
-
2
:]
==
':0'
:
op_name
=
op
.
input
[
i
][:
-
2
]
if
op_name
in
self
.
_option
.
input_nodes
:
op
.
input
[
i
]
=
op_name
for
i
in
xrange
(
len
(
op
.
output
)):
if
op
.
output
[
i
][
-
2
:]
==
':0'
:
op_name
=
op
.
output
[
i
][:
-
2
]
if
op_name
in
self
.
_option
.
output_nodes
:
op
.
output
[
i
]
=
op_name
def
add_shape_info
(
self
,
tf_graph_def
):
for
node
in
tf_graph_def
.
node
:
if
node
.
name
in
self
.
_option
.
input_nodes
:
del
node
.
attr
[
'shape'
].
shape
.
dim
[:]
node
.
attr
[
'shape'
].
shape
.
dim
.
extend
([
tensor_shape_pb2
.
TensorShapeProto
.
Dim
(
size
=
i
)
for
i
in
self
.
_option
.
input_nodes
[
node
.
name
].
shape
])
@
staticmethod
def
get_scope
(
tensor_name
):
idx
=
tensor_name
.
rfind
(
'/'
)
if
idx
==
-
1
:
return
tensor_name
else
:
return
tensor_name
[:
idx
]
def
convert_ops
(
self
):
for
tf_op
in
self
.
_tf_graph
.
get_operations
():
mace_check
(
tf_op
.
type
in
self
.
_op_converters
,
"Mace does not support tensorflow op type %s yet"
%
tf_op
.
type
)
self
.
_op_converters
[
tf_op
.
type
](
tf_op
)
def
convert_tensor
(
self
,
tf_op
):
output_name
=
tf_op
.
outputs
[
0
].
name
if
output_name
not
in
self
.
_skip_tensor
:
tensor
=
self
.
_mace_net_def
.
tensors
.
add
()
tensor
.
name
=
tf_op
.
outputs
[
0
].
name
tf_tensor
=
tf_op
.
outputs
[
0
].
eval
()
tensor
.
dims
.
extend
(
list
(
tf_tensor
.
shape
))
tf_dt
=
tf_op
.
get_attr
(
'dtype'
)
if
tf_dt
==
tf
.
float32
:
tensor
.
data_type
=
mace_pb2
.
DT_FLOAT
tensor
.
float_data
.
extend
(
tf_tensor
.
astype
(
np
.
float32
).
flat
)
elif
tf_dt
==
tf
.
int32
:
tensor
.
data_type
=
mace_pb2
.
DT_INT32
tensor
.
int32_data
.
extend
(
tf_tensor
.
astype
(
np
.
int32
).
flat
)
else
:
mace_check
(
False
,
"Not supported tensor type: %s"
%
tf_dt
.
name
)
def
add_tensor
(
self
,
name
,
shape
,
data_type
,
value
):
tensor
=
self
.
_mace_net_def
.
tensors
.
add
()
tensor
.
name
=
name
tensor
.
dims
.
extend
(
list
(
shape
))
tensor
.
data_type
=
data_type
tensor
.
float_data
.
extend
(
value
.
flat
)
def
convert_nop
(
self
,
tf_op
):
pass
def
convert_general_op
(
self
,
tf_op
):
op
=
self
.
_mace_net_def
.
op
.
add
()
op
.
name
=
tf_op
.
name
op
.
type
=
tf_op
.
type
op
.
input
.
extend
([
tf_input
.
name
for
tf_input
in
tf_op
.
inputs
])
op
.
output
.
extend
([
tf_output
.
name
for
tf_output
in
tf_op
.
outputs
])
for
tf_output
in
tf_op
.
outputs
:
output_shape
=
op
.
output_shape
.
add
()
output_shape
.
dims
.
extend
(
tf_output
.
shape
.
as_list
())
op
.
output_type
.
append
(
self
.
_option
.
data_type
)
data_type_arg
=
op
.
arg
.
add
()
data_type_arg
.
name
=
'T'
data_type_arg
.
i
=
self
.
_option
.
data_type
ConverterUtil
.
add_data_format_arg
(
op
,
DataFormat
.
NHWC
)
return
op
def
convert_identity
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
'Identity'
def
convert_conv2d
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
if
tf_op
.
type
==
'DepthwiseConv2dNative'
:
op
.
type
=
MaceOp
.
DepthwiseConv2d
.
name
elif
tf_op
.
type
==
'Conv2DBackpropInput'
:
op
.
type
=
MaceOp
.
Deconv2D
.
name
else
:
op
.
type
=
MaceOp
.
Conv2D
.
name
padding_arg
=
op
.
arg
.
add
()
padding_arg
.
name
=
MaceKeyword
.
mace_padding_str
padding_arg
.
i
=
self
.
padding_mode
[
tf_op
.
get_attr
(
tf_padding_str
)].
value
strides_arg
=
op
.
arg
.
add
()
strides_arg
.
name
=
MaceKeyword
.
mace_strides_str
strides_arg
.
ints
.
extend
(
tf_op
.
get_attr
(
tf_strides_str
)[
1
:
3
])
if
op
.
type
!=
MaceOp
.
Deconv2D
.
name
:
dilation_arg
=
op
.
arg
.
add
()
dilation_arg
.
name
=
MaceKeyword
.
mace_dilations_str
dilation_arg
.
ints
.
extend
(
tf_op
.
get_attr
(
tf_dilations_str
)[
1
:
3
])
def
convert_elementwise
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
Eltwise
.
name
type_arg
=
op
.
arg
.
add
()
type_arg
.
name
=
MaceKeyword
.
mace_element_type_str
type_arg
.
i
=
self
.
eltwise_type
[
tf_op
.
type
].
value
def
convert_biasadd
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
BiasAdd
.
name
def
convert_add
(
self
,
tf_op
):
if
len
(
tf_op
.
inputs
)
==
2
:
self
.
convert_elementwise
(
tf_op
)
else
:
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
AddN
.
name
def
convert_activation
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
Activation
.
name
type_arg
=
op
.
arg
.
add
()
type_arg
.
name
=
MaceKeyword
.
mace_activation_type_str
type_arg
.
s
=
self
.
activation_type
[
tf_op
.
type
].
name
if
tf_op
.
type
==
'Relu6'
:
limit_arg
=
op
.
arg
.
add
()
limit_arg
.
name
=
MaceKeyword
.
mace_activation_max_limit_str
limit_arg
.
f
=
6.0
def
convert_fused_batchnorm
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
FoldedBatchNorm
.
name
gamma_value
=
tf_op
.
inputs
[
1
].
eval
().
astype
(
np
.
float32
)
beta_value
=
tf_op
.
inputs
[
2
].
eval
().
astype
(
np
.
float32
)
mean_value
=
tf_op
.
inputs
[
3
].
eval
().
astype
(
np
.
float32
)
var_value
=
tf_op
.
inputs
[
4
].
eval
().
astype
(
np
.
float32
)
epsilon_value
=
tf_op
.
get_attr
(
tf_epsilon_str
)
scale_name
=
self
.
get_scope
(
tf_op
.
name
)
+
'/scale:0'
offset_name
=
self
.
get_scope
(
tf_op
.
name
)
+
'/offset:0'
scale_value
=
(
(
1.0
/
np
.
vectorize
(
math
.
sqrt
)(
var_value
+
epsilon_value
))
*
gamma_value
)
offset_value
=
(
-
mean_value
*
scale_value
)
+
beta_value
self
.
add_tensor
(
scale_name
,
scale_value
.
shape
,
mace_pb2
.
DT_FLOAT
,
scale_value
)
self
.
add_tensor
(
offset_name
,
offset_value
.
shape
,
mace_pb2
.
DT_FLOAT
,
offset_value
)
self
.
_skip_tensor
.
update
([
inp
.
name
for
inp
in
tf_op
.
inputs
][
1
:])
del
op
.
input
[
1
:]
op
.
input
.
extend
([
scale_name
,
offset_name
])
del
op
.
output
[
1
:]
del
op
.
output_shape
[
1
:]
del
op
.
output_type
[
1
:]
def
convert_pooling
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
Pooling
.
name
pooling_type_arg
=
op
.
arg
.
add
()
pooling_type_arg
.
name
=
MaceKeyword
.
mace_pooling_type_str
pooling_type_arg
.
i
=
self
.
pooling_type_mode
[
tf_op
.
type
].
value
padding_arg
=
op
.
arg
.
add
()
padding_arg
.
name
=
MaceKeyword
.
mace_padding_str
padding_arg
.
i
=
self
.
padding_mode
[
tf_op
.
get_attr
(
tf_padding_str
)].
value
strides_arg
=
op
.
arg
.
add
()
strides_arg
.
name
=
MaceKeyword
.
mace_strides_str
strides_arg
.
ints
.
extend
(
tf_op
.
get_attr
(
tf_strides_str
)[
1
:
3
])
kernels_arg
=
op
.
arg
.
add
()
kernels_arg
.
name
=
MaceKeyword
.
mace_kernel_str
kernels_arg
.
ints
.
extend
(
tf_op
.
get_attr
(
tf_kernel_str
)[
1
:
3
])
def
convert_softmax
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
Softmax
.
name
def
convert_resize_bilinear
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
ResizeBilinear
.
name
del
op
.
input
[
1
:]
size_arg
=
op
.
arg
.
add
()
size_arg
.
name
=
MaceKeyword
.
mace_resize_size_str
size_value
=
tf_op
.
inputs
[
1
].
eval
().
astype
(
np
.
int32
)
size_arg
.
ints
.
extend
(
size_value
)
self
.
_skip_tensor
.
update
(
tf_op
.
inputs
[
1
].
name
)
align_corners_arg
=
op
.
arg
.
add
()
align_corners_arg
.
name
=
MaceKeyword
.
mace_align_corners_str
align_corners_arg
.
i
=
tf_op
.
get_attr
(
tf_align_corners
)
def
convert_space_batch
(
self
,
tf_op
):
print
"""You might want to try 'flatten_atrous_conv' in
transform graph to turn atrous conv2d into regular conv2d.
This may give you performance benefit on GPU.
(see https://github.com/tensorflow/tensorflow/blob/master/
tensorflow/tools/graph_transforms/README.md#flatten_atrous_conv)
"""
op
=
self
.
convert_general_op
(
tf_op
)
del
op
.
input
[
1
:]
size_arg
=
op
.
arg
.
add
()
size_arg
.
name
=
MaceKeyword
.
mace_space_batch_block_shape_str
size_value
=
tf_op
.
inputs
[
1
].
eval
().
astype
(
np
.
int32
)
size_arg
.
ints
.
extend
(
size_value
)
crops_or_paddings_arg
=
op
.
arg
.
add
()
if
op
.
type
==
'BatchToSpaceND'
:
op
.
type
=
MaceOp
.
BatchToSpaceND
.
name
crops_or_paddings_arg
.
name
=
\
MaceKeyword
.
mace_batch_to_space_crops_str
else
:
op
.
type
=
MaceOp
.
SpaceToBatchND
.
name
crops_or_paddings_arg
.
name
=
MaceKeyword
.
mace_paddings_str
crops_or_paddings_value
=
tf_op
.
inputs
[
2
].
eval
().
astype
(
np
.
int32
).
flat
crops_or_paddings_arg
.
ints
.
extend
(
crops_or_paddings_value
)
self
.
_skip_tensor
.
update
(
tf_op
.
inputs
[
1
].
name
)
self
.
_skip_tensor
.
update
(
tf_op
.
inputs
[
2
].
name
)
def
convert_space_depth
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
if
op
.
type
==
'SpaceToDepth'
:
op
.
type
=
MaceOp
.
SpaceToDepth
.
name
else
:
op
.
type
=
MaceOp
.
DepthToSpace
.
name
size_arg
=
op
.
arg
.
add
()
size_arg
.
name
=
MaceKeyword
.
mace_space_depth_block_size_str
size_arg
.
i
=
tf_op
.
get_attr
(
tf_block_size
)
def
convert_pad
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
Pad
.
name
del
op
.
input
[
1
:]
paddings_arg
=
op
.
arg
.
add
()
paddings_arg
.
name
=
MaceKeyword
.
mace_paddings_str
paddings_value
=
tf_op
.
inputs
[
1
].
eval
().
astype
(
np
.
int32
).
flat
paddings_arg
.
ints
.
extend
(
paddings_value
)
self
.
_skip_tensor
.
update
(
tf_op
.
inputs
[
1
].
name
)
if
len
(
tf_op
.
inputs
)
==
3
:
constant_value_arg
=
op
.
arg
.
add
()
constant_value_arg
.
name
=
MaceKeyword
.
mace_constant_value_str
constant_value
=
tf_op
.
inputs
[
2
].
eval
().
astype
(
np
.
int32
).
flat
[
0
]
constant_value_arg
.
i
=
constant_value
self
.
_skip_tensor
.
update
(
tf_op
.
inputs
[
2
].
name
)
def
convert_concat
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
Concat
.
name
del
op
.
input
[
-
1
]
axis_arg
=
op
.
arg
.
add
()
axis_arg
.
name
=
MaceKeyword
.
mace_axis_str
axis
=
tf_op
.
inputs
[
-
1
].
eval
().
astype
(
np
.
int32
)
axis_arg
.
i
=
axis
mace_check
(
axis
==
3
,
"only support concat at channel dimension"
)
self
.
_skip_tensor
.
update
(
tf_op
.
inputs
[
-
1
].
name
)
def
convert_reshape
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
op
.
type
=
MaceOp
.
Reshape
.
name
del
op
.
input
[
1
:]
shape_arg
=
op
.
arg
.
add
()
shape_arg
.
name
=
MaceKeyword
.
mace_shape_str
shape_value
=
[]
if
tf_op
.
inputs
[
1
].
op
.
type
==
'Const'
:
shape_value
=
list
(
tf_op
.
inputs
[
1
].
eval
().
astype
(
np
.
int32
))
for
i
in
xrange
(
len
(
shape_value
)):
if
shape_value
[
i
]
==
-
1
:
shape_value
[
i
]
=
1
self
.
_skip_tensor
.
update
(
tf_op
.
inputs
[
-
1
].
name
)
elif
tf_op
.
inputs
[
1
].
op
.
type
==
'Shape'
:
shape_value
=
list
(
tf_op
.
inputs
[
1
].
op
.
inputs
[
0
].
shape
.
as_list
())
shape_arg
.
ints
.
extend
(
shape_value
)
def
convert_mean
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
del
op
.
input
[
1
:]
reduce_dims
=
tf_op
.
inputs
[
1
].
eval
()
mace_check
(
reduce_dims
[
0
]
==
1
and
reduce_dims
[
1
]
==
2
,
"Mean only support reduce dim 1, 2"
)
op
.
type
=
MaceOp
.
Pooling
.
name
pooling_type_arg
=
op
.
arg
.
add
()
pooling_type_arg
.
name
=
MaceKeyword
.
mace_pooling_type_str
pooling_type_arg
.
i
=
PoolingType
.
AVG
.
value
padding_arg
=
op
.
arg
.
add
()
padding_arg
.
name
=
MaceKeyword
.
mace_padding_str
padding_arg
.
i
=
PaddingMode
.
VALID
.
value
strides_arg
=
op
.
arg
.
add
()
strides_arg
.
name
=
MaceKeyword
.
mace_strides_str
strides_arg
.
ints
.
extend
([
1
,
1
])
kernels_arg
=
op
.
arg
.
add
()
kernels_arg
.
name
=
MaceKeyword
.
mace_kernel_str
kernels_arg
.
ints
.
extend
(
tf_op
.
inputs
[
0
].
shape
.
as_list
()[
1
:
3
])
self
.
_skip_tensor
.
add
(
tf_op
.
inputs
[
1
].
name
)
mace/python/tools/converter_tool/transformer.py
0 → 100644
浏览文件 @
3e82ad67
此差异已折叠。
点击以展开。
mace/python/tools/memory_optimizer.py
浏览文件 @
3e82ad67
...
...
@@ -129,7 +129,7 @@ class MemoryOptimizer(object):
self
.
idle_mem
.
remove
(
mem_id
)
if
mem_id
==
-
1
:
mem_id
=
self
.
total_mem_count
mem_id
=
self
.
mem_id_base
()
+
self
.
total_mem_count
self
.
total_mem_count
+=
1
self
.
mem_block
[
mem_id
]
=
op_mem_block
...
...
@@ -147,10 +147,13 @@ class MemoryOptimizer(object):
self
.
add_net_mem_blocks
()
print
(
'total op: %d'
,
len
(
self
.
net_def
.
op
))
print
(
'origin mem: %d, optimized mem: %d'
,
print
(
"total op: %d"
%
len
(
self
.
net_def
.
op
))
print
(
"origin mem: %d, optimized mem: %d"
%
(
self
.
get_total_origin_mem_size
(),
self
.
get_total_optimized_mem_size
())
self
.
get_total_optimized_mem_size
()))
def
mem_id_base
(
self
):
return
0
class
GPUMemoryOptimizer
(
MemoryOptimizer
):
...
...
@@ -189,6 +192,9 @@ class GPUMemoryOptimizer(MemoryOptimizer):
block
.
x
=
self
.
mem_block
[
mem
][
0
]
block
.
y
=
self
.
mem_block
[
mem
][
1
]
def
mem_id_base
(
self
):
return
20000
def
optimize_gpu_memory
(
net_def
):
mem_optimizer
=
GPUMemoryOptimizer
(
net_def
)
...
...
mace/python/tools/source_converter_lib.py
浏览文件 @
3e82ad67
...
...
@@ -84,11 +84,20 @@ def obfuscate_name(net_def):
op
.
output
[
i
]
=
in_out_map
[
op
.
output
[
i
]]
def
normalize_op_name
(
op_name
):
idx
=
op_name
.
rfind
(
':'
)
if
idx
==
-
1
:
return
op_name
else
:
return
op_name
[:
idx
]
def
rename_tensor
(
net_def
):
tensor_map
=
{}
for
t
in
net_def
.
tensors
:
if
t
.
name
not
in
tensor_map
:
tensor_map
[
t
.
name
]
=
"_"
+
t
.
name
[:
-
2
].
replace
(
"/"
,
"_"
)
tensor_map
[
t
.
name
]
=
"_"
+
normalize_op_name
(
t
.
name
).
replace
(
"/"
,
"_"
)
t
.
name
=
tensor_map
[
t
.
name
]
for
op
in
net_def
.
op
:
for
i
in
range
(
len
(
op
.
input
)):
...
...
@@ -118,6 +127,8 @@ class TensorInfo:
elif
t
.
data_type
==
mace_pb2
.
DT_UINT8
:
self
.
data
=
bytearray
(
np
.
array
(
t
.
int32_data
).
astype
(
np
.
uint8
).
tolist
())
else
:
raise
Exception
(
'Tensor data type %s not supported'
%
t
.
data_type
)
def
stringfy
(
value
):
...
...
mace/python/tools/tf_converter_lib.py
已删除
100644 → 0
浏览文件 @
04f7a34a
此差异已折叠。
点击以展开。
mace/test/mace_api_mt_test.cc
浏览文件 @
3e82ad67
...
...
@@ -152,7 +152,7 @@ void CheckOutputs(const NetDef &net_def,
memcpy
(
input_data
.
data
(),
input
.
second
.
data
().
get
(),
data_size
*
sizeof
(
float
));
std
::
string
input_name
=
MakeString
(
"mace_input_node_"
,
input
.
first
,
":0"
);
input
.
first
);
net
.
AddInputFromArray
<
D
,
float
>
(
input_name
,
input
.
second
.
shape
(),
input_data
);
}
...
...
@@ -181,7 +181,7 @@ void CheckOutputs(const NetDef &net_def,
float
*
data
=
tmp_tensor
->
mutable_data
<
float
>
();
memcpy
(
data
,
output
.
second
.
data
().
get
(),
data_size
*
sizeof
(
float
));
std
::
string
output_name
=
MakeString
(
"mace_output_node_"
,
output
.
first
,
":0"
);
output
.
first
);
ops
::
test
::
ExpectTensorNear
<
float
>
(
*
tmp_tensor
,
*
net
.
GetOutput
(
output_name
.
data
()),
1e-5
);
...
...
@@ -265,7 +265,7 @@ void MaceRunFunc(const int in_out_size) {
for
(
size_t
i
=
0
;
i
<
input_names
.
size
();
++
i
)
{
std
::
string
input_name
=
MakeString
(
"mace_input_node_"
,
input_names
[
i
]
,
":0"
);
input_names
[
i
]);
BufferToImage
<
half
>
(
input_name
,
input_names
[
i
],
mace
::
kernels
::
IN_OUT_CHANNEL
,
{
mem_map
[
input_names
[
i
]]},
...
...
@@ -281,7 +281,7 @@ void MaceRunFunc(const int in_out_size) {
}
for
(
size_t
i
=
0
;
i
<
output_names
.
size
();
++
i
)
{
std
::
string
output_name
=
MakeString
(
"mace_output_node_"
,
output_names
[
i
]
,
":0"
);
output_names
[
i
]);
ImageToBuffer
<
float
>
(
output_names
[
i
],
output_name
,
mace
::
kernels
::
IN_OUT_CHANNEL
,
&
net_def
);
}
...
...
mace/test/mace_api_test.cc
浏览文件 @
3e82ad67
...
...
@@ -162,7 +162,7 @@ void CheckOutputs(const NetDef &net_def,
memcpy
(
input_data
.
data
(),
input
.
second
.
data
().
get
(),
data_size
*
sizeof
(
float
));
std
::
string
input_name
=
MakeString
(
"mace_input_node_"
,
input
.
first
,
":0"
);
input
.
first
);
net
.
AddInputFromArray
<
D
,
float
>
(
input_name
,
input
.
second
.
shape
(),
input_data
);
}
...
...
@@ -191,7 +191,7 @@ void CheckOutputs(const NetDef &net_def,
float
*
data
=
tmp_tensor
->
mutable_data
<
float
>
();
memcpy
(
data
,
output
.
second
.
data
().
get
(),
data_size
*
sizeof
(
float
));
std
::
string
output_name
=
MakeString
(
"mace_output_node_"
,
output
.
first
,
":0"
);
output
.
first
);
ops
::
test
::
ExpectTensorNear
<
float
>
(
*
tmp_tensor
,
*
net
.
GetOutput
(
output_name
.
data
()),
1e-5
);
...
...
@@ -275,7 +275,7 @@ void MaceRun(const int in_out_size,
for
(
size_t
i
=
0
;
i
<
input_names
.
size
();
++
i
)
{
std
::
string
input_name
=
MakeString
(
"mace_input_node_"
,
input_names
[
i
]
,
":0"
);
input_names
[
i
]);
BufferToImage
<
half
>
(
input_name
,
input_names
[
i
],
mace
::
kernels
::
IN_OUT_CHANNEL
,
{
mem_map
[
input_names
[
i
]]},
...
...
@@ -291,7 +291,7 @@ void MaceRun(const int in_out_size,
}
for
(
size_t
i
=
0
;
i
<
output_names
.
size
();
++
i
)
{
std
::
string
output_name
=
MakeString
(
"mace_output_node_"
,
output_names
[
i
]
,
":0"
);
output_names
[
i
]);
ImageToBuffer
<
float
>
(
output_names
[
i
],
output_name
,
mace
::
kernels
::
IN_OUT_CHANNEL
,
&
net_def
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录