diff --git a/paddle/fluid/operators/multihead_matmul_op.cc b/paddle/fluid/operators/multihead_matmul_op.cc index fbf372ba6e15aca7b849a8696ac5551dc383ee51..b82cfb2d81422210addcdd7c3b6955263e769113 100644 --- a/paddle/fluid/operators/multihead_matmul_op.cc +++ b/paddle/fluid/operators/multihead_matmul_op.cc @@ -84,15 +84,39 @@ class MultiHeadMatMulOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(dim_bias_q[0], dim_bias_v[0], "Multihead input bias should have same batch size"); - PADDLE_ENFORCE_EQ(dim_bias_q[1], dim_bias_k[1], - "Multihead input bias should have same size"); - PADDLE_ENFORCE_EQ(dim_bias_q[1], dim_bias_v[1], - "Multihead input bias should have same size"); - auto dim_bias_qk = context->GetInputDim("BiasQK"); PADDLE_ENFORCE_GT(dim_bias_qk.size(), 3, "Multihead input bias qk should be at least 4-D tensor."); + int b_indx = dim_bias_q.size() - 1; + int indx = dim_q.size() - 1; + + PADDLE_ENFORCE_EQ( + dim_bias_q[b_indx], dim_q[indx], + platform::errors::InvalidArgument( + "bias_q's last dim size should equal to" + " q last dim size, but received bias_q's size is:%d q is:%d", + dim_bias_q[b_indx], dim_q[indx])); + PADDLE_ENFORCE_EQ( + dim_bias_k[b_indx], dim_k[indx], + platform::errors::InvalidArgument( + "bias_k's last dim size should equal to" + " k last dim size, but received bias_k's size is:%d k is:%d", + dim_bias_k[b_indx], dim_k[indx])); + PADDLE_ENFORCE_EQ( + dim_bias_v[b_indx], dim_v[indx], + platform::errors::InvalidArgument( + "bias_v's last dim size should equal to" + " v last dim size, but received bias_v's size is:%d v is:%d", + dim_bias_v[b_indx], dim_v[indx])); + + PADDLE_ENFORCE_EQ(dim_q[0], dim_bias_qk[0], + platform::errors::InvalidArgument( + "q should have same batch size" + "with bias_qk, but received q's batch size is:%d " + "bias_qk's batch size is:%d", + dim_q[0], dim_bias_qk[0])); + int head_number = context->Attrs().Get("head_number"); PADDLE_ENFORCE_GT(head_number, 1, "Multihead input head number should be at least 1."); diff --git a/paddle/fluid/operators/multihead_matmul_op.cu b/paddle/fluid/operators/multihead_matmul_op.cu index 74bc7731a93dec045d03bde46627bbd57d11daca..9648b62423e94528f5c73069baee9724c4babfe9 100644 --- a/paddle/fluid/operators/multihead_matmul_op.cu +++ b/paddle/fluid/operators/multihead_matmul_op.cu @@ -196,15 +196,14 @@ __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_, const int head_num, const int seq_len, const unsigned mask) { - int seq_id = blockIdx.x % seq_len; int qk_offset = blockIdx.x * seq_len; - int bias_offset = blockIdx.x % (head_num * seq_len) * seq_len; + assert(blockDim.x % 32 == 0); __shared__ float s_sum, s_max; float qk = threadIdx.x < seq_len ? static_cast((qk_buf_[threadIdx.x + qk_offset] + - bias_qk_[threadIdx.x + bias_offset])) + bias_qk_[threadIdx.x + qk_offset])) : 0.0f; float tmp = threadIdx.x < seq_len ? static_cast(qk) : -1e20f; @@ -259,15 +258,29 @@ void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num, q_buf_, k_buf_, beta, qk_buf_, batch_size * head_num, seq_len * size_per_head, seq_len * size_per_head); - int m = batch_size * head_num * seq_len; - int k = seq_len; - - int grid = m; - int block = k; + int grid = batch_size * head_num * seq_len; + int block = seq_len; + + // Align block to 32, also limit seq_len to max block size. + PADDLE_ENFORCE_LE(seq_len, 1024, platform::errors::InvalidArgument( + "seq_len should <= 1024, " + "but received seq_len is:%d", + seq_len)); + if (seq_len <= 32) + block = 32; + else if (seq_len > 32 && seq_len <= 64) + block = 64; + else if (seq_len > 64 && seq_len <= 128) + block = 128; + else if (seq_len > 128 && seq_len <= 256) + block = 256; + else if (seq_len > 256 && seq_len <= 512) + block = 512; + else + block = 1024; - unsigned mask = block < 32 ? (((unsigned)1 << block) - 1) : FINAL_MASK; softmax_kernel_with_eltadd<<>>( - qk_buf_, bias_qk, batch_size, head_num, seq_len, mask); + qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK); } template diff --git a/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py b/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py index ffb4cee8fb4d58d485b26128346e200f8f492611..9890cbb12220a352b0d626a01587fd3497745543 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py @@ -54,7 +54,8 @@ class TestFusedMultiheadMatmulOp(OpTest): self.BiasK = np.random.random((1, w)).astype("float32") self.BiasV = np.random.random((1, w)).astype("float32") self.BiasQK = np.random.random( - (1, self.head_number, self.seq_len, self.seq_len)).astype("float32") + (self.batch_size, self.head_number, self.seq_len, + self.seq_len)).astype("float32") # Compute Q path fc_q = self.Q + self.BiasQ reshape_q = np.reshape(fc_q, (self.batch_size, self.seq_len,