未验证 提交 a5a8d144 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

Fix softmax cuda bug (#21720)

* Fix softmax cuda bug

* Refine multihead log and softmax logic
上级 943a4449
......@@ -84,15 +84,36 @@ class MultiHeadMatMulOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(dim_bias_q[0], dim_bias_v[0],
"Multihead input bias should have same batch size");
PADDLE_ENFORCE_EQ(dim_bias_q[1], dim_bias_k[1],
"Multihead input bias should have same size");
PADDLE_ENFORCE_EQ(dim_bias_q[1], dim_bias_v[1],
"Multihead input bias should have same size");
auto dim_bias_qk = context->GetInputDim("BiasQK");
PADDLE_ENFORCE_GT(dim_bias_qk.size(), 3,
"Multihead input bias qk should be at least 4-D tensor.");
int b_size = dim_bias_q.size() - 1;
int size = dim_q.size() - 1;
PADDLE_ENFORCE_EQ(dim_bias_q[b_size], dim_q[size],
platform::errors::InvalidArgument(
"bias_q's last dim size should equal to"
" q last dim size, but bias_q's size is:%d q is:%d",
dim_bias_q[b_size], dim_q[size]));
PADDLE_ENFORCE_EQ(dim_bias_k[b_size], dim_k[size],
platform::errors::InvalidArgument(
"bias_k's last dim size should equal to"
" k last dim size, but bias_k's size is:%d k is:%d",
dim_bias_k[b_size], dim_k[size]));
PADDLE_ENFORCE_EQ(dim_bias_v[b_size], dim_v[size],
platform::errors::InvalidArgument(
"bias_v's last dim size should equal to"
" v last dim size, but bias_v's size is:%d v is:%d",
dim_bias_v[b_size], dim_v[size]));
PADDLE_ENFORCE_EQ(dim_q[0], dim_bias_qk[0],
platform::errors::InvalidArgument(
"q should have same batch size"
"with bias_qk, but q's batch size:%d not equal to "
"bias_qk's batch size:%d",
dim_q[0], dim_bias_qk[0]));
int head_number = context->Attrs().Get<int>("head_number");
PADDLE_ENFORCE_GT(head_number, 1,
"Multihead input head number should be at least 1.");
......
......@@ -196,15 +196,13 @@ __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
const int head_num,
const int seq_len,
const unsigned mask) {
int seq_id = blockIdx.x % seq_len;
int qk_offset = blockIdx.x * seq_len;
int bias_offset = blockIdx.x % (head_num * seq_len) * seq_len;
__shared__ float s_sum, s_max;
float qk = threadIdx.x < seq_len
? static_cast<float>((qk_buf_[threadIdx.x + qk_offset] +
bias_qk_[threadIdx.x + bias_offset]))
bias_qk_[threadIdx.x + qk_offset]))
: 0.0f;
float tmp = threadIdx.x < seq_len ? static_cast<float>(qk) : -1e20f;
......
......@@ -54,7 +54,8 @@ class TestFusedMultiheadMatmulOp(OpTest):
self.BiasK = np.random.random((1, w)).astype("float32")
self.BiasV = np.random.random((1, w)).astype("float32")
self.BiasQK = np.random.random(
(1, self.head_number, self.seq_len, self.seq_len)).astype("float32")
(self.batch_size, self.head_number, self.seq_len,
self.seq_len)).astype("float32")
# Compute Q path
fc_q = self.Q + self.BiasQ
reshape_q = np.reshape(fc_q, (self.batch_size, self.seq_len,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册