From 0712e299205b641ddce7f2c010d45f8fe635ed67 Mon Sep 17 00:00:00 2001 From: Dino Chen Date: Fri, 25 Aug 2023 07:22:14 +0800 Subject: [PATCH] add meta onDevice support for LLAMA2 (#4147) Co-authored-by: Molly Smith <112220543+molly-smith@users.noreply.github.com> --- deepspeed/module_inject/auto_tp.py | 8 ++++++-- deepspeed/module_inject/replace_module.py | 10 +--------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index b3d46de8..784700d7 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -114,6 +114,11 @@ class ReplaceWithTensorSlicing: class Loading(): + def is_load_module(module): + load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm] + load_layer_names = ["LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm"] + return module.__class__ in load_layers or module._get_name() in load_layer_names + def load_buffer(module, state_dict, prefix): for name in module._buffers.keys(): if module._buffers[name].data.is_meta: @@ -399,8 +404,7 @@ class AutoTP(): else: class_name = prev_class_name + '.' + prev_name checking_key = self.prefix + '.' + class_name + '.' + name + '.' if class_name != "" else self.prefix + '.' + name + '.' - if (child.__class__ in [nn.Linear, nn.Embedding, nn.LayerNorm] - or child._get_name() in ["LPLayerNorm", "SharedEmbedding"]) and self.state_dict is not None: + if Loading.is_load_module(child) and self.state_dict is not None: if any(checking_key in item for item in self.state_dict): Loading.load(child, self.state_dict, checking_key, self.mp_group) else: diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 7ed0f8f1..ee573f36 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -16,7 +16,6 @@ from .replace_policy import replace_policies, generic_policies from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading from deepspeed import comm as dist -from torch import nn from .load_checkpoint import load_model_with_checkpoint import time @@ -595,12 +594,6 @@ def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_di Returns: Modified ``model``. """ - try: - import transformers - OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding - except: - OPTLearnedPositionalEmbedding = None - load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm, OPTLearnedPositionalEmbedding] for name, child in model.named_children(): if child.__class__ in policies: replaced_module = policies[child.__class__][0](child, @@ -616,8 +609,7 @@ def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_di layer_id += 1 else: checking_key = prefix + name + '.' - if (child.__class__ in load_layers - or child._get_name() in ["LPLayerNorm", "SharedEmbedding"]) and state_dict is not None: + if Loading.is_load_module(child) and state_dict is not None: if any(checking_key in item for item in state_dict): Loading.load( child, -- GitLab