未验证 提交 8859ddd6 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

Refine multihead kernel, align block to 32 (#21961)

* Refine multihead kernel, align block to 32

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>

* Refine log comments

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>
上级 fd9b00df
......@@ -88,30 +88,33 @@ class MultiHeadMatMulOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GT(dim_bias_qk.size(), 3,
"Multihead input bias qk should be at least 4-D tensor.");
int b_size = dim_bias_q.size() - 1;
int size = dim_q.size() - 1;
int b_indx = dim_bias_q.size() - 1;
int indx = dim_q.size() - 1;
PADDLE_ENFORCE_EQ(dim_bias_q[b_size], dim_q[size],
PADDLE_ENFORCE_EQ(
dim_bias_q[b_indx], dim_q[indx],
platform::errors::InvalidArgument(
"bias_q's last dim size should equal to"
" q last dim size, but bias_q's size is:%d q is:%d",
dim_bias_q[b_size], dim_q[size]));
PADDLE_ENFORCE_EQ(dim_bias_k[b_size], dim_k[size],
" q last dim size, but received bias_q's size is:%d q is:%d",
dim_bias_q[b_indx], dim_q[indx]));
PADDLE_ENFORCE_EQ(
dim_bias_k[b_indx], dim_k[indx],
platform::errors::InvalidArgument(
"bias_k's last dim size should equal to"
" k last dim size, but bias_k's size is:%d k is:%d",
dim_bias_k[b_size], dim_k[size]));
PADDLE_ENFORCE_EQ(dim_bias_v[b_size], dim_v[size],
" k last dim size, but received bias_k's size is:%d k is:%d",
dim_bias_k[b_indx], dim_k[indx]));
PADDLE_ENFORCE_EQ(
dim_bias_v[b_indx], dim_v[indx],
platform::errors::InvalidArgument(
"bias_v's last dim size should equal to"
" v last dim size, but bias_v's size is:%d v is:%d",
dim_bias_v[b_size], dim_v[size]));
" v last dim size, but received bias_v's size is:%d v is:%d",
dim_bias_v[b_indx], dim_v[indx]));
PADDLE_ENFORCE_EQ(dim_q[0], dim_bias_qk[0],
platform::errors::InvalidArgument(
"q should have same batch size"
"with bias_qk, but q's batch size:%d not equal to "
"bias_qk's batch size:%d",
"with bias_qk, but received q's batch size is:%d "
"bias_qk's batch size is:%d",
dim_q[0], dim_bias_qk[0]));
int head_number = context->Attrs().Get<int>("head_number");
......
......@@ -197,6 +197,7 @@ __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
const int seq_len,
const unsigned mask) {
int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % 32 == 0);
__shared__ float s_sum, s_max;
......@@ -257,15 +258,29 @@ void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num,
q_buf_, k_buf_, beta, qk_buf_, batch_size * head_num,
seq_len * size_per_head, seq_len * size_per_head);
int m = batch_size * head_num * seq_len;
int k = seq_len;
int grid = m;
int block = k;
int grid = batch_size * head_num * seq_len;
int block = seq_len;
// Align block to 32, also limit seq_len to max block size.
PADDLE_ENFORCE_LE(seq_len, 1024, platform::errors::InvalidArgument(
"seq_len should <= 1024, "
"but received seq_len is:%d",
seq_len));
if (seq_len <= 32)
block = 32;
else if (seq_len > 32 && seq_len <= 64)
block = 64;
else if (seq_len > 64 && seq_len <= 128)
block = 128;
else if (seq_len > 128 && seq_len <= 256)
block = 256;
else if (seq_len > 256 && seq_len <= 512)
block = 512;
else
block = 1024;
unsigned mask = block < 32 ? (((unsigned)1 << block) - 1) : FINAL_MASK;
softmax_kernel_with_eltadd<T><<<grid, block, 0, stream>>>(
qk_buf_, bias_qk, batch_size, head_num, seq_len, mask);
qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK);
}
template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册