Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
cc54bb56
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
cc54bb56
编写于
4月 17, 2020
作者:
C
chenfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move opt to build graph
上级
64abbeaa
变更
14
展开全部
隐藏空白更改
内联
并排
Showing
14 changed file
with
381 addition
and
152 deletion
+381
-152
mindspore/ccsrc/kernel/kernel_build_info.cc
mindspore/ccsrc/kernel/kernel_build_info.cc
+8
-4
mindspore/ccsrc/kernel/kernel_build_info.h
mindspore/ccsrc/kernel/kernel_build_info.h
+3
-0
mindspore/ccsrc/kernel/kernel_query.cc
mindspore/ccsrc/kernel/kernel_query.cc
+3
-3
mindspore/ccsrc/kernel/mng/rt_kernel_info.cc
mindspore/ccsrc/kernel/mng/rt_kernel_info.cc
+24
-8
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
.../ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
+8
-4
mindspore/ccsrc/pre_activate/common/common_backend_optimization.cc
.../ccsrc/pre_activate/common/common_backend_optimization.cc
+1
-0
mindspore/ccsrc/session/anf_runtime_algorithm.cc
mindspore/ccsrc/session/anf_runtime_algorithm.cc
+24
-4
mindspore/ccsrc/session/ascend_session.cc
mindspore/ccsrc/session/ascend_session.cc
+219
-113
mindspore/ccsrc/session/ascend_session.h
mindspore/ccsrc/session/ascend_session.h
+21
-2
mindspore/ccsrc/session/kernel_graph.cc
mindspore/ccsrc/session/kernel_graph.cc
+43
-6
mindspore/ccsrc/session/kernel_graph.h
mindspore/ccsrc/session/kernel_graph.h
+12
-1
mindspore/ccsrc/session/session_basic.cc
mindspore/ccsrc/session/session_basic.cc
+12
-7
mindspore/ccsrc/session/session_basic.h
mindspore/ccsrc/session/session_basic.h
+2
-0
mindspore/ops/_op_impl/tbe/assign.py
mindspore/ops/_op_impl/tbe/assign.py
+1
-0
未找到文件。
mindspore/ccsrc/kernel/kernel_build_info.cc
浏览文件 @
cc54bb56
...
...
@@ -22,28 +22,32 @@ namespace mindspore {
namespace
kernel
{
std
::
string
KernelBuildInfo
::
GetInputFormat
(
size_t
input_index
)
const
{
if
(
input_index
>=
inputs_format_
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
input_index
<<
"] is exceed the number of input node"
;
MS_LOG
(
ERROR
)
<<
"The index ["
<<
input_index
<<
"] is exceed the number of input node"
;
return
kInvalidFormat
;
}
return
inputs_format_
[
input_index
];
}
std
::
string
KernelBuildInfo
::
GetOutputFormat
(
size_t
output_index
)
const
{
if
(
output_index
>=
outputs_format_
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
output_index
<<
"] is exceed the number of input node"
;
MS_LOG
(
ERROR
)
<<
"The index ["
<<
output_index
<<
"] is exceed the number of input node"
;
return
kInvalidFormat
;
}
return
outputs_format_
[
output_index
];
}
TypeId
KernelBuildInfo
::
GetInputDeviceType
(
size_t
input_index
)
const
{
if
(
input_index
>=
inputs_device_type_
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
input_index
<<
"] is exceed the number of input node"
;
MS_LOG
(
ERROR
)
<<
"The index ["
<<
input_index
<<
"] is exceed the number of input"
;
return
TypeId
::
kNumberTypeEnd
;
}
return
inputs_device_type_
[
input_index
];
}
TypeId
KernelBuildInfo
::
GetOutputDeviceType
(
size_t
output_index
)
const
{
if
(
output_index
>=
outputs_device_type_
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
output_index
<<
"] is exceed the number of input node"
;
MS_LOG
(
ERROR
)
<<
"The index ["
<<
output_index
<<
"] is exceed the number of output"
;
return
TypeId
::
kNumberTypeEnd
;
}
return
outputs_device_type_
[
output_index
];
}
...
...
mindspore/ccsrc/kernel/kernel_build_info.h
浏览文件 @
cc54bb56
...
...
@@ -82,6 +82,9 @@ class KernelBuildInfo {
bool
operator
==
(
const
KernelBuildInfo
&
other
)
const
;
public:
static
auto
constexpr
kInvalidFormat
=
"InvalidFormat"
;
private:
KernelType
kernel_type_
;
std
::
vector
<
std
::
string
>
inputs_format_
;
...
...
mindspore/ccsrc/kernel/kernel_query.cc
浏览文件 @
cc54bb56
...
...
@@ -26,7 +26,7 @@
namespace
mindspore
{
namespace
kernel
{
namespace
{
void
FilterInva
il
dKernelInfo
(
const
CNodePtr
&
kernel_node
,
void
FilterInva
li
dKernelInfo
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
*
kernel_info_list
)
{
MS_EXCEPTION_IF_NULL
(
kernel_info_list
);
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
filtered_list
;
...
...
@@ -63,9 +63,9 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
HcclMetadataInfo
(
kernel_node
,
kernel_info_list
);
}
if
(
kernel_info_list
->
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"
op
"
<<
kernel_node
->
DebugString
()
<<
"kernel query fail!"
;
MS_LOG
(
EXCEPTION
)
<<
"
Op
"
<<
kernel_node
->
DebugString
()
<<
"kernel query fail!"
;
}
FilterInva
il
dKernelInfo
(
kernel_node
,
kernel_info_list
);
FilterInva
li
dKernelInfo
(
kernel_node
,
kernel_info_list
);
}
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/mng/rt_kernel_info.cc
浏览文件 @
cc54bb56
...
...
@@ -46,24 +46,40 @@ RtKerDescFactory &RtKerDescFactory::Get() {
void
GetRtKelInfo
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
*
kernel_info_list
)
{
MS_LOG
(
INFO
)
<<
"Mng kernel Info."
;
MS_EXCEPTION_IF_NULL
(
kernel_info_list
);
MS_EXCEPTION_IF_NULL
(
kernel_node
);
std
::
string
opNameLower
=
AnfAlgo
::
GetCNodeName
(
kernel_node
);
(
void
)
std
::
transform
(
opNameLower
.
begin
(),
opNameLower
.
end
(),
opNameLower
.
begin
(),
::
tolower
);
auto
ker_desc_ptr
=
RtKerDescFactory
::
Create
(
opNameLower
);
if
(
ker_desc_ptr
==
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"Mng can't find op ["
<<
opNameLower
<<
"]."
;
if
(
ker_desc_ptr
!=
nullptr
&&
!
ker_desc_ptr
->
GetKernelInfo
().
empty
()
)
{
*
kernel_info_list
=
ker_desc_ptr
->
GetKernelInfo
()
;
return
;
}
MS_EXCEPTION_IF_NULL
(
ker_desc_ptr
);
auto
kernel_info
=
ker_desc_ptr
->
GetKernelInfo
();
if
(
kernel_info
.
empty
())
{
MS_LOG
(
DEBUG
)
<<
"Rt dose not have op ["
<<
opNameLower
<<
"]."
;
// if can't find kernel info in kernel info database, use the default kernel info
auto
node_name
=
AnfAlgo
::
GetCNodeName
(
kernel_node
);
if
(
node_name
==
"StreamSwitch"
||
node_name
==
"StreamActive"
)
{
auto
kernel_build_info_builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
();
// set input infos
auto
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
kernel_build_info_builder
->
SetInputsFormat
(
std
::
vector
<
std
::
string
>
(
input_num
,
kOpFormat_DEFAULT
));
std
::
vector
<
TypeId
>
input_types
=
{};
for
(
size_t
i
=
0
;
i
<
input_num
;
i
++
)
{
input_types
.
push_back
(
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
kernel_node
,
i
));
}
kernel_build_info_builder
->
SetInputsDeviceType
(
input_types
);
// set output info
auto
output_num
=
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
);
kernel_build_info_builder
->
SetOutputsFormat
(
std
::
vector
<
std
::
string
>
(
output_num
,
kOpFormat_DEFAULT
));
kernel_build_info_builder
->
SetOutputsDeviceType
(
std
::
vector
<
TypeId
>
(
output_num
,
TypeId
::
kTypeUnknown
));
// set ohter info
kernel_build_info_builder
->
SetFusionType
(
kernel
::
FusionType
::
OPAQUE
);
kernel_build_info_builder
->
SetProcessor
(
kernel
::
Processor
::
AICORE
);
kernel_build_info_builder
->
SetKernelType
(
KernelType
::
RT_KERNEL
);
kernel_info_list
->
push_back
(
kernel_build_info_builder
->
Build
());
return
;
}
*
kernel_info_list
=
kernel_info
;
MS_LOG
(
DEBUG
)
<<
"Rt dose not have op ["
<<
opNameLower
<<
"]."
;
}
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
浏览文件 @
cc54bb56
...
...
@@ -186,7 +186,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
save_graphs_path
=
"."
;
}
if
(
save_graphs
)
{
std
::
string
file_path
=
save_graphs_path
+
"/"
+
"hwopt_d_ir_fusion_before.ir"
;
std
::
string
file_path
=
save_graphs_path
+
"/"
+
"hwopt_d_ir_fusion_before"
+
"_graph_"
+
std
::
to_string
(
kernel_graph
->
graph_id
())
+
".ir"
;
DumpIR
(
file_path
,
kernel_graph
);
DumpIRProto
(
kernel_graph
,
"before_hwopt"
);
}
...
...
@@ -208,7 +209,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
(
void
)
optimizer
->
Optimize
(
kernel_graph
);
kernel_graph
->
SetExecOrderByDefault
();
if
(
save_graphs
)
{
std
::
string
file_path
=
save_graphs_path
+
"/"
+
"hwopt_d_ir_fusion_after.ir"
;
std
::
string
file_path
=
save_graphs_path
+
"/"
+
"hwopt_d_ir_fusion_after"
+
"_graph_"
+
std
::
to_string
(
kernel_graph
->
graph_id
())
+
".ir "
;
DumpIR
(
file_path
,
kernel_graph
);
}
}
...
...
@@ -252,7 +254,8 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
save_graphs_path
=
"."
;
}
if
(
save_graphs
)
{
std
::
string
file_path
=
save_graphs_path
+
"/"
+
"hwopt_d_before.ir"
;
std
::
string
file_path
=
save_graphs_path
+
"/"
+
"hwopt_d_before"
+
"_graph_"
+
std
::
to_string
(
kernel_graph
->
graph_id
())
+
".ir"
;
DumpIR
(
file_path
,
kernel_graph
);
}
// data layout optimization
...
...
@@ -278,7 +281,8 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
(
void
)
optimizer
->
Optimize
(
kernel_graph
);
kernel_graph
->
SetExecOrderByDefault
();
if
(
save_graphs
)
{
std
::
string
file_path
=
save_graphs_path
+
"/"
+
"hwopt_d_end.ir"
;
std
::
string
file_path
=
save_graphs_path
+
"/"
+
"hwopt_d_end"
+
"_graph_"
+
std
::
to_string
(
kernel_graph
->
graph_id
())
+
".ir"
;
DumpIR
(
file_path
,
kernel_graph
,
true
);
DumpIRProto
(
kernel_graph
,
"after_hwopt"
);
}
...
...
mindspore/ccsrc/pre_activate/common/common_backend_optimization.cc
浏览文件 @
cc54bb56
...
...
@@ -27,6 +27,7 @@
namespace
mindspore
{
namespace
opt
{
void
BackendCommonOptimization
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph
)
{
MS_LOG
(
INFO
)
<<
"start common opt graph:"
<<
kernel_graph
->
graph_id
();
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
bool
save_graphs
=
context_ptr
->
save_graphs_flag
();
...
...
mindspore/ccsrc/session/anf_runtime_algorithm.cc
浏览文件 @
cc54bb56
...
...
@@ -300,7 +300,12 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
build_info
);
return
build_info
->
GetOutputFormat
(
output_idx
);
auto
format
=
build_info
->
GetOutputFormat
(
output_idx
);
if
(
format
==
kernel
::
KernelBuildInfo
::
kInvalidFormat
)
{
MS_LOG
(
EXCEPTION
)
<<
"Node ["
<<
node
->
DebugString
()
<<
"]"
<<
" has a invalid output format"
;
}
return
format
;
}
std
::
string
AnfRuntimeAlgorithm
::
GetInputFormat
(
const
AnfNodePtr
&
node
,
size_t
input_idx
)
{
...
...
@@ -314,7 +319,12 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
build_info
);
return
build_info
->
GetInputFormat
(
input_idx
);
auto
format
=
build_info
->
GetInputFormat
(
input_idx
);
if
(
format
==
kernel
::
KernelBuildInfo
::
kInvalidFormat
)
{
MS_LOG
(
EXCEPTION
)
<<
"Node ["
<<
node
->
DebugString
()
<<
"]"
<<
" has a invalid input format"
;
}
return
format
;
}
KernelWithIndex
AnfRuntimeAlgorithm
::
GetPrevNodeOutput
(
const
AnfNodePtr
&
anf_node
,
size_t
input_idx
)
{
...
...
@@ -481,7 +491,12 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
build_info
);
return
build_info
->
GetOutputDeviceType
(
output_idx
);
auto
dtype
=
build_info
->
GetOutputDeviceType
(
output_idx
);
if
(
dtype
==
TypeId
::
kNumberTypeEnd
)
{
MS_LOG
(
EXCEPTION
)
<<
"Node ["
<<
node
->
DebugString
()
<<
"]"
<<
" has a invalid dtype"
;
}
return
dtype
;
}
TypeId
AnfRuntimeAlgorithm
::
GetInputDeviceDataType
(
const
AnfNodePtr
&
node
,
size_t
input_idx
)
{
...
...
@@ -494,7 +509,12 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
build_info
);
return
build_info
->
GetInputDeviceType
(
input_idx
);
auto
dtype
=
build_info
->
GetInputDeviceType
(
input_idx
);
if
(
dtype
==
TypeId
::
kNumberTypeEnd
)
{
MS_LOG
(
EXCEPTION
)
<<
"Node ["
<<
node
->
DebugString
()
<<
"]"
<<
" has a invalid dtype"
;
}
return
dtype
;
}
TypeId
AnfRuntimeAlgorithm
::
GetPrevNodeOutputDeviceDataType
(
const
AnfNodePtr
&
anf_node
,
size_t
input_idx
)
{
...
...
mindspore/ccsrc/session/ascend_session.cc
浏览文件 @
cc54bb56
此差异已折叠。
点击以展开。
mindspore/ccsrc/session/ascend_session.h
浏览文件 @
cc54bb56
...
...
@@ -21,6 +21,9 @@
#include <vector>
#include <utility>
#include <stack>
#include <map>
#include <tuple>
#include <set>
#include "session/session_basic.h"
#include "session/kernel_graph.h"
#include "kernel/kernel.h"
...
...
@@ -60,6 +63,8 @@ class AscendSession : public SessionBasic {
GraphId
GetFinalRunGraph
()
const
override
{
return
final_graph_id_
;
}
// insert active to graph
void
SetActive
(
GraphId
,
GraphId
)
override
;
// compile child graph when session have multiple child graphs
void
CompileChildGraph
(
const
KernelGraphPtr
&
child_graph
);
private:
void
InitRuntimeResource
();
...
...
@@ -95,12 +100,16 @@ class AscendSession : public SessionBasic {
size_t
ExecOrderOfChildGraph
(
GraphId
final_graph
,
GraphId
child_graph
);
// handle condition graph from vm
void
InsertSwitchToGraph
(
GraphId
condition_graph_id
,
GraphId
true_graph_id
);
// insert depend to graph, used to attch control nodes to graph
void
InsertDependToGraph
(
GraphId
graph_id
,
const
AnfNodePtr
&
attch_node
);
// insert depend to graph, used to attch control nodes to graph
void
InsertControlDependToGraph
(
GraphId
graph_id
,
const
AnfNodePtr
&
first_node
,
const
AnfNodePtr
&
second_node
);
// Get graph by graph id ,if not exist return null ptr
KernelGraphPtr
GetGraph
(
GraphId
graph_id
);
// set child graph parameter if front arg is a anf
void
SetChildGraphParameter
(
const
AnfNodePtr
&
front_anf
,
const
AnfNodePtr
&
backend_parameter
);
void
SetChildGraphParameter
(
const
AnfNodePtr
&
front_anf
,
GraphId
to_graph_id
,
size_t
input_idx
);
// set child graph parameter if front arg is a tensor
void
SetChildGraphParameter
(
const
tensor
::
TensorPtr
&
front_tensor
,
const
AnfNodePtr
&
backend_parameter
);
void
SetChildGraphParameter
(
const
tensor
::
TensorPtr
&
front_tensor
,
GraphId
to_graph_id
,
size_t
input_idx
);
// update the execution order of all child graphs
void
UpdateGraphOrder
(
GraphId
to_graph
);
// handle switch when merge
...
...
@@ -113,6 +122,12 @@ class AscendSession : public SessionBasic {
void
CopyOutputOfIf
(
GraphId
false_graph_id
);
// check if graph cache exist
bool
GraphCacheExist
(
const
GraphInfo
&
graph_info
)
const
;
// insert all assign to child graph
void
InsertAllAssigns
();
// create fake output of final graph
AnfNodePtr
CreateFakeOutput
(
GraphId
final_graph_id
,
const
AnfNodePtr
&
true_output
);
// sync intial tensors' data to device
void
SyncInitialTenosrToDevice
();
// member variables
// key is final_graph_id,value is child graph execute order of final graph
...
...
@@ -124,6 +139,10 @@ class AscendSession : public SessionBasic {
// record all conditions
std
::
unordered_map
<
GraphId
,
std
::
pair
<
GraphId
,
GraphId
>>
switches_
;
std
::
unordered_map
<
GraphId
,
AnfNodePtr
>
condition_output_
;
// share parameters
std
::
set
<
std
::
tuple
<
AnfNodePtr
,
GraphId
,
size_t
>>
assigns_
;
// initial tensors, these tensor will sync data to device before run graph
std
::
map
<
std
::
pair
<
GraphId
,
size_t
>
,
tensor
::
TensorPtr
>
initial_tenosrs_
;
// final_graph_id is used in every root graph has it's own session situation
GraphId
final_graph_id_
;
};
...
...
mindspore/ccsrc/session/kernel_graph.cc
浏览文件 @
cc54bb56
...
...
@@ -295,10 +295,7 @@ ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
// set the format of value_node to DEFAULT_FORMAT
kernel_build_info_builder
->
SetOutputsFormat
(
std
::
vector
<
std
::
string
>
{
kOpFormat_DEFAULT
});
// set value node initial device data type = infer data type
std
::
vector
<
TypeId
>
types
;
for
(
size_t
index
=
0
;
index
<
AnfAlgo
::
GetOutputTensorNum
(
value_node
);
++
index
)
{
types
.
push_back
(
kTypeUnknown
);
}
std
::
vector
<
TypeId
>
types
=
std
::
vector
<
TypeId
>
(
AnfAlgo
::
GetOutputTensorNum
(
value_node
),
kTypeUnknown
);
kernel_build_info_builder
->
SetOutputsDeviceType
(
types
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
kernel_build_info_builder
->
Build
(),
new_value_node
.
get
());
AnfAlgo
::
SetGraphId
(
graph_id_
,
new_value_node
.
get
());
...
...
@@ -330,10 +327,11 @@ void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, cons
MS_LOG
(
EXCEPTION
)
<<
"old can't be same with new"
;
}
if
(
backend_front_anf_map_
.
find
(
old_backend_anf
)
==
backend_front_anf_map_
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"old_backend_anf "
<<
old_backend_anf
->
DebugString
()
<<
" is not exist in the map"
;
MS_LOG
(
DEBUG
)
<<
"old_backend_anf "
<<
old_backend_anf
->
DebugString
()
<<
" is not exist in the map"
;
return
;
}
if
(
front_backend_anf_map_
.
find
(
backend_front_anf_map_
[
old_backend_anf
])
==
front_backend_anf_map_
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"anf is not exist in the map
e
,old "
<<
old_backend_anf
->
DebugString
();
MS_LOG
(
EXCEPTION
)
<<
"anf is not exist in the map ,old "
<<
old_backend_anf
->
DebugString
();
}
front_backend_anf_map_
[
backend_front_anf_map_
[
old_backend_anf
]]
=
new_backend_anf
;
backend_front_anf_map_
[
new_backend_anf
]
=
backend_front_anf_map_
[
old_backend_anf
];
...
...
@@ -528,5 +526,44 @@ bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) {
}
return
false
;
}
void
KernelGraph
::
ReplaceNode
(
const
AnfNodePtr
&
old_anf_node
,
AnfNodePtr
new_anf_node
)
{
MS_EXCEPTION_IF_NULL
(
old_anf_node
);
MS_EXCEPTION_IF_NULL
(
new_anf_node
);
MS_EXCEPTION_IF_NULL
(
inputs_
);
auto
it
=
node_output_edges_
.
find
(
old_anf_node
);
if
(
it
==
node_output_edges_
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"Can't find anf node in node_output_edges map"
;
}
auto
&
outputs
=
it
->
second
;
for
(
auto
&
output_node
:
outputs
)
{
auto
output_cnode
=
output_node
.
first
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
output_cnode
);
auto
&
output_node_inputs
=
output_cnode
->
inputs
();
for
(
size_t
i
=
1
;
i
<
output_node_inputs
.
size
();
i
++
)
{
if
(
output_node_inputs
[
i
]
==
old_anf_node
)
{
output_cnode
->
set_input
(
i
,
new_anf_node
);
}
}
// update graph inputs
for
(
size_t
i
=
0
;
i
<
inputs_
->
size
();
i
++
)
{
if
((
*
inputs_
)[
i
]
==
old_anf_node
)
{
(
*
inputs_
)[
i
]
=
new_anf_node
;
break
;
}
}
}
// update front to backend map
FrontBackendlMapUpdate
(
old_anf_node
,
new_anf_node
);
// update output depend relations
node_output_edges_
[
new_anf_node
]
=
it
->
second
;
(
void
)
node_output_edges_
.
erase
(
old_anf_node
);
}
void
KernelGraph
::
UpdateExecuteKernelStreamLabel
()
{
for
(
auto
&
kernel
:
execution_order_
)
{
AnfAlgo
::
SetStreamDistinctionLabel
(
stream_distinction_label_
,
kernel
.
get
());
}
}
}
// namespace session
}
// namespace mindspore
mindspore/ccsrc/session/kernel_graph.h
浏览文件 @
cc54bb56
...
...
@@ -27,6 +27,7 @@
#include "ir/func_graph.h"
#include "ir/anf.h"
#include "utils/graph_utils.h"
#include "device/kernel_info.h"
namespace
mindspore
{
namespace
session
{
...
...
@@ -37,6 +38,7 @@ class KernelGraph : public FuncGraph {
inputs_
=
std
::
make_shared
<
std
::
vector
<
AnfNodePtr
>>
();
execution_order_
=
{};
executable_
=
true
;
stream_distinction_label_
=
kInvalidDistincLabel
;
}
~
KernelGraph
()
override
=
default
;
...
...
@@ -88,7 +90,15 @@ class KernelGraph : public FuncGraph {
void
set_executable
(
bool
executable
)
{
executable_
=
executable
;
}
// set invalid inputs for control sink
std
::
vector
<
bool
>
*
MutableValidInputs
()
{
return
&
valid_inputs_
;
}
const
std
::
vector
<
bool
>
&
ValidInputs
()
const
{
return
valid_inputs_
;
}
std
::
vector
<
bool
>
valid_inputs
()
const
{
return
valid_inputs_
;
}
// replace node in graph
void
ReplaceNode
(
const
AnfNodePtr
&
old_anf_node
,
AnfNodePtr
new_anf_node
);
// set stream label of graph
void
set_stream_distinction_label
(
uint32_t
stream_label
)
{
stream_distinction_label_
=
stream_label
;
}
// get stream label of graph
uint32_t
stream_distinction_label
()
{
return
stream_distinction_label_
;
}
// refresh execute kernel stream label
void
UpdateExecuteKernelStreamLabel
();
private:
// remove value node form graph
...
...
@@ -108,6 +118,7 @@ class KernelGraph : public FuncGraph {
std
::
shared_ptr
<
std
::
vector
<
AnfNodePtr
>>
inputs_
;
std
::
vector
<
CNodePtr
>
execution_order_
;
uint32_t
graph_id_
;
uint32_t
stream_distinction_label_
;
// record map bettween front anf and backend anf,use two map implement bidirectional map
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
front_backend_anf_map_
;
...
...
mindspore/ccsrc/session/session_basic.cc
浏览文件 @
cc54bb56
...
...
@@ -417,9 +417,8 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
KernelGraphPtr
SessionBasic
::
ConstructKernelGraph
(
const
AnfNodePtrList
&
lst
,
const
AnfNodePtrList
&
outputs
)
{
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
other_graph_cnode
;
auto
graph
=
std
::
make_shared
<
KernelGraph
>
();
graph
->
set_graph_id
(
graph_sum_
);
MS_LOG
(
INFO
)
<<
"Create graph: "
<<
graph_sum_
;
auto
graph
=
NewKernelGraph
();
MS_LOG
(
INFO
)
<<
"Create graph: "
<<
graph
->
graph_id
();
size_t
from_other_graph_depend_num
=
0
;
for
(
const
auto
&
node
:
lst
)
{
MS_EXCEPTION_IF_NULL
(
node
);
...
...
@@ -456,7 +455,6 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
}
graph
->
SetExecOrderByDefault
();
opt
::
BackendCommonOptimization
(
graph
);
graphs_
[
graph_sum_
++
]
=
graph
;
return
graph
;
}
...
...
@@ -588,14 +586,14 @@ void SessionBasic::Summary(KernelGraph *graph) {
CNodePtr
SessionBasic
::
ConstructOutput
(
const
AnfNodePtrList
&
outputs
,
const
std
::
shared_ptr
<
KernelGraph
>
&
graph
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
std
::
vector
<
AnfNodePtr
>
output_args
;
for
(
const
auto
&
output
:
outputs
)
{
MS_LOG
(
INFO
)
<<
"output:"
<<
output
->
DebugString
();
}
auto
FindEqu
=
[
graph
,
outputs
](
const
AnfNodePtr
&
out
)
->
AnfNodePtr
{
auto
backend_anf
=
graph
->
GetBackendAnfByFrontAnf
(
out
);
if
(
backend_anf
!=
nullptr
)
{
return
backend_anf
;
}
for
(
const
auto
&
output
:
outputs
)
{
MS_LOG
(
INFO
)
<<
"output:"
<<
output
->
DebugString
();
}
MS_LOG
(
EXCEPTION
)
<<
"Can't find the node in the equiv map!"
;
};
output_args
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
...
...
@@ -695,5 +693,12 @@ BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) {
MS_LOG
(
EXCEPTION
)
<<
"The output is not a base ref list or a tensor!"
;
}
}
KernelGraphPtr
SessionBasic
::
NewKernelGraph
()
{
auto
graph
=
std
::
make_shared
<
KernelGraph
>
();
graph
->
set_graph_id
(
graph_sum_
);
graphs_
[
graph_sum_
++
]
=
graph
;
return
graph
;
}
}
// namespace session
}
// namespace mindspore
mindspore/ccsrc/session/session_basic.h
浏览文件 @
cc54bb56
...
...
@@ -104,6 +104,8 @@ class SessionBasic {
const
std
::
vector
<
bool
>
&
tensors_mask
);
// trans BaseRef list to py::tuple
BaseRef
TransformBaseRefListToTuple
(
const
BaseRef
&
base_ref
);
// create a new kernel graph and update the graph sum
KernelGraphPtr
NewKernelGraph
();
std
::
unordered_map
<
GraphId
,
std
::
shared_ptr
<
KernelGraph
>>
graphs_
;
std
::
unordered_map
<
GraphInfo
,
std
::
shared_ptr
<
KernelGraph
>>
run_op_graphs_
;
...
...
mindspore/ops/_op_impl/tbe/assign.py
浏览文件 @
cc54bb56
...
...
@@ -27,6 +27,7 @@ assign_op_info = TBERegOp("Assign") \
.
input
(
1
,
"value"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
I8_Default
,
DataType
.
I8_Default
)
\
.
dtype_format
(
DataType
.
BOOL_Default
,
DataType
.
BOOL_Default
,
DataType
.
BOOL_Default
)
\
.
dtype_format
(
DataType
.
I8_5HD
,
DataType
.
I8_5HD
,
DataType
.
I8_5HD
)
\
.
dtype_format
(
DataType
.
U8_Default
,
DataType
.
U8_Default
,
DataType
.
U8_Default
)
\
.
dtype_format
(
DataType
.
U8_5HD
,
DataType
.
U8_5HD
,
DataType
.
U8_5HD
)
\
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录