diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py index 1d2a39751905e24acfc1666cfd22952b673cf698..5491b451368c825c10f1e957d85e30ccacdd1dc7 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py @@ -18,7 +18,7 @@ import unittest import time import paddle.fluid as fluid -from paddle.distributed.utils import find_free_ports, watch_local_trainers, get_cluster, get_gpus, start_local_trainers +from paddle.distributed.utils import find_free_ports, watch_local_trainers, get_cluster, start_local_trainers def get_cluster_from_args(selected_gpus): @@ -41,6 +41,11 @@ def get_cluster_from_args(selected_gpus): return get_cluster(node_ips, node_ip, trainer_endpoints, selected_gpus) +def get_gpus(selected_gpus): + selected_gpus = [x.strip() for x in selected_gpus.split(',')] + return selected_gpus + + class TestMultipleGpus(unittest.TestCase): def run_mnist_2gpu(self, target_file_name): if not fluid.core.is_compiled_with_cuda(