Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
61d5539f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
61d5539f
编写于
9月 08, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
9月 08, 2020
浏览文件
操作
浏览文件
下载
差异文件
!5784 fix codedex & reviewbot
Merge pull request !5784 from Margaret_wangrui/codedex_bot
上级
c65f32b2
1f107d5a
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
81 addition
and
77 deletion
+81
-77
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
+0
-1
mindspore/ccsrc/backend/session/executor.cc
mindspore/ccsrc/backend/session/executor.cc
+1
-1
mindspore/ccsrc/backend/session/infer_session.cc
mindspore/ccsrc/backend/session/infer_session.cc
+6
-25
mindspore/ccsrc/backend/session/session_basic.cc
mindspore/ccsrc/backend/session/session_basic.cc
+70
-50
mindspore/ccsrc/backend/session/session_basic.h
mindspore/ccsrc/backend/session/session_basic.h
+4
-0
未找到文件。
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
浏览文件 @
61d5539f
...
@@ -1220,6 +1220,5 @@ bool AnfRuntimeAlgorithm::IsIndependentNode(const CNodePtr &node) {
...
@@ -1220,6 +1220,5 @@ bool AnfRuntimeAlgorithm::IsIndependentNode(const CNodePtr &node) {
}
}
return
true
;
return
true
;
}
}
}
// namespace session
}
// namespace session
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/backend/session/executor.cc
浏览文件 @
61d5539f
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
session
{
namespace
session
{
namespace
{
namespace
{
void
UpdateOutputTensors
(
VectorRef
*
outputs
,
void
UpdateOutputTensors
(
const
VectorRef
*
outputs
,
const
std
::
map
<
tensor
::
TensorPtr
,
session
::
KernelWithIndex
>
&
tensor_to_node
)
{
const
std
::
map
<
tensor
::
TensorPtr
,
session
::
KernelWithIndex
>
&
tensor_to_node
)
{
MS_EXCEPTION_IF_NULL
(
outputs
);
MS_EXCEPTION_IF_NULL
(
outputs
);
for
(
auto
item
:
*
outputs
)
{
for
(
auto
item
:
*
outputs
)
{
...
...
mindspore/ccsrc/backend/session/infer_session.cc
浏览文件 @
61d5539f
...
@@ -35,7 +35,6 @@ using std::vector;
...
@@ -35,7 +35,6 @@ using std::vector;
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
namespace
mindspore
{
namespace
mindspore
{
namespace
inference
{
namespace
inference
{
std
::
shared_ptr
<
InferSession
>
InferSession
::
CreateSession
(
const
std
::
string
&
device
,
uint32_t
device_id
)
{
std
::
shared_ptr
<
InferSession
>
InferSession
::
CreateSession
(
const
std
::
string
&
device
,
uint32_t
device_id
)
{
try
{
try
{
auto
session
=
std
::
make_shared
<
MSInferSession
>
();
auto
session
=
std
::
make_shared
<
MSInferSession
>
();
...
@@ -271,36 +270,18 @@ void MSInferSession::RegAllOp() {
...
@@ -271,36 +270,18 @@ void MSInferSession::RegAllOp() {
MsContext
::
GetInstance
()
->
set_param
<
int
>
(
MS_CTX_EXECUTION_MODE
,
kGraphMode
);
MsContext
::
GetInstance
()
->
set_param
<
int
>
(
MS_CTX_EXECUTION_MODE
,
kGraphMode
);
Py_Initialize
();
Py_Initialize
();
auto
c_expression
=
PyImport_ImportModule
(
"mindspore._c_expression"
);
auto
c_expression
=
PyImport_ImportModule
(
"mindspore._c_expression"
);
if
(
c_expression
==
nullptr
)
{
MS_EXCEPTION_IF_NULL
(
c_expression
);
MS_LOG
(
EXCEPTION
)
<<
"Failed to import mindspore._c_expression module."
;
return
;
}
PyObject
*
c_expression_dict
=
PyModule_GetDict
(
c_expression
);
PyObject
*
c_expression_dict
=
PyModule_GetDict
(
c_expression
);
if
(
c_expression_dict
==
nullptr
)
{
MS_EXCEPTION_IF_NULL
(
c_expression_dict
);
MS_LOG
(
EXCEPTION
)
<<
"Failed to get dict from mindspore._c_expression module."
;
return
;
}
PyObject
*
op_info_loader_class
=
PyDict_GetItemString
(
c_expression_dict
,
"OpInfoLoaderPy"
);
PyObject
*
op_info_loader_class
=
PyDict_GetItemString
(
c_expression_dict
,
"OpInfoLoaderPy"
);
if
(
op_info_loader_class
==
nullptr
)
{
MS_EXCEPTION_IF_NULL
(
op_info_loader_class
);
MS_LOG
(
EXCEPTION
)
<<
"Failed to get op_info_loader_class from mindspore._c_expression."
;
return
;
}
PyObject
*
op_info_loader
=
PyInstanceMethod_New
(
op_info_loader_class
);
PyObject
*
op_info_loader
=
PyInstanceMethod_New
(
op_info_loader_class
);
if
(
op_info_loader
==
nullptr
)
{
MS_EXCEPTION_IF_NULL
(
op_info_loader
);
MS_LOG
(
EXCEPTION
)
<<
"Failed to create op_info_loader instance."
;
return
;
}
PyObject
*
op_info_loader_ins
=
PyObject_CallObject
(
op_info_loader
,
nullptr
);
PyObject
*
op_info_loader_ins
=
PyObject_CallObject
(
op_info_loader
,
nullptr
);
if
(
op_info_loader_ins
==
nullptr
)
{
MS_EXCEPTION_IF_NULL
(
op_info_loader_ins
);
MS_LOG
(
EXCEPTION
)
<<
"Failed to call op_info_loader instance."
;
return
;
}
auto
all_ops_info_vector_addr_ul
=
PyObject_CallMethod
(
op_info_loader_ins
,
"get_all_ops_info"
,
nullptr
);
auto
all_ops_info_vector_addr_ul
=
PyObject_CallMethod
(
op_info_loader_ins
,
"get_all_ops_info"
,
nullptr
);
if
(
all_ops_info_vector_addr_ul
==
nullptr
)
{
MS_EXCEPTION_IF_NULL
(
all_ops_info_vector_addr_ul
);
MS_LOG
(
EXCEPTION
)
<<
"Failed to call get_all_ops_addr."
;
return
;
}
auto
all_ops_info_vector_addr
=
PyLong_AsVoidPtr
(
all_ops_info_vector_addr_ul
);
auto
all_ops_info_vector_addr
=
PyLong_AsVoidPtr
(
all_ops_info_vector_addr_ul
);
auto
all_ops_info
=
static_cast
<
std
::
vector
<
kernel
::
OpInfo
*>
*>
(
all_ops_info_vector_addr
);
auto
all_ops_info
=
static_cast
<
std
::
vector
<
kernel
::
OpInfo
*>
*>
(
all_ops_info_vector_addr
);
for
(
auto
op_info
:
*
all_ops_info
)
{
for
(
auto
op_info
:
*
all_ops_info
)
{
...
...
mindspore/ccsrc/backend/session/session_basic.cc
浏览文件 @
61d5539f
...
@@ -494,54 +494,52 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, Kern
...
@@ -494,54 +494,52 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, Kern
return
make_tuple
;
return
make_tuple
;
}
}
CNodePtr
SessionBasic
::
CreateNewCNode
(
const
CNodePtr
&
cnode
,
KernelGraph
*
graph
,
void
SessionBasic
::
GetCNodeInfo
(
const
CNodePtr
&
cnode
,
std
::
vector
<
AnfNodePtr
>
*
cnode_inputs
)
{
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
*
other_graph_cnode
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
cnode_inputs
);
MS_EXCEPTION_IF_NULL
(
other_graph_cnode
);
// get primitive of old node
std
::
vector
<
AnfNodePtr
>
cnode_inputs
;
auto
prim
=
AnfAlgo
::
GetCNodePrimitive
(
cnode
);
auto
prim
=
AnfAlgo
::
GetCNodePrimitive
(
cnode
);
if
(
prim
!=
nullptr
)
{
if
(
prim
!=
nullptr
)
{
// push attr to inputs[0] of new cnode
// push attr to inputs[0] of new cnode
cnode_inputs
.
push_back
(
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
*
prim
)));
cnode_inputs
->
push_back
(
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
*
prim
)));
}
else
{
}
else
{
auto
fg
=
AnfAlgo
::
GetCNodeFuncGraphPtr
(
cnode
);
auto
fg
=
AnfAlgo
::
GetCNodeFuncGraphPtr
(
cnode
);
MS_EXCEPTION_IF_NULL
(
fg
);
MS_EXCEPTION_IF_NULL
(
fg
);
auto
new_fg
=
BasicClone
(
fg
);
auto
new_fg
=
BasicClone
(
fg
);
cnode_inputs
.
push_back
(
std
::
make_shared
<
ValueNode
>
(
new_fg
));
cnode_inputs
->
push_back
(
std
::
make_shared
<
ValueNode
>
(
new_fg
));
}
}
}
void
SessionBasic
::
GetNewCNodeInputs
(
const
CNodePtr
&
cnode
,
KernelGraph
*
graph
,
std
::
vector
<
AnfNodePtr
>
*
cnode_inputs
,
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
*
other_graph_cnode
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
other_graph_cnode
);
MS_EXCEPTION_IF_NULL
(
cnode_inputs
);
auto
origin_inputs
=
cnode
->
inputs
();
auto
origin_inputs
=
cnode
->
inputs
();
bool
optimize_depend
=
false
;
bool
optimize_depend
=
IsPrimitiveCNode
(
cnode
,
prim
::
kPrimDepend
)
&&
origin_inputs
.
size
()
==
3
&&
bool
optimize_control_depend
=
false
;
origin_inputs
[
kRealInputIndexInDepend
]
->
isa
<
ValueNode
>
();
if
(
IsPrimitiveCNode
(
cnode
,
prim
::
kPrimDepend
)
&&
origin_inputs
.
size
()
==
3
&&
bool
optimize_control_depend
=
IsPrimitiveCNode
(
cnode
,
prim
::
kPrimControlDepend
)
&&
origin_inputs
.
size
()
==
3
;
origin_inputs
[
kRealInputIndexInDepend
]
->
isa
<
ValueNode
>
())
{
optimize_depend
=
true
;
}
if
(
IsPrimitiveCNode
(
cnode
,
prim
::
kPrimControlDepend
)
&&
origin_inputs
.
size
()
==
3
)
{
optimize_control_depend
=
true
;
}
// if has multiple depends,only select first depend as parameter
// if has multiple depends,only select first depend as parameter
for
(
size_t
input_idx
=
1
;
input_idx
<
origin_inputs
.
size
();
input_idx
++
)
{
for
(
size_t
input_idx
=
1
;
input_idx
<
origin_inputs
.
size
();
input_idx
++
)
{
auto
anf
=
origin_inputs
[
input_idx
];
auto
anf
=
origin_inputs
[
input_idx
];
MS_EXCEPTION_IF_NULL
(
anf
);
MS_EXCEPTION_IF_NULL
(
anf
);
// anf has been created before
// anf has been created before
if
(
graph
->
GetBackendAnfByFrontAnf
(
anf
)
!=
nullptr
)
{
if
(
graph
->
GetBackendAnfByFrontAnf
(
anf
)
!=
nullptr
)
{
cnode_inputs
.
emplace_back
(
graph
->
GetBackendAnfByFrontAnf
(
anf
));
cnode_inputs
->
emplace_back
(
graph
->
GetBackendAnfByFrontAnf
(
anf
));
continue
;
continue
;
}
else
if
(
other_graph_cnode
->
find
(
anf
)
!=
other_graph_cnode
->
end
())
{
}
else
if
(
other_graph_cnode
->
find
(
anf
)
!=
other_graph_cnode
->
end
())
{
cnode_inputs
.
push_back
((
*
other_graph_cnode
)[
anf
]);
cnode_inputs
->
push_back
((
*
other_graph_cnode
)[
anf
]);
continue
;
continue
;
}
else
if
(
anf
->
isa
<
ValueNode
>
()
&&
!
IsValueNode
<
FuncGraph
>
(
anf
))
{
}
else
if
(
anf
->
isa
<
ValueNode
>
()
&&
!
IsValueNode
<
FuncGraph
>
(
anf
))
{
// if input is a value node,
// if input is a value node,
auto
new_value_node
=
CreateNewValueNode
(
anf
,
graph
);
auto
new_value_node
=
CreateNewValueNode
(
anf
,
graph
);
if
(
new_value_node
!=
nullptr
)
{
if
(
new_value_node
!=
nullptr
)
{
cnode_inputs
.
emplace_back
(
new_value_node
);
cnode_inputs
->
emplace_back
(
new_value_node
);
}
}
continue
;
continue
;
}
else
if
(
anf
->
isa
<
Parameter
>
())
{
}
else
if
(
anf
->
isa
<
Parameter
>
())
{
auto
new_parameter
=
CreateNewParameterFromParameter
(
anf
,
graph
);
auto
new_parameter
=
CreateNewParameterFromParameter
(
anf
,
graph
);
cnode_inputs
.
push_back
(
new_parameter
);
cnode_inputs
->
push_back
(
new_parameter
);
if
(
GetGraphIdByNode
(
anf
)
==
kInvalidGraphId
)
{
if
(
GetGraphIdByNode
(
anf
)
==
kInvalidGraphId
)
{
graph
->
FrontBackendlMapAdd
(
anf
,
new_parameter
);
graph
->
FrontBackendlMapAdd
(
anf
,
new_parameter
);
}
else
{
}
else
{
...
@@ -549,20 +547,31 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
...
@@ -549,20 +547,31 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
}
}
continue
;
continue
;
}
else
if
(
optimize_depend
&&
input_idx
==
kDependAttachNodeIndex
)
{
}
else
if
(
optimize_depend
&&
input_idx
==
kDependAttachNodeIndex
)
{
cnode_inputs
.
push_back
(
origin_inputs
[
kRealInputIndexInDepend
]);
cnode_inputs
->
push_back
(
origin_inputs
[
kRealInputIndexInDepend
]);
continue
;
continue
;
}
else
if
(
optimize_control_depend
)
{
}
else
if
(
optimize_control_depend
)
{
cnode_inputs
.
push_back
(
NewValueNode
(
MakeValue
(
SizeToInt
(
input_idx
))));
cnode_inputs
->
push_back
(
NewValueNode
(
MakeValue
(
SizeToInt
(
input_idx
))));
}
else
{
}
else
{
// the input node is a cnode from other graph
// the input node is a cnode from other graph
auto
parameter_from_cnode
=
CreateNewParameterFromCNode
(
anf
,
graph
);
auto
parameter_from_cnode
=
CreateNewParameterFromCNode
(
anf
,
graph
);
if
(
parameter_from_cnode
==
nullptr
)
{
if
(
parameter_from_cnode
==
nullptr
)
{
parameter_from_cnode
=
NewValueNode
(
MakeValue
(
SizeToInt
(
input_idx
)));
parameter_from_cnode
=
NewValueNode
(
MakeValue
(
SizeToInt
(
input_idx
)));
}
}
cnode_inputs
.
push_back
(
parameter_from_cnode
);
cnode_inputs
->
push_back
(
parameter_from_cnode
);
(
*
other_graph_cnode
)[
anf
]
=
parameter_from_cnode
;
(
*
other_graph_cnode
)[
anf
]
=
parameter_from_cnode
;
}
}
}
}
}
CNodePtr
SessionBasic
::
CreateNewCNode
(
const
CNodePtr
&
cnode
,
KernelGraph
*
graph
,
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
*
other_graph_cnode
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
other_graph_cnode
);
// get primitive of old node
std
::
vector
<
AnfNodePtr
>
cnode_inputs
;
GetCNodeInfo
(
cnode
,
&
cnode_inputs
);
GetNewCNodeInputs
(
cnode
,
graph
,
&
cnode_inputs
,
other_graph_cnode
);
TraceManager
::
DebugTrace
(
std
::
make_shared
<
TraceCopy
>
(
cnode
->
debug_info
()));
TraceManager
::
DebugTrace
(
std
::
make_shared
<
TraceCopy
>
(
cnode
->
debug_info
()));
auto
new_cnode
=
graph
->
NewCNode
(
cnode_inputs
);
auto
new_cnode
=
graph
->
NewCNode
(
cnode_inputs
);
TraceManager
::
EndTrace
();
TraceManager
::
EndTrace
();
...
@@ -593,6 +602,42 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra
...
@@ -593,6 +602,42 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra
return
partial_node
;
return
partial_node
;
}
}
std
::
vector
<
AnfNodePtr
>
SessionBasic
::
CreateCallSwitchInputs
(
const
CNodePtr
&
cnode
,
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
graph
);
std
::
vector
<
AnfNodePtr
>
cnode_inputs
=
{
graph
->
NewValueNode
(
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimCall
->
name
())))};
auto
attr_input
=
cnode
->
input
(
kAnfPrimitiveIndex
);
MS_EXCEPTION_IF_NULL
(
attr_input
);
auto
cnode_input
=
graph
->
GetBackendAnfByFrontAnf
(
attr_input
);
auto
switch_cnode
=
cnode_input
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
switch_cnode
);
if
(
cnode
->
inputs
().
size
()
<
2
)
{
cnode_inputs
=
switch_cnode
->
inputs
();
return
cnode_inputs
;
}
std
::
vector
<
AnfNodePtr
>
switch_inputs
=
{
switch_cnode
->
input
(
kAnfPrimitiveIndex
),
switch_cnode
->
input
(
kFirstDataInputIndex
)};
for
(
size_t
index
=
kFirstBranchInSwitch
;
index
<
switch_cnode
->
inputs
().
size
();
index
++
)
{
auto
node
=
switch_cnode
->
input
(
index
);
// there is real input in call, should put it to true and false branch in switch
if
(
AnfAlgo
::
CheckPrimitiveType
(
node
,
prim
::
kPrimPartial
))
{
auto
partial_node
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
partial_node
);
std
::
vector
<
AnfNodePtr
>
partial_inputs
=
partial_node
->
inputs
();
partial_inputs
.
emplace_back
(
graph
->
GetBackendAnfByFrontAnf
(
cnode
->
input
(
kFirstDataInputIndex
)));
auto
new_partial
=
graph
->
NewCNode
(
partial_inputs
);
switch_inputs
.
emplace_back
(
new_partial
);
}
}
if
(
switch_inputs
.
size
()
<
kSwitchInputSize
)
{
MS_LOG
(
EXCEPTION
)
<<
"Switch inputs size: "
<<
switch_inputs
.
size
()
<<
"less than "
<<
kSwitchInputSize
;
}
auto
switch_node
=
graph
->
NewCNode
(
switch_inputs
);
cnode_inputs
.
emplace_back
(
switch_node
);
return
cnode_inputs
;
}
std
::
vector
<
AnfNodePtr
>
SessionBasic
::
CreateSwitchOrPartialNode
(
const
CNodePtr
&
cnode
,
KernelGraph
*
graph
)
{
std
::
vector
<
AnfNodePtr
>
SessionBasic
::
CreateSwitchOrPartialNode
(
const
CNodePtr
&
cnode
,
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
graph
);
...
@@ -618,32 +663,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &
...
@@ -618,32 +663,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &
});
});
return
cnode_inputs
;
return
cnode_inputs
;
}
else
if
(
AnfAlgo
::
CheckPrimitiveType
(
cnode_input
,
prim
::
kPrimSwitch
))
{
}
else
if
(
AnfAlgo
::
CheckPrimitiveType
(
cnode_input
,
prim
::
kPrimSwitch
))
{
auto
switch_cnode
=
cnode_input
->
cast
<
CNodePtr
>
();
return
CreateCallSwitchInputs
(
cnode
,
graph
);
MS_EXCEPTION_IF_NULL
(
switch_cnode
);
if
(
cnode
->
inputs
().
size
()
<
2
)
{
cnode_inputs
=
switch_cnode
->
inputs
();
return
cnode_inputs
;
}
std
::
vector
<
AnfNodePtr
>
switch_inputs
=
{
switch_cnode
->
input
(
kAnfPrimitiveIndex
),
switch_cnode
->
input
(
kFirstDataInputIndex
)};
for
(
size_t
index
=
kFirstBranchInSwitch
;
index
<
switch_cnode
->
inputs
().
size
();
index
++
)
{
auto
node
=
switch_cnode
->
input
(
index
);
// there is real input in call, should put it to true and false branch in switch
if
(
AnfAlgo
::
CheckPrimitiveType
(
node
,
prim
::
kPrimPartial
))
{
auto
partial_node
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
partial_node
);
std
::
vector
<
AnfNodePtr
>
partial_inputs
=
partial_node
->
inputs
();
partial_inputs
.
emplace_back
(
graph
->
GetBackendAnfByFrontAnf
(
cnode
->
input
(
kFirstDataInputIndex
)));
auto
new_partial
=
graph
->
NewCNode
(
partial_inputs
);
switch_inputs
.
emplace_back
(
new_partial
);
}
}
if
(
switch_inputs
.
size
()
<
kSwitchInputSize
)
{
MS_LOG
(
EXCEPTION
)
<<
"Switch inputs size: "
<<
switch_inputs
.
size
()
<<
"less than "
<<
kSwitchInputSize
;
}
auto
switch_node
=
graph
->
NewCNode
(
switch_inputs
);
cnode_inputs
.
emplace_back
(
switch_node
);
return
cnode_inputs
;
}
}
MS_LOG
(
EXCEPTION
)
<<
"CNode input[0] must be partial or switch."
;
MS_LOG
(
EXCEPTION
)
<<
"CNode input[0] must be partial or switch."
;
}
}
...
...
mindspore/ccsrc/backend/session/session_basic.h
浏览文件 @
61d5539f
...
@@ -131,6 +131,10 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
...
@@ -131,6 +131,10 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
std
::
vector
<
AnfNodePtr
>
CreateSwitchOrPartialNode
(
const
CNodePtr
&
cnode
,
KernelGraph
*
graph
);
std
::
vector
<
AnfNodePtr
>
CreateSwitchOrPartialNode
(
const
CNodePtr
&
cnode
,
KernelGraph
*
graph
);
std
::
vector
<
AnfNodePtr
>
CreateValueNode
(
const
CNodePtr
&
cnode
,
KernelGraph
*
graph
);
std
::
vector
<
AnfNodePtr
>
CreateValueNode
(
const
CNodePtr
&
cnode
,
KernelGraph
*
graph
);
void
CreateCNodeInputs
(
const
CNodePtr
&
cnode
,
KernelGraph
*
graph
,
std
::
vector
<
AnfNodePtr
>
*
cnode_inputs
);
void
CreateCNodeInputs
(
const
CNodePtr
&
cnode
,
KernelGraph
*
graph
,
std
::
vector
<
AnfNodePtr
>
*
cnode_inputs
);
std
::
vector
<
AnfNodePtr
>
CreateCallSwitchInputs
(
const
CNodePtr
&
cnode
,
KernelGraph
*
graph
);
void
GetCNodeInfo
(
const
CNodePtr
&
cnode
,
std
::
vector
<
AnfNodePtr
>
*
cnode_inputs
);
void
GetNewCNodeInputs
(
const
CNodePtr
&
cnode
,
KernelGraph
*
graph
,
std
::
vector
<
AnfNodePtr
>
*
cnode_inputs
,
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
*
other_graph_cnode
);
protected:
protected:
void
RunInfer
(
NotNull
<
FuncGraphPtr
>
func_graph
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
);
void
RunInfer
(
NotNull
<
FuncGraphPtr
>
func_graph
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录