未验证 提交 380a9bf7 编写于 作者: C Chitsing KUI 提交者: GitHub

fix backend bug (#52526)

上级 8ac5a6b6
...@@ -229,7 +229,7 @@ PD_REGISTER_KERNEL(flash_attn_unpadded_grad, ...@@ -229,7 +229,7 @@ PD_REGISTER_KERNEL(flash_attn_unpadded_grad,
phi::FlashAttnUnpaddedGradKernel, phi::FlashAttnUnpaddedGradKernel,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) { phi::dtype::bfloat16) {
kernel->InputAt(7).SetBackend(phi::Backend::CPU); // seed_offset kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset
} }
PD_REGISTER_KERNEL(flash_attn_grad, PD_REGISTER_KERNEL(flash_attn_grad,
...@@ -238,5 +238,5 @@ PD_REGISTER_KERNEL(flash_attn_grad, ...@@ -238,5 +238,5 @@ PD_REGISTER_KERNEL(flash_attn_grad,
phi::FlashAttnGradKernel, phi::FlashAttnGradKernel,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) { phi::dtype::bfloat16) {
kernel->InputAt(5).SetBackend(phi::Backend::CPU); // seed_offset kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册