Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
6073bca5
Mace
项目概览
Xiaomi
/
Mace
通知
107
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
6073bca5
编写于
8月 14, 2019
作者:
李
李滨
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'summary' into 'master'
Fix check tensors func See merge request !1173
上级
1fc9f7ed
aadc9ded
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
27 addition
and
30 deletion
+27
-30
.gitlab-ci.yml
.gitlab-ci.yml
+0
-2
tools/python/convert.py
tools/python/convert.py
+20
-9
tools/python/transform/transformer.py
tools/python/transform/transformer.py
+0
-13
tools/python/utils/util.py
tools/python/utils/util.py
+7
-6
未找到文件。
.gitlab-ci.yml
浏览文件 @
6073bca5
...
...
@@ -140,8 +140,6 @@ quantization_tests:
python tools/converter.py run --config=${CONF_FILE} --target_socs=$TARGET_SOCS --device_yml=${DEVICE_CONF_FILE} --round=1 --validate --model_graph_format=file --model_data_format=file || exit 1;
done
-
rm -rf mace-models
only
:
-
triggers
dynamic_linking_test
:
stage
:
extra
...
...
tools/python/convert.py
浏览文件 @
6073bca5
...
...
@@ -117,6 +117,12 @@ def convert(conf, output):
model_conf
[
"weight_file_path"
],
model_conf
[
"weight_sha256_checksum"
],
model_output
)
model_conf
[
"weight_file_path"
]
=
weight_file
# TODO: remove the following after quantize tool is made
if
"quantize_range_file"
in
model_conf
:
range_file
=
util
.
download_or_get_file
(
model_conf
[
"quantize_range_file"
],
""
,
model_output
)
model_conf
[
"quantize_range_file"
]
=
range_file
mace_model
=
convert_model
(
model_conf
)
...
...
@@ -215,16 +221,16 @@ def convert_model(conf):
output_node
.
data_format
=
cvt
.
DataFormat
.
NHWC
option
.
add_output_node
(
output_node
)
if
"check_
node
"
in
conf
:
check_
node_names
=
to_list
(
conf
[
"check_node
"
])
check_
node
_shapes
=
[
parse_int_array_from_str
(
shape
)
for
shape
in
to_list
(
conf
[
"check_shape
"
])]
mace_check
(
len
(
check_
node_names
)
==
len
(
check_node
_shapes
),
"check
node
count and shape count do not match."
)
for
i
in
range
(
len
(
check_
node_name
s
)):
if
"check_
tensors
"
in
conf
:
check_
tensors
=
to_list
(
conf
[
"check_tensors
"
])
check_
tensors
_shapes
=
[
parse_int_array_from_str
(
shape
)
for
shape
in
to_list
(
conf
[
"check_shapes
"
])]
mace_check
(
len
(
check_
tensors
)
==
len
(
check_tensors
_shapes
),
"check
tensors
count and shape count do not match."
)
for
i
in
range
(
len
(
check_
tensor
s
)):
check_node
=
cvt
.
NodeInfo
()
check_node
.
name
=
check_
node_name
s
[
i
]
check_node
.
shape
=
check_
node
_shapes
[
i
]
check_node
.
name
=
check_
tensor
s
[
i
]
check_node
.
shape
=
check_
tensors
_shapes
[
i
]
option
.
add_check_node
(
check_node
)
else
:
option
.
check_nodes
=
option
.
output_nodes
...
...
@@ -276,18 +282,23 @@ def merge_params(net_def):
if
tensor
.
data_type
==
mace_pb2
.
DT_HALF
:
data
=
bytearray
(
np
.
array
(
tensor
.
float_data
).
astype
(
np
.
float16
).
tobytes
())
tensor
.
data_size
=
len
(
tensor
.
float_data
)
elif
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT
:
data
=
bytearray
(
np
.
array
(
tensor
.
float_data
).
astype
(
np
.
float32
).
tobytes
())
tensor
.
data_size
=
len
(
tensor
.
float_data
)
elif
tensor
.
data_type
==
mace_pb2
.
DT_INT32
:
data
=
bytearray
(
np
.
array
(
tensor
.
int32_data
).
astype
(
np
.
int32
).
tobytes
())
tensor
.
data_size
=
len
(
tensor
.
int32_data
)
elif
tensor
.
data_type
==
mace_pb2
.
DT_UINT8
:
data
=
bytearray
(
np
.
array
(
tensor
.
int32_data
).
astype
(
np
.
uint8
).
tolist
())
tensor
.
data_size
=
len
(
tensor
.
int32_data
)
elif
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT16
:
data
=
bytearray
(
np
.
array
(
tensor
.
float_data
).
astype
(
np
.
float16
).
tobytes
())
tensor
.
data_size
=
len
(
tensor
.
float_data
)
else
:
raise
Exception
(
'Tensor data type %s not supported'
%
tensor
.
data_type
)
...
...
tools/python/transform/transformer.py
浏览文件 @
6073bca5
...
...
@@ -1314,19 +1314,6 @@ class Transformer(base_converter.ConverterInterface):
data_type
=
self
.
_option
.
data_type
net
.
data_type
=
data_type
for
tensor
in
net
.
tensors
:
if
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT
:
tensor
.
data_type
=
data_type
if
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT
\
or
tensor
.
data_type
==
mace_pb2
.
DT_HALF
\
or
tensor
.
data_type
==
mace_pb2
.
DT_FLOAT16
:
tensor
.
data_size
=
len
(
tensor
.
float_data
)
elif
tensor
.
data_type
==
mace_pb2
.
DT_INT32
:
tensor
.
data_size
=
len
(
tensor
.
int32_data
)
elif
tensor
.
data_type
==
mace_pb2
.
DT_UINT8
:
tensor
.
data_size
=
len
(
tensor
.
int32_data
)
if
self
.
_option
.
quantize
:
return
...
...
tools/python/utils/util.py
浏览文件 @
6073bca5
...
...
@@ -87,15 +87,16 @@ def file_checksum(fname):
def
download_or_get_file
(
file
,
sha256_checksum
,
output_dir
):
model_file
=
output_dir
+
"/"
+
sha256_checksum
+
".pb"
filename
=
os
.
path
.
basename
(
file
)
output_file
=
"%s/%s-%s.pb"
%
(
output_dir
,
filename
,
sha256_checksum
)
if
file
.
startswith
(
"http://"
)
or
file
.
startswith
(
"https://"
):
if
not
os
.
path
.
exists
(
model
_file
)
or
file_checksum
(
model
_file
)
!=
sha256_checksum
:
if
not
os
.
path
.
exists
(
output
_file
)
or
file_checksum
(
output
_file
)
!=
sha256_checksum
:
MaceLogger
.
info
(
"Downloading file %s, please wait ..."
%
file
)
urllib
.
urlretrieve
(
file
,
model
_file
)
urllib
.
urlretrieve
(
file
,
output
_file
)
MaceLogger
.
info
(
"Model downloaded successfully."
)
else
:
device
.
execute
(
"cp %s %s"
%
(
file
,
model
_file
))
device
.
execute
(
"cp %s %s"
%
(
file
,
output
_file
))
return
model
_file
return
output
_file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录