Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
毕竟曾有刹那
Mace
提交
e2e0a7f3
Mace
项目概览
毕竟曾有刹那
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
e2e0a7f3
编写于
8月 16, 2018
作者:
Y
yejianwu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support input with int32
上级
92f18fc6
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
83 addition
and
26 deletion
+83
-26
docs/user_guide/advanced_usage.rst
docs/user_guide/advanced_usage.rst
+2
-0
mace/kernels/strided_slice.h
mace/kernels/strided_slice.h
+2
-1
tools/converter.py
tools/converter.py
+30
-3
tools/generate_data.py
tools/generate_data.py
+20
-11
tools/sh_commands.py
tools/sh_commands.py
+6
-2
tools/validate.py
tools/validate.py
+23
-9
未找到文件。
docs/user_guide/advanced_usage.rst
浏览文件 @
e2e0a7f3
...
...
@@ -82,6 +82,8 @@ in one deployment file.
- The running device, one of [cpu, gpu, dsp, cpu_gpu]. cpu_gpu contains CPU and GPU model definition so you can run the model on both CPU and GPU.
* - data_type
- [optional] The data type used for specified runtime. [fp16_fp32, fp32_fp32] for GPU, default is fp16_fp32, [fp32] for CPU and [uint8] for DSP.
* - input_data_types
- [optional] The input data type for specific op(eg. gather), which can be [int32, float32], default to float32.
* - limit_opencl_kernel_time
- [optional] Whether splitting the OpenCL kernel within 1 ms to keep UI responsiveness, default is 0.
* - obfuscate
...
...
mace/kernels/strided_slice.h
浏览文件 @
e2e0a7f3
...
...
@@ -194,7 +194,8 @@ struct StridedSliceFunctor {
strides_data
[
2
]
>
0
?
k
<
real_end_indices
[
2
]
:
k
>
real_end_indices
[
2
];
k
+=
strides_data
[
2
])
{
*
output_data
++
=
input_data
[(
i
*
input
->
dim
(
1
)
+
j
)
*
input
->
dim
(
2
)
+
k
];
*
output_data
++
=
input_data
[(
i
*
input
->
dim
(
1
)
+
j
)
*
input
->
dim
(
2
)
+
k
];
}
}
}
...
...
tools/converter.py
浏览文件 @
e2e0a7f3
...
...
@@ -130,6 +130,16 @@ class RuntimeType(object):
cpu_gpu
=
'cpu+gpu'
InputDataTypeStrs
=
[
"int32"
,
"float32"
,
]
InputDataType
=
Enum
(
'InputDataType'
,
[(
ele
,
ele
)
for
ele
in
InputDataTypeStrs
],
type
=
str
)
CPUDataTypeStrs
=
[
"fp32"
,
]
...
...
@@ -183,6 +193,7 @@ class YAMLKeyword(object):
output_shapes
=
'output_shapes'
runtime
=
'runtime'
data_type
=
'data_type'
input_data_types
=
'input_data_types'
limit_opencl_kernel_time
=
'limit_opencl_kernel_time'
nnlib_graph_mode
=
'nnlib_graph_mode'
obfuscate
=
'obfuscate'
...
...
@@ -447,6 +458,18 @@ def format_model_config(flags):
if
not
isinstance
(
value
,
list
):
subgraph
[
key
]
=
[
value
]
input_data_types
=
subgraph
.
get
(
YAMLKeyword
.
input_data_types
,
""
)
if
input_data_types
:
if
not
isinstance
(
input_data_types
,
list
):
subgraph
[
YAMLKeyword
.
input_data_types
]
=
[
input_data_types
]
for
input_data_type
in
input_data_types
:
mace_check
(
input_data_type
in
InputDataTypeStrs
,
ModuleName
.
YAML_CONFIG
,
"'input_data_types' must be in "
+
str
(
InputDataTypeStrs
))
else
:
subgraph
[
YAMLKeyword
.
input_data_types
]
=
[]
validation_threshold
=
subgraph
.
get
(
YAMLKeyword
.
validation_threshold
,
{})
if
not
isinstance
(
validation_threshold
,
dict
):
...
...
@@ -1025,7 +1048,8 @@ def tuning(library_name, model_name, model_config,
subgraphs
[
0
][
YAMLKeyword
.
input_tensors
],
subgraphs
[
0
][
YAMLKeyword
.
input_shapes
],
subgraphs
[
0
][
YAMLKeyword
.
validation_inputs_data
],
input_ranges
=
subgraphs
[
0
][
YAMLKeyword
.
input_ranges
])
input_ranges
=
subgraphs
[
0
][
YAMLKeyword
.
input_ranges
],
input_data_types
=
subgraphs
[
0
][
YAMLKeyword
.
input_data_types
])
sh_commands
.
tuning_run
(
abi
=
target_abi
,
...
...
@@ -1170,7 +1194,8 @@ def run_specific_target(flags, configs, target_abi,
subgraphs
[
0
][
YAMLKeyword
.
input_tensors
],
subgraphs
[
0
][
YAMLKeyword
.
input_shapes
],
subgraphs
[
0
][
YAMLKeyword
.
validation_inputs_data
],
input_ranges
=
subgraphs
[
0
][
YAMLKeyword
.
input_ranges
])
input_ranges
=
subgraphs
[
0
][
YAMLKeyword
.
input_ranges
],
input_data_types
=
subgraphs
[
0
][
YAMLKeyword
.
input_data_types
])
runtime_list
=
[]
if
target_abi
==
ABIType
.
host
:
...
...
@@ -1236,6 +1261,7 @@ def run_specific_target(flags, configs, target_abi,
output_shapes
=
subgraphs
[
0
][
YAMLKeyword
.
output_shapes
],
model_output_dir
=
model_output_dir
,
phone_data_dir
=
PHONE_DATA_DIR
,
input_data_types
=
subgraphs
[
0
][
YAMLKeyword
.
input_data_types
],
# noqa
caffe_env
=
flags
.
caffe_env
,
validation_threshold
=
subgraphs
[
0
][
YAMLKeyword
.
validation_threshold
][
device_type
])
# noqa
if
flags
.
report
and
flags
.
round
>
0
:
...
...
@@ -1478,7 +1504,8 @@ def bm_specific_target(flags, configs, target_abi, target_soc, serial_num):
subgraphs
[
0
][
YAMLKeyword
.
input_tensors
],
subgraphs
[
0
][
YAMLKeyword
.
input_shapes
],
subgraphs
[
0
][
YAMLKeyword
.
validation_inputs_data
],
input_ranges
=
subgraphs
[
0
][
YAMLKeyword
.
input_ranges
])
input_ranges
=
subgraphs
[
0
][
YAMLKeyword
.
input_ranges
],
input_data_types
=
subgraphs
[
0
][
YAMLKeyword
.
input_data_types
])
runtime_list
=
[]
if
target_abi
==
ABIType
.
host
:
runtime_list
.
extend
([
RuntimeType
.
cpu
])
...
...
tools/generate_data.py
浏览文件 @
e2e0a7f3
...
...
@@ -27,30 +27,37 @@ import common
# --input_ranges -1,1
def
generate_data
(
name
,
shape
,
input_file
,
tensor_range
):
def
generate_data
(
name
,
shape
,
input_file
,
tensor_range
,
input_data_type
):
np
.
random
.
seed
()
data
=
np
.
random
.
random
(
shape
)
*
(
tensor_range
[
1
]
-
tensor_range
[
0
])
\
+
tensor_range
[
0
]
input_file_name
=
common
.
formatted_file_name
(
input_file
,
name
)
print
'Generate input file: '
,
input_file_name
data
.
astype
(
np
.
float32
).
tofile
(
input_file_name
)
if
input_data_type
==
'float32'
:
np_data_type
=
np
.
float32
elif
input_data_type
==
'int32'
:
np_data_type
=
np
.
int32
data
.
astype
(
np_data_type
).
tofile
(
input_file_name
)
def
generate_input_data
(
input_file
,
input_node
,
input_shape
,
input_ranges
):
def
generate_input_data
(
input_file
,
input_node
,
input_shape
,
input_ranges
,
input_data_type
):
input_names
=
[
name
for
name
in
input_node
.
split
(
','
)]
input_shapes
=
[
shape
for
shape
in
input_shape
.
split
(
':'
)]
if
input_ranges
:
input_ranges
=
[
r
for
r
in
input_ranges
.
split
(
':'
)]
else
:
input_ranges
=
None
assert
len
(
input_names
)
==
len
(
input_shapes
)
input_ranges
=
[[
-
1
,
1
]]
*
len
(
input_names
)
if
input_data_type
:
input_data_types
=
[
data_type
for
data_type
in
input_data_type
.
split
(
','
)]
else
:
input_data_types
=
[
'float32'
]
*
len
(
input_names
)
assert
len
(
input_names
)
==
len
(
input_shapes
)
==
len
(
input_ranges
)
==
len
(
input_data_types
)
# noqa
for
i
in
range
(
len
(
input_names
)):
shape
=
[
int
(
x
)
for
x
in
input_shapes
[
i
].
split
(
','
)]
if
input_ranges
:
input_range
=
[
float
(
x
)
for
x
in
input_ranges
[
i
].
split
(
','
)]
else
:
input_range
=
[
-
1
,
1
]
generate_data
(
input_names
[
i
],
shape
,
input_file
,
input_range
)
generate_data
(
input_names
[
i
],
shape
,
input_file
,
input_ranges
[
i
],
input_data_types
[
i
])
print
"Generate input file done."
...
...
@@ -66,6 +73,8 @@ def parse_args():
"--input_shape"
,
type
=
str
,
default
=
"1,64,64,3"
,
help
=
"input shape."
)
parser
.
add_argument
(
"--input_ranges"
,
type
=
str
,
default
=
"-1,1"
,
help
=
"input range."
)
parser
.
add_argument
(
"--input_data_type"
,
type
=
str
,
default
=
""
,
help
=
"input range."
)
return
parser
.
parse_known_args
()
...
...
@@ -73,4 +82,4 @@ def parse_args():
if
__name__
==
'__main__'
:
FLAGS
,
unparsed
=
parse_args
()
generate_input_data
(
FLAGS
.
input_file
,
FLAGS
.
input_node
,
FLAGS
.
input_shape
,
FLAGS
.
input_ranges
)
FLAGS
.
input_ranges
,
FLAGS
.
input_data_type
)
tools/sh_commands.py
浏览文件 @
e2e0a7f3
...
...
@@ -536,6 +536,7 @@ def gen_random_input(model_output_dir,
input_shapes
,
input_files
,
input_ranges
,
input_data_types
,
input_file_name
=
"model_input"
):
for
input_name
in
input_nodes
:
formatted_name
=
common
.
formatted_file_name
(
...
...
@@ -545,10 +546,12 @@ def gen_random_input(model_output_dir,
input_nodes_str
=
","
.
join
(
input_nodes
)
input_shapes_str
=
":"
.
join
(
input_shapes
)
input_ranges_str
=
":"
.
join
(
input_ranges
)
input_data_types_str
=
","
.
join
(
input_data_types
)
generate_input_data
(
"%s/%s"
%
(
model_output_dir
,
input_file_name
),
input_nodes_str
,
input_shapes_str
,
input_ranges_str
)
input_ranges_str
,
input_data_types_str
)
input_file_list
=
[]
if
isinstance
(
input_files
,
list
):
...
...
@@ -800,6 +803,7 @@ def validate_model(abi,
output_shapes
,
model_output_dir
,
phone_data_dir
,
input_data_types
,
caffe_env
,
input_file_name
=
"model_input"
,
output_file_name
=
"model_out"
,
...
...
@@ -821,7 +825,7 @@ def validate_model(abi,
"%s/%s"
%
(
model_output_dir
,
output_file_name
),
device_type
,
":"
.
join
(
input_shapes
),
":"
.
join
(
output_shapes
),
","
.
join
(
input_nodes
),
","
.
join
(
output_nodes
),
validation_threshold
)
validation_threshold
,
","
.
join
(
input_data_types
)
)
elif
platform
==
"caffe"
:
image_name
=
"mace-caffe:latest"
container_name
=
"mace_caffe_validator"
...
...
tools/validate.py
浏览文件 @
e2e0a7f3
...
...
@@ -40,10 +40,12 @@ import common
VALIDATION_MODULE
=
'VALIDATION'
def
load_data
(
file
):
def
load_data
(
file
,
data_type
=
'float32'
):
if
os
.
path
.
isfile
(
file
):
if
data_type
==
'float32'
:
return
np
.
fromfile
(
file
=
file
,
dtype
=
np
.
float32
)
else
:
elif
data_type
==
'int32'
:
return
np
.
fromfile
(
file
=
file
,
dtype
=
np
.
int32
)
return
np
.
empty
([
0
])
...
...
@@ -78,7 +80,7 @@ def normalize_tf_tensor_name(name):
def
validate_tf_model
(
platform
,
device_type
,
model_file
,
input_file
,
mace_out_file
,
input_names
,
input_shapes
,
output_names
,
validation_threshold
):
output_names
,
validation_threshold
,
input_data_types
):
import
tensorflow
as
tf
if
not
os
.
path
.
isfile
(
model_file
):
common
.
MaceLogger
.
error
(
...
...
@@ -98,7 +100,8 @@ def validate_tf_model(platform, device_type, model_file, input_file,
input_dict
=
{}
for
i
in
range
(
len
(
input_names
)):
input_value
=
load_data
(
common
.
formatted_file_name
(
input_file
,
input_names
[
i
]))
common
.
formatted_file_name
(
input_file
,
input_names
[
i
]),
input_data_types
[
i
])
input_value
=
input_value
.
reshape
(
input_shapes
[
i
])
input_node
=
graph
.
get_tensor_by_name
(
normalize_tf_tensor_name
(
input_names
[
i
]))
...
...
@@ -168,18 +171,23 @@ def validate_caffe_model(platform, device_type, model_file, input_file,
def
validate
(
platform
,
model_file
,
weight_file
,
input_file
,
mace_out_file
,
device_type
,
input_shape
,
output_shape
,
input_node
,
output_node
,
validation_threshold
):
validation_threshold
,
input_data_type
):
input_names
=
[
name
for
name
in
input_node
.
split
(
','
)]
input_shape_strs
=
[
shape
for
shape
in
input_shape
.
split
(
':'
)]
input_shapes
=
[[
int
(
x
)
for
x
in
shape
.
split
(
','
)]
for
shape
in
input_shape_strs
]
if
input_data_type
:
input_data_types
=
[
data_type
for
data_type
in
input_data_type
.
split
(
','
)]
else
:
input_data_types
=
[
'float32'
]
*
len
(
input_names
)
output_names
=
[
name
for
name
in
output_node
.
split
(
','
)]
assert
len
(
input_names
)
==
len
(
input_shapes
)
if
platform
==
'tensorflow'
:
validate_tf_model
(
platform
,
device_type
,
model_file
,
input_file
,
mace_out_file
,
input_names
,
input_shapes
,
output_names
,
validation_threshold
)
output_names
,
validation_threshold
,
input_data_types
)
elif
platform
==
'caffe'
:
output_shape_strs
=
[
shape
for
shape
in
output_shape
.
split
(
':'
)]
output_shapes
=
[[
int
(
x
)
for
x
in
shape
.
split
(
','
)]
...
...
@@ -220,6 +228,11 @@ def parse_args():
"--output_shape"
,
type
=
str
,
default
=
"1,64,64,2"
,
help
=
"output shape."
)
parser
.
add_argument
(
"--input_node"
,
type
=
str
,
default
=
"input_node"
,
help
=
"input node"
)
parser
.
add_argument
(
"--input_data_type"
,
type
=
str
,
default
=
""
,
help
=
"input data type"
)
parser
.
add_argument
(
"--output_node"
,
type
=
str
,
default
=
"output_node"
,
help
=
"output node"
)
parser
.
add_argument
(
...
...
@@ -241,4 +254,5 @@ if __name__ == '__main__':
FLAGS
.
output_shape
,
FLAGS
.
input_node
,
FLAGS
.
output_node
,
FLAGS
.
validation_threshold
)
FLAGS
.
validation_threshold
,
FLAGS
.
input_data_type
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录