未验证 提交 042115c8 编写于 作者: M Molly Smith 提交者: GitHub

Fix fused qkv sizing for bloom (#4161)

Co-authored-by: NLogan Adams <114770087+loadams@users.noreply.github.com>
上级 60d7b0a3
......@@ -68,7 +68,9 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index):
return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]
def _bloom_type_transpose(input, mp_size):
return input
shape = input.shape
dst_shape = shape[0] // mp_size
return input[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]
def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None):
......@@ -91,4 +93,4 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index):
return _transpose_fused_qkvw(src, mp_size, fused_type)
warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type,"
f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors")
return src
return _bloom_type_transpose(src, mp_size)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册