Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
362bbacf
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
362bbacf
编写于
5月 12, 2020
作者:
K
kswang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add group for allreduce fusion
上级
4ecc9389
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
117 addition
and
40 deletion
+117
-40
mindspore/ccsrc/parallel/context.cc
mindspore/ccsrc/parallel/context.cc
+16
-8
mindspore/ccsrc/parallel/context.h
mindspore/ccsrc/parallel/context.h
+7
-6
mindspore/ccsrc/pipeline/init.cc
mindspore/ccsrc/pipeline/init.cc
+4
-4
mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc
mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc
+3
-3
mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h
mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h
+1
-1
mindspore/parallel/_auto_parallel_context.py
mindspore/parallel/_auto_parallel_context.py
+86
-18
未找到文件。
mindspore/ccsrc/parallel/context.cc
浏览文件 @
362bbacf
...
...
@@ -113,20 +113,28 @@ void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ck
strategy_ckpt_save_file_
=
strategy_ckpt_save_file
;
}
void
ParallelContext
::
set_all_reduce_fusion_split_indices
(
const
std
::
vector
<
uint32_t
>
indices
)
{
all_reduce_fusion_split_indices_
=
indices
;
void
ParallelContext
::
SetAllReduceFusionSplitIndices
(
const
std
::
vector
<
uint32_t
>
indices
,
const
std
::
string
&
group
)
{
all_reduce_fusion_split_indices_
[
group
]
=
indices
;
}
const
std
::
vector
<
uint32_t
>
ParallelContext
::
all_reduce_fusion_split_indices
()
const
{
return
all_reduce_fusion_split_indices_
;
const
std
::
vector
<
uint32_t
>
ParallelContext
::
GetAllReduceFusionSplitIndices
(
const
std
::
string
&
group
)
const
{
auto
iter
=
all_reduce_fusion_split_indices_
.
find
(
group
);
if
(
iter
!=
all_reduce_fusion_split_indices_
.
end
())
{
return
iter
->
second
;
}
return
{};
}
void
ParallelContext
::
set_all_reduce_fusion_split_sizes
(
const
std
::
vector
<
uint32_t
>
sizes
)
{
all_reduce_fusion_split_sizes_
=
sizes
;
void
ParallelContext
::
SetAllReduceFusionSplitSizes
(
const
std
::
vector
<
uint32_t
>
sizes
,
const
std
::
string
&
group
)
{
all_reduce_fusion_split_sizes_
[
group
]
=
sizes
;
}
const
std
::
vector
<
uint32_t
>
ParallelContext
::
all_reduce_fusion_split_sizes
()
const
{
return
all_reduce_fusion_split_sizes_
;
const
std
::
vector
<
uint32_t
>
ParallelContext
::
GetAllReduceFusionSplitSizes
(
const
std
::
string
&
group
)
const
{
auto
iter
=
all_reduce_fusion_split_sizes_
.
find
(
group
);
if
(
iter
!=
all_reduce_fusion_split_sizes_
.
end
())
{
return
iter
->
second
;
}
return
{};
}
}
// namespace parallel
}
// namespace mindspore
mindspore/ccsrc/parallel/context.h
浏览文件 @
362bbacf
...
...
@@ -19,6 +19,7 @@
#include <cstdint>
#include <memory>
#include <map>
#include <string>
#include <vector>
...
...
@@ -76,10 +77,10 @@ class ParallelContext {
bool
global_rank_is_set
()
const
{
return
global_rank_is_set_
;
}
bool
parameter_broadcast_is_set
()
const
{
return
parameter_broadcast_is_set_
;
}
void
set_all_reduce_fusion_split_indices
(
const
std
::
vector
<
uint32_t
>
indices
);
const
std
::
vector
<
uint32_t
>
all_reduce_fusion_split_indices
(
)
const
;
void
set_all_reduce_fusion_split_sizes
(
const
std
::
vector
<
uint32_t
>
sizes
);
const
std
::
vector
<
uint32_t
>
all_reduce_fusion_split_sizes
(
)
const
;
void
SetAllReduceFusionSplitIndices
(
const
std
::
vector
<
uint32_t
>
indices
,
const
std
::
string
&
group
);
const
std
::
vector
<
uint32_t
>
GetAllReduceFusionSplitIndices
(
const
std
::
string
&
group
)
const
;
void
SetAllReduceFusionSplitSizes
(
const
std
::
vector
<
uint32_t
>
sizes
,
const
std
::
string
&
group
);
const
std
::
vector
<
uint32_t
>
GetAllReduceFusionSplitSizes
(
const
std
::
string
&
group
)
const
;
void
set_enable_all_reduce_fusion
(
bool
enable_all_reduce_fusion
)
{
enable_all_reduce_fusion_
=
enable_all_reduce_fusion
;
}
...
...
@@ -108,8 +109,8 @@ class ParallelContext {
bool
global_rank_is_set_
;
bool
parameter_broadcast_is_set_
;
bool
enable_all_reduce_fusion_
;
std
::
vector
<
uint32_t
>
all_reduce_fusion_split_indices_
;
std
::
vector
<
uint32_t
>
all_reduce_fusion_split_sizes_
;
std
::
map
<
std
::
string
,
std
::
vector
<
uint32_t
>
>
all_reduce_fusion_split_indices_
;
std
::
map
<
std
::
string
,
std
::
vector
<
uint32_t
>
>
all_reduce_fusion_split_sizes_
;
std
::
string
strategy_ckpt_load_file_
;
std
::
string
strategy_ckpt_save_file_
;
};
...
...
mindspore/ccsrc/pipeline/init.cc
浏览文件 @
362bbacf
...
...
@@ -159,13 +159,13 @@ PYBIND11_MODULE(_c_expression, m) {
.
def
(
"set_parallel_mode"
,
&
ParallelContext
::
set_parallel_mode
,
"Set parallel mode."
)
.
def
(
"get_strategy_search_mode"
,
&
ParallelContext
::
strategy_search_mode
,
"Get strategy search mode."
)
.
def
(
"set_strategy_search_mode"
,
&
ParallelContext
::
set_strategy_search_mode
,
"Set strategy search mode."
)
.
def
(
"set_all_reduce_fusion_split_indices"
,
&
ParallelContext
::
set_all_reduce_fusion_split_i
ndices
,
.
def
(
"set_all_reduce_fusion_split_indices"
,
&
ParallelContext
::
SetAllReduceFusionSplitI
ndices
,
"Set all reduce fusion split indices."
)
.
def
(
"get_all_reduce_fusion_split_indices"
,
&
ParallelContext
::
all_reduce_fusion_split_i
ndices
,
.
def
(
"get_all_reduce_fusion_split_indices"
,
&
ParallelContext
::
GetAllReduceFusionSplitI
ndices
,
"Get all reduce fusion split indices."
)
.
def
(
"set_all_reduce_fusion_split_sizes"
,
&
ParallelContext
::
set_all_reduce_fusion_split_s
izes
,
.
def
(
"set_all_reduce_fusion_split_sizes"
,
&
ParallelContext
::
SetAllReduceFusionSplitS
izes
,
"Set all reduce fusion split sizes."
)
.
def
(
"get_all_reduce_fusion_split_sizes"
,
&
ParallelContext
::
all_reduce_fusion_split_s
izes
,
.
def
(
"get_all_reduce_fusion_split_sizes"
,
&
ParallelContext
::
GetAllReduceFusionSplitS
izes
,
"Get all reduce fusion split sizes."
)
.
def
(
"set_enable_all_reduce_fusion"
,
&
ParallelContext
::
set_enable_all_reduce_fusion
,
"Set enable/disable all reduce fusion."
)
...
...
mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc
浏览文件 @
362bbacf
...
...
@@ -92,7 +92,7 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) {
}
// namespace
bool
CommunicationOpFusion
::
GetSplitSegments
(
const
CommunicationOpInfo
&
communication_op_info
,
size_t
*
segment_num
,
std
::
vector
<
size_t
>
*
segment_index
)
const
{
std
::
vector
<
size_t
>
*
segment_index
,
const
std
::
string
&
group
)
const
{
MS_EXCEPTION_IF_NULL
(
segment_num
);
MS_EXCEPTION_IF_NULL
(
segment_index
);
size_t
communication_op_node_size
=
communication_op_info
.
communication_op_nodes
.
size
();
...
...
@@ -100,7 +100,7 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
auto
parallel_context
=
parallel
::
ParallelContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
parallel_context
);
const
std
::
vector
<
uint32_t
>
split_indices
=
parallel_context
->
all_reduce_fusion_split_indices
(
);
const
auto
&
split_indices
=
parallel_context
->
GetAllReduceFusionSplitIndices
(
group
);
size_t
segments
=
0
;
if
(
split_indices
.
size
()
!=
0
)
{
...
...
@@ -255,7 +255,7 @@ bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
}
size_t
segment_num
=
0
;
std
::
vector
<
size_t
>
segment_index
;
if
(
GetSplitSegments
(
it
.
second
,
&
segment_num
,
&
segment_index
))
{
if
(
GetSplitSegments
(
it
.
second
,
&
segment_num
,
&
segment_index
,
it
.
first
))
{
if
(
DoFusion
(
func_graph
,
it
.
second
,
segment_num
,
segment_index
))
{
changed
=
true
;
}
...
...
mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h
浏览文件 @
362bbacf
...
...
@@ -46,7 +46,7 @@ class CommunicationOpFusion : public Pass {
const
CommunicationOpInfo
&
communication_op_info
,
size_t
start_index
,
size_t
end_index
)
const
;
bool
GetSplitSegments
(
const
CommunicationOpInfo
&
communication_op_info
,
size_t
*
segment_num
,
std
::
vector
<
size_t
>
*
segment_index
)
const
;
std
::
vector
<
size_t
>
*
segment_index
,
const
std
::
string
&
group
)
const
;
std
::
string
op_name_
;
size_t
groups_
=
1
;
};
...
...
mindspore/parallel/_auto_parallel_context.py
浏览文件 @
362bbacf
...
...
@@ -19,6 +19,8 @@ from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx,
from
mindspore._c_expression
import
AutoParallelContext
from
mindspore._checkparam
import
args_type_check
_MAX_GROUP_NAME_LEN
=
127
class
_AutoParallelContext
:
"""
...
...
@@ -243,51 +245,117 @@ class _AutoParallelContext:
self
.
check_context_handle
()
return
self
.
_context_handle
.
get_parameter_broadcast_is_set
()
def
set_all_reduce_fusion_split_indices
(
self
,
indices
):
def
set_all_reduce_fusion_split_indices
(
self
,
indices
,
group
=
""
):
"""
Set allreduce fusion strategy by parameters indices.
Args:
indices (list): Indices list.
group (str): The hccl communication group.
Raises:
TypeError: If type of indices item is not int.
TypeError: If group is not a python str.
"""
self
.
check_context_handle
()
for
index
in
indices
:
if
not
isinstance
(
index
,
int
):
raise
TypeError
(
'indices has invalid value'
)
self
.
_context_handle
.
set_all_reduce_fusion_split_indices
(
indices
)
if
isinstance
(
indices
,
(
list
)):
for
index
in
indices
:
if
not
isinstance
(
index
,
int
):
raise
TypeError
(
'indices has invalid value'
)
else
:
raise
TypeError
(
'indices must be a python list'
)
if
isinstance
(
group
,
(
str
)):
group_len
=
len
(
group
)
if
group_len
>
_MAX_GROUP_NAME_LEN
:
raise
ValueError
(
'Group name len is out of range {_MAX_GROUP_NAME_LEN}'
)
else
:
raise
TypeError
(
'Group must be a python str'
)
self
.
_context_handle
.
set_all_reduce_fusion_split_indices
(
indices
,
group
)
if
context
.
get_context
(
"device_target"
)
==
"Ascend"
:
_set_fusion_strategy_by_idx
(
indices
)
if
group
==
""
:
_set_fusion_strategy_by_idx
(
indices
)
else
:
_set_fusion_strategy_by_idx
(
indices
,
group
)
def
get_all_reduce_fusion_split_indices
(
self
,
group
=
""
):
"""
Get allreduce fusion split indices.
Args:
group (str): The hccl communication group.
Returns:
Return split sizes list according to the group.
def
get_all_reduce_fusion_split_indices
(
self
):
"""Get allreduce fusion split indices."""
Raises:
TypeError: If group is not a python str.
"""
self
.
check_context_handle
()
return
self
.
_context_handle
.
get_all_reduce_fusion_split_indices
()
if
isinstance
(
group
,
(
str
)):
group_len
=
len
(
group
)
if
group_len
>
_MAX_GROUP_NAME_LEN
:
raise
ValueError
(
'Group name len is out of range {_MAX_GROUP_NAME_LEN}'
)
else
:
raise
TypeError
(
'Group must be a python str'
)
return
self
.
_context_handle
.
get_all_reduce_fusion_split_indices
(
group
)
def
set_all_reduce_fusion_split_sizes
(
self
,
sizes
):
def
set_all_reduce_fusion_split_sizes
(
self
,
sizes
,
group
=
""
):
"""
Set allreduce fusion strategy by parameters data sizes.
Args:
sizes (list): Sizes list.
group (str): The hccl communication group.
Raises:
TypeError: If type of sizes item is not int.
TypeError: If group is not a python str.
"""
self
.
check_context_handle
()
for
size
in
sizes
:
if
not
isinstance
(
size
,
int
):
raise
TypeError
(
'sizes has invalid value'
)
self
.
_context_handle
.
set_all_reduce_fusion_split_sizes
(
sizes
)
if
isinstance
(
sizes
,
(
list
)):
for
size
in
sizes
:
if
not
isinstance
(
size
,
int
):
raise
TypeError
(
'sizes has invalid value'
)
else
:
raise
TypeError
(
'sizes must be a python list'
)
if
isinstance
(
group
,
(
str
)):
group_len
=
len
(
group
)
if
group_len
>
_MAX_GROUP_NAME_LEN
:
raise
ValueError
(
'Group name len is out of range {_MAX_GROUP_NAME_LEN}'
)
else
:
raise
TypeError
(
'Group must be a python str'
)
self
.
_context_handle
.
set_all_reduce_fusion_split_sizes
(
sizes
,
group
)
if
context
.
get_context
(
"device_target"
)
==
"Ascend"
:
_set_fusion_strategy_by_size
(
sizes
)
if
group
==
""
:
_set_fusion_strategy_by_size
(
sizes
)
else
:
_set_fusion_strategy_by_size
(
sizes
,
group
)
def
get_all_reduce_fusion_split_sizes
(
self
):
"""Get allreduce fusion split sizes."""
def
get_all_reduce_fusion_split_sizes
(
self
,
group
=
""
):
"""
Get allreduce fusion split sizes.
Args:
group (str): The hccl communication group.
Returns:
Return split sizes list according to the group.
Raises:
TypeError: If group is not a python str.
"""
self
.
check_context_handle
()
return
self
.
_context_handle
.
get_all_reduce_fusion_split_sizes
()
if
isinstance
(
group
,
(
str
)):
group_len
=
len
(
group
)
if
group_len
>
_MAX_GROUP_NAME_LEN
:
raise
ValueError
(
'Group name len is out of range {_MAX_GROUP_NAME_LEN}'
)
else
:
raise
TypeError
(
'Group must be a python str'
)
return
self
.
_context_handle
.
get_all_reduce_fusion_split_sizes
(
group
)
def
set_enable_all_reduce_fusion
(
self
,
enable_all_reduce_fusion
):
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录