未验证 提交 9b70ce56 编写于 作者: Q Quentin Anthony 提交者: GitHub

Comms Benchmarks (#2040)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 76ea0534
......@@ -39,7 +39,7 @@ repos:
name: check-torchdist
entry: ./scripts/check-torchdist.py
language: script
exclude: ^(deepspeed/comm/|docs/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py)
exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py)
# Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm
- repo: https://github.com/codespell-project/codespell
......
# Running Communication Benchmarks
To run benchmarks, there are two options:
1. Run a single communication operation:
For example, run with a single large message size:
<pre>
deepspeed all_reduce.py
</pre>
Scan across message sizes:
<pre>
deepspeed all_reduce.py --scan
</pre>
Each individual communication operation's benchmarks have separate benchmarking options. For `all_reduce.py`, for example:
<pre>
usage: ds_bench [-h] [--local_rank LOCAL_RANK] [--trials TRIALS] [--warmup WARMUP] [--maxsize MAXSIZE] [--async-op] [--bw-unit {Gbps,GBps}] [--backend {nccl}] [--dist {deepspeed,torch}] [--scan] [--dtype DTYPE] [--mem-factor MEM_FACTOR] [--debug]
optional arguments:
-h, --help show this help message and exit
--local_rank LOCAL_RANK
--trials TRIALS Number of timed iterations
--warmup WARMUP Number of warmup (non-timed) iterations
--maxsize MAXSIZE Max message size as a power of 2
--async-op Enables non-blocking communication
--bw-unit {Gbps,GBps}
--backend {nccl} Communication library to use
--dist {deepspeed,torch}
Distributed DL framework to use
--scan Enables scanning all message sizes
--dtype DTYPE PyTorch tensor dtype
--mem-factor MEM_FACTOR
Proportion of max available GPU memory to use for single-size evals
--debug Enables alltoall debug prints
</pre>
2. Run all available communication benchmarks:
<pre>
deepspeed run_all.py
</pre>
Like the individual benchmarks, `run_all.py` supports scanning arguments for the max message size, bw-unit, etc. Simply pass the desired arguments to `run_all.py` and they'll be propagated to each comm op.
Note that `ds_bench` is a pre-packaged wrapper around `run_all.py`. Users can pass the same arguments as well:
<pre>
<path to deepspeed>/bin/ds_bench --scan --trials=10
</pre>
# Adding Communication Benchmarks
To add new communication benchmarks, follow this general procedure:
1. Copy a similar benchmark file (e.g. to add `reduce_scatter`, copy `all_reduce.py` as a template)
2. Add a new bw formula in `utils.get_bw`
3. Add a new maximum tensor element formula in `utils.max_numel`
4. Replace comm op calls in new file with find-replace
5. Find a good default `mem_factor` for use in `run_<collective>_single()` function
6. Add new comm op to `run_all.py`
import torch
from benchmarks.communication.utils import *
from benchmarks.communication.constants import *
import time
import argparse
import os
import math
# Run allgather and print metrics
def timed_allgather(input, output, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist
sync_all()
# Warmup, establish connections, etc.
for i in range(args.warmup):
# use all_gather_base if available
if args.dist == 'torch':
if hasattr(torch.distributed, "_all_gather_base"):
dist._all_gather_base(output, input, group=None, async_op=args.async_op)
else:
output_tensors = list(
torch.chunk(output_tensor,
cdb.get_world_size(group)))
dist.all_gather(output_tensors, input_tensor, group=group, async_op=True)
elif args.dist == 'deepspeed':
dist.allgather_fn(output, input, group=None, async_op=args.async_op)
sync_all()
# time the actual comm op trials times and average it
pre = time.perf_counter()
for i in range(args.trials):
# use all_gather_base if available
if args.dist == 'torch':
if hasattr(torch.distributed, "_all_gather_base"):
dist._all_gather_base(output, input, group=None, async_op=args.async_op)
else:
output_tensors = list(
torch.chunk(output_tensor,
cdb.get_world_size(group)))
dist.all_gather(output_tensors, input_tensor, group=group, async_op=True)
elif args.dist == 'deepspeed':
dist.allgather_fn(output, input, group=None, async_op=args.async_op)
sync_all()
duration = time.perf_counter() - pre
# maintain and clean performance data
avg_duration = duration / args.trials
size = input.element_size() * input.nelement()
n = dist.get_world_size()
tput, busbw = get_bw('allgather', size, avg_duration, args)
tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
desc = f'{input.nelement()}x{input.element_size()}'
print_rank_0(
f"{convert_size(size):<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}"
)
def run_allgather(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist
# Prepare benchmark header
print_header(args, 'allgather')
global_rank = dist.get_rank()
world_size = dist.get_world_size()
if args.scan:
# Create list of message sizes
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
M_LIST.append(x)
sync_all()
# loop over various tensor sizes
for M in M_LIST:
global_rank = dist.get_rank()
try:
mat = torch.ones(world_size,
M,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
# Delete original mat to avoid OOM
del mat
torch.cuda.empty_cache()
output = torch.zeros(input.nelement() * world_size,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
break
sync_all()
timed_allgather(input, output, args)
else:
# all_gather_base saves memory
if (args.dist == 'torch'
and hasattr(torch.distributed,
"_all_gather_base")) or (args.dist == 'deepspeed'
and dist.has_allgather_base):
mem_factor = args.mem_factor + 0.2
else:
mem_factor = args.mem_factor
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
sync_all()
elements_per_gpu = max_numel(comm_op='allgather',
dtype=getattr(torch,
args.dtype),
mem_factor=mem_factor,
local_rank=local_rank,
args=args)
try:
mat = torch.ones(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
# multiply each GPU's tensor by the rank to ease debugging
input = ((mat.mul_(float(global_rank))).view(-1))
# Delete original mat to avoid OOM
del mat
torch.cuda.empty_cache()
output = torch.zeros(elements_per_gpu * world_size,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print(
'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
)
sync_all()
return
sync_all()
timed_allgather(input, output, args)
if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
run_allgather(local_rank=rank, args=args)
import torch
from benchmarks.communication.utils import *
from benchmarks.communication.constants import *
import time
import argparse
import os
import math
def timed_allreduce(input, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist
sync_all()
# Warmup, establish connections, etc.
for i in range(args.warmup):
dist.all_reduce(input, async_op=args.async_op)
sync_all()
# time the actual comm op trials times and average it
pre = time.perf_counter()
for i in range(args.trials):
dist.all_reduce(input, async_op=args.async_op)
sync_all()
duration = time.perf_counter() - pre
# maintain and clean performance data
avg_duration = duration / args.trials
size = input.element_size() * input.nelement()
n = dist.get_world_size()
tput, busbw = get_bw('allreduce', size, avg_duration, args)
tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
desc = f'{input.nelement()}x{input.element_size()}'
print_rank_0(
f"{convert_size(size):<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}"
)
def run_allreduce(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist
# Prepare benchmark header
print_header(args, 'allreduce')
world_size = dist.get_world_size()
global_rank = dist.get_rank()
if args.scan:
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
M_LIST.append(x)
sync_all()
# loop over various tensor sizes
for M in M_LIST:
global_rank = dist.get_rank()
try:
mat = torch.ones(world_size,
M,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
break
sync_all()
timed_allreduce(input, args)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so we double mem_factor
elements_per_gpu = max_numel(comm_op='allreduce',
dtype=getattr(torch,
args.dtype),
mem_factor=args.mem_factor * 2,
local_rank=local_rank,
args=args)
try:
mat = torch.ones(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print(
'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
)
sync_all()
return
sync_all()
timed_allreduce(input, args)
if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
run_allreduce(local_rank=rank, args=args)
import torch
from benchmarks.communication.utils import *
from benchmarks.communication.constants import *
import time
import argparse
import os
import math
def timed_alltoall(input, output, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist
sync_all()
# Warmup, establish connections, etc.
for i in range(args.warmup):
dist.all_to_all_single(output, input, async_op=args.async_op)
sync_all()
# time the actual comm op trials times and average it
pre = time.perf_counter()
for i in range(args.trials):
dist.all_to_all_single(output, input, async_op=args.async_op)
sync_all()
duration = time.perf_counter() - pre
# maintain and clean performance data
avg_duration = duration / args.trials
size = input.element_size() * input.nelement()
n = dist.get_world_size()
tput, busbw = get_bw('alltoall', size, avg_duration, args)
tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
desc = f'{input.nelement()}x{input.element_size()}'
print_rank_0(
f"{convert_size(size):<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}"
)
def run_alltoall(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist
world_size = dist.get_world_size()
global_rank = dist.get_rank()
# Prepare benchmark header
print_header(args, 'alltoall')
if args.scan:
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
M_LIST.append(x)
sync_all()
# loop over various tensor sizes
for M in M_LIST:
global_rank = dist.get_rank()
try:
mat = torch.ones(world_size,
M,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
assert mat.numel() % world_size == 0, f"tensor cannot be divided in {world_size} chunks"
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
output = (mat.clone().view(-1))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
break
sync_all()
timed_alltoall(input, output, args)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
elements_per_gpu = max_numel(comm_op='alltoall',
dtype=getattr(torch,
args.dtype),
mem_factor=args.mem_factor,
local_rank=local_rank,
args=args)
try:
mat = torch.ones(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
assert mat.numel() % world_size == 0, f"tensor with {mat.numel()} elements cannot be divided in {world_size} chunks"
input = ((mat.mul_(float(global_rank))).view(-1))
# Delete original mat to avoid OOM
del mat
torch.cuda.empty_cache()
output = torch.zeros(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print(
'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
)
sync_all()
return
sync_all()
if args.debug:
for i in range(world_size):
if i == global_rank:
print(f"Before AllToAll Input List at rank {global_rank}: {input}")
dist.barrier()
timed_alltoall(input, output, args)
if args.debug:
for i in range(world_size):
if i == global_rank:
print(f"AllToAll Results at rank {global_rank}: {output}")
dist.barrier()
if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
run_alltoall(local_rank=rank, args=args)
import torch
DEFAULT_WARMUPS = 5
DEFAULT_TRIALS = 50
DEFAULT_TYPE = 'float'
DEFAULT_BACKEND = 'nccl'
DEFAULT_UNIT = 'Gbps'
DEFAULT_DIST = 'deepspeed'
DEFAULT_MAXSIZE = 24
import torch
from benchmarks.communication.utils import *
from benchmarks.communication.constants import *
import time
import argparse
import os
import math
def timed_pt2pt(input, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist
sync_all()
# Warmup, establish connections, etc.
for i in range(args.warmup):
if dist.get_rank() == 0:
if args.async_op:
dist.isend(input, 1)
else:
dist.send(input, 1)
if dist.get_rank() == 1:
if args.async_op:
dist.irecv(input, src=0)
else:
dist.recv(input, src=0)
sync_all()
# time the actual comm op trials times and average it
pre = time.perf_counter()
for i in range(args.trials):
if dist.get_rank() == 0:
if args.async_op:
dist.isend(input, 1)
else:
dist.send(input, 1)
if dist.get_rank() == 1:
if args.async_op:
dist.irecv(input, src=0)
else:
dist.recv(input, src=0)
sync_all()
duration = time.perf_counter() - pre
# maintain and clean performance data
avg_duration = duration / args.trials
size = input.element_size() * input.nelement()
n = dist.get_world_size()
tput, busbw = get_bw('pt2pt', size, avg_duration, args)
tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
desc = f'{input.nelement()}x{input.element_size()}'
print_rank_0(
f"{convert_size(size):<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}"
)
def run_pt2pt(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist
# Prepare benchmark header
print_header(args, 'pt2pt')
global_rank = dist.get_rank()
world_size = dist.get_world_size()
if args.scan:
# Create list of message sizes
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
M_LIST.append(x)
sync_all()
# loop over various tensor sizes
for M in M_LIST:
global_rank = dist.get_rank()
try:
mat = torch.ones(world_size,
M,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
break
sync_all()
timed_pt2pt(input, args)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so double mem_factor
elements_per_gpu = max_numel(comm_op='pt2pt',
dtype=getattr(torch,
args.dtype),
mem_factor=args.mem_factor * 2,
local_rank=local_rank,
args=args)
try:
mat = torch.ones(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print(
'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
)
sync_all()
return
sync_all()
timed_pt2pt(input, args)
if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
run_pt2pt(local_rank=rank, args=args)
import torch
from benchmarks.communication.utils import *
from benchmarks.communication.all_reduce import run_allreduce
from benchmarks.communication.all_gather import run_allgather
from benchmarks.communication.all_to_all import run_alltoall
from benchmarks.communication.pt2pt import run_pt2pt
from benchmarks.communication.constants import *
import time
import argparse
import os
# For importing
def main(args, rank):
init_processes(local_rank=rank, args=args)
for comm_op in ['allreduce', 'alltoall', 'allgather', 'pt2pt']:
if comm_op == 'allreduce':
run_allreduce(local_rank=rank, args=args)
if comm_op == 'allgather':
run_allgather(local_rank=rank, args=args)
if comm_op == 'alltoall':
run_alltoall(local_rank=rank, args=args)
if comm_op == 'pt2pt':
run_pt2pt(local_rank=rank, args=args)
# For directly calling benchmark
if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
main(args, rank)
import torch
import os
import math
import argparse
from benchmarks.communication.constants import *
global dist
def init_torch_distributed(backend):
global dist
import torch.distributed as dist
torch.distributed.init_process_group(backend)
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
def init_deepspeed_comm(backend):
global dist
import deepspeed
import deepspeed.comm as dist
deepspeed.init_distributed(dist_backend=backend)
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
def init_processes(local_rank, args):
if args.dist == 'deepspeed':
init_deepspeed_comm(args.backend)
elif args.dist == 'torch':
init_torch_distributed(args.backend)
else:
print_rank_0(f"distributed framework {args.dist} not supported")
exit(0)
def print_rank_0(message):
if dist.get_rank() == 0:
print(message)
def print_header(args, comm_op):
if comm_op == 'pt2pt':
world_size = 2
else:
world_size = dist.get_world_size()
tput = f'Throughput ({args.bw_unit})'
busbw = f'BusBW ({args.bw_unit})'
header = f"\n---- Performance of {comm_op} on {world_size} devices ---------------------------------------------------------\n"
header += f"{'Size (Bytes)':20s} {'Description':25s} {'Duration':20s} {tput:20s} {busbw:20s}\n"
header += "----------------------------------------------------------------------------------------------------"
print_rank_0(header)
def get_bw(comm_op, size, duration, args):
n = dist.get_world_size()
tput = 0
busbw = 0
if comm_op == "alltoall":
tput = (size / duration)
busbw = (size / duration) * ((n - 1) / n)
elif comm_op == "allgather":
size *= n
tput = (size / duration)
busbw = (size / duration) * ((n - 1) / n)
elif comm_op == "allreduce":
tput = (size * 2 / duration)
busbw = (size / duration) * (2 * (n - 1) / n)
elif comm_op == "pt2pt":
tput = (size / duration)
busbw = tput
else:
print_rank_0("wrong comm_op specified")
exit(0)
if args.bw_unit == 'Gbps':
tput *= 8
busbw *= 8
return tput, busbw
def get_metric_strings(args, tput, busbw, duration):
duration_ms = duration * 1e3
duration_us = duration * 1e6
tput = f'{tput / 1e9:.3f}'
busbw = f'{busbw /1e9:.3f}'
if duration_us < 1e3:
duration = f'{duration_us:.3f} us'
else:
duration = f'{duration_ms:.3f} ms'
return tput, busbw, duration
def sync_all():
torch.cuda.synchronize()
dist.barrier()
def max_numel(comm_op, dtype, mem_factor, local_rank, args):
dtype_size = torch._utils._element_size(dtype)
max_memory_per_gpu = torch.cuda.get_device_properties(
local_rank).total_memory * mem_factor
if comm_op == 'allreduce' or comm_op == 'pt2pt':
elements_per_gpu = int(max_memory_per_gpu // dtype_size)
elif comm_op == 'allgather':
# all_gather performance is lower for non-powers of two, and the output buffer size scales with world size
# Therefore, divide by world size and round down to nearest power of 2
elements_per_gpu = int(max_memory_per_gpu // dtype_size // dist.get_world_size())
elements_per_gpu = int(pow(2, int(math.log(elements_per_gpu, 2))))
elif comm_op == 'alltoall':
# Number of elements must be divisible by world_size
# all_to_all performance is lower for non-powers of two. Round down like allgather.
elements_per_gpu = int(max_memory_per_gpu // dtype_size)
elements_per_gpu = int(dist.get_world_size() *
round(elements_per_gpu / dist.get_world_size()))
elements_per_gpu = int(pow(2, int(math.log(elements_per_gpu, 2))))
else:
print(f"This communication operation: {comm_op} is not supported yet")
exit(0)
return elements_per_gpu
# Helper function to pretty-print message sizes
def convert_size(size_bytes):
if size_bytes == 0:
return "0B"
size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
i = int(math.floor(math.log(size_bytes, 1024)))
p = math.pow(1024, i)
s = round(size_bytes / p, 2)
return "%s %s" % (s, size_name[i])
def benchmark_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int)
parser.add_argument("--trials",
type=int,
default=DEFAULT_TRIALS,
help='Number of timed iterations')
parser.add_argument("--warmup",
type=int,
default=DEFAULT_WARMUPS,
help='Number of warmup (non-timed) iterations')
parser.add_argument("--maxsize",
type=int,
default=24,
help='Max message size as a power of 2')
parser.add_argument("--async-op",
action="store_true",
help='Enables non-blocking communication')
parser.add_argument("--bw-unit",
type=str,
default=DEFAULT_UNIT,
choices=['Gbps',
'GBps'])
parser.add_argument("--backend",
type=str,
default=DEFAULT_BACKEND,
choices=['nccl'],
help='Communication library to use')
parser.add_argument("--dist",
type=str,
default=DEFAULT_DIST,
choices=['deepspeed',
'torch'],
help='Distributed DL framework to use')
parser.add_argument("--scan",
action="store_true",
help='Enables scanning all message sizes')
parser.add_argument("--dtype",
type=str,
default=DEFAULT_TYPE,
help='PyTorch tensor dtype')
parser.add_argument(
"--mem-factor",
type=float,
default=.4,
help='Proportion of max available GPU memory to use for single-size evals')
parser.add_argument("--debug",
action="store_true",
help='Enables alltoall debug prints')
return parser
#!/usr/bin/env python3
from benchmarks.communication.run_all import main
from benchmarks.communication.constants import *
from benchmarks.communication.utils import *
import argparse
import os
import sys
# Run the same file with deepspeed launcher. This is required since setuptools will auto-detect python files and insert a python shebang for both 'scripts' and 'entry_points', and this benchmarks require the DS launcher
required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
if not all(map(lambda v: v in os.environ, required_env)):
import subprocess
subprocess.run("deepspeed $(which ds_bench) " + " ".join(sys.argv[1:]), shell=True)
else:
args = benchmark_parser().parse_args()
rank = args.local_rank
main(args, rank)
......@@ -206,7 +206,7 @@ def allgather_fn(output_tensor: torch.Tensor,
group=group,
async_op=True)
else:
if not has_warned_all_gather:
if not has_warned_all_gather and get_rank() == 0:
utils.logger.warning(
"unable to find torch.distributed._all_gather_base. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
......
......@@ -300,6 +300,7 @@ setup(name='deepspeed',
'bin/ds',
'bin/ds_ssh',
'bin/ds_report',
'bin/ds_bench',
'bin/dsr',
'bin/ds_elastic'
],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册