Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4fff5f44
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4fff5f44
编写于
4月 28, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
!791 onnx adapter for MaxPoolWithArgmax
Merge pull request !791 from 梅晓蔚/master
上级
2b68b1e6
cd899fba
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
64 addition
and
5 deletion
+64
-5
mindspore/ccsrc/onnx/onnx_exporter.cc
mindspore/ccsrc/onnx/onnx_exporter.cc
+39
-5
tests/ut/python/utils/test_serialize.py
tests/ut/python/utils/test_serialize.py
+25
-0
未找到文件。
mindspore/ccsrc/onnx/onnx_exporter.cc
浏览文件 @
4fff5f44
...
...
@@ -29,11 +29,12 @@
namespace
mindspore
{
enum
OpMergeMode
{
OP_MERGE_UNDEFINED
=
0
,
// undefined behavior
OP_MERGE_IGNORE
=
1
,
// indicate an input op merged into other op in compute node list
OP_MERGE_CONV
=
2
,
// indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv`
OP_MERGE_GEMM
=
3
,
// indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm`
OP_MERGE_BATCH_NORM
=
4
,
// indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization`
OP_MERGE_UNDEFINED
=
0
,
// undefined behavior
OP_MERGE_IGNORE
=
1
,
// indicate an input op merged into other op in compute node list
OP_MERGE_CONV
=
2
,
// indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv`
OP_MERGE_GEMM
=
3
,
// indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm`
OP_MERGE_BATCH_NORM
=
4
,
// indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization`
OP_MERGE_MAXPOOL_WITH_ARGMAX
=
5
,
// indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool`
};
struct
OpMergedInfo
{
...
...
@@ -233,6 +234,13 @@ OPERATOR_ONNX_CONVERT_DEFINE(
.
Attr
(
"padding"
,
"auto_pad"
,
onnx
::
AttributeProto_AttributeType_STRING
,
SetPoolingPadMode
)
.
Attr
(
"strides"
,
"strides"
,
onnx
::
AttributeProto_AttributeType_INTS
,
SetAttrTupleValueToProto
<
2
>
))
OPERATOR_ONNX_CONVERT_DEFINE
(
MaxPoolWithArgmax
,
MaxPool
,
OpNameInfo
()
.
Attr
(
"ksize"
,
"kernel_shape"
,
onnx
::
AttributeProto_AttributeType_INTS
,
SetAttrTupleValueToProto
<
2
>
)
.
Attr
(
"padding"
,
"auto_pad"
,
onnx
::
AttributeProto_AttributeType_STRING
,
SetPoolingPadMode
)
.
Attr
(
"strides"
,
"strides"
,
onnx
::
AttributeProto_AttributeType_INTS
,
SetAttrTupleValueToProto
<
2
>
))
OPERATOR_ONNX_CONVERT_DEFINE
(
AvgPool
,
AveragePool
,
OpNameInfo
()
...
...
@@ -254,6 +262,7 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
fn
(
OP_CONVERT_FUNCTION_NAME
(
Flatten
)());
fn
(
OP_CONVERT_FUNCTION_NAME
(
MaxPool
)());
fn
(
OP_CONVERT_FUNCTION_NAME
(
MaxPoolWithArgmax
)());
fn
(
OP_CONVERT_FUNCTION_NAME
(
AvgPool
)());
fn
(
OP_CONVERT_FUNCTION_NAME
(
Squeeze
)());
...
...
@@ -328,6 +337,8 @@ class OnnxExporter {
onnx
::
GraphProto
*
graph_proto
);
void
ExportMergeBatchNorm
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
graph_proto
);
void
ExportMergeMaxPoolWithArgmax
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
graph_proto
);
void
ExportOutput
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
graph_proto
);
...
...
@@ -516,6 +527,12 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vecto
op_merged_infos
[
cnode
].
mode
=
OP_MERGE_BATCH_NORM
;
op_merged_infos
[
cnode
->
input
(
1
)].
mode
=
OP_MERGE_IGNORE
;
op_merged_infos
[
cnode
->
input
(
1
)].
referred_count
-=
1
;
}
else
if
(
cnode
->
IsApply
(
prim
::
kPrimTupleGetItem
)
&&
IsPrimitiveCNode
(
cnode
->
input
(
1
),
std
::
make_shared
<
Primitive
>
(
"MaxPoolWithArgmax"
))
&&
GetInt32Value
(
cnode
->
input
(
2
))
==
0
)
{
op_merged_infos
[
cnode
].
mode
=
OP_MERGE_MAXPOOL_WITH_ARGMAX
;
op_merged_infos
[
cnode
->
input
(
1
)].
mode
=
OP_MERGE_IGNORE
;
op_merged_infos
[
cnode
->
input
(
1
)].
referred_count
-=
1
;
}
}
}
...
...
@@ -563,6 +580,9 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodeP
case
OP_MERGE_BATCH_NORM
:
ExportMergeBatchNorm
(
func_graph
,
cnode
,
node_map_ptr
,
graph_proto
);
break
;
case
OP_MERGE_MAXPOOL_WITH_ARGMAX
:
ExportMergeMaxPoolWithArgmax
(
func_graph
,
cnode
,
node_map_ptr
,
graph_proto
);
break
;
default:
ExportCNode
(
func_graph
,
cnode
,
node_map_ptr
,
graph_proto
);
break
;
...
...
@@ -811,6 +831,20 @@ void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CN
(
*
node_map_ptr
)[
node
]
=
ExportPrimitive
(
func_graph
,
node_map_ptr
,
prim_batch_norm
,
inputs
,
graph_proto
);
}
void
OnnxExporter
::
ExportMergeMaxPoolWithArgmax
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
const
graph_proto
)
{
auto
maxpool_with_argmax_node
=
dyn_cast
<
CNode
>
(
node
->
input
(
1
));
PrimitivePtr
prim_maxpool_with_argmax
=
dyn_cast
<
Primitive
>
((
dyn_cast
<
ValueNode
>
(
maxpool_with_argmax_node
->
input
(
0
)))
->
value
());
std
::
vector
<
AnfNodePtr
>
inputs
;
for
(
size_t
i
=
1
;
i
<
maxpool_with_argmax_node
->
inputs
().
size
();
i
++
)
{
inputs
.
push_back
(
maxpool_with_argmax_node
->
input
(
i
));
}
(
*
node_map_ptr
)[
node
]
=
ExportPrimitive
(
func_graph
,
node_map_ptr
,
prim_maxpool_with_argmax
,
inputs
,
graph_proto
);
}
void
OnnxExporter
::
ExportOutput
(
const
FuncGraphPtr
&
/*func_graph*/
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
const
graph_proto
)
{
if
(
node
->
inputs
().
size
()
!=
2
)
{
...
...
tests/ut/python/utils/test_serialize.py
浏览文件 @
4fff5f44
...
...
@@ -362,6 +362,31 @@ def test_lenet5_onnx_export():
net
=
LeNet5
()
export
(
net
,
input
,
file_name
=
'lenet5.onnx'
,
file_format
=
'ONNX'
)
class
DefinedNet
(
nn
.
Cell
):
"""simple Net definition with maxpoolwithargmax."""
def
__init__
(
self
,
num_classes
=
10
):
super
(
DefinedNet
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
0
,
weight_init
=
"zeros"
)
self
.
bn1
=
nn
.
BatchNorm2d
(
64
)
self
.
relu
=
nn
.
ReLU
()
self
.
maxpool
=
P
.
MaxPoolWithArgmax
(
padding
=
"same"
,
ksize
=
2
,
strides
=
2
)
self
.
flatten
=
nn
.
Flatten
()
self
.
fc
=
nn
.
Dense
(
int
(
56
*
56
*
64
),
num_classes
)
def
construct
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
x
,
argmax
=
self
.
maxpool
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc
(
x
)
return
x
def
test_net_onnx_maxpoolwithargmax_export
():
input
=
Tensor
(
np
.
ones
([
1
,
3
,
224
,
224
]).
astype
(
np
.
float32
)
*
0.01
)
net
=
DefinedNet
()
export
(
net
,
input
,
file_name
=
'definedNet.onnx'
,
file_format
=
'ONNX'
)
@
run_on_onnxruntime
def
test_lenet5_onnx_load_run
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录