Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2f42466d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2f42466d
编写于
9月 07, 2020
作者:
H
hexia
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
export and load model for serving
上级
8e442ce7
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
381 addition
and
189 deletion
+381
-189
mindspore/ccsrc/transform/onnx/ir_exporter.cc
mindspore/ccsrc/transform/onnx/ir_exporter.cc
+145
-92
mindspore/ccsrc/utils/load_onnx/anf_converter.cc
mindspore/ccsrc/utils/load_onnx/anf_converter.cc
+1
-1
mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc
mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc
+231
-91
mindspore/ccsrc/utils/load_onnx/anf_model_parser.h
mindspore/ccsrc/utils/load_onnx/anf_model_parser.h
+4
-5
未找到文件。
mindspore/ccsrc/transform/onnx/ir_exporter.cc
浏览文件 @
2f42466d
...
...
@@ -90,14 +90,17 @@ class IrExportBuilder {
void
SetTensorProto
(
const
TypePtr
&
type
,
const
BaseShapePtr
&
shape
,
onnx
::
TensorProto
*
const
tensor_proto
);
void
SetAttributeProto
(
const
AnfNodePtr
&
node
,
onnx
::
NodeProto
*
const
node_proto
);
void
SetShapeToNodeProto
(
const
CNodePtr
&
node
,
onnx
::
NodeProto
*
const
node_proto
);
void
SetShapeToNodeProto
(
const
TypePtr
&
type
,
const
BaseShapePtr
&
shape
,
onnx
::
NodeProto
*
const
node
_proto
,
std
::
string
suffix
=
"0"
);
void
SetShapeToNodeProto
(
const
TypePtr
&
type
,
const
BaseShapePtr
&
shape
,
onnx
::
AttributeProto
*
const
attr
_proto
,
std
::
string
*
const
seq_string
);
void
SetValueToAttributeProto
(
const
ValuePtr
&
value
,
onnx
::
AttributeProto
*
const
attr_proto
);
void
SetTypeToAttributeProto
(
const
ValuePtr
&
value
,
onnx
::
AttributeProto
*
const
attr_proto
);
void
SetScalarToAttributeProto
(
const
ValuePtr
&
value
,
onnx
::
AttributeProto
*
const
attr_proto
);
void
SetTensorToAttributeProto
(
const
ValuePtr
&
value
,
onnx
::
AttributeProto
*
const
attr_proto
);
void
SetScalarToProto
(
const
ValuePtr
&
value
,
onnx
::
TensorProto
*
const
tensor_proto
);
void
SetSequenceToAttributeProto
(
const
ValueSequeuePtr
&
value
,
onnx
::
AttributeProto
*
const
attr_proto
);
void
SetScalarToProto
(
const
ValuePtr
&
value
,
onnx
::
TensorProto
*
const
tensor_proto
,
const
std
::
string
&
value_name
);
void
SetSequenceToAttributeProto
(
const
ValueSequeuePtr
&
value
,
onnx
::
AttributeProto
*
const
attr_proto
,
std
::
string
*
const
seq_string
);
void
SetSeqElemToAttributeProto
(
const
ValuePtr
&
value
,
onnx
::
AttributeProto
*
const
attr_proto
,
std
::
string
*
const
seq_string
);
onnx
::
TensorProto_DataType
GetOnnxDataType
(
TypeId
type_id
);
onnx
::
TensorProto_DataType
GetOnnxDataBitsIntType
(
int
bits
);
...
...
@@ -105,8 +108,10 @@ class IrExportBuilder {
std
::
string
GetNodeName
(
const
AnfNodePtr
&
node
);
std
::
string
GetUniqueNodeName
(
const
AnfNodePtr
&
node
);
std
::
string
GetOpTypeName
(
const
AnfNodePtr
&
node
);
size_t
AllocateIndex
()
{
return
++
node_index_
;
}
void
ResetIndex
()
{
node_index_
=
0
;
}
size_t
GetNodeIndex
()
{
return
++
node_index_
;
}
void
ResetNodeIndex
()
{
node_index_
=
0
;
}
size_t
GetTupleIndex
()
{
return
++
shape_index_
;
}
void
ResetTupleIndex
()
{
shape_index_
=
0
;
}
private:
onnx
::
ModelProto
model_
;
...
...
@@ -114,6 +119,7 @@ class IrExportBuilder {
std
::
list
<
FuncGraphPtr
>
todo_
;
std
::
map
<
AnfNodePtr
,
size_t
>
node_index_map_
;
size_t
node_index_
{
0
};
size_t
shape_index_
{
0
};
};
using
IrExporterPtr
=
std
::
shared_ptr
<
IrExporter
>
;
...
...
@@ -146,7 +152,7 @@ void IrExportBuilder::BuildModelInfo() {
void
IrExportBuilder
::
BuildModel
(
const
FuncGraphPtr
&
func_graph
)
{
onnx
::
GraphProto
*
graph_proto
=
model_
.
mutable_graph
();
graph_proto
->
set_name
(
func_graph
->
ToString
());
ResetIndex
();
Reset
Node
Index
();
todo_
.
clear
();
todo_
.
push_back
(
func_graph
);
while
(
!
todo_
.
empty
())
{
...
...
@@ -177,7 +183,7 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::Grap
input_proto
->
set_name
(
param_name
);
SetValueInfoProto
(
param
,
input_proto
);
if
(
!
param
->
has_default
())
{
MS_LOG
(
DEBUG
)
<<
"Parameter: '"
<<
item
->
ToString
()
<<
"' has no default"
;
MS_LOG
(
DEBUG
)
<<
"Parameter: '"
<<
item
->
ToString
()
<<
"' has no default
.
"
;
continue
;
}
...
...
@@ -232,13 +238,20 @@ void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr
auto
elem_type
=
tensor
->
element
();
const
auto
&
dims
=
shape
->
cast
<
abstract
::
ShapePtr
>
()
->
shape
();
type_proto
->
mutable_tensor_type
()
->
set_elem_type
(
GetOnnxDataType
(
elem_type
->
type_id
()));
for
(
const
auto
&
dim
:
dims
)
{
MS_LOG
(
DEBUG
)
<<
"SetValueInfoProto dim: "
<<
dim
;
type_proto
->
mutable_tensor_type
()
->
mutable_shape
()
->
add_dim
()
->
set_dim_value
(
dim
);
if
(
dims
.
size
()
==
0
)
{
MS_LOG
(
DEBUG
)
<<
"SetValueInfoProto set default dim 1."
;
type_proto
->
mutable_tensor_type
()
->
mutable_shape
()
->
add_dim
()
->
set_dim_value
(
1
);
}
else
{
for
(
const
auto
&
dim
:
dims
)
{
MS_LOG
(
DEBUG
)
<<
"SetValueInfoProto dim: "
<<
dim
;
type_proto
->
mutable_tensor_type
()
->
mutable_shape
()
->
add_dim
()
->
set_dim_value
(
dim
);
}
}
}
else
if
(
type
->
isa
<
Tuple
>
())
{
auto
tup_shape
=
shape
->
cast
<
abstract
::
TupleShapePtr
>
();
type_proto
->
set_denotation
(
std
::
to_string
(
tup_shape
->
shape
().
size
()));
type_proto
->
set_denotation
(
type
->
type_name
()
+
":"
+
std
::
to_string
(
tup_shape
->
shape
().
size
()));
}
else
if
(
type
->
isa
<
Number
>
()
||
type
->
isa
<
String
>
())
{
type_proto
->
set_denotation
(
type
->
type_name
());
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Value type: "
<<
type
->
type_name
()
<<
" is not supported!"
;
}
...
...
@@ -248,9 +261,10 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::Att
if
(
value
==
nullptr
||
attr_proto
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"ValuePtr or AttributeProto is null!"
;
}
attr_proto
->
set_ref_attr_name
(
"tensor"
);
attr_proto
->
set_type
(
onnx
::
AttributeProto_AttributeType_TENSOR
);
onnx
::
TensorProto
*
tensor_proto
=
attr_proto
->
mutable_t
();
attr_proto
->
set_ref_attr_name
(
"tensor:value0"
);
attr_proto
->
set_type
(
onnx
::
AttributeProto_AttributeType_TENSORS
);
onnx
::
TensorProto
*
tensor_proto
=
attr_proto
->
add_tensors
();
tensor_proto
->
set_name
(
"value0"
);
auto
data
=
value
->
cast
<
tensor
::
TensorPtr
>
();
tensor_proto
->
set_raw_data
(
data
->
data_c
(),
static_cast
<
size_t
>
(
data
->
data
().
nbytes
()));
auto
dtype
=
data
->
data_type
();
...
...
@@ -284,6 +298,7 @@ void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::Ten
void
IrExportBuilder
::
BuildNodes
(
const
FuncGraphPtr
&
func_graph
,
onnx
::
GraphProto
*
const
graph_proto
)
{
std
::
vector
<
AnfNodePtr
>
nodes
=
TopoSort
(
func_graph
->
get_return
(),
SuccIncoming
,
AlwaysInclude
);
bool
is_only_return
=
true
;
for
(
const
AnfNodePtr
&
node
:
nodes
)
{
if
(
!
node
->
isa
<
CNode
>
())
{
MS_LOG
(
DEBUG
)
<<
"Node: '"
<<
node
->
ToString
()
<<
"' is not cnode"
;
...
...
@@ -291,9 +306,13 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProt
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
cnode
==
func_graph
->
get_return
())
{
if
(
is_only_return
)
{
MS_LOG
(
EXCEPTION
)
<<
"Only has return node, can't convert to binary model!"
;
}
BuildOutput
(
cnode
,
graph_proto
);
}
else
{
BuildCNode
(
cnode
,
graph_proto
);
is_only_return
=
false
;
}
}
}
...
...
@@ -303,24 +322,11 @@ void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const
MS_LOG
(
EXCEPTION
)
<<
"Number of inputs of return node is not equal to 2."
;
}
AnfNodePtr
arg
=
node
->
input
(
1
);
// Using make_tuple to set multi-output
if
(
IsPrimitiveCNode
(
arg
,
prim
::
kPrimMakeTuple
))
{
auto
tuple_node
=
arg
->
cast
<
CNodePtr
>
();
for
(
size_t
i
=
1
;
i
<
tuple_node
->
size
();
i
++
)
{
auto
input_node
=
arg
->
cast
<
CNodePtr
>
()
->
input
(
i
);
onnx
::
ValueInfoProto
*
output_proto
=
graph_proto
->
add_output
();
auto
output_name
=
GetUniqueNodeName
(
tuple_node
->
input
(
i
));
output_proto
->
set_name
(
output_name
);
last_node_
->
add_output
(
output_name
);
SetValueInfoProto
(
tuple_node
->
input
(
i
),
output_proto
);
}
}
else
{
onnx
::
ValueInfoProto
*
output_proto
=
graph_proto
->
add_output
();
std
::
string
output_name
=
GetUniqueNodeName
(
node
);
output_proto
->
set_name
(
output_name
);
last_node_
->
add_output
(
output_name
);
SetValueInfoProto
(
arg
,
output_proto
);
}
onnx
::
ValueInfoProto
*
output_proto
=
graph_proto
->
add_output
();
std
::
string
output_name
=
GetUniqueNodeName
(
node
);
output_proto
->
set_name
(
output_name
);
last_node_
->
set_output
(
0
,
output_name
);
SetValueInfoProto
(
arg
,
output_proto
);
}
std
::
string
IrExportBuilder
::
GetOpTypeName
(
const
AnfNodePtr
&
node
)
{
...
...
@@ -343,45 +349,44 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
}
void
IrExportBuilder
::
SetShapeToNodeProto
(
const
TypePtr
&
type
,
const
BaseShapePtr
&
shape
,
onnx
::
NodeProto
*
const
node_proto
,
std
::
string
suffix
)
{
onnx
::
AttributeProto
*
attr_proto
=
node_proto
->
add_attribute
();
attr_proto
->
set_ref_attr_name
(
"shape"
);
if
(
suffix
.
compare
(
"0"
)
!=
0
)
{
attr_proto
->
set_name
(
"shape"
+
suffix
);
}
else
{
attr_proto
->
set_name
(
"shape"
);
}
onnx
::
TensorProto
*
tensor_proto
=
attr_proto
->
mutable_t
();
SetTensorProto
(
type
,
shape
,
tensor_proto
);
}
void
IrExportBuilder
::
SetShapeToNodeProto
(
const
CNodePtr
&
node
,
onnx
::
NodeProto
*
const
node_proto
)
{
// Get shape of cnode
// 1. prim ArgMaxWithValue need to get shape from tuple element
// 2. some cnode doesn't has shape, such as LayerNorm
// 3. other cnodes have shape
if
(
node
->
IsApply
(
prim
::
kPrimArgMaxWithValue
)
||
node
->
IsApply
(
prim
::
kPrimLayerNorm
))
{
auto
type
=
node
->
Type
();
auto
shape
=
node
->
Shape
();
if
(
!
type
->
isa
<
Tuple
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Output data of ArgMaxWithValue cnode must be tuple: "
<<
type
->
type_name
();
}
onnx
::
AttributeProto
*
const
attr_proto
,
std
::
string
*
const
seq_string
)
{
if
(
type
->
isa
<
Tuple
>
()
&&
seq_string
!=
nullptr
)
{
*
seq_string
+=
"Tuple["
;
auto
elements
=
type
->
cast
<
TuplePtr
>
()
->
elements
();
auto
tuple_shape
=
shape
->
cast
<
abstract
::
TupleShapePtr
>
()
->
shape
();
for
(
size_t
i
=
0
;
i
<
elements
.
size
();
i
++
)
{
SetShapeToNodeProto
(
elements
[
i
],
tuple_shape
[
i
],
node_proto
,
std
::
to_string
(
i
)
);
SetShapeToNodeProto
(
elements
[
i
],
tuple_shape
[
i
],
attr_proto
,
seq_string
);
}
*
seq_string
+=
"],"
;
}
else
if
(
type
->
isa
<
TensorType
>
()
&&
shape
->
isa
<
abstract
::
Shape
>
()
&&
seq_string
!=
nullptr
)
{
string
shape_name
=
"shape"
+
std
::
to_string
(
GetTupleIndex
());
*
seq_string
+=
shape_name
+
","
;
onnx
::
TensorProto
*
tensor_proto
=
attr_proto
->
add_tensors
();
tensor_proto
->
set_name
(
shape_name
);
SetTensorProto
(
type
,
shape
,
tensor_proto
);
}
else
if
((
type
->
isa
<
Number
>
()
||
type
->
isa
<
String
>
())
&&
seq_string
!=
nullptr
)
{
*
seq_string
+=
type
->
type_name
()
+
","
;
}
else
{
auto
type
=
node
->
Type
();
auto
shape
=
node
->
Shape
();
if
(
!
type
->
isa
<
TensorType
>
()
||
!
shape
->
isa
<
abstract
::
Shape
>
())
{
MS_LOG
(
DEBUG
)
<<
"Cnode has no shape: "
<<
node
->
ToString
();
return
;
}
SetShapeToNodeProto
(
type
,
shape
,
node_proto
);
MS_LOG
(
EXCEPTION
)
<<
"Type of cnode need to be supported: "
<<
type
->
type_name
();
}
}
void
IrExportBuilder
::
SetShapeToNodeProto
(
const
CNodePtr
&
node
,
onnx
::
NodeProto
*
const
node_proto
)
{
// Get shape of cnode
// 1. need to get shape from tuple element
// 2. save shape in TensorProto
// 3. save tuple string in ref_attr_name
MS_EXCEPTION_IF_NULL
(
node
);
auto
type
=
node
->
Type
();
auto
shape
=
node
->
Shape
();
ResetTupleIndex
();
std
::
string
seq_string
=
"shape:"
;
onnx
::
AttributeProto
*
attr_proto
=
node_proto
->
add_attribute
();
SetShapeToNodeProto
(
type
,
shape
,
attr_proto
,
&
seq_string
);
attr_proto
->
set_ref_attr_name
(
seq_string
);
MS_LOG
(
DEBUG
)
<<
"CNode shape: "
<<
seq_string
;
}
void
IrExportBuilder
::
BuildCNode
(
const
CNodePtr
&
node
,
onnx
::
GraphProto
*
const
graph_proto
)
{
auto
inputs_size
=
node
->
size
();
if
(
inputs_size
<
1
)
{
...
...
@@ -443,15 +448,19 @@ std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) {
std
::
string
node_name
=
""
;
if
(
node
->
isa
<
Parameter
>
())
{
node_name
=
GetNodeName
(
node
);
}
else
if
(
node
->
isa
<
CNode
>
()
||
node
->
isa
<
ValueNode
>
()
)
{
}
else
if
(
node
->
isa
<
CNode
>
())
{
auto
iter
=
node_index_map_
.
find
(
node
);
if
(
iter
!=
node_index_map_
.
end
())
{
node_name
=
GetNodeName
(
node
)
+
":"
+
std
::
to_string
(
iter
->
second
);
}
else
{
auto
node_idx
=
Allocat
eIndex
();
auto
node_idx
=
GetNod
eIndex
();
node_index_map_
[
node
]
=
node_idx
;
node_name
=
GetNodeName
(
node
)
+
":"
+
std
::
to_string
(
node_idx
);
}
}
else
if
(
node
->
isa
<
ValueNode
>
())
{
auto
node_idx
=
GetNodeIndex
();
node_index_map_
[
node
]
=
node_idx
;
node_name
=
GetNodeName
(
node
)
+
":"
+
std
::
to_string
(
node_idx
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Can not support type of node:"
<<
node
->
ToString
();
}
...
...
@@ -485,17 +494,21 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::Attri
if
(
value
==
nullptr
||
attr_proto
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"ValuePtr or AttributeProto is null!"
;
}
attr_proto
->
set_ref_attr_name
(
"type"
);
attr_proto
->
set_type
(
onnx
::
AttributeProto_AttributeType_TENSOR
);
onnx
::
TensorProto
*
tensor_proto
=
attr_proto
->
mutable_t
();
attr_proto
->
set_type
(
onnx
::
AttributeProto_AttributeType_TENSORS
);
onnx
::
TensorProto
*
tensor_proto
=
attr_proto
->
add_tensors
();
if
(
value
->
isa
<
Int
>
())
{
attr_proto
->
set_ref_attr_name
(
"type:value0"
);
tensor_proto
->
set_name
(
"value0"
);
auto
int_value
=
value
->
cast
<
IntPtr
>
();
tensor_proto
->
set_data_type
(
GetOnnxDataBitsIntType
(
int_value
->
nbits
()));
}
else
if
(
value
->
isa
<
Float
>
())
{
attr_proto
->
set_ref_attr_name
(
"type:value0"
);
tensor_proto
->
set_name
(
"value0"
);
auto
float_value
=
value
->
cast
<
FloatPtr
>
();
tensor_proto
->
set_data_type
(
GetOnnxDataBitsFloatType
(
float_value
->
nbits
()));
}
else
if
(
value
->
isa
<
TensorType
>
())
{
tensor_proto
->
set_name
(
"tensor"
);
attr_proto
->
set_ref_attr_name
(
"type:tensor0"
);
tensor_proto
->
set_name
(
"tensor0"
);
auto
elem_type
=
value
->
cast
<
TensorTypePtr
>
()
->
element
();
if
(
elem_type
->
isa
<
Int
>
())
{
auto
int_value
=
elem_type
->
cast
<
IntPtr
>
();
...
...
@@ -519,10 +532,18 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::Attr
SetScalarToAttributeProto
(
value
,
attr_proto
);
}
else
if
(
value
->
isa
<
Number
>
()
||
value
->
isa
<
TensorType
>
())
{
SetTypeToAttributeProto
(
value
,
attr_proto
);
}
else
if
(
value
->
isa
<
ValueSequeue
>
())
{
SetSequenceToAttributeProto
(
value
->
cast
<
ValueSequeuePtr
>
(),
attr_proto
);
}
else
if
(
value
->
isa
<
ValueSequeue
>
()
||
value
->
isa
<
ValueSequeue
>
())
{
ResetTupleIndex
();
std
::
string
seq_string
=
"scalar:"
;
attr_proto
->
set_type
(
onnx
::
AttributeProto_AttributeType_TENSORS
);
SetSequenceToAttributeProto
(
value
->
cast
<
ValueSequeuePtr
>
(),
attr_proto
,
&
seq_string
);
attr_proto
->
set_ref_attr_name
(
seq_string
);
MS_LOG
(
DEBUG
)
<<
"Attr string: "
<<
seq_string
;
}
else
if
(
value
->
isa
<
tensor
::
Tensor
>
())
{
SetTensorToAttributeProto
(
value
,
attr_proto
);
}
else
if
(
value
->
isa
<
None
>
())
{
attr_proto
->
set_ref_attr_name
(
"none"
);
MS_LOG
(
DEBUG
)
<<
"Attr string: "
<<
value
->
type_name
();
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Unsupported type: "
<<
value
->
type_name
();
}
...
...
@@ -532,16 +553,18 @@ void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::Att
if
(
value
==
nullptr
||
attr_proto
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"ValuePtr or AttributeProto is null!"
;
}
attr_proto
->
set_ref_attr_name
(
"scalar"
);
attr_proto
->
set_type
(
onnx
::
AttributeProto_AttributeType_TENSOR
);
onnx
::
TensorProto
*
tensor_proto
=
attr_proto
->
mutable_t
();
SetScalarToProto
(
value
,
tensor_proto
);
attr_proto
->
set_ref_attr_name
(
"scalar
:value0
"
);
attr_proto
->
set_type
(
onnx
::
AttributeProto_AttributeType_TENSOR
S
);
onnx
::
TensorProto
*
tensor_proto
=
attr_proto
->
add_tensors
();
SetScalarToProto
(
value
,
tensor_proto
,
"value0"
);
}
void
IrExportBuilder
::
SetScalarToProto
(
const
ValuePtr
&
value
,
onnx
::
TensorProto
*
const
tensor_proto
)
{
void
IrExportBuilder
::
SetScalarToProto
(
const
ValuePtr
&
value
,
onnx
::
TensorProto
*
const
tensor_proto
,
const
std
::
string
&
value_name
)
{
if
(
value
==
nullptr
||
tensor_proto
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"ValuePtr or TensorProto is null!"
;
}
tensor_proto
->
set_name
(
value_name
);
if
(
value
->
isa
<
StringImm
>
())
{
tensor_proto
->
set_data_type
(
onnx
::
TensorProto_DataType_STRING
);
tensor_proto
->
add_string_data
(
GetValue
<
std
::
string
>
(
value
));
...
...
@@ -560,44 +583,74 @@ void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto
}
else
if
(
value
->
isa
<
Int64Imm
>
())
{
tensor_proto
->
set_data_type
(
onnx
::
TensorProto_DataType_INT64
);
tensor_proto
->
add_int64_data
(
value
->
cast
<
Int64ImmPtr
>
()
->
value
());
}
else
if
(
value
->
isa
<
FloatImm
>
())
{
}
else
if
(
value
->
isa
<
UInt8Imm
>
())
{
tensor_proto
->
set_data_type
(
onnx
::
TensorProto_DataType_UINT8
);
tensor_proto
->
add_int32_data
(
value
->
cast
<
UInt8ImmPtr
>
()
->
value
());
}
else
if
(
value
->
isa
<
UInt16Imm
>
())
{
tensor_proto
->
set_data_type
(
onnx
::
TensorProto_DataType_UINT16
);
tensor_proto
->
add_int32_data
(
value
->
cast
<
UInt16ImmPtr
>
()
->
value
());
}
else
if
(
value
->
isa
<
UInt32Imm
>
())
{
tensor_proto
->
set_data_type
(
onnx
::
TensorProto_DataType_UINT32
);
tensor_proto
->
add_uint64_data
(
value
->
cast
<
UInt32ImmPtr
>
()
->
value
());
}
else
if
(
value
->
isa
<
UInt64Imm
>
())
{
tensor_proto
->
set_data_type
(
onnx
::
TensorProto_DataType_UINT64
);
tensor_proto
->
add_uint64_data
(
value
->
cast
<
UInt64ImmPtr
>
()
->
value
());
}
else
if
(
value
->
isa
<
FP32Imm
>
())
{
tensor_proto
->
set_data_type
(
onnx
::
TensorProto_DataType_FLOAT
);
tensor_proto
->
add_float_data
(
GetValue
<
float
>
(
value
));
}
else
if
(
value
->
isa
<
FP64Imm
>
())
{
tensor_proto
->
set_data_type
(
onnx
::
TensorProto_DataType_DOUBLE
);
tensor_proto
->
add_double_data
(
GetValue
<
double
>
(
value
));
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Unsupported scalar type: "
<<
value
->
type_name
();
}
}
void
IrExportBuilder
::
SetSequenceToAttributeProto
(
const
ValueSequeuePtr
&
value
,
onnx
::
AttributeProto
*
const
attr_proto
)
{
void
IrExportBuilder
::
SetSeqElemToAttributeProto
(
const
ValuePtr
&
value
,
onnx
::
AttributeProto
*
const
attr_proto
,
std
::
string
*
const
seq_string
)
{
string
value_name
=
"value"
+
std
::
to_string
(
GetTupleIndex
());
if
(
seq_string
!=
nullptr
)
{
*
seq_string
+=
value_name
+
","
;
}
onnx
::
TensorProto
*
tensor_proto
=
attr_proto
->
add_tensors
();
SetScalarToProto
(
value
,
tensor_proto
,
value_name
);
}
void
IrExportBuilder
::
SetSequenceToAttributeProto
(
const
ValueSequeuePtr
&
value
,
onnx
::
AttributeProto
*
const
attr_proto
,
std
::
string
*
const
seq_string
)
{
if
(
value
==
nullptr
||
attr_proto
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"ValueSequeuePtr or AttributeProto is null!"
;
}
attr_proto
->
set_ref_attr_name
(
"scalar"
);
attr_proto
->
set_type
(
onnx
::
AttributeProto_AttributeType_TENSOR
);
onnx
::
TensorProto
*
tensor_proto
=
attr_proto
->
mutable_t
();
if
(
value
->
isa
<
ValueTuple
>
())
{
if
(
value
->
isa
<
ValueTuple
>
()
&&
seq_string
!=
nullptr
)
{
*
seq_string
+=
"Tuple["
;
const
ValueTuplePtr
&
tuple_value
=
value
->
cast
<
ValueTuplePtr
>
();
if
(
tuple_value
->
value
().
size
()
==
0
)
{
MS_LOG
(
DEBUG
)
<<
"SetSequenceToAttributeProto tuple size is 0"
;
return
;
}
auto
type_id
=
tuple_value
->
value
()[
0
]
->
type
()
->
type_id
();
tensor_proto
->
set_data_type
(
GetOnnxDataType
(
type_id
));
for
(
const
auto
&
item
:
tuple_value
->
value
())
{
SetScalarToProto
(
item
,
tensor_proto
);
if
(
item
->
isa
<
ValueTuple
>
())
{
SetSequenceToAttributeProto
(
item
->
cast
<
ValueTuplePtr
>
(),
attr_proto
,
seq_string
);
}
else
{
SetSeqElemToAttributeProto
(
item
,
attr_proto
,
seq_string
);
}
}
}
else
if
(
value
->
isa
<
ValueList
>
())
{
*
seq_string
+=
"],"
;
}
else
if
(
value
->
isa
<
ValueList
>
()
&&
seq_string
!=
nullptr
)
{
*
seq_string
+=
"List["
;
const
ValueListPtr
&
list_value
=
value
->
cast
<
ValueListPtr
>
();
if
(
list_value
->
value
().
size
()
==
0
)
{
MS_LOG
(
DEBUG
)
<<
"SetSequenceToAttributeProto list size is 0"
;
MS_LOG
(
DEBUG
)
<<
"SetSequenceToAttributeProto list size is 0
.
"
;
return
;
}
auto
type_id
=
list_value
->
value
()[
0
]
->
type
()
->
type_id
();
tensor_proto
->
set_data_type
(
GetOnnxDataType
(
type_id
));
for
(
const
auto
&
item
:
list_value
->
value
())
{
SetScalarToProto
(
item
,
tensor_proto
);
if
(
item
->
isa
<
ValueList
>
())
{
SetSequenceToAttributeProto
(
item
->
cast
<
ValueListPtr
>
(),
attr_proto
,
seq_string
);
}
else
{
SetSeqElemToAttributeProto
(
item
,
attr_proto
,
seq_string
);
}
}
*
seq_string
+=
"],"
;
}
}
...
...
mindspore/ccsrc/utils/load_onnx/anf_converter.cc
浏览文件 @
2f42466d
...
...
@@ -57,7 +57,7 @@ int AnfConverter::ValidateFileStr(const std::string &modelFile, std::string file
bool
AnfConverter
::
ReadOnnxFromBinary
(
const
std
::
string
&
modelFile
,
google
::
protobuf
::
Message
*
onnx_model
)
{
std
::
unique_ptr
<
char
>
onnx_file
(
new
(
std
::
nothrow
)
char
[
PATH_MAX
]{
0
});
int
fd
=
open
(
onnx_file
.
get
(),
O_RDONLY
);
int
fd
=
open
(
modelFile
.
c_str
(),
O_RDONLY
);
if
(
fd
<
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"failed to open file"
;
}
...
...
mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc
浏览文件 @
2f42466d
...
...
@@ -18,8 +18,12 @@
#include <functional>
#include <map>
#include <memory>
#include <stack>
#include <string>
#include <vector>
#include <unordered_map>
#include <utility>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "ir/tensor.h"
#include "ir/param_info.h"
#include "frontend/operator/ops.h"
...
...
@@ -55,6 +59,97 @@ static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
{
onnx
::
TensorProto_DataType_STRING
,
kObjectTypeString
},
};
template
<
typename
T
,
typename
P
>
std
::
shared_ptr
<
T
>
ParserAttr
(
const
std
::
string
&
str
,
const
std
::
unordered_map
<
string
,
P
>
&
kv
)
{
std
::
stack
<
std
::
string
>
rules
;
std
::
stack
<
P
>
value
;
int
count
=
0
;
for
(
size_t
i
=
0
;
i
<
str
.
length
();
i
++
)
{
if
(
str
[
i
]
==
'['
)
{
rules
.
push
(
"["
);
}
else
if
(
str
[
i
]
==
']'
)
{
// rules
std
::
vector
<
P
>
vec
;
while
(
rules
.
top
()
!=
"["
)
{
rules
.
pop
();
vec
.
push_back
(
value
.
top
());
value
.
pop
();
}
// pop "["
rules
.
pop
();
// make tuple for names
std
::
string
res
=
"dummy"
;
// make tuple for values
reverse
(
vec
.
begin
(),
vec
.
end
());
auto
vt
=
std
::
make_shared
<
T
>
(
vec
);
if
(
rules
.
empty
()
&&
value
.
empty
())
{
return
vt
;
}
rules
.
push
(
res
);
value
.
push
(
vt
);
}
else
if
(
str
[
i
]
==
','
)
{
continue
;
}
else
{
count
++
;
if
(
str
[
i
+
1
]
==
'['
||
str
[
i
+
1
]
==
']'
||
str
[
i
+
1
]
==
','
)
{
auto
value_name
=
str
.
substr
(
i
-
count
+
1
,
count
);
value
.
push
(
kv
.
at
(
value_name
));
rules
.
push
(
value_name
);
count
=
0
;
}
}
}
return
{};
}
std
::
shared_ptr
<
ValueTuple
>
ParserScalarAttrValue
(
const
std
::
string
&
attr_name
,
const
std
::
unordered_map
<
string
,
ValuePtr
>
&
kv
)
{
std
::
string
str
=
attr_name
;
auto
replace
=
[
&
](
const
string
&
orgStr
,
const
string
&
newStr
)
{
std
::
string
::
size_type
pos
(
0
);
while
((
pos
=
str
.
find
(
orgStr
))
!=
std
::
string
::
npos
)
{
str
.
replace
(
pos
,
orgStr
.
length
(),
newStr
);
}
return
str
;
};
// remove "scalar:"
str
=
replace
(
"scalar:"
,
""
);
// remove "Tuple"
str
=
replace
(
"Tuple"
,
""
);
// remove "List"
str
=
replace
(
"List"
,
""
);
auto
result
=
ParserAttr
<
ValueTuple
>
(
str
,
kv
);
if
(
!
result
)
{
return
{};
}
return
result
;
}
std
::
shared_ptr
<
abstract
::
AbstractTuple
>
ParserAttrShape
(
const
std
::
string
&
attr_name
,
const
std
::
unordered_map
<
string
,
abstract
::
AbstractBasePtr
>
&
kv
)
{
std
::
string
str
=
attr_name
;
auto
replace
=
[
&
](
const
string
&
orgStr
,
const
string
&
newStr
)
{
std
::
string
::
size_type
pos
(
0
);
while
((
pos
=
str
.
find
(
orgStr
))
!=
std
::
string
::
npos
)
{
str
.
replace
(
pos
,
orgStr
.
length
(),
newStr
);
}
return
str
;
};
// remove "scalar:"
str
=
replace
(
"shape:"
,
""
);
// remove "Tuple"
str
=
replace
(
"Tuple"
,
""
);
// remove "List"
str
=
replace
(
"List"
,
""
);
auto
result
=
ParserAttr
<
abstract
::
AbstractTuple
>
(
str
,
kv
);
if
(
!
result
)
{
return
{};
}
return
result
;
}
#if 0
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \
const onnx::TensorProto &attr_tensor) { \
...
...
@@ -67,9 +162,16 @@ static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
if (attr_value_vec.size() == 1) { \
prim->AddAttr(attr_name, attr_value_vec[0]); \
} else { \
prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec));
\
ParserScalarAttrValue(prim, attr_name, attr_value_vec);
\
} \
}
#endif
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \
return MakeValue<valuetype>(value); \
}
PARSE_ONNXATTR_IN_SCALAR_FORM
(
double
,
double
)
PARSE_ONNXATTR_IN_SCALAR_FORM
(
float
,
float
)
...
...
@@ -110,6 +212,7 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, cons
tensor
::
TensorPtr
tensor_info
=
std
::
make_shared
<
tensor
::
Tensor
>
(
kDefaultValueSwitchMap
[
tensor_typeproto
.
elem_type
()],
shape
);
MS_EXCEPTION_IF_NULL
(
tensor_info
);
// tensor_info->MallocData();
auto
tensor_abstract
=
tensor_info
->
ToAbstract
();
MS_EXCEPTION_IF_NULL
(
tensor_abstract
);
node
->
set_abstract
(
tensor_abstract
);
...
...
@@ -167,45 +270,35 @@ bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const
return
true
;
}
bool
MSANFModelParser
::
ObtainCNodeAttrInScalarForm
(
const
PrimitivePtr
&
prim
,
const
std
::
string
&
attr_name
,
const
onnx
::
TensorProto
&
attr_tensor
)
{
MS_EXCEPTION_IF_NULL
(
prim
);
ValuePtr
MSANFModelParser
::
ObtainCNodeAttrInScalarForm
(
const
onnx
::
TensorProto
&
attr_tensor
)
{
const
int
attr_tensor_type
=
attr_tensor
.
data_type
();
switch
(
attr_tensor_type
)
{
case
onnx
::
TensorProto_DataType_STRING
:
{
ParseAttrInScalar_string_string
(
prim
,
attr_name
,
attr_tensor
);
break
;
return
ParseAttrInScalar_string_string
(
attr_tensor
);
}
case
onnx
::
TensorProto_DataType_INT32
:
{
ParseAttrInScalar_int32_int32
(
prim
,
attr_name
,
attr_tensor
);
break
;
return
ParseAttrInScalar_int32_int32
(
attr_tensor
);
}
case
onnx
::
TensorProto_DataType_INT64
:
{
ParseAttrInScalar_int64_int64
(
prim
,
attr_name
,
attr_tensor
);
break
;
return
ParseAttrInScalar_int64_int64
(
attr_tensor
);
}
case
onnx
::
TensorProto_DataType_UINT64
:
{
ParseAttrInScalar_uint64_uint64
(
prim
,
attr_name
,
attr_tensor
);
break
;
return
ParseAttrInScalar_uint64_uint64
(
attr_tensor
);
}
case
onnx
::
TensorProto_DataType_FLOAT
:
{
ParseAttrInScalar_float_float
(
prim
,
attr_name
,
attr_tensor
);
break
;
return
ParseAttrInScalar_float_float
(
attr_tensor
);
}
case
onnx
::
TensorProto_DataType_DOUBLE
:
{
ParseAttrInScalar_double_double
(
prim
,
attr_name
,
attr_tensor
);
break
;
return
ParseAttrInScalar_double_double
(
attr_tensor
);
}
case
onnx
::
TensorProto_DataType_BOOL
:
{
ParseAttrInScalar_int32_bool
(
prim
,
attr_name
,
attr_tensor
);
auto
value
=
prim
->
GetAttr
(
attr_name
);
break
;
return
ParseAttrInScalar_int32_bool
(
attr_tensor
);
}
default:
MS_LOG
(
ERROR
)
<<
"Obtain attr in scalar-form has not support input type: "
<<
attr_tensor_type
;
return
false
;
return
{}
;
}
return
true
;
return
{}
;
}
bool
MSANFModelParser
::
ObtainCNodeAttrInTensorForm
(
const
PrimitivePtr
&
prim
,
const
std
::
string
&
attr_name
,
...
...
@@ -223,21 +316,48 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx
return
false
;
}
const
std
::
string
&
ref_attr_name
=
attr_proto
.
ref_attr_name
();
const
onnx
::
TensorProto
&
attr_tensor
=
attr_proto
.
t
();
switch
(
kParseTypeSwitchMap
[
ref_attr_name
])
{
case
FORM_PARSE_TYPE
:
{
return
ObtainCNodeAttrInTypeForm
(
prim
,
attr_name
,
attr_tensor
);
}
case
FORM_PARSE_SCALAR
:
{
return
ObtainCNodeAttrInScalarForm
(
prim
,
attr_name
,
attr_tensor
);
string
type
;
std
::
size_t
pos
(
0
);
if
((
pos
=
ref_attr_name
.
find
(
"scalar:"
))
!=
std
::
string
::
npos
)
{
type
=
ref_attr_name
.
substr
(
pos
,
string
(
"scalar:"
).
length
()
-
1
);
}
else
if
((
pos
=
ref_attr_name
.
find
(
"type:"
))
!=
std
::
string
::
npos
)
{
type
=
ref_attr_name
.
substr
(
pos
,
string
(
"type:"
).
length
()
-
1
);
}
else
if
((
pos
=
ref_attr_name
.
find
(
"tensor:"
))
!=
std
::
string
::
npos
)
{
type
=
ref_attr_name
.
substr
(
pos
,
string
(
"tensor:"
).
length
()
-
1
);
}
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
kv
;
for
(
int
i
=
0
;
i
<
attr_proto
.
tensors_size
();
i
++
)
{
const
onnx
::
TensorProto
&
attr_tensor
=
attr_proto
.
tensors
(
i
);
switch
(
kParseTypeSwitchMap
[
type
])
{
case
FORM_PARSE_TYPE
:
{
ObtainCNodeAttrInTypeForm
(
prim
,
attr_name
,
attr_tensor
);
break
;
}
case
FORM_PARSE_SCALAR
:
{
auto
res
=
ObtainCNodeAttrInScalarForm
(
attr_tensor
);
kv
.
insert
(
std
::
pair
<
string
,
ValuePtr
>
(
attr_tensor
.
name
(),
res
));
break
;
}
case
FORM_PARSE_TENSOR
:
{
ObtainCNodeAttrInTensorForm
(
prim
,
attr_name
,
attr_tensor
);
break
;
}
default:
MS_LOG
(
ERROR
)
<<
"parse attr type don't support input of ref_attr_name"
;
return
false
;
}
case
FORM_PARSE_TENSOR
:
{
return
ObtainCNodeAttrInTensorForm
(
prim
,
attr_name
,
attr_tensor
);
}
if
(
kParseTypeSwitchMap
[
type
]
==
FORM_PARSE_SCALAR
)
{
if
(
kv
.
size
()
==
1
)
{
auto
iter
=
kv
.
begin
();
prim
->
AddAttr
(
attr_name
,
iter
->
second
);
}
else
{
auto
res
=
ParserScalarAttrValue
(
ref_attr_name
,
kv
);
prim
->
AddAttr
(
attr_name
,
res
);
}
default:
MS_LOG
(
ERROR
)
<<
"parse attr type don't support input of ref_attr_name"
;
return
false
;
}
return
true
;
}
bool
MSANFModelParser
::
ObtainValueNodeInTensorForm
(
const
std
::
string
&
value_node_name
,
const
onnx
::
TensorProto
&
attr_tensor
)
{
...
...
@@ -247,6 +367,7 @@ bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node
shape
.
push_back
(
attr_tensor
.
dims
(
i
));
}
tensor
::
TensorPtr
tensor_info
=
std
::
make_shared
<
tensor
::
Tensor
>
(
kDefaultValueSwitchMap
[
attr_tensor_type
],
shape
);
// tensor_info->MallocData();
const
std
::
string
&
tensor_buf
=
attr_tensor
.
raw_data
();
auto
*
tensor_data_buf
=
reinterpret_cast
<
uint8_t
*>
(
tensor_info
->
data_c
());
auto
ret
=
memcpy_s
(
tensor_data_buf
,
tensor_info
->
data
().
nbytes
(),
tensor_buf
.
data
(),
tensor_buf
.
size
());
...
...
@@ -324,22 +445,58 @@ bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_n
return
true
;
}
bool
MSANFModelParser
::
GetAttrValueForValueNode
(
const
std
::
string
&
ref_attr_name
,
const
std
::
string
&
value_node_name
,
const
onnx
::
TensorProto
&
attr_tensor
)
{
switch
(
kParseTypeSwitchMap
[
ref_attr_name
])
{
case
FORM_PARSE_SCALAR
:
{
return
ObtainValueNodeInScalarForm
(
value_node_name
,
attr_tensor
);
}
case
FORM_PARSE_TENSOR
:
{
return
ObtainValueNodeInTensorForm
(
value_node_name
,
attr_tensor
);
bool
MSANFModelParser
::
GetAttrValueForValueNode
(
const
std
::
string
&
value_node_name
,
const
onnx
::
AttributeProto
&
attr_proto
)
{
if
(
!
attr_proto
.
has_ref_attr_name
())
{
MS_LOG
(
ERROR
)
<<
"CNode parse attr type has no ref_attr_name"
;
return
false
;
}
const
std
::
string
&
ref_attr_name
=
attr_proto
.
ref_attr_name
();
string
type
;
std
::
size_t
pos
(
0
);
if
((
pos
=
ref_attr_name
.
find
(
"scalar:"
))
!=
std
::
string
::
npos
)
{
type
=
ref_attr_name
.
substr
(
pos
,
string
(
"scalar:"
).
length
()
-
1
);
}
else
if
((
pos
=
ref_attr_name
.
find
(
"type:"
))
!=
std
::
string
::
npos
)
{
type
=
ref_attr_name
.
substr
(
pos
,
string
(
"type:"
).
length
()
-
1
);
}
else
if
((
pos
=
ref_attr_name
.
find
(
"tensor:"
))
!=
std
::
string
::
npos
)
{
type
=
ref_attr_name
.
substr
(
pos
,
string
(
"tensor:"
).
length
()
-
1
);
}
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
kv
;
for
(
int
i
=
0
;
i
<
attr_proto
.
tensors_size
();
i
++
)
{
const
onnx
::
TensorProto
&
attr_tensor
=
attr_proto
.
tensors
(
i
);
auto
attr_name
=
attr_tensor
.
name
();
switch
(
kParseTypeSwitchMap
[
type
])
{
case
FORM_PARSE_TYPE
:
{
return
ObtainValueNodeInTypeForm
(
value_node_name
,
attr_tensor
);
}
case
FORM_PARSE_SCALAR
:
{
auto
res
=
ObtainCNodeAttrInScalarForm
(
attr_tensor
);
kv
.
insert
(
std
::
pair
<
string
,
ValuePtr
>
(
attr_tensor
.
name
(),
res
));
break
;
}
case
FORM_PARSE_TENSOR
:
{
return
ObtainValueNodeInTensorForm
(
value_node_name
,
attr_tensor
);
}
default:
MS_LOG
(
ERROR
)
<<
"parse attr type don't support input of ref_attr_name"
;
return
false
;
}
case
FORM_PARSE_TYPE
:
{
return
ObtainValueNodeInTypeForm
(
value_node_name
,
attr_tensor
);
}
ValueNodePtr
new_value_node
;
if
(
kParseTypeSwitchMap
[
type
]
==
FORM_PARSE_SCALAR
)
{
if
(
kv
.
size
()
==
1
)
{
auto
iter
=
kv
.
begin
();
new_value_node
=
NewValueNode
(
iter
->
second
);
new_value_node
->
set_abstract
(
iter
->
second
->
ToAbstract
());
}
else
{
auto
value_ptr
=
ParserScalarAttrValue
(
ref_attr_name
,
kv
);
new_value_node
=
NewValueNode
(
value_ptr
);
new_value_node
->
set_abstract
(
value_ptr
->
ToAbstract
());
}
default:
MS_LOG
(
ERROR
)
<<
"parse ValueNode value don't support input of ref_attr_name"
;
return
false
;
anfnode_build_map_
[
value_node_name
]
=
new_value_node
;
}
return
true
;
}
bool
MSANFModelParser
::
BuildValueNodeForFuncGraph
(
const
onnx
::
NodeProto
&
node_proto
)
{
...
...
@@ -349,24 +506,26 @@ bool MSANFModelParser::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_pr
MS_LOG
(
ERROR
)
<<
"parse ValueNode don't have ref_attr_name"
;
return
false
;
}
const
std
::
string
&
ref_attr_name
=
attr_proto
.
ref_attr_name
();
const
onnx
::
TensorProto
&
attr_tensor
=
attr_proto
.
t
();
return
GetAttrValueForValueNode
(
ref_attr_name
,
value_node_name
,
attr_tensor
);
return
GetAttrValueForValueNode
(
value_node_name
,
attr_proto
);
}
AbstractBasePtr
MSANFModelParser
::
GetAbstractForCNode
(
const
onnx
::
AttributeProto
&
attr_proto
)
{
ShapeVector
shape_vec
;
const
onnx
::
TensorProto
&
attr_tensor
=
attr_proto
.
t
();
for
(
int
i
=
0
;
i
<
attr_tensor
.
dims_size
();
++
i
)
{
shape_vec
.
push_back
(
attr_tensor
.
dims
(
i
));
}
tensor
::
TensorPtr
tensor_info
=
std
::
make_shared
<
tensor
::
Tensor
>
(
kDefaultValueSwitchMap
[
attr_tensor
.
data_type
()],
shape_vec
);
MS_EXCEPTION_IF_NULL
(
tensor_info
);
auto
abstract
=
tensor_info
->
ToAbstract
();
MS_EXCEPTION_IF_NULL
(
abstract
);
return
abstract
;
std
::
unordered_map
<
std
::
string
,
abstract
::
AbstractBasePtr
>
MSANFModelParser
::
GetAbstractForCNode
(
const
onnx
::
AttributeProto
&
attr_proto
)
{
std
::
unordered_map
<
std
::
string
,
abstract
::
AbstractBasePtr
>
kv
;
for
(
int
i
=
0
;
i
<
attr_proto
.
tensors_size
();
++
i
)
{
ShapeVector
shape_vec
;
const
onnx
::
TensorProto
&
attr_tensor
=
attr_proto
.
tensors
(
i
);
for
(
int
j
=
0
;
j
<
attr_tensor
.
dims_size
();
++
j
)
{
shape_vec
.
push_back
(
attr_tensor
.
dims
(
j
));
}
tensor
::
TensorPtr
tensor_info
=
std
::
make_shared
<
tensor
::
Tensor
>
(
kDefaultValueSwitchMap
[
attr_tensor
.
data_type
()],
shape_vec
);
MS_EXCEPTION_IF_NULL
(
tensor_info
);
auto
abstract
=
tensor_info
->
ToAbstract
();
MS_EXCEPTION_IF_NULL
(
abstract
);
kv
.
insert
(
std
::
pair
<
string
,
abstract
::
AbstractBasePtr
>
(
attr_tensor
.
name
(),
abstract
));
}
return
kv
;
}
CNodePtr
MSANFModelParser
::
BuildCNodeForFuncGraph
(
const
FuncGraphPtr
&
outputFuncGraph
,
...
...
@@ -383,21 +542,13 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
MS_EXCEPTION_IF_NULL
(
prim
);
prim
->
set_instance_name
(
node_type
);
AbstractBasePtr
abstract
=
nullptr
;
AbstractBasePtr
abstract_first
=
nullptr
;
AbstractBasePtr
abstract_second
=
nullptr
;
std
::
unordered_map
<
std
::
string
,
abstract
::
AbstractBasePtr
>
kv
;
string
shape_ref_attr_name
;
for
(
int
i
=
0
;
i
<
node_proto
.
attribute_size
();
++
i
)
{
const
onnx
::
AttributeProto
&
attr_proto
=
node_proto
.
attribute
(
i
);
if
(
attr_proto
.
name
()
==
kCNodeShapeAttr
)
{
abstract
=
GetAbstractForCNode
(
attr_proto
);
continue
;
}
if
(
attr_proto
.
name
()
==
kCNodeShape1Attr
)
{
abstract_first
=
GetAbstractForCNode
(
attr_proto
);
continue
;
}
if
(
attr_proto
.
name
()
==
kCNodeShape2Attr
)
{
abstract_second
=
GetAbstractForCNode
(
attr_proto
);
if
(
attr_proto
.
ref_attr_name
().
find
(
"shape:"
)
!=
string
::
npos
)
{
shape_ref_attr_name
=
attr_proto
.
ref_attr_name
();
kv
=
GetAbstractForCNode
(
attr_proto
);
continue
;
}
if
(
!
GetAttrValueForCNode
(
prim
,
attr_proto
))
{
...
...
@@ -419,24 +570,17 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
}
CNodePtr
cnode_ptr
=
outputFuncGraph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
cnode_ptr
);
if
(
node_type
==
"LayerNorm"
)
{
AbstractBasePtrList
elem
;
elem
.
push_back
(
abstract
);
elem
.
push_back
(
abstract_first
);
elem
.
push_back
(
abstract_second
);
cnode_ptr
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
elem
));
}
else
if
(
node_type
==
"ArgMaxWithValue"
)
{
AbstractBasePtrList
elem
;
elem
.
push_back
(
abstract
);
elem
.
push_back
(
abstract_first
);
cnode_ptr
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
elem
));
}
else
if
(
nullptr
==
abstract
)
{
if
(
0
==
kv
.
size
())
{
AbstractBasePtrList
elem
;
for
(
size_t
index
=
1
;
index
<
cnode_ptr
->
inputs
().
size
();
++
index
)
{
elem
.
push_back
(
cnode_ptr
->
input
(
index
)
->
abstract
());
}
cnode_ptr
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
elem
));
}
else
if
(
1
==
kv
.
size
())
{
std
::
unordered_map
<
std
::
string
,
abstract
::
AbstractBasePtr
>::
iterator
iter
=
kv
.
begin
();
cnode_ptr
->
set_abstract
(
iter
->
second
);
}
else
{
auto
abstract
=
ParserAttrShape
(
shape_ref_attr_name
,
kv
);
cnode_ptr
->
set_abstract
(
abstract
);
}
cnode_ptr
->
set_fullname_with_scope
(
fullname_with_scope
);
...
...
@@ -471,19 +615,15 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra
}
else
{
const
onnx
::
ValueInfoProto
&
output_node
=
importProto
.
output
(
0
);
const
onnx
::
TypeProto
&
output_typeproto
=
output_node
.
type
();
int
output_type
=
output_typeproto
.
tensor_type
().
elem_type
();
ShapeVector
output_shape
;
for
(
int
i
=
0
;
i
<
output_typeproto
.
tensor_type
().
shape
().
dim_size
();
++
i
)
{
output_shape
.
push_back
(
output_typeproto
.
tensor_type
().
shape
().
dim
(
i
).
dim_value
());
}
tensor
::
TensorPtr
tensor_return
=
std
::
make_shared
<
tensor
::
Tensor
>
(
kDefaultValueSwitchMap
[
output_type
],
output_shape
);
inputs
.
clear
();
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimReturn
));
inputs
.
push_back
(
cnode_ptr
);
auto
return_node
=
outputFuncGraph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
return_node
);
return_node
->
set_abstract
(
tensor_return
->
ToAbstract
());
outputFuncGraph
->
set_return
(
return_node
);
MS_LOG
(
INFO
)
<<
"Construct funcgraph finined, all success!"
;
}
...
...
mindspore/ccsrc/utils/load_onnx/anf_model_parser.h
浏览文件 @
2f42466d
...
...
@@ -52,18 +52,17 @@ class MSANFModelParser {
bool
GetAttrValueForCNode
(
const
PrimitivePtr
&
prim
,
const
onnx
::
AttributeProto
&
attr_proto
);
bool
ObtainCNodeAttrInTypeForm
(
const
PrimitivePtr
&
prim
,
const
std
::
string
&
attr_name
,
const
onnx
::
TensorProto
&
attr_tensor
);
bool
ObtainCNodeAttrInScalarForm
(
const
PrimitivePtr
&
prim
,
const
std
::
string
&
attr_name
,
const
onnx
::
TensorProto
&
attr_tensor
);
ValuePtr
ObtainCNodeAttrInScalarForm
(
const
onnx
::
TensorProto
&
attr_tensor
);
bool
ObtainCNodeAttrInTensorForm
(
const
PrimitivePtr
&
prim
,
const
std
::
string
&
attr_name
,
const
onnx
::
TensorProto
&
attr_tensor
);
bool
BuildValueNodeForFuncGraph
(
const
onnx
::
NodeProto
&
node_proto
);
bool
ObtainValueNodeInTensorForm
(
const
string
&
value_node_name
,
const
onnx
::
TensorProto
&
attr_tensor
);
bool
ObtainValueNodeInScalarForm
(
const
string
&
value_node_name
,
const
onnx
::
TensorProto
&
attr_tensor
);
bool
GetAttrValueForValueNode
(
const
string
&
ref_attr_name
,
const
std
::
string
&
value_node_name
,
const
onnx
::
TensorProto
&
attr_tensor
);
bool
GetAttrValueForValueNode
(
const
std
::
string
&
value_node_name
,
const
onnx
::
AttributeProto
&
attr_tensor
);
bool
ObtainValueNodeInTypeForm
(
const
string
&
value_node_name
,
const
onnx
::
TensorProto
&
attr_tensor
);
AbstractBasePtr
GetAbstractForCNode
(
const
onnx
::
AttributeProto
&
attr_proto
);
std
::
unordered_map
<
std
::
string
,
abstract
::
AbstractBasePtr
>
GetAbstractForCNode
(
const
onnx
::
AttributeProto
&
attr_proto
);
std
::
string
producer_name_
;
int
model_version_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录