Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
a28e8128
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
a28e8128
编写于
6月 29, 2020
作者:
Y
yzchenmonkey
提交者:
GitHub
6月 29, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add MegEngine converter for MACE (#658)
上级
43b96415
变更
11
展开全部
隐藏空白更改
内联
并排
Showing
11 changed file
with
833 addition
and
10 deletion
+833
-10
mace/ops/strided_slice.cc
mace/ops/strided_slice.cc
+2
-2
tools/converter.py
tools/converter.py
+1
-0
tools/device.py
tools/device.py
+4
-4
tools/python/convert.py
tools/python/convert.py
+4
-0
tools/python/run_model.py
tools/python/run_model.py
+1
-1
tools/python/transform/base_converter.py
tools/python/transform/base_converter.py
+1
-1
tools/python/transform/megengine_converter.py
tools/python/transform/megengine_converter.py
+696
-0
tools/python/utils/config_parser.py
tools/python/utils/config_parser.py
+1
-0
tools/python/validate.py
tools/python/validate.py
+54
-0
tools/sh_commands.py
tools/sh_commands.py
+14
-2
tools/validate.py
tools/validate.py
+55
-0
未找到文件。
mace/ops/strided_slice.cc
浏览文件 @
a28e8128
...
...
@@ -127,9 +127,9 @@ class StridedSliceOp : public Operation {
strides_data
,
strides_data
+
strides
->
size
());
MACE_CHECK
(
input
->
size
()
>
0
&&
input
->
dim_size
()
>
0
&&
input
->
dim_size
()
<=
4
,
input
->
dim_size
()
<=
5
,
// for megengine is 5, the others are 4
"The input size should larger than 0."
" And input dims should be an integer in (0,
4
]."
);
" And input dims should be an integer in (0,
5
]."
);
std
::
vector
<
index_t
>
output_shape
=
{};
...
...
tools/converter.py
浏览文件 @
a28e8128
...
...
@@ -60,6 +60,7 @@ PlatformTypeStrs = [
"tensorflow"
,
"caffe"
,
"onnx"
,
"megengine"
,
]
PlatformType
=
Enum
(
'PlatformType'
,
[(
ele
,
ele
)
for
ele
in
PlatformTypeStrs
],
type
=
str
)
...
...
tools/device.py
浏览文件 @
a28e8128
...
...
@@ -220,8 +220,8 @@ class DeviceWrapper:
"MACE_LOG_TENSOR_RANGE=%d"
%
(
1
if
quantize_stat
else
0
),
"%s/%s"
%
(
target_dir
,
target_name
),
"--model_name=%s"
%
model_tag
,
"--input_node=
%s
"
%
","
.
join
(
input_nodes
),
"--output_node=
%s
"
%
","
.
join
(
output_nodes
),
"--input_node=
'%s'
"
%
","
.
join
(
input_nodes
),
"--output_node=
'%s'
"
%
","
.
join
(
output_nodes
),
"--input_shape=%s"
%
":"
.
join
(
input_shapes
),
"--output_shape=%s"
%
":"
.
join
(
output_shapes
),
"--input_data_format=%s"
%
","
.
join
(
input_data_formats
),
...
...
@@ -322,8 +322,8 @@ class DeviceWrapper:
cmd
.
extend
([
"%s/%s"
%
(
self
.
data_dir
,
target_name
),
"--model_name=%s"
%
model_tag
,
"--input_node=
%s
"
%
","
.
join
(
input_nodes
),
"--output_node=
%s
"
%
","
.
join
(
output_nodes
),
"--input_node=
'%s'
"
%
","
.
join
(
input_nodes
),
"--output_node=
'%s'
"
%
","
.
join
(
output_nodes
),
"--input_shape=%s"
%
":"
.
join
(
input_shapes
),
"--output_shape=%s"
%
":"
.
join
(
output_shapes
),
"--input_data_format=%s"
%
","
.
join
(
input_data_formats
),
...
...
tools/python/convert.py
浏览文件 @
a28e8128
...
...
@@ -184,6 +184,10 @@ def convert_model(conf, quantize_stat):
from
transform
import
onnx_converter
converter
=
onnx_converter
.
OnnxConverter
(
option
,
conf
[
"model_file_path"
])
elif
platform
==
Platform
.
MEGENGINE
:
from
transform
import
megengine_converter
converter
=
megengine_converter
.
MegengineConverter
(
option
,
conf
[
"model_file_path"
])
else
:
mace_check
(
False
,
"Mace do not support platorm %s yet."
%
platform
)
...
...
tools/python/run_model.py
浏览文件 @
a28e8128
...
...
@@ -145,7 +145,7 @@ def run_model_for_device(flags, args, dev, model_name, model_conf):
"device"
:
runtime
.
name
}
opts
=
[
"--%s=
%s
"
%
(
arg_key
,
arg_val
)
for
arg_key
,
arg_val
in
opts
=
[
"--%s=
'%s'
"
%
(
arg_key
,
arg_val
)
for
arg_key
,
arg_val
in
model_args
.
items
()]
+
args
should_generate_data
=
(
flags
.
validate
or
flags
.
tune
or
"--benchmark"
in
opts
)
...
...
tools/python/transform/base_converter.py
浏览文件 @
a28e8128
...
...
@@ -86,6 +86,7 @@ class FrameworkType(Enum):
TENSORFLOW
=
0
CAFFE
=
1
ONNX
=
2
MEGENGINE
=
3
MaceSupportedOps
=
[
...
...
@@ -547,7 +548,6 @@ class ConverterOption(object):
# Model structure related transformation
TransformerRule
.
REMOVE_USELESS_OP
,
TransformerRule
.
TRANSFORM_FAKE_QUANTIZE
,
TransformerRule
.
REMOVE_USELESS_OP
,
TransformerRule
.
TRANSFORM_GLOBAL_POOLING
,
TransformerRule
.
TRANSFORM_LSTMCELL_ZEROSTATE
,
TransformerRule
.
TRANSFORM_BASIC_LSTMCELL
,
...
...
tools/python/transform/megengine_converter.py
0 → 100644
浏览文件 @
a28e8128
此差异已折叠。
点击以展开。
tools/python/utils/config_parser.py
浏览文件 @
a28e8128
...
...
@@ -149,6 +149,7 @@ class Platform(Enum):
TENSORFLOW
=
0
CAFFE
=
1
ONNX
=
2
MEGENGINE
=
3
def
parse_platform
(
str
):
...
...
tools/python/validate.py
浏览文件 @
a28e8128
...
...
@@ -318,6 +318,51 @@ def validate_onnx_model(model_file,
mace_out_value
,
value
,
validation_threshold
,
log_file
)
def
validate_megengine_model
(
model_file
,
input_file
,
mace_out_file
,
input_names
,
input_shapes
,
input_data_formats
,
output_names
,
output_shapes
,
output_data_formats
,
validation_threshold
,
input_data_types
,
log_file
):
import
megengine._internal
as
mgb
if
not
os
.
path
.
isfile
(
model_file
):
common
.
MaceLogger
.
error
(
VALIDATION_MODULE
,
"Input graph file '"
+
model_file
+
"' does not exist!"
,
)
feed_inputs
=
[]
for
i
in
range
(
len
(
input_names
)):
input_value
=
load_data
(
util
.
formatted_file_name
(
input_file
,
input_names
[
i
]),
input_data_types
[
i
])
input_value
=
input_value
.
reshape
(
input_shapes
[
i
])
if
(
input_data_formats
[
i
]
==
DataFormat
.
NHWC
and
\
len
(
input_shapes
[
i
])
==
4
):
input_value
=
input_value
.
transpose
((
0
,
3
,
1
,
2
))
feed_inputs
.
append
(
input_value
)
cg
,
_
,
outputs
=
mgb
.
load_comp_graph_from_file
(
model_file
)
inputs
=
mgb
.
cgtools
.
get_dep_vars
(
outputs
,
"Host2DeviceCopy"
)
inputs
=
sorted
(
inputs
,
key
=
lambda
i
:
i
.
name
)
outputs
=
list
(
map
(
mgb
.
copy_output
,
outputs
))
if
len
(
outputs
)
==
1
:
(
outputs
,)
=
outputs
func
=
cg
.
compile
(
inputs
,
outputs
)
mge_output_value
=
func
(
*
feed_inputs
)
for
i
in
range
(
len
(
output_names
)):
output_file_name
=
\
util
.
formatted_file_name
(
mace_out_file
,
output_names
[
i
])
mace_out_value
=
load_data
(
output_file_name
)
if
(
output_data_formats
[
i
]
==
DataFormat
.
NHWC
and
\
len
(
output_shapes
[
i
])
==
4
):
mace_out_value
=
\
mace_out_value
.
reshape
(
output_shapes
[
i
]).
transpose
((
0
,
3
,
1
,
2
))
compare_output
(
output_names
[
i
],
mace_out_value
,
mge_output_value
,
validation_threshold
,
log_file
)
def
validate
(
platform
,
model_file
,
weight_file
,
input_file
,
mace_out_file
,
input_shape
,
output_shape
,
input_data_format
,
...
...
@@ -354,3 +399,12 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
output_node
,
output_shape
,
output_data_format
,
validation_threshold
,
input_data_type
,
backend
,
log_file
)
elif
platform
==
Platform
.
MEGENGINE
:
validate_megengine_model
(
model_file
,
input_file
,
mace_out_file
,
input_node
,
input_shape
,
input_data_format
,
output_node
,
output_shape
,
output_data_format
,
validation_threshold
,
input_data_type
,
log_file
)
tools/sh_commands.py
浏览文件 @
a28e8128
...
...
@@ -748,8 +748,8 @@ def validate_model(abi,
"--input_file=/mace/%s"
%
input_file_name
,
"--mace_out_file=/mace/%s"
%
output_file_name
,
"--device_type=%s"
%
device_type
,
"--input_node=
%s
"
%
","
.
join
(
input_nodes
),
"--output_node=
%s
"
%
","
.
join
(
output_nodes
),
"--input_node=
'%s'
"
%
","
.
join
(
input_nodes
),
"--output_node=
'%s'
"
%
","
.
join
(
output_nodes
),
"--input_shape=%s"
%
":"
.
join
(
input_shapes
),
"--output_shape=%s"
%
":"
.
join
(
output_shapes
),
"--input_data_format=%s"
%
","
.
join
(
input_data_formats
),
...
...
@@ -761,6 +761,18 @@ def validate_model(abi,
validation_outputs_data
),
"--log_file=%s"
%
log_file
,
_fg
=
True
)
elif
platform
==
"megengine"
:
validate
(
platform
,
model_file_path
,
""
,
"%s/%s"
%
(
model_output_dir
,
input_file_name
),
"%s/%s"
%
(
model_output_dir
,
output_file_name
),
device_type
,
":"
.
join
(
input_shapes
),
":"
.
join
(
output_shapes
),
","
.
join
(
input_data_formats
),
","
.
join
(
output_data_formats
),
","
.
join
(
input_nodes
),
","
.
join
(
output_nodes
),
validation_threshold
,
","
.
join
(
input_data_types
),
backend
,
validation_outputs_data
,
log_file
)
six
.
print_
(
"Validation done!
\n
"
)
...
...
tools/validate.py
浏览文件 @
a28e8128
...
...
@@ -331,6 +331,52 @@ def validate_onnx_model(platform, device_type, model_file,
validation_threshold
,
log_file
)
def
validate_megengine_model
(
platform
,
device_type
,
model_file
,
input_file
,
mace_out_file
,
input_names
,
input_shapes
,
input_data_formats
,
output_names
,
output_shapes
,
output_data_formats
,
validation_threshold
,
input_data_types
,
log_file
):
import
megengine._internal
as
mgb
if
not
os
.
path
.
isfile
(
model_file
):
common
.
MaceLogger
.
error
(
VALIDATION_MODULE
,
"Input graph file '"
+
model_file
+
"' does not exist!"
,
)
feed_inputs
=
[]
for
i
in
range
(
len
(
input_names
)):
input_value
=
load_data
(
common
.
formatted_file_name
(
input_file
,
input_names
[
i
]),
input_data_types
[
i
])
input_value
=
input_value
.
reshape
(
input_shapes
[
i
])
if
(
input_data_formats
[
i
]
==
common
.
DataFormat
.
NHWC
and
\
len
(
input_shapes
[
i
])
==
4
):
input_value
=
input_value
.
transpose
((
0
,
3
,
1
,
2
))
feed_inputs
.
append
(
input_value
)
cg
,
_
,
outputs
=
mgb
.
load_comp_graph_from_file
(
model_file
)
inputs
=
mgb
.
cgtools
.
get_dep_vars
(
outputs
,
"Host2DeviceCopy"
)
inputs
=
sorted
(
inputs
,
key
=
lambda
i
:
i
.
name
)
outputs
=
list
(
map
(
mgb
.
copy_output
,
outputs
))
if
len
(
outputs
)
==
1
:
(
outputs
,)
=
outputs
func
=
cg
.
compile
(
inputs
,
outputs
)
mge_output_value
=
func
(
*
feed_inputs
)
for
i
in
range
(
len
(
output_names
)):
output_file_name
=
\
common
.
formatted_file_name
(
mace_out_file
,
output_names
[
i
])
mace_out_value
=
load_data
(
output_file_name
)
if
(
output_data_formats
[
i
]
==
common
.
DataFormat
.
NHWC
and
\
len
(
output_shapes
[
i
])
==
4
):
mace_out_value
=
\
mace_out_value
.
reshape
(
output_shapes
[
i
]).
transpose
((
0
,
3
,
1
,
2
))
compare_output
(
platform
,
device_type
,
output_names
[
i
],
mace_out_value
,
mge_output_value
,
validation_threshold
,
log_file
)
def
validate
(
platform
,
model_file
,
weight_file
,
input_file
,
mace_out_file
,
device_type
,
input_shape
,
output_shape
,
input_data_format_str
,
output_data_format_str
,
input_node
,
output_node
,
...
...
@@ -385,6 +431,15 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
output_names
,
output_shapes
,
output_data_formats
,
validation_threshold
,
input_data_types
,
backend
,
log_file
)
elif
platform
==
'megengine'
:
validate_megengine_model
(
platform
,
device_type
,
model_file
,
input_file
,
mace_out_file
,
input_names
,
input_shapes
,
input_data_formats
,
output_names
,
output_shapes
,
output_data_formats
,
validation_threshold
,
input_data_types
,
log_file
)
def
parse_args
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录