Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
8f9ccc88
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8f9ccc88
编写于
3月 15, 2019
作者:
L
liyin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support scalar input
上级
0894c8e9
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
69 addition
and
34 deletion
+69
-34
mace/python/tools/converter.py
mace/python/tools/converter.py
+10
-3
mace/python/tools/converter_tool/tensorflow_converter.py
mace/python/tools/converter_tool/tensorflow_converter.py
+40
-26
mace/tools/validation/mace_run.cc
mace/tools/validation/mace_run.cc
+0
-1
mace/utils/utils.cc
mace/utils/utils.cc
+7
-1
tools/common.py
tools/common.py
+7
-0
tools/generate_data.py
tools/generate_data.py
+1
-1
tools/validate.py
tools/validate.py
+4
-2
未找到文件。
mace/python/tools/converter.py
浏览文件 @
8f9ccc88
...
...
@@ -67,12 +67,19 @@ def file_checksum(fname):
return
hash_func
.
hexdigest
()
def
split_shape
(
shape
):
if
shape
.
strip
()
==
""
:
return
[]
else
:
return
shape
.
split
(
','
)
def
parse_int_array_from_str
(
ints_str
):
return
[
int
(
i
nt_str
)
for
int_str
in
ints_str
.
split
(
','
)]
return
[
int
(
i
)
for
i
in
split_shape
(
ints_str
)]
def
parse_float_array_from_str
(
in
ts_str
):
return
[
float
(
i
nt_str
)
for
int_str
in
in
ts_str
.
split
(
','
)]
def
parse_float_array_from_str
(
floa
ts_str
):
return
[
float
(
i
)
for
i
in
floa
ts_str
.
split
(
','
)]
def
transpose_shape
(
shape
,
dst_order
):
...
...
mace/python/tools/converter_tool/tensorflow_converter.py
浏览文件 @
8f9ccc88
...
...
@@ -288,6 +288,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
tf_graph_def
.
ParseFromString
(
f
.
read
())
self
.
_placeholders
=
{}
self
.
_skip_tensor
=
set
()
self
.
_output_shape
=
{}
print
(
"Run transform_graph: %s"
%
TFTransformGraphOptions
[
option
.
device
])
...
...
@@ -316,10 +318,16 @@ class TensorflowConverter(base_converter.ConverterInterface):
with
session
.
graph
.
as_default
()
as
graph
:
tf
.
import_graph_def
(
transformed_graph_def
,
name
=
''
)
self
.
_tf_graph
=
graph
self
.
update_output_shapes
(
session
)
self
.
_skip_tensor
=
set
()
self
.
_output_shape_list
=
[]
self
.
_output_shape_op_list
=
[]
# we have polluted graph with 'shape' ops, so reset it and reload it
# again
tf
.
reset_default_graph
()
with
tf
.
Session
()
as
session
:
with
session
.
graph
.
as_default
()
as
graph
:
tf
.
import_graph_def
(
transformed_graph_def
,
name
=
''
)
self
.
_tf_graph
=
graph
def
run
(
self
):
with
tf
.
Session
()
as
session
:
...
...
@@ -364,10 +372,17 @@ class TensorflowConverter(base_converter.ConverterInterface):
return
tensor_name
[:
idx
]
def
update_output_shapes
(
self
,
sess
):
output_shapes
=
sess
.
run
(
self
.
_output_shape_op_list
,
tensors
=
[]
shape_tensors
=
[]
for
tf_op
in
self
.
_tf_graph
.
get_operations
():
for
output
in
tf_op
.
outputs
:
tensors
.
append
(
output
.
name
)
shape_tensors
.
append
(
tf
.
shape
(
output
))
tensor_shapes
=
sess
.
run
(
shape_tensors
,
feed_dict
=
self
.
_placeholders
)
for
i
in
range
(
len
(
self
.
_output_shape_list
)):
self
.
_output_shape
_list
[
i
].
dims
.
extend
(
output_shapes
[
i
])
for
i
in
range
(
len
(
tensors
)):
self
.
_output_shape
[
tensors
[
i
]]
=
tensor_shapes
[
i
]
def
convert_ops
(
self
,
sess
):
for
tf_op
in
self
.
_tf_graph
.
get_operations
():
...
...
@@ -375,7 +390,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
"Mace does not support tensorflow op type %s yet"
%
tf_op
.
type
)
self
.
_op_converters
[
tf_op
.
type
](
tf_op
)
self
.
update_output_shapes
(
sess
)
self
.
convert_tensors
()
def
convert_tensors
(
self
):
...
...
@@ -409,18 +424,17 @@ class TensorflowConverter(base_converter.ConverterInterface):
# this function tries to infer tensor shape, but some dimension shape
# may be undefined due to variance of input length
def
infer_tensor_shape
(
self
,
output_shape
,
tensor
):
inferred_tensor_shape
=
tensor
.
shape
.
as_list
()
inferred_success
=
True
for
_
,
dim
in
enumerate
(
inferred_tensor_shape
):
if
dim
is
None
:
inferred_success
=
False
break
if
inferred_success
:
output_shape
.
dims
.
extend
(
inferred_tensor_shape
)
def
infer_tensor_shape
(
self
,
tensor
,
output_shape
=
None
):
shape
=
None
if
tensor
.
name
in
self
.
_output_shape
:
shape
=
self
.
_output_shape
[
tensor
.
name
]
else
:
self
.
_output_shape_list
.
append
(
output_shape
)
self
.
_output_shape_op_list
.
append
(
tf
.
shape
(
tensor
))
shape
=
tensor
.
shape
.
as_list
()
if
output_shape
:
output_shape
.
dims
.
extend
(
shape
)
return
shape
def
convert_nop
(
self
,
tf_op
):
pass
...
...
@@ -433,7 +447,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
op
.
output
.
extend
([
tf_output
.
name
for
tf_output
in
tf_op
.
outputs
])
for
tf_output
in
tf_op
.
outputs
:
output_shape
=
op
.
output_shape
.
add
()
self
.
infer_tensor_shape
(
output_shape
,
tf_output
)
self
.
infer_tensor_shape
(
tf_output
,
output_shape
)
data_type_arg
=
op
.
arg
.
add
()
data_type_arg
.
name
=
'T'
...
...
@@ -516,10 +530,10 @@ class TensorflowConverter(base_converter.ConverterInterface):
def
check_is_scalar
(
tf_op
):
if
len
(
tf_op
.
inputs
)
==
1
:
return
len
(
tf_op
.
inputs
[
0
].
shape
)
==
0
return
len
(
self
.
infer_tensor_shape
(
tf_op
.
inputs
[
0
])
)
==
0
elif
len
(
tf_op
.
inputs
)
==
2
:
return
len
(
tf_op
.
inputs
[
0
].
shape
)
==
0
and
\
len
(
tf_op
.
inputs
[
1
].
shape
)
==
0
return
len
(
self
.
infer_tensor_shape
(
tf_op
.
inputs
[
0
])
)
==
0
and
\
len
(
self
.
infer_tensor_shape
(
tf_op
.
inputs
[
1
])
)
==
0
if
check_is_scalar
(
tf_op
):
op
.
type
=
MaceOp
.
ScalarMath
.
name
...
...
@@ -546,9 +560,9 @@ class TensorflowConverter(base_converter.ConverterInterface):
EltwiseType
.
SUM
,
EltwiseType
.
PROD
,
EltwiseType
.
MAX
,
EltwiseType
.
MIN
]
if
len
(
tf_op
.
inputs
)
>
1
and
\
len
(
tf_op
.
inputs
[
1
].
shape
)
==
0
and
\
tf_op
.
inputs
[
1
].
op
.
type
==
TFOpType
.
Const
.
name
:
if
(
len
(
tf_op
.
inputs
)
>
1
and
len
(
self
.
infer_tensor_shape
(
tf_op
.
inputs
[
1
]))
==
0
and
tf_op
.
inputs
[
1
].
op
.
type
==
TFOpType
.
Const
.
name
)
:
scalar
=
tf_op
.
inputs
[
1
].
eval
().
astype
(
np
.
float32
)
value_arg
=
op
.
arg
.
add
()
value_arg
.
name
=
MaceKeyword
.
mace_scalar_input_str
...
...
@@ -560,7 +574,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
value_index_arg
.
i
=
1
self
.
_skip_tensor
.
add
(
tf_op
.
inputs
[
1
].
name
)
del
op
.
input
[
1
]
elif
len
(
tf_op
.
inputs
[
0
].
shape
)
==
0
and
\
elif
len
(
self
.
infer_tensor_shape
(
tf_op
.
inputs
[
0
])
)
==
0
and
\
tf_op
.
inputs
[
0
].
op
.
type
==
TFOpType
.
Const
.
name
and
\
is_commutative
(
type_arg
.
i
):
scalar
=
tf_op
.
inputs
[
0
].
eval
().
astype
(
np
.
float32
)
...
...
mace/tools/validation/mace_run.cc
浏览文件 @
8f9ccc88
...
...
@@ -274,7 +274,6 @@ bool RunModel(const std::string &model_name,
MemoryMap
(
FLAGS_model_data_file
,
&
model_weights_data
,
&
model_weights_data_size
);
MACE_CHECK
(
model_weights_data
!=
nullptr
&&
model_weights_data_size
!=
0
);
}
std
::
shared_ptr
<
mace
::
MaceEngine
>
engine
;
...
...
mace/utils/utils.cc
浏览文件 @
8f9ccc88
...
...
@@ -122,6 +122,9 @@ void MemoryMap(const std::string &file,
struct
stat
st
;
fstat
(
fd
,
&
st
);
*
size
=
static_cast
<
size_t
>
(
st
.
st_size
);
if
(
*
size
==
0
)
{
return
;
}
*
data
=
static_cast
<
const
unsigned
char
*>
(
mmap
(
nullptr
,
*
size
,
PROT_READ
,
MAP_PRIVATE
,
fd
,
0
));
...
...
@@ -135,7 +138,10 @@ void MemoryMap(const std::string &file,
void
MemoryUnMap
(
const
unsigned
char
*
data
,
const
size_t
&
size
)
{
MACE_CHECK
(
data
!=
nullptr
&&
size
>
0
,
"data is null or size is 0"
);
if
(
size
==
0
)
{
return
;
}
MACE_CHECK
(
data
!=
nullptr
,
"data is null"
);
int
ret
=
munmap
(
const_cast
<
unsigned
char
*>
(
data
),
size
);
...
...
tools/common.py
浏览文件 @
8f9ccc88
...
...
@@ -531,3 +531,10 @@ class ToolchainType:
class
TargetSOCTag
:
all
=
'all'
random
=
'random'
def
split_shape
(
shape
):
if
shape
.
strip
()
==
""
:
return
[]
else
:
return
shape
.
split
(
','
)
tools/generate_data.py
浏览文件 @
8f9ccc88
...
...
@@ -59,7 +59,7 @@ def generate_input_data(input_file, input_node, input_shape, input_ranges,
assert
len
(
input_names
)
==
len
(
input_shapes
)
==
len
(
input_ranges
)
==
len
(
input_data_types
)
# noqa
for
i
in
range
(
len
(
input_names
)):
shape
=
[
int
(
x
)
for
x
in
input_shapes
[
i
].
split
(
','
)]
shape
=
[
int
(
x
)
for
x
in
common
.
split_shape
(
input_shapes
[
i
]
)]
input_range
=
[
float
(
x
)
for
x
in
input_ranges
[
i
].
split
(
','
)]
generate_data
(
input_names
[
i
],
shape
,
input_file
,
input_range
,
input_data_types
[
i
])
...
...
tools/validate.py
浏览文件 @
8f9ccc88
...
...
@@ -68,6 +68,8 @@ def calculate_similarity(u, v, data_type=np.float64):
def
calculate_pixel_accuracy
(
out_value
,
mace_out_value
):
if
len
(
out_value
.
shape
)
<
2
:
return
1.0
out_value
=
out_value
.
reshape
((
-
1
,
out_value
.
shape
[
-
1
]))
batches
=
out_value
.
shape
[
0
]
classes
=
out_value
.
shape
[
1
]
...
...
@@ -323,10 +325,10 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
validation_outputs_data
,
log_file
):
input_names
=
[
name
for
name
in
input_node
.
split
(
','
)]
input_shape_strs
=
[
shape
for
shape
in
input_shape
.
split
(
':'
)]
input_shapes
=
[[
int
(
x
)
for
x
in
shape
.
split
(
','
)]
input_shapes
=
[[
int
(
x
)
for
x
in
common
.
split_shape
(
shape
)]
for
shape
in
input_shape_strs
]
output_shape_strs
=
[
shape
for
shape
in
output_shape
.
split
(
':'
)]
output_shapes
=
[[
int
(
x
)
for
x
in
shape
.
split
(
','
)]
output_shapes
=
[[
int
(
x
)
for
x
in
common
.
split_shape
(
shape
)]
for
shape
in
output_shape_strs
]
input_data_formats
=
[
df
for
df
in
input_data_format_str
.
split
(
','
)]
output_data_formats
=
[
df
for
df
in
output_data_format_str
.
split
(
','
)]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录