未验证 提交 d10b8ca0 编写于 作者: W Wang, Yi 提交者: GitHub

add sharded checkpoint loading for AutoTP path to reduce the peak mem… (#3102)

* add sharded checkpoint loading for AutoTP path to reduce the peak memory in initialization stage
Signed-off-by: NWang, Yi A <yi.a.wang@intel.com>

* fix gptj sharded checkpoint loading problem
Signed-off-by: NWang, Yi A <yi.a.wang@intel.com>

---------
Signed-off-by: NWang, Yi A <yi.a.wang@intel.com>
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 0a61d5d6
......@@ -151,9 +151,6 @@ class InferenceEngine(Module):
assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \
"If you want to use cuda graph, please upgrade torch to at least v1.10"
if config.checkpoint and not config.replace_with_kernel_inject:
self._load_checkpoint(config.checkpoint)
# convert model to intended dtype
if config.dtype:
self._convert_to_dtype(config)
......@@ -173,10 +170,6 @@ class InferenceEngine(Module):
if moe and dist.get_world_size() > 1:
self._create_ep_parallel_group(config.moe.moe_experts)
# retain this from the old conditional argument being passed to apply_injection_policy()
if not config.replace_with_kernel_inject:
config.checkpoint = None
# We only support three modes: 1) user specified policy for tensor-parallelism, 2) kernel injection (replace_with_kernel_inject), and 3) automatic tensor parallelism.
if self.injection_dict:
# 1. User specified Tensor Parallelism
......@@ -343,6 +336,11 @@ class InferenceEngine(Module):
def load(module, state_dict, prefix):
args = (state_dict, prefix, {}, True, [], [], error_msgs)
if hasattr(module, 'weight'):
if module.weight.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.weight = torch.nn.parameter.Parameter(data=torch.empty_like(module.weight.data,
device="cpu"),
requires_grad=module.weight.data.requires_grad)
if 'query_key_value' in prefix:
module.weight = self.mp_replace.strided_copy(module.weight.data,
state_dict[prefix + 'weight'],
......@@ -350,11 +348,26 @@ class InferenceEngine(Module):
else:
module.weight = self.mp_replace.copy(module.weight.data, state_dict[prefix + 'weight'])
else:
if module.norm.weight.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.norm.weight = torch.nn.parameter.Parameter(
data=torch.empty_like(module.norm.weight.data, device="cpu"),
requires_grad=module.norm.weight.data.requires_grad)
module.norm.weight = self.mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight'])
if prefix + 'bias' in self.key_list:
if hasattr(module, 'norm'):
if module.norm.bias.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.norm.bias = torch.nn.parameter.Parameter(
data=torch.empty_like(module.norm.bias.data, device="cpu"),
requires_grad=module.norm.bias.data.requires_grad)
module.norm.bias = self.mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias'])
else:
if module.bias.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data,
device="cpu"),
requires_grad=module.bias.data.requires_grad)
data = state_dict[prefix + 'bias']
data = data.to(get_accelerator().current_device_name())
module.bias = self.mp_replace.copy(module.bias, data)
......@@ -383,6 +396,15 @@ class InferenceEngine(Module):
load_module_recursive(r_module)
embedding_weight = None
for n, p in r_module.named_parameters():
if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
embedding_weight = p
if embedding_weight is not None and hasattr(r_module, "lm_head") and hasattr(
r_module.lm_head, "weight") and r_module.lm_head.weight.is_meta:
r_module.lm_head.weight = embedding_weight
def _apply_injection_policy(self, config, client_module=None):
# client_module is only passed when using the injection_dict method.
checkpoint_dir = config.checkpoint
......@@ -434,16 +456,18 @@ class InferenceEngine(Module):
else:
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir, self.checkpoint_engine)
if type(sd_loader) is list:
self.sd = torch.load(sd_loader[0], map_location='cpu')
checkpoint = sd_loader['checkpoints']
if type(checkpoint) is list:
self.sd = torch.load(checkpoint[0], map_location='cpu')
self.key_list = list(self.sd.keys())
self.load_model_with_checkpoint(self.module)
for i in range(1, len(sd_loader)):
for i in range(1, len(checkpoint)):
if not dist.is_initialized() or dist.get_rank() == 0:
print(f"loading checkpoint ({i})")
self.sd = torch.load(sd_loader[i], map_location=get_accelerator().device_name())
self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name())
self.key_list = list(self.sd.keys())
self.load_model_with_checkpoint(self.module)
else:
......
......@@ -25,6 +25,8 @@ import time
from .utils import policy_to_ds_container
import gc
class ReplaceWithTensorSlicing:
......@@ -361,11 +363,13 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
return _container.module
def replace_wo_policy(module, all_reduce_linears):
def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
mp_size = config.tensor_parallel.tp_size
mp_group = config.tensor_parallel.tp_group
def _replace(child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
return
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
weight_shape = child.weight.shape
if name in all_reduce_linears:
......@@ -381,6 +385,7 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
new_bias = torch.empty((weight_shape[0]), device=child.weight.device, dtype=child.weight.dtype)
if child.bias is not None:
new_bias.data.copy_(child.bias.data)
setattr(child, "replaced", True)
return LinearAllreduce(data, child.bias if child.bias is None else \
torch.nn.parameter.Parameter(new_bias.to(get_accelerator().current_device_name())), mp_group)
else:
......@@ -402,6 +407,8 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
return LinearLayer(weight=data.to(get_accelerator().current_device_name()), bias=bias_data)
def _slice_embedding(child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
return
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
new_weight = torch.empty((child.weight.shape[0], child.weight.shape[1] // mp_size),
device=child.weight.device,
......@@ -411,9 +418,12 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
child.weight.data)
new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // mp_size)
new_embedding.weight.data.copy_(data)
setattr(child, "replaced", True)
return new_embedding
def update_mp_params(child):
if getattr(child, "replaced", False) == True:
return
if hasattr(child, 'n_heads'):
assert child.n_heads % mp_size == 0, "n_heads ({}) must be divisible by mp_size ({})".format(
child.n_heads, mp_size)
......@@ -446,6 +456,7 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
assert child.hidden_size % mp_size == 0, "hidden_size ({}) must be divisible by mp_size ({})".format(
child.hidden_size, mp_size)
child.hidden_size = child.hidden_size // mp_size
setattr(child, "replaced", True)
conv_linear_layer = False
if linear_layer_setting is not None:
......@@ -465,6 +476,14 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
def _replace_module(r_module, prev_name=''):
for name, child in r_module.named_children():
checking_key = prefix + '.' + prev_name + '.' + name + '.' if prev_name != "" else prefix + '.' + name + '.'
if child.__class__ in [nn.Linear, nn.Embedding, nn.LayerNorm] and state_dict != None:
if any(checking_key in item for item in state_dict):
load(child, state_dict, checking_key, mp_group)
else:
continue
if len(child._buffers) != 0 and state_dict != None:
load_buffer(child, state_dict, checking_key)
if child.__class__ in linear_policies:
setattr(r_module, name, linear_policies[child.__class__](child, prev_name + '.' + name,
conv_linear_layer))
......@@ -475,7 +494,7 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
return _replace_module(module)
def replace_fn(child, _policy, layer_id=0):
def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
training = False # todo: refactor this part to go in the config
if training:
# copy relevant state from child -> new module
......@@ -490,19 +509,32 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
inference=True,
layer_id=layer_id)
else:
new_module = replace_wo_policy(child, _policy)
new_module = replace_wo_policy(child, _policy, prefix=prefix, state_dict=state_dict)
return new_module
replaced_module = replace_module(model=model,
orig_class=orig_layer_impl,
replace_fn=replace_fn,
_replace_policy=config.injection_policy_tuple)
if checkpoint_dict != None and not config.replace_with_kernel_inject:
# AutoTP shard loading
checkpoint = checkpoint_dict["checkpoints"]
pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")
for i in range(len(checkpoint)):
replaced_module = replace_module(model=model,
orig_class=orig_layer_impl,
replace_fn=replace_fn,
_replace_policy=config.injection_policy_tuple,
checkpoint=checkpoint[i])
pbar.update(1)
gc.collect()
else:
replaced_module = replace_module(model=model,
orig_class=orig_layer_impl,
replace_fn=replace_fn,
_replace_policy=config.injection_policy_tuple)
quantizer = GroupQuantizer(q_int8=quantize)
world_size = dist.get_world_size() if dist.is_initialized() else 1
rank = dist.get_rank() if dist.is_initialized() else 0
if checkpoint_dict is not None:
if checkpoint_dict is not None and config.replace_with_kernel_inject:
assert container_g.ckpt_load_enabled, \
f"Meta Tensor checkpoint loading not supported in {container_g.__class__.__name__} container"
start_time = time.time()
......@@ -527,7 +559,6 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
container=container_g)
pbar.update(1)
else:
import gc
num_checkpoints = len(ckpt_list) // ckpt_mp_size
tp_split_size = (world_size / ckpt_mp_size)
sd_offset = int(rank / tp_split_size)
......@@ -709,7 +740,7 @@ def revert_transformer_layer(orig_layer_impl, model, config, preln=False):
_replace_policy=None)
def replace_module(model, orig_class, replace_fn, _replace_policy):
def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=None):
""" Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``.
Arguments:
model (torch.nn.Module): the model to augment
......@@ -719,6 +750,9 @@ def replace_module(model, orig_class, replace_fn, _replace_policy):
Returns:
A modified ``model``.
"""
sd = None
if checkpoint != None:
sd = torch.load(checkpoint, map_location='cpu')
policy = {}
if orig_class is not None:
policy.update({orig_class: (replace_fn, _replace_policy)})
......@@ -735,14 +769,48 @@ def replace_module(model, orig_class, replace_fn, _replace_policy):
"No default policy found! Please specify your policy injection_policy (like {BertLayer:HFBEertLayerPolicy})." +\
"You can find some samples here: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py"
replaced_module, _ = _replace_module(model, policy)
replaced_module, _ = _replace_module(model, policy, state_dict=sd)
if checkpoint != None:
embedding_weight = None
for n, p in replaced_module.named_parameters():
if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
embedding_weight = p
if embedding_weight is not None and hasattr(replaced_module, "lm_head") and hasattr(
replaced_module.lm_head, "weight") and replaced_module.lm_head.weight.is_meta:
replaced_module.lm_head.weight = embedding_weight
return replaced_module
from ..pipe import PipelineModule
import re
def skip_level_0_prefix(model, name):
model = str(model)
key = re.search(r": (.*?)Model", model)
if key is None:
key = re.search(r": (.*?)Stack", model)
if key is None:
key = re.match(r"(.*?)Model", model)
if key is not None and key.group(1).lower() in "bloom":
# if keys start with 'model.', don't skip level 0 prefix
if not re.match("^model[.]", name):
return True
return False
def load_buffer(module, state_dict, prefix):
for name in module._buffers.keys():
if prefix + name in state_dict.keys():
if module._buffers[name].data.is_meta:
module._buffers[name] = torch.nn.parameter.Parameter(
data=torch.empty_like(module._buffers[name].data, device="cpu"),
requires_grad=module._buffers[name].data.requires_grad)
module._buffers[name].data.copy_(state_dict[prefix + name])
def _replace_module(model, policies, layer_id=0):
def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_dict=None):
""" Traverse model's children recursively and apply any transformations in ``policies``.
Arguments:
model (torch.nn.Module): model to augment
......@@ -750,9 +818,14 @@ def _replace_module(model, policies, layer_id=0):
Returns:
Modified ``model``.
"""
load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
for name, child in model.named_children():
if child.__class__ in policies:
replaced_module = policies[child.__class__][0](child, policies[child.__class__][-1], layer_id)
replaced_module = policies[child.__class__][0](child,
policies[child.__class__][-1],
layer_id,
prefix=prefix + name,
state_dict=state_dict)
setattr(model, name, replaced_module)
if isinstance(model, PipelineModule):
assert hasattr(model, 'forward_funcs'),\
......@@ -760,8 +833,63 @@ def _replace_module(model, policies, layer_id=0):
model.forward_funcs[model.fwd_map[name]] = replaced_module
layer_id += 1
else:
_, layer_id = _replace_module(child, policies, layer_id=layer_id)
checking_key = prefix + name + '.'
if child.__class__ in load_layers and state_dict != None:
if any(checking_key in item for item in state_dict):
load(
child,
state_dict,
checking_key,
)
else:
continue
if len(child._buffers) != 0 and state_dict != None:
load_buffer(child, state_dict, checking_key)
_, layer_id = _replace_module(child,
policies,
prefix if level_id == 0 and skip_level_0_prefix(model, name) else \
prefix + name + '.',
layer_id=layer_id,
level_id=level_id + 1,
state_dict=state_dict)
# Add the reset_cache func to the model, so that it can be called in the beginning of text-generation.
model.reset_cache = transformer_inference.DeepSpeedTransformerInference.reset_cache
return model, layer_id
def load(module, state_dict, prefix, mp_group=None):
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
if hasattr(module, 'weight'):
if module.weight.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.weight = torch.nn.parameter.Parameter(data=torch.empty_like(module.weight.data, device="cpu"),
requires_grad=module.weight.data.requires_grad)
if 'query_key_value' in prefix:
module.weight = mp_replace.qkv_copy(module.weight.data, state_dict[prefix + 'weight'])
else:
module.weight = mp_replace.copy(module.weight.data, state_dict[prefix + 'weight'])
else:
if hasattr(module, 'norm') and hasattr(module.norm, 'weight'):
if module.norm.weight.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.norm.weight = torch.nn.parameter.Parameter(data=torch.empty_like(module.norm.weight.data,
device="cpu"),
requires_grad=module.norm.weight.data.requires_grad)
module.norm.weight = mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight'])
if prefix + 'bias' in state_dict.keys():
if hasattr(module, 'bias'):
if module.bias.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data, device="cpu"),
requires_grad=module.bias.data.requires_grad)
module.bias = mp_replace.copy(module.bias, state_dict[prefix + 'bias'])
else:
if hasattr(module, 'norm') and hasattr(module.norm, 'bias'):
if module.norm.bias.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.norm.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.norm.bias.data,
device="cpu"),
requires_grad=module.norm.bias.data.requires_grad)
module.norm.bias = mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias'])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册