Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
c9453a61
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c9453a61
编写于
7月 31, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 31, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3239 Insert concat for outputs of AllGather
Merge pull request !3239 from YuJianfeng/allgther
上级
6adcb5a8
47ab812e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
156 addition
and
3 deletion
+156
-3
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
...c/backend/optimizer/ascend/ascend_backend_optimization.cc
+2
-0
mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc
...ptimizer/ascend/enhancer/concat_outputs_for_all_gather.cc
+104
-0
mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h
...optimizer/ascend/enhancer/concat_outputs_for_all_gather.h
+42
-0
mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc
...e/ccsrc/backend/optimizer/pass/communication_op_fusion.cc
+7
-3
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+1
-0
未找到文件。
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
浏览文件 @
c9453a61
...
...
@@ -102,6 +102,7 @@
#include "backend/optimizer/ascend/format_type/remove_internal_output.h"
#include "backend/optimizer/ascend/ir_fission/concat_fission.h"
#include "backend/optimizer/ascend/ir_fission/pack_fission.h"
#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h"
#include "utils/context/ms_context.h"
#include "utils/config_manager.h"
#include "debug/anf_ir_dump.h"
...
...
@@ -341,6 +342,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
auto
other_pm
=
std
::
make_shared
<
PassManager
>
(
"other_pm"
);
other_pm
->
AddPass
(
std
::
make_shared
<
AllReduceFusion
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
AllGatherFusion
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
ConcatOutputsForAllGather
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
ReduceScatterFusion
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
BroadcastFusion
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
InsertMemcpyAsyncForCascade
>
());
...
...
mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc
0 → 100644
浏览文件 @
c9453a61
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h"
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
namespace
mindspore
{
namespace
opt
{
namespace
{
void
AddOutputs
(
const
AnfNodePtr
&
node
,
int
rank_size
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
origin_abstract
=
node
->
abstract
();
MS_EXCEPTION_IF_NULL
(
origin_abstract
);
auto
tuple_abstract
=
origin_abstract
->
cast
<
abstract
::
AbstractTuplePtr
>
();
MS_EXCEPTION_IF_NULL
(
tuple_abstract
);
auto
&
origin_abstracts
=
tuple_abstract
->
elements
();
AbstractBasePtrList
abstract_list
;
std
::
vector
<
TypeId
>
outputs_device_type
;
std
::
vector
<
std
::
string
>
outputs_device_format
;
for
(
int
i
=
0
;
i
<
rank_size
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
origin_abstracts
.
size
();
++
j
)
{
abstract_list
.
push_back
(
origin_abstracts
[
j
]);
outputs_device_type
.
push_back
(
AnfAlgo
::
GetOutputDeviceDataType
(
node
,
j
));
outputs_device_format
.
push_back
(
AnfAlgo
::
GetOutputFormat
(
node
,
j
));
}
}
// Update abstract
auto
new_abstracts
=
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
abstract_list
);
node
->
set_abstract
(
new_abstracts
);
// Update kernel build info
auto
builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
(
AnfAlgo
::
GetSelectKernelBuildInfo
(
node
));
builder
->
SetOutputsDeviceType
(
outputs_device_type
);
builder
->
SetOutputsFormat
(
outputs_device_format
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
->
Build
(),
node
.
get
());
}
}
// namespace
AnfNodePtr
ConcatOutputsForAllGather
::
InsertConcatForOutput
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
std
::
vector
<
AnfNodePtr
>
&
new_tuple_getitems
,
int
rank_size
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
std
::
vector
<
AnfNodePtr
>
make_tuple_inputs
;
size_t
inputs_size
=
AnfAlgo
::
GetInputTensorNum
(
node
);
for
(
size_t
i
=
0
;
i
<
inputs_size
;
++
i
)
{
for
(
size_t
j
=
0
,
idx
=
i
;
j
<
IntToSize
(
rank_size
);
++
j
,
idx
+=
inputs_size
)
{
std
::
vector
<
AnfNodePtr
>
concat_inputs
{
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimConcat
->
name
()))};
concat_inputs
.
push_back
(
new_tuple_getitems
[
idx
]);
auto
concat
=
func_graph
->
NewCNode
(
concat_inputs
);
MS_EXCEPTION_IF_NULL
(
concat
);
MS_EXCEPTION_IF_NULL
(
new_tuple_getitems
[
idx
]);
concat
->
set_abstract
(
new_tuple_getitems
[
idx
]
->
abstract
());
AnfAlgo
::
SetNodeAttr
(
kAttrAxis
,
MakeValue
(
0
),
concat
);
AnfAlgo
::
SetNodeAttr
(
kAttrInputNums
,
MakeValue
(
rank_size
),
concat
);
std
::
vector
<
int
>
dyn_input_size
{
rank_size
};
AnfAlgo
::
SetNodeAttr
(
kAttrDynInputSizes
,
MakeValue
(
dyn_input_size
),
concat
);
kernel_select_
->
SelectKernel
(
concat
);
make_tuple_inputs
.
push_back
(
concat
);
}
}
auto
make_tuple
=
func_graph
->
NewCNode
(
make_tuple_inputs
);
return
make_tuple
;
}
const
BaseRef
ConcatOutputsForAllGather
::
DefinePattern
()
const
{
VarPtr
Xs
=
std
::
make_shared
<
SeqVar
>
();
auto
prim
=
std
::
make_shared
<
Primitive
>
(
kAllGatherOpName
);
return
VectorRef
({
prim
,
Xs
});
}
const
AnfNodePtr
ConcatOutputsForAllGather
::
Process
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
)
const
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
!
AnfAlgo
::
HasNodeAttr
(
kAttrFusion
,
cnode
)
||
!
AnfAlgo
::
HasNodeAttr
(
kAttrRankSize
,
cnode
))
{
return
nullptr
;
}
auto
fusion
=
AnfAlgo
::
GetNodeAttr
<
int
>
(
cnode
,
kAttrFusion
);
if
(
fusion
<=
0
)
{
return
nullptr
;
}
auto
rank_size
=
AnfAlgo
::
GetNodeAttr
<
int
>
(
node
,
kAttrRankSize
);
AddOutputs
(
node
,
rank_size
);
std
::
vector
<
AnfNodePtr
>
new_outputs
;
CreateMultipleOutputsOfAnfNode
(
func_graph
,
node
,
AnfAlgo
::
GetOutputTensorNum
(
node
),
&
new_outputs
);
return
InsertConcatForOutput
(
func_graph
,
node
,
new_outputs
,
rank_size
);
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h
0 → 100644
浏览文件 @
c9453a61
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_CONCAT_OUTPUTS_FOR_ALLGATHER_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_CONCAT_OUTPUTS_FOR_ALLGATHER_H_
#include <memory>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ascend_helper.h"
namespace
mindspore
{
namespace
opt
{
class
ConcatOutputsForAllGather
:
public
PatternProcessPass
{
public:
explicit
ConcatOutputsForAllGather
(
bool
multigraph
=
true
)
:
PatternProcessPass
(
"concat_outputs_for_all_gather"
,
multigraph
),
kernel_select_
(
std
::
make_shared
<
KernelSelect
>
())
{}
~
ConcatOutputsForAllGather
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
private:
AnfNodePtr
InsertConcatForOutput
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
std
::
vector
<
AnfNodePtr
>
&
new_tuple_getitems
,
int
rank_size
)
const
;
KernelSelectPtr
kernel_select_
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_CONCAT_OUTPUTS_FOR_ALLGATHER_H_
mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc
浏览文件 @
c9453a61
...
...
@@ -188,9 +188,13 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
auto
abstract_tuple
=
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
abstract_list
);
MS_EXCEPTION_IF_NULL
(
abstract_tuple
);
fused_node
->
set_abstract
(
abstract_tuple
);
AnfAlgo
::
CopyNodeAttr
(
"fusion"
,
communication_op_info
.
communication_op_nodes
[
end_index
],
fused_node
);
AnfAlgo
::
CopyNodeAttr
(
"op"
,
communication_op_info
.
communication_op_nodes
[
end_index
],
fused_node
);
AnfAlgo
::
CopyNodeAttr
(
"group"
,
communication_op_info
.
communication_op_nodes
[
end_index
],
fused_node
);
auto
final_node
=
communication_op_info
.
communication_op_nodes
[
end_index
];
AnfAlgo
::
CopyNodeAttr
(
kAttrFusion
,
final_node
,
fused_node
);
AnfAlgo
::
CopyNodeAttr
(
kAttrOp
,
final_node
,
fused_node
);
AnfAlgo
::
CopyNodeAttr
(
kAttrGroup
,
final_node
,
fused_node
);
if
(
AnfAlgo
::
HasNodeAttr
(
kAttrRankSize
,
final_node
))
{
AnfAlgo
::
CopyNodeAttr
(
kAttrRankSize
,
final_node
,
fused_node
);
}
return
fused_node
;
}
...
...
mindspore/ccsrc/utils/utils.h
浏览文件 @
c9453a61
...
...
@@ -250,6 +250,7 @@ constexpr auto kAttrChildGraph = "child_graph";
constexpr
auto
kAttrInputNums
=
"inputNums"
;
constexpr
auto
kAttrT
=
"T"
;
constexpr
auto
kAttrNum
=
"num"
;
constexpr
auto
kAttrRankSize
=
"rank_size"
;
// attr value
constexpr
auto
kValueTargetSwitch
=
"target_switch"
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录