From 266444b8afcf653b45cfdbeb8cb446561ba16ce6 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Wed, 15 May 2019 13:44:01 +0800 Subject: [PATCH] fix dist launch script test=develop (#17404) --- python/paddle/distributed/launch.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/launch.py b/python/paddle/distributed/launch.py index d8153fa0026..845ccf27451 100644 --- a/python/paddle/distributed/launch.py +++ b/python/paddle/distributed/launch.py @@ -38,6 +38,19 @@ default_envs = { GPUS = 8 +def get_gpu_ids(gpus): + if os.getenv("CUDA_VISIBLE_DEVICES"): + ids = [int(i) + for i in os.getenv("CUDA_VISIBLE_DEVICES").split(",")][:gpus] + if gpus > len(ids): + raise EnvironmentError( + "The count of env CUDA_VISIBLE_DEVICES should not greater than the passed gpus: %s" + % gpus) + return ids + else: + return [i for i in range(gpus)] + + def start_procs(gpus, entrypoint, entrypoint_args, log_dir): procs = [] log_fns = [] @@ -61,8 +74,8 @@ def start_procs(gpus, entrypoint, entrypoint_args, log_dir): all_nodes_devices_endpoints += "%s:617%d" % (n, i) nranks = num_nodes * gpus # ======== for dist training ======= - - for i in range(gpus): + gpu_ids = get_gpu_ids(gpus) + for i in gpu_ids: curr_env = {} curr_env.update(default_envs) curr_env.update({ -- GitLab