未验证 提交 6e6eab07 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

Fix multihead op bug. (#20783)

The op should handle k=1024

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>
上级 28dd2a58
......@@ -134,7 +134,7 @@ MultiHeadMatMul Operator.
This op is used for optimize multi head calculation in ernie model.
Not suggest to use in other case except has same structure as ernie.
Example of matrix multiplication with head_number of H
Example of matrix multiplication with head_number of B
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
Both the input `Q` and `K` can carry the LoD (Level of Details) information,
......
......@@ -331,7 +331,7 @@ void MultiHeadGPUCompute(const platform::CUDADeviceContext &dev_ctx,
auto stream = dev_ctx.stream();
int grid = m;
PADDLE_ENFORCE_LT(k, 1024,
PADDLE_ENFORCE_LE(k, 1024,
"Input head_number * size_per_head should <= 1024");
int block = k <= 1024 ? k : 1024;
add_QKV<T><<<grid, block, 0, stream>>>(Q, K, V, q_buf, k_buf, v_buf, bias_q,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册