未验证 提交 0977106a 编写于 作者: H hablb 提交者: GitHub

zero3 performance optimizations (#3622)

* Remove dead code

params_already_reduced is not used

* Prevent evaluation of debug strings

Debug strings are evaluated even when logging is disabled

* Use contiguous gradients tensor reduce scatter between ranks

Use allreduce instead of reduce scatter. lower cpu overhead.

* move overflow tracker to optimizer.step

Don't check overflow in gradients for every bucket.
Do overflow chack once on grad flat buffer just before optimizer step

---------
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 df425097
......@@ -11,6 +11,7 @@ from enum import Enum
import functools
import itertools
from typing import List
import logging
import torch
from torch import Tensor
from deepspeed import comm as dist
......@@ -898,7 +899,8 @@ class Init(InsertPostInitMethodToModuleSubClasses):
# to debug correctness issues.
params = sorted(params, key=lambda p: p.ds_id)
debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}")
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}")
if safe_mode:
# ensure that same list (with same ordering) of parameters are
......
......@@ -15,6 +15,7 @@ from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id
from deepspeed.accelerator import get_accelerator
import logging
def debug_rank0(message: str) -> None:
......@@ -235,25 +236,28 @@ class PartitionedParameterCoordinator:
2. kick off fetch for next few parameters we will need later (prefetch)
3. block on parameters in immediately required sub module
"""
debug_rank0(
f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} "
+ str({
"avail": f"{self.__n_available_params:.1e}",
"queue_sz": f"{len(self.__param_queue or [])}",
"inflight": [p.ds_id for p in self.__inflight_param_registry],
}))
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(
f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} "
+ str({
"avail": f"{self.__n_available_params:.1e}",
"queue_sz": f"{len(self.__param_queue or [])}",
"inflight": [p.ds_id for p in self.__inflight_param_registry],
}))
params_to_fetch = frozenset(iter_params(current_submodule))
# kick off all gather for params in the immediately required submodule
for param in params_to_fetch:
debug_rank0(f"-fetch: {param.ds_summary()}")
if logger.isEnabledFor(logging.DEBUG):
for param in params_to_fetch:
debug_rank0(f"-fetch: {param.ds_summary()}")
self.__all_gather_params(params_to_fetch)
# wait for parameters in the immediately needed submodule to become available
for param in params_to_fetch:
param.ds_active_sub_modules.add(current_submodule.id)
debug_rank0(f"-wait: {param.ds_summary()}")
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-wait: {param.ds_summary()}")
if param in self.__inflight_param_registry:
with get_accelerator().stream(self.__allgather_stream):
while self.__ongoing_fetch_events and self.__ongoing_fetch_events[0].query():
......@@ -328,8 +332,9 @@ class PartitionedParameterCoordinator:
params_to_prefetch.add(param_in_trace.param)
numel_prefetching += param_in_trace.param.ds_numel
for param in params_to_prefetch:
debug_rank0(f"-prefetch: {param.ds_summary()}")
if logger.isEnabledFor(logging.DEBUG):
for param in params_to_prefetch:
debug_rank0(f"-prefetch: {param.ds_summary()}")
self.__all_gather_params(params_to_prefetch)
if self.__prefetch_nvme:
......@@ -394,7 +399,8 @@ class PartitionedParameterCoordinator:
@instrument_w_nvtx
def __release_param(self, param: Parameter) -> None:
if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules:
debug_rank0(f"-release: {param.ds_summary()}")
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-release: {param.ds_summary()}")
param.partition()
self.__n_available_params -= param.ds_numel
......
......@@ -261,7 +261,6 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
if self.swap_optimizer:
self._configure_tensor_swapping(offload_optimizer_config, aio_config)
self.params_in_ipg_bucket = []
self.is_gradient_accumulation_boundary: bool = True
self.param_reduce_events: Deque[get_accelerator().Event] = collections.deque()
......@@ -277,7 +276,6 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []
self.params_already_reduced = []
self.is_gradient_accumulation_boundary = True
self._release_ipg_buffers()
self.previous_reduced_grads = None
......@@ -291,7 +289,6 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
unique_id = id(param)
self.param_id[unique_id] = count
self.param_dict[count] = param
self.params_already_reduced.append(False)
count = count + 1
#Largest partitioned param
......@@ -307,7 +304,6 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
if self.offload_optimizer:
self.norm_for_param_grads = {}
self.local_overflow = False
# stores if a partition has been reduced in this step
self.is_partition_reduced = {}
......@@ -397,20 +393,20 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
dtype=self.dtype,
device=get_accelerator().current_device_name())
grad_partitions_flat_buffer = None
self.grad_partitions_flat_buffer = None
self.__param_id_to_grad_partition: Dict[int, Tensor] = {}
all_params = list(itertools.chain.from_iterable(self.fp16_groups))
grad_partitions_flat_buffer: Tensor = torch.zeros(sum(p.partition_numel() for p in all_params),
dtype=self.dtype,
device=self.device)
self.grad_partitions_flat_buffer: Tensor = torch.zeros(sum(p.partition_numel() for p in all_params),
dtype=self.dtype,
device=self.device)
if self.offload_optimizer_pin_memory:
grad_partitions_flat_buffer = get_accelerator().pin_memory(grad_partitions_flat_buffer)
self.grad_partitions_flat_buffer = get_accelerator().pin_memory(self.grad_partitions_flat_buffer)
offset = 0
for param in all_params:
self.__param_id_to_grad_partition[param.ds_id] = grad_partitions_flat_buffer.narrow(
self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow(
0, offset, param.partition_numel())
offset += param.partition_numel()
......@@ -966,11 +962,6 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self.reduce_and_partition_stream.synchronize()
# if dist.get_rank() == 0:
# logger.info("Params already reduced %s", self.params_already_reduced)
for i in range(len(self.params_already_reduced)):
self.params_already_reduced[i] = False
#in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad
#TODO: use a similar code path for both cpu_offload and non-cpu offload
if not self.offload_optimizer:
......@@ -1045,18 +1036,11 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
# 0). Otherwise if the incoming param.ds_numel is large, this branch may get triggered on a
# garbage data and `self.average_tensor()` will crash because its params_to_reduce will be
# empty, while reduction_list will have that garbage data.
if self.elements_in_ipg_bucket > 0 and self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size:
if self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size and self.elements_in_ipg_bucket > 0:
self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.ds_numel)
self.__reduce_and_partition_ipg_grads()
param_id = self.get_param_id(param)
assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Gradient computed twice for this partition. \
Multiple gradient reduction is currently not supported"
self.__add_grad_to_ipg_bucket(param)
@instrument_w_nvtx
......@@ -1087,8 +1071,6 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
raise RuntimeError(f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter "
f"gradients whose size is not same as the params")
self.params_in_ipg_bucket.sort(key=lambda p: p.ds_id)
assert len(set(p.ds_id for p in self.params_in_ipg_bucket)) == len(self.params_in_ipg_bucket)
while self.param_reduce_events and self.param_reduce_events[0].query():
......@@ -1100,7 +1082,13 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
if safe_mode:
assert_ints_same_as_other_ranks([p.ds_id for p in self.params_in_ipg_bucket])
grad_partitions = self.__avg_scatter_grads(self.params_in_ipg_bucket)
if self.contiguous_gradients and not self.reduce_scatter:
grad_bucket = self.__ipg_bucket_flat_buffer.narrow(0, 0, self.elements_in_ipg_bucket)
grad_partitions = self.__avg_scatter_contiguous_grads(grad_bucket)
else:
self.params_in_ipg_bucket.sort(key=lambda p: p.ds_id)
grad_partitions = self.__avg_scatter_grads(self.params_in_ipg_bucket)
self.partition_grads(self.params_in_ipg_bucket, grad_partitions)
self.params_in_ipg_bucket.clear()
......@@ -1109,6 +1097,47 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
event.record()
self.param_reduce_events.append(event)
@instrument_w_nvtx
def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor) -> List[Tensor]:
dtype = buffer_to_reduce.dtype
if self.communication_data_type == self.dtype:
buffer_to_reduce = buffer_to_reduce.to(self.communication_data_type)
if self.postscale_gradients and self.gradient_predivide_factor != 1.0:
buffer_to_reduce = buffer_to_reduce.div_(self.gradient_predivide_factor)
world_sz = dist.get_world_size(self.dp_process_group)
rank = dist.get_rank(self.dp_process_group)
buffer_to_reduce.div_(world_sz)
dist.all_reduce(buffer_to_reduce, group=self.dp_process_group)
if self.postscale_gradients and self.gradient_predivide_factor != world_sz:
buffer_to_reduce = buffer_to_reduce.mul(self.gradient_predivide_factor)
if self.communication_data_type != self.dtype:
buffer_to_reduce = buffer_to_reduce.to(self.dtype)
grad_partitions = []
grad_offset_in_buffer = 0
for param in self.params_in_ipg_bucket:
grad = param.grad
chunk_sz = math.ceil(grad.numel() / world_sz)
start_offset = grad_offset_in_buffer + min(rank * chunk_sz, grad.numel())
end_offset = grad_offset_in_buffer + min(rank * chunk_sz + chunk_sz, grad.numel())
partition = buffer_to_reduce[start_offset:end_offset]
if param.partition_numel() != partition.numel():
padded_partition = torch.empty(param.partition_numel(), device=grad.device, dtype=grad.dtype)
if partition.numel() > 0:
padded_partition[:partition.numel()] = partition
grad_partitions.append(padded_partition)
else:
grad_partitions.append(partition)
grad_offset_in_buffer += grad.numel()
return grad_partitions
@instrument_w_nvtx
def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]:
"""average gradients and scatter partitions across ranks"""
......@@ -1223,15 +1252,6 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
# operations and so it can be used asynchronously
grad_buffer = cuda_grad_buffer
if hasattr(self.inf_or_nan_tracker, "logical_or_"):
self.inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any())
self.inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any())
else:
# logical_or_ not available in older versions of pytorch
self.inf_or_nan_tracker += torch.isinf(grad_buffer).any()
self.inf_or_nan_tracker += torch.isnan(grad_buffer).any()
self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0
# offload the gradient partition if applicable
if self.offload_optimizer:
i, dest_offset, _ = self.grad_position[self.get_param_id(param)]
......@@ -1567,7 +1587,6 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
def reset_cpu_buffers(self):
self.norm_for_param_grads = {}
self.local_overflow = False
def log_timers(self, timer_names):
if self.timers is None:
......@@ -1901,12 +1920,19 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
def has_overflow(self, partition_gradients=True):
if partition_gradients:
with get_accelerator().stream(self.reduce_and_partition_stream):
self.local_overflow = bool(self.inf_or_nan_tracker.item())
if hasattr(self.inf_or_nan_tracker, "logical_or_"):
self.inf_or_nan_tracker.logical_or_(torch.isinf(self.grad_partitions_flat_buffer).any())
self.inf_or_nan_tracker.logical_or_(torch.isnan(self.grad_partitions_flat_buffer).any())
else:
# logical_or_ not available in older versions of pytorch
self.inf_or_nan_tracker += torch.isinf(self.grad_partitions_flat_buffer).any()
self.inf_or_nan_tracker += torch.isnan(self.grad_partitions_flat_buffer).any()
self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0
overflow_gpu = self.inf_or_nan_tracker.clone().to(torch.uint8)
self.inf_or_nan_tracker.zero_()
overflow = self.local_overflow
#overflow = self.has_overflow_partitioned_grads_serial()
overflow_gpu = get_accelerator().ByteTensor([overflow])
get_accelerator().default_stream().wait_stream(self.reduce_and_partition_stream)
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group)
else:
......
......@@ -736,10 +736,11 @@ class TestZero3ParamPartitioningBase(DistributedTest):
@pytest.mark.parametrize("init_context_manager", [True, False])
@pytest.mark.parametrize("reduce_scatter", [True, False])
class TestZero3ParamPartitioningLargeParam(DistributedTest):
world_size = 4
def test(self, init_context_manager: bool, param_sz: int = 8100) -> None:
def test(self, init_context_manager: bool, reduce_scatter: bool, param_sz: int = 8100) -> None:
class LargeParamModel(Module):
......@@ -767,6 +768,7 @@ class TestZero3ParamPartitioningLargeParam(DistributedTest):
"stage3_max_reuse_distance": 0,
"contiguous_gradients": True,
"overlap_comm": True,
"reduce_scatter": reduce_scatter,
},
"optimizer": {
"type": "Adam",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册