From 723020337d9bf392398e4990a7abf6df7940a69d Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Thu, 8 Apr 2021 13:20:21 +0800 Subject: [PATCH] fix bug (#32135) --- .../tests/unittests/test_parallel_dygraph_dataparallel.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py index 1d2a3975190..5491b451368 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py @@ -18,7 +18,7 @@ import unittest import time import paddle.fluid as fluid -from paddle.distributed.utils import find_free_ports, watch_local_trainers, get_cluster, get_gpus, start_local_trainers +from paddle.distributed.utils import find_free_ports, watch_local_trainers, get_cluster, start_local_trainers def get_cluster_from_args(selected_gpus): @@ -41,6 +41,11 @@ def get_cluster_from_args(selected_gpus): return get_cluster(node_ips, node_ip, trainer_endpoints, selected_gpus) +def get_gpus(selected_gpus): + selected_gpus = [x.strip() for x in selected_gpus.split(',')] + return selected_gpus + + class TestMultipleGpus(unittest.TestCase): def run_mnist_2gpu(self, target_file_name): if not fluid.core.is_compiled_with_cuda( -- GitLab