Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
af5019b9
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
af5019b9
编写于
5月 22, 2020
作者:
Z
zhoufeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
link child graphs
Signed-off-by:
N
zhoufeng
<
zhoufeng54@huawei.com
>
上级
d9c74e0a
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
585 addition
and
78 deletion
+585
-78
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
+10
-6
mindspore/ccsrc/device/ascend/ascend_label_assign.cc
mindspore/ccsrc/device/ascend/ascend_label_assign.cc
+82
-15
mindspore/ccsrc/device/ascend/ascend_label_assign.h
mindspore/ccsrc/device/ascend/ascend_label_assign.h
+7
-1
mindspore/ccsrc/kernel/rts/label_switch.cc
mindspore/ccsrc/kernel/rts/label_switch.cc
+25
-4
mindspore/ccsrc/kernel/rts/label_switch.h
mindspore/ccsrc/kernel/rts/label_switch.h
+8
-0
mindspore/ccsrc/kernel/rts/rt_kernel_info.cc
mindspore/ccsrc/kernel/rts/rt_kernel_info.cc
+7
-1
mindspore/ccsrc/pipeline/action.cc
mindspore/ccsrc/pipeline/action.cc
+7
-5
mindspore/ccsrc/session/CMakeLists.txt
mindspore/ccsrc/session/CMakeLists.txt
+4
-3
mindspore/ccsrc/session/ascend_control_parser.cc
mindspore/ccsrc/session/ascend_control_parser.cc
+319
-0
mindspore/ccsrc/session/ascend_control_parser.h
mindspore/ccsrc/session/ascend_control_parser.h
+73
-0
mindspore/ccsrc/session/ascend_session.cc
mindspore/ccsrc/session/ascend_session.cc
+25
-37
mindspore/ccsrc/session/ascend_session.h
mindspore/ccsrc/session/ascend_session.h
+4
-2
mindspore/ccsrc/session/kernel_graph.h
mindspore/ccsrc/session/kernel_graph.h
+9
-0
tests/ut/cpp/CMakeLists.txt
tests/ut/cpp/CMakeLists.txt
+1
-0
tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc
tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc
+4
-4
未找到文件。
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
浏览文件 @
af5019b9
...
...
@@ -29,6 +29,7 @@
#include "hccl/hcom.h"
#include "common/trans.h"
#include "runtime/context.h"
#include "device/ascend/ascend_label_assign.h"
#include "device/ascend/ascend_stream_assign.h"
#include "device/ascend/ascend_memory_pool.h"
#include "framework/ge_runtime/model_runner.h"
...
...
@@ -281,21 +282,24 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
return
true
;
}
AscendStreamAssign
&
assign_instance
=
AscendStreamAssign
::
GetInstance
();
AscendStreamAssign
&
stream_assign_instance
=
AscendStreamAssign
::
GetInstance
();
AscendLabelAssign
&
label_assign_instance
=
AscendLabelAssign
::
GetInstance
();
// the streams' flag not HEAD_STREAM
std
::
vector
<
uint32_t
>
wait_active_stream_list
;
assign_instance
.
GetWaitStreams
(
&
wait_active_stream_list
);
auto
force_copy_stream_list
=
assign_instance
.
hcom_streams
();
stream_
assign_instance
.
GetWaitStreams
(
&
wait_active_stream_list
);
auto
force_copy_stream_list
=
stream_
assign_instance
.
hcom_streams
();
MS_LOG
(
INFO
)
<<
"call DavinciModel total stream num:"
<<
assign_instance
.
GetTotalStreamNum
()
<<
", total event num:"
<<
assign_instance
.
total_event_num
()
MS_LOG
(
INFO
)
<<
"call DavinciModel total stream num:"
<<
stream_assign_instance
.
GetTotalStreamNum
()
<<
", total event num:"
<<
stream_assign_instance
.
total_event_num
()
<<
", total label num:"
<<
label_assign_instance
.
GetLabelNum
(
NOT_NULL
(
graph
))
<<
", wait_active_stream_list size:"
<<
wait_active_stream_list
.
size
()
<<
", force_copy_stream_list size:"
<<
force_copy_stream_list
.
size
();
std
::
vector
<
std
::
shared_ptr
<
ge
::
model_runner
::
OpInfo
>>
empty_list
;
std
::
shared_ptr
<
ge
::
model_runner
::
DavinciModel
>
model
=
std
::
make_shared
<
ge
::
model_runner
::
DavinciModel
>
(
task_info_list
,
empty_list
,
empty_list
,
empty_list
,
empty_list
,
wait_active_stream_list
,
force_copy_stream_list
,
0
,
0
,
0
,
0
,
0
,
0
,
assign_instance
.
GetTotalStreamNum
(),
1
,
assign_instance
.
total_event_num
(),
0
);
0
,
0
,
0
,
0
,
0
,
stream_assign_instance
.
GetTotalStreamNum
(),
label_assign_instance
.
GetLabelNum
(
NOT_NULL
(
graph
)),
stream_assign_instance
.
total_event_num
(),
0
);
auto
ret
=
graph_model_map_
.
insert
(
std
::
make_pair
(
graph
->
graph_id
(),
model
));
if
(
!
ret
.
second
)
{
...
...
mindspore/ccsrc/device/ascend/ascend_label_assign.cc
浏览文件 @
af5019b9
...
...
@@ -15,6 +15,8 @@
*/
#include <vector>
#include <string>
#include <set>
#include "device/ascend/ascend_label_assign.h"
#include "session/anf_runtime_algorithm.h"
...
...
@@ -36,6 +38,7 @@ static void UpdateLabelGoto(NotNull<CNodePtr> node) {
uint32_t
goto_label_id
=
GetValue
<
uint32_t
>
(
value
);
AnfAlgo
::
SetNodeAttr
(
kAttrLabelIndex
,
MakeValue
<
uint32_t
>
(
goto_label_id
),
node
.
get
());
MS_LOG
(
INFO
)
<<
"Node "
<<
node
->
DebugString
()
<<
" goto label id "
<<
goto_label_id
;
node
->
set_inputs
({
node
->
input
(
0
)});
}
static
void
UpdateLabelSwitch
(
NotNull
<
CNodePtr
>
node
)
{
...
...
@@ -58,29 +61,93 @@ static void UpdateLabelSwitch(NotNull<CNodePtr> node) {
MS_LOG
(
INFO
)
<<
"Switch "
<<
node
->
DebugString
()
<<
" case "
<<
i
-
kLabelSwitchLabelId
<<
": id "
<<
goto_label_id
;
}
AnfAlgo
::
SetNodeAttr
(
kAttrLabelSwitchList
,
MakeValue
<
std
::
vector
<
uint32_t
>>
(
label_list
),
node
.
get
());
node
->
set_inputs
({
node
->
input
(
0
),
node
->
input
(
1
)});
}
void
AscendLabelAssign
::
AssignLabel
(
NotNull
<
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&>
graph
)
{
auto
cnode_list
=
graph
->
execution_order
();
// 1 assign label id to label_set
uint32_t
cur_label_id
=
0
;
for
(
auto
&
node
:
cnode_list
)
{
if
(
AnfAlgo
::
GetCNodeName
(
node
)
==
kLabelSetOpName
)
{
AnfAlgo
::
SetNodeAttr
(
kAttrLabelIndex
,
MakeValue
<
uint32_t
>
(
cur_label_id
),
node
);
MS_LOG
(
INFO
)
<<
"Node "
<<
node
->
DebugString
()
<<
" assign label id "
<<
cur_label_id
;
++
cur_label_id
;
static
void
AssignLabelForLabelSet
(
NotNull
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
graph
,
NotNull
<
uint32_t
*>
label_id
,
NotNull
<
std
::
set
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
*>
memo
)
{
if
(
memo
->
find
(
graph
.
get
())
!=
memo
->
end
())
{
return
;
}
MS_LOG
(
INFO
)
<<
"Assign label for "
<<
graph
->
ToString
();
auto
nodes
=
TopoSort
(
graph
->
get_return
());
for
(
auto
&
node
:
nodes
)
{
if
(
!
node
->
isa
<
CNode
>
())
{
continue
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
std
::
string
node_name
=
AnfAlgo
::
GetCNodeName
(
node
);
if
(
node_name
==
kLabelSetOpName
&&
!
AnfAlgo
::
HasNodeAttr
(
kAttrLabelIndex
,
cnode
))
{
AnfAlgo
::
SetNodeAttr
(
kAttrLabelIndex
,
MakeValue
<
uint32_t
>
(
*
label_id
),
node
);
MS_LOG
(
INFO
)
<<
"Node "
<<
node
->
DebugString
()
<<
" assign label id "
<<
*
label_id
;
++
(
*
label_id
);
}
}
// 2 update label_switch / label_goto
for
(
auto
&
node
:
cnode_list
)
{
if
(
AnfAlgo
::
GetCNodeName
(
node
)
==
kLabelGotoOpName
)
{
UpdateLabelGoto
(
NOT_NULL
(
node
));
for
(
auto
&
cg
:
graph
->
child_graph_order
())
{
AssignLabelForLabelSet
(
NOT_NULL
(
cg
),
label_id
,
memo
);
}
}
static
void
AssignLabelForGotoSwitch
(
NotNull
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
graph
,
NotNull
<
std
::
set
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
*>
memo
)
{
if
(
memo
->
find
(
graph
.
get
())
!=
memo
->
end
())
{
return
;
}
MS_LOG
(
INFO
)
<<
"Process label goto/switch for "
<<
graph
->
ToString
();
auto
nodes
=
TopoSort
(
graph
->
get_return
());
for
(
auto
&
node
:
nodes
)
{
if
(
!
node
->
isa
<
CNode
>
())
{
continue
;
}
if
(
AnfAlgo
::
GetCNodeName
(
node
)
==
kLabelSwitchOpName
)
{
UpdateLabelSwitch
(
NOT_NULL
(
node
));
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
std
::
string
node_name
=
AnfAlgo
::
GetCNodeName
(
node
);
if
(
node_name
==
kLabelGotoOpName
)
{
UpdateLabelGoto
(
NOT_NULL
(
cnode
));
cnode
->
set_abstract
(
nullptr
);
}
if
(
node_name
==
kLabelSwitchOpName
)
{
UpdateLabelSwitch
(
NOT_NULL
(
cnode
));
}
}
for
(
auto
&
cg
:
graph
->
child_graph_order
())
{
AssignLabelForGotoSwitch
(
NOT_NULL
(
cg
),
memo
);
}
}
void
AscendLabelAssign
::
AssignLabel
(
NotNull
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
graph
)
{
MS_LOG
(
INFO
)
<<
"Assign label start."
;
std
::
set
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
memo
;
uint32_t
label_id
=
0
;
AssignLabelForLabelSet
(
graph
,
NOT_NULL
(
&
label_id
),
NOT_NULL
(
&
memo
));
memo
.
clear
();
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
label_num_mutex_
);
label_num_
[
graph
.
get
().
get
()]
=
label_id
;
}
AssignLabelForGotoSwitch
(
graph
,
NOT_NULL
(
&
memo
));
MS_LOG
(
INFO
)
<<
"Assign label end."
;
}
uint32_t
AscendLabelAssign
::
GetLabelNum
(
NotNull
<
const
session
::
KernelGraph
*>
graph
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
label_num_mutex_
);
auto
iter
=
label_num_
.
find
(
graph
.
get
());
if
(
iter
==
label_num_
.
end
())
{
MS_LOG
(
WARNING
)
<<
"Graph "
<<
graph
->
ToString
()
<<
" has not assigned label."
;
return
1
;
}
return
iter
->
second
;
}
uint32_t
AscendLabelAssign
::
GetLabelNum
(
NotNull
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
graph
)
{
return
GetLabelNum
(
NOT_NULL
(
graph
.
get
().
get
()));
}
}
// namespace ascend
...
...
mindspore/ccsrc/device/ascend/ascend_label_assign.h
浏览文件 @
af5019b9
...
...
@@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_
#include <memory>
#include <map>
#include "session/kernel_graph.h"
#include "utils/contract.h"
...
...
@@ -35,11 +36,16 @@ class AscendLabelAssign {
AscendLabelAssign
(
const
AscendLabelAssign
&
)
=
delete
;
AscendLabelAssign
&
operator
=
(
const
AscendLabelAssign
&
)
=
delete
;
void
AssignLabel
(
NotNull
<
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&>
graph
);
void
AssignLabel
(
NotNull
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
graph
);
uint32_t
GetLabelNum
(
NotNull
<
const
session
::
KernelGraph
*>
graph
);
uint32_t
GetLabelNum
(
NotNull
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
graph
);
private:
AscendLabelAssign
()
=
default
;
~
AscendLabelAssign
()
=
default
;
std
::
map
<
const
session
::
KernelGraph
*
,
uint32_t
>
label_num_
;
std
::
mutex
label_num_mutex_
;
};
}
// namespace ascend
}
// namespace device
...
...
mindspore/ccsrc/kernel/rts/label_switch.cc
浏览文件 @
af5019b9
...
...
@@ -17,6 +17,7 @@
#include "kernel/rts/label_switch.h"
#include <asm-generic/param.h>
#include <memory>
#include <string>
#include "runtime/stream.h"
#include "framework/ge_runtime/task_info.h"
#include "session/anf_runtime_algorithm.h"
...
...
@@ -66,13 +67,33 @@ std::vector<TaskInfoPtr> LabelSwitchKernel::GenTask(const std::vector<AddressPtr
MS_LOG
(
INFO
)
<<
"LabelSwitchKernel GenTask label size:"
<<
label_size_
<<
", stream id:"
<<
stream_id
;
std
::
vector
<
TaskInfoPtr
>
task_info_list
;
cond_
=
inputs
[
0
]
->
addr
;
// std::shared_ptr<LabelSwitchTaskInfo> task_info_ptr =
// std::make_shared<LabelSwitchTaskInfo>(stream_id, label_size_, &label_list_, cond_);
// need updata ge task info define
std
::
shared_ptr
<
LabelSwitchTaskInfo
>
task_info_ptr
=
std
::
make_shared
<
LabelSwitchTaskInfo
>
(
stream_id
,
label_size_
);
// todo: need update ge task info define
auto
task_info_ptr
=
std
::
make_shared
<
LabelSwitchTaskInfo
>
(
stream_id
,
0
);
// auto task_info_ptr = std::make_shared<LabelSwitchTaskInfo>(stream_id, label_size_, label_list_, cond_);
MS_EXCEPTION_IF_NULL
(
task_info_ptr
);
task_info_list
.
emplace_back
(
task_info_ptr
);
return
task_info_list
;
}
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
LabelSwitchDesc
::
GetKernelInfo
()
{
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
label_switch_build_info
{};
vector
<
string
>
input_format
{
kOpFormat_DEFAULT
,
kOpFormat_DEFAULT
};
vector
<
TypeId
>
input_type
{
kNumberTypeUInt32
,
kNumberTypeBool
};
if
(
input_format
.
size
()
!=
input_type
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Invalid param num, input_format size "
<<
input_format
.
size
()
<<
" input_type size "
<<
input_type
.
size
();
}
for
(
size_t
i
=
0
;
i
<
input_format
.
size
();
++
i
)
{
auto
builder
=
KernelBuildInfo
::
KernelBuildInfoBuilder
();
builder
.
SetInputsFormat
({
input_format
[
i
]});
builder
.
SetInputsDeviceType
({
input_type
[
i
]});
builder
.
SetProcessor
(
AICORE
);
builder
.
SetKernelType
(
RT_KERNEL
);
builder
.
SetFusionType
(
OPAQUE
);
label_switch_build_info
.
emplace_back
(
builder
.
Build
());
}
return
label_switch_build_info
;
}
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/rts/label_switch.h
浏览文件 @
af5019b9
...
...
@@ -42,6 +42,14 @@ class LabelSwitchKernel : public RtKernel {
void
*
cond_
;
};
class
LabelSwitchDesc
:
public
RtKerDesc
{
public:
LabelSwitchDesc
()
=
default
;
~
LabelSwitchDesc
()
override
=
default
;
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
GetKernelInfo
()
override
;
};
MS_REG_RTKERNEL_DESC
(
labelswitch
,
LabelSwitchDesc
);
MS_REG_RTKERNEL
(
labelswitch
,
LabelSwitchKernel
);
}
// namespace kernel
}
// namespace mindspore
...
...
mindspore/ccsrc/kernel/rts/rt_kernel_info.cc
浏览文件 @
af5019b9
...
...
@@ -44,6 +44,12 @@ RtKerDescFactory &RtKerDescFactory::Get() {
return
_this
;
}
static
bool
IsDefaultKernelInfo
(
const
std
::
string
&
name
)
{
static
const
std
::
set
<
std
::
string
>
white_list
=
{
kStreamSwitchOpName
,
kStreamActiveOpName
,
kLabelSetOpName
,
kLabelGotoOpName
};
return
white_list
.
find
(
name
)
!=
white_list
.
end
();
}
void
GetRtKelInfo
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
*
kernel_info_list
)
{
MS_EXCEPTION_IF_NULL
(
kernel_info_list
);
...
...
@@ -58,7 +64,7 @@ void GetRtKelInfo(const CNodePtr &kernel_node,
}
// if can't find kernel info in kernel info database, use the default kernel info
auto
node_name
=
AnfAlgo
::
GetCNodeName
(
kernel_node
);
if
(
node_name
==
"StreamSwitch"
||
node_name
==
"StreamActive"
)
{
if
(
IsDefaultKernelInfo
(
node_name
)
)
{
auto
kernel_build_info_builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
();
// set input infos
auto
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
...
...
mindspore/ccsrc/pipeline/action.cc
浏览文件 @
af5019b9
...
...
@@ -331,12 +331,14 @@ bool ExecuteAction(const ResourcePtr &res) {
}
auto
graph_id
=
res
->
results
()[
kOutput
].
cast
<
GraphId
>
();
auto
bc_ptr
=
res
->
results
()[
kBackend
].
cast
<
std
::
shared_ptr
<
compile
::
MsBackend
>>
();
std
::
shared_ptr
<
compile
::
Backend
>
bc_ptr
=
res
->
results
()[
kBackend
].
cast
<
std
::
shared_ptr
<
compile
::
Backend
>>
();
std
::
shared_ptr
<
compile
::
MsBackend
>
msbc_ptr
=
std
::
dynamic_pointer_cast
<
compile
::
MsBackend
>
(
bc_ptr
);
MS_EXCEPTION_IF_NULL
(
msbc_ptr
);
compile
::
VmEvalFuncPtr
run
=
std
::
make_shared
<
compile
::
VmEvalFunc
>
([
&
bc_ptr
,
graph_id
](
const
VectorRef
&
args
)
->
BaseRef
{
MS_LOG
(
INFO
)
<<
"Execute args size"
<<
args
.
size
();
auto
outs
=
bc_ptr
->
RunGraph
(
graph_id
,
args
);
MS_LOG
(
DEBUG
)
<<
"out size"
<<
outs
.
size
();
std
::
make_shared
<
compile
::
VmEvalFunc
>
([
ms
bc_ptr
,
graph_id
](
const
VectorRef
&
args
)
->
BaseRef
{
MS_LOG
(
INFO
)
<<
"Execute args size
"
<<
args
.
size
();
auto
outs
=
ms
bc_ptr
->
RunGraph
(
graph_id
,
args
);
MS_LOG
(
DEBUG
)
<<
"out size
"
<<
outs
.
size
();
return
outs
[
0
];
});
res
->
results
()[
kOutput
]
=
run
;
...
...
mindspore/ccsrc/session/CMakeLists.txt
浏览文件 @
af5019b9
...
...
@@ -6,22 +6,23 @@ file(GLOB_RECURSE _SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
)
if
(
ENABLE_GPU
)
file
(
GLOB_RECURSE _GPU_SRC_LIST
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
file
(
GLOB_RECURSE _GPU_SRC_LIST RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"gpu_session.cc"
)
list
(
APPEND _SESSION_SRC_LIST
${
_GPU_SRC_LIST
}
)
endif
()
if
(
ENABLE_CPU
)
file
(
GLOB_RECURSE _CPU_SRC_LIST
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
file
(
GLOB_RECURSE _CPU_SRC_LIST RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"cpu_session.cc"
)
list
(
APPEND _SESSION_SRC_LIST
${
_CPU_SRC_LIST
}
)
endif
()
if
(
ENABLE_D
)
file
(
GLOB_RECURSE _D_SRC_LIST
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
file
(
GLOB_RECURSE _D_SRC_LIST RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"ascend_session.cc"
"ascend_control_parser.cc"
)
list
(
APPEND _SESSION_SRC_LIST
${
_D_SRC_LIST
}
)
endif
()
...
...
mindspore/ccsrc/session/ascend_control_parser.cc
0 → 100644
浏览文件 @
af5019b9
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <utility>
#include <memory>
#include "session/ascend_control_parser.h"
#include "session/anf_runtime_algorithm.h"
namespace
mindspore
{
namespace
session
{
static
VectorRef
GetCallArgs
(
std
::
vector
<
AnfNodePtr
>::
iterator
iter_begin
,
std
::
vector
<
AnfNodePtr
>::
iterator
iter_end
)
{
VectorRef
call_args
;
for
(
auto
iter
=
iter_begin
;
iter
!=
iter_end
;
++
iter
)
{
if
(
utils
::
isa
<
ValueNode
>
(
*
iter
))
{
call_args
.
push_back
(
GetValueNode
(
*
iter
));
}
else
{
call_args
.
push_back
(
*
iter
);
}
}
return
call_args
;
}
void
AscendControlParser
::
LinkGraph
(
NotNull
<
KernelGraphPtr
>
kg
)
{
std
::
set
<
KernelGraphPtr
>
memo
;
ProcessKernelGraph
(
kg
,
nullptr
,
nullptr
,
{},
NOT_NULL
(
&
memo
));
}
NotNull
<
CNodePtr
>
AscendControlParser
::
ProcessKernelGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
const
CNodePtr
&
last_node
,
const
CNodePtr
&
last_label
,
const
VectorRef
&
args
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
{
MS_LOG
(
INFO
)
<<
"Start process KernelGraph "
<<
kg
->
ToString
();
// 0. recursive condition
if
(
memo
->
find
(
kg
)
!=
memo
->
end
())
{
MS_LOG
(
INFO
)
<<
"KernelGraph has beed processed: "
<<
kg
->
ToString
();
return
NOT_NULL
(
kg
->
get_start_label
());
}
// 2. args replace placeholder
LinkParentGraph
(
kg
,
last_node
,
last_label
,
args
);
// 3. topological sort
std
::
vector
<
CNodePtr
>
nodes
=
GetCNodes
(
TopoSort
(
kg
->
get_return
()));
if
(
nodes
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"KernelGraph "
<<
kg
->
ToString
()
<<
" has no cnodes!"
;
}
// 4. insert first_label
auto
start_label
=
kg
->
NewCNode
({
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
kLabelSetOpName
))});
for
(
auto
node
:
nodes
)
{
if
(
!
IsPrimitiveCNode
(
node
,
prim
::
kPrimPartial
))
{
InsertControlDependToGraph
(
kg
,
NOT_NULL
(
start_label
),
NOT_NULL
(
node
));
break
;
}
}
kg
->
set_start_label
(
start_label
);
// 5. traverse
for
(
size_t
i
=
0
;
i
<
nodes
.
size
();
++
i
)
{
auto
&
cnode
=
nodes
[
i
];
if
(
cnode
->
size
()
<
kCNodePrim
+
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"Inputs of apply node is empty"
;
}
AnfNodePtr
fn
=
cnode
->
input
(
kCNodePrim
);
if
(
!
IsPrimitive
(
fn
,
prim
::
kPrimCall
)
||
cnode
->
size
()
<
kCNodeCallArg
+
1
)
{
MS_LOG
(
DEBUG
)
<<
"continue node "
<<
cnode
->
DebugString
();
continue
;
}
AnfNodePtr
arg
=
cnode
->
input
(
kCNodeCallArg
);
if
(
IsValueNode
<
KernelGraph
>
(
arg
))
{
RecurseCall
(
kg
,
NOT_NULL
(
cnode
),
(
i
+
1
<
nodes
.
size
()
?
nodes
[
i
+
1
]
:
nullptr
),
memo
);
}
else
if
(
!
arg
->
isa
<
CNode
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Unknown type call node "
<<
cnode
->
DebugString
();
}
else
if
(
IsPrimitiveCNode
(
arg
->
cast
<
CNodePtr
>
(),
prim
::
kPrimSwitch
))
{
auto
arg_cnode
=
arg
->
cast
<
CNodePtr
>
();
cnode
->
set_inputs
(
cnode
->
inputs
());
RecurseSwitch
(
kg
,
NOT_NULL
(
cnode
),
memo
);
}
else
if
(
IsPrimitiveCNode
(
arg
->
cast
<
CNodePtr
>
(),
prim
::
kPrimSwitchLayer
))
{
auto
arg_cnode
=
arg
->
cast
<
CNodePtr
>
();
cnode
->
set_inputs
(
cnode
->
inputs
());
RecurseSwitchLayer
(
kg
,
NOT_NULL
(
cnode
),
memo
);
}
}
MS_LOG
(
INFO
)
<<
"End KernelGraph process: "
<<
kg
->
ToString
();
return
NOT_NULL
(
start_label
);
}
std
::
vector
<
CNodePtr
>
AscendControlParser
::
GetCNodes
(
const
std
::
vector
<
AnfNodePtr
>
&
in
)
{
std
::
vector
<
CNodePtr
>
out
;
for
(
auto
&
node
:
in
)
{
if
(
node
->
isa
<
CNode
>
())
{
out
.
push_back
(
node
->
cast
<
CNodePtr
>
());
}
}
return
out
;
}
void
AscendControlParser
::
InsertDependToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
attch_node
)
{
std
::
vector
<
AnfNodePtr
>
inputs
=
{
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
"depend"
))};
auto
return_node
=
kg
->
get_return
();
MS_EXCEPTION_IF_NULL
(
return_node
);
inputs
.
push_back
(
return_node
->
input
(
1
));
inputs
.
push_back
(
attch_node
.
get
());
auto
depend_node
=
kg
->
NewCNode
(
inputs
);
return_node
->
set_input
(
1
,
depend_node
);
}
void
AscendControlParser
::
InsertControlDependToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
first_node
,
NotNull
<
AnfNodePtr
>
second_node
)
{
MS_LOG
(
INFO
)
<<
"Insert control depend at the end of graph, the first node is "
<<
first_node
->
DebugString
()
<<
", the second node is "
<<
second_node
->
DebugString
();
std
::
vector
<
AnfNodePtr
>
inputs
=
{
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimControlDepend
->
name
())),
first_node
,
second_node
};
auto
control_depend
=
kg
->
NewCNode
(
inputs
);
InsertDependToGraph
(
kg
,
NOT_NULL
(
control_depend
));
}
void
AscendControlParser
::
LinkParentGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
const
CNodePtr
&
from_graph_call_node
,
const
CNodePtr
&
last_label
,
const
VectorRef
&
args
)
{
if
(
from_graph_call_node
!=
nullptr
)
{
SetSubGraphInput
(
kg
,
NOT_NULL
(
from_graph_call_node
),
args
);
}
auto
origin_return
=
kg
->
get_return
();
std
::
vector
<
AnfNodePtr
>
origin_return_inputs
=
origin_return
->
inputs
();
// if entry graph, replace return with make_tuple
if
(
from_graph_call_node
==
nullptr
||
last_label
==
nullptr
)
{
MS_LOG
(
INFO
)
<<
kg
->
ToString
()
<<
" is entry graph."
;
std
::
vector
<
AnfNodePtr
>
make_tuple_inputs
=
{
std
::
make_shared
<
ValueNode
>
(
prim
::
kPrimMakeTuple
)};
make_tuple_inputs
.
insert
(
make_tuple_inputs
.
end
(),
origin_return_inputs
.
begin
()
+
1
,
origin_return_inputs
.
end
());
auto
make_tuple
=
kg
->
NewCNode
(
make_tuple_inputs
);
origin_return
->
set_inputs
({
origin_return
->
input
(
kCNodePrim
),
make_tuple
});
}
else
{
// else replace return with label_goto
auto
label_goto
=
kg
->
NewCNode
({
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
kLabelGotoOpName
)),
last_label
});
InsertDependToGraph
(
kg
,
NOT_NULL
(
label_goto
));
}
}
void
AscendControlParser
::
RecurseCall
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
const
CNodePtr
&
next_node
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
{
MS_LOG
(
INFO
)
<<
"process call func "
<<
cur_node
->
DebugString
();
// 1 get kernel graph
auto
origin_inputs
=
cur_node
->
inputs
();
std
::
vector
<
AnfNodePtr
>
new_inputs
=
{
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
kLabelGotoOpName
))};
auto
call_args
=
GetCallArgs
(
origin_inputs
.
begin
()
+
1
,
origin_inputs
.
end
());
if
(
!
IsValueNode
<
KernelGraph
>
(
origin_inputs
[
kCNodeCallArg
]))
{
MS_LOG
(
WARNING
)
<<
"Node "
<<
cur_node
->
DebugString
(
10
)
<<
" index "
<<
kCNodeCallArg
<<
" is not a ValueNode"
;
return
;
}
// 2 return label
auto
back_label
=
kg
->
NewCNode
({
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
kLabelSetOpName
))});
// 3 add depend relationship
InsertControlDependToGraph
(
kg
,
cur_node
,
NOT_NULL
(
back_label
));
if
(
next_node
!=
nullptr
&&
next_node
!=
kg
->
get_return
())
{
InsertControlDependToGraph
(
kg
,
NOT_NULL
(
back_label
),
NOT_NULL
(
next_node
));
}
auto
call_kg
=
GetValueNode
<
KernelGraphPtr
>
(
origin_inputs
[
kCNodeCallArg
]);
// 4 modify call op to goto op
cur_node
->
set_input
(
kCNodePrim
,
new_inputs
[
kCNodePrim
]);
// 5 recurse sub graph
CNodePtr
sub_label
=
ProcessKernelGraph
(
NOT_NULL
(
call_kg
),
cur_node
,
back_label
,
call_args
,
memo
);
new_inputs
.
push_back
(
sub_label
);
new_inputs
.
insert
(
new_inputs
.
end
(),
origin_inputs
.
begin
(),
origin_inputs
.
end
());
cur_node
->
set_inputs
(
new_inputs
);
cur_node
->
set_abstract
(
nullptr
);
MS_LOG
(
INFO
)
<<
"success process call func "
<<
cur_node
->
DebugString
();
}
void
AscendControlParser
::
RecurseSwitch
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
{
MS_LOG
(
INFO
)
<<
"process switch node "
<<
cur_node
->
DebugString
();
if
(
cur_node
->
size
()
<
kCNodeSwitchLength
)
{
MS_LOG
(
EXCEPTION
)
<<
"Inputs of apply node must more than "
<<
kCNodeSwitchLength
;
}
// 1 return label
auto
back_label
=
kg
->
NewCNode
({
std
::
make_shared
<
ValueNode
>
(
prim
::
kPrimLabelSet
)});
// 2 recurse sub graph
auto
origin_switch_inputs
=
cur_node
->
inputs
();
std
::
vector
<
AnfNodePtr
>
new_switch_inputs
=
{
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
kLabelSwitchOpName
)),
origin_switch_inputs
[
kCNodeSwitchCond
]};
for
(
size_t
i
=
kCNodeSwitchCond
+
1
;
i
<
kCNodeSwitchLength
;
++
i
)
{
// 2.1 branch kernel graph and args
CNodePtr
partial
;
KernelGraphPtr
branch_fg
;
VectorRef
call_args
;
std
::
tie
(
partial
,
branch_fg
,
call_args
)
=
ParsePartial
(
NOT_NULL
(
origin_switch_inputs
[
i
]));
// 2.2 add depend relationship
InsertControlDependToGraph
(
kg
,
cur_node
,
NOT_NULL
(
back_label
));
// 2.3 recurse sub graph
CNodePtr
branch_label
=
ProcessKernelGraph
(
NOT_NULL
(
branch_fg
),
cur_node
,
back_label
,
call_args
,
memo
);
new_switch_inputs
.
push_back
(
branch_label
);
}
std
::
swap
(
new_switch_inputs
[
kCNodeSwitchTrue
],
new_switch_inputs
[
kCNodeSwitchFalse
]);
new_switch_inputs
.
insert
(
new_switch_inputs
.
end
(),
origin_switch_inputs
.
begin
(),
origin_switch_inputs
.
end
());
cur_node
->
set_inputs
(
new_switch_inputs
);
cur_node
->
set_abstract
(
nullptr
);
MS_LOG
(
INFO
)
<<
"success process switch func "
<<
cur_node
->
DebugString
();
}
void
AscendControlParser
::
RecurseSwitchLayer
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
{
MS_LOG
(
INFO
)
<<
"process switch node "
<<
cur_node
->
DebugString
();
if
(
cur_node
->
size
()
<
kCNodeSwitchLayerLength
)
{
MS_LOG
(
EXCEPTION
)
<<
"Inputs of apply node must more than "
<<
kCNodeSwitchLayerLength
;
}
auto
branch_tuple
=
cur_node
->
input
(
kCNodeSwitchLayerBranch
);
MS_EXCEPTION_IF_NULL
(
branch_tuple
);
if
(
!
branch_tuple
->
isa
<
CNode
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Inputs of apply node must more than "
<<
kCNodeSwitchLayerLength
;
}
auto
branch_partial
=
utils
::
cast
<
CNodePtr
>
(
branch_tuple
)
->
inputs
();
// 1 return label
auto
back_label
=
kg
->
NewCNode
({
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
kLabelSwitchOpName
))});
// 2 recurse sub graph
auto
origin_switch_inputs
=
cur_node
->
inputs
();
std
::
vector
<
AnfNodePtr
>
new_switch_inputs
=
{
std
::
make_shared
<
ValueNode
>
(
prim
::
kPrimLabelSwitch
),
origin_switch_inputs
[
kCNodeSwitchCond
]};
for
(
size_t
i
=
0
;
i
<
branch_partial
.
size
();
++
i
)
{
// 2.1 branch kernel graph and args
CNodePtr
partial
;
KernelGraphPtr
branch_fg
;
VectorRef
call_args
;
std
::
tie
(
partial
,
branch_fg
,
call_args
)
=
ParsePartial
(
NOT_NULL
(
origin_switch_inputs
[
i
]));
// 2.2 add depend relationship
InsertControlDependToGraph
(
kg
,
cur_node
,
NOT_NULL
(
back_label
));
// 2.3 recurse sub graph
CNodePtr
branch_label
=
ProcessKernelGraph
(
NOT_NULL
(
branch_fg
),
cur_node
,
back_label
,
call_args
,
memo
);
new_switch_inputs
.
push_back
(
branch_label
);
}
new_switch_inputs
.
insert
(
new_switch_inputs
.
end
(),
branch_partial
.
begin
(),
branch_partial
.
end
());
cur_node
->
set_inputs
(
new_switch_inputs
);
cur_node
->
set_abstract
(
nullptr
);
MS_LOG
(
INFO
)
<<
"success process switch layer "
<<
cur_node
->
DebugString
();
}
std
::
tuple
<
CNodePtr
,
KernelGraphPtr
,
VectorRef
>
AscendControlParser
::
ParsePartial
(
NotNull
<
AnfNodePtr
>
node
)
{
if
(
!
node
.
get
()
->
isa
<
CNode
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Switch branches must be partial, node: "
<<
node
->
DebugString
();
}
// 2.1 branch kernel graph and args
auto
partial_cnode
=
utils
::
cast
<
CNodePtr
>
(
node
.
get
());
if
(
partial_cnode
->
size
()
<
kCNodePartialLength
)
{
MS_LOG
(
EXCEPTION
)
<<
"Inputs of partial node must more than "
<<
kCNodePartialLength
;
}
auto
partial_inputs
=
partial_cnode
->
inputs
();
auto
branch_kg
=
GetValueNode
<
KernelGraphPtr
>
(
partial_inputs
[
kCNodePartialFunc
]);
auto
call_args
=
GetCallArgs
(
partial_inputs
.
begin
()
+
kCNodePartialFunc
+
1
,
partial_inputs
.
end
());
return
{
partial_cnode
,
branch_kg
,
call_args
};
}
void
AscendControlParser
::
InsertAssignToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
from
,
NotNull
<
AnfNodePtr
>
to
)
{
if
(
AnfAlgo
::
OutputAddrExist
(
from
,
0
)
&&
AnfAlgo
::
OutputAddrExist
(
to
,
0
)
&&
AnfAlgo
::
GetOutputAddr
(
from
,
0
)
==
AnfAlgo
::
GetOutputAddr
(
to
,
0
))
{
return
;
}
if
(
from
.
get
()
==
to
.
get
())
{
return
;
}
MS_LOG
(
INFO
)
<<
"Insert assign to graph "
<<
kg
->
ToString
()
<<
" from "
<<
from
->
DebugString
()
<<
" to "
<<
to
->
DebugString
();
// config inputs of assign node
std
::
vector
<
AnfNodePtr
>
inputs
=
{
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
"Assign"
)),
to
,
from
};
// generate a new cnode
auto
assign_node
=
kg
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
assign_node
);
assign_node
->
set_abstract
(
to
->
abstract
());
// append the assign at the end of from graph
InsertDependToGraph
(
kg
,
NOT_NULL
(
assign_node
));
}
size_t
AscendControlParser
::
SetChildGraphInput
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
node
,
size_t
input_index
)
{
auto
output_num
=
AnfAlgo
::
GetOutputTensorNum
(
node
);
if
(
output_num
>
1
&&
!
AnfAlgo
::
CheckPrimitiveType
(
node
,
prim
::
kPrimTupleGetItem
))
{
return
input_index
+
output_num
;
}
auto
&
graph_inputs
=
kg
->
inputs
();
if
(
input_index
>=
graph_inputs
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"input_index "
<<
input_index
<<
" out of range size "
<<
graph_inputs
.
size
();
}
auto
backend_parameter
=
graph_inputs
[
input_index
];
if
(
node
.
get
()
->
isa
<
Parameter
>
())
{
MS_EXCEPTION_IF_NULL
(
backend_parameter
);
MS_LOG
(
INFO
)
<<
"Reuse node ["
<<
node
->
DebugString
()
<<
"], old node["
<<
backend_parameter
->
DebugString
()
<<
"] will be replaced."
;
kg
->
ReplaceNode
(
backend_parameter
,
node
);
return
input_index
;
}
InsertAssignToGraph
(
kg
,
node
,
NOT_NULL
(
backend_parameter
));
return
input_index
+
1
;
}
void
AscendControlParser
::
SetSubGraphInput
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
from_graph_call_node
,
const
VectorRef
&
args
)
{}
}
// namespace session
}
// namespace mindspore
mindspore/ccsrc/session/ascend_control_parser.h
0 → 100644
浏览文件 @
af5019b9
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H
#define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H
#include <set>
#include <vector>
#include <tuple>
#include "session/kernel_graph.h"
#include "utils/base_ref.h"
#include "utils/contract.h"
namespace
mindspore
{
namespace
session
{
class
AscendControlParser
{
public:
static
void
LinkGraph
(
NotNull
<
KernelGraphPtr
>
kg
);
static
void
InsertDependToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
attch_node
);
static
void
InsertControlDependToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
first_node
,
NotNull
<
AnfNodePtr
>
second_node
);
private:
static
NotNull
<
CNodePtr
>
ProcessKernelGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
const
CNodePtr
&
last_node
,
const
CNodePtr
&
last_label
,
const
VectorRef
&
args
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
static
void
RecurseCall
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
const
CNodePtr
&
next_node
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
static
void
RecurseSwitch
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
static
void
RecurseSwitchLayer
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
cur_node
,
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
);
static
std
::
vector
<
CNodePtr
>
GetCNodes
(
const
std
::
vector
<
AnfNodePtr
>
&
in
);
static
void
LinkParentGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
const
CNodePtr
&
from_graph_call_node
,
const
CNodePtr
&
last_label
,
const
VectorRef
&
args
);
static
void
SetSubGraphInput
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
CNodePtr
>
from_graph_call_node
,
const
VectorRef
&
args
);
static
std
::
tuple
<
CNodePtr
,
KernelGraphPtr
,
VectorRef
>
ParsePartial
(
NotNull
<
AnfNodePtr
>
node
);
static
void
InsertAssignToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
from
,
NotNull
<
AnfNodePtr
>
to
);
static
size_t
SetChildGraphInput
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
node
,
size_t
input_index
);
static
constexpr
size_t
kCNodePrim
=
0
;
static
constexpr
size_t
kCNodeCallArg
=
1
;
static
constexpr
size_t
kCNodeSwitchCond
=
1
;
static
constexpr
size_t
kCNodeSwitchTrue
=
2
;
static
constexpr
size_t
kCNodeSwitchFalse
=
3
;
static
constexpr
size_t
kCNodeSwitchLength
=
4
;
static
constexpr
size_t
kCNodePartialLength
=
2
;
static
constexpr
size_t
kCNodePartialFunc
=
1
;
static
constexpr
size_t
kCNodeSwitchLayerCond
=
1
;
static
constexpr
size_t
kCNodeSwitchLayerBranch
=
2
;
static
constexpr
size_t
kCNodeSwitchLayerLength
=
3
;
};
}
// namespace session
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H
mindspore/ccsrc/session/ascend_session.cc
浏览文件 @
af5019b9
...
...
@@ -160,14 +160,14 @@ void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) {
std
::
vector
<
CNodePtr
>
GetCNodes
(
const
std
::
vector
<
AnfNodePtr
>
&
anf_nodes
)
{
std
::
vector
<
CNodePtr
>
cnodes
=
{};
size_t
i
=
0
;
for
(
const
auto
anf
:
anf_nodes
)
{
for
(
auto
anf
:
anf_nodes
)
{
MS_LOG
(
INFO
)
<<
"apply_list["
<<
i
++
<<
"] = "
<<
anf
->
DebugString
();
MS_EXCEPTION_IF_NULL
(
anf
);
if
(
anf
->
isa
<
CNode
>
())
{
cnodes
.
push_back
(
anf
->
cast
<
CNodePtr
>
());
}
}
return
std
::
move
(
cnodes
)
;
return
cnodes
;
}
std
::
vector
<
std
::
vector
<
CNodePtr
>>
GetChildList
(
const
KernelGraph
&
cur_graph
,
const
std
::
vector
<
CNodePtr
>
&
cnodes
)
{
...
...
@@ -189,7 +189,7 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co
ret
.
push_back
(
std
::
vector
<
CNodePtr
>
(
cnodes
.
begin
()
+
after_call_index
,
cnodes
.
end
()));
}
}
return
std
::
move
(
ret
)
;
return
ret
;
}
void
UpdateRealInput
(
KernelGraph
*
graph
)
{
...
...
@@ -232,7 +232,7 @@ void UpdateRealInput(KernelGraph *graph) {
auto
ret
=
std
::
vector
<
AnfNodePtr
>
(
partial_cnode
->
inputs
().
begin
()
+
2
,
partial_cnode
->
inputs
().
end
());
partial_cnode
->
set_inputs
(
std
::
vector
<
AnfNodePtr
>
(
partial_cnode
->
inputs
().
begin
(),
partial_cnode
->
inputs
().
begin
()
+
2
));
return
std
::
move
(
ret
)
;
return
ret
;
};
bind_call_partial_with_parameter
(
child_graphs
[
0
]
->
inputs
(),
get_partial_args
(
2
),
child_graphs
[
0
].
get
());
bind_call_partial_with_parameter
(
child_graphs
[
1
]
->
inputs
(),
get_partial_args
(
3
),
child_graphs
[
1
].
get
());
...
...
@@ -256,27 +256,28 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
// split switch
SplitGraph
(
graph
);
// insert goto labels and label_sets
LinkChildGraphs
(
graph
.
get
(
));
LinkChildGraphs
(
NOT_NULL
(
graph
));
// resource initialize
InitRuntimeResource
();
// ir fusion
IRFusion
(
graph
);
// kernel select
SelectKernelGraphKernel
(
*
graph
);
// convert model of predict module
ConvertPredictModel
(
graph
);
// hardware optimize
HardwareOptimizeGraphs
(
graph
);
// assign label
AssignLabel
(
NOT_NULL
(
graph
));
if
(
!
graph
->
executable
())
{
return
graph
->
graph_id
();
}
for
(
auto
iter
:
graphs_
)
{
if
(
iter
.
second
==
graph
)
{
MS_LOG
(
INFO
)
<<
"Entry graph "
<<
graph
->
ToString
()
<<
" graph id "
<<
graph
->
graph_id
();
final_graph_id_
=
graph
->
graph_id
();
}
MS_LOG
(
INFO
)
<<
"CompileChildGraph "
<<
iter
.
second
->
ToString
();
CompileChildGraph
(
iter
.
second
);
}
// adjust kernel
AdjustKernel
(
graph
);
// root graph valiate,include genearte execute order and so on
RootGraphExecutorValidate
(
graph
.
get
());
// assign stream
AssignStream
(
graph
);
// assign label
AssignLabel
(
NOT_NULL
(
graph
));
// build kernel if node is cnode
BuildKernel
(
graph
);
// alloc mem
MemoryAlloc
(
graph
.
get
());
// task generate
...
...
@@ -556,7 +557,7 @@ void AscendSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_grap
MS_LOG
(
INFO
)
<<
"Finish!"
;
}
void
AscendSession
::
AssignLabel
(
NotNull
<
const
KernelGraphPtr
&
>
kernel_graph
)
const
{
void
AscendSession
::
AssignLabel
(
NotNull
<
KernelGraphPtr
>
kernel_graph
)
const
{
MS_LOG
(
INFO
)
<<
"Start!"
;
device
::
ascend
::
AscendLabelAssign
::
GetInstance
().
AssignLabel
(
kernel_graph
);
MS_LOG
(
INFO
)
<<
"Finish!"
;
...
...
@@ -1305,29 +1306,13 @@ void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived
}
void
AscendSession
::
InsertDependToGraph
(
GraphId
graph_id
,
const
AnfNodePtr
&
attch_node
)
{
MS_LOG
(
INFO
)
<<
"Insert depend at the end of graph, the attach node is "
<<
attch_node
->
DebugString
();
auto
graph
=
GetGraph
(
graph_id
);
MS_EXCEPTION_IF_NULL
(
graph
);
std
::
vector
<
AnfNodePtr
>
inputs
=
{
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
"depend"
))};
auto
return_node
=
graph
->
get_return
();
MS_EXCEPTION_IF_NULL
(
return_node
);
inputs
.
push_back
(
return_node
->
input
(
1
));
inputs
.
push_back
(
attch_node
);
auto
depend_node
=
graph
->
NewCNode
(
inputs
);
return_node
->
set_input
(
1
,
depend_node
);
AscendControlParser
::
InsertDependToGraph
(
NOT_NULL
(
GetGraph
(
graph_id
)),
NOT_NULL
(
attch_node
));
}
void
AscendSession
::
InsertControlDependToGraph
(
GraphId
graph_id
,
const
AnfNodePtr
&
first_node
,
const
AnfNodePtr
&
second_node
)
{
MS_LOG
(
INFO
)
<<
"Insert control depend at the end of graph, the first node is "
<<
first_node
->
DebugString
()
<<
", the second node is "
<<
second_node
->
DebugString
();
auto
graph
=
GetGraph
(
graph_id
);
MS_EXCEPTION_IF_NULL
(
graph
);
std
::
vector
<
AnfNodePtr
>
inputs
=
{
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
"ControlDepend"
))};
inputs
.
push_back
(
first_node
);
inputs
.
push_back
(
second_node
);
auto
control_depend
=
graph
->
NewCNode
(
inputs
);
InsertDependToGraph
(
graph_id
,
control_depend
);
AscendControlParser
::
InsertControlDependToGraph
(
NOT_NULL
(
GetGraph
(
graph_id
)),
NOT_NULL
(
first_node
),
NOT_NULL
(
second_node
));
}
size_t
AscendSession
::
ExecOrderOfChildGraph
(
GraphId
final_graph
,
GraphId
child_graph
)
{
...
...
@@ -1482,5 +1467,8 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
SplitGraph
(
child_graph
);
}
}
void
AscendSession
::
LinkChildGraphs
(
NotNull
<
KernelGraphPtr
>
graph
)
{
AscendControlParser
::
LinkGraph
(
graph
);
}
}
// namespace session
}
// namespace mindspore
mindspore/ccsrc/session/ascend_session.h
浏览文件 @
af5019b9
...
...
@@ -28,6 +28,7 @@
#include "session/kernel_graph.h"
#include "kernel/kernel.h"
#include "session/session_factory.h"
#include "session/ascend_control_parser.h"
namespace
mindspore
{
namespace
session
{
...
...
@@ -74,7 +75,7 @@ class AscendSession : public SessionBasic {
void
AdjustKernel
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
)
const
;
void
RunOpAdjustKernel
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
)
const
;
void
AssignStream
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
)
const
;
void
AssignLabel
(
NotNull
<
const
KernelGraphPtr
&
>
kernel_graph
)
const
;
void
AssignLabel
(
NotNull
<
KernelGraphPtr
>
kernel_graph
)
const
;
void
BuildKernel
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
)
const
;
void
MemoryAlloc
(
KernelGraph
*
kernel_graph
)
const
;
void
RunOpMemoryAlloc
(
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
KernelGraph
*
kernel_graph
)
const
;
...
...
@@ -96,7 +97,8 @@ class AscendSession : public SessionBasic {
void
SetFinalGraphOutput
(
const
VectorRef
&
vec_output
);
void
SplitGraph
(
const
KernelGraphPtr
&
graph
);
void
LinkChildGraphs
(
KernelGraph
*
graph
)
{}
void
LinkChildGraphs
(
NotNull
<
KernelGraphPtr
>
graph
);
void
IRFusion
(
const
KernelGraphPtr
&
graph
)
{}
void
SelectKernelGraphKernel
(
const
KernelGraph
&
graph
)
{}
void
ConvertPredictModel
(
const
KernelGraphPtr
graph
)
{}
...
...
mindspore/ccsrc/session/kernel_graph.h
浏览文件 @
af5019b9
...
...
@@ -28,6 +28,7 @@
#include "ir/func_graph.h"
#include "ir/anf.h"
#include "utils/graph_utils.h"
#include "utils/contract.h"
#include "device/kernel_info.h"
namespace
mindspore
{
...
...
@@ -108,6 +109,7 @@ class KernelGraph : public FuncGraph {
std
::
vector
<
std
::
shared_ptr
<
KernelGraph
>>
child_graph_order
()
const
{
return
child_graph_order_
;
}
// checkout whether current graph is leaf graph
bool
IsLeafGraph
()
const
;
// set input_tensors pointer of control parameter
void
set_input_ctrl_tensors
(
const
std
::
shared_ptr
<
std
::
vector
<
tensor
::
TensorPtr
>>
&
input_tensors_ptr
)
{
input_ctrl_tensors_
=
input_tensors_ptr
;
...
...
@@ -126,6 +128,9 @@ class KernelGraph : public FuncGraph {
// used to dump ir
std
::
string
ToString
()
const
override
;
void
set_start_label
(
const
CNodePtr
&
start_label
)
{
start_label_
=
start_label
;
}
CNodePtr
get_start_label
()
{
return
start_label_
;
}
private:
// remove value node form graph
bool
RemoveValueNodeFromGraph
(
const
ValueNodePtr
&
value_node
);
...
...
@@ -168,12 +173,16 @@ class KernelGraph : public FuncGraph {
std
::
map
<
AnfNodePtr
,
std
::
shared_ptr
<
KernelGraph
>>
node_to_child_graphs_
;
// child graph execute order in root graph
std
::
vector
<
std
::
shared_ptr
<
KernelGraph
>>
child_graph_order_
;
// input_tensors of control parameter
std
::
shared_ptr
<
std
::
vector
<
tensor
::
TensorPtr
>>
input_ctrl_tensors_
;
// parameter graph
std
::
shared_ptr
<
KernelGraph
>
parent_graph_
;
// record real parameters,inputs_ is the formal parameters
std
::
map
<
AnfNodePtr
,
std
::
set
<
AnfNodePtr
>>
real_inputs_
;
CNodePtr
start_label_
;
};
}
// namespace session
using
KernelGraphPtr
=
std
::
shared_ptr
<
session
::
KernelGraph
>
;
...
...
tests/ut/cpp/CMakeLists.txt
浏览文件 @
af5019b9
...
...
@@ -61,6 +61,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/transform/*.cc"
"../../../mindspore/ccsrc/session/anf_runtime_algorithm.cc"
"../../../mindspore/ccsrc/session/ascend_session.cc"
"../../../mindspore/ccsrc/session/ascend_control_parser.cc"
"../../../mindspore/ccsrc/session/kernel_graph.cc"
"../../../mindspore/ccsrc/session/session_basic.cc"
"../../../mindspore/ccsrc/session/session_factory.cc"
...
...
tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc
浏览文件 @
af5019b9
...
...
@@ -22,7 +22,9 @@ namespace mindspore {
namespace
device
{
namespace
ascend
{
void
AscendLabelAssign
::
AssignLabel
(
NotNull
<
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&>
)
{}
void
AscendLabelAssign
::
AssignLabel
(
NotNull
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
graph
)
{}
uint32_t
AscendLabelAssign
::
GetLabelNum
(
NotNull
<
const
session
::
KernelGraph
*>
graph
)
{
return
1
;
}
uint32_t
AscendLabelAssign
::
GetLabelNum
(
NotNull
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
graph
)
{
return
1
;
}
void
AscendStreamAssign
::
AssignStreamNew
(
const
KernelGraphPtr
&
graph
)
{
return
;
}
...
...
@@ -39,9 +41,7 @@ bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::ve
}
// namespace ascend
void
KernelAdjust
::
Reorder
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph_ptr
)
{
return
;
}
void
KernelAdjust
::
InsertSwitchLoop
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph_ptr
)
{
return
;
}
bool
KernelAdjust
::
StepLoadCtrlInputs
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph_ptr
)
{
return
true
;
}
bool
KernelAdjust
::
StepLoadCtrlInputs
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph_ptr
)
{
return
true
;
}
bool
KernelAdjust
::
NeedInsertSwitch
()
{
return
true
;
}
void
KernelAdjust
::
Profiling
(
NotNull
<
session
::
KernelGraph
*>
kernel_graph_ptr
)
{
return
;
}
}
// namespace device
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录