From a5a8d14414213fadcfcd7dc60c794d1a515a390e Mon Sep 17 00:00:00 2001 From: zhaoyuchen2018 <45989343+zhaoyuchen2018@users.noreply.github.com> Date: Mon, 16 Dec 2019 20:40:23 +0800 Subject: [PATCH] Fix softmax cuda bug (#21720) * Fix softmax cuda bug * Refine multihead log and softmax logic --- .../operators/fused/multihead_matmul_op.cc | 31 ++++++++++++++++--- .../operators/fused/multihead_matmul_op.cu | 4 +-- .../test_fused_multihead_matmul_op.py | 3 +- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cc b/paddle/fluid/operators/fused/multihead_matmul_op.cc index fbf372ba6e1..e1b3822e482 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cc +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cc @@ -84,15 +84,36 @@ class MultiHeadMatMulOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(dim_bias_q[0], dim_bias_v[0], "Multihead input bias should have same batch size"); - PADDLE_ENFORCE_EQ(dim_bias_q[1], dim_bias_k[1], - "Multihead input bias should have same size"); - PADDLE_ENFORCE_EQ(dim_bias_q[1], dim_bias_v[1], - "Multihead input bias should have same size"); - auto dim_bias_qk = context->GetInputDim("BiasQK"); PADDLE_ENFORCE_GT(dim_bias_qk.size(), 3, "Multihead input bias qk should be at least 4-D tensor."); + int b_size = dim_bias_q.size() - 1; + int size = dim_q.size() - 1; + + PADDLE_ENFORCE_EQ(dim_bias_q[b_size], dim_q[size], + platform::errors::InvalidArgument( + "bias_q's last dim size should equal to" + " q last dim size, but bias_q's size is:%d q is:%d", + dim_bias_q[b_size], dim_q[size])); + PADDLE_ENFORCE_EQ(dim_bias_k[b_size], dim_k[size], + platform::errors::InvalidArgument( + "bias_k's last dim size should equal to" + " k last dim size, but bias_k's size is:%d k is:%d", + dim_bias_k[b_size], dim_k[size])); + PADDLE_ENFORCE_EQ(dim_bias_v[b_size], dim_v[size], + platform::errors::InvalidArgument( + "bias_v's last dim size should equal to" + " v last dim size, but bias_v's size is:%d v is:%d", + dim_bias_v[b_size], dim_v[size])); + + PADDLE_ENFORCE_EQ(dim_q[0], dim_bias_qk[0], + platform::errors::InvalidArgument( + "q should have same batch size" + "with bias_qk, but q's batch size:%d not equal to " + "bias_qk's batch size:%d", + dim_q[0], dim_bias_qk[0])); + int head_number = context->Attrs().Get("head_number"); PADDLE_ENFORCE_GT(head_number, 1, "Multihead input head number should be at least 1."); diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cu b/paddle/fluid/operators/fused/multihead_matmul_op.cu index 74bc7731a93..1c8b422917f 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cu +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cu @@ -196,15 +196,13 @@ __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_, const int head_num, const int seq_len, const unsigned mask) { - int seq_id = blockIdx.x % seq_len; int qk_offset = blockIdx.x * seq_len; - int bias_offset = blockIdx.x % (head_num * seq_len) * seq_len; __shared__ float s_sum, s_max; float qk = threadIdx.x < seq_len ? static_cast((qk_buf_[threadIdx.x + qk_offset] + - bias_qk_[threadIdx.x + bias_offset])) + bias_qk_[threadIdx.x + qk_offset])) : 0.0f; float tmp = threadIdx.x < seq_len ? static_cast(qk) : -1e20f; diff --git a/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py b/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py index ffb4cee8fb4..9890cbb1222 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py @@ -54,7 +54,8 @@ class TestFusedMultiheadMatmulOp(OpTest): self.BiasK = np.random.random((1, w)).astype("float32") self.BiasV = np.random.random((1, w)).astype("float32") self.BiasQK = np.random.random( - (1, self.head_number, self.seq_len, self.seq_len)).astype("float32") + (self.batch_size, self.head_number, self.seq_len, + self.seq_len)).astype("float32") # Compute Q path fc_q = self.Q + self.BiasQ reshape_q = np.reshape(fc_q, (self.batch_size, self.seq_len, -- GitLab