Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
817d1ae2
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
817d1ae2
编写于
6月 22, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2403 Session code review
Merge pull request !2403 from JoyLvliang/session-code-review
上级
fa216697
4de9e250
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
20 addition
and
9 deletion
+20
-9
mindspore/ccsrc/session/session_basic.cc
mindspore/ccsrc/session/session_basic.cc
+19
-4
mindspore/ccsrc/session/session_basic.h
mindspore/ccsrc/session/session_basic.h
+1
-5
未找到文件。
mindspore/ccsrc/session/session_basic.cc
浏览文件 @
817d1ae2
...
@@ -52,6 +52,7 @@ PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
...
@@ -52,6 +52,7 @@ PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
return
nullptr
;
return
nullptr
;
}
}
auto
param_value
=
std
::
dynamic_pointer_cast
<
ParamValuePy
>
(
parameter
->
default_param
());
auto
param_value
=
std
::
dynamic_pointer_cast
<
ParamValuePy
>
(
parameter
->
default_param
());
MS_EXCEPTION_IF_NULL
(
param_value
);
auto
py_param
=
param_value
->
value
();
auto
py_param
=
param_value
->
value
();
return
py_param
.
ptr
();
return
py_param
.
ptr
();
}
}
...
@@ -69,7 +70,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
...
@@ -69,7 +70,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
}
}
if
(
node
->
isa
<
Parameter
>
())
{
if
(
node
->
isa
<
Parameter
>
())
{
for
(
size_t
input_idx
=
0
;
input_idx
<
graph
.
inputs
().
size
();
input_idx
++
)
{
for
(
size_t
input_idx
=
0
;
input_idx
<
graph
.
inputs
().
size
();
input_idx
++
)
{
if
(
input_idx
>
input_tensors
.
size
())
{
if
(
input_idx
>
=
input_tensors
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"input idx:"
<<
input_idx
<<
"out of range:"
<<
input_tensors
.
size
();
MS_LOG
(
EXCEPTION
)
<<
"input idx:"
<<
input_idx
<<
"out of range:"
<<
input_tensors
.
size
();
}
}
if
(
graph
.
inputs
()[
input_idx
]
==
node
)
{
if
(
graph
.
inputs
()[
input_idx
]
==
node
)
{
...
@@ -149,6 +150,8 @@ BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
...
@@ -149,6 +150,8 @@ BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
}
}
ValueNodePtr
CreateNewValueNode
(
const
AnfNodePtr
&
anf
,
KernelGraph
*
graph
)
{
ValueNodePtr
CreateNewValueNode
(
const
AnfNodePtr
&
anf
,
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
anf
);
MS_EXCEPTION_IF_NULL
(
graph
);
auto
value_node
=
anf
->
cast
<
ValueNodePtr
>
();
auto
value_node
=
anf
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_node
);
MS_EXCEPTION_IF_NULL
(
value_node
);
auto
value
=
value_node
->
value
();
auto
value
=
value_node
->
value
();
...
@@ -229,6 +232,7 @@ ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph,
...
@@ -229,6 +232,7 @@ ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph,
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
input_tensor
);
MS_EXCEPTION_IF_NULL
(
input_tensor
);
auto
value_node
=
std
::
make_shared
<
ValueNode
>
(
input_tensor
);
auto
value_node
=
std
::
make_shared
<
ValueNode
>
(
input_tensor
);
MS_EXCEPTION_IF_NULL
(
value_node
);
// construct abstract of value node
// construct abstract of value node
auto
type_of_tensor
=
input_tensor
->
Dtype
();
auto
type_of_tensor
=
input_tensor
->
Dtype
();
auto
shape_of_tensor
=
input_tensor
->
shape
();
auto
shape_of_tensor
=
input_tensor
->
shape
();
...
@@ -242,6 +246,7 @@ ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph,
...
@@ -242,6 +246,7 @@ ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph,
ParameterPtr
ConstructRunOpParameter
(
const
std
::
shared_ptr
<
KernelGraph
>
&
graph
,
const
tensor
::
TensorPtr
&
input_tensor
,
ParameterPtr
ConstructRunOpParameter
(
const
std
::
shared_ptr
<
KernelGraph
>
&
graph
,
const
tensor
::
TensorPtr
&
input_tensor
,
int
tensor_mask
)
{
int
tensor_mask
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
auto
param
=
graph
->
NewParameter
();
auto
param
=
graph
->
NewParameter
();
MS_EXCEPTION_IF_NULL
(
param
);
MS_EXCEPTION_IF_NULL
(
param
);
if
(
tensor_mask
==
kParameterWeightTensorMask
)
{
if
(
tensor_mask
==
kParameterWeightTensorMask
)
{
...
@@ -295,6 +300,7 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
...
@@ -295,6 +300,7 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
}
}
bool
ExistSummaryNode
(
const
KernelGraph
*
graph
)
{
bool
ExistSummaryNode
(
const
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
auto
ret
=
graph
->
get_return
();
auto
ret
=
graph
->
get_return
();
MS_EXCEPTION_IF_NULL
(
ret
);
MS_EXCEPTION_IF_NULL
(
ret
);
auto
all_nodes
=
DeepLinkedGraphSearch
(
ret
);
auto
all_nodes
=
DeepLinkedGraphSearch
(
ret
);
...
@@ -315,7 +321,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
...
@@ -315,7 +321,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
if
(
!
anf
->
isa
<
Parameter
>
())
{
if
(
!
anf
->
isa
<
Parameter
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"anf["
<<
anf
->
DebugString
()
<<
"] is not a parameter"
;
MS_LOG
(
EXCEPTION
)
<<
"anf["
<<
anf
->
DebugString
()
<<
"] is not a parameter"
;
}
}
MS_EXCEPTION_IF_NULL
(
graph
);
auto
m_tensor
=
GetParamDefaultInputTensor
(
anf
);
auto
m_tensor
=
GetParamDefaultInputTensor
(
anf
);
auto
valid_inputs
=
graph
->
MutableValidInputs
();
auto
valid_inputs
=
graph
->
MutableValidInputs
();
MS_EXCEPTION_IF_NULL
(
valid_inputs
);
MS_EXCEPTION_IF_NULL
(
valid_inputs
);
...
@@ -344,6 +350,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
...
@@ -344,6 +350,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
AnfNodePtr
SessionBasic
::
CreateNewParameterFromCNode
(
const
AnfNodePtr
&
anf
,
bool
valid_input
,
KernelGraph
*
graph
)
{
AnfNodePtr
SessionBasic
::
CreateNewParameterFromCNode
(
const
AnfNodePtr
&
anf
,
bool
valid_input
,
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
anf
);
MS_EXCEPTION_IF_NULL
(
anf
);
MS_EXCEPTION_IF_NULL
(
graph
);
MS_LOG
(
INFO
)
<<
"Create a new parameter from cnode["
<<
anf
->
DebugString
()
<<
"]"
;
MS_LOG
(
INFO
)
<<
"Create a new parameter from cnode["
<<
anf
->
DebugString
()
<<
"]"
;
auto
parameters
=
CreateParameterFromTuple
(
anf
,
valid_input
,
graph
);
auto
parameters
=
CreateParameterFromTuple
(
anf
,
valid_input
,
graph
);
if
(
parameters
.
empty
())
{
if
(
parameters
.
empty
())
{
...
@@ -482,6 +489,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
...
@@ -482,6 +489,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
ValueNodePtr
SessionBasic
::
CreateValueNodeKernelGraph
(
const
AnfNodePtr
&
anf
,
KernelGraph
*
graph
)
{
ValueNodePtr
SessionBasic
::
CreateValueNodeKernelGraph
(
const
AnfNodePtr
&
anf
,
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
anf
);
MS_EXCEPTION_IF_NULL
(
anf
);
MS_EXCEPTION_IF_NULL
(
graph
);
auto
value_node
=
anf
->
cast
<
ValueNodePtr
>
();
auto
value_node
=
anf
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_node
);
MS_EXCEPTION_IF_NULL
(
value_node
);
auto
sub_func_graph
=
AnfAlgo
::
GetValueNodeFuncGraph
(
anf
);
auto
sub_func_graph
=
AnfAlgo
::
GetValueNodeFuncGraph
(
anf
);
...
@@ -509,6 +517,7 @@ ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, Ker
...
@@ -509,6 +517,7 @@ ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, Ker
ParameterPtr
SessionBasic
::
CreateNewParameter
(
const
AnfNodePtr
&
anf
,
KernelGraph
*
graph
)
{
ParameterPtr
SessionBasic
::
CreateNewParameter
(
const
AnfNodePtr
&
anf
,
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
anf
);
MS_EXCEPTION_IF_NULL
(
anf
);
MS_EXCEPTION_IF_NULL
(
graph
);
if
(
!
anf
->
isa
<
Parameter
>
())
{
if
(
!
anf
->
isa
<
Parameter
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"anf["
<<
anf
->
DebugString
()
<<
"] is not a parameter"
;
MS_LOG
(
EXCEPTION
)
<<
"anf["
<<
anf
->
DebugString
()
<<
"] is not a parameter"
;
}
}
...
@@ -536,6 +545,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
...
@@ -536,6 +545,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
KernelGraphPtr
SessionBasic
::
ConstructKernelGraph
(
const
AnfNodePtrList
&
lst
,
const
AnfNodePtrList
&
outputs
)
{
KernelGraphPtr
SessionBasic
::
ConstructKernelGraph
(
const
AnfNodePtrList
&
lst
,
const
AnfNodePtrList
&
outputs
)
{
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
other_graph_cnode
;
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
other_graph_cnode
;
auto
graph
=
NewKernelGraph
();
auto
graph
=
NewKernelGraph
();
MS_EXCEPTION_IF_NULL
(
graph
);
MS_LOG
(
INFO
)
<<
"Create graph: "
<<
graph
->
graph_id
();
MS_LOG
(
INFO
)
<<
"Create graph: "
<<
graph
->
graph_id
();
size_t
from_other_graph_depend_num
=
0
;
size_t
from_other_graph_depend_num
=
0
;
for
(
const
auto
&
node
:
lst
)
{
for
(
const
auto
&
node
:
lst
)
{
...
@@ -585,6 +595,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
...
@@ -585,6 +595,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
MS_EXCEPTION_IF_NULL
(
all_out_graph
);
MS_EXCEPTION_IF_NULL
(
all_out_graph
);
auto
node_list
=
TopoSort
(
func_graph
->
get_return
());
auto
node_list
=
TopoSort
(
func_graph
->
get_return
());
auto
graph
=
NewKernelGraph
();
auto
graph
=
NewKernelGraph
();
MS_EXCEPTION_IF_NULL
(
graph
);
front_backend_graph_map_
[
func_graph
]
=
graph
;
front_backend_graph_map_
[
func_graph
]
=
graph
;
MS_LOG
(
INFO
)
<<
"Create graph: "
<<
graph
->
graph_id
();
MS_LOG
(
INFO
)
<<
"Create graph: "
<<
graph
->
graph_id
();
...
@@ -724,8 +735,8 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
...
@@ -724,8 +735,8 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
}
}
auto
anf_outputs
=
kernel_graph
->
outputs
();
auto
anf_outputs
=
kernel_graph
->
outputs
();
for
(
auto
&
item
:
anf_outputs
)
{
for
(
auto
&
item
:
anf_outputs
)
{
MS_LOG
(
INFO
)
<<
"update output["
<<
item
->
DebugString
()
<<
"]"
;
MS_EXCEPTION_IF_NULL
(
item
);
MS_EXCEPTION_IF_NULL
(
item
);
MS_LOG
(
INFO
)
<<
"update output["
<<
item
->
DebugString
()
<<
"]"
;
if
(
AnfAlgo
::
IsTupleOutput
(
item
)
&&
AnfAlgo
::
IsRealKernel
(
item
))
{
if
(
AnfAlgo
::
IsTupleOutput
(
item
)
&&
AnfAlgo
::
IsRealKernel
(
item
))
{
outputs
->
emplace_back
(
CreatTupleForOutput
(
item
,
*
kernel_graph
,
input_tensors
));
outputs
->
emplace_back
(
CreatTupleForOutput
(
item
,
*
kernel_graph
,
input_tensors
));
continue
;
continue
;
...
@@ -761,6 +772,7 @@ void SessionBasic::GetSummaryNodes(KernelGraph *graph) {
...
@@ -761,6 +772,7 @@ void SessionBasic::GetSummaryNodes(KernelGraph *graph) {
auto
node
=
cnode
->
input
(
kSummaryGetItem
);
auto
node
=
cnode
->
input
(
kSummaryGetItem
);
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
node
);
auto
item_with_index
=
AnfAlgo
::
VisitKernelWithReturnType
(
node
,
0
,
true
);
auto
item_with_index
=
AnfAlgo
::
VisitKernelWithReturnType
(
node
,
0
,
true
);
MS_EXCEPTION_IF_NULL
(
item_with_index
.
first
);
if
(
!
AnfAlgo
::
IsRealKernel
(
item_with_index
.
first
))
{
if
(
!
AnfAlgo
::
IsRealKernel
(
item_with_index
.
first
))
{
MS_LOG
(
EXCEPTION
)
<<
"Unexpected node:"
<<
item_with_index
.
first
->
DebugString
();
MS_LOG
(
EXCEPTION
)
<<
"Unexpected node:"
<<
item_with_index
.
first
->
DebugString
();
}
}
...
@@ -812,6 +824,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
...
@@ -812,6 +824,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
graph
);
std
::
vector
<
AnfNodePtr
>
output_args
;
std
::
vector
<
AnfNodePtr
>
output_args
;
for
(
const
auto
&
output
:
outputs
)
{
for
(
const
auto
&
output
:
outputs
)
{
MS_EXCEPTION_IF_NULL
(
output
);
MS_LOG
(
INFO
)
<<
"output:"
<<
output
->
DebugString
();
MS_LOG
(
INFO
)
<<
"output:"
<<
output
->
DebugString
();
}
}
auto
FindEqu
=
[
graph
,
outputs
](
const
AnfNodePtr
&
out
)
->
AnfNodePtr
{
auto
FindEqu
=
[
graph
,
outputs
](
const
AnfNodePtr
&
out
)
->
AnfNodePtr
{
...
@@ -883,7 +896,9 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
...
@@ -883,7 +896,9 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
}
}
auto
parameter
=
ConstructRunOpParameter
(
graph
,
input_tensors
[
i
],
tensors_mask
[
i
]);
auto
parameter
=
ConstructRunOpParameter
(
graph
,
input_tensors
[
i
],
tensors_mask
[
i
]);
inputs
.
push_back
(
parameter
);
inputs
.
push_back
(
parameter
);
graph
->
MutableInputs
()
->
push_back
(
parameter
);
auto
mutable_inputs
=
graph
->
MutableInputs
();
MS_EXCEPTION_IF_NULL
(
mutable_inputs
);
mutable_inputs
->
push_back
(
parameter
);
}
}
// set execution order
// set execution order
auto
cnode
=
graph
->
NewCNode
(
inputs
);
auto
cnode
=
graph
->
NewCNode
(
inputs
);
...
...
mindspore/ccsrc/session/session_basic.h
浏览文件 @
817d1ae2
...
@@ -48,11 +48,7 @@ using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
...
@@ -48,11 +48,7 @@ using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
class
SessionBasic
{
class
SessionBasic
{
public:
public:
SessionBasic
()
:
device_id_
(
0
)
{
SessionBasic
()
:
context_
(
nullptr
),
summary_callback_
(
nullptr
),
device_id_
(
0
)
{}
graphs_
=
{};
run_op_graphs_
=
{};
summary_callback_
=
nullptr
;
}
virtual
void
Init
(
uint32_t
device_id
)
{
device_id_
=
device_id
;
}
virtual
void
Init
(
uint32_t
device_id
)
{
device_id_
=
device_id
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录