Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6316343e
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6316343e
编写于
8月 29, 2020
作者:
L
laiyongqiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add atomic clean op for every communication op's input
上级
034453e4
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
49 addition
and
3 deletion
+49
-3
mindspore/ccsrc/device/ascend/kernel_build_ascend.cc
mindspore/ccsrc/device/ascend/kernel_build_ascend.cc
+49
-3
未找到文件。
mindspore/ccsrc/device/ascend/kernel_build_ascend.cc
浏览文件 @
6316343e
...
...
@@ -184,11 +184,17 @@ bool IsAtomicNode(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL
(
kernel_node
);
auto
kernel_mod
=
AnfAlgo
::
GetKernelMod
(
kernel_node
);
MS_EXCEPTION_IF_NULL
(
kernel_mod
);
auto
atomic_flag
=
false
;
std
::
vector
<
size_t
>
clean_output_indexs
;
if
(
AnfAlgo
::
HasNodeAttr
(
kAttrAutomicOutputIndexs
,
kernel_node
))
{
clean_output_indexs
=
AnfAlgo
::
GetNodeAttr
<
std
::
vector
<
size_t
>>
(
kernel_node
,
kAttrAutomicOutputIndexs
);
atomic_flag
=
true
;
}
auto
parameters_indexs
=
kernel_mod
->
GenParameters
();
if
(
parameters_indexs
.
empty
())
{
return
false
;
return
atomic_flag
;
}
auto
atomic_flag
=
false
;
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
size_t
output_num
=
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
);
auto
workspace_size_list
=
kernel_mod
->
GetWorkspaceSizeList
();
...
...
@@ -199,7 +205,7 @@ bool IsAtomicNode(const CNodePtr &kernel_node) {
parameters_indexs
.
push_back
(
0
);
}
}
std
::
vector
<
size_t
>
clean_output_indexs
;
// in parameters data sort as input->workspace->output
size_t
index
=
0
;
while
(
index
<
output_num
)
{
...
...
@@ -210,6 +216,8 @@ bool IsAtomicNode(const CNodePtr &kernel_node) {
index
++
;
}
if
(
atomic_flag
)
{
std
::
set
<
size_t
>
s
(
clean_output_indexs
.
begin
(),
clean_output_indexs
.
end
());
clean_output_indexs
.
assign
(
s
.
begin
(),
s
.
end
());
AnfAlgo
::
SetNodeAttr
(
kAttrAutomicOutputIndexs
,
MakeValue
(
clean_output_indexs
),
kernel_node
);
}
for
(
size_t
i
=
0
;
i
<
workspace_num
;
++
i
)
{
...
...
@@ -238,11 +246,49 @@ bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) {
return
ret
;
}
std
::
map
<
AnfNodePtr
,
std
::
vector
<
size_t
>>
GetCommunicationOpInputInfo
(
const
mindspore
::
session
::
KernelGraph
*
kernel_graph
)
{
std
::
map
<
AnfNodePtr
,
std
::
vector
<
size_t
>>
comm_input_info_map
;
for
(
auto
&
kernel
:
kernel_graph
->
execution_order
())
{
auto
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel
);
if
(
mindspore
::
session
::
AnfRuntimeAlgorithm
::
IsCommunicationOp
(
kernel
))
{
for
(
size_t
i
=
0
;
i
<
input_num
;
i
++
)
{
auto
input_node
=
kernel
->
input
(
i
+
1
);
auto
kernel_input
=
AnfAlgo
::
VisitKernelWithReturnType
(
input_node
,
0
,
true
);
MS_LOG
(
INFO
)
<<
" Add atomic clean for single communication op input, comm:"
<<
kernel
->
fullname_with_scope
()
<<
" input_node: "
<<
kernel_input
.
first
->
fullname_with_scope
()
<<
" index: "
<<
kernel_input
.
second
;
auto
iter
=
comm_input_info_map
.
find
(
kernel_input
.
first
);
if
(
iter
!=
comm_input_info_map
.
end
())
{
iter
->
second
.
push_back
(
kernel_input
.
second
);
}
else
{
std
::
vector
<
size_t
>
indexes
=
{
kernel_input
.
second
};
comm_input_info_map
[
kernel_input
.
first
]
=
indexes
;
}
}
}
}
// remove duplicate index
for
(
auto
&
info
:
comm_input_info_map
)
{
std
::
set
<
size_t
>
s
(
info
.
second
.
begin
(),
info
.
second
.
end
());
info
.
second
.
assign
(
s
.
begin
(),
s
.
end
());
}
return
comm_input_info_map
;
}
void
KernelBuildPreprocess
(
mindspore
::
session
::
KernelGraph
*
kernel_graph
)
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
std
::
vector
<
CNodePtr
>
new_nodes
;
std
::
map
<
AnfNodePtr
,
std
::
vector
<
size_t
>>
comm_input_info_map
=
GetCommunicationOpInputInfo
(
kernel_graph
);
for
(
const
auto
&
anf_node
:
kernel_graph
->
execution_order
())
{
std
::
string
apply_function_name
=
AnfAlgo
::
GetCNodeName
(
anf_node
);
if
(
comm_input_info_map
.
find
(
anf_node
)
!=
comm_input_info_map
.
end
())
{
auto
indexes
=
comm_input_info_map
[
anf_node
];
AnfAlgo
::
SetNodeAttr
(
kAttrAutomicOutputIndexs
,
MakeValue
(
indexes
),
anf_node
);
}
if
(
apply_function_name
==
prim
::
kPrimMaxPoolGrad
->
name
()
&&
AnfAlgo
::
GetKernelType
(
anf_node
)
==
KernelType
::
AKG_KERNEL
)
{
auto
clear_zero_prim
=
std
::
make_shared
<
Primitive
>
(
kClearZeroOpName
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录