Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
320b509c
Mace
项目概览
Xiaomi
/
Mace
通知
107
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
320b509c
编写于
8月 09, 2019
作者:
L
liyin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor model converter & encrypt process
上级
64419daa
变更
48
隐藏空白更改
内联
并排
Showing
48 changed file
with
1217 addition
and
1295 deletion
+1217
-1295
.gitignore
.gitignore
+1
-0
docs/user_guide/quantization_usage.rst
docs/user_guide/quantization_usage.rst
+1
-1
mace/codegen/BUILD.bazel
mace/codegen/BUILD.bazel
+2
-2
mace/python/tools/BUILD.bazel
mace/python/tools/BUILD.bazel
+0
-53
mace/python/tools/binary_codegen.py
mace/python/tools/binary_codegen.py
+0
-112
mace/python/tools/convert_util.py
mace/python/tools/convert_util.py
+0
-91
mace/python/tools/converter.py
mace/python/tools/converter.py
+0
-422
mace/python/tools/encrypt_opencl_codegen.py
mace/python/tools/encrypt_opencl_codegen.py
+2
-1
mace/python/tools/graph_util.py
mace/python/tools/graph_util.py
+0
-68
mace/python/tools/mace_engine_factory_codegen.py
mace/python/tools/mace_engine_factory_codegen.py
+4
-4
mace/python/tools/opencl_binary_codegen.py
mace/python/tools/opencl_binary_codegen.py
+2
-1
mace/python/tools/quantization/quantize_util_test.py
mace/python/tools/quantization/quantize_util_test.py
+0
-16
mace/python/tools/tf_ops_stats.py
mace/python/tools/tf_ops_stats.py
+0
-242
mace/python/tools/visualization/BUILD.bazel
mace/python/tools/visualization/BUILD.bazel
+0
-14
tools/converter.py
tools/converter.py
+85
-118
tools/python/__init__.py
tools/python/__init__.py
+0
-0
tools/python/config/model/model-template.yaml
tools/python/config/model/model-template.yaml
+0
-0
tools/python/convert.py
tools/python/convert.py
+342
-0
tools/python/encrypt.py
tools/python/encrypt.py
+237
-0
tools/python/py_proto/__init__.py
tools/python/py_proto/__init__.py
+12
-0
tools/python/quantize/__init__.py
tools/python/quantize/__init__.py
+0
-0
tools/python/quantize/quantize_stat.py
tools/python/quantize/quantize_stat.py
+18
-0
tools/python/quantize/quantize_util.py
tools/python/quantize/quantize_util.py
+23
-3
tools/python/quantize/quantize_util_test.py
tools/python/quantize/quantize_util_test.py
+11
-24
tools/python/run.py
tools/python/run.py
+0
-0
tools/python/template/model.jinja2
tools/python/template/model.jinja2
+0
-5
tools/python/template/model_header.jinja2
tools/python/template/model_header.jinja2
+0
-0
tools/python/template/operator.jinja2
tools/python/template/operator.jinja2
+182
-0
tools/python/template/tensor_data.jinja2
tools/python/template/tensor_data.jinja2
+0
-0
tools/python/template/tensor_source.jinja2
tools/python/template/tensor_source.jinja2
+2
-2
tools/python/transform/__init__.py
tools/python/transform/__init__.py
+0
-0
tools/python/transform/apu_converter.py
tools/python/transform/apu_converter.py
+12
-13
tools/python/transform/base_converter.py
tools/python/transform/base_converter.py
+1
-1
tools/python/transform/caffe_converter.py
tools/python/transform/caffe_converter.py
+14
-14
tools/python/transform/hexagon_converter.py
tools/python/transform/hexagon_converter.py
+23
-30
tools/python/transform/onnx_converter.py
tools/python/transform/onnx_converter.py
+14
-14
tools/python/transform/shape_inference.py
tools/python/transform/shape_inference.py
+6
-6
tools/python/transform/tensorflow_converter.py
tools/python/transform/tensorflow_converter.py
+14
-14
tools/python/transform/transformer.py
tools/python/transform/transformer.py
+38
-20
tools/python/utils/__init__.py
tools/python/utils/__init__.py
+0
-0
tools/python/utils/config_parser.py
tools/python/utils/config_parser.py
+65
-0
tools/python/utils/device.py
tools/python/utils/device.py
+0
-0
tools/python/utils/target.py
tools/python/utils/target.py
+0
-0
tools/python/utils/util.py
tools/python/utils/util.py
+101
-0
tools/python/visualize/__init__.py
tools/python/visualize/__init__.py
+0
-0
tools/python/visualize/index.html
tools/python/visualize/index.html
+0
-0
tools/python/visualize/visualize_model.py
tools/python/visualize/visualize_model.py
+5
-3
tools/sh_commands.py
tools/sh_commands.py
+0
-1
未找到文件。
.gitignore
浏览文件 @
320b509c
...
@@ -28,3 +28,4 @@ examples/android/macelibrary/src/main/cpp/include/
...
@@ -28,3 +28,4 @@ examples/android/macelibrary/src/main/cpp/include/
examples/android/macelibrary/src/main/cpp/lib/arm64-v8a/
examples/android/macelibrary/src/main/cpp/lib/arm64-v8a/
examples/android/macelibrary/src/main/jniLibs/arm64-v8a/
examples/android/macelibrary/src/main/jniLibs/arm64-v8a/
tools/python/py_proto/*_pb2.py
docs/user_guide/quantization_usage.rst
浏览文件 @
320b509c
...
@@ -62,7 +62,7 @@ MACE provides tools to do statistics with following steps:
...
@@ -62,7 +62,7 @@ MACE provides tools to do statistics with following steps:
.. code:: sh
.. code:: sh
python
mace/python/tools/quantization
/quantize_stat.py --log_file range_log > overall_range
python
tools/python/tools/quantize
/quantize_stat.py --log_file range_log > overall_range
4. Convert quantized model (by setting `target_abis` to the final target abis, e.g., `armeabi-v7a`,
4. Convert quantized model (by setting `target_abis` to the final target abis, e.g., `armeabi-v7a`,
...
...
mace/codegen/BUILD.bazel
浏览文件 @
320b509c
...
@@ -9,8 +9,8 @@ load("//mace:mace.bzl", "encrypt_opencl_kernel_genrule", "mace_version_genrule")
...
@@ -9,8 +9,8 @@ load("//mace:mace.bzl", "encrypt_opencl_kernel_genrule", "mace_version_genrule")
cc_library
(
cc_library
(
name
=
"generated_models"
,
name
=
"generated_models"
,
srcs
=
glob
([
"models/*/*.cc"
]),
srcs
=
glob
([
"models/*
*
/*.cc"
]),
hdrs
=
glob
([
"models/*/*.h"
]),
hdrs
=
glob
([
"models/*
*
/*.h"
]),
copts
=
[
copts
=
[
"-Werror"
,
"-Werror"
,
"-Wextra"
,
"-Wextra"
,
...
...
mace/python/tools/BUILD.bazel
浏览文件 @
320b509c
py_library
(
name
=
"quantization_lib"
,
srcs
=
[
"quantization/quantize_util.py"
,
],
srcs_version
=
"PY2AND3"
,
)
py_library
(
name
=
"converter_lib"
,
srcs
=
[
"convert_util.py"
,
"converter_tool/base_converter.py"
,
"converter_tool/caffe_converter.py"
,
"converter_tool/hexagon_converter.py"
,
"converter_tool/onnx_converter.py"
,
"converter_tool/shape_inference.py"
,
"converter_tool/tensorflow_converter.py"
,
"converter_tool/apu_converter.py"
,
"converter_tool/transformer.py"
,
"graph_util.py"
,
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":quantization_lib"
,
"//mace/proto:mace_py"
,
"//third_party/caffe:caffe_py"
,
],
)
py_library
(
name
=
"model_saver_lib"
,
srcs
=
[
"model_saver.py"
,
],
srcs_version
=
"PY2AND3"
,
deps
=
[
"//mace/proto:mace_py"
,
],
)
py_binary
(
name
=
"converter"
,
srcs
=
[
"converter.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":converter_lib"
,
":model_saver_lib"
,
"//mace/python/tools/visualization:visualization_lib"
,
"@six_archive//:six"
,
],
)
py_binary
(
py_binary
(
name
=
"archive_static_lib"
,
name
=
"archive_static_lib"
,
srcs
=
[
"archive_static_lib.py"
],
srcs
=
[
"archive_static_lib.py"
],
...
...
mace/python/tools/binary_codegen.py
已删除
100644 → 0
浏览文件 @
64419daa
# Copyright 2018 The MACE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
sys
import
struct
import
jinja2
import
numpy
as
np
import
six
# python mace/python/tools/binary_codegen.py \
# --binary_dirs=${BIN_FILE} \
# --binary_file_name=mace_run.config \
# --output_path=${CODE_GEN_PATH} --variable_name=kTuningParamsData
FLAGS
=
None
def
generate_cpp_source
(
binary_dirs
,
binary_file_name
,
variable_name
):
data_map
=
{}
for
binary_dir
in
binary_dirs
.
split
(
","
):
binary_path
=
os
.
path
.
join
(
binary_dir
,
binary_file_name
)
if
not
os
.
path
.
exists
(
binary_path
):
continue
with
open
(
binary_path
,
"rb"
)
as
f
:
binary_array
=
np
.
fromfile
(
f
,
dtype
=
np
.
uint8
)
six
.
print_
(
"Generate binary from"
,
binary_path
)
idx
=
0
size
,
=
struct
.
unpack
(
"Q"
,
binary_array
[
idx
:
idx
+
8
])
idx
+=
8
for
_
in
six
.
moves
.
range
(
size
):
key_size
,
=
struct
.
unpack
(
"i"
,
binary_array
[
idx
:
idx
+
4
])
idx
+=
4
key
,
=
struct
.
unpack
(
str
(
key_size
)
+
"s"
,
binary_array
[
idx
:
idx
+
key_size
])
idx
+=
key_size
params_size
,
=
struct
.
unpack
(
"i"
,
binary_array
[
idx
:
idx
+
4
])
idx
+=
4
data_map
[
key
]
=
[]
count
=
params_size
/
4
params
=
struct
.
unpack
(
str
(
count
)
+
"i"
,
binary_array
[
idx
:
idx
+
params_size
])
for
i
in
params
:
data_map
[
key
].
append
(
i
)
idx
+=
params_size
env
=
jinja2
.
Environment
(
loader
=
jinja2
.
FileSystemLoader
(
sys
.
path
[
0
]))
return
env
.
get_template
(
'str2vec_maps.cc.jinja2'
).
render
(
maps
=
data_map
,
data_type
=
'unsigned int'
,
variable_name
=
variable_name
)
def
tuning_param_codegen
(
binary_dirs
,
binary_file_name
,
output_path
,
variable_name
):
cpp_binary_source
=
generate_cpp_source
(
binary_dirs
,
binary_file_name
,
variable_name
)
if
os
.
path
.
isfile
(
output_path
):
os
.
remove
(
output_path
)
with
open
(
output_path
,
"w"
)
as
w_file
:
w_file
.
write
(
cpp_binary_source
)
def
parse_args
():
"""Parses command line arguments."""
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--binary_dirs"
,
type
=
str
,
default
=
""
,
help
=
"The binaries file path."
)
parser
.
add_argument
(
"--binary_file_name"
,
type
=
str
,
default
=
"mace_run.config"
,
help
=
"The binary file name."
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
default
=
""
,
help
=
"The path of generated C++ source file which contains the binary."
)
parser
.
add_argument
(
"--variable_name"
,
type
=
str
,
default
=
"kTuningParamsData"
,
help
=
"global variable name."
)
return
parser
.
parse_known_args
()
if
__name__
==
'__main__'
:
FLAGS
,
unparsed
=
parse_args
()
tuning_param_codegen
(
FLAGS
.
binary_dirs
,
FLAGS
.
binary_file_name
,
FLAGS
.
output_path
,
FLAGS
.
variable_name
)
mace/python/tools/convert_util.py
已删除
100644 → 0
浏览文件 @
64419daa
# Copyright 2018 The MACE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
enum
def
mace_check
(
condition
,
msg
):
if
not
condition
:
raise
Exception
(
msg
)
def
roundup_div4
(
value
):
return
int
((
value
+
3
)
//
4
)
class
OpenCLBufferType
(
enum
.
Enum
):
CONV2D_FILTER
=
0
IN_OUT_CHANNEL
=
1
ARGUMENT
=
2
IN_OUT_HEIGHT
=
3
IN_OUT_WIDTH
=
4
WINOGRAD_FILTER
=
5
DW_CONV2D_FILTER
=
6
WEIGHT_HEIGHT
=
7
WEIGHT_WIDTH
=
8
def
calculate_image_shape
(
buffer_type
,
shape
,
winograd_blk_size
=
0
):
# keep the same with mace/kernel/opencl/helper.cc
image_shape
=
[
0
,
0
]
if
buffer_type
==
OpenCLBufferType
.
CONV2D_FILTER
:
mace_check
(
len
(
shape
)
==
4
,
"Conv2D filter buffer should be 4D"
)
image_shape
[
0
]
=
shape
[
1
]
image_shape
[
1
]
=
shape
[
2
]
*
shape
[
3
]
*
roundup_div4
(
shape
[
0
])
elif
buffer_type
==
OpenCLBufferType
.
IN_OUT_CHANNEL
:
mace_check
(
len
(
shape
)
==
2
or
len
(
shape
)
==
4
,
"input/output buffer should be 2D|4D"
)
if
len
(
shape
)
==
4
:
image_shape
[
0
]
=
roundup_div4
(
shape
[
3
])
*
shape
[
2
]
image_shape
[
1
]
=
shape
[
0
]
*
shape
[
1
]
elif
len
(
shape
)
==
2
:
image_shape
[
0
]
=
roundup_div4
(
shape
[
1
])
image_shape
[
1
]
=
shape
[
0
]
elif
buffer_type
==
OpenCLBufferType
.
ARGUMENT
:
mace_check
(
len
(
shape
)
==
1
,
"Argument buffer should be 1D not "
+
str
(
shape
))
image_shape
[
0
]
=
roundup_div4
(
shape
[
0
])
image_shape
[
1
]
=
1
elif
buffer_type
==
OpenCLBufferType
.
IN_OUT_HEIGHT
:
if
len
(
shape
)
==
4
:
image_shape
[
0
]
=
shape
[
2
]
*
shape
[
3
]
image_shape
[
1
]
=
shape
[
0
]
*
roundup_div4
(
shape
[
1
])
elif
len
(
shape
)
==
2
:
image_shape
[
0
]
=
shape
[
0
]
image_shape
[
1
]
=
roundup_div4
(
shape
[
1
])
elif
buffer_type
==
OpenCLBufferType
.
IN_OUT_WIDTH
:
mace_check
(
len
(
shape
)
==
4
,
"Input/output buffer should be 4D"
)
image_shape
[
0
]
=
roundup_div4
(
shape
[
2
])
*
shape
[
3
]
image_shape
[
1
]
=
shape
[
0
]
*
shape
[
1
]
elif
buffer_type
==
OpenCLBufferType
.
WINOGRAD_FILTER
:
mace_check
(
len
(
shape
)
==
4
,
"Winograd filter buffer should be 4D"
)
image_shape
[
0
]
=
roundup_div4
(
shape
[
1
])
image_shape
[
1
]
=
(
shape
[
0
]
*
(
winograd_blk_size
+
2
)
*
(
winograd_blk_size
+
2
))
elif
buffer_type
==
OpenCLBufferType
.
DW_CONV2D_FILTER
:
mace_check
(
len
(
shape
)
==
4
,
"Winograd filter buffer should be 4D"
)
image_shape
[
0
]
=
shape
[
0
]
*
shape
[
2
]
*
shape
[
3
]
image_shape
[
1
]
=
roundup_div4
(
shape
[
1
])
elif
buffer_type
==
OpenCLBufferType
.
WEIGHT_HEIGHT
:
mace_check
(
len
(
shape
)
==
4
,
"Weight buffer should be 4D"
)
image_shape
[
0
]
=
shape
[
1
]
*
shape
[
2
]
*
shape
[
3
]
image_shape
[
1
]
=
roundup_div4
(
shape
[
0
])
elif
buffer_type
==
OpenCLBufferType
.
WEIGHT_WIDTH
:
mace_check
(
len
(
shape
)
==
4
,
"Weight buffer should be 4D"
)
image_shape
[
0
]
=
roundup_div4
(
shape
[
1
])
*
shape
[
2
]
*
shape
[
3
]
image_shape
[
1
]
=
shape
[
0
]
else
:
mace_check
(
False
,
"OpenCL Image do not support type "
+
str
(
buffer_type
))
return
image_shape
mace/python/tools/converter.py
已删除
100644 → 0
浏览文件 @
64419daa
# Copyright 2018 The MACE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
sys
import
hashlib
import
os.path
import
copy
import
six
from
mace.proto
import
mace_pb2
from
mace.python.tools
import
model_saver
from
mace.python.tools.converter_tool
import
base_converter
as
cvt
from
mace.python.tools.converter_tool
import
transformer
from
mace.python.tools.convert_util
import
mace_check
from
mace.python.tools.visualization
import
visualize_model
# ./bazel-bin/mace/python/tools/tf_converter --model_file quantized_test.pb \
# --output quantized_test_dsp.pb \
# --runtime dsp \
# --input_dim input_node,1,28,28,3
FLAGS
=
None
device_type_map
=
{
'cpu'
:
cvt
.
DeviceType
.
CPU
.
value
,
'gpu'
:
cvt
.
DeviceType
.
GPU
.
value
,
'dsp'
:
cvt
.
DeviceType
.
HEXAGON
.
value
,
'hta'
:
cvt
.
DeviceType
.
HTA
.
value
,
'apu'
:
cvt
.
DeviceType
.
APU
.
value
,
'cpu+gpu'
:
cvt
.
DeviceType
.
CPU
.
value
}
data_format_map
=
{
'NONE'
:
cvt
.
DataFormat
.
NONE
,
'NHWC'
:
cvt
.
DataFormat
.
NHWC
,
'NCHW'
:
cvt
.
DataFormat
.
NCHW
,
'OIHW'
:
cvt
.
DataFormat
.
OIHW
,
}
data_type_map
=
{
'float32'
:
mace_pb2
.
DT_FLOAT
,
'int32'
:
mace_pb2
.
DT_INT32
,
}
def
parse_data_type
(
data_type
,
device_type
):
if
device_type
==
cvt
.
DeviceType
.
CPU
.
value
or
\
device_type
==
cvt
.
DeviceType
.
GPU
.
value
:
if
data_type
==
'fp32_fp32'
:
return
mace_pb2
.
DT_FLOAT
else
:
return
mace_pb2
.
DT_HALF
elif
device_type
==
cvt
.
DeviceType
.
HEXAGON
.
value
or
\
device_type
==
cvt
.
DeviceType
.
HTA
.
value
:
return
mace_pb2
.
DT_FLOAT
elif
device_type
==
cvt
.
DeviceType
.
APU
.
value
:
return
mace_pb2
.
DT_FLOAT
else
:
print
(
"Invalid device type: "
+
str
(
device_type
))
def
file_checksum
(
fname
):
hash_func
=
hashlib
.
sha256
()
with
open
(
fname
,
"rb"
)
as
f
:
for
chunk
in
iter
(
lambda
:
f
.
read
(
4096
),
b
""
):
hash_func
.
update
(
chunk
)
return
hash_func
.
hexdigest
()
def
split_shape
(
shape
):
if
shape
.
strip
()
==
""
:
return
[]
else
:
return
shape
.
split
(
','
)
def
parse_int_array_from_str
(
ints_str
):
return
[
int
(
i
)
for
i
in
split_shape
(
ints_str
)]
def
parse_float_array_from_str
(
floats_str
):
return
[
float
(
i
)
for
i
in
floats_str
.
split
(
','
)]
def
transpose_shape
(
shape
,
dst_order
):
t_shape
=
[
0
]
*
len
(
shape
)
for
i
in
range
(
len
(
shape
)):
t_shape
[
i
]
=
shape
[
dst_order
[
i
]]
return
t_shape
def
main
(
unused_args
):
if
not
os
.
path
.
isfile
(
FLAGS
.
model_file
):
six
.
print_
(
"Input graph file '"
+
FLAGS
.
model_file
+
"' does not exist!"
,
file
=
sys
.
stderr
)
sys
.
exit
(
-
1
)
model_checksum
=
file_checksum
(
FLAGS
.
model_file
)
if
FLAGS
.
model_checksum
!=
""
and
FLAGS
.
model_checksum
!=
model_checksum
:
six
.
print_
(
"Model checksum mismatch: %s != %s"
%
(
model_checksum
,
FLAGS
.
model_checksum
),
file
=
sys
.
stderr
)
sys
.
exit
(
-
1
)
weight_checksum
=
None
if
FLAGS
.
platform
==
'caffe'
:
if
not
os
.
path
.
isfile
(
FLAGS
.
weight_file
):
six
.
print_
(
"Input weight file '"
+
FLAGS
.
weight_file
+
"' does not exist!"
,
file
=
sys
.
stderr
)
sys
.
exit
(
-
1
)
weight_checksum
=
file_checksum
(
FLAGS
.
weight_file
)
if
FLAGS
.
weight_checksum
!=
""
and
\
FLAGS
.
weight_checksum
!=
weight_checksum
:
six
.
print_
(
"Weight checksum mismatch: %s != %s"
%
(
weight_checksum
,
FLAGS
.
weight_checksum
),
file
=
sys
.
stderr
)
sys
.
exit
(
-
1
)
if
FLAGS
.
platform
not
in
[
'tensorflow'
,
'caffe'
,
'onnx'
]:
six
.
print_
(
"platform %s is not supported."
%
FLAGS
.
platform
,
file
=
sys
.
stderr
)
sys
.
exit
(
-
1
)
if
FLAGS
.
runtime
not
in
[
'cpu'
,
'gpu'
,
'dsp'
,
'hta'
,
'apu'
,
'cpu+gpu'
]:
six
.
print_
(
"runtime %s is not supported."
%
FLAGS
.
runtime
,
file
=
sys
.
stderr
)
sys
.
exit
(
-
1
)
option
=
cvt
.
ConverterOption
()
if
FLAGS
.
graph_optimize_options
:
option
.
transformer_option
=
FLAGS
.
graph_optimize_options
.
split
(
','
)
option
.
winograd
=
FLAGS
.
winograd
option
.
quantize
=
FLAGS
.
quantize
option
.
quantize_large_weights
=
FLAGS
.
quantize_large_weights
option
.
quantize_range_file
=
FLAGS
.
quantize_range_file
option
.
change_concat_ranges
=
FLAGS
.
change_concat_ranges
option
.
cl_mem_type
=
FLAGS
.
cl_mem_type
option
.
device
=
device_type_map
[
FLAGS
.
runtime
]
option
.
data_type
=
parse_data_type
(
FLAGS
.
data_type
,
option
.
device
)
input_node_names
=
FLAGS
.
input_node
.
split
(
','
)
input_data_types
=
FLAGS
.
input_data_types
.
split
(
','
)
input_node_shapes
=
FLAGS
.
input_shape
.
split
(
':'
)
input_node_formats
=
FLAGS
.
input_data_formats
.
split
(
","
)
if
FLAGS
.
input_range
:
input_node_ranges
=
FLAGS
.
input_range
.
split
(
':'
)
else
:
input_node_ranges
=
[]
if
len
(
input_node_names
)
!=
len
(
input_node_shapes
):
raise
Exception
(
'input node count and shape count do not match.'
)
for
i
in
six
.
moves
.
range
(
len
(
input_node_names
)):
input_node
=
cvt
.
NodeInfo
()
input_node
.
name
=
input_node_names
[
i
]
input_node
.
data_type
=
data_type_map
[
input_data_types
[
i
]]
input_node
.
data_format
=
data_format_map
[
input_node_formats
[
i
]]
input_node
.
shape
=
parse_int_array_from_str
(
input_node_shapes
[
i
])
if
input_node
.
data_format
==
cvt
.
DataFormat
.
NCHW
and
\
len
(
input_node
.
shape
)
==
4
:
input_node
.
shape
=
transpose_shape
(
input_node
.
shape
,
[
0
,
2
,
3
,
1
])
input_node
.
data_format
=
cvt
.
DataFormat
.
NHWC
if
len
(
input_node_ranges
)
>
i
:
input_node
.
range
=
parse_float_array_from_str
(
input_node_ranges
[
i
])
option
.
add_input_node
(
input_node
)
output_node_names
=
FLAGS
.
output_node
.
split
(
','
)
output_data_types
=
FLAGS
.
output_data_types
.
split
(
','
)
output_node_shapes
=
FLAGS
.
output_shape
.
split
(
':'
)
output_node_formats
=
FLAGS
.
output_data_formats
.
split
(
","
)
if
len
(
output_node_names
)
!=
len
(
output_node_shapes
):
raise
Exception
(
'output node count and shape count do not match.'
)
for
i
in
six
.
moves
.
range
(
len
(
output_node_names
)):
output_node
=
cvt
.
NodeInfo
()
output_node
.
name
=
output_node_names
[
i
]
output_node
.
data_type
=
data_type_map
[
output_data_types
[
i
]]
output_node
.
data_format
=
data_format_map
[
output_node_formats
[
i
]]
output_node
.
shape
=
parse_int_array_from_str
(
output_node_shapes
[
i
])
if
output_node
.
data_format
==
cvt
.
DataFormat
.
NCHW
and
\
len
(
output_node
.
shape
)
==
4
:
output_node
.
shape
=
transpose_shape
(
output_node
.
shape
,
[
0
,
2
,
3
,
1
])
output_node
.
data_format
=
cvt
.
DataFormat
.
NHWC
option
.
add_output_node
(
output_node
)
if
FLAGS
.
check_node
!=
''
:
check_node_names
=
FLAGS
.
check_node
.
split
(
','
)
check_node_shapes
=
FLAGS
.
check_shape
.
split
(
':'
)
if
len
(
check_node_names
)
!=
len
(
check_node_shapes
):
raise
Exception
(
'check node count and shape count do not match.'
)
for
i
in
six
.
moves
.
range
(
len
(
check_node_names
)):
check_node
=
cvt
.
NodeInfo
()
check_node
.
name
=
check_node_names
[
i
]
check_node
.
shape
=
parse_int_array_from_str
(
check_node_shapes
[
i
])
option
.
add_check_node
(
check_node
)
else
:
option
.
check_nodes
=
option
.
output_nodes
option
.
build
()
print
(
"Transform model to one that can better run on device"
)
if
FLAGS
.
platform
==
'tensorflow'
:
from
mace.python.tools.converter_tool
import
tensorflow_converter
converter
=
tensorflow_converter
.
TensorflowConverter
(
option
,
FLAGS
.
model_file
)
elif
FLAGS
.
platform
==
'caffe'
:
from
mace.python.tools.converter_tool
import
caffe_converter
converter
=
caffe_converter
.
CaffeConverter
(
option
,
FLAGS
.
model_file
,
FLAGS
.
weight_file
)
elif
FLAGS
.
platform
==
'onnx'
:
from
mace.python.tools.converter_tool
import
onnx_converter
converter
=
onnx_converter
.
OnnxConverter
(
option
,
FLAGS
.
model_file
)
else
:
six
.
print_
(
"Mace do not support platorm %s yet."
%
FLAGS
.
platform
,
file
=
sys
.
stderr
)
exit
(
1
)
output_graph_def
=
converter
.
run
()
mace_transformer
=
transformer
.
Transformer
(
option
,
output_graph_def
)
output_graph_def
,
quantize_activation_info
=
mace_transformer
.
run
()
if
option
.
device
in
[
cvt
.
DeviceType
.
HEXAGON
.
value
,
cvt
.
DeviceType
.
HTA
.
value
]:
from
mace.python.tools.converter_tool
import
hexagon_converter
converter
=
hexagon_converter
.
HexagonConverter
(
option
,
output_graph_def
,
quantize_activation_info
)
output_graph_def
=
converter
.
run
()
elif
FLAGS
.
runtime
==
'apu'
:
if
FLAGS
.
platform
!=
'tensorflow'
:
raise
Exception
(
'apu only support model from tensorflow'
)
from
mace.python.tools.converter_tool
import
apu_converter
converter
=
apu_converter
.
ApuConverter
(
option
,
output_graph_def
,
quantize_activation_info
)
output_graph_def
=
converter
.
run
()
try
:
visualizer
=
visualize_model
.
ModelVisualizer
(
FLAGS
.
model_tag
,
output_graph_def
)
visualizer
.
save_html
()
except
:
# noqa
print
(
"Failed to visualize model:"
,
sys
.
exc_info
()[
0
])
model_saver
.
save_model
(
option
,
output_graph_def
,
model_checksum
,
weight_checksum
,
FLAGS
.
template_dir
,
FLAGS
.
obfuscate
,
FLAGS
.
model_tag
,
FLAGS
.
output_dir
,
FLAGS
.
embed_model_data
,
FLAGS
.
winograd
,
FLAGS
.
model_graph_format
)
def
str2bool
(
v
):
if
v
.
lower
()
in
(
'yes'
,
'true'
,
't'
,
'y'
,
'1'
):
return
True
elif
v
.
lower
()
in
(
'no'
,
'false'
,
'f'
,
'n'
,
'0'
):
return
False
else
:
raise
argparse
.
ArgumentTypeError
(
'Boolean value expected.'
)
def
parse_args
():
"""Parses command line arguments."""
parser
=
argparse
.
ArgumentParser
()
parser
.
register
(
"type"
,
"bool"
,
lambda
v
:
v
.
lower
()
==
"true"
)
parser
.
add_argument
(
"--model_file"
,
type
=
str
,
default
=
""
,
help
=
"TensorFlow
\'
GraphDef
\'
file to load, "
"Onnx model file .onnx to load, "
"Caffe prototxt file to load."
)
parser
.
add_argument
(
"--weight_file"
,
type
=
str
,
default
=
""
,
help
=
"Caffe data file to load."
)
parser
.
add_argument
(
"--model_checksum"
,
type
=
str
,
default
=
""
,
help
=
"Model file sha256 checksum"
)
parser
.
add_argument
(
"--weight_checksum"
,
type
=
str
,
default
=
""
,
help
=
"Weight file sha256 checksum"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
""
,
help
=
"File to save the output graph to."
)
parser
.
add_argument
(
"--runtime"
,
type
=
str
,
default
=
""
,
help
=
"Runtime: cpu/gpu/dsp/apu"
)
parser
.
add_argument
(
"--input_node"
,
type
=
str
,
default
=
"input_node"
,
help
=
"e.g., input_node"
)
parser
.
add_argument
(
"--input_data_types"
,
type
=
str
,
default
=
"float32"
,
help
=
"e.g., float32|int32"
)
parser
.
add_argument
(
"--input_data_formats"
,
type
=
str
,
default
=
"NHWC"
,
help
=
"e.g., NHWC,NONE"
)
parser
.
add_argument
(
"--output_node"
,
type
=
str
,
default
=
"softmax"
,
help
=
"e.g., softmax"
)
parser
.
add_argument
(
"--output_data_types"
,
type
=
str
,
default
=
"float32"
,
help
=
"e.g., float32|int32"
)
parser
.
add_argument
(
"--output_data_formats"
,
type
=
str
,
default
=
"NHWC"
,
help
=
"e.g., NHWC,NONE"
)
parser
.
add_argument
(
"--check_node"
,
type
=
str
,
default
=
"softmax"
,
help
=
"e.g., softmax"
)
parser
.
add_argument
(
"--template_dir"
,
type
=
str
,
default
=
""
,
help
=
"template path"
)
parser
.
add_argument
(
"--obfuscate"
,
type
=
str2bool
,
nargs
=
'?'
,
const
=
False
,
default
=
False
,
help
=
"obfuscate model names"
)
parser
.
add_argument
(
"--model_tag"
,
type
=
str
,
default
=
""
,
help
=
"model tag for generated function and namespace"
)
parser
.
add_argument
(
"--winograd"
,
type
=
int
,
default
=
0
,
help
=
"Which version of winograd convolution to use. [2 | 4]"
)
parser
.
add_argument
(
"--dsp_mode"
,
type
=
int
,
default
=
0
,
help
=
"dsp run mode, defalut=0"
)
parser
.
add_argument
(
"--input_shape"
,
type
=
str
,
default
=
""
,
help
=
"input shape."
)
parser
.
add_argument
(
"--input_range"
,
type
=
str
,
default
=
""
,
help
=
"input range."
)
parser
.
add_argument
(
"--output_shape"
,
type
=
str
,
default
=
""
,
help
=
"output shape."
)
parser
.
add_argument
(
"--check_shape"
,
type
=
str
,
default
=
""
,
help
=
"check shape."
)
parser
.
add_argument
(
"--platform"
,
type
=
str
,
default
=
"tensorflow"
,
help
=
"tensorflow/caffe/onnx"
)
parser
.
add_argument
(
"--embed_model_data"
,
type
=
str2bool
,
default
=
True
,
help
=
"embed model data."
)
parser
.
add_argument
(
"--model_graph_format"
,
type
=
str
,
default
=
"file"
,
help
=
"[file|code] build models to code"
+
"or `Protobuf` file."
)
parser
.
add_argument
(
"--data_type"
,
type
=
str
,
default
=
"fp16_fp32"
,
help
=
"fp16_fp32/fp32_fp32"
)
parser
.
add_argument
(
"--graph_optimize_options"
,
type
=
str
,
default
=
""
,
help
=
"graph optimize options"
)
parser
.
add_argument
(
"--quantize"
,
type
=
str2bool
,
nargs
=
'?'
,
const
=
False
,
default
=
False
,
help
=
"quantize model"
)
parser
.
add_argument
(
"--quantize_large_weights"
,
type
=
str2bool
,
nargs
=
'?'
,
const
=
False
,
default
=
False
,
help
=
"quantize large weights for compression"
)
parser
.
add_argument
(
"--quantize_range_file"
,
type
=
str
,
default
=
""
,
help
=
"file path of quantize range for each tensor"
)
parser
.
add_argument
(
"--change_concat_ranges"
,
type
=
str2bool
,
nargs
=
'?'
,
const
=
False
,
default
=
False
,
help
=
"change ranges to use memcpy for quantized concat"
)
parser
.
add_argument
(
"--cl_mem_type"
,
type
=
str
,
default
=
"image"
,
help
=
"which memory type to use.[image|buffer]"
)
return
parser
.
parse_known_args
()
if
__name__
==
'__main__'
:
FLAGS
,
unparsed
=
parse_args
()
main
(
unused_args
=
[
sys
.
argv
[
0
]]
+
unparsed
)
mace/python/tools/encrypt_opencl_codegen.py
浏览文件 @
320b509c
...
@@ -55,7 +55,8 @@ def create_output_dir(dir_path):
...
@@ -55,7 +55,8 @@ def create_output_dir(dir_path):
def
write_cl_encrypted_kernel_to_file
(
def
write_cl_encrypted_kernel_to_file
(
encrypted_code_maps
,
template_path
,
output_path
):
encrypted_code_maps
,
template_path
,
output_path
):
env
=
jinja2
.
Environment
(
loader
=
jinja2
.
FileSystemLoader
(
sys
.
path
[
0
]))
cwd
=
os
.
path
.
dirname
(
__file__
)
env
=
jinja2
.
Environment
(
loader
=
jinja2
.
FileSystemLoader
(
cwd
))
cl_encrypted_kernel
=
env
.
get_template
(
template_path
).
render
(
cl_encrypted_kernel
=
env
.
get_template
(
template_path
).
render
(
tag
=
'codegen'
,
tag
=
'codegen'
,
maps
=
encrypted_code_maps
,
maps
=
encrypted_code_maps
,
...
...
mace/python/tools/graph_util.py
已删除
100644 → 0
浏览文件 @
64419daa
# Copyright 2018 The MACE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tensorflow
as
tf
from
mace.proto
import
mace_pb2
from
collections
import
OrderedDict
def
sort_tf_node
(
node
,
nodes_map
,
ordered_nodes_map
):
if
node
.
name
not
in
ordered_nodes_map
:
for
input_tensor_name
in
node
.
input
:
input_node_name
=
input_tensor_name
.
split
(
':'
)[
0
]
if
':'
in
input_tensor_name
else
input_tensor_name
if
input_node_name
not
in
nodes_map
or
\
input_node_name
in
ordered_nodes_map
:
continue
input_node
=
nodes_map
[
input_node_name
]
sort_tf_node
(
input_node
,
nodes_map
,
ordered_nodes_map
)
ordered_nodes_map
[
node
.
name
]
=
node
def
sort_tf_graph
(
graph_def
):
nodes_map
=
{}
ordered_nodes_map
=
OrderedDict
()
for
node
in
graph_def
.
node
:
nodes_map
[
node
.
name
]
=
node
for
node
in
graph_def
.
node
:
sort_tf_node
(
node
,
nodes_map
,
ordered_nodes_map
)
sorted_graph
=
tf
.
GraphDef
()
sorted_graph
.
node
.
extend
([
node
for
node
in
ordered_nodes_map
.
values
()])
return
sorted_graph
def
sort_mace_node
(
node
,
nodes_map
,
ordered_nodes_map
):
if
node
.
name
not
in
ordered_nodes_map
:
for
input_tensor_name
in
node
.
input
:
input_node_name
=
input_tensor_name
.
split
(
':'
)[
0
]
if
':'
in
input_tensor_name
else
input_tensor_name
if
input_node_name
not
in
nodes_map
or
\
input_node_name
in
ordered_nodes_map
:
continue
input_node
=
nodes_map
[
input_node_name
]
sort_mace_node
(
input_node
,
nodes_map
,
ordered_nodes_map
)
ordered_nodes_map
[
node
.
name
]
=
node
def
sort_mace_graph
(
graph_def
,
output_name
):
nodes_map
=
{}
ordered_nodes_map
=
OrderedDict
()
for
node
in
graph_def
.
op
:
nodes_map
[
node
.
name
]
=
node
sort_mace_node
(
nodes_map
[
output_name
],
nodes_map
,
ordered_nodes_map
)
del
graph_def
.
op
[:]
graph_def
.
op
.
extend
([
node
for
node
in
ordered_nodes_map
.
values
()])
return
graph_def
mace/python/tools/mace_engine_factory_codegen.py
浏览文件 @
320b509c
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
argparse
import
argparse
import
os
from
jinja2
import
Environment
,
FileSystemLoader
from
jinja2
import
Environment
,
FileSystemLoader
...
@@ -20,11 +21,10 @@ from jinja2 import Environment, FileSystemLoader
...
@@ -20,11 +21,10 @@ from jinja2 import Environment, FileSystemLoader
FLAGS
=
None
FLAGS
=
None
def
gen_mace_engine_factory
(
model_tags
,
template_dir
,
def
gen_mace_engine_factory
(
model_tags
,
embed_model_data
,
output_dir
):
embed_model_data
,
output_dir
):
cwd
=
os
.
path
.
dirname
(
__file__
)
# Create the jinja2 environment.
j2_env
=
Environment
(
j2_env
=
Environment
(
loader
=
FileSystemLoader
(
template_dir
),
trim_blocks
=
True
)
loader
=
FileSystemLoader
(
cwd
),
trim_blocks
=
True
)
# generate mace_run BUILD file
# generate mace_run BUILD file
template_name
=
'mace_engine_factory.h.jinja2'
template_name
=
'mace_engine_factory.h.jinja2'
model_tags
=
list
(
model_tags
)
model_tags
=
list
(
model_tags
)
...
...
mace/python/tools/opencl_binary_codegen.py
浏览文件 @
320b509c
...
@@ -29,8 +29,9 @@ def generate_opencl_code(binary_file_name, load_func_name, size_func_name,
...
@@ -29,8 +29,9 @@ def generate_opencl_code(binary_file_name, load_func_name, size_func_name,
with
open
(
binary_file_name
,
'rb'
)
as
f
:
with
open
(
binary_file_name
,
'rb'
)
as
f
:
binary_array
=
np
.
fromfile
(
f
,
dtype
=
np
.
uint8
)
binary_array
=
np
.
fromfile
(
f
,
dtype
=
np
.
uint8
)
cwd
=
os
.
path
.
dirname
(
__file__
)
env
=
jinja2
.
Environment
(
env
=
jinja2
.
Environment
(
loader
=
jinja2
.
FileSystemLoader
(
sys
.
path
[
0
]
))
loader
=
jinja2
.
FileSystemLoader
(
cwd
))
content
=
env
.
get_template
(
'file_binary.cc.jinja2'
).
render
(
content
=
env
.
get_template
(
'file_binary.cc.jinja2'
).
render
(
data
=
binary_array
,
data
=
binary_array
,
data_size
=
len
(
binary_array
),
data_size
=
len
(
binary_array
),
...
...
mace/python/tools/quantization/quantize_util_test.py
已删除
100644 → 0
浏览文件 @
64419daa
import
unittest
import
numpy
as
np
import
quantize_util
class
TestQuantize
(
unittest
.
TestCase
):
def
test_quantize_dequantize
(
self
):
test_input
=
np
.
random
.
rand
(
20
,
30
)
*
5
quantized_data
=
quantize_util
.
quantize
(
test_input
)
dequantized_output
=
quantize_util
.
dequantize
(
quantized_data
)
np
.
testing
.
assert_array_almost_equal
(
test_input
,
dequantized_output
,
2
)
if
__name__
==
'__main__'
:
unittest
.
main
()
mace/python/tools/tf_ops_stats.py
已删除
100644 → 0
浏览文件 @
64419daa
# Copyright 2018 The MACE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
operator
import
functools
import
argparse
import
sys
import
copy
import
six
import
tensorflow
as
tf
from
tensorflow
import
gfile
from
tensorflow.core.framework
import
graph_pb2
from
tensorflow.core.framework
import
tensor_shape_pb2
# ./bazel-bin/mace/python/tools/tf_ops_stats --input model.pb
FLAGS
=
None
def
hist_inc
(
hist
,
key
):
if
key
in
hist
:
hist
[
key
]
+=
1
else
:
hist
[
key
]
=
1
def
to_int_list
(
long_list
):
int_list
=
[]
for
value
in
long_list
:
int_list
.
append
(
int
(
value
))
return
int_list
def
add_shape_info
(
input_graph_def
,
input_nodes
,
input_shapes
):
inputs_replaced_graph
=
graph_pb2
.
GraphDef
()
for
node
in
input_graph_def
.
node
:
if
node
.
name
in
input_nodes
or
node
.
name
+
':0'
in
input_nodes
:
if
node
.
name
in
input_nodes
:
idx
=
input_nodes
.
index
(
node
.
name
)
else
:
idx
=
input_nodes
.
index
(
node
.
name
+
':0'
)
input_shape
=
input_shapes
[
idx
]
print
(
input_shape
)
placeholder_node
=
copy
.
deepcopy
(
node
)
placeholder_node
.
attr
.
clear
()
placeholder_node
.
attr
[
'shape'
].
shape
.
dim
.
extend
([
tensor_shape_pb2
.
TensorShapeProto
.
Dim
(
size
=
i
)
for
i
in
input_shape
])
placeholder_node
.
attr
[
'dtype'
].
CopyFrom
(
node
.
attr
[
'dtype'
])
inputs_replaced_graph
.
node
.
extend
([
placeholder_node
])
else
:
inputs_replaced_graph
.
node
.
extend
([
copy
.
deepcopy
(
node
)])
return
inputs_replaced_graph
def
main
(
unused_args
):
if
not
FLAGS
.
input
or
not
gfile
.
Exists
(
FLAGS
.
input
):
print
(
'Input graph file '
+
FLAGS
.
input
+
' does not exist!'
)
return
-
1
input_graph_def
=
tf
.
GraphDef
()
with
gfile
.
Open
(
FLAGS
.
input
,
'rb'
)
as
f
:
data
=
f
.
read
()
input_graph_def
.
ParseFromString
(
data
)
input_nodes
=
[
x
for
x
in
FLAGS
.
input_tensors
.
split
(
','
)]
input_shapes
=
[]
if
FLAGS
.
input_shapes
!=
""
:
input_shape_strs
=
[
x
for
x
in
FLAGS
.
input_shapes
.
split
(
':'
)]
for
shape_str
in
input_shape_strs
:
input_shapes
.
extend
([[
int
(
x
)
for
x
in
shape_str
.
split
(
','
)]])
input_graph_def
=
add_shape_info
(
input_graph_def
,
input_nodes
,
input_shapes
)
with
tf
.
Session
()
as
session
:
with
session
.
graph
.
as_default
()
as
graph
:
tf
.
import_graph_def
(
input_graph_def
,
name
=
''
)
stats
=
{}
ops
=
graph
.
get_operations
()
# extract kernel size for conv_2d
tensor_shapes
=
{}
tensor_values
=
{}
print
(
"=========================consts============================"
)
for
op
in
ops
:
if
op
.
type
==
'Const'
:
for
output
in
op
.
outputs
:
tensor_name
=
output
.
name
tensor
=
output
.
eval
()
tensor_shape
=
list
(
tensor
.
shape
)
tensor_shapes
[
tensor_name
]
=
tensor_shape
print
(
"Const %s: %s, %d"
%
(
tensor_name
,
tensor_shape
,
functools
.
reduce
(
operator
.
mul
,
tensor_shape
,
1
)))
if
len
(
tensor_shape
)
==
1
and
tensor_shape
[
0
]
<
10
:
tensor_values
[
tensor_name
]
=
list
(
tensor
)
print
(
"=========================ops============================"
)
for
op
in
ops
:
if
op
.
type
in
[
'Conv2D'
]:
padding
=
op
.
get_attr
(
'padding'
)
strides
=
to_int_list
(
op
.
get_attr
(
'strides'
))
data_format
=
op
.
get_attr
(
'data_format'
)
ksize
=
'Unknown'
input
=
op
.
inputs
[
1
]
input_name
=
input
.
name
if
input_name
.
endswith
(
'read:0'
):
ksize
=
input
.
shape
.
as_list
()
elif
input_name
in
tensor_shapes
:
ksize
=
tensor_shapes
[
input_name
]
print
(
'%s(padding=%s, strides=%s, ksize=%s, format=%s) %s => %s'
%
(
op
.
type
,
padding
,
strides
,
ksize
,
data_format
,
op
.
inputs
[
0
].
shape
,
op
.
outputs
[
0
].
shape
))
key
=
'%s(padding=%s, strides=%s, ksize=%s, format=%s)'
%
(
op
.
type
,
padding
,
strides
,
ksize
,
data_format
)
hist_inc
(
stats
,
key
)
elif
op
.
type
in
[
'FusedResizeAndPadConv2D'
]:
padding
=
op
.
get_attr
(
'padding'
)
strides
=
to_int_list
(
op
.
get_attr
(
'strides'
))
resize_align_corners
=
op
.
get_attr
(
'resize_align_corners'
)
ksize
=
'Unknown'
for
input
in
op
.
inputs
:
input_name
=
input
.
name
if
input_name
.
endswith
(
'weights:0'
)
and
input_name
in
tensor_shapes
:
ksize
=
tensor_shapes
[
input_name
]
break
key
=
'%s(padding=%s, strides=%s, ksize=%s, '
\
'resize_align_corners=%s)'
%
(
op
.
type
,
padding
,
strides
,
ksize
,
resize_align_corners
)
hist_inc
(
stats
,
key
)
elif
op
.
type
in
[
'ResizeBilinear'
]:
align_corners
=
op
.
get_attr
(
'align_corners'
)
size
=
'Unknown'
for
input
in
op
.
inputs
:
input_name
=
input
.
name
if
input_name
.
endswith
(
'size:0'
)
and
input_name
in
tensor_values
:
size
=
tensor_values
[
input_name
]
break
key
=
'%s(size=%s, align_corners=%s)'
%
(
op
.
type
,
size
,
align_corners
)
print
(
key
)
hist_inc
(
stats
,
key
)
elif
op
.
type
in
[
'AvgPool'
,
'MaxPool'
]:
padding
=
op
.
get_attr
(
'padding'
)
strides
=
to_int_list
(
op
.
get_attr
(
'strides'
))
ksize
=
to_int_list
(
op
.
get_attr
(
'ksize'
))
data_format
=
op
.
get_attr
(
'data_format'
)
key
=
'%s(padding=%s, strides=%s, ksize=%s)'
%
(
op
.
type
,
padding
,
strides
,
ksize
)
hist_inc
(
stats
,
key
)
elif
op
.
type
in
[
'SpaceToBatchND'
,
'BatchToSpaceND'
]:
block_shape
=
'Unknown'
for
input
in
op
.
inputs
:
input_name
=
input
.
name
if
input_name
.
endswith
(
'block_shape:0'
)
and
input_name
in
tensor_values
:
block_shape
=
tensor_values
[
input_name
]
break
paddings
=
'Unknown'
for
input
in
op
.
inputs
:
input_name
=
input
.
name
if
input_name
.
endswith
(
'paddings:0'
)
and
input_name
in
tensor_values
:
paddings
=
tensor_values
[
input_name
]
break
crops
=
'Unknown'
for
input
in
op
.
inputs
:
input_name
=
input
.
name
if
input_name
.
endswith
(
'crops:0'
)
and
input_name
in
tensor_values
:
paddings
=
tensor_values
[
input_name
]
break
if
op
.
type
==
'SpaceToBatchND'
:
key
=
'%s(block_shape=%s, paddings=%s)'
%
(
op
.
type
,
block_shape
,
paddings
)
else
:
key
=
'%s(block_shape=%s, crops=%s)'
%
(
op
.
type
,
block_shape
,
crops
)
print
(
key
)
hist_inc
(
stats
,
key
)
elif
op
.
type
==
'Pad'
:
paddings
=
'Unknown'
for
input
in
op
.
inputs
:
input_name
=
input
.
name
if
input_name
.
endswith
(
'paddings:0'
)
and
input_name
in
tensor_values
:
paddings
=
tensor_values
[
input_name
]
break
key
=
'%s(paddings=%s)'
%
(
op
.
type
,
paddings
)
hist_inc
(
stats
,
key
)
else
:
hist_inc
(
stats
,
op
.
type
)
print
(
"=========================stats============================"
)
for
key
,
value
in
sorted
(
six
.
iteritems
(
stats
)):
print
(
'%s: %d'
%
(
key
,
value
))
def
parse_args
():
"""Parses command line arguments."""
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--input'
,
type
=
str
,
default
=
''
,
help
=
'TensorFlow
\'
GraphDef
\'
file to load.'
)
parser
.
add_argument
(
'--input_tensors'
,
type
=
str
,
default
=
''
,
help
=
'input tensor names split by comma.'
)
parser
.
add_argument
(
'--input_shapes'
,
type
=
str
,
default
=
''
,
help
=
'input tensor shapes split by colon and comma.'
)
return
parser
.
parse_known_args
()
if
__name__
==
'__main__'
:
FLAGS
,
unparsed
=
parse_args
()
main
(
unused_args
=
[
sys
.
argv
[
0
]]
+
unparsed
)
mace/python/tools/visualization/BUILD.bazel
已删除
100644 → 0
浏览文件 @
64419daa
py_library
(
name
=
"visualization_lib"
,
srcs
=
[
"visualize_model.py"
,
],
data
=
[
"index.html"
,
],
srcs_version
=
"PY2AND3"
,
visibility
=
[
"//visibility:public"
],
deps
=
[
"//mace/proto:mace_py"
,
],
)
tools/converter.py
浏览文件 @
320b509c
...
@@ -12,20 +12,25 @@
...
@@ -12,20 +12,25 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
argparse
import
glob
import
glob
import
sh
import
sh
import
sys
import
sys
import
time
import
time
import
yaml
import
yaml
from
enum
import
Enum
import
six
import
sh_commands
import
sh_commands
from
enum
import
Enum
sys
.
path
.
insert
(
0
,
"tools/python"
)
# noqa
from
common
import
*
from
common
import
*
from
device
import
DeviceWrapper
,
DeviceManager
from
device
import
DeviceWrapper
,
DeviceManager
from
utils
import
config_parser
import
convert
import
encrypt
################################
################################
# set environment
# set environment
...
@@ -711,111 +716,6 @@ def print_configuration(configs):
...
@@ -711,111 +716,6 @@ def print_configuration(configs):
MaceLogger
.
summary
(
StringFormatter
.
table
(
header
,
data
,
title
))
MaceLogger
.
summary
(
StringFormatter
.
table
(
header
,
data
,
title
))
def
convert_model
(
configs
,
cl_mem_type
):
# Remove previous output dirs
library_name
=
configs
[
YAMLKeyword
.
library_name
]
if
not
os
.
path
.
exists
(
BUILD_OUTPUT_DIR
):
os
.
makedirs
(
BUILD_OUTPUT_DIR
)
elif
os
.
path
.
exists
(
os
.
path
.
join
(
BUILD_OUTPUT_DIR
,
library_name
)):
sh
.
rm
(
"-rf"
,
os
.
path
.
join
(
BUILD_OUTPUT_DIR
,
library_name
))
os
.
makedirs
(
os
.
path
.
join
(
BUILD_OUTPUT_DIR
,
library_name
))
if
not
os
.
path
.
exists
(
BUILD_DOWNLOADS_DIR
):
os
.
makedirs
(
BUILD_DOWNLOADS_DIR
)
model_output_dir
=
\
'%s/%s/%s'
%
(
BUILD_OUTPUT_DIR
,
library_name
,
MODEL_OUTPUT_DIR_NAME
)
model_header_dir
=
\
'%s/%s/%s'
%
(
BUILD_OUTPUT_DIR
,
library_name
,
MODEL_HEADER_DIR_PATH
)
# clear output dir
if
os
.
path
.
exists
(
model_output_dir
):
sh
.
rm
(
"-rf"
,
model_output_dir
)
os
.
makedirs
(
model_output_dir
)
if
os
.
path
.
exists
(
model_header_dir
):
sh
.
rm
(
"-rf"
,
model_header_dir
)
embed_model_data
=
\
configs
[
YAMLKeyword
.
model_data_format
]
==
ModelFormat
.
code
if
os
.
path
.
exists
(
MODEL_CODEGEN_DIR
):
sh
.
rm
(
"-rf"
,
MODEL_CODEGEN_DIR
)
if
os
.
path
.
exists
(
ENGINE_CODEGEN_DIR
):
sh
.
rm
(
"-rf"
,
ENGINE_CODEGEN_DIR
)
if
configs
[
YAMLKeyword
.
model_graph_format
]
==
ModelFormat
.
code
:
os
.
makedirs
(
model_header_dir
)
sh_commands
.
gen_mace_engine_factory_source
(
configs
[
YAMLKeyword
.
models
].
keys
(),
embed_model_data
)
sh
.
cp
(
"-f"
,
glob
.
glob
(
"mace/codegen/engine/*.h"
),
model_header_dir
)
for
model_name
in
configs
[
YAMLKeyword
.
models
]:
MaceLogger
.
header
(
StringFormatter
.
block
(
"Convert %s model"
%
model_name
))
model_config
=
configs
[
YAMLKeyword
.
models
][
model_name
]
runtime
=
model_config
[
YAMLKeyword
.
runtime
]
if
cl_mem_type
:
model_config
[
YAMLKeyword
.
cl_mem_type
]
=
cl_mem_type
else
:
model_config
[
YAMLKeyword
.
cl_mem_type
]
=
"image"
data_type
=
model_config
[
YAMLKeyword
.
data_type
]
# TODO(liuqi): support multiple subgraphs
subgraphs
=
model_config
[
YAMLKeyword
.
subgraphs
]
model_codegen_dir
=
"%s/%s"
%
(
MODEL_CODEGEN_DIR
,
model_name
)
sh_commands
.
gen_model_code
(
model_codegen_dir
,
model_config
[
YAMLKeyword
.
platform
],
model_config
[
YAMLKeyword
.
model_file_path
],
model_config
[
YAMLKeyword
.
weight_file_path
],
model_config
[
YAMLKeyword
.
model_sha256_checksum
],
model_config
[
YAMLKeyword
.
weight_sha256_checksum
],
","
.
join
(
subgraphs
[
0
][
YAMLKeyword
.
input_tensors
]),
","
.
join
(
subgraphs
[
0
][
YAMLKeyword
.
input_data_types
]),
","
.
join
(
subgraphs
[
0
][
YAMLKeyword
.
input_data_formats
]),
","
.
join
(
subgraphs
[
0
][
YAMLKeyword
.
output_tensors
]),
","
.
join
(
subgraphs
[
0
][
YAMLKeyword
.
output_data_types
]),
","
.
join
(
subgraphs
[
0
][
YAMLKeyword
.
output_data_formats
]),
","
.
join
(
subgraphs
[
0
][
YAMLKeyword
.
check_tensors
]),
runtime
,
model_name
,
":"
.
join
(
subgraphs
[
0
][
YAMLKeyword
.
input_shapes
]),
":"
.
join
(
subgraphs
[
0
][
YAMLKeyword
.
input_ranges
]),
":"
.
join
(
subgraphs
[
0
][
YAMLKeyword
.
output_shapes
]),
":"
.
join
(
subgraphs
[
0
][
YAMLKeyword
.
check_shapes
]),
model_config
[
YAMLKeyword
.
nnlib_graph_mode
],
embed_model_data
,
model_config
[
YAMLKeyword
.
winograd
],
model_config
[
YAMLKeyword
.
quantize
],
model_config
[
YAMLKeyword
.
quantize_large_weights
],
model_config
[
YAMLKeyword
.
quantize_range_file
],
model_config
[
YAMLKeyword
.
change_concat_ranges
],
model_config
[
YAMLKeyword
.
obfuscate
],
configs
[
YAMLKeyword
.
model_graph_format
],
data_type
,
model_config
[
YAMLKeyword
.
cl_mem_type
],
","
.
join
(
model_config
.
get
(
YAMLKeyword
.
graph_optimize_options
,
[])))
if
configs
[
YAMLKeyword
.
model_graph_format
]
==
ModelFormat
.
file
:
sh
.
mv
(
"-f"
,
'%s/%s.pb'
%
(
model_codegen_dir
,
model_name
),
model_output_dir
)
sh
.
mv
(
"-f"
,
'%s/%s.data'
%
(
model_codegen_dir
,
model_name
),
model_output_dir
)
else
:
if
not
embed_model_data
:
sh
.
mv
(
"-f"
,
'%s/%s.data'
%
(
model_codegen_dir
,
model_name
),
model_output_dir
)
sh
.
cp
(
"-f"
,
glob
.
glob
(
"mace/codegen/models/*/*.h"
),
model_header_dir
)
MaceLogger
.
summary
(
StringFormatter
.
block
(
"Model %s converted"
%
model_name
))
def
build_model_lib
(
configs
,
address_sanitizer
,
debug_mode
):
def
build_model_lib
(
configs
,
address_sanitizer
,
debug_mode
):
MaceLogger
.
header
(
StringFormatter
.
block
(
"Building model library"
))
MaceLogger
.
header
(
StringFormatter
.
block
(
"Building model library"
))
...
@@ -863,13 +763,85 @@ def print_library_summary(configs):
...
@@ -863,13 +763,85 @@ def print_library_summary(configs):
def
convert_func
(
flags
):
def
convert_func
(
flags
):
configs
=
format_model_config
(
flags
)
configs
=
config_parser
.
parse
(
flags
.
config
)
library_name
=
configs
[
YAMLKeyword
.
library_name
]
if
not
os
.
path
.
exists
(
BUILD_OUTPUT_DIR
):
os
.
makedirs
(
BUILD_OUTPUT_DIR
)
elif
os
.
path
.
exists
(
os
.
path
.
join
(
BUILD_OUTPUT_DIR
,
library_name
)):
sh
.
rm
(
"-rf"
,
os
.
path
.
join
(
BUILD_OUTPUT_DIR
,
library_name
))
os
.
makedirs
(
os
.
path
.
join
(
BUILD_OUTPUT_DIR
,
library_name
))
if
not
os
.
path
.
exists
(
BUILD_DOWNLOADS_DIR
):
os
.
makedirs
(
BUILD_DOWNLOADS_DIR
)
print_configuration
(
configs
)
model_output_dir
=
\
'%s/%s/%s'
%
(
BUILD_OUTPUT_DIR
,
library_name
,
MODEL_OUTPUT_DIR_NAME
)
model_header_dir
=
\
'%s/%s/%s'
%
(
BUILD_OUTPUT_DIR
,
library_name
,
MODEL_HEADER_DIR_PATH
)
# clear output dir
if
os
.
path
.
exists
(
model_output_dir
):
sh
.
rm
(
"-rf"
,
model_output_dir
)
os
.
makedirs
(
model_output_dir
)
if
os
.
path
.
exists
(
model_header_dir
):
sh
.
rm
(
"-rf"
,
model_header_dir
)
convert_model
(
configs
,
flags
.
cl_mem_type
)
if
os
.
path
.
exists
(
MODEL_CODEGEN_DIR
):
sh
.
rm
(
"-rf"
,
MODEL_CODEGEN_DIR
)
if
os
.
path
.
exists
(
ENGINE_CODEGEN_DIR
):
sh
.
rm
(
"-rf"
,
ENGINE_CODEGEN_DIR
)
if
configs
[
YAMLKeyword
.
model_graph_format
]
==
ModelFormat
.
code
:
if
flags
.
model_data_format
:
model_data_format
=
flags
.
model_data_format
else
:
model_data_format
=
configs
.
get
(
YAMLKeyword
.
model_data_format
,
"file"
)
embed_model_data
=
model_data_format
==
ModelFormat
.
code
if
flags
.
model_graph_format
:
model_graph_format
=
flags
.
model_graph_format
else
:
model_graph_format
=
configs
.
get
(
YAMLKeyword
.
model_graph_format
,
"file"
)
if
model_graph_format
==
ModelFormat
.
code
:
os
.
makedirs
(
model_header_dir
)
sh_commands
.
gen_mace_engine_factory_source
(
configs
[
YAMLKeyword
.
models
].
keys
(),
embed_model_data
)
sh
.
cp
(
"-f"
,
glob
.
glob
(
"mace/codegen/engine/*.h"
),
model_header_dir
)
convert
.
convert
(
configs
,
MODEL_CODEGEN_DIR
)
for
model_name
,
model_config
in
configs
[
YAMLKeyword
.
models
].
items
():
model_codegen_dir
=
"%s/%s"
%
(
MODEL_CODEGEN_DIR
,
model_name
)
encrypt
.
encrypt
(
model_name
,
"%s/%s.pb"
%
(
model_codegen_dir
,
model_name
),
"%s/%s.data"
%
(
model_codegen_dir
,
model_name
),
model_config
[
YAMLKeyword
.
runtime
],
model_codegen_dir
,
bool
(
model_config
[
YAMLKeyword
.
obfuscate
]))
if
model_graph_format
==
ModelFormat
.
file
:
sh
.
mv
(
"-f"
,
'%s/file/%s.pb'
%
(
model_codegen_dir
,
model_name
),
model_output_dir
)
sh
.
mv
(
"-f"
,
'%s/file/%s.data'
%
(
model_codegen_dir
,
model_name
),
model_output_dir
)
sh
.
rm
(
"-rf"
,
'%s/code'
%
model_codegen_dir
)
else
:
if
not
embed_model_data
:
sh
.
mv
(
"-f"
,
'%s/file/%s.data'
%
(
model_codegen_dir
,
model_name
),
model_output_dir
)
sh
.
rm
(
'%s/code/tensor_data.cc'
%
model_codegen_dir
)
sh
.
cp
(
"-f"
,
glob
.
glob
(
"mace/codegen/models/*/code/*.h"
),
model_header_dir
)
MaceLogger
.
summary
(
StringFormatter
.
block
(
"Model %s converted"
%
model_name
))
if
model_graph_format
==
ModelFormat
.
code
:
build_model_lib
(
configs
,
flags
.
address_sanitizer
,
flags
.
debug_mode
)
build_model_lib
(
configs
,
flags
.
address_sanitizer
,
flags
.
debug_mode
)
print_library_summary
(
configs
)
print_library_summary
(
configs
)
...
@@ -1047,11 +1019,6 @@ def parse_args():
...
@@ -1047,11 +1019,6 @@ def parse_args():
'convert'
,
'convert'
,
parents
=
[
all_type_parent_parser
,
convert_run_parent_parser
],
parents
=
[
all_type_parent_parser
,
convert_run_parent_parser
],
help
=
'convert to mace model (file or code)'
)
help
=
'convert to mace model (file or code)'
)
convert
.
add_argument
(
"--cl_mem_type"
,
type
=
str
,
default
=
None
,
help
=
"Which type of OpenCL memory type to use [image | buffer]."
)
convert
.
set_defaults
(
func
=
convert_func
)
convert
.
set_defaults
(
func
=
convert_func
)
run
=
subparsers
.
add_parser
(
run
=
subparsers
.
add_parser
(
...
...
mace/python/tools/converter_tool
/__init__.py
→
tools/python
/__init__.py
浏览文件 @
320b509c
文件已移动
tools/
experimental
/config/model/model-template.yaml
→
tools/
python
/config/model/model-template.yaml
浏览文件 @
320b509c
文件已移动
tools/python/convert.py
0 → 100644
浏览文件 @
320b509c
# Copyright 2019 The MACE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
os
import
sys
import
numpy
as
np
from
utils
import
config_parser
from
utils
import
util
from
utils.util
import
mace_check
from
py_proto
import
mace_pb2
from
transform
import
base_converter
as
cvt
from
transform
import
transformer
from
visualize
import
visualize_model
device_type_map
=
{
'cpu'
:
cvt
.
DeviceType
.
CPU
.
value
,
'gpu'
:
cvt
.
DeviceType
.
GPU
.
value
,
'dsp'
:
cvt
.
DeviceType
.
HEXAGON
.
value
,
'hta'
:
cvt
.
DeviceType
.
HTA
.
value
,
'apu'
:
cvt
.
DeviceType
.
APU
.
value
,
'cpu+gpu'
:
cvt
.
DeviceType
.
CPU
.
value
}
data_format_map
=
{
'NONE'
:
cvt
.
DataFormat
.
NONE
,
'NHWC'
:
cvt
.
DataFormat
.
NHWC
,
'NCHW'
:
cvt
.
DataFormat
.
NCHW
,
'OIHW'
:
cvt
.
DataFormat
.
OIHW
,
}
data_type_map
=
{
'float32'
:
mace_pb2
.
DT_FLOAT
,
'int32'
:
mace_pb2
.
DT_INT32
,
}
def
parse_data_type
(
data_type
,
quantize
):
if
quantize
or
data_type
==
'fp32_fp32'
:
return
mace_pb2
.
DT_FLOAT
else
:
return
mace_pb2
.
DT_HALF
def
split_shape
(
shape
):
if
shape
.
strip
()
==
""
:
return
[]
else
:
return
shape
.
split
(
','
)
def
parse_int_array_from_str
(
ints_str
):
return
[
int
(
i
)
for
i
in
split_shape
(
ints_str
)]
def
parse_float_array_from_str
(
floats_str
):
return
[
float
(
i
)
for
i
in
floats_str
.
split
(
','
)]
def
transpose_shape
(
shape
,
dst_order
):
t_shape
=
[
0
]
*
len
(
shape
)
for
i
in
range
(
len
(
shape
)):
t_shape
[
i
]
=
shape
[
dst_order
[
i
]]
return
t_shape
def
to_list
(
x
):
if
isinstance
(
x
,
list
):
return
x
else
:
return
[
x
]
def
separate_params
(
mace_model
):
tensors
=
mace_model
.
tensors
params
=
mace_pb2
.
NetDef
()
params
.
tensors
.
extend
(
tensors
)
model
=
mace_model
del
model
.
tensors
[:]
return
model
,
params
def
convert
(
conf
,
output
):
if
not
os
.
path
.
exists
(
output
):
os
.
mkdir
(
output
)
for
model_name
,
model_conf
in
conf
[
"models"
].
items
():
model_output
=
output
+
"/"
+
model_name
if
not
os
.
path
.
exists
(
model_output
):
os
.
mkdir
(
model_output
)
subgraph
=
model_conf
[
"subgraphs"
][
0
]
del
model_conf
[
"subgraphs"
]
model_conf
.
update
(
subgraph
)
model_file
=
util
.
download_or_get_file
(
model_conf
[
"model_file_path"
],
model_conf
[
"model_sha256_checksum"
],
model_output
)
model_conf
[
"model_file_path"
]
=
model_file
if
"weight_file_path"
in
model_conf
:
weight_file
=
util
.
download_or_get_file
(
model_conf
[
"weight_file_path"
],
model_conf
[
"weight_sha256_checksum"
],
model_output
)
model_conf
[
"weight_file_path"
]
=
weight_file
mace_model
=
convert_model
(
model_conf
)
try
:
visualizer
=
visualize_model
.
ModelVisualizer
(
model_name
,
mace_model
,
model_output
)
visualizer
.
save_html
()
except
:
# noqa
print
(
"Failed to visualize model:"
,
sys
.
exc_info
()[
0
])
model
,
params
=
merge_params
(
mace_model
)
output_model_file
=
model_output
+
"/"
+
model_name
+
".pb"
output_params_file
=
model_output
+
"/"
+
model_name
+
".data"
with
open
(
output_model_file
,
"wb"
)
as
f
:
f
.
write
(
model
.
SerializeToString
())
with
open
(
output_params_file
,
"wb"
)
as
f
:
f
.
write
(
bytearray
(
params
))
with
open
(
output_model_file
+
"_txt"
,
"w"
)
as
f
:
f
.
write
(
str
(
model
))
def
convert_model
(
conf
):
print
(
conf
)
platform
=
conf
[
"platform"
]
mace_check
(
platform
in
[
'tensorflow'
,
'caffe'
,
'onnx'
],
"platform not supported"
)
runtime
=
conf
[
"runtime"
]
mace_check
(
runtime
in
[
'cpu'
,
'gpu'
,
'dsp'
,
'hta'
,
'apu'
,
'cpu+gpu'
],
"runtime not supported"
)
option
=
cvt
.
ConverterOption
()
if
"graph_optimize_options"
in
conf
:
option
.
transformer_option
=
conf
[
"graph_optimize_options"
].
split
(
','
)
option
.
winograd
=
conf
.
get
(
"winograd"
,
0
)
option
.
quantize
=
bool
(
conf
.
get
(
"quantize"
,
0
))
option
.
quantize_large_weights
=
bool
(
conf
.
get
(
"quantize_large_weights"
,
0
))
option
.
quantize_range_file
=
conf
.
get
(
"quantize_range_file"
,
""
)
option
.
change_concat_ranges
=
bool
(
conf
.
get
(
"change_concat_ranges"
,
0
))
option
.
cl_mem_type
=
conf
.
get
(
"cl_mem_type"
,
"image"
)
option
.
device
=
device_type_map
[
conf
.
get
(
"runtime"
,
"cpu"
)]
option
.
data_type
=
parse_data_type
(
conf
.
get
(
"data_type"
,
"fp16_fp32"
),
option
.
quantize
)
input_tensors
=
to_list
(
conf
[
"input_tensors"
])
input_shapes
=
[
parse_int_array_from_str
(
shape
)
for
shape
in
to_list
(
conf
[
"input_shapes"
])]
mace_check
(
len
(
input_tensors
)
==
len
(
input_shapes
),
"input node count and shape count do not match"
)
input_count
=
len
(
input_tensors
)
input_data_types
=
[
data_type_map
[
dt
]
for
dt
in
to_list
(
conf
.
get
(
"input_data_types"
,
[
"float32"
]
*
input_count
))]
input_data_formats
=
[
data_format_map
[
df
]
for
df
in
to_list
(
conf
.
get
(
"input_data_formats"
,
[
"NHWC"
]
*
input_count
))]
input_ranges
=
[
parse_float_array_from_str
(
r
)
for
r
in
to_list
(
conf
.
get
(
"input_ranges"
,
[
"-1.0,1.0"
]
*
input_count
))]
for
i
in
range
(
len
(
input_tensors
)):
input_node
=
cvt
.
NodeInfo
()
input_node
.
name
=
input_tensors
[
i
]
input_node
.
shape
=
input_shapes
[
i
]
input_node
.
data_type
=
input_data_types
[
i
]
input_node
.
data_format
=
input_data_formats
[
i
]
if
(
input_node
.
data_format
==
cvt
.
DataFormat
.
NCHW
and
len
(
input_node
.
shape
)
==
4
):
input_node
.
shape
=
transpose_shape
(
input_node
.
shape
,
[
0
,
2
,
3
,
1
])
input_node
.
data_format
=
cvt
.
DataFormat
.
NHWC
input_node
.
range
=
input_ranges
[
i
]
option
.
add_input_node
(
input_node
)
output_tensors
=
to_list
(
conf
[
"output_tensors"
])
output_shapes
=
[
parse_int_array_from_str
(
shape
)
for
shape
in
to_list
(
conf
[
"output_shapes"
])]
mace_check
(
len
(
output_tensors
)
==
len
(
output_shapes
),
"output node count and shape count do not match"
)
output_count
=
len
(
output_tensors
)
output_data_types
=
[
data_type_map
[
dt
]
for
dt
in
to_list
(
conf
.
get
(
"output_data_types"
,
[
"float32"
]
*
output_count
))]
output_data_formats
=
[
data_format_map
[
df
]
for
df
in
to_list
(
conf
.
get
(
"output_data_formats"
,
[
"NHWC"
]
*
output_count
))]
for
i
in
range
(
len
(
output_tensors
)):
output_node
=
cvt
.
NodeInfo
()
output_node
.
name
=
output_tensors
[
i
]
output_node
.
shape
=
output_shapes
[
i
]
output_node
.
data_type
=
output_data_types
[
i
]
output_node
.
data_format
=
output_data_formats
[
i
]
if
output_node
.
data_format
==
cvt
.
DataFormat
.
NCHW
and
len
(
output_node
.
shape
)
==
4
:
output_node
.
shape
=
transpose_shape
(
output_node
.
shape
,
[
0
,
2
,
3
,
1
])
output_node
.
data_format
=
cvt
.
DataFormat
.
NHWC
option
.
add_output_node
(
output_node
)
if
"check_node"
in
conf
:
check_node_names
=
to_list
(
conf
[
"check_node"
])
check_node_shapes
=
[
parse_int_array_from_str
(
shape
)
for
shape
in
to_list
(
conf
[
"check_shape"
])]
mace_check
(
len
(
check_node_names
)
==
len
(
check_node_shapes
),
"check node count and shape count do not match."
)
for
i
in
range
(
len
(
check_node_names
)):
check_node
=
cvt
.
NodeInfo
()
check_node
.
name
=
check_node_names
[
i
]
check_node
.
shape
=
check_node_shapes
[
i
]
option
.
add_check_node
(
check_node
)
else
:
option
.
check_nodes
=
option
.
output_nodes
option
.
build
()
print
(
"Transform model to one that can better run on device"
)
if
platform
==
'tensorflow'
:
from
transform
import
tensorflow_converter
converter
=
tensorflow_converter
.
TensorflowConverter
(
option
,
conf
[
"model_file_path"
])
elif
platform
==
'caffe'
:
from
transform
import
caffe_converter
converter
=
caffe_converter
.
CaffeConverter
(
option
,
conf
[
"model_file_path"
],
conf
[
"weight_file_path"
])
elif
platform
==
'onnx'
:
from
transform
import
onnx_converter
converter
=
onnx_converter
.
OnnxConverter
(
option
,
conf
[
"model_file_path"
])
else
:
mace_check
(
False
,
"Mace do not support platorm %s yet."
%
platform
)
output_graph_def
=
converter
.
run
()
mace_transformer
=
transformer
.
Transformer
(
option
,
output_graph_def
)
output_graph_def
,
quantize_activation_info
=
mace_transformer
.
run
()
if
option
.
device
in
[
cvt
.
DeviceType
.
HEXAGON
.
value
,
cvt
.
DeviceType
.
HTA
.
value
]:
from
transform
import
hexagon_converter
converter
=
hexagon_converter
.
HexagonConverter
(
option
,
output_graph_def
,
quantize_activation_info
)
output_graph_def
=
converter
.
run
()
elif
runtime
==
'apu'
:
mace_check
(
platform
==
"tensorflow"
,
"apu only support model from tensorflow"
)
from
transform
import
apu_converter
converter
=
apu_converter
.
ApuConverter
(
option
,
output_graph_def
,
quantize_activation_info
)
output_graph_def
=
converter
.
run
()
return
output_graph_def
def
merge_params
(
net_def
):
def
tensor_to_bytes
(
tensor
):
if
tensor
.
data_type
==
mace_pb2
.
DT_HALF
:
data
=
bytearray
(
np
.
array
(
tensor
.
float_data
).
astype
(
np
.
float16
).
tobytes
())
elif
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT
:
data
=
bytearray
(
np
.
array
(
tensor
.
float_data
).
astype
(
np
.
float32
).
tobytes
())
elif
tensor
.
data_type
==
mace_pb2
.
DT_INT32
:
data
=
bytearray
(
np
.
array
(
tensor
.
int32_data
).
astype
(
np
.
int32
).
tobytes
())
elif
tensor
.
data_type
==
mace_pb2
.
DT_UINT8
:
data
=
bytearray
(
np
.
array
(
tensor
.
int32_data
).
astype
(
np
.
uint8
).
tolist
())
elif
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT16
:
data
=
bytearray
(
np
.
array
(
tensor
.
float_data
).
astype
(
np
.
float16
).
tobytes
())
else
:
raise
Exception
(
'Tensor data type %s not supported'
%
tensor
.
data_type
)
return
data
model_data
=
[]
offset
=
0
for
tensor
in
net_def
.
tensors
:
raw_data
=
tensor_to_bytes
(
tensor
)
if
tensor
.
data_type
!=
mace_pb2
.
DT_UINT8
and
offset
%
4
!=
0
:
padding
=
4
-
offset
%
4
model_data
.
extend
(
bytearray
([
0
]
*
padding
))
offset
+=
padding
tensor
.
offset
=
offset
model_data
.
extend
(
raw_data
)
offset
+=
len
(
raw_data
)
for
tensor
in
net_def
.
tensors
:
if
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT
\
or
tensor
.
data_type
==
mace_pb2
.
DT_HALF
\
or
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT16
:
del
tensor
.
float_data
[:]
elif
tensor
.
data_type
==
mace_pb2
.
DT_INT32
:
del
tensor
.
int32_data
[:]
elif
tensor
.
data_type
==
mace_pb2
.
DT_UINT8
:
del
tensor
.
int32_data
[:]
return
net_def
,
model_data
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--config'
,
type
=
str
,
default
=
""
,
required
=
True
,
help
=
"the path of model yaml configuration file."
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
"."
,
help
=
"output dir"
)
flgs
,
_
=
parser
.
parse_known_args
()
return
flgs
if
__name__
==
'__main__'
:
flags
=
parse_args
()
conf
=
config_parser
.
parse
(
flags
.
config
)
convert
(
conf
,
flags
.
output
)
mace/python/tools/model_saver
.py
→
tools/python/encrypt
.py
浏览文件 @
320b509c
# Copyright 201
8
The MACE Authors. All Rights Reserved.
# Copyright 201
9
The MACE Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,31 +12,28 @@
...
@@ -12,31 +12,28 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
datetime
import
datetime
import
os
import
os
import
six
import
uuid
import
numpy
as
np
import
hashlib
import
hashlib
from
enum
import
Enum
from
mace.proto
import
mace_pb2
from
mace.python.tools.converter_tool
import
base_converter
as
cvt
from
mace.python.tools.convert_util
import
mace_check
from
jinja2
import
Environment
,
FileSystemLoader
from
jinja2
import
Environment
,
FileSystemLoader
from
py_proto
import
mace_pb2
from
utils
import
util
from
transform
import
base_converter
as
cvt
from
utils.util
import
mace_check
from
utils.config_parser
import
CPP_KEYWORDS
GENERATED_NAME
=
set
()
GENERATED_NAME
=
set
()
class
ModelFormat
(
object
):
file
=
"file"
code
=
"code"
def
generate_obfuscated_name
(
namespace
,
name
):
def
generate_obfuscated_name
(
namespace
,
name
):
md5
=
hashlib
.
md5
()
md5
=
hashlib
.
md5
()
md5
.
update
(
six
.
b
(
namespace
)
)
md5
.
update
(
namespace
)
md5
.
update
(
six
.
b
(
name
)
)
md5
.
update
(
name
)
md5_digest
=
md5
.
hexdigest
()
md5_digest
=
md5
.
hexdigest
()
name
=
md5_digest
[:
8
]
name
=
md5_digest
[:
8
]
...
@@ -76,19 +73,23 @@ def generate_in_out_map(ops, tensor_map):
...
@@ -76,19 +73,23 @@ def generate_in_out_map(ops, tensor_map):
return
in_out_map
return
in_out_map
def
obfuscate_name
(
option
,
net_def
):
def
stringfy
(
value
):
return
', '
.
join
(
'"{0}"'
.
format
(
w
)
for
w
in
value
)
def
obfuscate_name
(
model
):
input_nodes
=
set
()
input_nodes
=
set
()
for
name
in
option
.
input_nodes
:
for
input_node
in
model
.
input_info
:
input_nodes
.
add
(
name
)
input_nodes
.
add
(
input_node
.
name
)
output_nodes
=
set
()
output_nodes
=
set
()
for
name
in
option
.
output_nodes
:
for
output_node
in
model
.
output_info
:
output_nodes
.
add
(
name
)
output_nodes
.
add
(
output_node
.
name
)
tensor_map
=
generate_tensor_map
(
net_def
.
tensors
)
tensor_map
=
generate_tensor_map
(
model
.
tensors
)
in_out_map
=
generate_in_out_map
(
net_def
.
op
,
tensor_map
)
in_out_map
=
generate_in_out_map
(
model
.
op
,
tensor_map
)
for
t
in
net_def
.
tensors
:
for
t
in
model
.
tensors
:
if
t
.
name
not
in
input_nodes
and
t
.
name
not
in
output_nodes
:
if
t
.
name
not
in
input_nodes
and
t
.
name
not
in
output_nodes
:
t
.
name
=
tensor_map
[
t
.
name
]
t
.
name
=
tensor_map
[
t
.
name
]
for
op
in
net_def
.
op
:
for
op
in
model
.
op
:
for
i
in
range
(
len
(
op
.
input
)):
for
i
in
range
(
len
(
op
.
input
)):
if
op
.
input
[
i
]
not
in
input_nodes
:
if
op
.
input
[
i
]
not
in
input_nodes
:
op
.
input
[
i
]
=
in_out_map
[
op
.
input
[
i
]]
op
.
input
[
i
]
=
in_out_map
[
op
.
input
[
i
]]
...
@@ -97,201 +98,140 @@ def obfuscate_name(option, net_def):
...
@@ -97,201 +98,140 @@ def obfuscate_name(option, net_def):
op
.
output
[
i
]
=
in_out_map
[
op
.
output
[
i
]]
op
.
output
[
i
]
=
in_out_map
[
op
.
output
[
i
]]
def
stringfy
(
value
):
def
save_model_to_code
(
namespace
,
model
,
params
,
model_checksum
,
return
', '
.
join
(
'"{0}"'
.
format
(
w
)
for
w
in
value
)
params_checksum
,
device
,
output
):
if
not
os
.
path
.
exists
(
output
):
os
.
mkdir
(
output
)
class
TensorInfo
:
cwd
=
os
.
path
.
dirname
(
__file__
)
def
__init__
(
self
,
id
,
tensor
):
self
.
id
=
id
self
.
data_type
=
tensor
.
data_type
if
tensor
.
data_type
==
mace_pb2
.
DT_HALF
:
self
.
data
=
bytearray
(
np
.
array
(
tensor
.
float_data
).
astype
(
np
.
float16
).
tobytes
())
elif
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT
:
self
.
data
=
bytearray
(
np
.
array
(
tensor
.
float_data
).
astype
(
np
.
float32
).
tobytes
())
elif
tensor
.
data_type
==
mace_pb2
.
DT_INT32
:
self
.
data
=
bytearray
(
np
.
array
(
tensor
.
int32_data
).
astype
(
np
.
int32
).
tobytes
())
elif
tensor
.
data_type
==
mace_pb2
.
DT_UINT8
:
self
.
data
=
bytearray
(
np
.
array
(
tensor
.
int32_data
).
astype
(
np
.
uint8
).
tolist
())
elif
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT16
:
self
.
data
=
bytearray
(
np
.
array
(
tensor
.
float_data
).
astype
(
np
.
float16
).
tobytes
())
else
:
raise
Exception
(
'Tensor data type %s not supported'
%
tensor
.
data_type
)
def
update_tensor_infos
(
net_def
,
data_type
):
offset
=
0
counter
=
0
tensor_infos
=
[]
for
tensor
in
net_def
.
tensors
:
if
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT
:
tensor
.
data_type
=
data_type
# Add offset and data_size
tensor_info
=
TensorInfo
(
counter
,
tensor
)
tensor_infos
.
append
(
tensor_info
)
# align
if
tensor_info
.
data_type
!=
mace_pb2
.
DT_UINT8
and
offset
%
4
!=
0
:
padding
=
4
-
offset
%
4
offset
+=
padding
if
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT
\
or
tensor
.
data_type
==
mace_pb2
.
DT_HALF
\
or
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT16
:
tensor
.
data_size
=
len
(
tensor
.
float_data
)
elif
tensor
.
data_type
==
mace_pb2
.
DT_INT32
:
tensor
.
data_size
=
len
(
tensor
.
int32_data
)
elif
tensor
.
data_type
==
mace_pb2
.
DT_UINT8
:
tensor
.
data_size
=
len
(
tensor
.
int32_data
)
tensor
.
offset
=
offset
offset
+=
len
(
tensor_info
.
data
)
counter
+=
1
def
extract_model_data
(
net_def
):
model_data
=
[]
offset
=
0
counter
=
0
for
tensor
in
net_def
.
tensors
:
tensor_info
=
TensorInfo
(
counter
,
tensor
)
# align
mace_check
(
offset
<=
tensor
.
offset
,
"Current offset should be <= tensor.offset"
)
if
offset
<
tensor
.
offset
:
model_data
.
extend
(
bytearray
([
0
]
*
(
tensor
.
offset
-
offset
)))
offset
=
tensor
.
offset
model_data
.
extend
(
tensor_info
.
data
)
offset
+=
len
(
tensor_info
.
data
)
counter
+=
1
return
model_data
def
save_model_data
(
net_def
,
model_tag
,
output_dir
):
model_data
=
extract_model_data
(
net_def
)
# generate tensor data
with
open
(
output_dir
+
model_tag
+
'.data'
,
"wb"
)
as
f
:
f
.
write
(
bytearray
(
model_data
))
def
save_model_to_proto
(
net_def
,
model_tag
,
output_dir
):
for
tensor
in
net_def
.
tensors
:
if
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT
\
or
tensor
.
data_type
==
mace_pb2
.
DT_HALF
\
or
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT16
:
del
tensor
.
float_data
[:]
elif
tensor
.
data_type
==
mace_pb2
.
DT_INT32
:
del
tensor
.
int32_data
[:]
elif
tensor
.
data_type
==
mace_pb2
.
DT_UINT8
:
del
tensor
.
int32_data
[:]
proto_file_path
=
output_dir
+
model_tag
+
'.pb'
with
open
(
proto_file_path
,
"wb"
)
as
f
:
f
.
write
(
net_def
.
SerializeToString
())
with
open
(
proto_file_path
+
'_txt'
,
"w"
)
as
f
:
f
.
write
(
str
(
net_def
))
return
proto_file_path
def
save_model_to_code
(
net_def
,
model_tag
,
device
,
template_dir
,
output_dir
,
embed_model_data
,
model_checksum
,
weight_checksum
,
obfuscate
,
winograd_conv
):
# Create the jinja2 environment.
j2_env
=
Environment
(
j2_env
=
Environment
(
loader
=
FileSystemLoader
(
template_dir
),
trim_blocks
=
True
)
loader
=
FileSystemLoader
(
cwd
+
"/template"
),
trim_blocks
=
True
)
j2_env
.
filters
[
'stringfy'
]
=
stringfy
j2_env
.
filters
[
"stringfy"
]
=
stringfy
# generate tensor source files
template_name
=
'tensor_source.jinja2'
template_name
=
"tensor_source.jinja2"
counter
=
0
counter
=
0
for
tensor
in
net_def
.
tensors
:
for
tensor
in
model
.
tensors
:
tensor_info
=
TensorInfo
(
counter
,
tensor
)
# convert tensor
# convert tensor
source
=
j2_env
.
get_template
(
template_name
).
render
(
source
=
j2_env
.
get_template
(
template_name
).
render
(
tensor_info
=
tensor_info
,
tensor
=
tensor
,
tensor
=
tensor
,
tag
=
model_tag
,
tensor_id
=
counter
,
tag
=
namespace
,
)
)
with
open
(
output
_dir
+
'tensor'
+
str
(
counter
)
+
'.cc'
,
"w"
)
as
f
:
with
open
(
output
+
"/tensor"
+
str
(
counter
)
+
".cc"
,
"w"
)
as
f
:
f
.
write
(
source
)
f
.
write
(
source
)
counter
+=
1
counter
+=
1
# generate tensor data
template_name
=
"tensor_data.jinja2"
if
embed_model_data
:
source
=
j2_env
.
get_template
(
template_name
).
render
(
model_data
=
extract_model_data
(
net_def
)
tag
=
namespace
,
template_name
=
'tensor_data.jinja2'
model_data_size
=
len
(
params
),
source
=
j2_env
.
get_template
(
template_name
).
render
(
model_data
=
params
)
tag
=
model_tag
,
with
open
(
output
+
"/tensor_data.cc"
,
"w"
)
as
f
:
model_data_size
=
len
(
model_data
),
f
.
write
(
source
)
model_data
=
model_data
)
with
open
(
output_dir
+
'tensor_data'
+
'.cc'
,
"w"
)
as
f
:
f
.
write
(
source
)
# generate op source files
template_name
=
"operator.jinja2"
template_name
=
'operator.jinja2'
counter
=
0
counter
=
0
op_size
=
len
(
net_def
.
op
)
op_size
=
len
(
model
.
op
)
try
:
device
=
cvt
.
DeviceType
[
device
.
upper
()]
except
:
# noqa
if
device
.
upper
==
"DSP"
:
device
=
cvt
.
DeviceType
.
HEXAGON
else
:
device
=
cvt
.
DeviceType
.
CPU
for
start
in
range
(
0
,
op_size
,
10
):
for
start
in
range
(
0
,
op_size
,
10
):
source
=
j2_env
.
get_template
(
template_name
).
render
(
source
=
j2_env
.
get_template
(
template_name
).
render
(
start
=
start
,
start
=
start
,
end
=
min
(
start
+
10
,
op_size
),
end
=
min
(
start
+
10
,
op_size
),
net
=
net_def
,
net
=
model
,
tag
=
model_tag
,
tag
=
namespace
,
device
=
device
,
device
=
device
,
)
)
with
open
(
output
_dir
+
'op'
+
str
(
counter
)
+
'.cc'
,
"w"
)
as
f
:
with
open
(
output
+
"/op"
+
str
(
counter
)
+
".cc"
,
"w"
)
as
f
:
f
.
write
(
source
)
f
.
write
(
source
)
counter
+=
1
counter
+=
1
# generate model source files
# generate model source files
build_time
=
datetime
.
datetime
.
now
().
strftime
(
'%Y-%m-%d %H:%M:%S'
)
build_time
=
datetime
.
datetime
.
now
().
strftime
(
"%Y-%m-%d %H:%M:%S"
)
template_name
=
'model.jinja2'
template_name
=
"model.jinja2"
checksum
=
model_checksum
checksum
=
"{},{}"
.
format
(
model_checksum
,
params_checksum
)
if
weight_checksum
is
not
None
:
checksum
=
"{},{}"
.
format
(
model_checksum
,
weight_checksum
)
source
=
j2_env
.
get_template
(
template_name
).
render
(
source
=
j2_env
.
get_template
(
template_name
).
render
(
net
=
net_def
,
net
=
model
,
tag
=
model_tag
,
tag
=
namespace
,
obfuscate
=
obfuscate
,
embed_model_data
=
embed_model_data
,
winograd_conv
=
winograd_conv
,
checksum
=
checksum
,
checksum
=
checksum
,
build_time
=
build_time
)
build_time
=
build_time
)
with
open
(
output
_dir
+
'model.cc'
,
"w"
)
as
f
:
with
open
(
output
+
"/model.cc"
,
"w"
)
as
f
:
f
.
write
(
source
)
f
.
write
(
source
)
# generate model header file
template_name
=
'model_header.jinja2'
template_name
=
'model_header.jinja2'
source
=
j2_env
.
get_template
(
template_name
).
render
(
tag
=
model_tag
,
)
source
=
j2_env
.
get_template
(
template_name
).
render
(
tag
=
namespace
,
)
with
open
(
output
_dir
+
model_tag
+
'.h'
,
"w"
)
as
f
:
with
open
(
output
+
"/"
+
namespace
+
'.h'
,
"w"
)
as
f
:
f
.
write
(
source
)
f
.
write
(
source
)
def
save_model
(
option
,
net_def
,
model_checksum
,
weight_checksum
,
template_dir
,
def
save_model_to_file
(
model_name
,
model
,
params
,
output
):
obfuscate
,
model_tag
,
output_dir
,
embed_model_data
,
if
not
os
.
path
.
exists
(
output
):
winograd_conv
,
model_graph_format
):
os
.
mkdir
(
output
)
if
obfuscate
:
with
open
(
output
+
"/"
+
model_name
+
".pb"
,
"wb"
)
as
f
:
obfuscate_name
(
option
,
net_def
)
f
.
write
(
model
.
SerializeToString
())
with
open
(
output
+
"/"
+
model_name
+
".data"
,
"wb"
)
as
f
:
output_dir
=
output_dir
+
'/'
f
.
write
(
params
)
net_def
.
data_type
=
option
.
data_type
# update tensor type
update_tensor_infos
(
net_def
,
option
.
data_type
)
def
encrypt
(
model_name
,
model_file
,
params_file
,
device
,
output
,
is_obfuscate
=
False
):
if
model_graph_format
==
ModelFormat
.
file
or
not
embed_model_data
:
model_checksum
=
util
.
file_checksum
(
model_file
)
save_model_data
(
net_def
,
model_tag
,
output_dir
)
params_checksum
=
util
.
file_checksum
(
params_file
)
if
model_graph_format
==
ModelFormat
.
file
:
with
open
(
model_file
,
"rb"
)
as
model_file
:
save_model_to_proto
(
net_def
,
model_tag
,
output_dir
)
with
open
(
params_file
,
"rb"
)
as
params_file
:
else
:
model
=
mace_pb2
.
NetDef
()
save_model_to_code
(
net_def
,
model_tag
,
option
.
device
,
model
.
ParseFromString
(
model_file
.
read
())
template_dir
,
output_dir
,
embed_model_data
,
params
=
bytearray
(
params_file
.
read
())
model_checksum
,
weight_checksum
,
obfuscate
,
winograd_conv
)
if
is_obfuscate
:
obfuscate_name
(
model
)
save_model_to_file
(
model_name
,
model
,
params
,
output
+
"/file/"
)
save_model_to_code
(
model_name
,
model
,
params
,
model_checksum
,
params_checksum
,
device
,
output
+
"/code/"
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model_name'
,
type
=
str
,
help
=
"the namespace of gernerated code"
)
parser
.
add_argument
(
'--model_file'
,
type
=
str
,
help
=
"model file"
)
parser
.
add_argument
(
'--params_file'
,
type
=
str
,
help
=
"params file"
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cpu'
,
help
=
"cpu/gpu/hexagon/hta/apu"
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
"."
,
help
=
"output dir"
)
parser
.
add_argument
(
"--obfuscate"
,
action
=
"store_true"
,
help
=
"obfuscate model names"
)
flgs
,
_
=
parser
.
parse_known_args
()
mace_check
(
flags
.
model_name
not
in
CPP_KEYWORDS
,
"model name cannot be cpp"
"keywords"
)
return
flgs
if
__name__
==
'__main__'
:
flags
=
parse_args
()
encrypt
(
flags
.
model_name
,
flags
.
model_file
,
flags
.
params_file
,
flags
.
device
,
flags
.
output
,
flags
.
obfuscate
)
tools/
experimental/utils/util
.py
→
tools/
python/py_proto/__init__
.py
浏览文件 @
320b509c
...
@@ -15,3 +15,15 @@
...
@@ -15,3 +15,15 @@
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
from
utils
import
device
cwd
=
os
.
path
.
dirname
(
__file__
)
# TODO: Remove bazel deps
device
.
execute
(
"bazel build //mace/proto:mace_py"
)
device
.
execute
(
"cp -f bazel-genfiles/mace/proto/mace_pb2.py %s"
%
cwd
)
device
.
execute
(
"bazel build //third_party/caffe:caffe_py"
)
device
.
execute
(
"cp -f bazel-genfiles/third_party/caffe/caffe_pb2.py %s"
%
cwd
)
mace/python/tools/quantization
/__init__.py
→
tools/python/quantize
/__init__.py
浏览文件 @
320b509c
文件已移动
mace/python/tools/quantization
/quantize_stat.py
→
tools/python/quantize
/quantize_stat.py
浏览文件 @
320b509c
# Copyright 2019 The MACE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
argparse
import
numpy
as
np
import
numpy
as
np
...
...
mace/python/tools/quantization
/quantize_util.py
→
tools/python/quantize
/quantize_util.py
浏览文件 @
320b509c
# Copyright 2019 The MACE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
numpy
as
np
import
math
import
math
from
mace.python.tools.converter_tool
.base_converter
import
DeviceType
from
transform
.base_converter
import
DeviceType
class
QuantizedData
(
object
):
class
QuantizedData
(
object
):
...
@@ -126,7 +144,8 @@ def cal_multiplier_and_shift(scale):
...
@@ -126,7 +144,8 @@ def cal_multiplier_and_shift(scale):
def
quantize_with_scale_and_zero
(
data
,
scale
,
zero
):
def
quantize_with_scale_and_zero
(
data
,
scale
,
zero
):
output
=
np
.
round
(
zero
+
data
/
scale
).
astype
(
np
.
int32
)
np_data
=
np
.
array
(
data
).
astype
(
float
)
output
=
np
.
round
(
zero
+
np_data
/
scale
).
astype
(
np
.
int32
)
quantized_data
=
QuantizedData
()
quantized_data
=
QuantizedData
()
quantized_data
.
data
=
output
quantized_data
.
data
=
output
quantized_data
.
scale
=
scale
quantized_data
.
scale
=
scale
...
@@ -140,7 +159,8 @@ def quantize(data, device, non_zero):
...
@@ -140,7 +159,8 @@ def quantize(data, device, non_zero):
in_max
=
np_data
.
max
()
in_max
=
np_data
.
max
()
scale
,
zero
,
out_min
,
out_max
=
adjust_range
(
in_min
,
in_max
,
device
,
scale
,
zero
,
out_min
,
out_max
=
adjust_range
(
in_min
,
in_max
,
device
,
non_zero
=
non_zero
)
non_zero
=
non_zero
)
output
=
np
.
clip
((
np
.
round
(
zero
+
data
/
scale
).
astype
(
np
.
int32
)),
0
,
255
)
output
=
np
.
clip
((
np
.
round
(
zero
+
np_data
/
scale
).
astype
(
np
.
int32
)),
0
,
255
)
quantized_data
=
QuantizedData
()
quantized_data
=
QuantizedData
()
quantized_data
.
data
=
output
quantized_data
.
data
=
output
...
...
tools/
experimental/utils/config_parser
.py
→
tools/
python/quantize/quantize_util_test
.py
浏览文件 @
320b509c
...
@@ -12,32 +12,19 @@
...
@@ -12,32 +12,19 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
__future__
import
absolute_impor
t
import
unittes
t
from
__future__
import
division
import
numpy
as
np
from
__future__
import
print_function
import
quantize.quantize_util
import
re
import
os
import
yaml
class
TestQuantize
(
unittest
.
TestCase
):
def
sanitize_load
(
s
):
def
test_quantize_dequantize
(
self
):
# do not let yaml parse ON/OFF to boolean
test_input
=
np
.
random
.
rand
(
20
,
30
)
*
5
for
w
in
[
"ON"
,
"OFF"
,
"on"
,
"off"
]:
quantized_data
=
quantize_util
.
quantize
(
test_input
)
s
=
re
.
sub
(
r
":\s+"
+
w
,
r
": '"
+
w
+
"'"
,
s
)
dequantized_output
=
quantize_util
.
dequantize
(
quantized_data
)
np
.
testing
.
assert_array_almost_equal
(
test_input
,
dequantized_output
,
2
)
# sub ${} to env value
s
=
re
.
sub
(
r
"\${(\w+)}"
,
lambda
x
:
os
.
environ
[
x
.
group
(
1
)],
s
)
return
yaml
.
load
(
s
)
if
__name__
==
'__main__'
:
def
parse
(
path
):
unittest
.
main
()
with
open
(
path
)
as
f
:
config
=
sanitize_load
(
f
.
read
())
return
config
def
parse_device_info
(
path
):
conf
=
parse
(
path
)
return
conf
[
"devices"
]
tools/
experimental
/run.py
→
tools/
python
/run.py
浏览文件 @
320b509c
文件已移动
mace/python/tools
/model.jinja2
→
tools/python/template
/model.jinja2
浏览文件 @
320b509c
...
@@ -166,10 +166,5 @@ const std::string ModelBuildTime() {
...
@@ -166,10 +166,5 @@ const std::string ModelBuildTime() {
return {{ build_time|tojson }};
return {{ build_time|tojson }};
}
}
const std::string ModelBuildOptions() {
return {{ "obfuscate: {}, embed_model_data: {}, winograd: {}"
.format(obfuscate, embed_model_data, winograd_conv)|tojson }};
}
} // namespace {{tag}}
} // namespace {{tag}}
} // namespace mace
} // namespace mace
mace/python/tools
/model_header.jinja2
→
tools/python/template
/model_header.jinja2
浏览文件 @
320b509c
文件已移动
tools/python/template/operator.jinja2
0 → 100644
浏览文件 @
320b509c
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This is a generated file. DO NOT EDIT!
#include <vector>
#include <string>
#include "mace/proto/mace.pb.h"
#include "mace/public/mace.h"
#include "mace/port/env.h"
#include "mace/utils/logging.h"
namespace mace {
namespace {
void UpdateOp(mace::OperatorDef *op,
const std::string &name,
const std::string &type,
const std::vector<std::string> &inputs,
const std::vector<std::string> &outputs,
const std::vector<mace::DataType> &output_types,
uint32_t node_id,
const std::vector<int> &mem_ids) {
op->set_name(name);
op->set_type(type);
op->set_node_id(node_id);
op->mutable_input()->Reserve(inputs.size());
for (auto input : inputs) {
op->add_input(input);
}
op->mutable_output()->Reserve(outputs.size());
for (auto output : outputs) {
op->add_output(output);
}
op->mutable_output_type()->Reserve(output_types.size());
for (auto output_type : output_types) {
op->add_output_type(output_type);
}
op->mutable_mem_id()->Reserve(mem_ids.size());
for (auto mem_id : mem_ids) {
op->add_mem_id(mem_id);
}
}
} // namespace
} // namespace mace
namespace mace {
namespace {{tag}} {
{% for i in range(start, end) %}
void CreateOperator{{i}}(mace::OperatorDef *op) {
MACE_LATENCY_LOGGER(2, "Create operator {{ net.op[i].name }}");
mace::Argument *arg = nullptr;
op->mutable_arg()->Reserve({{ net.op[i].arg|length }});
{% for arg in net.op[i].arg %}
arg = op->add_arg();
arg->set_name({{ arg.name|tojson }});
{%- if arg.HasField('f') %}
arg->set_f({{ arg.f }});
{%- endif %}
{%- if arg.HasField('i') %}
arg->set_i({{ arg.i }});
{%- endif %}
{%- if arg.HasField('s') %}
arg->set_s({{ arg.s.decode('utf-8')|tojson }});
{%- endif %}
{% if arg.floats|length > 0 %}
arg->mutable_floats()->Reserve({{ arg.floats|length }});
{% for float_value in arg.floats %}
arg->add_floats({{ float_value }});
{% endfor %}
{% endif %}
{% if arg.ints|length > 0 %}
arg->mutable_ints()->Reserve({{ arg.ints|length }});
{% for int_value in arg.ints %}
arg->add_ints({{ int_value }});
{% endfor %}
{% endif %}
{% endfor %}
{% if net.op[i].output_shape|length > 0 %}
op->mutable_output_shape()->Reserve({{ net.op[i].output_shape|length }});
{% for shape in net.op[i].output_shape %}
{% if shape.dims|length > 0 %}
{
mace::OutputShape *output_shape = op->add_output_shape();
output_shape->mutable_dims()->Reserve({{ shape.dims|length }});
{% for dim in shape.dims %}
output_shape->add_dims({{ dim }});
{% endfor %}
}
{% endif %}
{% endfor %}
{% endif %}
std::vector<int> output_types_int({ {{ net.op[i].output_type | join(', ') }} });
std::vector<mace::DataType> output_types({{ net.op[i].output_type | length }});
for (int k = 0; k < {{ net.op[i].output_type | length }}; ++k) {
output_types[k] = static_cast<mace::DataType>(output_types_int[k]);
}
UpdateOp(op, {{ net.op[i].name|tojson }}, {{ net.op[i].type|tojson}},
{ {{ net.op[i].input|stringfy }} },
{ {{ net.op[i].output|stringfy }} },
output_types,
{{ net.op[i].node_id }},
{ {{ net.op[i].mem_id | join(', ') }} });
op->mutable_quantize_info()->Reserve({{ net.op[i].quantize_info | length }});
{% for j in range(net.op[i].quantize_info|length) %}
auto quantize_info{{j}} = op->add_quantize_info();
quantize_info{{j}}->set_scale({{ net.op[i].quantize_info[j].scale }});
quantize_info{{j}}->set_zero_point({{ net.op[i].quantize_info[j].zero_point }});
quantize_info{{j}}->set_minval({{ net.op[i].quantize_info[j].minval }});
quantize_info{{j}}->set_maxval({{ net.op[i].quantize_info[j].maxval }});
{% endfor %}
{% if device == 3 %}
op->set_padding({{ net.op[i].padding }});
{% if net.op[i].node_input | length > 0 %}
std::vector<int> input_node_ids({ {{ net.op[i].node_input | map(attribute='node_id') | join(', ') }} });
std::vector<int> input_output_ports({ {{ net.op[i].node_input | map(attribute='output_port') | join(', ')}} });
mace::NodeInput *node_input = nullptr;
op->mutable_node_input()->Reserve({{ net.op[i].node_input|length }});
for (size_t i = 0; i < {{ net.op[i].node_input|length }}; ++i) {
node_input = op->add_node_input();
node_input->set_node_id(input_node_ids[i]);
node_input->set_output_port(input_output_ports[i]);
}
{% endif %}
{% if net.op[i].out_max_byte_size | length > 0 %}
std::vector<int> out_max_byte_sizes {{ net.op[i].out_max_byte_size | replace('[', '{') | replace(']', '}') }};
op->mutable_out_max_byte_size()->Reserve({{ net.op[i].out_max_byte_size|length }});
for (size_t i = 0; i < {{ net.op[i].out_max_byte_size|length }}; ++i) {
op->add_out_max_byte_size(out_max_byte_sizes[i]);
}
{% endif %}
{% endif %}
{% if device == 5 %}
{% if net.op[i].node_input | length > 0 %}
std::vector<int> input_node_ids({ {{ net.op[i].node_input | map(attribute='node_id') | join(', ') }} });
mace::NodeInput *node_input = nullptr;
op->mutable_node_input()->Reserve({{ net.op[i].node_input|length }});
for (size_t i = 0; i < {{ net.op[i].node_input|length }}; ++i) {
node_input = op->add_node_input();
node_input->set_node_id(input_node_ids[i]);
}
{% endif %}
{% endif %}
}
{% endfor %}
} // namespace {{tag}}
} // namespace mace
mace/python/tools
/tensor_data.jinja2
→
tools/python/template
/tensor_data.jinja2
浏览文件 @
320b509c
文件已移动
mace/python/tools
/tensor_source.jinja2
→
tools/python/template
/tensor_source.jinja2
浏览文件 @
320b509c
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
namespace mace {
namespace mace {
namespace {{tag}} {
namespace {{tag}} {
void CreateTensor{{tensor_i
nfo.i
d}}(mace::ConstTensor *const_tensor) {
void CreateTensor{{tensor_id}}(mace::ConstTensor *const_tensor) {
MACE_LATENCY_LOGGER(2, "Create tensor {{ tensor.name }}");
MACE_LATENCY_LOGGER(2, "Create tensor {{ tensor.name }}");
const_tensor->set_name({{ tensor.name|tojson }});
const_tensor->set_name({{ tensor.name|tojson }});
const_tensor->set_offset({{ tensor.offset }});
const_tensor->set_offset({{ tensor.offset }});
...
@@ -30,7 +30,7 @@ void CreateTensor{{tensor_info.id}}(mace::ConstTensor *const_tensor) {
...
@@ -30,7 +30,7 @@ void CreateTensor{{tensor_info.id}}(mace::ConstTensor *const_tensor) {
{% for dim in tensor.dims %}
{% for dim in tensor.dims %}
const_tensor->add_dims({{ dim }});
const_tensor->add_dims({{ dim }});
{% endfor %}
{% endfor %}
const_tensor->set_data_type(static_cast<DataType>({{ tensor
_info
.data_type }}));
const_tensor->set_data_type(static_cast<DataType>({{ tensor.data_type }}));
const_tensor->set_node_id({{ tensor.node_id }});
const_tensor->set_node_id({{ tensor.node_id }});
const_tensor->set_scale({{ tensor.scale }});
const_tensor->set_scale({{ tensor.scale }});
const_tensor->set_zero_point({{ tensor.zero_point }});
const_tensor->set_zero_point({{ tensor.zero_point }});
...
...
tools/__init__.py
→
tools/
python/transform/
__init__.py
浏览文件 @
320b509c
文件已移动
mace/python/tools/converter_tool
/apu_converter.py
→
tools/python/transform
/apu_converter.py
浏览文件 @
320b509c
...
@@ -17,19 +17,18 @@ import numpy as np
...
@@ -17,19 +17,18 @@ import numpy as np
from
enum
import
Enum
from
enum
import
Enum
from
operator
import
mul
from
operator
import
mul
from
mace.proto
import
mace_pb2
from
py_proto
import
mace_pb2
from
mace.python.tools.converter_tool
import
base_converter
from
transform
import
base_converter
from
mace.python.tools.converter_tool.base_converter
import
ConverterUtil
from
transform.base_converter
import
ConverterUtil
from
mace.python.tools.converter_tool.base_converter
import
EltwiseType
from
transform.base_converter
import
EltwiseType
from
mace.python.tools.converter_tool.base_converter
import
MaceKeyword
from
transform.base_converter
import
MaceKeyword
from
mace.python.tools.converter_tool.base_converter
import
MaceOp
from
transform.base_converter
import
MaceOp
from
mace.python.tools.converter_tool.base_converter
import
PaddingMode
from
transform.base_converter
import
PaddingMode
from
mace.python.tools.converter_tool.base_converter
import
PoolingType
from
transform.base_converter
import
PoolingType
from
mace.python.tools.converter_tool.base_converter
import
ReduceType
from
transform.base_converter
import
ReduceType
from
mace.python.tools.converter_tool.base_converter
import
DataFormat
from
transform.base_converter
import
DataFormat
from
mace.python.tools.converter_tool.base_converter
import
FrameworkType
from
transform.base_converter
import
FrameworkType
from
mace.python.tools.convert_util
import
mace_check
from
utils.util
import
mace_check
from
mace.python.tools
import
graph_util
ApuSupportedOps
=
[
ApuSupportedOps
=
[
...
...
mace/python/tools/converter_tool
/base_converter.py
→
tools/python/transform
/base_converter.py
浏览文件 @
320b509c
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
from
enum
import
Enum
from
enum
import
Enum
from
mace.
proto
import
mace_pb2
from
py_
proto
import
mace_pb2
class
DeviceType
(
Enum
):
class
DeviceType
(
Enum
):
...
...
mace/python/tools/converter_tool
/caffe_converter.py
→
tools/python/transform
/caffe_converter.py
浏览文件 @
320b509c
...
@@ -19,20 +19,20 @@ import numpy as np
...
@@ -19,20 +19,20 @@ import numpy as np
import
six
import
six
import
google.protobuf.text_format
import
google.protobuf.text_format
from
mace.
proto
import
mace_pb2
from
py_
proto
import
mace_pb2
from
mace.python.tools.converter_tool
import
base_converter
from
transform
import
base_converter
from
mace.python.tools.converter_tool
import
shape_inference
from
transform
import
shape_inference
from
mace.python.tools.converter_tool
.base_converter
import
PoolingType
from
transform
.base_converter
import
PoolingType
from
mace.python.tools.converter_tool
.base_converter
import
ActivationType
from
transform
.base_converter
import
ActivationType
from
mace.python.tools.converter_tool
.base_converter
import
EltwiseType
from
transform
.base_converter
import
EltwiseType
from
mace.python.tools.converter_tool
.base_converter
import
FrameworkType
from
transform
.base_converter
import
FrameworkType
from
mace.python.tools.converter_tool
.base_converter
import
DataFormat
from
transform
.base_converter
import
DataFormat
from
mace.python.tools.converter_tool
.base_converter
import
MaceOp
from
transform
.base_converter
import
MaceOp
from
mace.python.tools.converter_tool
.base_converter
import
MaceKeyword
from
transform
.base_converter
import
MaceKeyword
from
mace.python.tools.converter_tool
.base_converter
import
ConverterUtil
from
transform
.base_converter
import
ConverterUtil
from
mace.python.tools.convert_
util
import
mace_check
from
utils.
util
import
mace_check
from
third_party.caffe
import
caffe_pb2
from
py_proto
import
caffe_pb2
caffe_group_str
=
'group'
caffe_group_str
=
'group'
caffe_kernel_h_str
=
'kernel_h'
caffe_kernel_h_str
=
'kernel_h'
...
...
mace/python/tools/converter_tool
/hexagon_converter.py
→
tools/python/transform
/hexagon_converter.py
浏览文件 @
320b509c
...
@@ -17,18 +17,17 @@ import numpy as np
...
@@ -17,18 +17,17 @@ import numpy as np
from
enum
import
Enum
from
enum
import
Enum
from
operator
import
mul
from
operator
import
mul
from
mace.proto
import
mace_pb2
from
py_proto
import
mace_pb2
from
mace.python.tools.converter_tool
import
base_converter
from
transform
import
base_converter
from
mace.python.tools.converter_tool.base_converter
import
ConverterUtil
from
transform.base_converter
import
ConverterUtil
from
mace.python.tools.converter_tool.base_converter
import
DeviceType
from
transform.base_converter
import
DeviceType
from
mace.python.tools.converter_tool.base_converter
import
EltwiseType
from
transform.base_converter
import
EltwiseType
from
mace.python.tools.converter_tool.base_converter
import
MaceKeyword
from
transform.base_converter
import
MaceKeyword
from
mace.python.tools.converter_tool.base_converter
import
MaceOp
from
transform.base_converter
import
MaceOp
from
mace.python.tools.converter_tool.base_converter
import
PaddingMode
from
transform.base_converter
import
PaddingMode
from
mace.python.tools.converter_tool.base_converter
import
PoolingType
from
transform.base_converter
import
PoolingType
from
mace.python.tools.converter_tool.base_converter
import
ReduceType
from
transform.base_converter
import
ReduceType
from
mace.python.tools.convert_util
import
mace_check
from
utils.util
import
mace_check
from
mace.python.tools
import
graph_util
from
six.moves
import
reduce
from
six.moves
import
reduce
...
@@ -144,23 +143,18 @@ class HexagonConverter(base_converter.ConverterInterface):
...
@@ -144,23 +143,18 @@ class HexagonConverter(base_converter.ConverterInterface):
return
self
.
_model
return
self
.
_model
def
add_port_for_tensors
(
self
,
tensors
):
for
i
in
range
(
len
(
tensors
)):
if
':'
not
in
tensors
[
i
]:
node_name
=
tensors
[
i
]
tensors
[
i
]
+=
':0'
if
node_name
in
self
.
_quantize_activation_info
:
self
.
_quantize_activation_info
[
tensors
[
i
]]
=
\
self
.
_quantize_activation_info
[
node_name
]
def
convert_ops
(
self
):
def
convert_ops
(
self
):
print
(
"Convert mace graph to hexagon."
)
print
(
"Convert mace graph to hexagon."
)
for
op
in
self
.
_model
.
op
:
for
op
in
self
.
_model
.
op
:
if
not
self
.
_hexagon_ops
.
has_op
(
op
.
type
):
if
not
self
.
_hexagon_ops
.
has_op
(
op
.
type
):
raise
Exception
(
'Unsupported op: '
,
op
)
raise
Exception
(
'Unsupported op: '
,
op
)
for
i
in
range
(
len
(
op
.
input
)):
self
.
add_port_for_tensors
(
op
.
input
)
if
':'
not
in
op
.
input
[
i
]:
self
.
add_port_for_tensors
(
op
.
output
)
node_name
=
op
.
input
[
i
]
op
.
input
[
i
]
+=
':0'
if
node_name
in
self
.
_quantize_activation_info
:
self
.
_quantize_activation_info
[
op
.
input
[
i
]]
=
\
self
.
_quantize_activation_info
[
node_name
]
if
op
.
type
==
MaceOp
.
Conv2D
.
name
\
if
op
.
type
==
MaceOp
.
Conv2D
.
name
\
or
op
.
type
==
MaceOp
.
DepthwiseConv2d
.
name
:
or
op
.
type
==
MaceOp
.
DepthwiseConv2d
.
name
:
...
@@ -488,15 +482,13 @@ class HexagonConverter(base_converter.ConverterInterface):
...
@@ -488,15 +482,13 @@ class HexagonConverter(base_converter.ConverterInterface):
for
tensor
in
self
.
_model
.
tensors
:
for
tensor
in
self
.
_model
.
tensors
:
tensor
.
node_id
=
node_id_counter
tensor
.
node_id
=
node_id_counter
node_id_counter
+=
1
node_id_counter
+=
1
node_id_map
[
tensor
.
name
]
=
tensor
.
node_id
tensor_op
,
port
=
get_op_and_port_from_tensor
(
tensor
.
name
)
node_id_map
[
tensor_op
]
=
tensor
.
node_id
print
(
"Hexagon op:"
)
print
(
"Hexagon op:"
)
index
=
0
index
=
0
for
op
in
self
.
_model
.
op
:
for
op
in
self
.
_model
.
op
:
op
.
node_id
=
node_id_counter
op
.
node_id
=
node_id_counter
node_id_counter
+=
1
for
output
in
op
.
output
:
node_id_map
[
output
]
=
op
.
node_id
if
op
.
type
not
in
[
HexagonOp
.
QuantizeINPUT_f_to_8
,
if
op
.
type
not
in
[
HexagonOp
.
QuantizeINPUT_f_to_8
,
HexagonOp
.
DequantizeOUTPUT_8tof
.
name
]:
HexagonOp
.
DequantizeOUTPUT_8tof
.
name
]:
index_str
=
str
(
index
)
index_str
=
str
(
index
)
...
@@ -505,10 +497,11 @@ class HexagonConverter(base_converter.ConverterInterface):
...
@@ -505,10 +497,11 @@ class HexagonConverter(base_converter.ConverterInterface):
index_str
=
''
index_str
=
''
print
(
'Op: %s (%s, node_id:%d, index:%s)'
%
print
(
'Op: %s (%s, node_id:%d, index:%s)'
%
(
op
.
name
,
op
.
type
,
op
.
node_id
,
index_str
))
(
op
.
name
,
op
.
type
,
op
.
node_id
,
index_str
))
node_id_counter
+=
1
node_id_map
[
op
.
name
]
=
op
.
node_id
for
ipt
in
op
.
input
:
for
ipt
in
op
.
input
:
op_name
,
port
=
get_op_and_port_from_tensor
(
ipt
)
op_name
,
port
=
get_op_and_port_from_tensor
(
ipt
)
tensor_name
=
ipt
if
port
==
0
else
op_name
+
':0'
node_id
=
node_id_map
[
op_name
]
node_id
=
node_id_map
[
tensor_name
]
node_input
=
op
.
node_input
.
add
()
node_input
=
op
.
node_input
.
add
()
node_input
.
node_id
=
node_id
node_input
.
node_id
=
node_id
node_input
.
output_port
=
int
(
port
)
node_input
.
output_port
=
int
(
port
)
mace/python/tools/converter_tool
/onnx_converter.py
→
tools/python/transform
/onnx_converter.py
浏览文件 @
320b509c
...
@@ -17,20 +17,20 @@ import sys
...
@@ -17,20 +17,20 @@ import sys
from
enum
import
Enum
from
enum
import
Enum
import
six
import
six
from
mace.
proto
import
mace_pb2
from
py_
proto
import
mace_pb2
from
mace.python.tools.converter_tool
import
base_converter
from
transform
import
base_converter
from
mace.python.tools.converter_tool
.base_converter
import
PoolingType
from
transform
.base_converter
import
PoolingType
from
mace.python.tools.converter_tool
.base_converter
import
PaddingMode
from
transform
.base_converter
import
PaddingMode
from
mace.python.tools.converter_tool
.base_converter
import
ActivationType
from
transform
.base_converter
import
ActivationType
from
mace.python.tools.converter_tool
.base_converter
import
EltwiseType
from
transform
.base_converter
import
EltwiseType
from
mace.python.tools.converter_tool
.base_converter
import
ReduceType
from
transform
.base_converter
import
ReduceType
from
mace.python.tools.converter_tool
.base_converter
import
FrameworkType
from
transform
.base_converter
import
FrameworkType
from
mace.python.tools.converter_tool
.base_converter
import
RoundMode
from
transform
.base_converter
import
RoundMode
from
mace.python.tools.converter_tool
.base_converter
import
DataFormat
from
transform
.base_converter
import
DataFormat
from
mace.python.tools.converter_tool
.base_converter
import
MaceOp
from
transform
.base_converter
import
MaceOp
from
mace.python.tools.converter_tool
.base_converter
import
MaceKeyword
from
transform
.base_converter
import
MaceKeyword
from
mace.python.tools.converter_tool
.base_converter
import
ConverterUtil
from
transform
.base_converter
import
ConverterUtil
from
mace.python.tools.convert_
util
import
mace_check
from
utils.
util
import
mace_check
import
numpy
as
np
import
numpy
as
np
...
...
mace/python/tools/converter_tool
/shape_inference.py
→
tools/python/transform
/shape_inference.py
浏览文件 @
320b509c
...
@@ -18,12 +18,12 @@ import math
...
@@ -18,12 +18,12 @@ import math
import
numpy
as
np
import
numpy
as
np
import
six
import
six
from
mace.python.tools.converter_tool
.transformer
import
Transformer
from
transform
.transformer
import
Transformer
from
mace.python.tools.converter_tool
.base_converter
import
DataFormat
from
transform
.base_converter
import
DataFormat
from
mace.python.tools.converter_tool
.base_converter
import
MaceOp
from
transform
.base_converter
import
MaceOp
from
mace.python.tools.converter_tool
.base_converter
import
MaceKeyword
from
transform
.base_converter
import
MaceKeyword
from
mace.python.tools.converter_tool
.base_converter
import
ConverterUtil
from
transform
.base_converter
import
ConverterUtil
from
mace.python.tools.convert_
util
import
mace_check
from
utils.
util
import
mace_check
class
ShapeInference
(
object
):
class
ShapeInference
(
object
):
...
...
mace/python/tools/converter_tool
/tensorflow_converter.py
→
tools/python/transform
/tensorflow_converter.py
浏览文件 @
320b509c
...
@@ -19,20 +19,20 @@ import six
...
@@ -19,20 +19,20 @@ import six
import
tensorflow
as
tf
import
tensorflow
as
tf
from
enum
import
Enum
from
enum
import
Enum
from
mace.
proto
import
mace_pb2
from
py_
proto
import
mace_pb2
from
mace.python.tools.converter_tool
import
base_converter
from
transform
import
base_converter
from
mace.python.tools.converter_tool
.base_converter
import
PoolingType
from
transform
.base_converter
import
PoolingType
from
mace.python.tools.converter_tool
.base_converter
import
PaddingMode
from
transform
.base_converter
import
PaddingMode
from
mace.python.tools.converter_tool
.base_converter
import
ActivationType
from
transform
.base_converter
import
ActivationType
from
mace.python.tools.converter_tool
.base_converter
import
EltwiseType
from
transform
.base_converter
import
EltwiseType
from
mace.python.tools.converter_tool
.base_converter
import
PadType
from
transform
.base_converter
import
PadType
from
mace.python.tools.converter_tool
.base_converter
import
FrameworkType
from
transform
.base_converter
import
FrameworkType
from
mace.python.tools.converter_tool
.base_converter
import
ReduceType
from
transform
.base_converter
import
ReduceType
from
mace.python.tools.converter_tool
.base_converter
import
DataFormat
from
transform
.base_converter
import
DataFormat
from
mace.python.tools.converter_tool
.base_converter
import
MaceOp
from
transform
.base_converter
import
MaceOp
from
mace.python.tools.converter_tool
.base_converter
import
MaceKeyword
from
transform
.base_converter
import
MaceKeyword
from
mace.python.tools.converter_tool
.base_converter
import
ConverterUtil
from
transform
.base_converter
import
ConverterUtil
from
mace.python.tools.convert_
util
import
mace_check
from
utils.
util
import
mace_check
from
tensorflow.core.framework
import
tensor_shape_pb2
from
tensorflow.core.framework
import
tensor_shape_pb2
from
tensorflow.tools.graph_transforms
import
TransformGraph
from
tensorflow.tools.graph_transforms
import
TransformGraph
...
...
mace/python/tools/converter_tool
/transformer.py
→
tools/python/transform
/transformer.py
浏览文件 @
320b509c
...
@@ -18,22 +18,22 @@ import re
...
@@ -18,22 +18,22 @@ import re
import
numpy
as
np
import
numpy
as
np
import
six
import
six
from
mace.
proto
import
mace_pb2
from
py_
proto
import
mace_pb2
from
mace.python.tools.converter_tool
import
base_converter
from
transform
import
base_converter
from
mace.python.tools.converter_tool
.base_converter
import
ConverterUtil
from
transform
.base_converter
import
ConverterUtil
from
mace.python.tools.converter_tool
.base_converter
import
DataFormat
from
transform
.base_converter
import
DataFormat
from
mace.python.tools.converter_tool
.base_converter
import
DeviceType
from
transform
.base_converter
import
DeviceType
from
mace.python.tools.converter_tool
.base_converter
import
EltwiseType
from
transform
.base_converter
import
EltwiseType
from
mace.python.tools.converter_tool
.base_converter
import
FrameworkType
from
transform
.base_converter
import
FrameworkType
from
mace.python.tools.converter_tool
.base_converter
import
MaceKeyword
from
transform
.base_converter
import
MaceKeyword
from
mace.python.tools.converter_tool
.base_converter
import
MaceOp
from
transform
.base_converter
import
MaceOp
from
mace.python.tools.converter_tool
.base_converter
import
MaceFixedDataFormatOps
# noqa
from
transform
.base_converter
import
MaceFixedDataFormatOps
# noqa
from
mace.python.tools.converter_tool
.base_converter
import
MaceTransposableDataFormatOps
# noqa
from
transform
.base_converter
import
MaceTransposableDataFormatOps
# noqa
from
mace.python.tools.converter_tool
.base_converter
import
PaddingMode
from
transform
.base_converter
import
PaddingMode
from
mace.python.tools.converter_tool
.base_converter
import
ReduceType
from
transform
.base_converter
import
ReduceType
from
mace.python.tools.converter_tool
.base_converter
import
TransformerRule
from
transform
.base_converter
import
TransformerRule
from
mace.python.tools.convert_util
import
mace_check
from
quantize
import
quantize_util
from
mace.python.tools.quantization
import
quantize_util
from
utils.util
import
mace_check
class
Transformer
(
base_converter
.
ConverterInterface
):
class
Transformer
(
base_converter
.
ConverterInterface
):
...
@@ -1309,12 +1309,27 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -1309,12 +1309,27 @@ class Transformer(base_converter.ConverterInterface):
return
False
return
False
def
update_float_op_data_type
(
self
):
def
update_float_op_data_type
(
self
):
if
self
.
_option
.
quantize
:
return
print
(
"update op with float data type"
)
print
(
"update op with float data type"
)
net
=
self
.
_model
net
=
self
.
_model
data_type
=
self
.
_option
.
data_type
data_type
=
self
.
_option
.
data_type
net
.
data_type
=
data_type
for
tensor
in
net
.
tensors
:
if
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT
:
tensor
.
data_type
=
data_type
if
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT
\
or
tensor
.
data_type
==
mace_pb2
.
DT_HALF
\
or
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT16
:
tensor
.
data_size
=
len
(
tensor
.
float_data
)
elif
tensor
.
data_type
==
mace_pb2
.
DT_INT32
:
tensor
.
data_size
=
len
(
tensor
.
int32_data
)
elif
tensor
.
data_type
==
mace_pb2
.
DT_UINT8
:
tensor
.
data_size
=
len
(
tensor
.
int32_data
)
if
self
.
_option
.
quantize
:
return
for
op
in
net
.
op
:
for
op
in
net
.
op
:
data_type_arg
=
ConverterUtil
.
get_arg
(
data_type_arg
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_op_data_type_str
)
op
,
MaceKeyword
.
mace_op_data_type_str
)
...
@@ -1732,6 +1747,8 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -1732,6 +1747,8 @@ class Transformer(base_converter.ConverterInterface):
self
.
_quantize_activation_info
[
op
.
input
[
0
]]
=
quantize_info
self
.
_quantize_activation_info
[
op
.
input
[
0
]]
=
quantize_info
# for add -> fakequant pattern
# for add -> fakequant pattern
self
.
_quantize_activation_info
[
op
.
output
[
0
]]
=
quantize_info
self
.
_quantize_activation_info
[
op
.
output
[
0
]]
=
quantize_info
print
(
op
.
input
[
0
],
op
.
output
[
0
])
op
.
type
=
MaceOp
.
Identity
.
name
op
.
type
=
MaceOp
.
Identity
.
name
return
False
return
False
...
@@ -1857,7 +1874,8 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -1857,7 +1874,8 @@ class Transformer(base_converter.ConverterInterface):
self
.
copy_quantize_info
(
self
.
copy_quantize_info
(
op
,
self
.
_quantize_activation_info
[
new_input_name
])
op
,
self
.
_quantize_activation_info
[
new_input_name
])
else
:
else
:
self
.
copy_quantize_info
(
op
,
producer_op
.
quantize_info
[
0
])
self
.
copy_quantize_info
(
op
,
producer_op
.
quantize_info
[
0
])
self
.
_quantize_activation_info
[
op
.
output
[
0
]]
=
\
self
.
_quantize_activation_info
[
op
.
output
[
0
]]
=
\
op
.
quantize_info
[
0
]
op
.
quantize_info
[
0
]
elif
(
op
.
type
==
MaceOp
.
Concat
.
name
elif
(
op
.
type
==
MaceOp
.
Concat
.
name
...
...
tools/
experimental
/__init__.py
→
tools/
python/utils
/__init__.py
浏览文件 @
320b509c
文件已移动
tools/python/utils/config_parser.py
0 → 100644
浏览文件 @
320b509c
# Copyright 2019 The MACE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
re
import
os
import
yaml
CPP_KEYWORDS
=
[
'alignas'
,
'alignof'
,
'and'
,
'and_eq'
,
'asm'
,
'atomic_cancel'
,
'atomic_commit'
,
'atomic_noexcept'
,
'auto'
,
'bitand'
,
'bitor'
,
'bool'
,
'break'
,
'case'
,
'catch'
,
'char'
,
'char16_t'
,
'char32_t'
,
'class'
,
'compl'
,
'concept'
,
'const'
,
'constexpr'
,
'const_cast'
,
'continue'
,
'co_await'
,
'co_return'
,
'co_yield'
,
'decltype'
,
'default'
,
'delete'
,
'do'
,
'double'
,
'dynamic_cast'
,
'else'
,
'enum'
,
'explicit'
,
'export'
,
'extern'
,
'false'
,
'float'
,
'for'
,
'friend'
,
'goto'
,
'if'
,
'import'
,
'inline'
,
'int'
,
'long'
,
'module'
,
'mutable'
,
'namespace'
,
'new'
,
'noexcept'
,
'not'
,
'not_eq'
,
'nullptr'
,
'operator'
,
'or'
,
'or_eq'
,
'private'
,
'protected'
,
'public'
,
'register'
,
'reinterpret_cast'
,
'requires'
,
'return'
,
'short'
,
'signed'
,
'sizeof'
,
'static'
,
'static_assert'
,
'static_cast'
,
'struct'
,
'switch'
,
'synchronized'
,
'template'
,
'this'
,
'thread_local'
,
'throw'
,
'true'
,
'try'
,
'typedef'
,
'typeid'
,
'typename'
,
'union'
,
'unsigned'
,
'using'
,
'virtual'
,
'void'
,
'volatile'
,
'wchar_t'
,
'while'
,
'xor'
,
'xor_eq'
,
'override'
,
'final'
,
'transaction_safe'
,
'transaction_safe_dynamic'
,
'if'
,
'elif'
,
'else'
,
'endif'
,
'defined'
,
'ifdef'
,
'ifndef'
,
'define'
,
'undef'
,
'include'
,
'line'
,
'error'
,
'pragma'
,
]
def
sanitize_load
(
s
):
# do not let yaml parse ON/OFF to boolean
for
w
in
[
"ON"
,
"OFF"
,
"on"
,
"off"
]:
s
=
re
.
sub
(
r
":\s+"
+
w
,
r
": '"
+
w
+
"'"
,
s
)
# sub ${} to env value
s
=
re
.
sub
(
r
"\${(\w+)}"
,
lambda
x
:
os
.
environ
[
x
.
group
(
1
)],
s
)
return
yaml
.
load
(
s
)
def
parse
(
path
):
with
open
(
path
)
as
f
:
config
=
sanitize_load
(
f
.
read
())
return
config
def
parse_device_info
(
path
):
conf
=
parse
(
path
)
return
conf
[
"devices"
]
tools/
experimental
/utils/device.py
→
tools/
python
/utils/device.py
浏览文件 @
320b509c
文件已移动
tools/
experimental
/utils/target.py
→
tools/
python
/utils/target.py
浏览文件 @
320b509c
文件已移动
tools/python/utils/util.py
0 → 100644
浏览文件 @
320b509c
# Copyright 2019 The MACE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
inspect
import
hashlib
import
os
import
urllib
from
utils
import
device
################################
# log
################################
class
CMDColors
:
PURPLE
=
'
\033
[95m'
BLUE
=
'
\033
[94m'
GREEN
=
'
\033
[92m'
YELLOW
=
'
\033
[93m'
RED
=
'
\033
[91m'
ENDC
=
'
\033
[0m'
BOLD
=
'
\033
[1m'
UNDERLINE
=
'
\033
[4m'
def
get_frame_info
(
level
=
2
):
caller_frame
=
inspect
.
stack
()[
level
]
info
=
inspect
.
getframeinfo
(
caller_frame
[
0
])
return
info
.
filename
+
':'
+
str
(
info
.
lineno
)
+
': '
class
MaceLogger
:
@
staticmethod
def
header
(
message
):
print
(
CMDColors
.
PURPLE
+
message
+
CMDColors
.
ENDC
)
@
staticmethod
def
summary
(
message
):
print
(
CMDColors
.
GREEN
+
message
+
CMDColors
.
ENDC
)
@
staticmethod
def
info
(
message
):
print
(
get_frame_info
()
+
message
)
@
staticmethod
def
warning
(
message
):
print
(
CMDColors
.
YELLOW
+
'WARNING: '
+
get_frame_info
()
+
message
+
CMDColors
.
ENDC
)
@
staticmethod
def
error
(
message
):
print
(
CMDColors
.
RED
+
'ERROR: '
+
get_frame_info
()
+
message
+
CMDColors
.
ENDC
)
exit
(
1
)
def
mace_check
(
condition
,
message
):
if
not
condition
:
MaceLogger
.
error
(
message
)
################################
# file
################################
def
file_checksum
(
fname
):
hash_func
=
hashlib
.
sha256
()
with
open
(
fname
,
"rb"
)
as
f
:
for
chunk
in
iter
(
lambda
:
f
.
read
(
4096
),
b
""
):
hash_func
.
update
(
chunk
)
return
hash_func
.
hexdigest
()
def
download_or_get_file
(
file
,
sha256_checksum
,
output_dir
):
model_file
=
output_dir
+
"/"
+
sha256_checksum
+
".pb"
if
file
.
startswith
(
"http://"
)
or
file
.
startswith
(
"https://"
):
if
not
os
.
path
.
exists
(
model_file
)
or
file_checksum
(
model_file
)
!=
sha256_checksum
:
MaceLogger
.
info
(
"Downloading file %s, please wait ..."
%
file
)
urllib
.
urlretrieve
(
file
,
model_file
)
MaceLogger
.
info
(
"Model downloaded successfully."
)
else
:
device
.
execute
(
"cp %s %s"
%
(
file
,
model_file
))
return
model_file
tools/
experimental/utils
/__init__.py
→
tools/
python/visualize
/__init__.py
浏览文件 @
320b509c
文件已移动
mace/python/tools/visualization
/index.html
→
tools/python/visualize
/index.html
浏览文件 @
320b509c
文件已移动
mace/python/tools/visualization
/visualize_model.py
→
tools/python/visualize
/visualize_model.py
浏览文件 @
320b509c
import
os
import
json
import
json
import
numpy
as
np
import
numpy
as
np
...
@@ -19,8 +20,8 @@ class NPEncoder(json.JSONEncoder):
...
@@ -19,8 +20,8 @@ class NPEncoder(json.JSONEncoder):
class
ModelVisualizer
(
object
):
class
ModelVisualizer
(
object
):
def
__init__
(
self
,
model_name
,
proto
):
def
__init__
(
self
,
model_name
,
proto
,
output_dir
):
self
.
_output_file
=
"
build/%s_index.html"
%
model_name
self
.
_output_file
=
"
%s/%s_index.html"
%
(
output_dir
,
model_name
)
self
.
_proto
=
proto
self
.
_proto
=
proto
def
render_html
(
self
):
def
render_html
(
self
):
...
@@ -82,7 +83,8 @@ class ModelVisualizer(object):
...
@@ -82,7 +83,8 @@ class ModelVisualizer(object):
json_msg
=
json
.
dumps
(
json_obj
,
cls
=
NPEncoder
)
json_msg
=
json
.
dumps
(
json_obj
,
cls
=
NPEncoder
)
with
open
(
"mace/python/tools/visualization/index.html"
)
as
f
:
cwd
=
os
.
path
.
dirname
(
__file__
)
with
open
(
cwd
+
"/index.html"
)
as
f
:
html
=
f
.
read
()
html
=
f
.
read
()
return
html
%
json_msg
return
html
%
json_msg
...
...
tools/sh_commands.py
浏览文件 @
320b509c
...
@@ -368,7 +368,6 @@ def gen_mace_engine_factory_source(model_tags,
...
@@ -368,7 +368,6 @@ def gen_mace_engine_factory_source(model_tags,
sh
.
mkdir
(
"-p"
,
codegen_tools_dir
)
sh
.
mkdir
(
"-p"
,
codegen_tools_dir
)
gen_mace_engine_factory
(
gen_mace_engine_factory
(
model_tags
,
model_tags
,
"mace/python/tools"
,
embed_model_data
,
embed_model_data
,
codegen_tools_dir
)
codegen_tools_dir
)
six
.
print_
(
"Generate mace engine creator source done!
\n
"
)
six
.
print_
(
"Generate mace engine creator source done!
\n
"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录