Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
d4671497
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d4671497
编写于
8月 10, 2020
作者:
Y
yeyunpeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix op multi output problem
上级
8e3c8f3d
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
119 addition
and
71 deletion
+119
-71
mindspore/lite/schema/model.fbs
mindspore/lite/schema/model.fbs
+3
-1
mindspore/lite/schema/ops.fbs
mindspore/lite/schema/ops.fbs
+6
-0
mindspore/lite/src/common/anf_exporter/anf_exporter.cc
mindspore/lite/src/common/anf_exporter/anf_exporter.cc
+71
-59
mindspore/lite/src/common/anf_exporter/anf_exporter.h
mindspore/lite/src/common/anf_exporter/anf_exporter.h
+2
-3
mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc
...e/lite/src/common/anf_importer/import_from_meta_graphT.cc
+37
-8
未找到文件。
mindspore/lite/schema/model.fbs
浏览文件 @
d4671497
...
@@ -189,7 +189,9 @@ union PrimitiveType {
...
@@ -189,7 +189,9 @@ union PrimitiveType {
ActivationGrad,
ActivationGrad,
PriorBox,
PriorBox,
SpaceToBatchND,
SpaceToBatchND,
TopKV2
TopKV2,
Return,
MakeTuple
}
}
enum QuantType: int {
enum QuantType: int {
...
...
mindspore/lite/schema/ops.fbs
浏览文件 @
d4671497
...
@@ -864,3 +864,9 @@ table TopKV2 {
...
@@ -864,3 +864,9 @@ table TopKV2 {
sorted : bool = true;
sorted : bool = true;
}
}
table MakeTuple {
}
table Return {
}
\ No newline at end of file
mindspore/lite/src/common/anf_exporter/anf_exporter.cc
浏览文件 @
d4671497
...
@@ -81,8 +81,7 @@ bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) {
...
@@ -81,8 +81,7 @@ bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) {
return
false
;
return
false
;
}
}
ValueNodePtr
valueNode
=
utils
::
cast
<
ValueNodePtr
>
(
indexNode
);
ValueNodePtr
valueNode
=
utils
::
cast
<
ValueNodePtr
>
(
indexNode
);
mapRemoveGetItem_
[
tupleGetItemNode
->
input
(
1
)
->
fullname_with_scope
()]
=
mapRemoveGetItem_
[
tupleGetItemNode
->
input
(
1
)
->
fullname_with_scope
()]
=
GetValue
<
int
>
(
valueNode
->
value
());
GetValue
<
int
>
(
valueNode
->
value
());
}
else
{
}
else
{
inputs
.
emplace_back
(
cnode
->
input
(
i
));
inputs
.
emplace_back
(
cnode
->
input
(
i
));
}
}
...
@@ -114,16 +113,34 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
...
@@ -114,16 +113,34 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
auto
metaGraphT
=
std
::
make_unique
<
schema
::
MetaGraphT
>
();
auto
metaGraphT
=
std
::
make_unique
<
schema
::
MetaGraphT
>
();
for
(
const
auto
&
cnode
:
cnodes
)
{
for
(
const
auto
&
cnode
:
cnodes
)
{
auto
primitive
=
GetValueNode
<
PrimitivePtr
>
(
cnode
->
input
(
0
));
auto
primitive
=
GetValueNode
<
PrimitivePtr
>
(
cnode
->
input
(
0
));
if
(
primitive
!=
nullptr
&&
if
(
primitive
!=
nullptr
)
{
RemoveNodeInAnfExporter
.
count
(
primitive
->
name
())
!=
0
)
{
if
(
RemoveNodeInAnfExporter
.
count
(
primitive
->
name
())
!=
0
)
{
continue
;
continue
;
}
}
else
{
auto
primitiveT_value
=
GetValueNode
<
std
::
shared_ptr
<
PrimitiveTValue
>>
(
cnode
->
input
(
0
));
auto
primT
=
primitiveT_value
->
GetPrimitiveT
();
if
(
primT
->
value
.
type
==
schema
::
PrimitiveType_TupleGetItem
||
primT
->
value
.
type
==
schema
::
PrimitiveType_MakeTuple
)
{
continue
;
}
}
}
mapRemoveGetItem_
.
clear
();
mapRemoveGetItem_
.
clear
();
RemoveIfMakeTuple
(
cnode
);
RemoveIfMakeTuple
(
cnode
);
RemoveIfTupleGetItem
(
cnode
);
RemoveIfTupleGetItem
(
cnode
);
if
(
primitive
!=
nullptr
&&
primitive
->
name
()
==
prim
::
kPrimReturn
->
name
())
{
AddOutPutIfReturn
(
metaGraphT
,
cnode
);
if
(
primitive
!=
nullptr
)
{
continue
;
if
(
primitive
->
name
()
==
prim
::
kPrimReturn
->
name
())
{
AddOutPutIfReturn
(
metaGraphT
,
cnode
);
continue
;
}
}
else
{
auto
primitiveT_value
=
GetValueNode
<
std
::
shared_ptr
<
PrimitiveTValue
>>
(
cnode
->
input
(
0
));
auto
primT
=
primitiveT_value
->
GetPrimitiveT
();
if
(
primT
->
value
.
type
==
schema
::
PrimitiveType_Return
)
{
AddOutPutIfReturn
(
metaGraphT
,
cnode
);
continue
;
}
}
}
auto
node
=
std
::
make_unique
<
schema
::
CNodeT
>
();
auto
node
=
std
::
make_unique
<
schema
::
CNodeT
>
();
...
@@ -134,27 +151,24 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
...
@@ -134,27 +151,24 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
primitive
=
GetValueNode
<
PrimitivePtr
>
(
cnode
->
input
(
0
));
primitive
=
GetValueNode
<
PrimitivePtr
>
(
cnode
->
input
(
0
));
MS_ASSERT
(
primitive
!=
nullptr
);
MS_ASSERT
(
primitive
!=
nullptr
);
std
::
string
opType
=
primitive
->
name
();
std
::
string
opType
=
primitive
->
name
();
auto
nodeParser
=
auto
nodeParser
=
AnfNodePopulaterRegistry
::
GetInstance
()
->
GetNodePopulater
(
opType
);
AnfNodePopulaterRegistry
::
GetInstance
()
->
GetNodePopulater
(
opType
);
if
(
nodeParser
==
nullptr
)
{
if
(
nodeParser
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Find op parser failed, opType: "
<<
opType
;
MS_LOG
(
ERROR
)
<<
"Find op parser failed, opType: "
<<
opType
;
return
nullptr
;
return
nullptr
;
}
}
std
::
vector
<
schema
::
TensorT
*>
outputs
;
std
::
vector
<
schema
::
TensorT
*>
outputs
;
if
(
utils
::
isa
<
abstract
::
AbstractSequeue
>
(
cnode
->
abstract
()))
{
if
(
utils
::
isa
<
abstract
::
AbstractSequeue
>
(
cnode
->
abstract
()))
{
auto
abstract_cnode
=
auto
abstract_cnode
=
utils
::
cast
<
abstract
::
AbstractSequeuePtr
>
(
cnode
->
abstract
());
utils
::
cast
<
abstract
::
AbstractSequeuePtr
>
(
cnode
->
abstract
());
outputs
.
resize
(
abstract_cnode
->
size
());
outputs
.
resize
(
abstract_cnode
->
size
());
}
}
nodeParser
->
Parse
(
cnode
,
node
.
get
(),
&
outputs
);
nodeParser
->
Parse
(
cnode
,
node
.
get
(),
&
outputs
);
SetOpInputNode
(
cnode
,
metaGraphT
.
get
(),
node
.
get
());
SetOpInputNode
(
cnode
,
metaGraphT
.
get
(),
node
.
get
());
SetOpOutputNode
(
outputs
,
metaGraphT
.
get
(),
node
.
get
());
SetOpOutputNode
(
cnode
,
outputs
,
metaGraphT
.
get
(),
node
.
get
());
metaGraphT
->
nodes
.
emplace_back
(
std
::
move
(
node
));
metaGraphT
->
nodes
.
emplace_back
(
std
::
move
(
node
));
continue
;
continue
;
}
}
auto
primitiveT_value
=
auto
primitiveT_value
=
GetValueNode
<
std
::
shared_ptr
<
PrimitiveTValue
>>
(
cnode
->
input
(
0
));
GetValueNode
<
std
::
shared_ptr
<
PrimitiveTValue
>>
(
cnode
->
input
(
0
));
if
(
primitiveT_value
==
nullptr
)
{
if
(
primitiveT_value
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"PrimitiveT_value is nullptr"
;
MS_LOG
(
ERROR
)
<<
"PrimitiveT_value is nullptr"
;
return
nullptr
;
return
nullptr
;
...
@@ -166,11 +180,10 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
...
@@ -166,11 +180,10 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
return
nullptr
;
return
nullptr
;
}
}
node
->
primitive
=
node
->
primitive
=
std
::
unique_ptr
<
schema
::
PrimitiveT
>
(
primitiveT_value
->
GetPrimitiveT
());
std
::
unique_ptr
<
schema
::
PrimitiveT
>
(
primitiveT_value
->
GetPrimitiveT
());
std
::
vector
<
schema
::
TensorT
*>
outputs
;
std
::
vector
<
schema
::
TensorT
*>
outputs
;
SetOpInputNode
(
cnode
,
metaGraphT
.
get
(),
node
.
get
());
SetOpInputNode
(
cnode
,
metaGraphT
.
get
(),
node
.
get
());
SetOpOutputNode
(
outputs
,
metaGraphT
.
get
(),
node
.
get
());
SetOpOutputNode
(
cnode
,
outputs
,
metaGraphT
.
get
(),
node
.
get
());
// add quant param
// add quant param
node
->
quantType
=
primitiveT_value
->
GetQuantType
();
node
->
quantType
=
primitiveT_value
->
GetQuantType
();
...
@@ -244,9 +257,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
...
@@ -244,9 +257,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
return
metaGraphT
.
release
();
return
metaGraphT
.
release
();
}
}
void
AnfExporter
::
SetOpInputNode
(
const
CNodePtr
&
cnode
,
void
AnfExporter
::
SetOpInputNode
(
const
CNodePtr
&
cnode
,
schema
::
MetaGraphT
*
meta_graph
,
schema
::
CNodeT
*
fbNode
)
{
schema
::
MetaGraphT
*
meta_graph
,
schema
::
CNodeT
*
fbNode
)
{
MS_ASSERT
(
nullptr
!=
meta_graph
);
MS_ASSERT
(
nullptr
!=
meta_graph
);
MS_ASSERT
(
nullptr
!=
fbNode
);
MS_ASSERT
(
nullptr
!=
fbNode
);
if
(
cnode
->
inputs
().
size
()
<=
1
)
{
if
(
cnode
->
inputs
().
size
()
<=
1
)
{
...
@@ -281,38 +292,30 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
...
@@ -281,38 +292,30 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
auto
paramTensor
=
std
::
make_unique
<
schema
::
TensorT
>
();
auto
paramTensor
=
std
::
make_unique
<
schema
::
TensorT
>
();
auto
abstractBase
=
paramNode
->
abstract
();
auto
abstractBase
=
paramNode
->
abstract
();
if
(
abstractBase
==
nullptr
)
{
if
(
abstractBase
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Abstract of parameter is nullptr, "
MS_LOG
(
ERROR
)
<<
"Abstract of parameter is nullptr, "
<<
paramNode
->
name
();
<<
paramNode
->
name
();
MS_ASSERT
(
false
);
MS_ASSERT
(
false
);
return
;
return
;
}
}
if
(
!
utils
::
isa
<
abstract
::
AbstractTensorPtr
>
(
abstractBase
))
{
if
(
!
utils
::
isa
<
abstract
::
AbstractTensorPtr
>
(
abstractBase
))
{
MS_LOG
(
ERROR
)
<<
"Abstract of parameter should be anstract tensor, "
MS_LOG
(
ERROR
)
<<
"Abstract of parameter should be anstract tensor, "
<<
paramNode
->
name
();
<<
paramNode
->
name
();
MS_ASSERT
(
false
);
MS_ASSERT
(
false
);
return
;
return
;
}
}
auto
abstractTensor
=
auto
abstractTensor
=
utils
::
cast
<
abstract
::
AbstractTensorPtr
>
(
abstractBase
);
utils
::
cast
<
abstract
::
AbstractTensorPtr
>
(
abstractBase
);
auto
typePtr
=
abstractTensor
->
element
()
->
GetTypeTrack
();
auto
typePtr
=
abstractTensor
->
element
()
->
GetTypeTrack
();
MS_ASSERT
(
typePtr
!=
nullptr
);
MS_ASSERT
(
typePtr
!=
nullptr
);
paramTensor
->
dataType
=
typePtr
->
type_id
();
paramTensor
->
dataType
=
typePtr
->
type_id
();
if
(
!
utils
::
isa
<
abstract
::
ShapePtr
>
(
abstractTensor
->
BuildShape
()))
{
if
(
!
utils
::
isa
<
abstract
::
ShapePtr
>
(
abstractTensor
->
BuildShape
()))
{
MS_LOG
(
ERROR
)
<<
"Shape of Abstract of parameter should be ShapePtr, "
MS_LOG
(
ERROR
)
<<
"Shape of Abstract of parameter should be ShapePtr, "
<<
paramNode
->
name
();
<<
paramNode
->
name
();
MS_ASSERT
(
false
);
MS_ASSERT
(
false
);
return
;
return
;
}
}
paramTensor
->
dims
=
paramTensor
->
dims
=
utils
::
cast
<
abstract
::
ShapePtr
>
(
abstractTensor
->
BuildShape
())
->
shape
();
utils
::
cast
<
abstract
::
ShapePtr
>
(
abstractTensor
->
BuildShape
())
auto
paramValue
=
std
::
dynamic_pointer_cast
<
ParamValueLite
>
(
paramNode
->
default_param
());
->
shape
();
auto
paramValue
=
std
::
dynamic_pointer_cast
<
ParamValueLite
>
(
paramNode
->
default_param
());
if
(
paramValue
!=
nullptr
)
{
if
(
paramValue
!=
nullptr
)
{
paramTensor
->
nodeType
=
schema
::
NodeType_ValueNode
;
paramTensor
->
nodeType
=
schema
::
NodeType_ValueNode
;
paramTensor
->
data
.
resize
(
paramValue
->
tensor_size
());
paramTensor
->
data
.
resize
(
paramValue
->
tensor_size
());
memcpy
(
paramTensor
->
data
.
data
(),
paramValue
->
tensor_addr
(),
memcpy
(
paramTensor
->
data
.
data
(),
paramValue
->
tensor_addr
(),
paramValue
->
tensor_size
());
paramValue
->
tensor_size
());
for
(
auto
&
ite
:
paramValue
->
quant_param
())
{
for
(
auto
&
ite
:
paramValue
->
quant_param
())
{
auto
quantPar
=
std
::
make_unique
<
schema
::
QuantParamT
>
();
auto
quantPar
=
std
::
make_unique
<
schema
::
QuantParamT
>
();
quantPar
->
scale
=
ite
->
scale
;
quantPar
->
scale
=
ite
->
scale
;
...
@@ -326,8 +329,7 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
...
@@ -326,8 +329,7 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
paramTensor
->
dataType
=
paramValue
->
tensor_type
();
paramTensor
->
dataType
=
paramValue
->
tensor_type
();
}
}
}
}
nodeIdMap
[
paramNode
->
fullname_with_scope
()]
=
nodeIdMap
[
paramNode
->
fullname_with_scope
()]
=
meta_graph
->
allTensors
.
size
();
meta_graph
->
allTensors
.
size
();
fbNode
->
inputIndex
.
emplace_back
(
meta_graph
->
allTensors
.
size
());
fbNode
->
inputIndex
.
emplace_back
(
meta_graph
->
allTensors
.
size
());
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
paramTensor
));
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
paramTensor
));
}
else
if
(
inputNode
->
isa
<
ValueNode
>
())
{
}
else
if
(
inputNode
->
isa
<
ValueNode
>
())
{
...
@@ -336,19 +338,15 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
...
@@ -336,19 +338,15 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
auto
value
=
valueNode
->
value
();
auto
value
=
valueNode
->
value
();
if
(
value
->
isa
<
lite
::
tensor
::
Tensor
>
())
{
if
(
value
->
isa
<
lite
::
tensor
::
Tensor
>
())
{
auto
valueAbstract
=
valueNode
->
abstract
();
auto
valueAbstract
=
valueNode
->
abstract
();
auto
abstractTensor
=
auto
abstractTensor
=
utils
::
cast
<
abstract
::
AbstractTensorPtr
>
(
valueAbstract
);
utils
::
cast
<
abstract
::
AbstractTensorPtr
>
(
valueAbstract
);
auto
typePtr
=
abstractTensor
->
element
()
->
GetTypeTrack
();
auto
typePtr
=
abstractTensor
->
element
()
->
GetTypeTrack
();
paramTensor
->
dataType
=
typePtr
->
type_id
();
paramTensor
->
dataType
=
typePtr
->
type_id
();
paramTensor
->
dims
=
paramTensor
->
dims
=
utils
::
cast
<
abstract
::
ShapePtr
>
(
abstractTensor
->
BuildShape
())
->
shape
();
utils
::
cast
<
abstract
::
ShapePtr
>
(
abstractTensor
->
BuildShape
())
->
shape
();
paramTensor
->
nodeType
=
schema
::
NodeType_ValueNode
;
paramTensor
->
nodeType
=
schema
::
NodeType_ValueNode
;
auto
data
=
value
->
cast
<
lite
::
tensor
::
TensorPtr
>
();
auto
data
=
value
->
cast
<
lite
::
tensor
::
TensorPtr
>
();
paramTensor
->
data
.
resize
(
data
->
Size
());
paramTensor
->
data
.
resize
(
data
->
Size
());
memcpy
(
paramTensor
->
data
.
data
(),
data
->
Data
(),
data
->
Size
());
memcpy
(
paramTensor
->
data
.
data
(),
data
->
Data
(),
data
->
Size
());
nodeIdMap
[
valueNode
->
fullname_with_scope
()]
=
nodeIdMap
[
valueNode
->
fullname_with_scope
()]
=
meta_graph
->
allTensors
.
size
();
meta_graph
->
allTensors
.
size
();
fbNode
->
inputIndex
.
emplace_back
(
meta_graph
->
allTensors
.
size
());
fbNode
->
inputIndex
.
emplace_back
(
meta_graph
->
allTensors
.
size
());
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
paramTensor
));
meta_graph
->
allTensors
.
emplace_back
(
std
::
move
(
paramTensor
));
}
else
if
(
value
->
isa
<
mindspore
::
Int32Imm
>
())
{
}
else
if
(
value
->
isa
<
mindspore
::
Int32Imm
>
())
{
...
@@ -376,30 +374,44 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
...
@@ -376,30 +374,44 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode,
}
}
}
}
void
AnfExporter
::
SetOpOutputNode
(
void
AnfExporter
::
SetOpOutputNode
(
const
CNodePtr
&
cnode
,
const
std
::
vector
<
schema
::
TensorT
*>
&
outputTensors
,
const
std
::
vector
<
schema
::
TensorT
*>
&
outputTensors
,
schema
::
MetaGraphT
*
graph
,
schema
::
CNodeT
*
fbnode
)
{
schema
::
MetaGraphT
*
graph
,
schema
::
CNodeT
*
cnode
)
{
MS_ASSERT
(
nullptr
!=
graph
);
MS_ASSERT
(
nullptr
!=
graph
);
MS_ASSERT
(
nullptr
!=
c
node
);
MS_ASSERT
(
nullptr
!=
fb
node
);
std
::
string
cnodeName
=
c
node
->
name
;
std
::
string
cnodeName
=
fb
node
->
name
;
if
(
!
outputTensors
.
empty
())
{
if
(
!
outputTensors
.
empty
())
{
int
i
=
0
;
int
i
=
0
;
for
(
auto
outputTensor
:
outputTensors
)
{
for
(
auto
outputTensor
:
outputTensors
)
{
std
::
string
name
=
cnodeName
+
"_o:"
+
std
::
to_string
(
i
);
std
::
string
name
=
cnodeName
+
"_o:"
+
std
::
to_string
(
i
);
auto
msTensor
=
new
schema
::
TensorT
();
msTensor
->
nodeType
=
schema
::
NodeType_Parameter
;
nodeIdMap
[
name
]
=
graph
->
allTensors
.
size
();
nodeIdMap
[
name
]
=
graph
->
allTensors
.
size
();
c
node
->
outputIndex
.
emplace_back
(
graph
->
allTensors
.
size
());
fb
node
->
outputIndex
.
emplace_back
(
graph
->
allTensors
.
size
());
graph
->
allTensors
.
emplace_back
(
ms
Tensor
);
graph
->
allTensors
.
emplace_back
(
output
Tensor
);
i
++
;
i
++
;
}
}
return
;
return
;
}
}
auto
msTensor
=
new
schema
::
TensorT
();
msTensor
->
nodeType
=
schema
::
NodeType_Parameter
;
if
(
utils
::
isa
<
abstract
::
AbstractTuple
>
(
cnode
->
abstract
()))
{
cnode
->
outputIndex
.
emplace_back
(
graph
->
allTensors
.
size
());
auto
tuple
=
std
::
reinterpret_pointer_cast
<
abstract
::
AbstractTuple
>
(
cnode
->
abstract
());
nodeIdMap
[
cnodeName
]
=
graph
->
allTensors
.
size
();
for
(
int
i
=
0
;
i
<
tuple
->
size
();
i
++
)
{
graph
->
allTensors
.
emplace_back
(
msTensor
);
auto
msTensor
=
new
schema
::
TensorT
();
msTensor
->
nodeType
=
schema
::
NodeType_Parameter
;
fbnode
->
outputIndex
.
emplace_back
(
graph
->
allTensors
.
size
());
if
(
tuple
->
size
()
==
1
)
{
nodeIdMap
[
cnodeName
]
=
graph
->
allTensors
.
size
();
}
else
{
std
::
string
name
=
cnodeName
+
"_o:"
+
std
::
to_string
(
i
);
nodeIdMap
[
name
]
=
graph
->
allTensors
.
size
();
}
graph
->
allTensors
.
emplace_back
(
msTensor
);
}
}
else
{
auto
msTensor
=
new
schema
::
TensorT
();
msTensor
->
nodeType
=
schema
::
NodeType_Parameter
;
fbnode
->
outputIndex
.
emplace_back
(
graph
->
allTensors
.
size
());
nodeIdMap
[
cnodeName
]
=
graph
->
allTensors
.
size
();
graph
->
allTensors
.
emplace_back
(
msTensor
);
}
}
}
schema
::
MetaGraphT
*
Export
(
const
FuncGraphPtr
&
funcGraph
)
{
schema
::
MetaGraphT
*
Export
(
const
FuncGraphPtr
&
funcGraph
)
{
...
...
mindspore/lite/src/common/anf_exporter/anf_exporter.h
浏览文件 @
d4671497
...
@@ -32,8 +32,8 @@ class AnfExporter {
...
@@ -32,8 +32,8 @@ class AnfExporter {
AnfExporter
()
=
default
;
AnfExporter
()
=
default
;
virtual
~
AnfExporter
()
=
default
;
virtual
~
AnfExporter
()
=
default
;
schema
::
MetaGraphT
*
Export
(
const
FuncGraphPtr
&
funcGraph
);
schema
::
MetaGraphT
*
Export
(
const
FuncGraphPtr
&
funcGraph
);
void
SetOpOutputNode
(
const
std
::
vector
<
schema
::
TensorT
*>
&
outputTensors
,
schema
::
MetaGraphT
*
graph
,
void
SetOpOutputNode
(
const
CNodePtr
&
cnode
,
const
std
::
vector
<
schema
::
TensorT
*>
&
outputTensors
,
schema
::
CNodeT
*
c
node
);
schema
::
MetaGraphT
*
graph
,
schema
::
CNodeT
*
fb
node
);
void
SetOpInputNode
(
const
CNodePtr
&
cnode
,
schema
::
MetaGraphT
*
meta_graph
,
schema
::
CNodeT
*
fbNode
);
void
SetOpInputNode
(
const
CNodePtr
&
cnode
,
schema
::
MetaGraphT
*
meta_graph
,
schema
::
CNodeT
*
fbNode
);
void
RemoveIfMakeTuple
(
const
CNodePtr
&
cnode
);
void
RemoveIfMakeTuple
(
const
CNodePtr
&
cnode
);
bool
RemoveIfTupleGetItem
(
const
CNodePtr
&
cnode
);
bool
RemoveIfTupleGetItem
(
const
CNodePtr
&
cnode
);
...
@@ -47,4 +47,3 @@ class AnfExporter {
...
@@ -47,4 +47,3 @@ class AnfExporter {
schema
::
MetaGraphT
*
Export
(
const
FuncGraphPtr
&
funcGraph
);
schema
::
MetaGraphT
*
Export
(
const
FuncGraphPtr
&
funcGraph
);
}
// namespace mindspore::lite
}
// namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_
#endif // MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_
mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc
浏览文件 @
d4671497
...
@@ -71,11 +71,11 @@ int AnfImporterFromMetaGraphT::ConverterCNode() {
...
@@ -71,11 +71,11 @@ int AnfImporterFromMetaGraphT::ConverterCNode() {
for
(
size_t
i
=
0
;
i
<
meta_graph_
->
nodes
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
meta_graph_
->
nodes
.
size
();
i
++
)
{
auto
&
cNode
=
meta_graph_
->
nodes
.
at
(
i
);
auto
&
cNode
=
meta_graph_
->
nodes
.
at
(
i
);
MS_EXCEPTION_IF_NULL
(
cNode
);
MS_EXCEPTION_IF_NULL
(
cNode
);
auto
tensor_id
=
cNode
->
outputIndex
.
front
();
if
(
nullptr
!=
GetNode
(
tensor_id
))
{
continue
;
}
bool
flag
=
false
;
if
(
cNode
->
outputIndex
.
size
()
>
1
)
{
flag
=
true
;
}
auto
primTValue
=
std
::
make_shared
<
PrimitiveTValue
>
(
cNode
->
primitive
.
release
());
auto
primTValue
=
std
::
make_shared
<
PrimitiveTValue
>
(
cNode
->
primitive
.
release
());
cNode
->
primitive
=
nullptr
;
cNode
->
primitive
=
nullptr
;
auto
value_node
=
NewValueNode
(
primTValue
);
auto
value_node
=
NewValueNode
(
primTValue
);
...
@@ -90,9 +90,39 @@ int AnfImporterFromMetaGraphT::ConverterCNode() {
...
@@ -90,9 +90,39 @@ int AnfImporterFromMetaGraphT::ConverterCNode() {
// todo: CheckInputNodeType, the first node should be op;
// todo: CheckInputNodeType, the first node should be op;
op_inputs
.
push_back
(
node
);
op_inputs
.
push_back
(
node
);
}
}
auto
cnode
=
func_graph_
->
NewCNode
(
op_inputs
);
cnode
->
set_fullname_with_scope
(
cNode
->
name
);
auto
new_cnode
=
func_graph_
->
NewCNode
(
op_inputs
);
AddNode
(
tensor_id
,
cnode
);
new_cnode
->
set_fullname_with_scope
(
cNode
->
name
);
std
::
vector
<
uint32_t
>
out_tensor_ids
=
cNode
->
outputIndex
;
AbstractBasePtrList
ptr_list
;
int
total
=
0
;
for
(
auto
out_tensor_id
:
out_tensor_ids
)
{
if
(
nullptr
!=
GetNode
(
out_tensor_id
))
{
ptr_list
.
push_back
(
GetNode
(
out_tensor_id
)
->
abstract
());
continue
;
}
std
::
vector
<
int
>
shape
;
auto
&
tensor
=
meta_graph_
->
allTensors
.
at
(
out_tensor_id
);
for
(
int
&
dim
:
tensor
->
dims
)
{
shape
.
push_back
(
dim
);
}
auto
type_id
=
static_cast
<
TypeId
>
(
tensor
->
dataType
);
auto
type_ptr
=
TypeIdToType
(
type_id
);
auto
abstract_tensor
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
type_ptr
,
shape
);
auto
getItemPrim
=
NewValueNode
(
prim
::
kPrimTupleGetItem
);
if
(
flag
)
{
auto
getItemIndex
=
NewValueNode
(
MakeValue
<
int
>
(
total
++
));
std
::
vector
<
AnfNodePtr
>
inputs
{
getItemPrim
,
new_cnode
,
getItemIndex
};
CNodePtr
new_item_cnode
=
func_graph_
->
NewCNode
(
inputs
);
AddNode
(
out_tensor_id
,
new_item_cnode
);
}
else
{
AddNode
(
out_tensor_id
,
new_cnode
);
}
ptr_list
.
push_back
(
std
::
move
(
abstract_tensor
));
}
new_cnode
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
ptr_list
));
}
}
return
RET_OK
;
return
RET_OK
;
}
}
...
@@ -120,4 +150,3 @@ void AnfImporterFromMetaGraphT::AddReturnCNode() {
...
@@ -120,4 +150,3 @@ void AnfImporterFromMetaGraphT::AddReturnCNode() {
FuncGraphPtr
AnfImporterFromMetaGraphT
::
GetResult
()
{
return
this
->
func_graph_
;
}
FuncGraphPtr
AnfImporterFromMetaGraphT
::
GetResult
()
{
return
this
->
func_graph_
;
}
}
// namespace mindspore::lite
}
// namespace mindspore::lite
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录