From 5d8e4395b61929627151f6fd4a607589288a78bf Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 2 Jun 2021 00:19:11 +0800 Subject: [PATCH] [Cherry-pick] Fix spawn default nprocs get error (#33215) (#33249) * fix spawn default nprocs get error * polish error message --- python/paddle/distributed/spawn.py | 25 ++++++++++--------- .../test_spawn_and_init_parallel_env.py | 11 +++++++- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/python/paddle/distributed/spawn.py b/python/paddle/distributed/spawn.py index c46672dca09..e21f142f10b 100644 --- a/python/paddle/distributed/spawn.py +++ b/python/paddle/distributed/spawn.py @@ -89,6 +89,18 @@ def _options_valid_check(options): % key) +def _get_default_nprocs(): + device = get_device() + if 'gpu' in device: + return core.get_cuda_device_count() + elif 'xpu' in device: + return core.get_xpu_device_count() + else: + raise RuntimeError( + "`paddle.distributed.spawn` does not support parallel training on device `{}` now.". + format(device)) + + def _get_node_ip(ips): node_ip = None node_ips = [x.strip() for x in ips.split(',')] @@ -448,18 +460,7 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options): # get default nprocs if nprocs == -1: - device = get_device() - if device == 'cpu': - # TODO: not supports cpu parallel now - nprocs = _cpu_num() - elif device == 'gpu': - nprocs = core.get_cuda_device_count() - elif device == 'xpu': - nprocs = core.get_xpu_device_count() - else: - raise ValueError( - "`device` should be a string of `cpu`, 'gpu' or 'xpu', but got {}". - format(device)) + nprocs = _get_default_nprocs() # NOTE(chenweihang): [ why need get cluster info before run? ] # when using `paddle.distributed.spawn` start parallel training, diff --git a/python/paddle/fluid/tests/unittests/test_spawn_and_init_parallel_env.py b/python/paddle/fluid/tests/unittests/test_spawn_and_init_parallel_env.py index 6efab81a265..14547eca5ac 100644 --- a/python/paddle/fluid/tests/unittests/test_spawn_and_init_parallel_env.py +++ b/python/paddle/fluid/tests/unittests/test_spawn_and_init_parallel_env.py @@ -20,7 +20,7 @@ import unittest import paddle import paddle.distributed as dist -from paddle.distributed.spawn import _get_subprocess_env_list, _options_valid_check +from paddle.distributed.spawn import _get_subprocess_env_list, _options_valid_check, _get_default_nprocs from paddle.fluid import core from paddle.fluid.dygraph import parallel_helper @@ -87,6 +87,15 @@ class TestSpawnAssistMethod(unittest.TestCase): options['error'] = "error" _options_valid_check(options) + def test_get_default_nprocs(self): + paddle.set_device('cpu') + with self.assertRaises(RuntimeError): + nprocs = _get_default_nprocs() + + paddle.set_device('gpu') + nprocs = _get_default_nprocs() + self.assertEqual(nprocs, core.get_cuda_device_count()) + if __name__ == "__main__": unittest.main() -- GitLab