未验证 提交 5d8e4395 编写于 作者: C Chen Weihang 提交者: GitHub

[Cherry-pick] Fix spawn default nprocs get error (#33215) (#33249)

* fix spawn default nprocs get error

* polish error message
上级 8a5a45f8
......@@ -89,6 +89,18 @@ def _options_valid_check(options):
% key)
def _get_default_nprocs():
device = get_device()
if 'gpu' in device:
return core.get_cuda_device_count()
elif 'xpu' in device:
return core.get_xpu_device_count()
else:
raise RuntimeError(
"`paddle.distributed.spawn` does not support parallel training on device `{}` now.".
format(device))
def _get_node_ip(ips):
node_ip = None
node_ips = [x.strip() for x in ips.split(',')]
......@@ -448,18 +460,7 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
# get default nprocs
if nprocs == -1:
device = get_device()
if device == 'cpu':
# TODO: not supports cpu parallel now
nprocs = _cpu_num()
elif device == 'gpu':
nprocs = core.get_cuda_device_count()
elif device == 'xpu':
nprocs = core.get_xpu_device_count()
else:
raise ValueError(
"`device` should be a string of `cpu`, 'gpu' or 'xpu', but got {}".
format(device))
nprocs = _get_default_nprocs()
# NOTE(chenweihang): [ why need get cluster info before run? ]
# when using `paddle.distributed.spawn` start parallel training,
......
......@@ -20,7 +20,7 @@ import unittest
import paddle
import paddle.distributed as dist
from paddle.distributed.spawn import _get_subprocess_env_list, _options_valid_check
from paddle.distributed.spawn import _get_subprocess_env_list, _options_valid_check, _get_default_nprocs
from paddle.fluid import core
from paddle.fluid.dygraph import parallel_helper
......@@ -87,6 +87,15 @@ class TestSpawnAssistMethod(unittest.TestCase):
options['error'] = "error"
_options_valid_check(options)
def test_get_default_nprocs(self):
paddle.set_device('cpu')
with self.assertRaises(RuntimeError):
nprocs = _get_default_nprocs()
paddle.set_device('gpu')
nprocs = _get_default_nprocs()
self.assertEqual(nprocs, core.get_cuda_device_count())
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册