diff --git a/deepspeed/module_inject/containers/bloom.py b/deepspeed/module_inject/containers/bloom.py index f8b980effe76d9d0e9574d6b86743817ab06ff58..64e219b88342cd8f403d6c80ff4274ebb425cab2 100644 --- a/deepspeed/module_inject/containers/bloom.py +++ b/deepspeed/module_inject/containers/bloom.py @@ -37,7 +37,10 @@ class BLOOMLayerPolicy(TransformerPolicy): inference=True, use_load_prefix=True, split_qkv=False): - super().__init__(inference, linear_layer=True) + super().__init__(inference, + linear_layer=True, + use_load_prefix=use_load_prefix, + split_qkv=split_qkv) self.client_module = client_module try: import transformers diff --git a/deepspeed/module_inject/containers/gptneox.py b/deepspeed/module_inject/containers/gptneox.py index 7e1b568cefafb352f0fc1e52d72eaa6c82a18793..dba55526a4a0a1499b7b7c3e3dbdf88e75869ed9 100644 --- a/deepspeed/module_inject/containers/gptneox.py +++ b/deepspeed/module_inject/containers/gptneox.py @@ -32,7 +32,7 @@ class GPTNEOXLayerPolicy(TransformerPolicy): version = 0 def __init__(self, client_module, inference=True, megatron_v2=True, split_qkv=False): - super().__init__(inference, megatron_v2=megatron_v2) + super().__init__(inference, megatron_v2=megatron_v2, split_qkv=split_qkv) self.client_module = client_module if GPTNEOXLayerPolicy._orig_layer_class is None: if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"): diff --git a/deepspeed/module_inject/containers/opt.py b/deepspeed/module_inject/containers/opt.py index 73dcd419cd90178d7a36b24f1cad3cbf1071dbb4..92a956c1ed965d78e83b0b610cb2e37ef2d8fbb2 100644 --- a/deepspeed/module_inject/containers/opt.py +++ b/deepspeed/module_inject/containers/opt.py @@ -27,7 +27,8 @@ class HFOPTLayerPolicy(TransformerPolicy): super().__init__(inference, linear_layer=True, mlp_act_func_type=ActivationFuncType.ReLU, - pre_attn_norm=True) + pre_attn_norm=True, + use_load_prefix=use_load_prefix) self.client_module = client_module try: