未验证 提交 a855405e 编写于 作者: S Sam Ade Jacobs 提交者: GitHub

DeepSpeed Ulysses release (#4198)

Co-authored-by: NMasahiro Tanaka <mtanaka@microsoft.com>
上级 6df15873
......@@ -346,6 +346,12 @@ def all_to_all_single(output,
async_op=async_op)
@timed_op
def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False):
global cdb
return cdb.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op)
@timed_op
def send(tensor, dst, group=None, tag=0, prof=False, log_name='send', debug=get_caller_func()):
global cdb
......
......@@ -273,6 +273,9 @@ class TorchBackend(Backend):
group=group,
async_op=async_op)
def all_to_all(self, output_tensor_list, input_tensor_list, group=None, async_op=False):
return torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op)
def send(self, tensor, dst, group=None, tag=0):
return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag)
......
......@@ -906,9 +906,7 @@ class DeepSpeedEngine(Module):
logger.error(f"No torch_nebula was found! Will fall back to torch.save. Details: {err}")
self.checkpoint_engine = TorchCheckpointEngine()
dp_rank = self.global_rank
if self.mpu:
dp_rank = self.mpu.get_data_parallel_rank()
dp_rank = groups._get_sequence_data_parallel_rank()
rank = self.local_rank if self.use_node_local_storage() else dp_rank
......@@ -1040,7 +1038,7 @@ class DeepSpeedEngine(Module):
group=self.expert_data_parallel_group[p.group_name])
else:
if torch.is_tensor(p) and is_replicated(p):
dist.broadcast(p, groups._get_broadcast_src_rank(), group=self.data_parallel_group)
dist.broadcast(p, groups._get_broadcast_src_rank(), group=self.seq_data_parallel_group)
@staticmethod
def __check_params(model: Module, dtype: torch.dtype) -> None:
......@@ -1110,6 +1108,8 @@ class DeepSpeedEngine(Module):
self.local_all_to_all_group = groups._get_local_all_to_all_group()
self.data_parallel_group = groups._get_data_parallel_group()
self.dp_world_size = groups._get_data_parallel_world_size()
self.seq_data_parallel_group = groups._get_sequence_data_parallel_group()
self.seq_dp_world_size = groups._get_sequence_data_parallel_world_size()
self.mp_world_size = groups._get_model_parallel_world_size()
self.expert_parallel_group = groups._get_expert_parallel_group_dict()
self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict()
......@@ -1406,7 +1406,7 @@ class DeepSpeedEngine(Module):
mpu=self.mpu,
clip_grad=clip_grad,
allgather_bucket_size=self.zero_allgather_bucket_size(),
dp_process_group=self.data_parallel_group,
dp_process_group=self.seq_data_parallel_group,
timers=timers)
return optimizer
......@@ -1457,7 +1457,7 @@ class DeepSpeedEngine(Module):
contiguous_gradients=contiguous_gradients,
reduce_bucket_size=self.zero_reduce_bucket_size(),
allgather_bucket_size=self.zero_allgather_bucket_size(),
dp_process_group=self.data_parallel_group,
dp_process_group=self.seq_data_parallel_group,
expert_parallel_group=self.expert_parallel_group if self.has_moe_layers else None,
expert_data_parallel_group=self.expert_data_parallel_group if self.has_moe_layers else None,
reduce_scatter=self.zero_reduce_scatter(),
......@@ -1527,7 +1527,7 @@ class DeepSpeedEngine(Module):
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(),
dp_process_group=self.data_parallel_group,
dp_process_group=self.seq_data_parallel_group,
all2all_process_group=self.local_all_to_all_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
......@@ -1569,7 +1569,7 @@ class DeepSpeedEngine(Module):
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(),
dp_process_group=self.data_parallel_group,
dp_process_group=self.seq_data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
offload_optimizer_config=self.zero_offload_optimizer(),
......@@ -2837,10 +2837,10 @@ class DeepSpeedEngine(Module):
zero_sd_list = None
checkpoint_folder = f'{os.path.join(load_dir, tag)}'
else:
if load_optimizer_states and self.dp_world_size != self.loaded_checkpoint_dp_world_size:
if load_optimizer_states and self.seq_dp_world_size != self.loaded_checkpoint_dp_world_size:
raise ZeRORuntimeException("The checkpoint being loaded used a DP " \
f"world size of {self.loaded_checkpoint_dp_world_size} but the " \
f"current world size is {self.dp_world_size}. Automatic adjustment " \
f"current world size is {self.seq_dp_world_size}. Automatic adjustment " \
"of ZeRO's optimizer state partitioning with a new world size is not " \
"currently supported.")
checkpoint_folder = None
......@@ -3191,7 +3191,7 @@ class DeepSpeedEngine(Module):
skipped_steps=self.skipped_steps,
global_steps=self.global_steps,
global_samples=self.global_samples,
dp_world_size=self.dp_world_size,
dp_world_size=self.seq_dp_world_size,
mp_world_size=self.mp_world_size,
ds_config=self.config,
ds_version=version)
......
......@@ -713,22 +713,21 @@ class Init(InsertPostInitMethodToModuleSubClasses):
apply_param_persistence = False
override_module_apply = get_config_default(DeepSpeedZeroConfig, "override_module_apply")
def __init__(
self,
module=None,
data_parallel_group=None,
mem_efficient_linear=True,
remote_device=None,
pin_memory=False,
config_dict_or_path=None,
config=None,
enabled=True,
dtype=None,
mpu=None,
zero_param_parallel_group=None,
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
):
def __init__(self,
module=None,
data_parallel_group=None,
mem_efficient_linear=True,
remote_device=None,
pin_memory=False,
config_dict_or_path=None,
config=None,
enabled=True,
dtype=None,
mpu=None,
zero_param_parallel_group=None,
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
sequence_data_parallel_group=None):
"""A context to enable massive model construction for training with
ZeRO-3. Models are automatically partitioned (or, sharded) across the
system and converted to half precision.
......@@ -833,10 +832,17 @@ class Init(InsertPostInitMethodToModuleSubClasses):
if not dist.is_initialized():
init_distributed()
assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm"
if data_parallel_group is None:
if data_parallel_group is None and sequence_data_parallel_group is None:
self.ds_process_group = dist.get_world_group()
else:
elif sequence_data_parallel_group is not None:
self.ds_process_group = sequence_data_parallel_group
elif data_parallel_group is not None:
self.ds_process_group = data_parallel_group
else: # both given
raise ValueError(
"Both 'data_parallel_group' and 'sequence_data_parallel_group' were specified. Please provide only one of these arguments."
)
self.rank = dist.get_rank(group=self.ds_process_group)
self.dp_world_size = dist.get_world_size(group=self.ds_process_group)
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from typing import Any, Tuple
from torch import Tensor
from torch.nn import Module
import deepspeed.comm as dist
class _SeqAllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
ctx.group = group
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx
seq_world_size = dist.get_world_size(group)
input_list = [t.contiguous() for t in torch.tensor_split(input, seq_world_size, scatter_idx)]
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
# TODO Use all_to_all_single instead
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_idx).contiguous()
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
class DistributedAttention(torch.nn.Module):
"""Initialization.
Arguments:
local_attention (Module): local attention with q,k,v
sequence_process_group (ProcessGroup): sequence parallel process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
"""
def __init__(
self,
local_attention: Module,
sequence_process_group: dist.ProcessGroup,
scatter_idx: int = 2,
gather_idx: int = 0,
) -> None:
super(DistributedAttention, self).__init__()
self.local_attn = local_attention
self.spg = sequence_process_group
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tensor:
""" forward
Arguments:
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
args: other args
Returns:
* output (Tensor): context output
"""
# TODO Merge three alltoall calls into one
#in shape : e.g., [s/p:h:]
query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx)
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)
#out shape : e.g., [s:h/p:]
context_layer = self.local_attn(query_layer, key_layer, value_layer, *args)
output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
#out e.g., [s/p::h]
return output
......@@ -361,7 +361,7 @@ def _get_data_parallel_group():
def _get_broadcast_src_rank():
return dist.get_global_rank(_get_data_parallel_group(), 0)
return dist.get_global_rank(_get_sequence_data_parallel_group(), 0)
def _get_expert_broadcast_src_rank(group_name):
......@@ -414,12 +414,57 @@ def _get_model_parallel_world_size():
def _get_data_parallel_rank():
"""Return my rank for the data parallel group."""
global mpu
if mpu is not None:
return mpu.get_data_parallel_rank()
return dist.get_rank(group=_get_data_parallel_group())
def _get_sequence_parallel_world_size():
"""Return world size for the model parallel group."""
global mpu
if mpu is not None and hasattr(mpu, 'get_sequence_parallel_world_size'):
return mpu.get_sequence_parallel_world_size()
return 1
def _get_sequence_parallel_rank():
"""Return my rank for the data parallel group."""
global mpu
if mpu is not None and hasattr(mpu, 'get_sequence_parallel_rank'):
return mpu.get_sequence_parallel_rank()
return 0
def _get_sequence_parallel_group():
global mpu
if mpu is not None and hasattr(mpu, 'get_sequence_parallel_group'):
return mpu.get_sequence_parallel_group()
return None
def _get_sequence_data_parallel_world_size():
"""Return world size for the model parallel group."""
global mpu
if mpu is not None and hasattr(mpu, 'get_sequence_data_parallel_world_size'):
return mpu.get_sequence_data_parallel_world_size()
return _get_data_parallel_world_size()
def _get_sequence_data_parallel_rank():
"""Return my rank for the data parallel group."""
global mpu
if mpu is not None and hasattr(mpu, 'get_sequence_data_parallel_rank'):
return mpu.get_sequence_data_parallel_rank()
return _get_data_parallel_rank()
def _get_sequence_data_parallel_group():
global mpu
# When sequence parallelism is enabled, the process group for zero sharding and
# gradient allreduce must be across both dimensions of data and sequence parallelism.
if mpu is not None and hasattr(mpu, 'get_sequence_data_parallel_group'):
return mpu.get_sequence_data_parallel_group()
return _get_data_parallel_group()
def _get_expert_model_parallel_world_size():
global expert_tensor_parallel_world_size
return expert_tensor_parallel_world_size
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册