未验证 提交 6bd39b5e 编写于 作者: W WangXi 提交者: GitHub

fix inf in fused_attention (#41933)

上级 1d9ee667
......@@ -16,6 +16,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/transpose_op.cu.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace paddle {
......@@ -117,6 +119,18 @@ class FMHARef {
v_ptr = k_ptr + k_size;
}
{
// NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for
// float16 calculation, INF may appear in QK^T if we do not scale before.
float alpha = 1.0 / sqrt(head_dim_);
auto q_tensor = transpose_2_out_tensor->Slice(0, 1);
auto functor = phi::funcs::ScaleFunctor<T>(alpha);
std::vector<const framework::Tensor*> ins = {&q_tensor};
std::vector<framework::Tensor*> outs = {&q_tensor};
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx_, ins,
&outs, functor);
}
// q*k^t, batched_gemm
CBLAS_TRANSPOSE transA = CblasNoTrans;
CBLAS_TRANSPOSE transB = CblasTrans;
......@@ -125,7 +139,7 @@ class FMHARef {
int gemm_m = seq_len_;
int gemm_n = out_seq_len;
int gemm_k = head_dim_;
T alpha = static_cast<T>(1.0 / sqrt(head_dim_));
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
int64_t stride_a = gemm_m * gemm_k;
int64_t stride_b = gemm_k * gemm_n;
......@@ -300,7 +314,9 @@ class FMHARef {
}
T* qk_out_grad_data = qk_out_grad_tensor->data<T>();
alpha = static_cast<T>(1.0 / sqrt(head_dim_));
// NOTE(wangxi): For we scale Q with 1/sqrt(Dh) in forward, so we set
// alpha = 1.0 in backward.
alpha = static_cast<T>(1.0);
// recall batchedgemm(nt) fw: q_ptr * (k_ptr)^t = qk_out
// bw: dy (seq_len * head_dim) = (dout)^t * x
transA = CblasTrans;
......@@ -314,6 +330,7 @@ class FMHARef {
qk_out_grad_data, q_ptr, beta, k_grad_ptr, gemm_batch_size,
stride_a, stride_b);
// dx (seq_len * head_dim) = dout * y
alpha = static_cast<T>(1.0 / sqrt(head_dim_));
transA = CblasNoTrans;
transB = CblasNoTrans;
gemm_m = seq_len_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册