Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
32405f9a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
32405f9a
编写于
7月 03, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 03, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2817 add internal output
Merge pull request !2817 from kisnwang/optimize-sub-graph-memcpy
上级
0f399f0a
e9067b4a
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
269 addition
and
97 deletion
+269
-97
mindspore/ccsrc/device/kernel_runtime.cc
mindspore/ccsrc/device/kernel_runtime.cc
+2
-1
mindspore/ccsrc/device/kernel_runtime.h
mindspore/ccsrc/device/kernel_runtime.h
+1
-1
mindspore/ccsrc/ir/anf.cc
mindspore/ccsrc/ir/anf.cc
+40
-1
mindspore/ccsrc/ir/anf.h
mindspore/ccsrc/ir/anf.h
+1
-1
mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc
.../ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc
+11
-1
mindspore/ccsrc/session/ascend_session.cc
mindspore/ccsrc/session/ascend_session.cc
+0
-9
mindspore/ccsrc/session/ascend_session.h
mindspore/ccsrc/session/ascend_session.h
+0
-2
mindspore/ccsrc/session/kernel_graph.cc
mindspore/ccsrc/session/kernel_graph.cc
+73
-0
mindspore/ccsrc/session/kernel_graph.h
mindspore/ccsrc/session/kernel_graph.h
+10
-0
mindspore/ccsrc/session/session_basic.cc
mindspore/ccsrc/session/session_basic.cc
+127
-42
mindspore/ccsrc/session/session_basic.h
mindspore/ccsrc/session/session_basic.h
+4
-0
mindspore/ccsrc/vm/transform.cc
mindspore/ccsrc/vm/transform.cc
+0
-39
未找到文件。
mindspore/ccsrc/device/kernel_runtime.cc
浏览文件 @
32405f9a
...
...
@@ -340,7 +340,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
}
}
void
KernelRuntime
::
AssignStaticMemoryOutput
(
const
session
::
KernelGraph
*
graph
)
{
void
KernelRuntime
::
AssignStaticMemoryOutput
(
session
::
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
auto
nodes
=
AnfAlgo
::
GetAllOutput
(
graph
->
output
(),
{
prim
::
kPrimTupleGetItem
});
std
::
vector
<
session
::
KernelWithIndex
>
non_communication_op
;
...
...
@@ -351,6 +351,7 @@ void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph)
if
(
!
item_with_index
.
first
->
isa
<
CNode
>
()
||
!
AnfAlgo
::
IsRealKernel
(
item_with_index
.
first
))
{
continue
;
}
graph
->
AddFinalOutputKernel
(
item_with_index
.
first
);
if
(
AnfAlgo
::
IsCommunicationOp
(
item_with_index
.
first
))
{
AssignCommunicationNodeMem
(
kStaticMem
,
item_with_index
.
first
);
}
else
{
...
...
mindspore/ccsrc/device/kernel_runtime.h
浏览文件 @
32405f9a
...
...
@@ -95,7 +95,7 @@ class KernelRuntime {
#endif
private:
void
AssignStaticMemoryOutput
(
const
session
::
KernelGraph
*
graph
);
void
AssignStaticMemoryOutput
(
session
::
KernelGraph
*
graph
);
void
GenLaunchArgs
(
const
session
::
KernelGraph
&
graph
,
const
AnfNodePtr
&
kernel
,
AddressPtrList
*
kernel_inputs
,
AddressPtrList
*
kernel_workspaces
,
AddressPtrList
*
kernel_outputs
);
bool
LaunchKernelMod
(
const
session
::
KernelGraph
&
graph
);
...
...
mindspore/ccsrc/ir/anf.cc
浏览文件 @
32405f9a
...
...
@@ -25,7 +25,7 @@
#include "ir/func_graph.h"
#include "ir/primitive_base.h"
#include "utils/context/ms_context.h"
#include "operator/ops.h"
namespace
mindspore
{
...
...
@@ -179,4 +179,43 @@ std::string get_id(const AnfNodePtr &node) {
void
reset_id
()
{
node_ids
.
clear
();
}
}
// namespace id_generator
std
::
string
GetCNodeTarget
(
const
AnfNodePtr
&
node
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
std
::
string
default_target
=
context_ptr
->
device_target
();
if
(
!
node
->
isa
<
CNode
>
())
{
return
default_target
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
attr_input
=
cnode
->
input
(
0
);
if
(
attr_input
==
nullptr
)
{
return
default_target
;
}
auto
value_node
=
attr_input
->
cast
<
ValueNodePtr
>
();
if
(
value_node
==
nullptr
)
{
return
default_target
;
}
auto
value
=
value_node
->
value
();
if
(
value
==
nullptr
)
{
return
default_target
;
}
if
(
!
value
->
isa
<
Primitive
>
())
{
return
default_target
;
}
auto
primitive
=
value
->
cast
<
PrimitivePtr
>
();
auto
att_target
=
primitive
->
GetAttr
(
"primitive_target"
);
if
(
att_target
!=
nullptr
)
{
if
(
!
att_target
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Only support string CPU|GPU|Ascend for primitive_target"
;
}
auto
target
=
GetValue
<
std
::
string
>
(
att_target
);
if
(
kTargetSet
.
find
(
target
)
==
kTargetSet
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"Only support string CPU|GPU|Ascend for primitive_target"
;
}
return
target
;
}
return
default_target
;
}
}
// namespace mindspore
mindspore/ccsrc/ir/anf.h
浏览文件 @
32405f9a
...
...
@@ -448,7 +448,7 @@ void reset_id();
}
// namespace id_generator
using
TaggedNodeMap
=
std
::
unordered_map
<
AnfNodePtr
,
size_t
>
;
using
TaggedGraph
=
std
::
pair
<
FuncGraphPtr
,
TaggedNodeMap
>
;
std
::
string
GetCNodeTarget
(
const
AnfNodePtr
&
node
);
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_IR_ANF_H_
mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc
浏览文件 @
32405f9a
...
...
@@ -46,6 +46,11 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
if
(
node
==
nullptr
||
!
AnfAlgo
::
IsRealKernel
(
node
))
{
return
nullptr
;
}
AnfNodePtr
front_node
;
auto
kernel_graph
=
func_graph
->
cast
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
();
if
(
kernel_graph
!=
nullptr
&&
kernel_graph
->
IsInternalOutput
(
node
))
{
front_node
=
kernel_graph
->
GetFrontNodeByInternalOutput
(
node
);
}
AnfAlgo
::
SetNodeAttr
(
kAttrVisited
,
MakeValue
(
true
),
node
);
MS_LOG
(
DEBUG
)
<<
"====process op: "
<<
node
->
DebugString
();
AnfNodePtr
new_node
=
InsertTransOpForInput
(
func_graph
,
node
,
kernel_select_
);
...
...
@@ -56,7 +61,12 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
return
new_node
;
}
}
return
InsertTransOpForOutput
(
func_graph
,
new_node
,
kernel_select_
);
auto
final_node
=
InsertTransOpForOutput
(
func_graph
,
new_node
,
kernel_select_
);
if
(
kernel_graph
!=
nullptr
&&
front_node
!=
nullptr
)
{
auto
old_node
=
kernel_graph
->
GetInternalOutputByFrontNode
(
front_node
);
kernel_graph
->
ReplaceInternalOutput
(
old_node
,
final_node
);
}
return
final_node
;
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/session/ascend_session.cc
浏览文件 @
32405f9a
...
...
@@ -987,15 +987,6 @@ void AscendSession::SetFinalGraphOutput(const BaseRef &output) {
}
}
KernelGraphPtr
AscendSession
::
GetGraph
(
mindspore
::
GraphId
graph_id
)
{
auto
it
=
graphs_
.
find
(
graph_id
);
if
(
it
==
graphs_
.
end
())
{
MS_LOG
(
WARNING
)
<<
"Can't find graph "
<<
graph_id
;
return
nullptr
;
}
return
it
->
second
;
}
void
AscendSession
::
InsertSwitchToGraph
(
GraphId
condition_graph_id
,
GraphId
true_graph_id
)
{
MS_LOG
(
INFO
)
<<
"Start!"
;
MS_LOG
(
INFO
)
<<
"Condition graph id["
<<
condition_graph_id
<<
"],true graph id["
<<
true_graph_id
<<
"]"
;
...
...
mindspore/ccsrc/session/ascend_session.h
浏览文件 @
32405f9a
...
...
@@ -128,8 +128,6 @@ class AscendSession : public SessionBasic {
void
InsertDependToGraph
(
GraphId
graph_id
,
const
AnfNodePtr
&
attch_node
);
// insert depend to graph, used to attch control nodes to graph
void
InsertControlDependToGraph
(
GraphId
graph_id
,
const
AnfNodePtr
&
first_node
,
const
AnfNodePtr
&
second_node
);
// Get graph by graph id ,if not exist return null ptr
KernelGraphPtr
GetGraph
(
GraphId
graph_id
);
// set child graph parameter if front arg is a anf
void
SetChildGraphParameter
(
const
AnfNodePtr
&
front_anf
,
GraphId
to_graph_id
,
size_t
input_idx
);
// set child graph parameter if front arg is a tensor
...
...
mindspore/ccsrc/session/kernel_graph.cc
浏览文件 @
32405f9a
...
...
@@ -329,6 +329,9 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
FrontBackendlMapUpdate
(
cnode
,
new_cnode
);
}
AnfAlgo
::
SetGraphId
(
graph_id_
,
cnode
.
get
());
if
(
IsInternalOutput
(
cnode
))
{
ReplaceInternalOutput
(
cnode
,
new_cnode
);
}
return
new_cnode
;
}
...
...
@@ -872,6 +875,76 @@ void KernelGraph::PrintGraphExecuteOrder() const {
}
}
void
KernelGraph
::
AddInternalOutput
(
const
AnfNodePtr
&
front_node
,
const
AnfNodePtr
&
node
)
{
if
(
front_node
==
nullptr
||
node
==
nullptr
)
{
MS_LOG
(
INFO
)
<<
"Front node or node is nullptr"
;
return
;
}
MS_LOG
(
INFO
)
<<
"Add internal node "
<<
node
->
DebugString
()
<<
" with front node "
<<
front_node
->
DebugString
();
front_to_internal_outputs_map_
[
front_node
]
=
node
;
internal_outputs_to_front_map_
[
node
]
=
front_node
;
}
void
KernelGraph
::
ReplaceInternalOutput
(
const
AnfNodePtr
&
node
,
const
AnfNodePtr
&
new_node
)
{
if
(
new_node
==
nullptr
||
node
==
nullptr
)
{
MS_LOG
(
INFO
)
<<
"New node or node is nullptr"
;
return
;
}
if
(
node
==
new_node
)
{
MS_LOG
(
INFO
)
<<
"New node and node is the same"
;
return
;
}
auto
iter
=
internal_outputs_to_front_map_
.
find
(
node
);
if
(
iter
==
internal_outputs_to_front_map_
.
end
())
{
MS_LOG
(
INFO
)
<<
"Node is not internal output"
;
return
;
}
MS_LOG
(
INFO
)
<<
"Replace internal node "
<<
node
->
DebugString
()
<<
" To "
<<
new_node
->
DebugString
();
internal_outputs_to_front_map_
[
new_node
]
=
iter
->
second
;
front_to_internal_outputs_map_
[
iter
->
second
]
=
new_node
;
internal_outputs_to_front_map_
.
erase
(
iter
);
}
AnfNodePtr
KernelGraph
::
GetInternalOutputByFrontNode
(
const
AnfNodePtr
&
front_node
)
const
{
auto
iter
=
front_to_internal_outputs_map_
.
find
(
front_node
);
if
(
iter
!=
front_to_internal_outputs_map_
.
end
())
{
return
iter
->
second
;
}
return
nullptr
;
}
bool
KernelGraph
::
IsInternalOutput
(
const
AnfNodePtr
&
node
)
const
{
if
(
internal_outputs_to_front_map_
.
find
(
node
)
!=
internal_outputs_to_front_map_
.
end
())
{
return
true
;
}
return
false
;
}
AnfNodePtr
KernelGraph
::
GetFrontNodeByInternalOutput
(
const
AnfNodePtr
&
node
)
const
{
auto
iter
=
internal_outputs_to_front_map_
.
find
(
node
);
if
(
iter
!=
internal_outputs_to_front_map_
.
end
())
{
return
iter
->
second
;
}
return
nullptr
;
}
void
KernelGraph
::
AddFinalOutputKernel
(
const
AnfNodePtr
&
node
)
{
if
(
node
==
nullptr
)
{
return
;
}
(
void
)
final_output_kernels_
.
insert
(
node
);
}
bool
KernelGraph
::
IsFinalOutputKernel
(
const
AnfNodePtr
&
node
)
const
{
if
(
node
==
nullptr
)
{
return
false
;
}
if
(
final_output_kernels_
.
find
(
node
)
!=
final_output_kernels_
.
end
())
{
return
true
;
}
return
false
;
}
std
::
string
KernelGraph
::
ToString
()
const
{
return
std
::
string
(
"kernel_graph_"
).
append
(
std
::
to_string
(
graph_id_
));
}
KernelGraph
::~
KernelGraph
()
{
device
::
KernelRuntimeManager
::
Instance
().
ClearGraphResource
(
graph_id_
);
}
...
...
mindspore/ccsrc/session/kernel_graph.h
浏览文件 @
32405f9a
...
...
@@ -144,6 +144,13 @@ class KernelGraph : public FuncGraph {
void
PrintGraphExecuteOrder
()
const
;
const
std
::
map
<
std
::
string
,
std
::
pair
<
AnfNodePtr
,
int
>>
&
summary_nodes
()
const
{
return
summary_nodes_
;
}
void
set_summary_nodes
(
const
std
::
map
<
std
::
string
,
std
::
pair
<
AnfNodePtr
,
int
>>
&
nodes
)
{
summary_nodes_
=
nodes
;
}
void
AddInternalOutput
(
const
AnfNodePtr
&
front_node
,
const
AnfNodePtr
&
node
);
void
ReplaceInternalOutput
(
const
AnfNodePtr
&
node
,
const
AnfNodePtr
&
new_node
);
AnfNodePtr
GetInternalOutputByFrontNode
(
const
AnfNodePtr
&
front_node
)
const
;
bool
IsInternalOutput
(
const
AnfNodePtr
&
node
)
const
;
AnfNodePtr
GetFrontNodeByInternalOutput
(
const
AnfNodePtr
&
node
)
const
;
void
AddFinalOutputKernel
(
const
AnfNodePtr
&
node
);
bool
IsFinalOutputKernel
(
const
AnfNodePtr
&
node
)
const
;
private:
// remove value node form graph
...
...
@@ -202,6 +209,9 @@ class KernelGraph : public FuncGraph {
CNodePtr
start_label_
;
CNodePtr
end_goto_
;
bool
null_output_
;
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
front_to_internal_outputs_map_
;
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
internal_outputs_to_front_map_
;
std
::
set
<
AnfNodePtr
>
final_output_kernels_
;
};
}
// namespace session
using
KernelGraphPtr
=
std
::
shared_ptr
<
session
::
KernelGraph
>
;
...
...
mindspore/ccsrc/session/session_basic.cc
浏览文件 @
32405f9a
...
...
@@ -95,6 +95,13 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
TypeId
type_id
=
kNumberTypeFloat32
;
type_id
=
AnfAlgo
::
GetOutputInferDataType
(
node
,
output_index
);
std
::
vector
<
int
>
temp_shape
;
if
(
graph
.
IsInternalOutput
(
node
))
{
temp_shape
.
emplace_back
(
1
);
tensor
::
TensorPtr
tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
type_id
,
temp_shape
);
tensor
->
set_device_address
(
address
);
tensor
->
set_dirty
(
false
);
return
tensor
;
}
(
void
)
std
::
copy
(
shape
.
begin
(),
shape
.
end
(),
std
::
back_inserter
(
temp_shape
));
tensor
::
TensorPtr
tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
type_id
,
temp_shape
);
// if in paynative mode,data only copyed to host when user want to print data
...
...
@@ -172,48 +179,6 @@ ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
return
new_value_node
;
}
std
::
vector
<
AnfNodePtr
>
CreateParameterFromTuple
(
const
AnfNodePtr
&
node
,
bool
valid_input
,
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
graph
);
std
::
vector
<
AnfNodePtr
>
parameters
;
std
::
vector
<
AnfNodePtr
>
pre_graph_out
=
{
node
};
// If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
if
(
!
AnfAlgo
::
IsRealKernel
(
node
))
{
pre_graph_out
=
AnfAlgo
::
GetAllOutput
(
node
,
{
prim
::
kPrimTupleGetItem
});
}
auto
valid_inputs
=
graph
->
MutableValidInputs
();
MS_EXCEPTION_IF_NULL
(
valid_inputs
);
auto
graph_inputs
=
graph
->
MutableInputs
();
MS_EXCEPTION_IF_NULL
(
graph_inputs
);
auto
create_parameter
=
[
&
](
const
AbstractBasePtr
&
abstract
)
->
void
{
auto
parameter
=
graph
->
NewParameter
();
MS_EXCEPTION_IF_NULL
(
parameter
);
parameter
->
set_abstract
(
abstract
);
auto
new_parameter
=
graph
->
NewParameter
(
parameter
);
parameters
.
push_back
(
new_parameter
);
valid_inputs
->
push_back
(
valid_input
);
graph_inputs
->
push_back
(
new_parameter
);
};
for
(
const
auto
&
out_node
:
pre_graph_out
)
{
MS_EXCEPTION_IF_NULL
(
out_node
);
auto
abstract
=
out_node
->
abstract
();
MS_EXCEPTION_IF_NULL
(
abstract
);
// create multiple parameters if is a tuple output real kernel
if
(
abstract
->
isa
<
abstract
::
AbstractTuple
>
()
&&
!
AnfAlgo
::
CheckPrimitiveType
(
out_node
,
prim
::
kPrimTupleGetItem
))
{
auto
tuple_abstract
=
abstract
->
cast
<
abstract
::
AbstractTuplePtr
>
();
MS_EXCEPTION_IF_NULL
(
tuple_abstract
);
MS_LOG
(
INFO
)
<<
"Tuple_size ["
<<
tuple_abstract
->
size
()
<<
"]"
;
for
(
size_t
output_idx
=
0
;
output_idx
<
tuple_abstract
->
size
();
output_idx
++
)
{
create_parameter
((
*
tuple_abstract
)[
output_idx
]);
}
continue
;
}
// create single parameter if is a abstract real kernel
create_parameter
(
out_node
->
abstract
());
}
return
parameters
;
}
size_t
LoadCtrlInputTensor
(
const
std
::
shared_ptr
<
KernelGraph
>
&
graph
,
std
::
vector
<
tensor
::
TensorPtr
>
*
inputs
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
MS_LOG
(
INFO
)
<<
"Load kInputCtrlTensors"
;
...
...
@@ -323,6 +288,103 @@ bool ExistSummaryNode(const KernelGraph *graph) {
}
// namespace
GraphId
SessionBasic
::
graph_sum_
=
0
;
KernelGraphPtr
SessionBasic
::
GetGraph
(
mindspore
::
GraphId
graph_id
)
{
auto
it
=
graphs_
.
find
(
graph_id
);
if
(
it
==
graphs_
.
end
())
{
MS_LOG
(
WARNING
)
<<
"Can't find graph "
<<
graph_id
;
return
nullptr
;
}
return
it
->
second
;
}
void
SessionBasic
::
InitInternalOutputParameter
(
const
AnfNodePtr
&
out_node
,
const
AnfNodePtr
&
parameter
)
{
auto
graph_id
=
GetGraphIdByNode
(
out_node
);
if
(
graph_id
==
kInvalidGraphId
)
{
return
;
}
auto
node_graph
=
GetGraph
(
graph_id
);
if
(
node_graph
==
nullptr
)
{
return
;
}
MS_LOG
(
INFO
)
<<
"Init parameter with pre graph output node: "
<<
out_node
->
DebugString
();
auto
ref_node
=
node_graph
->
GetInternalOutputByFrontNode
(
out_node
);
if
(
ref_node
==
nullptr
)
{
MS_LOG
(
INFO
)
<<
"No corresponding internal output for output node"
;
return
;
}
auto
real_kernel
=
AnfAlgo
::
VisitKernel
(
ref_node
,
0
);
auto
ref_real_node
=
real_kernel
.
first
;
auto
ref_real_node_index
=
real_kernel
.
second
;
if
(
ref_real_node
->
isa
<
CNode
>
()
&&
node_graph
->
IsInternalOutput
(
ref_real_node
)
&&
node_graph
->
IsFinalOutputKernel
(
ref_real_node
))
{
auto
kernel_info
=
ref_real_node
->
kernel_info
();
if
(
kernel_info
==
nullptr
||
kernel_info
->
select_kernel_build_info
()
==
nullptr
)
{
MS_LOG
(
INFO
)
<<
"No kernel info"
;
return
;
}
auto
address
=
AnfAlgo
::
GetMutableOutputAddr
(
ref_real_node
,
ref_real_node_index
);
if
(
address
==
nullptr
)
{
MS_LOG
(
INFO
)
<<
"No kernel address"
;
return
;
}
auto
format
=
AnfAlgo
::
GetOutputFormat
(
ref_real_node
,
ref_real_node_index
);
auto
type
=
AnfAlgo
::
GetOutputDeviceDataType
(
ref_real_node
,
ref_real_node_index
);
parameter
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
auto
d_kernel_info
=
parameter
->
kernel_info
();
MS_EXCEPTION_IF_NULL
(
d_kernel_info
);
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
builder
;
builder
.
SetOutputsDeviceType
({
type
});
builder
.
SetOutputsFormat
({
format
});
d_kernel_info
->
set_select_kernel_build_info
(
builder
.
Build
());
AnfAlgo
::
SetOutputAddr
(
address
,
0
,
parameter
.
get
());
}
}
std
::
vector
<
AnfNodePtr
>
SessionBasic
::
CreateParameterFromTuple
(
const
AnfNodePtr
&
node
,
bool
valid_input
,
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
graph
);
std
::
vector
<
AnfNodePtr
>
parameters
;
std
::
vector
<
AnfNodePtr
>
pre_graph_out
=
{
node
};
// If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
if
(
!
AnfAlgo
::
IsRealKernel
(
node
))
{
pre_graph_out
=
AnfAlgo
::
GetAllOutput
(
node
,
{
prim
::
kPrimTupleGetItem
});
}
auto
valid_inputs
=
graph
->
MutableValidInputs
();
MS_EXCEPTION_IF_NULL
(
valid_inputs
);
auto
graph_inputs
=
graph
->
MutableInputs
();
MS_EXCEPTION_IF_NULL
(
graph_inputs
);
auto
create_parameter
=
[
&
](
const
AbstractBasePtr
&
abstract
)
->
void
{
auto
parameter
=
graph
->
NewParameter
();
MS_EXCEPTION_IF_NULL
(
parameter
);
parameter
->
set_abstract
(
abstract
);
auto
new_parameter
=
graph
->
NewParameter
(
parameter
);
parameters
.
push_back
(
new_parameter
);
valid_inputs
->
push_back
(
valid_input
);
graph_inputs
->
push_back
(
new_parameter
);
};
for
(
const
auto
&
out_node
:
pre_graph_out
)
{
MS_EXCEPTION_IF_NULL
(
out_node
);
auto
abstract
=
out_node
->
abstract
();
MS_EXCEPTION_IF_NULL
(
abstract
);
// create multiple parameters if is a tuple output real kernel
if
(
abstract
->
isa
<
abstract
::
AbstractTuple
>
()
&&
!
AnfAlgo
::
CheckPrimitiveType
(
out_node
,
prim
::
kPrimTupleGetItem
))
{
auto
tuple_abstract
=
abstract
->
cast
<
abstract
::
AbstractTuplePtr
>
();
MS_EXCEPTION_IF_NULL
(
tuple_abstract
);
MS_LOG
(
INFO
)
<<
"Tuple_size ["
<<
tuple_abstract
->
size
()
<<
"]"
;
for
(
size_t
output_idx
=
0
;
output_idx
<
tuple_abstract
->
size
();
output_idx
++
)
{
create_parameter
((
*
tuple_abstract
)[
output_idx
]);
}
continue
;
}
// create single parameter if is a abstract real kernel
create_parameter
(
out_node
->
abstract
());
InitInternalOutputParameter
(
out_node
,
parameters
[
parameters
.
size
()
-
1
]);
}
return
parameters
;
}
ParameterPtr
SessionBasic
::
CreateNewParameterFromParameter
(
const
AnfNodePtr
&
anf
,
bool
valid_input
,
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
anf
);
...
...
@@ -877,6 +939,29 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
auto
FindEqu
=
[
graph
,
outputs
](
const
AnfNodePtr
&
out
)
->
AnfNodePtr
{
auto
backend_anf
=
graph
->
GetBackendAnfByFrontAnf
(
out
);
if
(
backend_anf
!=
nullptr
)
{
auto
front_real_kernel
=
AnfAlgo
::
VisitKernel
(
out
,
0
);
auto
backend_real_kernel
=
AnfAlgo
::
VisitKernel
(
backend_anf
,
0
);
MS_EXCEPTION_IF_NULL
(
out
);
auto
out_func_graph
=
out
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
out_func_graph
);
auto
out_func_graph_manager
=
out_func_graph
->
manager
();
if
(
out_func_graph_manager
==
nullptr
)
{
return
backend_anf
;
}
auto
node_users
=
out_func_graph_manager
->
node_users
();
auto
users
=
node_users
[
out
];
bool
internal_output
=
true
;
std
::
string
kernel_target
=
GetCNodeTarget
(
front_real_kernel
.
first
);
for
(
auto
user
:
users
)
{
if
(
!
AnfAlgo
::
IsRealKernel
(
user
.
first
)
||
kernel_target
!=
GetCNodeTarget
(
user
.
first
))
{
internal_output
=
false
;
break
;
}
}
if
(
internal_output
)
{
MS_LOG
(
INFO
)
<<
"Internal output1: "
<<
out
->
DebugString
()
<<
"To "
<<
backend_real_kernel
.
first
->
DebugString
();
graph
->
AddInternalOutput
(
out
,
backend_real_kernel
.
first
);
}
return
backend_anf
;
}
MS_LOG
(
EXCEPTION
)
<<
"Can't find the node in the equiv map!"
;
...
...
mindspore/ccsrc/session/session_basic.h
浏览文件 @
32405f9a
...
...
@@ -110,6 +110,8 @@ class SessionBasic {
#endif
protected:
// Get graph by graph id ,if not exist return null ptr
KernelGraphPtr
GetGraph
(
GraphId
graph_id
);
virtual
void
LoadInputData
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs_const
)
const
;
void
UpdateOutputs
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
,
VectorRef
*
const
outputs
,
...
...
@@ -127,11 +129,13 @@ class SessionBasic {
BaseRef
TransformBaseRefListToTuple
(
const
BaseRef
&
base_ref
);
// create a new kernel graph and update the graph sum
KernelGraphPtr
NewKernelGraph
();
std
::
vector
<
AnfNodePtr
>
CreateParameterFromTuple
(
const
AnfNodePtr
&
node
,
bool
valid_input
,
KernelGraph
*
graph
);
virtual
ParameterPtr
CreateNewParameterFromParameter
(
const
AnfNodePtr
&
anf
,
bool
valid_input
,
KernelGraph
*
graph
);
ValueNodePtr
CreateValueNodeKernelGraph
(
const
AnfNodePtr
&
anf
,
KernelGraph
*
graph
);
ParameterPtr
CreateNewParameter
(
const
AnfNodePtr
&
anf
,
KernelGraph
*
graph
);
AnfNodePtr
CreateNewParameterFromCNode
(
const
AnfNodePtr
&
anf
,
bool
valid_input
,
KernelGraph
*
graph
);
void
AddParameterToGraphInputs
(
const
std
::
vector
<
AnfNodePtr
>
&
parameters
,
KernelGraph
*
graph
);
void
InitInternalOutputParameter
(
const
AnfNodePtr
&
out_node
,
const
AnfNodePtr
&
parameter
);
std
::
unordered_map
<
GraphId
,
std
::
shared_ptr
<
KernelGraph
>>
graphs_
;
std
::
unordered_map
<
GraphInfo
,
std
::
shared_ptr
<
KernelGraph
>>
run_op_graphs_
;
...
...
mindspore/ccsrc/vm/transform.cc
浏览文件 @
32405f9a
...
...
@@ -52,45 +52,6 @@ const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
}
namespace
{
std
::
string
GetCNodeTarget
(
const
AnfNodePtr
&
node
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
std
::
string
default_target
=
context_ptr
->
device_target
();
if
(
!
node
->
isa
<
CNode
>
())
{
return
default_target
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
attr_input
=
cnode
->
input
(
kAnfPrimitiveIndex
);
if
(
attr_input
==
nullptr
)
{
return
default_target
;
}
auto
value_node
=
attr_input
->
cast
<
ValueNodePtr
>
();
if
(
value_node
==
nullptr
)
{
return
default_target
;
}
auto
value
=
value_node
->
value
();
if
(
value
==
nullptr
)
{
return
default_target
;
}
if
(
!
value
->
isa
<
Primitive
>
())
{
return
default_target
;
}
auto
primitive
=
value
->
cast
<
PrimitivePtr
>
();
auto
att_target
=
primitive
->
GetAttr
(
"primitive_target"
);
if
(
att_target
!=
nullptr
)
{
if
(
!
att_target
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Only support string CPU|GPU|Ascend for primitive_target"
;
}
auto
target
=
GetValue
<
std
::
string
>
(
att_target
);
if
(
kTargetSet
.
find
(
target
)
==
kTargetSet
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"Only support string CPU|GPU|Ascend for primitive_target"
;
}
return
target
;
}
return
default_target
;
}
bool
ContainMultiTarget
(
const
std
::
vector
<
AnfNodePtr
>
&
nodes
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录