diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index e89092252fbcfaf777722de715ef1e189c8a50b3..8e75ecc473f2cb758afec60e183233367284d41f 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -229,7 +229,7 @@ PD_REGISTER_KERNEL(flash_attn_unpadded_grad, phi::FlashAttnUnpaddedGradKernel, phi::dtype::float16, phi::dtype::bfloat16) { - kernel->InputAt(7).SetBackend(phi::Backend::CPU); // seed_offset + kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset } PD_REGISTER_KERNEL(flash_attn_grad, @@ -238,5 +238,5 @@ PD_REGISTER_KERNEL(flash_attn_grad, phi::FlashAttnGradKernel, phi::dtype::float16, phi::dtype::bfloat16) { - kernel->InputAt(5).SetBackend(phi::Backend::CPU); // seed_offset + kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset }