未验证 提交 266444b8 编写于 作者: Y Yan Xu 提交者: GitHub

fix dist launch script test=develop (#17404)

上级 0823a7bc
......@@ -38,6 +38,19 @@ default_envs = {
GPUS = 8
def get_gpu_ids(gpus):
if os.getenv("CUDA_VISIBLE_DEVICES"):
ids = [int(i)
for i in os.getenv("CUDA_VISIBLE_DEVICES").split(",")][:gpus]
if gpus > len(ids):
raise EnvironmentError(
"The count of env CUDA_VISIBLE_DEVICES should not greater than the passed gpus: %s"
% gpus)
return ids
else:
return [i for i in range(gpus)]
def start_procs(gpus, entrypoint, entrypoint_args, log_dir):
procs = []
log_fns = []
......@@ -61,8 +74,8 @@ def start_procs(gpus, entrypoint, entrypoint_args, log_dir):
all_nodes_devices_endpoints += "%s:617%d" % (n, i)
nranks = num_nodes * gpus
# ======== for dist training =======
for i in range(gpus):
gpu_ids = get_gpu_ids(gpus)
for i in gpu_ids:
curr_env = {}
curr_env.update(default_envs)
curr_env.update({
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册