未验证 提交 9726bd46 编写于 作者: Q Quentin Anthony 提交者: GitHub

Fix comms benchmark import issues and support MPI/slurm launching (#2932)

* Fix benchmark import issues and support MPI launching with pure torch.dist

* Formatting

* Update comms benchmark README

* Formatting

* Added better error handling and support MPI torch.dist backend

* Update formatting versions

* Formatting again

* Trigger CI

---------
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 91d63e02
# Running Communication Benchmarks
# The DeepSpeed Communication Benchmarking Suite
The intent of these benchmarks is to measure communication latency/bw of deepspeed and/or pytorch distributed communication operations at the Python layer. These benchmarks are complementary to C-level comms benchmarks like [OSU Micro-Benchmarks](https://mvapich.cse.ohio-state.edu/benchmarks/) and [NCCL Tests](https://github.com/NVIDIA/nccl-tests) in that users can:
- Easily debug which layer of the communication software stack hangs or performance degradations originate from.
- Measure the expected communication performance of either DeepSpeed comms or pure PyTorch distributed
To run benchmarks, there are two options:
1. Run a single communication operation:
For example, run with a single large message size:
For example, run with a single large message size (calculated to barely fit within GPU mem):
<pre>
deepspeed all_reduce.py
</pre>
......@@ -15,6 +18,17 @@ Scan across message sizes:
deepspeed all_reduce.py --scan
</pre>
Benchmark pure PyTorch distributed comms (without importing or using DeepSpeed) with MPI
<pre>
mpirun -np 16 --hostfile ${HOSTFILE} -x LD_LIBRARY_PATH -x PATH -x LD_PRELOAD python all_reduce.py --scan --dist="torch"
</pre>
or Slurm
<pre>
srun -n 16 python all_reduce.py --scan --dist="torch"
</pre>
2. Run all available communication benchmarks:
<pre>
......
'''Copyright The Microsoft DeepSpeed Team'''
from benchmarks.communication.utils import *
from benchmarks.communication.constants import *
from deepspeed.accelerator import get_accelerator
import torch
import sys, os, time
COMMS_BENCH_DIR = os.path.join(os.path.dirname(__file__), "../")
sys.path.append(COMMS_BENCH_DIR)
import time
from communication.utils import *
from communication.constants import *
from deepspeed.accelerator import get_accelerator
# Run all_gather and print metrics
......@@ -94,6 +98,8 @@ def run_all_gather(local_rank, args):
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
break
else:
raise e
sync_all()
timed_all_gather(input, output, args)
else:
......@@ -126,6 +132,8 @@ def run_all_gather(local_rank, args):
print('WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!')
sync_all()
return
else:
raise e
sync_all()
timed_all_gather(input, output, args)
......
'''Copyright The Microsoft DeepSpeed Team'''
from benchmarks.communication.utils import *
from benchmarks.communication.constants import *
from deepspeed.accelerator import get_accelerator
import torch
import sys, os, time
COMMS_BENCH_DIR = os.path.join(os.path.dirname(__file__), "../")
sys.path.append(COMMS_BENCH_DIR)
import time
from communication.utils import *
from communication.constants import *
from deepspeed.accelerator import get_accelerator
def timed_all_reduce(input, args):
......@@ -72,6 +76,8 @@ def run_all_reduce(local_rank, args):
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
break
else:
raise e
sync_all()
timed_all_reduce(input, args)
else:
......@@ -92,6 +98,8 @@ def run_all_reduce(local_rank, args):
print('WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!')
sync_all()
return
else:
raise e
sync_all()
timed_all_reduce(input, args)
......
'''Copyright The Microsoft DeepSpeed Team'''
from benchmarks.communication.utils import *
from benchmarks.communication.constants import *
from deepspeed.accelerator import get_accelerator
import torch
import sys, os, time
COMMS_BENCH_DIR = os.path.join(os.path.dirname(__file__), "../")
sys.path.append(COMMS_BENCH_DIR)
import time
from communication.utils import *
from communication.constants import *
from deepspeed.accelerator import get_accelerator
def timed_all_to_all(input, output, args):
......@@ -73,6 +77,8 @@ def run_all_to_all(local_rank, args):
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
break
else:
raise e
sync_all()
timed_all_to_all(input, output, args)
else:
......@@ -99,6 +105,8 @@ def run_all_to_all(local_rank, args):
print('WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!')
sync_all()
return
else:
raise e
sync_all()
if args.debug:
......
'''Copyright The Microsoft DeepSpeed Team'''
import torch
from benchmarks.communication.utils import *
from benchmarks.communication.constants import *
from deepspeed.accelerator import get_accelerator
import sys, os, time
COMMS_BENCH_DIR = os.path.join(os.path.dirname(__file__), "../")
sys.path.append(COMMS_BENCH_DIR)
import time
from communication.utils import *
from communication.constants import *
from deepspeed.accelerator import get_accelerator
def timed_broadcast(input, args):
......@@ -73,6 +76,8 @@ def run_broadcast(local_rank, args):
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
break
else:
raise e
sync_all()
timed_broadcast(input, args)
else:
......
......@@ -8,3 +8,4 @@ DEFAULT_BACKEND = get_accelerator().communication_backend_name()
DEFAULT_UNIT = 'Gbps'
DEFAULT_DIST = 'deepspeed'
DEFAULT_MAXSIZE = 24
TORCH_DISTRIBUTED_DEFAULT_PORT = 29500
'''Copyright The Microsoft DeepSpeed Team'''
from benchmarks.communication.utils import *
from benchmarks.communication.constants import *
from deepspeed.accelerator import get_accelerator
import torch
import sys, os, time
COMMS_BENCH_DIR = os.path.join(os.path.dirname(__file__), "../")
sys.path.append(COMMS_BENCH_DIR)
import time
from communication.utils import *
from communication.constants import *
from deepspeed.accelerator import get_accelerator
def timed_pt2pt(input, args):
......@@ -91,6 +95,8 @@ def run_pt2pt(local_rank, args):
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
break
else:
raise e
sync_all()
timed_pt2pt(input, args)
else:
......
'''Copyright The Microsoft DeepSpeed Team'''
from benchmarks.communication.utils import *
from benchmarks.communication.all_reduce import run_all_reduce
from benchmarks.communication.all_gather import run_all_gather
from benchmarks.communication.all_to_all import run_all_to_all
from benchmarks.communication.pt2pt import run_pt2pt
from benchmarks.communication.broadcast import run_broadcast
from benchmarks.communication.constants import *
import sys, os
COMMS_BENCH_DIR = os.path.join(os.path.dirname(__file__), "../")
sys.path.append(COMMS_BENCH_DIR)
from communication.utils import *
from communication.all_reduce import run_all_reduce
from communication.all_gather import run_all_gather
from communication.all_to_all import run_all_to_all
from communication.pt2pt import run_pt2pt
from communication.broadcast import run_broadcast
from communication.constants import *
# For importing
......
'''Copyright The Microsoft DeepSpeed Team'''
import torch
import os
import os, sys
import math
import argparse
from benchmarks.communication.constants import *
COMMS_BENCH_DIR = os.path.join(os.path.dirname(__file__), "../")
sys.path.append(COMMS_BENCH_DIR)
from communication.constants import *
from deepspeed.accelerator import get_accelerator
global dist
def env2int(env_list, default=-1):
for e in env_list:
val = int(os.environ.get(e, -1))
if val >= 0: return val
return default
def init_torch_distributed(backend):
global dist
import torch.distributed as dist
# discover rank/size info from env
if 'MASTER_PORT' not in os.environ:
os.environ['MASTER_PORT'] = str(TORCH_DISTRIBUTED_DEFAULT_PORT)
if 'MASTER_ADDR' not in os.environ:
try:
from mpi4py import MPI
except ModuleNotFoundError:
print(
"Cannot import mpi4py and MASTER_ADDR not set. Please either install mpi4py or set the MASTER_ADDR on all ranks"
)
raise Exception
import subprocess
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
master_addr = None
if rank == 0:
hostname_cmd = ["hostname -I"]
result = subprocess.check_output(hostname_cmd, shell=True)
master_addr = result.decode('utf-8').split()[0]
master_addr = comm.bcast(master_addr, root=0)
os.environ['MASTER_ADDR'] = master_addr
local_rank = env2int(
['LOCAL_RANK', 'MPI_LOCALRANKID', 'OMPI_COMM_WORLD_LOCAL_RANK', 'MV2_COMM_WORLD_LOCAL_RANK', 'SLURM_LOCALID'])
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(local_rank)
rank = env2int(['RANK', 'MPI_RANKID', 'OMPI_COMM_WORLD_RANK', 'MV2_COMM_WORLD_RANK', 'SLURM_PROCID'])
if 'RANK' not in os.environ:
os.environ['RANK'] = str(rank)
world_size = env2int(['WORLD_SIZE', 'OMPI_COMM_WORLD_SIZE', 'MV2_COMM_WORLD_SIZE', 'SLURM_NPROCS'])
if 'WORLD_SIZE' not in os.environ:
os.environ['WORLD_SIZE'] = str(world_size)
torch.distributed.init_process_group(backend)
local_rank = int(os.environ['LOCAL_RANK'])
get_accelerator().set_device(local_rank)
......@@ -169,7 +212,7 @@ def benchmark_parser():
parser.add_argument("--backend",
type=str,
default=DEFAULT_BACKEND,
choices=['nccl', 'ccl'],
choices=['nccl', 'ccl', 'mpi'],
help='Communication library to use')
parser.add_argument("--dist",
type=str,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册