未验证 提交 5402f8e7 编写于 作者: X xiongkun 提交者: GitHub

bugfix: only check backend when mode == Collecive (#36758) (#36772)

* bugfix: only check backend when mode == Collecive
上级 c542d571
...@@ -334,7 +334,20 @@ def launch_ps(args, distribute_mode): ...@@ -334,7 +334,20 @@ def launch_ps(args, distribute_mode):
return return
def infer_backend(args):
if args.backend != "auto": return
if fluid.core.is_compiled_with_cuda():
args.backend = 'nccl'
elif fluid.core.is_compiled_with_npu():
args.backend = 'unknown'
elif fluid.core.is_compiled_with_xpu():
args.backend = 'bkcl'
else:
args.backend = 'gloo'
def which_distributed_mode(args): def which_distributed_mode(args):
infer_backend(args) # modify the args.backend
if args.run_mode is not None: if args.run_mode is not None:
assert args.run_mode in ["collective", "ps", "ps-heter"] assert args.run_mode in ["collective", "ps", "ps-heter"]
...@@ -368,12 +381,9 @@ def which_distributed_mode(args): ...@@ -368,12 +381,9 @@ def which_distributed_mode(args):
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
accelerators = fluid.core.get_cuda_device_count() accelerators = fluid.core.get_cuda_device_count()
args.backend = 'nccl'
elif fluid.core.is_compiled_with_npu(): elif fluid.core.is_compiled_with_npu():
args.backend = 'unknown'
accelerators = fluid.core.get_npu_device_count() accelerators = fluid.core.get_npu_device_count()
elif fluid.core.is_compiled_with_xpu(): elif fluid.core.is_compiled_with_xpu():
args.backend = 'bkcl'
accelerators = fluid.core.get_xpu_device_count() accelerators = fluid.core.get_xpu_device_count()
else: else:
accelerators = 0 accelerators = 0
...@@ -400,7 +410,6 @@ def which_distributed_mode(args): ...@@ -400,7 +410,6 @@ def which_distributed_mode(args):
But found args.servers not empty, default use ps mode") But found args.servers not empty, default use ps mode")
return DistributeMode.PS return DistributeMode.PS
else: else:
args.backend = "gloo"
return DistributeMode.COLLECTIVE return DistributeMode.COLLECTIVE
else: else:
logger.warning( logger.warning(
...@@ -583,20 +592,21 @@ def launch(): ...@@ -583,20 +592,21 @@ def launch():
_print_arguments(args) _print_arguments(args)
if args.backend == 'auto': if args.backend == 'auto':
distribute_mode = which_distributed_mode(args) distribute_mode = which_distributed_mode(
assert args.backend in [ args) # which_distributed_mode must modify args.backend
'gloo', 'nccl', 'bkcl', 'unknown'
] # which_distributed_mode must modify args.backend
else: else:
assert args.run_mode == 'collective' or args.run_mode == None, "When backend is not 'auto', run mode must be collective" assert args.run_mode == 'collective' or args.run_mode == None, "When backend is not 'auto', run mode must be collective"
check_backend(args.backend) check_backend(args.backend)
distribute_mode = DistributeMode.COLLECTIVE distribute_mode = DistributeMode.COLLECTIVE
block_windows_and_macos( assert args.backend in ['gloo', 'nccl', 'bkcl', 'unknown']
args.backend) # raise error when using gloo on windows or macos
if args.backend == 'gloo': if args.backend == 'gloo':
logger.warning("launch start with CPUONLY mode") logger.warning("launch start with CPUONLY mode")
block_windows_and_macos(
args.backend) # raise error when using gloo on windows or macos
if enable_elastic(args, distribute_mode): if enable_elastic(args, distribute_mode):
launch_elastic(args, distribute_mode) launch_elastic(args, distribute_mode)
return return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册