diff --git a/python/paddle/distributed/launch.py b/python/paddle/distributed/launch.py index d8153fa00267b00eedc52aa043af9ba7dc090f7d..845ccf27451d9c113e538122fa032a1805611686 100644 --- a/python/paddle/distributed/launch.py +++ b/python/paddle/distributed/launch.py @@ -38,6 +38,19 @@ default_envs = { GPUS = 8 +def get_gpu_ids(gpus): + if os.getenv("CUDA_VISIBLE_DEVICES"): + ids = [int(i) + for i in os.getenv("CUDA_VISIBLE_DEVICES").split(",")][:gpus] + if gpus > len(ids): + raise EnvironmentError( + "The count of env CUDA_VISIBLE_DEVICES should not greater than the passed gpus: %s" + % gpus) + return ids + else: + return [i for i in range(gpus)] + + def start_procs(gpus, entrypoint, entrypoint_args, log_dir): procs = [] log_fns = [] @@ -61,8 +74,8 @@ def start_procs(gpus, entrypoint, entrypoint_args, log_dir): all_nodes_devices_endpoints += "%s:617%d" % (n, i) nranks = num_nodes * gpus # ======== for dist training ======= - - for i in range(gpus): + gpu_ids = get_gpu_ids(gpus) + for i in gpu_ids: curr_env = {} curr_env.update(default_envs) curr_env.update({