From 042115c80b11862f1e2cd030fd3747cadf9ef868 Mon Sep 17 00:00:00 2001 From: Molly Smith <112220543+molly-smith@users.noreply.github.com> Date: Tue, 29 Aug 2023 07:30:30 -0700 Subject: [PATCH] Fix fused qkv sizing for bloom (#4161) Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/module_inject/fusedqkv_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/deepspeed/module_inject/fusedqkv_utils.py b/deepspeed/module_inject/fusedqkv_utils.py index f25f9f8d..3ddf8f44 100644 --- a/deepspeed/module_inject/fusedqkv_utils.py +++ b/deepspeed/module_inject/fusedqkv_utils.py @@ -68,7 +68,9 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index): return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape] def _bloom_type_transpose(input, mp_size): - return input + shape = input.shape + dst_shape = shape[0] // mp_size + return input[gpu_index * dst_shape:(gpu_index + 1) * dst_shape] def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None): @@ -91,4 +93,4 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index): return _transpose_fused_qkvw(src, mp_size, fused_type) warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type," f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors") - return src + return _bloom_type_transpose(src, mp_size) -- GitLab