未验证 提交 9377921a 编写于 作者: H Heyang Qin 提交者: GitHub

Separate ZeRO3 InflightParamRegistry for train and eval (#3884)

* create standalone registries for training and eval respectively
---------
Co-authored-by: NAmmar Ahmad Awan <ammar.awan@microsoft.com>
上级 db4638d1
......@@ -250,7 +250,10 @@ class DeepSpeedZeRoOffload(object):
self.__allgather_stream = get_accelerator().Stream() if overlap_comm else get_accelerator().default_stream()
if not hasattr(module, "ds_inflight_param_registry"):
module.ds_inflight_param_registry = InflightParamRegistry()
module.ds_inflight_param_registry = dict()
# we need two registries, one for training and one for eval. They will be used when creating PartitionedParameterCoordinator
module.ds_inflight_param_registry[True] = InflightParamRegistry()
module.ds_inflight_param_registry[False] = InflightParamRegistry()
self.__inflight_param_registry = module.ds_inflight_param_registry
self.forward_hooks = []
......@@ -279,7 +282,7 @@ class DeepSpeedZeRoOffload(object):
max_reuse_distance_in_numel=self._max_reuse_distance_in_numel,
max_available_parameters_in_numel=self._max_available_parameters_in_numel,
allgather_stream=self.__allgather_stream,
inflight_param_registry=self.__inflight_param_registry,
inflight_param_registry=self.__inflight_param_registry[training],
prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme,
timers=self.timers,
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册