Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
5f10b2c3
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
5f10b2c3
编写于
5月 16, 2018
作者:
刘
刘琦
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'transform' into 'master'
Add identity op See merge request !484
上级
f147aa67
bd88ead3
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
24 addition
and
10 deletion
+24
-10
mace/python/tools/converter_tool/tensorflow_converter.py
mace/python/tools/converter_tool/tensorflow_converter.py
+9
-6
mace/python/tools/tf_ops_stats.py
mace/python/tools/tf_ops_stats.py
+5
-2
tools/validate.py
tools/validate.py
+10
-2
未找到文件。
mace/python/tools/converter_tool/tensorflow_converter.py
浏览文件 @
5f10b2c3
...
...
@@ -101,6 +101,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
'AvgPool'
:
self
.
convert_pooling
,
'MaxPool'
:
self
.
convert_pooling
,
'Squeeze'
:
self
.
convert_identity
,
'Identity'
:
self
.
convert_identity
,
'Reshape'
:
self
.
convert_reshape
,
'Shape'
:
self
.
convert_nop
,
'Softmax'
:
self
.
convert_softmax
,
...
...
@@ -153,11 +154,13 @@ class TensorflowConverter(base_converter.ConverterInterface):
def
add_shape_info
(
self
,
tf_graph_def
):
for
node
in
tf_graph_def
.
node
:
if
node
.
name
in
self
.
_option
.
input_nodes
:
for
input_node
in
self
.
_option
.
input_nodes
.
values
():
if
node
.
name
==
input_node
.
name
\
or
node
.
name
+
':0'
==
input_node
.
name
:
del
node
.
attr
[
'shape'
].
shape
.
dim
[:]
node
.
attr
[
'shape'
].
shape
.
dim
.
extend
([
tensor_shape_pb2
.
TensorShapeProto
.
Dim
(
size
=
i
)
for
i
in
self
.
_option
.
input_nodes
[
node
.
name
]
.
shape
input_node
.
shape
])
@
staticmethod
...
...
mace/python/tools/tf_ops_stats.py
浏览文件 @
5f10b2c3
...
...
@@ -45,8 +45,11 @@ def to_int_list(long_list):
def
add_shape_info
(
input_graph_def
,
input_nodes
,
input_shapes
):
inputs_replaced_graph
=
graph_pb2
.
GraphDef
()
for
node
in
input_graph_def
.
node
:
if
node
.
name
in
input_nodes
or
node
.
name
+
':0'
in
input_nodes
:
if
node
.
name
in
input_nodes
:
idx
=
input_nodes
.
index
(
node
.
name
)
else
:
idx
=
input_nodes
.
index
(
node
.
name
+
':0'
)
input_shape
=
input_shapes
[
idx
]
print
input_shape
placeholder_node
=
copy
.
deepcopy
(
node
)
...
...
tools/validate.py
浏览文件 @
5f10b2c3
...
...
@@ -65,6 +65,13 @@ def compare_output(platform, device_type, output_name, mace_out_value,
sys
.
exit
(
-
1
)
def
normalize_tf_tensor_name
(
name
):
if
name
.
find
(
':'
)
==
-
1
:
return
name
+
':0'
else
:
return
name
def
validate_tf_model
(
platform
,
device_type
,
model_file
,
input_file
,
mace_out_file
,
input_names
,
input_shapes
,
output_names
):
import
tensorflow
as
tf
...
...
@@ -88,13 +95,14 @@ def validate_tf_model(platform, device_type, model_file, input_file,
common
.
formatted_file_name
(
input_file
,
input_names
[
i
]))
input_value
=
input_value
.
reshape
(
input_shapes
[
i
])
input_node
=
graph
.
get_tensor_by_name
(
input_names
[
i
]
+
':0'
)
normalize_tf_tensor_name
(
input_names
[
i
])
)
input_dict
[
input_node
]
=
input_value
output_nodes
=
[]
for
name
in
output_names
:
output_nodes
.
extend
(
[
graph
.
get_tensor_by_name
(
name
+
':0'
)])
[
graph
.
get_tensor_by_name
(
normalize_tf_tensor_name
(
name
))])
output_values
=
session
.
run
(
output_nodes
,
feed_dict
=
input_dict
)
for
i
in
range
(
len
(
output_names
)):
output_file_name
=
common
.
formatted_file_name
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录