From 2c6e819450282e7509d8ac55682edd1336026e72 Mon Sep 17 00:00:00 2001 From: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> Date: Thu, 2 Feb 2023 23:12:53 -0800 Subject: [PATCH] Fix Checkpoint-loading with Meta-tensor (#2781) * Reset KV-cache at the beginning of text-generation * Pass the ckpt-loading arguments to work with meta-tensor * remove unrelated changes --- deepspeed/module_inject/containers/bloom.py | 5 ++++- deepspeed/module_inject/containers/gptneox.py | 2 +- deepspeed/module_inject/containers/opt.py | 3 ++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/deepspeed/module_inject/containers/bloom.py b/deepspeed/module_inject/containers/bloom.py index f8b980ef..64e219b8 100644 --- a/deepspeed/module_inject/containers/bloom.py +++ b/deepspeed/module_inject/containers/bloom.py @@ -37,7 +37,10 @@ class BLOOMLayerPolicy(TransformerPolicy): inference=True, use_load_prefix=True, split_qkv=False): - super().__init__(inference, linear_layer=True) + super().__init__(inference, + linear_layer=True, + use_load_prefix=use_load_prefix, + split_qkv=split_qkv) self.client_module = client_module try: import transformers diff --git a/deepspeed/module_inject/containers/gptneox.py b/deepspeed/module_inject/containers/gptneox.py index 7e1b568c..dba55526 100644 --- a/deepspeed/module_inject/containers/gptneox.py +++ b/deepspeed/module_inject/containers/gptneox.py @@ -32,7 +32,7 @@ class GPTNEOXLayerPolicy(TransformerPolicy): version = 0 def __init__(self, client_module, inference=True, megatron_v2=True, split_qkv=False): - super().__init__(inference, megatron_v2=megatron_v2) + super().__init__(inference, megatron_v2=megatron_v2, split_qkv=split_qkv) self.client_module = client_module if GPTNEOXLayerPolicy._orig_layer_class is None: if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"): diff --git a/deepspeed/module_inject/containers/opt.py b/deepspeed/module_inject/containers/opt.py index 73dcd419..92a956c1 100644 --- a/deepspeed/module_inject/containers/opt.py +++ b/deepspeed/module_inject/containers/opt.py @@ -27,7 +27,8 @@ class HFOPTLayerPolicy(TransformerPolicy): super().__init__(inference, linear_layer=True, mlp_act_func_type=ActivationFuncType.ReLU, - pre_attn_norm=True) + pre_attn_norm=True, + use_load_prefix=use_load_prefix) self.client_module = client_module try: -- GitLab