Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
041a5d2e
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
041a5d2e
编写于
6月 29, 2020
作者:
L
luxuhui
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat: support splitV op for tensorflow
N/A Signed-off-by:
N
Luxuhui
<
luxuhui@xiaomi.com
>
上级
a28e8128
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
45 addition
and
21 deletion
+45
-21
tools/device.py
tools/device.py
+2
-2
tools/python/transform/base_converter.py
tools/python/transform/base_converter.py
+2
-0
tools/python/transform/megengine_converter.py
tools/python/transform/megengine_converter.py
+6
-4
tools/python/transform/tensorflow_converter.py
tools/python/transform/tensorflow_converter.py
+24
-5
tools/python/validate.py
tools/python/validate.py
+5
-4
tools/sh_commands.py
tools/sh_commands.py
+2
-2
tools/validate.py
tools/validate.py
+4
-4
未找到文件。
tools/device.py
浏览文件 @
041a5d2e
...
...
@@ -220,8 +220,8 @@ class DeviceWrapper:
"MACE_LOG_TENSOR_RANGE=%d"
%
(
1
if
quantize_stat
else
0
),
"%s/%s"
%
(
target_dir
,
target_name
),
"--model_name=%s"
%
model_tag
,
"--input_node=
'%s'
"
%
","
.
join
(
input_nodes
),
"--output_node=
'%s'
"
%
","
.
join
(
output_nodes
),
"--input_node=
%s
"
%
","
.
join
(
input_nodes
),
"--output_node=
%s
"
%
","
.
join
(
output_nodes
),
"--input_shape=%s"
%
":"
.
join
(
input_shapes
),
"--output_shape=%s"
%
":"
.
join
(
output_shapes
),
"--input_data_format=%s"
%
","
.
join
(
input_data_formats
),
...
...
tools/python/transform/base_converter.py
浏览文件 @
041a5d2e
...
...
@@ -236,6 +236,7 @@ class MaceKeyword(object):
mace_end_axis_str
=
'end_axis'
mace_num_axes_str
=
'num_axes'
mace_num_split_str
=
'num_split'
mace_size_splits_str
=
'size_splits'
mace_keepdims_str
=
'keepdims'
mace_shape_str
=
'shape'
mace_winograd_filter_transformed
=
'is_filter_transformed'
...
...
@@ -548,6 +549,7 @@ class ConverterOption(object):
# Model structure related transformation
TransformerRule
.
REMOVE_USELESS_OP
,
TransformerRule
.
TRANSFORM_FAKE_QUANTIZE
,
TransformerRule
.
REMOVE_USELESS_OP
,
TransformerRule
.
TRANSFORM_GLOBAL_POOLING
,
TransformerRule
.
TRANSFORM_LSTMCELL_ZEROSTATE
,
TransformerRule
.
TRANSFORM_BASIC_LSTMCELL
,
...
...
tools/python/transform/megengine_converter.py
浏览文件 @
041a5d2e
...
...
@@ -166,7 +166,7 @@ class MegengineConverter(base_converter.ConverterInterface):
if
","
in
op
.
input
[
i
]:
op_name
=
op
.
input
[
i
]
op_name
=
op_name
.
replace
(
","
,
"#"
)
if
(
op_name
in
self
.
_option
.
input_nodes
or
\
if
(
op_name
in
self
.
_option
.
input_nodes
or
op_name
in
self
.
_option
.
output_nodes
):
op
.
input
[
i
]
=
op_name
for
i
in
six
.
moves
.
range
(
len
(
op
.
output
)):
...
...
@@ -195,7 +195,8 @@ class MegengineConverter(base_converter.ConverterInterface):
kernels_arg
.
name
=
MaceKeyword
.
mace_kernel_str
kernels_arg
.
ints
.
extend
(
kernel
)
if
op_def
.
type
in
(
MaceOp
.
Conv2D
.
name
,
MaceOp
.
DepthwiseConv2d
.
name
,
MaceOp
.
Deconv2D
.
name
,
MaceOp
.
DepthwiseDeconv2d
.
name
):
MaceOp
.
Deconv2D
.
name
,
MaceOp
.
DepthwiseDeconv2d
.
name
):
dilation
=
[
params
[
mge_dilate_h_str
],
params
[
mge_dilate_w_str
]]
dilation_arg
=
op_def
.
arg
.
add
()
dilation_arg
.
name
=
MaceKeyword
.
mace_dilations_str
...
...
@@ -426,13 +427,14 @@ class MegengineConverter(base_converter.ConverterInterface):
# check the case of counting include padding
mode
=
mge_op
.
params
[
"mode"
]
if
mode
==
"AVERAGE_COUNT_EXCLUDE_PADDING"
or
\
(
mode
==
"AVERAGE"
and
mge_op
.
params
[
"pad_w"
]
==
0
and
\
(
mode
==
"AVERAGE"
and
mge_op
.
params
[
"pad_w"
]
==
0
and
mge_op
.
params
[
"pad_h"
]
==
0
):
pool_type_arg
.
i
=
PoolingType
.
AVG
.
value
elif
mode
==
"MAX"
:
pool_type_arg
.
i
=
PoolingType
.
MAX
.
value
else
:
mace_check
(
False
,
"AVERAGE pooling should not count padding values"
)
mace_check
(
False
,
"AVERAGE pooling should not count padding values"
)
self
.
add_stride_pad_kernel_arg
(
mge_op
.
params
,
op
)
...
...
tools/python/transform/tensorflow_converter.py
浏览文件 @
041a5d2e
...
...
@@ -116,6 +116,7 @@ TFSupportedOps = [
'SpaceToBatchND'
,
'SpaceToDepth'
,
'Split'
,
'SplitV'
,
'Sqrt'
,
'Square'
,
'SquaredDifference'
,
...
...
@@ -279,6 +280,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType
.
SpaceToBatchND
.
name
:
self
.
convert_space_batch
,
TFOpType
.
SpaceToDepth
.
name
:
self
.
convert_space_depth
,
TFOpType
.
Split
.
name
:
self
.
convert_split
,
TFOpType
.
SplitV
.
name
:
self
.
convert_splitv
,
TFOpType
.
Sqrt
.
name
:
self
.
convert_elementwise
,
TFOpType
.
Squeeze
.
name
:
self
.
convert_squeeze
,
TFOpType
.
Stack
.
name
:
self
.
convert_stack
,
...
...
@@ -1057,14 +1059,15 @@ class TensorflowConverter(base_converter.ConverterInterface):
keep_dims_arg
.
name
=
MaceKeyword
.
mace_keepdims_str
keep_dims_arg
.
i
=
0
def
convert_split
(
self
,
tf_op
):
def
convert_split
(
self
,
tf_op
,
axis_idx
=
0
):
op
=
self
.
convert_general_op
(
tf_op
)
num_or_size_splits
=
tf_op
.
get_attr
(
'num_split'
)
if
num_or_size_splits
==
1
:
is_split
=
(
num_or_size_splits
>
1
)
if
not
is_split
:
op
.
type
=
MaceOp
.
Identity
.
name
else
:
op
.
type
=
MaceOp
.
Split
.
name
axis
=
tf_op
.
inputs
[
0
].
eval
().
astype
(
np
.
int32
)
axis
=
tf_op
.
inputs
[
axis_idx
].
eval
().
astype
(
np
.
int32
)
axis
=
len
(
op
.
output_shape
[
0
].
dims
)
+
axis
if
axis
<
0
else
axis
axis_arg
=
op
.
arg
.
add
()
...
...
@@ -1074,8 +1077,24 @@ class TensorflowConverter(base_converter.ConverterInterface):
num_split_arg
=
op
.
arg
.
add
()
num_split_arg
.
name
=
MaceKeyword
.
mace_num_split_str
num_split_arg
.
i
=
num_or_size_splits
del
op
.
input
[
0
]
self
.
_skip_tensor
.
add
(
tf_op
.
inputs
[
0
].
name
)
del
op
.
input
[
axis_idx
]
self
.
_skip_tensor
.
add
(
tf_op
.
inputs
[
axis_idx
].
name
)
return
(
op
,
is_split
)
def
convert_splitv
(
self
,
tf_op
):
(
op
,
is_split
)
=
self
.
convert_split
(
tf_op
,
2
)
if
not
is_split
:
return
size_splits_arg
=
op
.
arg
.
add
()
size_splits_arg
.
name
=
MaceKeyword
.
mace_size_splits_str
size_splits
=
tf_op
.
inputs
[
1
].
eval
().
astype
(
np
.
int32
)
del
op
.
input
[
1
]
self
.
_skip_tensor
.
add
(
tf_op
.
inputs
[
1
].
name
)
# todo(luxuhui): support size_splits
for
size
in
size_splits
:
mace_check
(
size
==
size_splits
[
0
],
"SplitV Only support even distribution"
)
size_splits_arg
.
ints
.
extend
(
size_splits
)
def
convert_tile
(
self
,
tf_op
):
op
=
self
.
convert_general_op
(
tf_op
)
...
...
tools/python/validate.py
浏览文件 @
041a5d2e
...
...
@@ -318,6 +318,7 @@ def validate_onnx_model(model_file,
mace_out_value
,
value
,
validation_threshold
,
log_file
)
def
validate_megengine_model
(
model_file
,
input_file
,
mace_out_file
,
input_names
,
input_shapes
,
input_data_formats
,
output_names
,
output_shapes
,
...
...
@@ -337,7 +338,7 @@ def validate_megengine_model(model_file, input_file,
util
.
formatted_file_name
(
input_file
,
input_names
[
i
]),
input_data_types
[
i
])
input_value
=
input_value
.
reshape
(
input_shapes
[
i
])
if
(
input_data_formats
[
i
]
==
DataFormat
.
NHWC
and
\
if
(
input_data_formats
[
i
]
==
DataFormat
.
NHWC
and
len
(
input_shapes
[
i
])
==
4
):
input_value
=
input_value
.
transpose
((
0
,
3
,
1
,
2
))
feed_inputs
.
append
(
input_value
)
...
...
@@ -356,10 +357,10 @@ def validate_megengine_model(model_file, input_file,
output_file_name
=
\
util
.
formatted_file_name
(
mace_out_file
,
output_names
[
i
])
mace_out_value
=
load_data
(
output_file_name
)
if
(
output_data_formats
[
i
]
==
DataFormat
.
NHWC
and
\
if
(
output_data_formats
[
i
]
==
DataFormat
.
NHWC
and
len
(
output_shapes
[
i
])
==
4
):
mace_out_value
=
\
mace_out_value
.
reshape
(
output_shapes
[
i
]).
transpose
((
0
,
3
,
1
,
2
))
mace_out_value
=
mace_out_value
.
reshape
(
output_shapes
[
i
]).
transpose
((
0
,
3
,
1
,
2
))
compare_output
(
output_names
[
i
],
mace_out_value
,
mge_output_value
,
validation_threshold
,
log_file
)
...
...
tools/sh_commands.py
浏览文件 @
041a5d2e
...
...
@@ -748,8 +748,8 @@ def validate_model(abi,
"--input_file=/mace/%s"
%
input_file_name
,
"--mace_out_file=/mace/%s"
%
output_file_name
,
"--device_type=%s"
%
device_type
,
"--input_node=
'%s'
"
%
","
.
join
(
input_nodes
),
"--output_node=
'%s'
"
%
","
.
join
(
output_nodes
),
"--input_node=
%s
"
%
","
.
join
(
input_nodes
),
"--output_node=
%s
"
%
","
.
join
(
output_nodes
),
"--input_shape=%s"
%
":"
.
join
(
input_shapes
),
"--output_shape=%s"
%
":"
.
join
(
output_shapes
),
"--input_data_format=%s"
%
","
.
join
(
input_data_formats
),
...
...
tools/validate.py
浏览文件 @
041a5d2e
...
...
@@ -350,7 +350,7 @@ def validate_megengine_model(platform, device_type, model_file, input_file,
common
.
formatted_file_name
(
input_file
,
input_names
[
i
]),
input_data_types
[
i
])
input_value
=
input_value
.
reshape
(
input_shapes
[
i
])
if
(
input_data_formats
[
i
]
==
common
.
DataFormat
.
NHWC
and
\
if
(
input_data_formats
[
i
]
==
common
.
DataFormat
.
NHWC
and
len
(
input_shapes
[
i
])
==
4
):
input_value
=
input_value
.
transpose
((
0
,
3
,
1
,
2
))
feed_inputs
.
append
(
input_value
)
...
...
@@ -369,10 +369,10 @@ def validate_megengine_model(platform, device_type, model_file, input_file,
output_file_name
=
\
common
.
formatted_file_name
(
mace_out_file
,
output_names
[
i
])
mace_out_value
=
load_data
(
output_file_name
)
if
(
output_data_formats
[
i
]
==
common
.
DataFormat
.
NHWC
and
\
if
(
output_data_formats
[
i
]
==
common
.
DataFormat
.
NHWC
and
len
(
output_shapes
[
i
])
==
4
):
mace_out_value
=
\
mace_out_value
.
reshape
(
output_shapes
[
i
]).
transpose
((
0
,
3
,
1
,
2
))
mace_out_value
=
mace_out_value
.
reshape
(
output_shapes
[
i
]).
transpose
((
0
,
3
,
1
,
2
))
compare_output
(
platform
,
device_type
,
output_names
[
i
],
mace_out_value
,
mge_output_value
,
validation_threshold
,
log_file
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录