Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4ff41808
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4ff41808
编写于
4月 22, 2020
作者:
L
lirongzhen
提交者:
lirongzhen1
4月 24, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enable/disable allreduce_fusion
上级
9edc69af
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
38 addition
and
2 deletion
+38
-2
mindspore/ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.cc
.../ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.cc
+2
-1
mindspore/ccsrc/parallel/context.cc
mindspore/ccsrc/parallel/context.cc
+1
-0
mindspore/ccsrc/parallel/context.h
mindspore/ccsrc/parallel/context.h
+5
-0
mindspore/ccsrc/pipeline/init.cc
mindspore/ccsrc/pipeline/init.cc
+4
-0
mindspore/parallel/_auto_parallel_context.py
mindspore/parallel/_auto_parallel_context.py
+17
-0
mindspore/parallel/_utils.py
mindspore/parallel/_utils.py
+5
-0
tests/ut/python/parallel/__init__.py
tests/ut/python/parallel/__init__.py
+2
-0
tests/ut/python/parallel/test_allreduce_fusion.py
tests/ut/python/parallel/test_allreduce_fusion.py
+2
-1
未找到文件。
mindspore/ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.cc
浏览文件 @
4ff41808
...
...
@@ -31,10 +31,11 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
MS_EXCEPTION_IF_NULL
(
optimizer
);
MS_EXCEPTION_IF_NULL
(
ParallelContext
::
GetInstance
());
std
::
string
parallel_mode
=
ParallelContext
::
GetInstance
()
->
parallel_mode
();
bool
enable_all_reduce_fusion
=
ParallelContext
::
GetInstance
()
->
enable_all_reduce_fusion
();
// assume no change to graph
bool
changes
=
false
;
// control whether use model_parallel mode
if
(((
parallel_mode
!=
AUTO_PARALLEL
)
&&
(
parallel_mode
!=
SEMI_AUTO_PARALLEL
))
||
if
(((
parallel_mode
!=
AUTO_PARALLEL
)
&&
(
parallel_mode
!=
SEMI_AUTO_PARALLEL
))
||
(
!
enable_all_reduce_fusion
)
||
(
root
->
has_flag
(
ALLREDUCE_FUSION_RUN_ONCE_ONLY
)))
{
return
changes
;
}
...
...
mindspore/ccsrc/parallel/context.cc
浏览文件 @
4ff41808
...
...
@@ -55,6 +55,7 @@ void ParallelContext::Reset() {
parallel_mode_
=
STAND_ALONE
;
parameter_broadcast_
=
false
;
parameter_broadcast_is_set_
=
false
;
enable_all_reduce_fusion_
=
false
;
}
void
ParallelContext
::
set_device_num
(
int32_t
device_num
)
{
...
...
mindspore/ccsrc/parallel/context.h
浏览文件 @
4ff41808
...
...
@@ -80,6 +80,10 @@ class ParallelContext {
const
std
::
vector
<
uint32_t
>
all_reduce_fusion_split_indices
()
const
;
void
set_all_reduce_fusion_split_sizes
(
const
std
::
vector
<
uint32_t
>
sizes
);
const
std
::
vector
<
uint32_t
>
all_reduce_fusion_split_sizes
()
const
;
void
set_enable_all_reduce_fusion
(
bool
enable_all_reduce_fusion
)
{
enable_all_reduce_fusion_
=
enable_all_reduce_fusion
;
}
bool
enable_all_reduce_fusion
()
const
{
return
enable_all_reduce_fusion_
;
}
void
Reset
();
...
...
@@ -98,6 +102,7 @@ class ParallelContext {
bool
device_num_is_set_
;
bool
global_rank_is_set_
;
bool
parameter_broadcast_is_set_
;
bool
enable_all_reduce_fusion_
;
std
::
vector
<
uint32_t
>
all_reduce_fusion_split_indices_
;
std
::
vector
<
uint32_t
>
all_reduce_fusion_split_sizes_
;
};
...
...
mindspore/ccsrc/pipeline/init.cc
浏览文件 @
4ff41808
...
...
@@ -183,6 +183,10 @@ PYBIND11_MODULE(_c_expression, m) {
"Set all reduce fusion split sizes."
)
.
def
(
"get_all_reduce_fusion_split_sizes"
,
&
ParallelContext
::
all_reduce_fusion_split_sizes
,
"Get all reduce fusion split sizes."
)
.
def
(
"set_enable_all_reduce_fusion"
,
&
ParallelContext
::
set_enable_all_reduce_fusion
,
"Set enable/disable all reduce fusion."
)
.
def
(
"get_enable_all_reduce_fusion"
,
&
ParallelContext
::
enable_all_reduce_fusion
,
"Get enable/disable all reduce fusion."
)
.
def
(
"get_parameter_broadcast"
,
&
ParallelContext
::
parameter_broadcast
,
"Get parameter broadcast."
)
.
def
(
"get_parameter_broadcast_is_set"
,
&
ParallelContext
::
parameter_broadcast_is_set
,
"Get parameter broadcast is set."
)
...
...
mindspore/parallel/_auto_parallel_context.py
浏览文件 @
4ff41808
...
...
@@ -259,6 +259,23 @@ class _AutoParallelContext:
self
.
check_context_handle
()
return
self
.
_context_handle
.
get_all_reduce_fusion_split_sizes
()
def
set_enable_all_reduce_fusion
(
self
,
enable_all_reduce_fusion
):
"""
Set enable/disable all reduce fusion.
Args:
enable_all_reduce_fusion (bool): Enable/disable all reduce fusion.
"""
self
.
check_context_handle
()
if
not
isinstance
(
enable_all_reduce_fusion
,
bool
):
raise
TypeError
(
'enable_all_reduce_fusion is invalid type'
)
self
.
_context_handle
.
set_enable_all_reduce_fusion
(
enable_all_reduce_fusion
)
def
get_enable_all_reduce_fusion
(
self
):
"""Get all reduce fusion flag."""
self
.
check_context_handle
()
return
self
.
_context_handle
.
get_enable_all_reduce_fusion
()
def
get_device_num_is_set
(
self
):
"""Get device number is set or not."""
self
.
check_context_handle
()
...
...
mindspore/parallel/_utils.py
浏览文件 @
4ff41808
...
...
@@ -117,6 +117,7 @@ _cast_before_mirror = None
_loss_repeated_mean
=
None
_communication_backend
=
None
_has_checkpointed
=
False
_enable_all_reduce_fusion
=
None
def
_checkpoint_auto_parallel_context
():
...
...
@@ -133,6 +134,7 @@ def _checkpoint_auto_parallel_context():
global
_cast_before_mirror
global
_loss_repeated_mean
global
_communication_backend
global
_enable_all_reduce_fusion
_parallel_mode
=
auto_parallel_context
().
get_parallel_mode
()
_device_num
=
_get_device_num
()
_global_rank
=
_get_global_rank
()
...
...
@@ -141,6 +143,7 @@ def _checkpoint_auto_parallel_context():
_cast_before_mirror
=
auto_parallel_context
().
get_cast_before_mirror
()
_loss_repeated_mean
=
auto_parallel_context
().
get_loss_repeated_mean
()
_communication_backend
=
auto_parallel_context
().
get_communication_backend
()
_enable_all_reduce_fusion
=
auto_parallel_context
().
get_enable_all_reduce_fusion
()
_has_checkpointed
=
True
...
...
@@ -154,10 +157,12 @@ def _restore_auto_parallel_context():
global
_cast_before_mirror
global
_loss_repeated_mean
global
_communication_backend
global
_enable_all_reduce_fusion
_set_auto_parallel_context
(
parallel_mode
=
_parallel_mode
,
device_num
=
_device_num
,
global_rank
=
_global_rank
,
parameter_broadcast
=
_parameter_broadcast
,
mirror_mean
=
_mirror_mean
,
cast_before_mirror
=
_cast_before_mirror
,
loss_repeated_mean
=
_loss_repeated_mean
)
auto_parallel_context
().
set_communication_backend
(
_communication_backend
)
auto_parallel_context
().
set_enable_all_reduce_fusion
(
_enable_all_reduce_fusion
)
def
_reset_checkpoint_auto_parallel_context
():
...
...
tests/ut/python/parallel/__init__.py
浏览文件 @
4ff41808
...
...
@@ -13,10 +13,12 @@
# limitations under the License.
import
mindspore.context
as
context
from
mindspore.parallel._auto_parallel_context
import
auto_parallel_context
from
mindspore.parallel._utils
import
_reset_op_id
def
setup_module
(
module
):
auto_parallel_context
().
set_enable_all_reduce_fusion
(
enable_all_reduce_fusion
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
False
)
_reset_op_id
()
...
...
tests/ut/python/parallel/test_allreduce_fusion.py
浏览文件 @
4ff41808
...
...
@@ -23,7 +23,7 @@ from tests.dataset_mock import MindData
from
mindspore
import
context
from
mindspore.common.api
import
_executor
from
mindspore.parallel
import
_cost_model_context
as
cost_model_context
from
mindspore.parallel._auto_parallel_context
import
auto_parallel_context
class
Dataset
(
MindData
):
...
...
@@ -105,6 +105,7 @@ def train_common(net):
epoch_size
=
2
device_num
=
4
context
.
reset_auto_parallel_context
()
auto_parallel_context
().
set_enable_all_reduce_fusion
(
enable_all_reduce_fusion
=
True
)
context
.
set_auto_parallel_context
(
parallel_mode
=
ParallelMode
.
SEMI_AUTO_PARALLEL
,
device_num
=
device_num
,
parameter_broadcast
=
False
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录