未验证 提交 a6317eb5 编写于 作者: M Mayank Mishra 提交者: GitHub

️ replace deprecated functions for communication (#2995)

* 💩 drop dead code

*  replace has_all_gather_base with has_all_gather_into_tensor

*  remove deprecated _all_gather_base

*  remove deprecated _reduce_scatter_base

* 🎨 reformat files

* 🔧 fix _six

* Trigger CI

* Trigger CI

* Trigger CI

* 🎨 formatting

* incorporate suggestion

* incorporate suggestion

---------
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 4b6d7c15
......@@ -9,48 +9,36 @@ sys.path.append(COMMS_BENCH_DIR)
from communication.utils import *
from communication.constants import *
from deepspeed.accelerator import get_accelerator
from deepspeed.comm import TorchBackend
# Run all_gather and print metrics
def timed_all_gather(input, output, args):
if args.dist == 'torch':
import torch.distributed as dist
all_gather_func = TorchBackend.get_all_gather_function()
elif args.dist == 'deepspeed':
import deepspeed.comm as dist
all_gather_func = dist.allgather_fn
sync_all()
# Warmups, establish connections, etc.
for i in range(args.warmups):
# use all_gather_base if available
if args.dist == 'torch':
if hasattr(torch.distributed, "_all_gather_base"):
dist._all_gather_base(output, input, group=None, async_op=args.async_op)
else:
output_tensors = list(torch.chunk(output_tensor, cdb.get_world_size(group)))
dist.all_gather(output_tensors, input_tensor, group=group, async_op=True)
elif args.dist == 'deepspeed':
dist.allgather_fn(output, input, group=None, async_op=args.async_op)
all_gather_func(output, input, group=None, async_op=args.async_op)
sync_all()
# time the actual comm op trials times and average it
pre = time.perf_counter()
for i in range(args.trials):
# use all_gather_base if available
if args.dist == 'torch':
if hasattr(torch.distributed, "_all_gather_base"):
dist._all_gather_base(output, input, group=None, async_op=args.async_op)
else:
output_tensors = list(torch.chunk(output_tensor, cdb.get_world_size(group)))
dist.all_gather(output_tensors, input_tensor, group=group, async_op=True)
elif args.dist == 'deepspeed':
dist.allgather_fn(output, input, group=None, async_op=args.async_op)
all_gather_func(output, input, group=None, async_op=args.async_op)
sync_all()
duration = time.perf_counter() - pre
# maintain and clean performance data
avg_duration = duration / args.trials
size = input.element_size() * input.nelement()
n = dist.get_world_size()
tput, busbw = get_bw('all_gather', size, avg_duration, args)
tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
desc = f'{input.nelement()}x{input.element_size()}'
......@@ -103,9 +91,8 @@ def run_all_gather(local_rank, args):
sync_all()
timed_all_gather(input, output, args)
else:
# all_gather_base saves memory
if (args.dist == 'torch' and hasattr(torch.distributed, "_all_gather_base")) or (args.dist == 'deepspeed'
and dist.has_allgather_base):
# all_gather_into_tensor saves memory
if ((args.dist == 'torch' or args.dist == 'deepspeed') and dist.has_all_gather_into_tensor()):
mem_factor = args.mem_factor + 0.2
else:
mem_factor = args.mem_factor
......
......@@ -193,7 +193,7 @@ def is_available() -> bool:
return True
def set_backend(backend):
def set_backend(backend_name):
if not use_ds_backend:
utils.logger.error(
"DeepSpeed communication backend is required. Please use deepspeed.comm.init_distributed(backend, use_deepspeed=True) to use this functionality"
......@@ -233,12 +233,11 @@ def all_gather(tensor_list,
return cdb.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
def has_reduce_scatter_base():
def has_reduce_scatter_tensor():
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_reduce_scatter_base is not None, 'has_reduce_scatter_base is not yet defined'
return cdb.has_reduce_scatter_base
return cdb.has_reduce_scatter_tensor()
def reduce_scatter_fn(output_tensor,
......@@ -251,18 +250,19 @@ def reduce_scatter_fn(output_tensor,
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
if cdb.has_reduce_scatter_base:
return reduce_scatter_base(output_tensor,
tensor,
op=op,
group=group,
async_op=async_op,
prof=prof,
debug=debug)
if cdb.has_reduce_scatter_tensor():
return reduce_scatter_tensor(output_tensor,
tensor,
op=op,
group=group,
async_op=async_op,
prof=prof,
debug=debug)
else:
utils.logger.warning_once("unable to find torch.distributed._reduce_scatter_base. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
if get_rank() == 0:
utils.logger.warning_once("unable to find torch.distributed.reduce_scatter_tensor. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
input_tensor_lst = list(torch.chunk(tensor, cdb.get_world_size(group)))
return reduce_scatter(output_tensor,
input_tensor_lst,
......@@ -274,51 +274,50 @@ def reduce_scatter_fn(output_tensor,
@timed_op
def reduce_scatter_base(output_tensor,
tensor,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
log_name='reduce_scatter_base',
debug=get_caller_func()):
def reduce_scatter_tensor(output_tensor,
tensor,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
log_name='reduce_scatter_tensor',
debug=get_caller_func()):
global cdb
return cdb.reduce_scatter_base(output_tensor=output_tensor,
input_tensor=tensor,
op=op,
group=group,
async_op=async_op)
return cdb.reduce_scatter_tensor(output_tensor=output_tensor,
input_tensor=tensor,
op=op,
group=group,
async_op=async_op)
@timed_op
def all_gather_base(output_tensor,
tensor,
group=None,
async_op=False,
prof=False,
log_name='all_gather_base',
debug=get_caller_func()):
def all_gather_into_tensor(output_tensor,
tensor,
group=None,
async_op=False,
prof=False,
log_name='all_gather_into_tensor',
debug=get_caller_func()):
global cdb
return cdb.all_gather_base(output_tensor=output_tensor, input_tensor=tensor, group=group, async_op=async_op)
return cdb.all_gather_into_tensor(output_tensor=output_tensor, input_tensor=tensor, group=group, async_op=async_op)
def has_allgather_base():
def has_all_gather_into_tensor():
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_allgather_base is not None, 'has_allgather_base is not yet defined'
return cdb.has_allgather_base
return cdb.has_all_gather_into_tensor()
def allgather_fn(output_tensor, input_tensor, group=None, async_op=False, debug=get_caller_func()):
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
if cdb.has_allgather_base:
return all_gather_base(output_tensor, input_tensor, group=group, async_op=async_op, debug=debug)
if cdb.has_all_gather_into_tensor():
return all_gather_into_tensor(output_tensor, input_tensor, group=group, async_op=async_op, debug=debug)
else:
if get_rank() == 0:
utils.logger.warning_once("unable to find torch.distributed._all_gather_base. will fall back to "
utils.logger.warning_once("unable to find torch.distributed.all_gather_into_tensor. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
output_tensors = list(torch.chunk(output_tensor, cdb.get_world_size(group)))
......
......@@ -20,8 +20,8 @@ class TorchBackend(Backend):
def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'):
super(TorchBackend, self).__init__()
self.torch_version_before_18 = older_torch()
self.has_allgather_base = has_allgather_base()
self.has_reduce_scatter_base = has_reduce_scatter_base()
self.all_gather_function = self.get_all_gather_function()
self.reduce_scatter_function = self.get_reduce_scatter_function()
self.initialized = True
self.name = name
# Future functionality to support ds.initialize() on a single GPU
......@@ -30,6 +30,28 @@ class TorchBackend(Backend):
self.single_gpu_mode = True
self.init_process_group(backend, timeout, init_method, rank, world_size)
@classmethod
def get_all_gather_function(self):
if hasattr(torch.distributed, "all_gather_into_tensor"):
return torch.distributed.all_gather_into_tensor
elif hasattr(torch.distributed, "_all_gather_base"):
return torch.distributed._all_gather_base
return None
@classmethod
def get_reduce_scatter_function(self):
if hasattr(torch.distributed, "reduce_scatter_tensor"):
return torch.distributed.reduce_scatter_tensor
elif hasattr(torch.distributed, "_reduce_scatter_base"):
return torch.distributed._reduce_scatter_base
return None
def has_all_gather_into_tensor(self):
return self.all_gather_function is not None
def has_reduce_scatter_tensor(self):
return self.reduce_scatter_function is not None
def init_process_group(self, backend, timeout, init_method, rank, world_size):
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend,
......@@ -59,27 +81,27 @@ class TorchBackend(Backend):
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):
if self.has_allgather_base:
return torch.distributed.distributed_c10d._all_gather_base(output_tensor=output_tensor,
input_tensor=input_tensor,
group=group,
async_op=async_op)
def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
if self.has_all_gather_into_tensor():
return self.all_gather_function(output_tensor=output_tensor,
input_tensor=input_tensor,
group=group,
async_op=async_op)
else:
utils.logger.warning("unable to find torch.distributed._all_gather_base. will fall back to "
"torch.distributed.reduce_scatter which will result in suboptimal performance. "
utils.logger.warning("unable to find torch.distributed.all_gather_into_tensor. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
pass
def reduce_scatter_base(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
if self.has_reduce_scatter_base:
return torch.distributed._reduce_scatter_base(output_tensor,
input_tensor,
op=self._reduce_op(op),
group=group,
async_op=async_op)
def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
if self.has_reduce_scatter_tensor():
return self.reduce_scatter_function(output_tensor,
input_tensor,
op=self._reduce_op(op),
group=group,
async_op=async_op)
else:
utils.logger.warning("unable to find torch.distributed._reduce_scatter_base. will fall back to "
utils.logger.warning("unable to find torch.distributed.reduce_scatter_tensor. will fall back to "
"torch.distributed.reduce_scatter which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
pass
......
......@@ -20,20 +20,6 @@ def older_torch():
return False
def has_allgather_base():
'''
Helper to check if torch.distributed has _all_gather_base
'''
return hasattr(torch.distributed, "_all_gather_base")
def has_reduce_scatter_base():
'''
Helper to check if torch.distributed has _reduce_scatter_base
'''
return hasattr(torch.distributed, "_reduce_scatter_base")
def get_local_rank_from_launcher():
# DeepSpeed launcher will set it so get from there
......
......@@ -19,9 +19,9 @@ import torch
from deepspeed import comm as dist
try:
from torch._six import inf as inf
from torch._six import inf
except ModuleNotFoundError:
from torch import inf as inf
from torch import inf
from deepspeed.utils import groups, logger
from deepspeed.runtime.constants import PIPE_REPLICATED
......
......@@ -672,11 +672,9 @@ class Init(InsertPostInitMethodToModuleSubClasses):
assert isinstance(module, torch.nn.Module)
self._convert_to_zero_parameters(module.parameters(recurse=True))
self.use_all_gather_base = False
if dist.has_allgather_base():
self.use_all_gather_base = True
else:
logger.info(f"_all_gather_base API is not available in torch {torch.__version__}")
self.use_all_gather_into_tensor = dist.has_all_gather_into_tensor()
if not self.use_all_gather_into_tensor:
logger.info(f"all_gather_into_tensor API is not available in torch {torch.__version__}")
def _convert_to_zero_parameters(self, param_list):
for param in param_list:
......@@ -1124,12 +1122,12 @@ class Init(InsertPostInitMethodToModuleSubClasses):
# param.ds_numel).view(param.ds_shape)
# param.data = replicated_tensor.data
# return None
if self.use_all_gather_base:
# try the _all_gather_base on PyTorch master branch
handle = dist.all_gather_base(flat_tensor,
param.ds_tensor.to(get_accelerator().device_name()),
group=self.ds_process_group,
async_op=async_op)
if self.use_all_gather_into_tensor:
# try the all_gather_into_tensor on PyTorch master branch
handle = dist.all_gather_into_tensor(flat_tensor,
param.ds_tensor.to(get_accelerator().device_name()),
group=self.ds_process_group,
async_op=async_op)
else:
partitions = []
for i in range(self.world_size):
......@@ -1172,12 +1170,12 @@ class Init(InsertPostInitMethodToModuleSubClasses):
for param_idx, param in enumerate(param_list):
input_tensor = local_tensors[param_idx].view(-1)
if self.use_all_gather_base:
# try the _all_gather_base from Pytorch master
h = dist.all_gather_base(allgather_params[param_idx],
input_tensor,
group=self.ds_process_group,
async_op=True)
if self.use_all_gather_into_tensor:
# try the all_gather_into_tensor from Pytorch master
h = dist.all_gather_into_tensor(allgather_params[param_idx],
input_tensor,
group=self.ds_process_group,
async_op=True)
else:
output_list = []
for i in range(self.world_size):
......
......@@ -31,7 +31,7 @@ def calc_bw_log(comm_op, size, duration):
if comm_op == "all_to_all_single":
tput = (size / duration)
busbw = (size / duration) * ((n - 1) / n)
elif comm_op == "all_gather" or comm_op == "all_gather_base" or comm_op == "reduce_scatter" or comm_op == "reduce_scatter_base":
elif comm_op == "all_gather" or comm_op == "all_gather_into_tensor" or comm_op == "reduce_scatter" or comm_op == "reduce_scatter_tensor":
size *= n
tput = (size / duration)
busbw = (size / duration) * ((n - 1) / n)
......
......@@ -46,9 +46,9 @@ There are currently two ways to view communication log records:
If the `enabled` configuration option is selected, all communication operations will be immediately printed to the console. This mode is intended for detailed debugging, and is not recommended for most users. The following is an example snippet of `verbose` output:
```
[2022-06-26 01:39:55,722] [INFO] [logging.py:69:log_dist] [Rank 0] rank=0 | comm op: reduce_scatter_base | time (ms): 9.46 | msg size: 678.86 MB | algbw (Gbps): 1204.52 | busbw (Gbps): 1129.23
[2022-06-26 01:39:56,470] [INFO] [logging.py:69:log_dist] [Rank 0] rank=0 | comm op: all_gather_base | time (ms): 0.11 | msg size: 6.0 MB | algbw (Gbps): 954.41 | busbw (Gbps): 894.76
[2022-06-26 01:39:56,471] [INFO] [logging.py:69:log_dist] [Rank 0] rank=0 | comm op: all_gather_base | time (ms): 0.08 | msg size: 6.0 MB | algbw (Gbps): 1293.47 | busbw (Gbps): 1212.63
[2022-06-26 01:39:55,722] [INFO] [logging.py:69:log_dist] [Rank 0] rank=0 | comm op: reduce_scatter_tensor | time (ms): 9.46 | msg size: 678.86 MB | algbw (Gbps): 1204.52 | busbw (Gbps): 1129.23
[2022-06-26 01:39:56,470] [INFO] [logging.py:69:log_dist] [Rank 0] rank=0 | comm op: all_gather_into_tensor | time (ms): 0.11 | msg size: 6.0 MB | algbw (Gbps): 954.41 | busbw (Gbps): 894.76
[2022-06-26 01:39:56,471] [INFO] [logging.py:69:log_dist] [Rank 0] rank=0 | comm op: all_gather_into_tensor | time (ms): 0.08 | msg size: 6.0 MB | algbw (Gbps): 1293.47 | busbw (Gbps): 1212.63
```
For advanced users, the `debug` option will append the calling function of each communication operation to that operation's `log_name`. See [Log Summaries](#log-summaries) for an example of a `deepspeed.comm.log_summary()` call with `debug` enabled.
......@@ -99,7 +99,7 @@ Comm. Op Message Size Count Total Latency(ms)
broadcast
2.0 KB 146 11.12 0.08 0.43 0.41
98.25 MB 1 8317.12 8317.12 0.20 0.19
reduce_scatter_base
reduce_scatter_tensor
678.86 MB 40 602.29 9.69 1468.06 1376.31
```
......@@ -111,6 +111,6 @@ Comm. Op Message Size Count Total Latency(ms)
broadcast | [Caller Func: _broadcast_model]
2.0 KB 146 9.39 0.06 0.52 0.48
98.25 MB 1 8540.60 8540.60 0.19 0.18
reduce_scatter_base | [Caller Func: reduce_scatter_fn]
reduce_scatter_tensor | [Caller Func: reduce_scatter_fn]
678.86 MB 80 1527.17 13.94 1211.75 1136.01
```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册