未验证 提交 1164626c 编写于 作者: R Ruibiao Chen 提交者: GitHub

Fix fused_attention_op and fused_feedforward_op bugs in xpu (#53318)

* Fix fused_attention_op and fused_feedforward_op bugs in xpu

* Fix d_x alloc errors for fused_feedforward_grad_kernel
上级 e72cad59
......@@ -478,7 +478,7 @@ void FusedFeedForwardGradKernel(
dropout2_fix_seed,
nullptr,
dropout2_seed_val);
dev_ctx.template Alloc<T>(d_x);
dev_ctx.template Alloc<float>(d_ln_scale);
dev_ctx.template Alloc<float>(d_ln_bias);
dev_ctx.template Alloc<T>(d_linear1_bias);
......@@ -529,7 +529,7 @@ void FusedFeedForwardGradKernel(
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_feedward_grad,
PD_REGISTER_KERNEL(fused_feedforward_grad,
XPU,
ALL_LAYOUT,
phi::fusion::FusedFeedForwardGradKernel,
......
......@@ -377,7 +377,7 @@ void FusedFeedForwardKernel(const Context& dev_ctx,
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_feedward,
PD_REGISTER_KERNEL(fused_feedforward,
XPU,
ALL_LAYOUT,
phi::fusion::FusedFeedForwardKernel,
......
......@@ -181,12 +181,12 @@ void FusedAttentionKernel(const Context &dev_ctx,
float *ln_mean_ptr =
(ln_mean == nullptr)
? (nullptr)
: reinterpret_cast<float *>(dev_ctx.template Alloc<T>(ln_mean));
: reinterpret_cast<float *>(dev_ctx.template Alloc<float>(ln_mean));
float *ln_var_ptr =
(ln_var == nullptr)
? (nullptr)
: reinterpret_cast<float *>(dev_ctx.template Alloc<T>(ln_var));
: reinterpret_cast<float *>(dev_ctx.template Alloc<float>(ln_var));
XPUTypeT *ln_out_ptr =
(ln_out == nullptr)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册