Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
39945d0f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
39945d0f
编写于
4月 20, 2020
作者:
Y
YuJianfeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add AllGather fusion pass
上级
38ad5673
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
146 addition
and
53 deletion
+146
-53
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
.../ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
+2
-1
mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc
mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc
+72
-50
mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h
mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h
+67
-0
mindspore/ccsrc/session/gpu_session.cc
mindspore/ccsrc/session/gpu_session.cc
+1
-1
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+3
-0
tests/ut/cpp/pre_activate/common/ir_fusion/allreduce_fusion_test.cc
...pp/pre_activate/common/ir_fusion/allreduce_fusion_test.cc
+1
-1
未找到文件。
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
浏览文件 @
39945d0f
...
...
@@ -21,7 +21,7 @@
#include "pre_activate/ascend/ir_fission/bn_grad_split.h"
#include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h"
#include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h"
#include "pre_activate/pass/
allreduce
_fusion.h"
#include "pre_activate/pass/
communication_op
_fusion.h"
#include "pre_activate/ascend/ir_fusion/square_sum_fusion.h"
#include "pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h"
#include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h"
...
...
@@ -254,6 +254,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
auto
optimizer
=
std
::
make_shared
<
GraphOptimizer
>
();
auto
other_pm
=
std
::
make_shared
<
PassManager
>
(
"other_pm"
);
other_pm
->
AddPass
(
std
::
make_shared
<
AllReduceFusion
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
AllGatherFusion
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
ParameterTransOpFusion
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
BufferFusion
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
GetitemTuple
>
());
...
...
mindspore/ccsrc/pre_activate/pass/
allreduce
_fusion.cc
→
mindspore/ccsrc/pre_activate/pass/
communication_op
_fusion.cc
浏览文件 @
39945d0f
...
...
@@ -13,14 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/pass/
allreduce
_fusion.h"
#include "pre_activate/pass/
communication_op
_fusion.h"
#include <vector>
#include <string>
#include <memory>
#include <unordered_map>
#include "utils/utils.h"
#include "utils/graph_utils.h"
#include "operator/ops.h"
#include "device/kernel_info.h"
...
...
@@ -31,9 +29,12 @@
namespace
mindspore
{
namespace
opt
{
namespace
{
kernel
::
KernelBuildInfoPtr
GenerateKernelBuildInfo
(
const
AllReduceInfo_t
&
allreduce_node_info
,
size_t
start_index
,
constexpr
auto
kAttrDefaultGroup
=
"default_group"
;
constexpr
auto
kAttrDefaultOp
=
"default_op"
;
kernel
::
KernelBuildInfoPtr
GenerateKernelBuildInfo
(
const
CommunicationOpInfo
&
communication_op_info
,
size_t
start_index
,
size_t
end_index
)
{
if
(
end_index
>=
allreduce_node_info
.
allreduce_node
.
size
())
{
if
(
end_index
>=
communication_op_info
.
communication_op_nodes
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"end index out of vector size"
;
}
std
::
vector
<
std
::
string
>
inputs_device_format
;
...
...
@@ -43,7 +44,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AllReduceInfo_t &allred
std
::
vector
<
std
::
vector
<
size_t
>>
outputs_shape
;
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
builder
;
for
(
size_t
idx
=
start_index
;
idx
<=
end_index
;
++
idx
)
{
auto
cnode
=
allreduce_node_info
.
allreduce_node
[
idx
];
auto
cnode
=
communication_op_info
.
communication_op_nodes
[
idx
];
MS_EXCEPTION_IF_NULL
(
cnode
);
for
(
size_t
input_index
=
0
;
input_index
<
AnfAlgo
::
GetInputTensorNum
(
cnode
);
++
input_index
)
{
inputs_device_format
.
push_back
(
AnfAlgo
::
GetInputFormat
(
cnode
,
input_index
));
...
...
@@ -64,14 +65,38 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AllReduceInfo_t &allred
builder
.
SetOutputsDeviceType
(
outputs_device_type
);
return
builder
.
Build
();
}
std
::
string
GetFusionGroupKey
(
const
AnfNodePtr
&
node
)
{
auto
primitive
=
AnfAlgo
::
GetCNodePrimitive
(
node
);
MS_EXCEPTION_IF_NULL
(
primitive
);
ValuePtr
attr_fusion
=
primitive
->
GetAttr
(
kAttrFusion
);
if
(
attr_fusion
==
nullptr
)
{
return
""
;
}
int
fusion
=
GetValue
<
int
>
(
attr_fusion
);
if
(
fusion
==
0
)
{
return
""
;
}
std
::
string
group
=
kAttrDefaultGroup
;
ValuePtr
attr_group
=
primitive
->
GetAttr
(
kAttrGroup
);
if
(
attr_group
!=
nullptr
)
{
group
=
GetValue
<
std
::
string
>
(
attr_group
);
}
std
::
string
op
=
kAttrDefaultOp
;
ValuePtr
attr_op
=
primitive
->
GetAttr
(
kAttrOp
);
if
(
attr_op
!=
nullptr
)
{
op
=
GetValue
<
std
::
string
>
(
attr_op
);
}
return
group
+
op
+
std
::
to_string
(
fusion
);
}
}
// namespace
bool
AllReduceFusion
::
GetSplitSegments
(
const
AllReduceInfo_t
&
allreduce_node
_info
,
size_t
*
segment_num
,
std
::
vector
<
size_t
>
*
segment_index
)
const
{
bool
CommunicationOpFusion
::
GetSplitSegments
(
const
CommunicationOpInfo
&
communication_op
_info
,
size_t
*
segment_num
,
std
::
vector
<
size_t
>
*
segment_index
)
const
{
MS_EXCEPTION_IF_NULL
(
segment_num
);
MS_EXCEPTION_IF_NULL
(
segment_index
);
size_t
allreduce_node_size
=
allreduce_node_info
.
allreduce_node
.
size
();
MS_LOG
(
INFO
)
<<
"graph
all reduce node size "
<<
allreduce
_node_size
;
size_t
communication_op_node_size
=
communication_op_info
.
communication_op_nodes
.
size
();
MS_LOG
(
INFO
)
<<
"graph
"
<<
op_name_
<<
" node size "
<<
communication_op
_node_size
;
auto
parallel_context
=
parallel
::
ParallelContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
parallel_context
);
...
...
@@ -82,30 +107,31 @@ bool AllReduceFusion::GetSplitSegments(const AllReduceInfo_t &allreduce_node_inf
uint32_t
last_index
=
0
;
for
(
size_t
i
=
0
;
i
<
split_indices
.
size
();
++
i
)
{
uint32_t
index
=
split_indices
[
i
];
if
(
index
<=
last_index
||
index
>=
allreduce
_node_size
)
{
MS_LOG
(
EXCEPTION
)
<<
"invalid
allreduce
split index "
<<
i
<<
" "
<<
index
;
if
(
index
<=
last_index
||
index
>=
communication_op
_node_size
)
{
MS_LOG
(
EXCEPTION
)
<<
"invalid
"
<<
op_name_
<<
"
split index "
<<
i
<<
" "
<<
index
;
}
segment_index
->
push_back
(
index
);
last_index
=
index
;
segments
++
;
}
if
(
last_index
!=
allreduce
_node_size
-
1
)
{
segment_index
->
push_back
(
allreduce
_node_size
-
1
);
if
(
last_index
!=
communication_op
_node_size
-
1
)
{
segment_index
->
push_back
(
communication_op
_node_size
-
1
);
segments
++
;
}
}
else
{
segments
=
groups_
;
for
(
size_t
i
=
0
;
i
<
segments
-
1
;
++
i
)
{
segment_index
->
push_back
((
i
+
1
)
*
(
allreduce
_node_size
/
segments
)
-
1
);
segment_index
->
push_back
((
i
+
1
)
*
(
communication_op
_node_size
/
segments
)
-
1
);
}
segment_index
->
push_back
(
allreduce
_node_size
-
1
);
segment_index
->
push_back
(
communication_op
_node_size
-
1
);
}
if
(
segments
>=
allreduce_node_size
)
{
MS_LOG
(
INFO
)
<<
"fusion not changed: segment_num="
<<
segments
<<
", allreduce_node_size="
<<
allreduce_node_size
;
if
(
segments
>=
communication_op_node_size
)
{
MS_LOG
(
INFO
)
<<
"fusion not changed: segment_num="
<<
segments
<<
", communication_op_node_size="
<<
communication_op_node_size
;
return
false
;
}
if
(
segment_index
->
at
(
segments
-
1
)
!=
allreduce
_node_size
-
1
)
{
if
(
segment_index
->
at
(
segments
-
1
)
!=
communication_op
_node_size
-
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"the last segment index is invalid."
;
}
for
(
size_t
i
=
0
;
i
<
segments
-
1
;
++
i
)
{
...
...
@@ -118,19 +144,19 @@ bool AllReduceFusion::GetSplitSegments(const AllReduceInfo_t &allreduce_node_inf
return
true
;
}
AnfNodePtr
AllReduceFusion
::
CreateFusedAllReduce
(
const
FuncGraphPtr
&
func_graph
,
const
AllReduceInfo_t
&
allreduce_node_info
,
size_t
start_index
,
size_t
end_index
)
const
{
AnfNodePtr
CommunicationOpFusion
::
CreateFusedCommunicationOp
(
const
FuncGraphPtr
&
func_graph
,
const
CommunicationOpInfo
&
communication_op_info
,
size_t
start_index
,
size_t
end_index
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
auto
prim
=
std
::
make_shared
<
Primitive
>
(
kAllReduceOpName
);
auto
prim
=
std
::
make_shared
<
Primitive
>
(
op_name_
);
MS_EXCEPTION_IF_NULL
(
prim
);
std
::
vector
<
AnfNodePtr
>
fusion_inputs
=
{
NewValueNode
(
prim
)};
// get all inputs of current segment
if
(
end_index
>=
allreduce_node_info
.
allreduce_node
.
size
())
{
if
(
end_index
>=
communication_op_info
.
communication_op_nodes
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"end index out of vector size"
;
}
for
(
size_t
idx
=
start_index
;
idx
<=
end_index
;
++
idx
)
{
auto
cnode
=
allreduce_node_info
.
allreduce_node
[
idx
];
auto
cnode
=
communication_op_info
.
communication_op_nodes
[
idx
];
MS_EXCEPTION_IF_NULL
(
cnode
);
fusion_inputs
.
insert
(
fusion_inputs
.
end
(),
cnode
->
inputs
().
begin
()
+
1
,
cnode
->
inputs
().
end
());
}
...
...
@@ -141,14 +167,14 @@ AnfNodePtr AllReduceFusion::CreateFusedAllReduce(const FuncGraphPtr &func_graph,
fused_node
->
set_kernel_info
(
kernel_info
);
AbstractBasePtrList
abstract_list
;
for
(
size_t
idx
=
start_index
;
idx
<=
end_index
;
++
idx
)
{
auto
cnode
=
allreduce_node_info
.
allreduce_node
[
idx
];
auto
cnode
=
communication_op_info
.
communication_op_nodes
[
idx
];
MS_EXCEPTION_IF_NULL
(
cnode
);
AnfAlgo
::
CopyNodeAttr
(
"fusion"
,
cnode
,
fused_node
);
AnfAlgo
::
CopyNodeAttr
(
"op"
,
cnode
,
fused_node
);
AnfAlgo
::
CopyNodeAttr
(
"group"
,
cnode
,
fused_node
);
abstract_list
.
push_back
(
cnode
->
abstract
());
}
auto
kernel_build_info
=
GenerateKernelBuildInfo
(
allreduce_node
_info
,
start_index
,
end_index
);
auto
kernel_build_info
=
GenerateKernelBuildInfo
(
communication_op
_info
,
start_index
,
end_index
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
kernel_build_info
,
fused_node
.
get
());
auto
abstract_tuple
=
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
abstract_list
);
MS_EXCEPTION_IF_NULL
(
abstract_tuple
);
...
...
@@ -156,8 +182,8 @@ AnfNodePtr AllReduceFusion::CreateFusedAllReduce(const FuncGraphPtr &func_graph,
return
fused_node
;
}
bool
AllReduceFusion
::
DoFusion
(
const
FuncGraphPtr
&
func_graph
,
const
AllReduceInfo_t
&
allreduce_node
_info
,
size_t
segment_num
,
const
std
::
vector
<
size_t
>
&
segment_index
)
const
{
bool
CommunicationOpFusion
::
DoFusion
(
const
FuncGraphPtr
&
func_graph
,
const
CommunicationOpInfo
&
communication_op
_info
,
size_t
segment_num
,
const
std
::
vector
<
size_t
>
&
segment_index
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
auto
manager
=
func_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
...
...
@@ -169,12 +195,13 @@ bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceIn
start_index
=
end_index
+
1
;
continue
;
}
AnfNodePtr
new_allreduce
=
CreateFusedAllReduce
(
func_graph
,
allreduce_node_info
,
start_index
,
end_index
);
// replace old allreduce with new allreduce
AnfNodePtr
new_communication_op
=
CreateFusedCommunicationOp
(
func_graph
,
communication_op_info
,
start_index
,
end_index
);
// replace old communication op with new communication op
for
(
auto
idx
=
start_index
;
idx
<=
end_index
;
++
idx
)
{
std
::
vector
<
AnfNodePtr
>
tuple_getitem_input
;
tuple_getitem_input
.
push_back
(
NewValueNode
(
prim
::
kPrimTupleGetItem
));
tuple_getitem_input
.
push_back
(
new_
allreduce
);
tuple_getitem_input
.
push_back
(
new_
communication_op
);
auto
index
=
NewValueNode
(
SizeToInt
(
idx
-
start_index
));
MS_EXCEPTION_IF_NULL
(
index
);
auto
imm
=
std
::
make_shared
<
Int32Imm
>
(
idx
-
start_index
);
...
...
@@ -185,10 +212,10 @@ bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceIn
tuple_getitem_input
.
push_back
(
index
);
AnfNodePtr
tuple_getitem
=
func_graph
->
NewCNode
(
tuple_getitem_input
);
MS_EXCEPTION_IF_NULL
(
tuple_getitem
);
auto
allreduce_node_item
=
allreduce_node_info
.
allreduce_node
.
at
(
idx
);
MS_EXCEPTION_IF_NULL
(
allreduce
_node_item
);
tuple_getitem
->
set_abstract
(
allreduce
_node_item
->
abstract
());
if
(
!
manager
->
Replace
(
allreduce
_node_item
,
tuple_getitem
))
{
auto
communication_op_node_item
=
communication_op_info
.
communication_op_nodes
.
at
(
idx
);
MS_EXCEPTION_IF_NULL
(
communication_op
_node_item
);
tuple_getitem
->
set_abstract
(
communication_op
_node_item
->
abstract
());
if
(
!
manager
->
Replace
(
communication_op
_node_item
,
tuple_getitem
))
{
MS_LOG
(
EXCEPTION
)
<<
"manager replace node failed"
;
}
}
...
...
@@ -198,29 +225,24 @@ bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceIn
return
changed
;
}
bool
AllReduce
Fusion
::
Run
(
const
FuncGraphPtr
&
func_graph
)
{
bool
CommunicationOp
Fusion
::
Run
(
const
FuncGraphPtr
&
func_graph
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
const
float
input_grad_size_num
=
0.0
;
const
float
input_grad_time_num
=
0.0
;
// divide candidate fusion groups with same (group,op,fusion) attrs, fusion==0 means not fusion
std
::
unordered_map
<
std
::
string
,
AllReduceInfo_t
>
candidate_groups
;
std
::
unordered_map
<
std
::
string
,
CommunicationOpInfo
>
candidate_groups
;
std
::
vector
<
AnfNodePtr
>
node_list
=
TopoSort
(
func_graph
->
get_return
());
for
(
auto
&
node
:
node_list
)
{
if
(
node
!=
nullptr
&&
node
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
node
)
==
kAllReduceOpName
)
{
auto
primitive
=
AnfAlgo
::
GetCNodePrimitive
(
node
);
MS_EXCEPTION_IF_NULL
(
primitive
);
int
fusion
=
GetValue
<
int
>
(
primitive
->
GetAttr
(
"fusion"
));
if
(
fusion
==
0
)
{
if
(
node
!=
nullptr
&&
node
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
node
)
==
op_name_
)
{
std
::
string
key
=
GetFusionGroupKey
(
node
);
if
(
key
.
empty
())
{
continue
;
}
std
::
string
group
=
GetValue
<
std
::
string
>
(
primitive
->
GetAttr
(
"group"
));
std
::
string
op
=
GetValue
<
std
::
string
>
(
primitive
->
GetAttr
(
"op"
));
std
::
string
key
=
group
+
op
+
std
::
to_string
(
fusion
);
if
(
candidate_groups
.
find
(
key
)
==
candidate_groups
.
end
())
{
AllReduceInfo_t
allreduce_node
_info
;
candidate_groups
[
key
]
=
allreduce_node
_info
;
CommunicationOpInfo
communication_op
_info
;
candidate_groups
[
key
]
=
communication_op
_info
;
}
candidate_groups
[
key
].
allreduce_node
.
push_back
(
node
->
cast
<
CNodePtr
>
());
candidate_groups
[
key
].
communication_op_nodes
.
push_back
(
node
->
cast
<
CNodePtr
>
());
candidate_groups
[
key
].
input_grad_size
.
push_back
(
input_grad_size_num
);
candidate_groups
[
key
].
input_grad_time
.
push_back
(
input_grad_time_num
);
}
...
...
@@ -228,7 +250,7 @@ bool AllReduceFusion::Run(const FuncGraphPtr &func_graph) {
// split candidate group to segments according to _group class member
bool
changed
=
false
;
for
(
auto
&
it
:
candidate_groups
)
{
if
(
it
.
second
.
allreduce_node
.
size
()
<=
1
)
{
if
(
it
.
second
.
communication_op_nodes
.
size
()
<=
1
)
{
continue
;
}
size_t
segment_num
=
0
;
...
...
mindspore/ccsrc/pre_activate/pass/
allreduce
_fusion.h
→
mindspore/ccsrc/pre_activate/pass/
communication_op
_fusion.h
浏览文件 @
39945d0f
...
...
@@ -13,37 +13,55 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ALLREDUCE_FUSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ALLREDUCE_FUSION_H_
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_
#include <utility>
#include <vector>
#include <string>
#include "pre_activate/common/pass.h"
#include "ir/func_graph.h"
#include "ir/anf.h"
#include "utils/utils.h"
namespace
mindspore
{
namespace
opt
{
struct
AllReduceInfo_t
{
std
::
vector
<
CNodePtr
>
allreduce_node
;
struct
CommunicationOpInfo
{
std
::
vector
<
CNodePtr
>
communication_op_nodes
;
std
::
vector
<
float
>
input_grad_size
;
std
::
vector
<
float
>
input_grad_time
;
};
class
AllReduce
Fusion
:
public
Pass
{
class
CommunicationOp
Fusion
:
public
Pass
{
public:
explicit
AllReduceFusion
(
size_t
groups
=
1
)
:
Pass
(
"all_reduce_fusion"
),
groups_
(
groups
)
{}
~
AllReduceFusion
()
override
=
default
;
explicit
CommunicationOpFusion
(
const
std
::
string
&
name
,
std
::
string
op_name
,
size_t
groups
=
1
)
:
Pass
(
name
),
op_name_
(
std
::
move
(
op_name
)),
groups_
(
groups
)
{}
~
CommunicationOpFusion
()
override
=
default
;
bool
Run
(
const
FuncGraphPtr
&
graph
)
override
;
private:
bool
DoFusion
(
const
FuncGraphPtr
&
func_graph
,
const
AllReduceInfo_t
&
allreduce_node
_info
,
size_t
segment_num
,
bool
DoFusion
(
const
FuncGraphPtr
&
func_graph
,
const
CommunicationOpInfo
&
communication_op
_info
,
size_t
segment_num
,
const
std
::
vector
<
size_t
>
&
segment_index
)
const
;
AnfNodePtr
CreateFusedAllReduce
(
const
FuncGraphPtr
&
func_graph
,
const
AllReduceInfo_t
&
allreduce_node_info
,
size_t
start_index
,
size_t
end_index
)
const
;
bool
GetSplitSegments
(
const
AllReduceInfo_t
&
allreduce_node_info
,
size_t
*
segment_num
,
AnfNodePtr
CreateFusedCommunicationOp
(
const
FuncGraphPtr
&
func_graph
,
const
CommunicationOpInfo
&
communication_op_info
,
size_t
start_index
,
size_t
end_index
)
const
;
bool
GetSplitSegments
(
const
CommunicationOpInfo
&
communication_op_info
,
size_t
*
segment_num
,
std
::
vector
<
size_t
>
*
segment_index
)
const
;
std
::
string
op_name_
;
size_t
groups_
=
1
;
};
class
AllReduceFusion
:
public
CommunicationOpFusion
{
public:
explicit
AllReduceFusion
(
size_t
groups
=
1
)
:
CommunicationOpFusion
(
"all_reduce_fusion"
,
kAllReduceOpName
,
groups
)
{}
~
AllReduceFusion
()
override
=
default
;
};
class
AllGatherFusion
:
public
CommunicationOpFusion
{
public:
explicit
AllGatherFusion
(
size_t
groups
=
1
)
:
CommunicationOpFusion
(
"all_gather_fusion"
,
kAllGatherOpName
,
groups
)
{}
~
AllGatherFusion
()
override
=
default
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_
ALLREDUCE
_FUSION_H_
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_
COMMUNICATION_OP
_FUSION_H_
mindspore/ccsrc/session/gpu_session.cc
浏览文件 @
39945d0f
...
...
@@ -20,7 +20,7 @@
#include "device/gpu/gpu_stream_assign.h"
#include "pre_activate/common/optimizer.h"
#include "pre_activate/common/pass_manager.h"
#include "pre_activate/pass/
allreduce
_fusion.h"
#include "pre_activate/pass/
communication_op
_fusion.h"
#include "device/kernel_runtime_manager.h"
#include "predict/predict.h"
#include "common/utils.h"
...
...
mindspore/ccsrc/utils/utils.h
浏览文件 @
39945d0f
...
...
@@ -154,6 +154,9 @@ constexpr auto kAttrOutputUsedNum = "output_used_num";
constexpr
auto
kAttrHasBias
=
"has_bias"
;
constexpr
auto
kAttrN
=
"n"
;
constexpr
auto
kAttrLabelForInsertStreamActive
=
"label_for_insert_stream_active"
;
constexpr
auto
kAttrFusion
=
"fusion"
;
constexpr
auto
kAttrGroup
=
"group"
;
constexpr
auto
kAttrOp
=
"op"
;
// attr value
constexpr
auto
kValueTargetSwitch
=
"target_switch"
;
...
...
tests/ut/cpp/pre_activate/common/ir_fusion/allreduce_fusion_test.cc
浏览文件 @
39945d0f
...
...
@@ -20,7 +20,7 @@
#include "ir/manager.h"
#include "debug/anf_ir_dump.h"
#include "session/anf_runtime_algorithm.h"
#include "pre_activate/pass/
allreduce
_fusion.h"
#include "pre_activate/pass/
communication_op
_fusion.h"
#include "pre_activate/common/optimizer.h"
#include "device/kernel_info.h"
#include "pre_activate/common/pass_manager.h"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录