未验证 提交 53182531 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Refactor universal checkpointing and tensor fragments (#2253)

* Refactor universal checkpointing and tensor fragments

* Formatting
上级 47e030f5
......@@ -11,3 +11,5 @@ from .reshape_utils import (merge_state)
from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor)
from .zero_checkpoint import ZeROCheckpoint
from .universal_checkpoint import enable_universal_checkpoint
......@@ -21,8 +21,11 @@ FP32_WEIGHT_KEY = "fp32"
#########################################
# Module checkpoint keys
#########################################
PARAM = 'param'
PARAM_SHAPES = 'param_shapes'
BUFFER_NAMES = 'buffer_names'
VOCAB_DIVISIBILITY_PADDING_TENSOR = 'vocab_divisibility_padding_tensor'
CAT_DIM = "cat_dim"
#########################################
# Checkpoint naming constants
......
"""
Copyright 2022 The Microsoft DeepSpeed Team
"""
import os
import torch
import types
from .constants import (FP32_WEIGHT_KEY,
PARAM,
VOCAB_DIVISIBILITY_PADDING_TENSOR,
CAT_DIM)
def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
hp_mapping = self._hp_mapping
optim_state_keys = hp_mapping.get_optim_state_keys()
hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys
checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys}
for file in checkpoint_files.values():
assert os.path.isfile(file), f'{file} is not a valid file'
for key in hp_keys:
ckpt_file = checkpoint_files[key]
ckpt_dict = torch.load(ckpt_file)
full_hp_param = ckpt_dict[PARAM]
# need to deal with slices that were averaged.
# the opposite of averaging here becomes an exact copy of the first slice
# I thought of 2 ways:
# implementation a. find a way for a client to pass a dict with patterns
# if any(re.search(pattern, folder) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS):
# tp_rank = 0
# tp_world_size = 1
# the other approach is to assume that the saved data is correct and if full_hp_param.shape ==
# self.shape that means we automatically copy?
# implementation b.
# this version requires no additional data passed from the client
# if the shapes already match it must be slices that were averaged - so we just hack around those
if full_hp_param.shape == self.shape:
tp_rank = 0
tp_world_size = 1
# special case for word_embeddings weights which get padded differently depending on TP degree.
# the converter to universal currently strips the original padding completely so the saved
# weight is padding-free and we just need to add new padding depending on the target TP
# degree
vocab_divisibility_padding_tensor = ckpt_dict.get(
VOCAB_DIVISIBILITY_PADDING_TENSOR,
None)
if vocab_divisibility_padding_tensor is not None:
# In the absence of data passed from the user wrt new padded vocab specific to tp degree
# we can again derive that data by reverse engineering the target shapes like so:
padded_target_vocab_size = self.shape[0] * tp_world_size
if padded_target_vocab_size > full_hp_param.shape[0]:
# Need to expand
padding_tensor = vocab_divisibility_padding_tensor.expand(
padded_target_vocab_size - full_hp_param.shape[0])
# Implement the following concat in efficient way using pad
#full_hp_param = torch.cat((full_hp_param, padding_tensor), 0)
full_hp_param = torch.nn.functional.pad(full_hp_param,
(0,
0,
0,
padding_tensor.shape[0]),
"constant",
0)
full_hp_param[:-padding_tensor.shape[0], :] = padding_tensor
else:
# Need to shrink or keep the same
full_hp_param = full_hp_param[:padded_target_vocab_size, :]
full_param_numel = full_hp_param.numel()
tp_slice_numel = self.numel()
# if key == FP32_WEIGHT_KEY and 'word_embeddings.weight' in folder:
# print_rank_0(f'{full_hp_param[:10]=}', force=True)
assert full_param_numel == tp_world_size * tp_slice_numel, \
f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}'
dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment(
key)
# print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
# print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")
# since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse
chunk_dim = ckpt_dict.get(CAT_DIM, 0)
# this performs the opposite of cat when merging TP slices
tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank]
tp_hp_slice = tp_hp_slice.flatten()
lp_frag_address = hp_mapping.lp_fragment_address
tp_hp_fragment = tp_hp_slice.narrow(0,
lp_frag_address.start,
lp_frag_address.numel)
assert dst_tensor.numel() == lp_frag_address.numel, \
f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}'
# print(f"{key} SHAPE: {tp_hp_slice.shape=}")
# print(f"{key} SHAPE: {dst_tensor.shape=}")
# print(f"{key} SHAPE: {tp_hp_fragment.shape=}")
dst_tensor.data.copy_(tp_hp_fragment.data)
def enable_universal_checkpoint(param_list):
for param in param_list:
param.load_hp_checkpoint_state = types.MethodType(load_hp_checkpoint_state,
param)
......@@ -21,162 +21,15 @@ from deepspeed.runtime.utils import (get_global_norm_of_tensors,
is_model_parallel_parameter,
see_memory_usage)
from deepspeed.utils import link_hp_params
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.checkpoint.constants import (DS_VERSION,
PARTITION_COUNT,
BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS,
CLIP_GRAD,
GROUP_PADDINGS,
PARAM_SLICE_MAPPINGS,
FP32_WEIGHT_KEY)
import types
from dataclasses import dataclass
@dataclass
class fragment_address:
numel: int
start: int
@dataclass
class tensor_fragment:
lp_fragment: torch.Tensor
lp_fragment_address: fragment_address
hp_fragment: torch.Tensor
hp_fragment_address: fragment_address
optim_fragment: {}
def update_hp(self):
self.hp_fragment.data.copy_(self.lp_fragment.data)
def update_lp(self):
self.lp_fragment.data.copy_(self.hp_fragment.data)
def get_optim_state_fragment(self, key):
if key in self.optim_fragment:
return self.optim_fragment[key]
else:
raise ValueError(f'{key} not found in optimizer state fragment')
def get_hp_fragment_address(self):
return self.hp_fragment_address
def get_optim_state_keys(self):
return list(self.optim_fragment.keys())
def get_full_hp_param(self, optim_state_key=None):
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
if self._hp_mapping is not None:
lp_frag_address = self._hp_mapping.lp_fragment_address
reduce_fragment = torch.narrow(reduce_buffer,
0,
lp_frag_address.start,
lp_frag_address.numel)
if optim_state_key is None:
hp_fragment = self._hp_mapping.hp_fragment
else:
hp_fragment = self._hp_mapping.get_optim_state_fragment(optim_state_key)
reduce_fragment.data.copy_(hp_fragment.data)
dist.all_reduce(reduce_buffer, group=self._dp_group)
return reduce_buffer.reshape_as(self)
def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
hp_mapping = self._hp_mapping
optim_state_keys = hp_mapping.get_optim_state_keys()
hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys
checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys}
for file in checkpoint_files.values():
assert os.path.isfile(file), f'{file} is not a valid file'
for key in hp_keys:
ckpt_file = checkpoint_files[key]
ckpt_dict = torch.load(ckpt_file)
full_hp_param = ckpt_dict['param']
# need to deal with slices that were averaged.
# the opposite of averaging here becomes an exact copy of the first slice
# I thought of 2 ways:
# implementation a. find a way for a client to pass a dict with patterns
# if any(re.search(pattern, folder) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS):
# tp_rank = 0
# tp_world_size = 1
# the other approach is to assume that the saved data is correct and if full_hp_param.shape ==
# self.shape that means we automatically copy?
# implementation b.
# this version requires no additional data passed from the client
# if the shapes already match it must be slices that were averaged - so we just hack around those
if full_hp_param.shape == self.shape:
tp_rank = 0
tp_world_size = 1
# special case for word_embeddings weights which get padded differently depending on TP degree.
# the converter to universal currently strips the original padding completely so the saved
# weight is padding-free and we just need to add new padding depending on the target TP
# degree
vocab_divisibility_padding_tensor = ckpt_dict.get(
'vocab_divisibility_padding_tensor',
None)
if vocab_divisibility_padding_tensor is not None:
# In the absence of data passed from the user wrt new padded vocab specific to tp degree
# we can again derive that data by reverse engineering the target shapes like so:
padded_target_vocab_size = self.shape[0] * tp_world_size
if padded_target_vocab_size > full_hp_param.shape[0]:
# Need to expand
padding_tensor = vocab_divisibility_padding_tensor.expand(
padded_target_vocab_size - full_hp_param.shape[0])
# Implement the following concat in efficient way using pad
#full_hp_param = torch.cat((full_hp_param, padding_tensor), 0)
full_hp_param = torch.nn.functional.pad(full_hp_param,
(0,
0,
0,
padding_tensor.shape[0]),
"constant",
0)
full_hp_param[:-padding_tensor.shape[0], :] = padding_tensor
else:
# Need to shrink or keep the same
full_hp_param = full_hp_param[:padded_target_vocab_size, :]
full_param_numel = full_hp_param.numel()
tp_slice_numel = self.numel()
# if key == FP32_WEIGHT_KEY and 'word_embeddings.weight' in folder:
# print_rank_0(f'{full_hp_param[:10]=}', force=True)
assert full_param_numel == tp_world_size * tp_slice_numel, \
f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}'
dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment(
key)
# print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
# print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")
# since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse
chunk_dim = ckpt_dict.get('cat_dim', 0)
# this performs the opposite of cat when merging TP slices
tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank]
tp_hp_slice = tp_hp_slice.flatten()
lp_frag_address = hp_mapping.lp_fragment_address
tp_hp_fragment = tp_hp_slice.narrow(0,
lp_frag_address.start,
lp_frag_address.numel)
assert dst_tensor.numel() == lp_frag_address.numel, \
f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}'
# print(f"{key} SHAPE: {tp_hp_slice.shape=}")
# print(f"{key} SHAPE: {dst_tensor.shape=}")
# print(f"{key} SHAPE: {tp_hp_fragment.shape=}")
dst_tensor.data.copy_(tp_hp_fragment.data)
PARAM_SLICE_MAPPINGS)
class BF16_Optimizer(ZeROOptimizer):
......@@ -327,8 +180,13 @@ class BF16_Optimizer(ZeROOptimizer):
# Need optimizer states initialized before linking lp to optimizer state
self._link_all_hp_params()
self._enable_universal_checkpoint()
self._param_slice_mappings = self._create_param_mapping()
def _enable_universal_checkpoint(self):
for lp_param_group in self.bf16_groups:
enable_universal_checkpoint(param_list=lp_param_group)
def _create_param_mapping(self):
param_mapping = []
for i, _ in enumerate(self.optimizer.param_groups):
......@@ -344,93 +202,18 @@ class BF16_Optimizer(ZeROOptimizer):
def _link_all_hp_params(self):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
for i, param_group in enumerate(self.optimizer.param_groups):
for i, _ in enumerate(self.optimizer.param_groups):
# Link bf16 and fp32 params in partition
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
self._link_hp_params(self.bf16_groups[i],
self.fp32_groups_flat_partition[i],
partition_id * partition_size,
partition_size,
self.real_dp_process_group[i])
def _init_lp_to_hp_mapping(self,
lp_param_list,
partition_start,
partition_size,
dp_group):
current_offset = 0
param_and_offset_list = []
partition_end = partition_start + partition_size
for lp_param in lp_param_list:
lp_param._hp_mapping = None
lp_param._dp_group = dp_group
lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param)
lp_param.load_hp_checkpoint_state = types.MethodType(
load_hp_checkpoint_state,
lp_param)
# lp_param overlaps with partition if both are true
# 1) current_offset < partition_end,
# 2) current_offset + lp_param.numel() >= partition_start
lp_param_end = current_offset + lp_param.numel()
if current_offset < partition_end and lp_param_end > partition_start:
param_and_offset_list.append((lp_param, current_offset))
current_offset += lp_param.numel()
return param_and_offset_list
def _link_hp_params(self,
lp_param_list,
flat_hp_partition,
partition_start,
partition_size,
dp_group):
local_lp_param_and_offset = self._init_lp_to_hp_mapping(
lp_param_list,
partition_start,
partition_size,
dp_group)
hp_end = partition_start + partition_size
for lp_param, lp_start in local_lp_param_and_offset:
lp_end = lp_param.numel() + lp_start
hp_start = partition_start
fragment_start = max(lp_start, hp_start)
fragment_end = min(lp_end, hp_end)
# print(
# f'{self.dp_rank=} {lp_start=} {lp_end-lp_start=} {hp_start=} {hp_end-hp_start=} {fragment_start=} {fragment_end-fragment_start=}'
# )
assert fragment_start < fragment_end, \
f'fragment start {fragment_start} should be < fragment_end {fragment_end}'
fragment_numel = fragment_end - fragment_start
hp_frag_address = fragment_address(start=fragment_start - hp_start,
numel=fragment_numel)
hp_fragment_tensor = flat_hp_partition.narrow(0,
hp_frag_address.start,
hp_frag_address.numel)
optim_fragment = {
key: value.narrow(0,
hp_frag_address.start,
hp_frag_address.numel)
for key,
value in self.optimizer.state[flat_hp_partition].items()
if torch.is_tensor(value) and value.dim() > 0
}
lp_frag_address = fragment_address(start=fragment_start - lp_start,
numel=fragment_numel)
lp_fragment_tensor = lp_param.flatten().narrow(0,
lp_frag_address.start,
lp_frag_address.numel)
lp_param._hp_mapping = tensor_fragment(lp_fragment=lp_fragment_tensor,
lp_fragment_address=lp_frag_address,
hp_fragment=hp_fragment_tensor,
hp_fragment_address=hp_frag_address,
optim_fragment=optim_fragment)
flat_hp_partition = self.fp32_groups_flat_partition[i]
link_hp_params(
lp_param_list=self.bf16_groups[i],
flat_hp_partition=flat_hp_partition,
partition_start=partition_id * partition_size,
partition_size=partition_size,
partition_optimizer_state=self.optimizer.state[flat_hp_partition],
dp_group=self.real_dp_process_group[i])
def initialize_optimizer_states(self):
"""Take an optimizer step with zero-valued gradients to allocate internal
......
......@@ -4,4 +4,6 @@ from .comms_logging import get_caller_func
from .init_on_device import OnDevice
from .groups import *
from .nvtx import instrument_w_nvtx
from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping
from .mixed_precision_linkage import link_hp_params
from deepspeed.runtime.dataloader import RepeatingLoader
"""
Copyright 2022 The Microsoft DeepSpeed Team
"""
import types
from deepspeed.utils import get_full_hp_param, get_hp_fragment_mapping
def link_hp_params(lp_param_list,
flat_hp_partition,
partition_start,
partition_size,
partition_optimizer_state,
dp_group):
local_lp_param_and_offset = _init_lp_to_hp_mapping(lp_param_list,
partition_start,
partition_size,
dp_group)
for lp_param, lp_start in local_lp_param_and_offset:
lp_param._hp_mapping = get_hp_fragment_mapping(lp_param,
lp_start,
flat_hp_partition,
partition_start,
partition_size,
partition_optimizer_state)
def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group):
current_offset = 0
param_and_offset_list = []
partition_end = partition_start + partition_size
for lp_param in lp_param_list:
lp_param._hp_mapping = None
lp_param._dp_group = dp_group
lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param)
# lp_param overlaps with partition if both are true
# 1) current_offset < partition_end,
# 2) current_offset + lp_param.numel() >= partition_start
lp_param_end = current_offset + lp_param.numel()
if current_offset < partition_end and lp_param_end > partition_start:
param_and_offset_list.append((lp_param, current_offset))
current_offset += lp_param.numel()
return param_and_offset_list
"""
Copyright 2022 The Microsoft DeepSpeed Team
"""
import torch
from dataclasses import dataclass
from deepspeed import comm as dist
@dataclass
class fragment_address:
numel: int
start: int
@dataclass
class tensor_fragment:
lp_fragment: torch.Tensor
lp_fragment_address: fragment_address
hp_fragment: torch.Tensor
hp_fragment_address: fragment_address
optim_fragment: {}
def update_hp(self):
self.hp_fragment.data.copy_(self.lp_fragment.data)
def update_lp(self):
self.lp_fragment.data.copy_(self.hp_fragment.data)
def get_optim_state_fragment(self, key):
if key in self.optim_fragment:
return self.optim_fragment[key]
else:
raise ValueError(f'{key} not found in optimizer state fragment')
def get_hp_fragment_address(self):
return self.hp_fragment_address
def get_optim_state_keys(self):
return list(self.optim_fragment.keys())
def get_full_hp_param(self, optim_state_key=None):
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
if self._hp_mapping is not None:
lp_frag_address = self._hp_mapping.lp_fragment_address
reduce_fragment = torch.narrow(reduce_buffer,
0,
lp_frag_address.start,
lp_frag_address.numel)
if optim_state_key is None:
hp_fragment = self._hp_mapping.hp_fragment
else:
hp_fragment = self._hp_mapping.get_optim_state_fragment(optim_state_key)
reduce_fragment.data.copy_(hp_fragment.data)
dist.all_reduce(reduce_buffer, group=self._dp_group)
return reduce_buffer.reshape_as(self)
def get_hp_fragment_mapping(lp_param,
lp_start,
flat_hp_partition,
partition_start,
partition_size,
optimizer_state_dict):
lp_end = lp_param.numel() + lp_start
hp_start = partition_start
hp_end = partition_start + partition_size
fragment_start = max(lp_start, hp_start)
fragment_end = min(lp_end, hp_end)
# print(
# f'{self.dp_rank=} {lp_start=} {lp_end-lp_start=} {hp_start=} {hp_end-hp_start=} {fragment_start=} {fragment_end-fragment_start=}'
# )
assert fragment_start < fragment_end, \
f'fragment start {fragment_start} should be < fragment_end {fragment_end}'
fragment_numel = fragment_end - fragment_start
hp_frag_address = fragment_address(start=fragment_start - hp_start,
numel=fragment_numel)
hp_fragment_tensor = flat_hp_partition.narrow(0,
hp_frag_address.start,
hp_frag_address.numel)
optim_fragment = {
key: value.narrow(0,
hp_frag_address.start,
hp_frag_address.numel)
for key,
value in optimizer_state_dict.items()
if torch.is_tensor(value) and value.dim() > 0
}
lp_frag_address = fragment_address(start=fragment_start - lp_start,
numel=fragment_numel)
lp_fragment_tensor = lp_param.flatten().narrow(0,
lp_frag_address.start,
lp_frag_address.numel)
return tensor_fragment(lp_fragment=lp_fragment_tensor,
lp_fragment_address=lp_frag_address,
hp_fragment=hp_fragment_tensor,
hp_fragment_address=hp_frag_address,
optim_fragment=optim_fragment)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册