未验证 提交 6b49b60e 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Get correct fp16 reuse buffer size (#1071)

上级 29b444b6
......@@ -1075,6 +1075,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
def _create_fp16_partitions_with_defragmentation(self):
dist.barrier()
partition_id = dist.get_rank(group=self.dp_process_group)
create_fp16_flat_reuse_buffer = False
#create a flat CPU memory allocation for each param group
if self.offload_param:
......@@ -1168,16 +1169,21 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
see_memory_usage(f"After Flattening param group {i}", force=False)
#create a pinned memory to be used for swapping out params to NVME after optimizer step
if self.fp16_partitioned_groups_flat[
-1] is None and self.param_group_fp16_flat_reuse_buffer is None:
self.param_group_fp16_flat_reuse_buffer = torch.empty(
max(self.fp16_partitioned_groups_flat_numel),
dtype=self.dtype,
device='cpu',
pin_memory=True)
if self.fp16_partitioned_groups_flat[-1] is None:
create_fp16_flat_reuse_buffer = True
see_memory_usage(f"After Flattening param subgroup {i}", force=False)
if create_fp16_flat_reuse_buffer:
assert self.param_group_fp16_flat_reuse_buffer is None, \
f'Unexpected that pinned memory for swapping params out to NVMe is already created'
self.param_group_fp16_flat_reuse_buffer = torch.empty(max(
self.fp16_partitioned_groups_flat_numel),
dtype=self.dtype,
device='cpu',
pin_memory=True)
def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id):
offset = 0
elements_in_sub_group = sum(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册