未验证 提交 0712e299 编写于 作者: D Dino Chen 提交者: GitHub

add meta onDevice support for LLAMA2 (#4147)

Co-authored-by: NMolly Smith <112220543+molly-smith@users.noreply.github.com>
上级 f6903190
......@@ -114,6 +114,11 @@ class ReplaceWithTensorSlicing:
class Loading():
def is_load_module(module):
load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
load_layer_names = ["LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm"]
return module.__class__ in load_layers or module._get_name() in load_layer_names
def load_buffer(module, state_dict, prefix):
for name in module._buffers.keys():
if module._buffers[name].data.is_meta:
......@@ -399,8 +404,7 @@ class AutoTP():
else:
class_name = prev_class_name + '.' + prev_name
checking_key = self.prefix + '.' + class_name + '.' + name + '.' if class_name != "" else self.prefix + '.' + name + '.'
if (child.__class__ in [nn.Linear, nn.Embedding, nn.LayerNorm]
or child._get_name() in ["LPLayerNorm", "SharedEmbedding"]) and self.state_dict is not None:
if Loading.is_load_module(child) and self.state_dict is not None:
if any(checking_key in item for item in self.state_dict):
Loading.load(child, self.state_dict, checking_key, self.mp_group)
else:
......
......@@ -16,7 +16,6 @@ from .replace_policy import replace_policies, generic_policies
from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading
from deepspeed import comm as dist
from torch import nn
from .load_checkpoint import load_model_with_checkpoint
import time
......@@ -595,12 +594,6 @@ def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_di
Returns:
Modified ``model``.
"""
try:
import transformers
OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding
except:
OPTLearnedPositionalEmbedding = None
load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm, OPTLearnedPositionalEmbedding]
for name, child in model.named_children():
if child.__class__ in policies:
replaced_module = policies[child.__class__][0](child,
......@@ -616,8 +609,7 @@ def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_di
layer_id += 1
else:
checking_key = prefix + name + '.'
if (child.__class__ in load_layers
or child._get_name() in ["LPLayerNorm", "SharedEmbedding"]) and state_dict is not None:
if Loading.is_load_module(child) and state_dict is not None:
if any(checking_key in item for item in state_dict):
Loading.load(
child,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册