Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9c5f6b91
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9c5f6b91
编写于
4年前
作者:
W
wandongdong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add onnx ops for deepfm
上级
a5161a96
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
175 addition
and
15 deletion
+175
-15
mindspore/ccsrc/onnx/onnx_exporter.cc
mindspore/ccsrc/onnx/onnx_exporter.cc
+156
-14
tests/ut/python/onnx/test_onnx.py
tests/ut/python/onnx/test_onnx.py
+19
-1
未找到文件。
mindspore/ccsrc/onnx/onnx_exporter.cc
浏览文件 @
9c5f6b91
...
...
@@ -249,6 +249,13 @@ OPERATOR_ONNX_CONVERT_DEFINE(
.
Attr
(
"padding"
,
"auto_pad"
,
onnx
::
AttributeProto_AttributeType_STRING
,
SetPoolingPadMode
)
.
Attr
(
"strides"
,
"strides"
,
onnx
::
AttributeProto_AttributeType_INTS
,
SetAttrTupleValueToProto
<
2
>
))
OPERATOR_ONNX_CONVERT_DEFINE
(
GatherV2
,
Gather
,
OpNameInfo
())
OPERATOR_ONNX_CONVERT_DEFINE
(
make_tuple
,
SequenceConstruct
,
OpNameInfo
())
OPERATOR_ONNX_CONVERT_DEFINE
(
Concat
,
Concat
,
OpNameInfo
())
OPERATOR_ONNX_CONVERT_DEFINE
(
RealDiv
,
Div
,
OpNameInfo
())
OPERATOR_ONNX_CONVERT_DEFINE
(
ReduceSum
,
ReduceSum
,
OpNameInfo
())
OPERATOR_ONNX_CONVERT_DEFINE
(
Sub
,
Sub
,
OpNameInfo
())
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
void
RegisterOpConverters
(
const
std
::
function
<
void
(
OpNameInfo
&&
)
>
&
fn
)
{
...
...
@@ -269,6 +276,12 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
fn
(
OP_CONVERT_FUNCTION_NAME
(
Squeeze
)());
fn
(
OP_CONVERT_FUNCTION_NAME
(
BatchNorm
)());
fn
(
OP_CONVERT_FUNCTION_NAME
(
MatMul
)());
fn
(
OP_CONVERT_FUNCTION_NAME
(
make_tuple
)());
fn
(
OP_CONVERT_FUNCTION_NAME
(
Concat
)());
fn
(
OP_CONVERT_FUNCTION_NAME
(
RealDiv
)());
fn
(
OP_CONVERT_FUNCTION_NAME
(
BiasAdd
)());
fn
(
OP_CONVERT_FUNCTION_NAME
(
Sub
)());
}
class
OpConvertRegistry
{
...
...
@@ -325,8 +338,8 @@ class OnnxExporter {
void
ExportPrimReshape
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
graph_proto
);
void
ExportPrimReduce
Mean
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
graph_proto
);
void
ExportPrimReduce
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
graph_proto
);
void
ExportPrimCast
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
graph_proto
);
void
ExportPrimPReLU
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
...
...
@@ -335,6 +348,12 @@ class OnnxExporter {
onnx
::
GraphProto
*
graph_proto
);
void
ExportPrimDepthwiseConv2d
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
graph_proto
);
void
ExportPrimTile
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
graph_proto
);
void
ExportPrimSquare
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
graph_proto
);
void
ExportPrimGatherV2
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
graph_proto
);
void
ExportMergeConv
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
graph_proto
);
...
...
@@ -628,16 +647,19 @@ void OnnxExporter::ExportPrimReshape(const FuncGraphPtr & /*func_graph*/, const
node_proto
->
add_input
(
name_shape
);
}
void
OnnxExporter
::
ExportPrimReduceMean
(
const
FuncGraphPtr
&
/*func_graph*/
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
const
graph_proto
)
{
void
OnnxExporter
::
ExportPrimReduce
(
const
FuncGraphPtr
&
/*func_graph*/
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
const
graph_proto
)
{
auto
input_data
=
GetNodeInputName
(
node
->
input
(
1
),
node_map_ptr
,
graph_proto
);
auto
input_axis
=
node
->
input
(
2
);
auto
node_idx
=
AllocateNodeIndex
();
(
*
node_map_ptr
)[
node
]
=
node_idx
;
onnx
::
NodeProto
*
node_proto
=
graph_proto
->
add_node
();
node_proto
->
set_op_type
(
prim
::
kPrimReduceMean
->
name
());
auto
name
=
prim
::
kPrimReduceMean
->
name
();
if
(
node
->
IsApply
(
prim
::
kPrimReduceSum
))
{
name
=
prim
::
kPrimReduceSum
->
name
();
}
node_proto
->
set_op_type
(
name
);
node_proto
->
add_output
(
std
::
to_string
(
node_idx
));
node_proto
->
add_input
(
input_data
);
...
...
@@ -646,13 +668,18 @@ void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, con
attr_proto
->
set_name
(
"axes"
);
attr_proto
->
set_type
(
onnx
::
AttributeProto_AttributeType_INTS
);
auto
axis_value
=
dyn_cast
<
ValueNode
>
(
input_axis
)
->
value
();
auto
tuple_ptr
=
dyn_cast
<
ValueTuple
>
(
axis_value
);
MS_EXCEPTION_IF_NULL
(
tuple_ptr
);
for
(
size_t
i
=
0
;
i
<
tuple_ptr
->
size
();
++
i
)
{
attr_proto
->
add_ints
(
GetValue
<
int
>
((
*
tuple_ptr
)[
i
]));
auto
int_ptr
=
dyn_cast
<
Int32Imm
>
(
axis_value
);
if
(
int_ptr
==
nullptr
)
{
auto
tuple_ptr
=
dyn_cast
<
ValueTuple
>
(
axis_value
);
MS_EXCEPTION_IF_NULL
(
tuple_ptr
);
for
(
size_t
i
=
0
;
i
<
tuple_ptr
->
size
();
++
i
)
{
attr_proto
->
add_ints
(
GetValue
<
int
>
((
*
tuple_ptr
)[
i
]));
}
}
else
{
attr_proto
->
add_ints
(
int_ptr
->
value
());
}
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Need to insert op convert variable from tuple to attributes for
ReduceMean."
;
MS_LOG
(
EXCEPTION
)
<<
"Need to insert op convert variable from tuple to attributes for
"
<<
name
;
}
}
...
...
@@ -826,6 +853,83 @@ void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/
SetAttrTupleValueToProto
<
2
>
(
prim
->
GetAttr
(
"stride"
),
onnx
::
AttributeProto_AttributeType_INTS
,
onnx_attr_proto
,
prim
);
}
void
OnnxExporter
::
ExportPrimTile
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
const
graph_proto
)
{
auto
name_x
=
GetNodeInputName
(
node
->
input
(
1
),
node_map_ptr
,
graph_proto
);
auto
multiples
=
node
->
input
(
2
);
std
::
string
name_multiples
;
if
(
multiples
->
isa
<
ValueNode
>
())
{
auto
const_node_idx
=
AllocateNodeIndex
();
(
*
node_map_ptr
)[
multiples
]
=
const_node_idx
;
onnx
::
NodeProto
*
node_proto
=
graph_proto
->
add_node
();
name_multiples
=
std
::
to_string
(
const_node_idx
);
node_proto
->
add_output
(
name_multiples
);
node_proto
->
set_op_type
(
"Constant"
);
onnx
::
AttributeProto
*
attr_proto
=
node_proto
->
add_attribute
();
attr_proto
->
set_name
(
"repeat"
);
attr_proto
->
set_type
(
onnx
::
AttributeProto_AttributeType_TENSOR
);
ConvertTupleToTensor
(
dyn_cast
<
ValueNode
>
(
multiples
)
->
value
(),
attr_proto
->
mutable_t
());
}
else
{
name_multiples
=
GetNodeInputName
(
multiples
,
node_map_ptr
,
graph_proto
);
MS_LOG
(
EXCEPTION
)
<<
"Need to insert op convert variable from tuple to tensor for Tile."
;
}
auto
node_idx
=
AllocateNodeIndex
();
(
*
node_map_ptr
)[
node
]
=
node_idx
;
onnx
::
NodeProto
*
node_proto
=
graph_proto
->
add_node
();
node_proto
->
set_op_type
(
"Tile"
);
node_proto
->
add_output
(
std
::
to_string
(
node_idx
));
node_proto
->
add_input
(
name_x
);
node_proto
->
add_input
(
name_multiples
);
}
void
OnnxExporter
::
ExportPrimSquare
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
const
graph_proto
)
{
auto
name_x
=
GetNodeInputName
(
node
->
input
(
1
),
node_map_ptr
,
graph_proto
);
std
::
string
name_exponent
;
auto
const_node_idx
=
AllocateNodeIndex
();
onnx
::
NodeProto
*
node_proto_exp
=
graph_proto
->
add_node
();
name_exponent
=
std
::
to_string
(
const_node_idx
);
node_proto_exp
->
add_output
(
name_exponent
);
node_proto_exp
->
set_op_type
(
"Constant"
);
onnx
::
AttributeProto
*
attr_proto
=
node_proto_exp
->
add_attribute
();
attr_proto
->
set_type
(
onnx
::
AttributeProto_AttributeType_TENSOR
);
onnx
::
TensorProto
*
tensor_proto
=
attr_proto
->
mutable_t
();
tensor_proto
->
set_name
(
"exponent"
);
tensor_proto
->
add_dims
(
static_cast
<::
google
::
protobuf
::
int64
>
(
1
));
tensor_proto
->
set_data_type
(
onnx
::
TensorProto_DataType_INT64
);
tensor_proto
->
add_int64_data
(
2
);
auto
node_idx
=
AllocateNodeIndex
();
(
*
node_map_ptr
)[
node
]
=
node_idx
;
onnx
::
NodeProto
*
node_proto
=
graph_proto
->
add_node
();
node_proto
->
set_op_type
(
"Pow"
);
node_proto
->
add_output
(
std
::
to_string
(
node_idx
));
node_proto
->
add_input
(
name_x
);
node_proto
->
add_input
(
name_exponent
);
}
void
OnnxExporter
::
ExportPrimGatherV2
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
const
graph_proto
)
{
auto
name_x
=
GetNodeInputName
(
node
->
input
(
1
),
node_map_ptr
,
graph_proto
);
auto
name_indices
=
GetNodeInputName
(
node
->
input
(
2
),
node_map_ptr
,
graph_proto
);
auto
axis
=
node
->
input
(
3
)
->
cast
<
ValueNodePtr
>
()
->
value
();
auto
node_idx
=
AllocateNodeIndex
();
(
*
node_map_ptr
)[
node
]
=
node_idx
;
onnx
::
NodeProto
*
node_proto
=
graph_proto
->
add_node
();
node_proto
->
set_op_type
(
"Gather"
);
node_proto
->
add_output
(
std
::
to_string
(
node_idx
));
node_proto
->
add_input
(
name_x
);
node_proto
->
add_input
(
name_indices
);
onnx
::
AttributeProto
*
attr_proto
=
node_proto
->
add_attribute
();
attr_proto
->
set_type
(
onnx
::
AttributeProto_AttributeType_INT
);
attr_proto
->
set_i
(
static_cast
<::
google
::
protobuf
::
int64
>
(
dyn_cast
<
Int32Imm
>
(
axis
)
->
value
()));
}
void
OnnxExporter
::
ExportCNode
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
std
::
map
<
AnfNodePtr
,
size_t
>
*
node_map_ptr
,
onnx
::
GraphProto
*
const
graph_proto
)
{
// Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert
...
...
@@ -833,8 +937,8 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
return
ExportPrimReshape
(
func_graph
,
node
,
node_map_ptr
,
graph_proto
);
}
if
(
node
->
IsApply
(
prim
::
kPrimReduceMean
))
{
return
ExportPrimReduce
Mean
(
func_graph
,
node
,
node_map_ptr
,
graph_proto
);
if
(
node
->
IsApply
(
prim
::
kPrimReduceMean
)
||
node
->
IsApply
(
prim
::
kPrimReduceSum
)
)
{
return
ExportPrimReduce
(
func_graph
,
node
,
node_map_ptr
,
graph_proto
);
}
// MindSpore Cast(x, T) --> ONNX Cast[to=T](x)
...
...
@@ -857,6 +961,21 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
return
ExportPrimDepthwiseConv2d
(
func_graph
,
node
,
node_map_ptr
,
graph_proto
);
}
// MindSpore Tile(x) --> ONNX Tile(x, repeat)
if
(
node
->
IsApply
(
prim
::
kPrimTile
))
{
return
ExportPrimTile
(
func_graph
,
node
,
node_map_ptr
,
graph_proto
);
}
// MindSpore Square(x) --> ONNX Pow(x, 2)
if
(
node
->
IsApply
(
prim
::
kPrimSquare
))
{
return
ExportPrimSquare
(
func_graph
,
node
,
node_map_ptr
,
graph_proto
);
}
// MindSpore GatherV2(x, indices, axis) --> ONNX Pow(x, indices)
if
(
node
->
IsApply
(
prim
::
kPrimGatherV2
))
{
return
ExportPrimGatherV2
(
func_graph
,
node
,
node_map_ptr
,
graph_proto
);
}
auto
inputs
=
node
->
inputs
();
if
(
inputs
.
size
()
<
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"Inputs of apply node is empty"
;
...
...
@@ -1054,7 +1173,30 @@ void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *cons
node_proto
->
set_op_type
(
"Constant"
);
onnx
::
AttributeProto
*
attr_proto
=
node_proto
->
add_attribute
();
attr_proto
->
set_name
(
"value"
);
MS_LOG
(
EXCEPTION
)
<<
"Need to set value "
<<
value
->
ToString
()
<<
" attribute for Constant node"
;
if
(
value
->
isa
<
Int32Imm
>
())
{
attr_proto
->
set_type
(
onnx
::
AttributeProto_AttributeType_INT
);
auto
casted_value
=
dyn_cast
<
Int32Imm
>
(
value
);
if
(
casted_value
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Cast value "
<<
value
->
ToString
()
<<
" to type T failed."
;
}
auto
attr_value
=
casted_value
->
value
();
attr_proto
->
set_i
(
static_cast
<::
google
::
protobuf
::
int64
>
(
attr_value
));
attr_proto
->
set_type
(
onnx
::
AttributeProto_AttributeType_INT
);
}
else
if
(
value
->
isa
<
tensor
::
Tensor
>
())
{
attr_proto
->
set_type
(
onnx
::
AttributeProto_AttributeType_TENSOR
);
onnx
::
TensorProto
*
tensor_proto
=
attr_proto
->
mutable_t
();
auto
data
=
dyn_cast
<
tensor
::
Tensor
>
(
value
);
tensor_proto
->
set_raw_data
(
data
->
data
().
request
(
true
).
ptr
,
static_cast
<
size_t
>
(
data
->
data
().
nbytes
()));
auto
dtype
=
data
->
data_type
();
auto
shape
=
data
->
shape_c
();
tensor_proto
->
set_data_type
(
GetOnnxDataType
(
dtype
));
for
(
const
auto
&
dim
:
shape
)
{
tensor_proto
->
add_dims
(
dim
);
}
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Need to set value "
<<
value
->
ToString
()
<<
" attribute for Constant node"
;
}
}
std
::
string
GetOnnxProtoString
(
const
FuncGraphPtr
&
func_graph
)
{
...
...
This diff is collapsed.
Click to expand it.
tests/ut/python/onnx/test_onnx.py
浏览文件 @
9c5f6b91
...
...
@@ -142,6 +142,20 @@ class DepthwiseConv2dAndReLU6(nn.Cell):
x
=
self
.
relu6
(
x
)
return
x
class
DeepFMOpNet
(
nn
.
Cell
):
"""Net definition with Gatherv2 and Tile and Square."""
def
__init__
(
self
):
super
(
DeepFMOpNet
,
self
).
__init__
()
self
.
gather
=
P
.
GatherV2
()
self
.
square
=
P
.
Square
()
self
.
tile
=
P
.
Tile
()
def
construct
(
self
,
x
,
y
):
x
=
self
.
tile
(
x
,
(
1000
,
1
))
x
=
self
.
square
(
x
)
x
=
self
.
gather
(
x
,
y
,
0
)
return
x
# generate mindspore Tensor by shape and numpy datatype
def
gen_tensor
(
shape
,
dtype
=
np
.
float32
):
...
...
@@ -153,6 +167,7 @@ net_cfgs = [
(
'lenet'
,
LeNet5
(),
gen_tensor
([
1
,
1
,
32
,
32
])),
(
'maxpoolwithargmax'
,
DefinedNet
(),
gen_tensor
([
1
,
3
,
224
,
224
])),
(
'depthwiseconv_relu6'
,
DepthwiseConv2dAndReLU6
(
3
,
kernel_size
=
3
),
gen_tensor
([
1
,
3
,
32
,
32
])),
(
'deepfm_ops'
,
DeepFMOpNet
(),
(
gen_tensor
([
1
,
1
]),
gen_tensor
([
1000
,
1
],
dtype
=
np
.
int32
)))
]
...
...
@@ -164,7 +179,10 @@ def get_id(cfg):
@
pytest
.
mark
.
parametrize
(
'name, net, inp'
,
net_cfgs
,
ids
=
get_id
(
net_cfgs
))
def
test_onnx_export
(
name
,
net
,
inp
):
onnx_file
=
name
+
".onnx"
export
(
net
,
inp
,
file_name
=
onnx_file
,
file_format
=
'ONNX'
)
if
isinstance
(
inp
,
(
tuple
,
list
)):
export
(
net
,
*
inp
,
file_name
=
onnx_file
,
file_format
=
'ONNX'
)
else
:
export
(
net
,
inp
,
file_name
=
onnx_file
,
file_format
=
'ONNX'
)
# check existence of exported onnx file and delete it
assert
os
.
path
.
exists
(
onnx_file
)
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
新手
引导
客服
返回
顶部