未验证 提交 673cb608 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Improve z3 trace management (#1916)

* Fix OOM and type mismatch

* Toggle prefetching

* Disable z3 prefetching for inference (temp workaround)

* Fix zero3 tracing issues

* Remove debug prints

* Enable prefetch for inference

* Code clarity

* Invalidate trace cache

* Trace cache invalidation when needed
Separate nvme prefetch from all-gather prefetch

* Track last used step id

* Use debug name in error message

* Construct param trace from module trace
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 a3b90030
......@@ -27,7 +27,14 @@ import deepspeed
from ..utils import get_only_unique_item, see_memory_usage
from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks
from deepspeed.utils import init_distributed, instrument_w_nvtx, logger
from deepspeed.utils.debug import debug_param2name_id_shape, debug_param2name_id_shape_device, debug_module2name, debug_param2name, debug_param2name_id_shape_status, printflock, log_rank_file
from deepspeed.utils.debug import (debug_param2name_id_shape,
debug_param2name_id_shape_device,
debug_module2name,
debug_param2name,
debug_param2name_id,
debug_param2name_id_shape_status,
printflock,
log_rank_file)
from deepspeed.utils.logging import logger
from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus
......@@ -937,9 +944,9 @@ class Init(InsertPostInitMethodToModuleSubClasses):
param.all_gather()
return param._orig_item()
def ds_summary(slf: torch.Tensor) -> dict:
def ds_summary(slf: torch.Tensor, use_debug_name: bool = False) -> dict:
return {
"id": slf.ds_id,
"id": debug_param2name_id(slf) if use_debug_name else slf.ds_id,
"status": slf.ds_status.name,
"numel": slf.numel(),
"ds_numel": slf.ds_numel,
......
......@@ -16,6 +16,7 @@ from deepspeed.utils.logging import logger
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.offload_constants import *
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id
def debug_rank0(message: str) -> None:
......@@ -33,7 +34,7 @@ def iter_params(module: Module, recurse=False) -> Iterable[Parameter]:
return map(lambda pair: pair[1], get_all_parameters(module, recurse))
class TraceMode(Enum):
class ZeRoTraceMode(Enum):
# Record trace of the network during a single forward+backward (for training) or forward (for inference)
RECORD = 1
# Use recorded network trace to optimize current forward+backward or forward
......@@ -75,12 +76,14 @@ class PartitionedParameterCoordinator:
# keeps track of the number of submodules invoked so far.
self.__step_id: int = 0
# network tracing mode
self.__trace_mode: TraceMode = TraceMode.RECORD
self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.RECORD
# sequence of submodules/parameters in forward pass + backward pass
self.__submodule_order: Iterable[Module] = []
self.__param_order: Iterable[__class__.__ParamInTrace] = []
self.__most_recent_step_id_param_fetched_for = collections.defaultdict(
lambda: int(-1e10))
self.__step_id_module_fetched_for = collections.defaultdict(
lambda: collections.deque())
# number of available params, and max number of available params
self.__n_available_params: int = 0
self.__max_n_available_params: int = max_available_parameters_in_numel
......@@ -126,24 +129,29 @@ class PartitionedParameterCoordinator:
self.__param_queue = None
def is_complete_trace(self) -> bool:
return self.__trace_mode == TraceMode.COMPLETE
return self.__trace_mode == ZeRoTraceMode.COMPLETE
def is_invalid_trace(self) -> bool:
return self.__trace_mode == TraceMode.INVALID
return self.__trace_mode == ZeRoTraceMode.INVALID
def is_record_trace(self) -> bool:
return self.__trace_mode == TraceMode.RECORD
return self.__trace_mode == ZeRoTraceMode.RECORD
def _invalidate_trace(self) -> None:
if self.is_invalid_trace():
raise RuntimeError("attempted to invalidate already invalid trace")
self.__trace_mode = TraceMode.INVALID
self.__trace_mode = ZeRoTraceMode.INVALID
self._clear_trace_structures()
def trace_prologue(self, sub_module: Module) -> None:
if self.is_complete_trace():
# sub_module must match expectation else invalidate trace cache
if sub_module != self.__submodule_order[self.__step_id]:
expected_module_id = self.__submodule_order[self.__step_id].id
debug_rank0(
f"Invalidate trace cache @ step {self.__step_id}: "
f"expected module {expected_module_id}, but got module {sub_module.id}"
)
self._invalidate_trace()
def record_module(self, sub_module: Module) -> None:
......@@ -151,17 +159,27 @@ class PartitionedParameterCoordinator:
if not self.is_record_trace():
raise RuntimeError(
f"attempted to record trace when status = {self.__trace_mode}")
self.__submodule_order.append(sub_module)
self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id)
def record_parameters(self, sub_module: Module) -> None:
"""adds sub module to trace"""
if not self.is_record_trace():
raise RuntimeError(
f"attempted to record trace when status = {self.__trace_mode}")
step_id = self.__step_id_module_fetched_for[sub_module.id].popleft()
for param in sorted(set(iter_params(sub_module)), key=lambda p: p.ds_id):
self.__param_order.append(
__class__.__ParamInTrace(param=param,
step_id_last_used_at=self.__step_id))
step_id_last_used_at=step_id))
def construct_parameter_trace_from_module_trace(self):
"""use module trace to construct parameter trace"""
self.__param_order = []
for sub_module in self.__submodule_order:
self.record_parameters(sub_module)
def reset_step(self) -> None:
"""indicate that we have completed one fwd+bwd for the model"""
......@@ -180,22 +198,38 @@ class PartitionedParameterCoordinator:
if self.is_record_trace():
# Successfully recorded a trace
self.construct_parameter_trace_from_module_trace()
self.__submodule_order = tuple(self.__submodule_order) # freeze
self.__param_order = tuple(self.__param_order) # freeze
self.__trace_mode = TraceMode.COMPLETE # self.trace_complete = True
self.__trace_mode = ZeRoTraceMode.COMPLETE
print_rank_0(
f"completed trace: {[m.id for m in self.__submodule_order]}",
f"completed record trace: {[m.id for m in self.__submodule_order]}",
force=False)
else:
# Enable trace recording for next forward/backward pass
self.__trace_mode = TraceMode.RECORD
self.__trace_mode = ZeRoTraceMode.RECORD
self.__param_queue = collections.deque(self.__param_order) # reset fetch queue
self.__most_recent_step_id_param_fetched_for = collections.defaultdict(
lambda: int(-1e10))
self.__step_id_module_fetched_for = collections.defaultdict(
lambda: collections.deque())
self.__step_id = 0
self.__n_available_params = 0
def _dump_params(self, tag, sub_module, params, step_id=None):
if step_id is None:
step_id = self.__step_id
param_names = [debug_param2name_id(p) for p in params]
print(
f'{tag} step = {step_id} mod = {debug_module2name_id(sub_module)} p_names = {param_names}'
)
def _dump_param_ids(self, tag, mod_id, p_ids, step_id=None):
if step_id is None:
step_id = self.__step_id
print(f'{tag} mod = {mod_id}, step = {step_id}, p_ids = {p_ids}')
"""Fetch and Release
Fetching, prefetching, and releasing parameters
"""
......@@ -264,15 +298,23 @@ class PartitionedParameterCoordinator:
self.__most_recent_step_id_param_fetched_for[
param_in_trace.param] = param_in_trace.step_id_last_used_at
discarded_from_prefetch_queue.add(param_in_trace.param)
if discarded_from_prefetch_queue != params_not_already_fetched:
raise RuntimeError(
f"tracing error at step {self.__step_id}: \n"
f"module id: {current_submodule.id}, training: {current_submodule.training}\n"
f"expected the next {len(params_not_already_fetched)} parameters in the "
f"parameter fetch queue to be {tuple(p.ds_summary() for p in params_not_already_fetched)} \n"
f"but got \n {tuple(p.ds_summary() for p in discarded_from_prefetch_queue)}."
f"parameter fetch queue to be {tuple(p.ds_summary(use_debug_name=True) for p in params_not_already_fetched)} \n"
f"but got \n {tuple(p.ds_summary(use_debug_name=True) for p in discarded_from_prefetch_queue)}."
)
def _is_currently_on_nvme(param):
if param.nvme_swapper is None:
return False
return param.ds_tensor.final_location == OFFLOAD_NVME_DEVICE \
and param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE
# kick off all gather for params in the next few submodules (prefetch)
if self.__prefetch_bucket_sz > 0:
max_params_to_prefetch = min(
......@@ -283,11 +325,25 @@ class PartitionedParameterCoordinator:
while self.__param_queue and numel_prefetching < max_params_to_prefetch:
param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft(
)
self.__most_recent_step_id_param_fetched_for[
param_in_trace.param] = param_in_trace.step_id_last_used_at
if param_in_trace.param not in params_to_prefetch:
if _is_currently_on_nvme(param_in_trace.param):
# nvme prefetch is handled elsewhere. Need to break here to preserve fetch order
self.__param_queue.appendleft(param_in_trace)
break
do_prefetch = param_in_trace.param.ds_status == ZeroParamStatus.NOT_AVAILABLE
if param_in_trace.param in params_to_prefetch:
# Avoid duplicates
do_prefetch = False
self.__most_recent_step_id_param_fetched_for[param_in_trace.param] = \
max(self.__most_recent_step_id_param_fetched_for[param_in_trace.param],
param_in_trace.step_id_last_used_at)
if do_prefetch:
params_to_prefetch.add(param_in_trace.param)
numel_prefetching += param_in_trace.param.ds_numel
for param in params_to_prefetch:
debug_rank0(f"-prefetch: {param.ds_summary()}")
self.__all_gather_params(params_to_prefetch)
......
......@@ -1222,8 +1222,6 @@ class DeepSpeedZeroOptimizer_Stage3(object):
force=False)
param_coordinator = self._get_param_coordinator(training=sub_module.training)
if param_coordinator.is_record_trace():
param_coordinator.record_parameters(sub_module)
param_coordinator.release_sub_module(sub_module)
see_memory_usage(
......@@ -1236,7 +1234,6 @@ class DeepSpeedZeroOptimizer_Stage3(object):
param_coordinator.trace_prologue(sub_module)
if param_coordinator.is_record_trace():
param_coordinator.record_module(sub_module)
param_coordinator.record_parameters(sub_module)
param_coordinator.fetch_sub_module(sub_module)
@torch.no_grad()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册