From 8859ddd6cff0a501230fec02f0640b83862842ef Mon Sep 17 00:00:00 2001 From: zhaoyuchen2018 <45989343+zhaoyuchen2018@users.noreply.github.com> Date: Fri, 27 Dec 2019 21:49:38 +0800 Subject: [PATCH] Refine multihead kernel, align block to 32 (#21961) * Refine multihead kernel, align block to 32 test=develop Signed-off-by: zhaoyuchen * Refine log comments test=develop Signed-off-by: zhaoyuchen --- .../operators/fused/multihead_matmul_op.cc | 43 ++++++++++--------- .../operators/fused/multihead_matmul_op.cu | 29 ++++++++++--- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cc b/paddle/fluid/operators/fused/multihead_matmul_op.cc index e1b3822e482..b82cfb2d814 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cc +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cc @@ -88,30 +88,33 @@ class MultiHeadMatMulOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_GT(dim_bias_qk.size(), 3, "Multihead input bias qk should be at least 4-D tensor."); - int b_size = dim_bias_q.size() - 1; - int size = dim_q.size() - 1; - - PADDLE_ENFORCE_EQ(dim_bias_q[b_size], dim_q[size], - platform::errors::InvalidArgument( - "bias_q's last dim size should equal to" - " q last dim size, but bias_q's size is:%d q is:%d", - dim_bias_q[b_size], dim_q[size])); - PADDLE_ENFORCE_EQ(dim_bias_k[b_size], dim_k[size], - platform::errors::InvalidArgument( - "bias_k's last dim size should equal to" - " k last dim size, but bias_k's size is:%d k is:%d", - dim_bias_k[b_size], dim_k[size])); - PADDLE_ENFORCE_EQ(dim_bias_v[b_size], dim_v[size], - platform::errors::InvalidArgument( - "bias_v's last dim size should equal to" - " v last dim size, but bias_v's size is:%d v is:%d", - dim_bias_v[b_size], dim_v[size])); + int b_indx = dim_bias_q.size() - 1; + int indx = dim_q.size() - 1; + + PADDLE_ENFORCE_EQ( + dim_bias_q[b_indx], dim_q[indx], + platform::errors::InvalidArgument( + "bias_q's last dim size should equal to" + " q last dim size, but received bias_q's size is:%d q is:%d", + dim_bias_q[b_indx], dim_q[indx])); + PADDLE_ENFORCE_EQ( + dim_bias_k[b_indx], dim_k[indx], + platform::errors::InvalidArgument( + "bias_k's last dim size should equal to" + " k last dim size, but received bias_k's size is:%d k is:%d", + dim_bias_k[b_indx], dim_k[indx])); + PADDLE_ENFORCE_EQ( + dim_bias_v[b_indx], dim_v[indx], + platform::errors::InvalidArgument( + "bias_v's last dim size should equal to" + " v last dim size, but received bias_v's size is:%d v is:%d", + dim_bias_v[b_indx], dim_v[indx])); PADDLE_ENFORCE_EQ(dim_q[0], dim_bias_qk[0], platform::errors::InvalidArgument( "q should have same batch size" - "with bias_qk, but q's batch size:%d not equal to " - "bias_qk's batch size:%d", + "with bias_qk, but received q's batch size is:%d " + "bias_qk's batch size is:%d", dim_q[0], dim_bias_qk[0])); int head_number = context->Attrs().Get("head_number"); diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cu b/paddle/fluid/operators/fused/multihead_matmul_op.cu index 1c8b422917f..9648b62423e 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cu +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cu @@ -197,6 +197,7 @@ __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_, const int seq_len, const unsigned mask) { int qk_offset = blockIdx.x * seq_len; + assert(blockDim.x % 32 == 0); __shared__ float s_sum, s_max; @@ -257,15 +258,29 @@ void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num, q_buf_, k_buf_, beta, qk_buf_, batch_size * head_num, seq_len * size_per_head, seq_len * size_per_head); - int m = batch_size * head_num * seq_len; - int k = seq_len; - - int grid = m; - int block = k; + int grid = batch_size * head_num * seq_len; + int block = seq_len; + + // Align block to 32, also limit seq_len to max block size. + PADDLE_ENFORCE_LE(seq_len, 1024, platform::errors::InvalidArgument( + "seq_len should <= 1024, " + "but received seq_len is:%d", + seq_len)); + if (seq_len <= 32) + block = 32; + else if (seq_len > 32 && seq_len <= 64) + block = 64; + else if (seq_len > 64 && seq_len <= 128) + block = 128; + else if (seq_len > 128 && seq_len <= 256) + block = 256; + else if (seq_len > 256 && seq_len <= 512) + block = 512; + else + block = 1024; - unsigned mask = block < 32 ? (((unsigned)1 << block) - 1) : FINAL_MASK; softmax_kernel_with_eltadd<<>>( - qk_buf_, bias_qk, batch_size, head_num, seq_len, mask); + qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK); } template -- GitLab