diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 54e4cbdc1624921e6946210a6a192d10fcbdb7dd..6eb5881112f8916d93f3e7531f0a8d87f83bf340 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -16,6 +16,8 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/transpose_op.cu.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/funcs/functors.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" namespace paddle { @@ -117,6 +119,18 @@ class FMHARef { v_ptr = k_ptr + k_size; } + { + // NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for + // float16 calculation, INF may appear in QK^T if we do not scale before. + float alpha = 1.0 / sqrt(head_dim_); + auto q_tensor = transpose_2_out_tensor->Slice(0, 1); + auto functor = phi::funcs::ScaleFunctor(alpha); + std::vector ins = {&q_tensor}; + std::vector outs = {&q_tensor}; + paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx_, ins, + &outs, functor); + } + // q*k^t, batched_gemm CBLAS_TRANSPOSE transA = CblasNoTrans; CBLAS_TRANSPOSE transB = CblasTrans; @@ -125,7 +139,7 @@ class FMHARef { int gemm_m = seq_len_; int gemm_n = out_seq_len; int gemm_k = head_dim_; - T alpha = static_cast(1.0 / sqrt(head_dim_)); + T alpha = static_cast(1.0); T beta = static_cast(0.0); int64_t stride_a = gemm_m * gemm_k; int64_t stride_b = gemm_k * gemm_n; @@ -300,7 +314,9 @@ class FMHARef { } T* qk_out_grad_data = qk_out_grad_tensor->data(); - alpha = static_cast(1.0 / sqrt(head_dim_)); + // NOTE(wangxi): For we scale Q with 1/sqrt(Dh) in forward, so we set + // alpha = 1.0 in backward. + alpha = static_cast(1.0); // recall batchedgemm(nt) fw: q_ptr * (k_ptr)^t = qk_out // bw: dy (seq_len * head_dim) = (dout)^t * x transA = CblasTrans; @@ -314,6 +330,7 @@ class FMHARef { qk_out_grad_data, q_ptr, beta, k_grad_ptr, gemm_batch_size, stride_a, stride_b); // dx (seq_len * head_dim) = dout * y + alpha = static_cast(1.0 / sqrt(head_dim_)); transA = CblasNoTrans; transB = CblasNoTrans; gemm_m = seq_len_;