未验证 提交 8920308c 编写于 作者: R Reza Yazdani 提交者: GitHub

Fix the tensor-slicing copy for qkv parameters (#2198)

Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 28dfca8a
......@@ -483,8 +483,8 @@ def replace_transformer_layer(orig_layer_impl,
attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w)
attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b)
else:
attn_block.attn_qkvw = mp_replace.copy(attn_block.attn_qkvw, qkvw)
attn_block.attn_qkvb = mp_replace.copy(attn_block.attn_qkvb, qkvb)
attn_block.attn_qkvw = mp_replace.qkv_copy(attn_block.attn_qkvw, qkvw)
attn_block.attn_qkvb = mp_replace.qkv_copy(attn_block.attn_qkvb, qkvb)
attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w)
attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册