Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e6a4d932
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e6a4d932
编写于
8月 29, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 29, 2020
浏览文件
操作
浏览文件
下载
差异文件
!5350 [AutoParallel]Rectification distributed init
Merge pull request !5350 from lichen/rectification_init
上级
d81b30e6
d3e55b54
变更
31
隐藏空白更改
内联
并排
Showing
31 changed file
with
52 addition
and
63 deletion
+52
-63
mindspore/communication/management.py
mindspore/communication/management.py
+12
-2
mindspore/ops/operations/comm_ops.py
mindspore/ops/operations/comm_ops.py
+4
-4
model_zoo/official/cv/googlenet/train.py
model_zoo/official/cv/googlenet/train.py
+1
-1
model_zoo/official/cv/inceptionv3/train.py
model_zoo/official/cv/inceptionv3/train.py
+1
-4
model_zoo/official/cv/mobilenetv2/train.py
model_zoo/official/cv/mobilenetv2/train.py
+1
-1
model_zoo/official/cv/mobilenetv2_quant/train.py
model_zoo/official/cv/mobilenetv2_quant/train.py
+1
-1
model_zoo/official/cv/mobilenetv3/train.py
model_zoo/official/cv/mobilenetv3/train.py
+1
-1
model_zoo/official/cv/resnet/src/dataset.py
model_zoo/official/cv/resnet/src/dataset.py
+2
-2
model_zoo/official/cv/resnet/train.py
model_zoo/official/cv/resnet/train.py
+1
-1
model_zoo/official/cv/resnet50_quant/src/dataset.py
model_zoo/official/cv/resnet50_quant/src/dataset.py
+2
-2
model_zoo/official/cv/resnet_thor/src/dataset.py
model_zoo/official/cv/resnet_thor/src/dataset.py
+1
-1
model_zoo/official/cv/resnet_thor/train.py
model_zoo/official/cv/resnet_thor/train.py
+1
-1
model_zoo/official/cv/resnext50/eval.py
model_zoo/official/cv/resnext50/eval.py
+1
-4
model_zoo/official/cv/resnext50/train.py
model_zoo/official/cv/resnext50/train.py
+1
-4
model_zoo/official/cv/vgg16/train.py
model_zoo/official/cv/vgg16/train.py
+1
-1
model_zoo/official/cv/warpctc/train.py
model_zoo/official/cv/warpctc/train.py
+1
-1
model_zoo/official/nlp/bert/run_pretrain.py
model_zoo/official/nlp/bert/run_pretrain.py
+2
-2
model_zoo/official/nlp/bert_thor/run_pretrain.py
model_zoo/official/nlp/bert_thor/run_pretrain.py
+2
-2
model_zoo/official/nlp/mass/train.py
model_zoo/official/nlp/mass/train.py
+1
-4
model_zoo/official/nlp/tinybert/run_general_distill.py
model_zoo/official/nlp/tinybert/run_general_distill.py
+2
-2
model_zoo/official/recommend/deepfm/train.py
model_zoo/official/recommend/deepfm/train.py
+1
-1
model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py
...l/recommend/wide_and_deep/train_and_eval_auto_parallel.py
+1
-4
model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py
...cial/recommend/wide_and_deep/train_and_eval_distribute.py
+1
-4
model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py
...ecommend/wide_and_deep/train_and_eval_parameter_server.py
+1
-4
tests/st/nccl/test_nccl_all_gather_op.py
tests/st/nccl/test_nccl_all_gather_op.py
+1
-1
tests/st/nccl/test_nccl_all_reduce_op.py
tests/st/nccl/test_nccl_all_reduce_op.py
+1
-1
tests/st/nccl/test_nccl_broadcast_op.py
tests/st/nccl/test_nccl_broadcast_op.py
+1
-1
tests/st/nccl/test_nccl_lenet.py
tests/st/nccl/test_nccl_lenet.py
+1
-1
tests/st/nccl/test_nccl_reduce_scatter_op.py
tests/st/nccl/test_nccl_reduce_scatter_op.py
+1
-1
tests/st/ps/multi_full_ps/test_multi_full_ps.py
tests/st/ps/multi_full_ps/test_multi_full_ps.py
+1
-1
tests/ut/python/train/test_dataset_helper.py
tests/ut/python/train/test_dataset_helper.py
+3
-3
未找到文件。
mindspore/communication/management.py
浏览文件 @
e6a4d932
...
...
@@ -14,6 +14,7 @@
# ============================================================================
"""Communication management API"""
import
os
from
mindspore
import
context
from
mindspore.parallel._auto_parallel_context
import
auto_parallel_context
from
._comm_helper
import
Backend
,
_get_rank_helper
,
_get_size_helper
,
\
_get_world_rank_from_group_rank_helper
,
_get_group_rank_from_world_rank_helper
,
\
...
...
@@ -45,7 +46,7 @@ class GlobalComm:
WORLD_COMM_GROUP
=
DEFAULT_WORLD_COMM_GROUP
def
init
(
backend_name
=
"hccl"
):
def
init
(
backend_name
=
None
):
"""
Init distributed backend, e.g., hccl/nccl, it is required before communication service can be used.
...
...
@@ -57,11 +58,20 @@ def init(backend_name="hccl"):
backend_name (str): Backend.
Raises:
TypeError: If backend name is not a string.
TypeError: If backen_name is not a string.
RuntimeError: If device target is invalid.
RuntimeError: If backend is invalid or distributed init fails.
"""
if
MS_ROLE
in
(
"MS_PSERVER"
,
"MS_SCHED"
):
return
if
backend_name
is
None
:
device_target
=
context
.
get_context
(
"device_target"
)
if
device_target
==
"Ascend"
:
backend_name
=
"hccl"
elif
device_target
==
"GPU"
:
backend_name
=
"nccl"
else
:
raise
RuntimeError
(
"Device target {} is not supported."
.
format
(
device_target
))
if
not
isinstance
(
backend_name
,
str
):
raise
TypeError
(
"Backend name must be a string, but got {}"
.
format
(
type
(
backend_name
)))
...
...
mindspore/ops/operations/comm_ops.py
浏览文件 @
e6a4d932
...
...
@@ -73,7 +73,7 @@ class AllReduce(PrimitiveWithInfer):
>>> import mindspore.nn as nn
>>> import mindspore.ops.operations as P
>>>
>>> init(
'nccl'
)
>>> init()
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
...
...
@@ -136,7 +136,7 @@ class AllGather(PrimitiveWithInfer):
>>> from mindspore.communication import init
>>> from mindspore import Tensor
>>>
>>> init(
'nccl'
)
>>> init()
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
...
...
@@ -246,7 +246,7 @@ class ReduceScatter(PrimitiveWithInfer):
>>> import mindspore.nn as nn
>>> import mindspore.ops.operations as P
>>>
>>> init(
'nccl'
)
>>> init()
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
...
...
@@ -360,7 +360,7 @@ class Broadcast(PrimitiveWithInfer):
>>> import mindspore.nn as nn
>>> import mindspore.ops.operations as P
>>>
>>> init(
'nccl'
)
>>> init()
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
...
...
model_zoo/official/cv/googlenet/train.py
浏览文件 @
e6a4d932
...
...
@@ -82,7 +82,7 @@ if __name__ == '__main__':
mirror_mean
=
True
)
init
()
elif
device_target
==
"GPU"
:
init
(
"nccl"
)
init
()
if
device_num
>
1
:
context
.
reset_auto_parallel_context
()
...
...
model_zoo/official/cv/inceptionv3/train.py
浏览文件 @
e6a4d932
...
...
@@ -57,10 +57,7 @@ if __name__ == '__main__':
cfg
=
config_ascend
if
args_opt
.
platform
==
'Ascend'
else
config_gpu
# init distributed
if
args_opt
.
is_distributed
:
if
args_opt
.
platform
==
"Ascend"
:
init
()
else
:
init
(
"nccl"
)
init
()
cfg
.
rank
=
get_rank
()
cfg
.
group_size
=
get_group_size
()
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
...
...
model_zoo/official/cv/mobilenetv2/train.py
浏览文件 @
e6a4d932
...
...
@@ -65,7 +65,7 @@ elif args_opt.device_target == "GPU":
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
,
save_graphs
=
False
)
init
(
"nccl"
)
init
()
context
.
set_auto_parallel_context
(
device_num
=
get_group_size
(),
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
)
...
...
model_zoo/official/cv/mobilenetv2_quant/train.py
浏览文件 @
e6a4d932
...
...
@@ -58,7 +58,7 @@ if args_opt.device_target == "Ascend":
device_target
=
"Ascend"
,
device_id
=
device_id
,
save_graphs
=
False
)
elif
args_opt
.
device_target
==
"GPU"
:
init
(
"nccl"
)
init
()
context
.
set_auto_parallel_context
(
device_num
=
get_group_size
(),
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
)
...
...
model_zoo/official/cv/mobilenetv3/train.py
浏览文件 @
e6a4d932
...
...
@@ -55,7 +55,7 @@ if args_opt.device_target == "GPU":
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
,
save_graphs
=
False
)
init
(
"nccl"
)
init
()
context
.
set_auto_parallel_context
(
device_num
=
get_group_size
(),
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
)
...
...
model_zoo/official/cv/resnet/src/dataset.py
浏览文件 @
e6a4d932
...
...
@@ -38,7 +38,7 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target=
if
target
==
"Ascend"
:
device_num
,
rank_id
=
_get_rank_info
()
else
:
init
(
"nccl"
)
init
()
rank_id
=
get_rank
()
device_num
=
get_group_size
()
...
...
@@ -93,7 +93,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
if
target
==
"Ascend"
:
device_num
,
rank_id
=
_get_rank_info
()
else
:
init
(
"nccl"
)
init
()
rank_id
=
get_rank
()
device_num
=
get_group_size
()
...
...
model_zoo/official/cv/resnet/train.py
浏览文件 @
e6a4d932
...
...
@@ -87,7 +87,7 @@ if __name__ == '__main__':
init
()
# GPU target
else
:
init
(
"nccl"
)
init
()
context
.
set_auto_parallel_context
(
device_num
=
get_group_size
(),
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
)
if
args_opt
.
net
==
"resnet50"
:
...
...
model_zoo/official/cv/resnet50_quant/src/dataset.py
浏览文件 @
e6a4d932
...
...
@@ -46,7 +46,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
device_num
=
int
(
os
.
getenv
(
"RANK_SIZE"
))
rank_id
=
int
(
os
.
getenv
(
"RANK_ID"
))
else
:
init
(
"nccl"
)
init
()
rank_id
=
get_rank
()
device_num
=
get_group_size
()
...
...
@@ -114,7 +114,7 @@ def create_dataset_py(dataset_path, do_train, repeat_num=1, batch_size=32, targe
device_num
=
int
(
os
.
getenv
(
"RANK_SIZE"
))
rank_id
=
int
(
os
.
getenv
(
"RANK_ID"
))
else
:
init
(
"nccl"
)
init
()
rank_id
=
get_rank
()
device_num
=
get_group_size
()
...
...
model_zoo/official/cv/resnet_thor/src/dataset.py
浏览文件 @
e6a4d932
...
...
@@ -40,7 +40,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
if
target
==
"Ascend"
:
device_num
,
rank_id
=
_get_rank_info
()
else
:
init
(
"nccl"
)
init
()
rank_id
=
get_rank
()
device_num
=
get_group_size
()
...
...
model_zoo/official/cv/resnet_thor/train.py
浏览文件 @
e6a4d932
...
...
@@ -106,7 +106,7 @@ if __name__ == '__main__':
init
()
# GPU target
else
:
init
(
"nccl"
)
init
()
context
.
set_auto_parallel_context
(
device_num
=
get_group_size
(),
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
)
ckpt_save_dir
=
config
.
save_checkpoint_path
+
"ckpt_"
+
str
(
get_rank
())
+
"/"
...
...
model_zoo/official/cv/resnext50/eval.py
浏览文件 @
e6a4d932
...
...
@@ -112,10 +112,7 @@ def test(cloud_args=None):
# init distributed
if
args
.
is_distributed
:
if
args
.
platform
==
"Ascend"
:
init
()
elif
args
.
platform
==
"GPU"
:
init
(
"nccl"
)
init
()
args
.
rank
=
get_rank
()
args
.
group_size
=
get_group_size
()
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
...
...
model_zoo/official/cv/resnext50/train.py
浏览文件 @
e6a4d932
...
...
@@ -172,10 +172,7 @@ def train(cloud_args=None):
# init distributed
if
args
.
is_distributed
:
if
args
.
platform
==
"Ascend"
:
init
()
else
:
init
(
"nccl"
)
init
()
args
.
rank
=
get_rank
()
args
.
group_size
=
get_group_size
()
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
...
...
model_zoo/official/cv/vgg16/train.py
浏览文件 @
e6a4d932
...
...
@@ -136,7 +136,7 @@ if __name__ == '__main__':
init
()
context
.
set_context
(
device_id
=
args
.
device_id
)
elif
args
.
device_target
==
"GPU"
:
init
(
"nccl"
)
init
()
args
.
rank
=
get_rank
()
args
.
group_size
=
get_group_size
()
...
...
model_zoo/official/cv/warpctc/train.py
浏览文件 @
e6a4d932
...
...
@@ -61,7 +61,7 @@ if __name__ == '__main__':
device_num
=
int
(
os
.
environ
.
get
(
"RANK_SIZE"
))
rank
=
int
(
os
.
environ
.
get
(
"RANK_ID"
))
else
:
init
(
'nccl'
)
init
()
lr_scale
=
0.5
device_num
=
get_group_size
()
rank
=
get_rank
()
...
...
model_zoo/official/nlp/bert/run_pretrain.py
浏览文件 @
e6a4d932
...
...
@@ -70,11 +70,11 @@ def run_pretrain():
ckpt_save_dir
=
args_opt
.
save_checkpoint_path
if
args_opt
.
distribute
==
"true"
:
if
args_opt
.
device_target
==
'Ascend'
:
D
.
init
(
'hccl'
)
D
.
init
()
device_num
=
args_opt
.
device_num
rank
=
args_opt
.
device_id
%
device_num
else
:
D
.
init
(
'nccl'
)
D
.
init
()
device_num
=
D
.
get_group_size
()
rank
=
D
.
get_rank
()
ckpt_save_dir
=
args_opt
.
save_checkpoint_path
+
'ckpt_'
+
str
(
rank
)
+
'/'
...
...
model_zoo/official/nlp/bert_thor/run_pretrain.py
浏览文件 @
e6a4d932
...
...
@@ -73,11 +73,11 @@ def run_pretrain():
ckpt_save_dir
=
args_opt
.
save_checkpoint_path
if
args_opt
.
distribute
==
"true"
:
if
args_opt
.
device_target
==
'Ascend'
:
D
.
init
(
'hccl'
)
D
.
init
()
device_num
=
args_opt
.
device_num
rank
=
args_opt
.
device_id
%
device_num
else
:
D
.
init
(
'nccl'
)
D
.
init
()
device_num
=
D
.
get_group_size
()
rank
=
D
.
get_rank
()
ckpt_save_dir
=
args_opt
.
save_checkpoint_path
+
'ckpt_'
+
str
(
rank
)
+
'/'
...
...
model_zoo/official/nlp/mass/train.py
浏览文件 @
e6a4d932
...
...
@@ -228,10 +228,7 @@ def _build_training_pipeline(config: TransformerConfig,
def
_setup_parallel_env
(
platform
):
context
.
reset_auto_parallel_context
()
if
platform
==
"GPU"
:
MultiAscend
.
init
(
"nccl"
)
else
:
MultiAscend
.
init
()
MultiAscend
.
init
()
context
.
set_auto_parallel_context
(
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
device_num
=
MultiAscend
.
get_group_size
(),
...
...
model_zoo/official/nlp/tinybert/run_general_distill.py
浏览文件 @
e6a4d932
...
...
@@ -69,11 +69,11 @@ def run_general_distill():
if
args_opt
.
distribute
==
"true"
:
if
args_opt
.
device_target
==
'Ascend'
:
D
.
init
(
'hccl'
)
D
.
init
()
device_num
=
args_opt
.
device_num
rank
=
args_opt
.
device_id
%
device_num
else
:
D
.
init
(
'nccl'
)
D
.
init
()
device_num
=
D
.
get_group_size
()
rank
=
D
.
get_rank
()
save_ckpt_dir
=
save_ckpt_dir
+
'_ckpt_'
+
str
(
rank
)
...
...
model_zoo/official/recommend/deepfm/train.py
浏览文件 @
e6a4d932
...
...
@@ -64,7 +64,7 @@ if __name__ == '__main__':
init
()
rank_id
=
int
(
os
.
environ
.
get
(
'RANK_ID'
))
elif
args_opt
.
device_target
==
"GPU"
:
init
(
"nccl"
)
init
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args_opt
.
device_target
)
context
.
reset_auto_parallel_context
()
context
.
set_auto_parallel_context
(
device_num
=
get_group_size
(),
...
...
model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py
浏览文件 @
e6a4d932
...
...
@@ -146,10 +146,7 @@ if __name__ == "__main__":
context
.
set_context
(
variable_memory_max_size
=
"24GB"
)
context
.
set_context
(
enable_sparse
=
True
)
set_multi_subgraphs
()
if
wide_deep_config
.
device_target
==
"Ascend"
:
init
(
"hccl"
)
elif
wide_deep_config
.
device_target
==
"GPU"
:
init
(
"nccl"
)
init
()
if
wide_deep_config
.
host_device_mix
==
1
:
context
.
set_auto_parallel_context
(
parallel_mode
=
ParallelMode
.
SEMI_AUTO_PARALLEL
,
mirror_mean
=
True
)
...
...
model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py
浏览文件 @
e6a4d932
...
...
@@ -118,10 +118,7 @@ if __name__ == "__main__":
wide_deep_config
.
argparse_init
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
wide_deep_config
.
device_target
,
save_graphs
=
True
)
if
wide_deep_config
.
device_target
==
"Ascend"
:
init
(
"hccl"
)
elif
wide_deep_config
.
device_target
==
"GPU"
:
init
(
"nccl"
)
init
()
context
.
set_auto_parallel_context
(
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
,
device_num
=
get_group_size
())
...
...
model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py
浏览文件 @
e6a4d932
...
...
@@ -118,10 +118,7 @@ if __name__ == "__main__":
wide_deep_config
.
argparse_init
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
wide_deep_config
.
device_target
)
if
wide_deep_config
.
device_target
==
"Ascend"
:
init
(
"hccl"
)
elif
wide_deep_config
.
device_target
==
"GPU"
:
init
(
"nccl"
)
init
()
context
.
set_auto_parallel_context
(
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
,
device_num
=
get_group_size
())
...
...
tests/st/nccl/test_nccl_all_gather_op.py
浏览文件 @
e6a4d932
...
...
@@ -24,7 +24,7 @@ from mindspore.ops import operations as P
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
'GPU'
)
init
(
'nccl'
)
init
()
rank
=
get_rank
()
size
=
get_group_size
()
x
=
np
.
ones
([
1
,
1
,
3
,
3
]).
astype
(
np
.
float32
)
*
0.01
*
(
rank
+
1
)
...
...
tests/st/nccl/test_nccl_all_reduce_op.py
浏览文件 @
e6a4d932
...
...
@@ -24,7 +24,7 @@ from mindspore.ops import operations as P
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
'GPU'
)
init
(
'nccl'
)
init
()
rank
=
get_rank
()
size
=
get_group_size
()
x
=
np
.
ones
([
3
,
1
,
3
,
3
]).
astype
(
np
.
float32
)
*
0.01
*
(
rank
+
1
)
...
...
tests/st/nccl/test_nccl_broadcast_op.py
浏览文件 @
e6a4d932
...
...
@@ -24,7 +24,7 @@ from mindspore.ops import operations as P
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
'GPU'
)
init
(
'nccl'
)
init
()
rank
=
get_rank
()
size
=
get_group_size
()
x
=
np
.
ones
([
3
,
1
,
3
,
3
]).
astype
(
np
.
float32
)
*
0.01
*
(
rank
+
1
)
...
...
tests/st/nccl/test_nccl_lenet.py
浏览文件 @
e6a4d932
...
...
@@ -25,7 +25,7 @@ from mindspore.nn.optim import Momentum
from
mindspore.ops
import
operations
as
P
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
init
(
'nccl'
)
init
()
epoch
=
5
total
=
5000
...
...
tests/st/nccl/test_nccl_reduce_scatter_op.py
浏览文件 @
e6a4d932
...
...
@@ -24,7 +24,7 @@ from mindspore.ops import operations as P
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
'GPU'
)
init
(
'nccl'
)
init
()
rank
=
get_rank
()
size
=
get_group_size
()
x
=
np
.
ones
([
size
,
1
,
3
,
3
]).
astype
(
np
.
float32
)
*
0.01
*
(
rank
+
1
)
...
...
tests/st/ps/multi_full_ps/test_multi_full_ps.py
浏览文件 @
e6a4d932
...
...
@@ -30,7 +30,7 @@ args, _ = parser.parse_known_args()
device_target
=
args
.
device_target
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
device_target
)
if
device_target
==
"GPU"
:
init
(
'nccl'
)
init
()
def
conv
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
):
...
...
tests/ut/python/train/test_dataset_helper.py
浏览文件 @
e6a4d932
...
...
@@ -75,7 +75,7 @@ def test_dataset_iter_normal():
@
pytest
.
mark
.
skipif
(
'not context.get_context("enable_ge")'
)
def
test_dataset_iter_ge
():
init
()
init
(
"hccl"
)
dataset
=
get_dataset
(
32
)
dataset_helper
=
DatasetHelper
(
dataset
,
dataset_sink_mode
=
True
,
sink_size
=
10
)
count
=
0
...
...
@@ -87,7 +87,7 @@ def test_dataset_iter_ge():
@
pytest
.
mark
.
skipif
(
'context.get_context("enable_ge")'
)
def
test_dataset_iter_ms_loop_sink
():
init
()
init
(
"hccl"
)
context
.
set_context
(
enable_loop_sink
=
True
)
dataset
=
get_dataset
(
32
)
dataset_helper
=
DatasetHelper
(
dataset
,
dataset_sink_mode
=
True
,
sink_size
=
10
)
...
...
@@ -101,7 +101,7 @@ def test_dataset_iter_ms_loop_sink():
@
pytest
.
mark
.
skipif
(
'context.get_context("enable_ge")'
)
def
test_dataset_iter_ms
():
init
()
init
(
"hccl"
)
context
.
set_context
(
enable_loop_sink
=
False
)
dataset
=
get_dataset
(
32
)
DatasetHelper
(
dataset
,
dataset_sink_mode
=
True
,
sink_size
=
10
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录