Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
8f7aa5bd
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
8f7aa5bd
编写于
8月 29, 2020
作者:
Y
yao_yf
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
auto parallel context modify
上级
042ac51f
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
57 addition
and
87 deletion
+57
-87
mindspore/ccsrc/frontend/parallel/context.cc
mindspore/ccsrc/frontend/parallel/context.cc
+3
-10
mindspore/ccsrc/frontend/parallel/context.h
mindspore/ccsrc/frontend/parallel/context.h
+3
-7
mindspore/ccsrc/frontend/parallel/step_parallel.cc
mindspore/ccsrc/frontend/parallel/step_parallel.cc
+13
-8
mindspore/ccsrc/pipeline/jit/init.cc
mindspore/ccsrc/pipeline/jit/init.cc
+2
-4
mindspore/communication/management.py
mindspore/communication/management.py
+0
-4
mindspore/context.py
mindspore/context.py
+5
-5
mindspore/parallel/_auto_parallel_context.py
mindspore/parallel/_auto_parallel_context.py
+15
-29
tests/ut/python/hccl_test/manage/api.py
tests/ut/python/hccl_test/manage/api.py
+1
-1
tests/ut/python/parallel/test_element_wise_function.py
tests/ut/python/parallel/test_element_wise_function.py
+7
-7
tests/ut/python/parallel/test_set_auto_parallel_context.py
tests/ut/python/parallel/test_set_auto_parallel_context.py
+8
-12
未找到文件。
mindspore/ccsrc/frontend/parallel/context.cc
浏览文件 @
8f7aa5bd
...
...
@@ -42,15 +42,12 @@ std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
return
inst_context_
;
}
ParallelContext
::
ParallelContext
()
{
communication_backend_
=
HCCL_BACKEND
;
Reset
();
}
ParallelContext
::
ParallelContext
()
{
Reset
();
}
void
ParallelContext
::
Reset
()
{
mirror_mean_
=
false
;
full_batch_
=
false
;
cast_before_mirror
_
=
true
;
gradient_fp32_sync
_
=
true
;
loss_repeated_mean_
=
true
;
device_num_
=
1
;
global_rank_
=
0
;
...
...
@@ -81,14 +78,10 @@ void ParallelContext::set_mirror_mean(bool mirror_mean) { mirror_mean_ = mirror_
void
ParallelContext
::
set_full_batch
(
bool
full_batch
)
{
full_batch_
=
full_batch
;
}
void
ParallelContext
::
set_
cast_before_mirror
(
bool
cast_before_mirror
)
{
cast_before_mirror_
=
cast_before_mirror
;
}
void
ParallelContext
::
set_
gradient_fp32_sync
(
bool
gradient_fp32_sync
)
{
gradient_fp32_sync_
=
gradient_fp32_sync
;
}
void
ParallelContext
::
set_loss_repeated_mean
(
bool
loss_repeated_mean
)
{
loss_repeated_mean_
=
loss_repeated_mean
;
}
void
ParallelContext
::
set_communication_backend
(
const
std
::
string
&
communication_backend
)
{
communication_backend_
=
communication_backend
;
}
bool
ParallelContext
::
set_parallel_mode
(
const
std
::
string
&
parallel_mode
)
{
auto
iter
=
std
::
find
(
PARALLEL_MODE_LIST
.
begin
(),
PARALLEL_MODE_LIST
.
end
(),
parallel_mode
);
if
(
iter
==
PARALLEL_MODE_LIST
.
end
())
{
...
...
mindspore/ccsrc/frontend/parallel/context.h
浏览文件 @
8f7aa5bd
...
...
@@ -58,8 +58,8 @@ class ParallelContext {
void
set_full_batch
(
bool
full_batch
);
bool
full_batch
()
const
{
return
full_batch_
;
}
void
set_
cast_before_mirror
(
bool
cast_before_mirror
);
bool
cast_before_mirror
()
const
{
return
cast_before_mirror
_
;
}
void
set_
gradient_fp32_sync
(
bool
gradient_fp32_sync
);
bool
gradient_fp32_sync
()
const
{
return
gradient_fp32_sync
_
;
}
void
set_loss_repeated_mean
(
bool
loss_repeated_mean
);
bool
loss_repeated_mean
()
const
{
return
loss_repeated_mean_
;
}
...
...
@@ -70,9 +70,6 @@ class ParallelContext {
void
set_global_rank
(
int32_t
global_rank
);
int32_t
global_rank
()
const
{
return
global_rank_
;
}
void
set_communication_backend
(
const
std
::
string
&
communication_backend
);
std
::
string
communication_backend
()
const
{
return
communication_backend_
;
}
bool
set_parallel_mode
(
const
std
::
string
&
parallel_mode
);
std
::
string
parallel_mode
()
const
{
return
parallel_mode_
;
}
...
...
@@ -112,11 +109,10 @@ class ParallelContext {
static
std
::
shared_ptr
<
ParallelContext
>
inst_context_
;
bool
mirror_mean_
;
bool
full_batch_
;
bool
cast_before_mirror
_
;
bool
gradient_fp32_sync
_
;
bool
loss_repeated_mean_
;
int32_t
device_num_
;
int32_t
global_rank_
;
std
::
string
communication_backend_
;
std
::
string
parallel_mode_
;
std
::
string
strategy_search_mode_
;
bool
parameter_broadcast_
;
...
...
mindspore/ccsrc/frontend/parallel/step_parallel.cc
浏览文件 @
8f7aa5bd
...
...
@@ -43,6 +43,7 @@
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
#include "utils/comm_manager.h"
#include "utils/symbolic.h"
#include "utils/ms_context.h"
using
mindspore
::
tensor
::
Tensor
;
...
...
@@ -869,8 +870,8 @@ std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &
}
bool
IsCastBeforMirror
(
const
CNodePtr
&
node
,
size_t
index
)
{
// only if
cast_before_mirror
is true, pre node is cast and type is not float32 return true
if
(
!
ParallelContext
::
GetInstance
()
->
cast_before_mirror
())
{
// only if
gradient_fp32_sync
is true, pre node is cast and type is not float32 return true
if
(
!
ParallelContext
::
GetInstance
()
->
gradient_fp32_sync
())
{
return
false
;
}
auto
pre_node
=
node
->
input
(
index
);
...
...
@@ -2421,13 +2422,17 @@ Status ParallelInit() {
MS_EXCEPTION_IF_NULL
(
ParallelContext
::
GetInstance
());
int32_t
device_num
=
ParallelContext
::
GetInstance
()
->
device_num
();
int32_t
global_rank
=
ParallelContext
::
GetInstance
()
->
global_rank
();
std
::
string
backend
=
ParallelContext
::
GetInstance
()
->
communication_backend
();
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
std
::
string
backend
=
ms_context
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
);
std
::
string
world_group
;
if
(
backend
==
HCCL_BACKEND
)
{
std
::
string
communication_backend
;
if
(
backend
==
kAscendDevice
||
backend
==
kDavinciDevice
)
{
world_group
=
HCCL_WORLD_GROUP
;
}
else
if
(
backend
==
NCCL_BACKEND
)
{
communication_backend
=
HCCL_BACKEND
;
}
else
if
(
backend
==
kGPUDevice
)
{
world_group
=
NCCL_WORLD_GROUP
;
communication_backend
=
NCCL_BACKEND
;
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Invalid communication backend: "
<<
backend
;
}
...
...
@@ -2450,14 +2455,14 @@ Status ParallelInit() {
MS_LOG
(
INFO
)
<<
"Get global rank from communication model, the global rank is "
<<
global_rank
;
}
if
(
!
InitDevice
(
device_num
,
global_rank
,
backend
))
{
if
(
!
InitDevice
(
device_num
,
global_rank
,
communication_
backend
))
{
MS_LOG
(
ERROR
)
<<
"Init device failed"
;
return
FAILED
;
}
MS_LOG
(
INFO
)
<<
"The parallel context: dev num: "
<<
device_num
<<
", global rank: "
<<
global_rank
<<
", backend: "
<<
backend
<<
", mirror_mean: "
<<
ParallelContext
::
GetInstance
()
->
mirror_mean
()
<<
",
cast_before_mirror: "
<<
ParallelContext
::
GetInstance
()
->
cast_before_mirror
();
<<
",
gradient_fp32_sync: "
<<
ParallelContext
::
GetInstance
()
->
gradient_fp32_sync
();
return
SUCCESS
;
}
...
...
mindspore/ccsrc/pipeline/jit/init.cc
浏览文件 @
8f7aa5bd
...
...
@@ -209,12 +209,10 @@ PYBIND11_MODULE(_c_expression, m) {
.
def
(
"get_global_rank_is_set"
,
&
ParallelContext
::
global_rank_is_set
,
"Get global rank is set."
)
.
def
(
"get_mirror_mean"
,
&
ParallelContext
::
mirror_mean
,
"Get mirror mean."
)
.
def
(
"set_mirror_mean"
,
&
ParallelContext
::
set_mirror_mean
,
"Set mirror mean."
)
.
def
(
"get_
cast_before_mirror"
,
&
ParallelContext
::
cast_before_mirror
,
"Get cast before mirror."
)
.
def
(
"set_
cast_before_mirror"
,
&
ParallelContext
::
set_cast_before_mirror
,
"Set cast before mirror."
)
.
def
(
"get_
gradient_fp32_sync"
,
&
ParallelContext
::
gradient_fp32_sync
,
"Get cast before mirror."
)
.
def
(
"set_
gradient_fp32_sync"
,
&
ParallelContext
::
set_gradient_fp32_sync
,
"Set cast before mirror."
)
.
def
(
"get_loss_repeated_mean"
,
&
ParallelContext
::
loss_repeated_mean
,
"Get loss repeated mean."
)
.
def
(
"set_loss_repeated_mean"
,
&
ParallelContext
::
set_loss_repeated_mean
,
"Set loss repeated mean."
)
.
def
(
"get_communication_backend"
,
&
ParallelContext
::
communication_backend
,
"Get communication backend."
)
.
def
(
"set_communication_backend"
,
&
ParallelContext
::
set_communication_backend
,
"Set communication backend."
)
.
def
(
"get_parallel_mode"
,
&
ParallelContext
::
parallel_mode
,
"Get parallel mode."
)
.
def
(
"set_parallel_mode"
,
&
ParallelContext
::
set_parallel_mode
,
"Set parallel mode."
)
.
def
(
"get_strategy_search_mode"
,
&
ParallelContext
::
strategy_search_mode
,
"Get strategy search mode."
)
...
...
mindspore/communication/management.py
浏览文件 @
8f7aa5bd
...
...
@@ -15,7 +15,6 @@
"""Communication management API"""
import
os
from
mindspore
import
context
from
mindspore.parallel._auto_parallel_context
import
auto_parallel_context
from
._comm_helper
import
Backend
,
_get_rank_helper
,
_get_size_helper
,
\
_get_world_rank_from_group_rank_helper
,
_get_group_rank_from_world_rank_helper
,
\
_create_group_helper
,
_destroy_group_helper
,
HCCL_WORLD_COMM_GROUP
,
NCCL_WORLD_COMM_GROUP
,
\
...
...
@@ -86,9 +85,6 @@ def init(backend_name=None):
else
:
raise
RuntimeError
(
"Backend name {} is not supported."
.
format
(
backend_name
))
auto_parallel_context
().
set_communication_backend
(
backend_name
)
def
release
():
"""
Release distributed resource. e.g., hccl/nccl.
...
...
mindspore/context.py
浏览文件 @
8f7aa5bd
...
...
@@ -434,7 +434,7 @@ def _context():
return
_k_context
@
args_type_check
(
device_num
=
int
,
global_rank
=
int
,
mirror_mean
=
bool
,
cast_before_mirror
=
bool
,
parallel_mode
=
str
,
@
args_type_check
(
device_num
=
int
,
global_rank
=
int
,
mirror_mean
=
bool
,
gradient_fp32_sync
=
bool
,
parallel_mode
=
str
,
auto_parallel_search_mode
=
str
,
parameter_broadcast
=
bool
,
strategy_ckpt_load_file
=
str
,
strategy_ckpt_save_file
=
str
,
full_batch
=
bool
,
enable_parallel_optimizer
=
bool
)
def
set_auto_parallel_context
(
**
kwargs
):
...
...
@@ -454,9 +454,9 @@ def set_auto_parallel_context(**kwargs):
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror.
"stand_alone" do not support mirror_mean. Default: False.
cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True
.
gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.
.
"stand_alone", "data_parallel" and "hybrid_parallel" do not support
cast_before_mirror
. Default: True.
gradient_fp32_sync
. Default: True.
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
"hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
...
...
@@ -492,7 +492,7 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(device_num=8)
>>> context.set_auto_parallel_context(global_rank=0)
>>> context.set_auto_parallel_context(mirror_mean=True)
>>> context.set_auto_parallel_context(
cast_before_mirror
=False)
>>> context.set_auto_parallel_context(
gradient_fp32_sync
=False)
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
>>> context.set_auto_parallel_context(parameter_broadcast=False)
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
...
...
@@ -524,7 +524,7 @@ def reset_auto_parallel_context():
- device_num: 1.
- global_rank: 0.
- mirror_mean: False.
-
cast_before_mirror
: True.
-
gradient_fp32_sync
: True.
- parallel_mode: "stand_alone".
- parameter_broadcast: False.
- strategy_ckpt_load_file: "".
...
...
mindspore/parallel/_auto_parallel_context.py
浏览文件 @
8f7aa5bd
...
...
@@ -113,24 +113,24 @@ class _AutoParallelContext:
self
.
check_context_handle
()
return
self
.
_context_handle
.
get_mirror_mean
()
def
set_
cast_before_mirror
(
self
,
cast_before_mirror
):
def
set_
gradient_fp32_sync
(
self
,
gradient_fp32_sync
):
"""
Set
cast_before_mirror
.
Set
gradient_fp32_sync
.
Note:
If
cast_before_mirror
is true,
If
gradient_fp32_sync
is true,
it will convert tensor type from fp16 to fp32 before parameter gradients allreduce.
Args:
cast_before_mirror (bool): The cast_before_mirror
flag.
gradient_fp32_sync (bool): The gradient_fp32_sync
flag.
"""
self
.
check_context_handle
()
self
.
_context_handle
.
set_
cast_before_mirror
(
cast_before_mirror
)
self
.
_context_handle
.
set_
gradient_fp32_sync
(
gradient_fp32_sync
)
def
get_
cast_before_mirror
(
self
):
"""Get
cast_before_mirror
flag."""
def
get_
gradient_fp32_sync
(
self
):
"""Get
gradient_fp32_sync
flag."""
self
.
check_context_handle
()
return
self
.
_context_handle
.
get_
cast_before_mirror
()
return
self
.
_context_handle
.
get_
gradient_fp32_sync
()
def
set_loss_repeated_mean
(
self
,
loss_repeated_mean
):
"""
...
...
@@ -152,21 +152,6 @@ class _AutoParallelContext:
self
.
check_context_handle
()
return
self
.
_context_handle
.
get_loss_repeated_mean
()
def
set_communication_backend
(
self
,
communication_backend
):
"""
Set communication backend.
Args:
communication_backend (str): The communication backend.
"""
self
.
check_context_handle
()
self
.
_context_handle
.
set_communication_backend
(
communication_backend
)
def
get_communication_backend
(
self
):
"""Get communication backend."""
self
.
check_context_handle
()
return
self
.
_context_handle
.
get_communication_backend
()
def
set_parallel_mode
(
self
,
parallel_mode
):
"""
Set parallel mode for auto parallel.
...
...
@@ -469,7 +454,7 @@ _set_auto_parallel_context_func_map = {
"device_num"
:
auto_parallel_context
().
set_device_num
,
"global_rank"
:
auto_parallel_context
().
set_global_rank
,
"mirror_mean"
:
auto_parallel_context
().
set_mirror_mean
,
"
cast_before_mirror"
:
auto_parallel_context
().
set_cast_before_mirror
,
"
gradient_fp32_sync"
:
auto_parallel_context
().
set_gradient_fp32_sync
,
"loss_repeated_mean"
:
auto_parallel_context
().
set_loss_repeated_mean
,
"parallel_mode"
:
auto_parallel_context
().
set_parallel_mode
,
"auto_parallel_search_mode"
:
auto_parallel_context
().
set_strategy_search_mode
,
...
...
@@ -484,7 +469,7 @@ _get_auto_parallel_context_func_map = {
"device_num"
:
auto_parallel_context
().
get_device_num
,
"global_rank"
:
auto_parallel_context
().
get_global_rank
,
"mirror_mean"
:
auto_parallel_context
().
get_mirror_mean
,
"
cast_before_mirror"
:
auto_parallel_context
().
get_cast_before_mirror
,
"
gradient_fp32_sync"
:
auto_parallel_context
().
get_gradient_fp32_sync
,
"loss_repeated_mean"
:
auto_parallel_context
().
get_loss_repeated_mean
,
"parallel_mode"
:
auto_parallel_context
().
get_parallel_mode
,
"auto_parallel_search_mode"
:
auto_parallel_context
().
get_strategy_search_mode
,
...
...
@@ -495,7 +480,7 @@ _get_auto_parallel_context_func_map = {
"enable_parallel_optimizer"
:
auto_parallel_context
().
get_enable_parallel_optimizer
}
@
args_type_check
(
device_num
=
int
,
global_rank
=
int
,
mirror_mean
=
bool
,
cast_before_mirror
=
bool
,
@
args_type_check
(
device_num
=
int
,
global_rank
=
int
,
mirror_mean
=
bool
,
gradient_fp32_sync
=
bool
,
loss_repeated_mean
=
bool
,
parallel_mode
=
str
,
auto_parallel_search_mode
=
str
,
parameter_broadcast
=
bool
,
strategy_ckpt_load_file
=
str
,
strategy_ckpt_save_file
=
str
,
full_batch
=
bool
,
enable_parallel_optimizer
=
bool
)
...
...
@@ -512,8 +497,9 @@ def _set_auto_parallel_context(**kwargs):
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
calculations. Default: True.
cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True. Default: True.
calculations. Default: True.
gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.
Default: True.
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
"hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
...
...
@@ -577,7 +563,7 @@ def _reset_auto_parallel_context():
- device_num: 1.
- global_rank: 0.
- mirror_mean: False.
-
cast_before_mirror
: True.
-
gradient_fp32_sync
: True.
- parallel_mode: "stand_alone".
- parameter_broadcast: False.
- strategy_ckpt_load_file: ""
...
...
tests/ut/python/hccl_test/manage/api.py
浏览文件 @
8f7aa5bd
...
...
@@ -61,7 +61,7 @@ def get_rank_id(group=None):
def
get_rank_size
(
group
=
None
):
hccl
=
Hccl
()
if
group
is
None
:
if
group
is
None
or
"nccl_world_group"
in
group
:
return
hccl
.
rank_size
if
isinstance
(
group
,
str
):
return
int
(
group
.
split
(
"-"
)[
0
])
...
...
tests/ut/python/parallel/test_element_wise_function.py
浏览文件 @
8f7aa5bd
...
...
@@ -830,7 +830,7 @@ def test_matmul_cast():
compile_net
(
net
,
x
,
y
,
b
)
def
test_
cast_before_mirror
():
def
test_
gradient_fp32_sync
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
strategy1
):
super
().
__init__
()
...
...
@@ -843,7 +843,7 @@ def test_cast_before_mirror():
out
=
self
.
matmul
(
out
,
b
)
return
out
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
cast_before_mirror
=
True
)
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
gradient_fp32_sync
=
True
)
strategy1
=
((
2
,
2
),
(
2
,
2
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
strategy1
)))
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
)
...
...
@@ -854,7 +854,7 @@ def test_cast_before_mirror():
compile_net
(
net
,
x
,
y
,
b
)
def
test_
cast_before_mirror
1
():
def
test_
gradient_fp32_sync
1
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
strategy1
):
super
().
__init__
()
...
...
@@ -867,7 +867,7 @@ def test_cast_before_mirror1():
out
=
self
.
matmul
(
out
,
b
)
return
out
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
cast_before_mirror
=
True
)
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
gradient_fp32_sync
=
True
)
strategy1
=
((
2
,
2
),
(
2
,
2
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
strategy1
)))
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
)
...
...
@@ -878,7 +878,7 @@ def test_cast_before_mirror1():
compile_net
(
net
,
x
,
y
,
b
)
def
test_
cast_before_mirror
2
():
def
test_
gradient_fp32_sync
2
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
strategy1
):
super
().
__init__
()
...
...
@@ -891,7 +891,7 @@ def test_cast_before_mirror2():
out
=
self
.
matmul
(
out
,
b
)
return
out
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
cast_before_mirror
=
False
)
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
gradient_fp32_sync
=
False
)
strategy1
=
((
2
,
2
),
(
2
,
2
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
strategy1
)))
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
)
...
...
@@ -902,7 +902,7 @@ def test_cast_before_mirror2():
compile_net
(
net
,
x
,
y
,
b
)
def
test_
cast_before_mirror
3
():
def
test_
gradient_fp32_sync
3
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
strategy1
):
super
().
__init__
()
...
...
tests/ut/python/parallel/test_set_auto_parallel_context.py
浏览文件 @
8f7aa5bd
...
...
@@ -20,25 +20,21 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
def
test_set_auto_parallel_context
():
context
.
set_auto_parallel_context
(
device_num
=
4
,
global_rank
=
3
,
mirror_mean
=
True
,
cast_before_mirror
=
False
,
context
.
set_auto_parallel_context
(
device_num
=
4
,
global_rank
=
3
,
mirror_mean
=
True
,
gradient_fp32_sync
=
False
,
parallel_mode
=
"auto_parallel"
,
parameter_broadcast
=
False
)
device_num
=
context
.
get_auto_parallel_context
(
"device_num"
)
global_rank
=
context
.
get_auto_parallel_context
(
"global_rank"
)
mirror_mean
=
context
.
get_auto_parallel_context
(
"mirror_mean"
)
cast_before_mirror
=
context
.
get_auto_parallel_context
(
"cast_before_mirror
"
)
gradient_fp32_sync
=
context
.
get_auto_parallel_context
(
"gradient_fp32_sync
"
)
parallel_mode
=
context
.
get_auto_parallel_context
(
"parallel_mode"
)
parameter_broadcast
=
context
.
get_auto_parallel_context
(
"parameter_broadcast"
)
assert
device_num
==
4
assert
global_rank
==
3
assert
mirror_mean
assert
not
cast_before_mirror
assert
not
gradient_fp32_sync
assert
parallel_mode
==
"auto_parallel"
assert
not
parameter_broadcast
auto_parallel_context
().
set_communication_backend
(
"hccl"
)
backend
=
auto_parallel_context
().
get_communication_backend
()
assert
backend
==
"hccl"
auto_parallel_context
().
set_device_num
(
4
)
device_num
=
auto_parallel_context
().
get_device_num
()
device_num_is_set
=
auto_parallel_context
().
get_device_num_is_set
()
...
...
@@ -53,9 +49,9 @@ def test_set_auto_parallel_context():
mirror_mean
=
auto_parallel_context
().
get_mirror_mean
()
assert
mirror_mean
auto_parallel_context
().
set_
cast_before_mirror
(
False
)
cast_before_mirror
=
auto_parallel_context
().
get_cast_before_mirror
()
assert
not
cast_before_mirror
auto_parallel_context
().
set_
gradient_fp32_sync
(
False
)
gradient_fp32_sync
=
auto_parallel_context
().
get_gradient_fp32_sync
()
assert
not
gradient_fp32_sync
parameter_broadcast_is_set
=
auto_parallel_context
().
get_parameter_broadcast_is_set
()
assert
parameter_broadcast_is_set
...
...
@@ -91,7 +87,7 @@ def test_reset_auto_parallel_context():
device_num
=
context
.
get_auto_parallel_context
(
"device_num"
)
global_rank
=
context
.
get_auto_parallel_context
(
"global_rank"
)
mirror_mean
=
context
.
get_auto_parallel_context
(
"mirror_mean"
)
cast_before_mirror
=
context
.
get_auto_parallel_context
(
"cast_before_mirror
"
)
gradient_fp32_sync
=
context
.
get_auto_parallel_context
(
"gradient_fp32_sync
"
)
parallel_mode
=
context
.
get_auto_parallel_context
(
"parallel_mode"
)
parameter_broadcast
=
context
.
get_auto_parallel_context
(
"parameter_broadcast"
)
device_num_is_set
=
auto_parallel_context
().
get_device_num_is_set
()
...
...
@@ -99,7 +95,7 @@ def test_reset_auto_parallel_context():
assert
device_num
==
1
assert
global_rank
==
0
assert
not
mirror_mean
assert
cast_before_mirror
assert
gradient_fp32_sync
assert
parallel_mode
==
"stand_alone"
assert
not
parameter_broadcast
assert
not
device_num_is_set
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录