未验证 提交 380a9bf7 编写于 作者: C Chitsing KUI 提交者: GitHub

fix backend bug (#52526)

上级 8ac5a6b6
......@@ -229,7 +229,7 @@ PD_REGISTER_KERNEL(flash_attn_unpadded_grad,
phi::FlashAttnUnpaddedGradKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(7).SetBackend(phi::Backend::CPU); // seed_offset
kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset
}
PD_REGISTER_KERNEL(flash_attn_grad,
......@@ -238,5 +238,5 @@ PD_REGISTER_KERNEL(flash_attn_grad,
phi::FlashAttnGradKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(5).SetBackend(phi::Backend::CPU); // seed_offset
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册