未验证 提交 1164626c 编写于 作者: R Ruibiao Chen 提交者: GitHub

Fix fused_attention_op and fused_feedforward_op bugs in xpu (#53318)

* Fix fused_attention_op and fused_feedforward_op bugs in xpu

* Fix d_x alloc errors for fused_feedforward_grad_kernel
上级 e72cad59
...@@ -478,7 +478,7 @@ void FusedFeedForwardGradKernel( ...@@ -478,7 +478,7 @@ void FusedFeedForwardGradKernel(
dropout2_fix_seed, dropout2_fix_seed,
nullptr, nullptr,
dropout2_seed_val); dropout2_seed_val);
dev_ctx.template Alloc<T>(d_x);
dev_ctx.template Alloc<float>(d_ln_scale); dev_ctx.template Alloc<float>(d_ln_scale);
dev_ctx.template Alloc<float>(d_ln_bias); dev_ctx.template Alloc<float>(d_ln_bias);
dev_ctx.template Alloc<T>(d_linear1_bias); dev_ctx.template Alloc<T>(d_linear1_bias);
...@@ -529,7 +529,7 @@ void FusedFeedForwardGradKernel( ...@@ -529,7 +529,7 @@ void FusedFeedForwardGradKernel(
} // namespace fusion } // namespace fusion
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(fused_feedward_grad, PD_REGISTER_KERNEL(fused_feedforward_grad,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::fusion::FusedFeedForwardGradKernel, phi::fusion::FusedFeedForwardGradKernel,
......
...@@ -377,7 +377,7 @@ void FusedFeedForwardKernel(const Context& dev_ctx, ...@@ -377,7 +377,7 @@ void FusedFeedForwardKernel(const Context& dev_ctx,
} // namespace fusion } // namespace fusion
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(fused_feedward, PD_REGISTER_KERNEL(fused_feedforward,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::fusion::FusedFeedForwardKernel, phi::fusion::FusedFeedForwardKernel,
......
...@@ -181,12 +181,12 @@ void FusedAttentionKernel(const Context &dev_ctx, ...@@ -181,12 +181,12 @@ void FusedAttentionKernel(const Context &dev_ctx,
float *ln_mean_ptr = float *ln_mean_ptr =
(ln_mean == nullptr) (ln_mean == nullptr)
? (nullptr) ? (nullptr)
: reinterpret_cast<float *>(dev_ctx.template Alloc<T>(ln_mean)); : reinterpret_cast<float *>(dev_ctx.template Alloc<float>(ln_mean));
float *ln_var_ptr = float *ln_var_ptr =
(ln_var == nullptr) (ln_var == nullptr)
? (nullptr) ? (nullptr)
: reinterpret_cast<float *>(dev_ctx.template Alloc<T>(ln_var)); : reinterpret_cast<float *>(dev_ctx.template Alloc<float>(ln_var));
XPUTypeT *ln_out_ptr = XPUTypeT *ln_out_ptr =
(ln_out == nullptr) (ln_out == nullptr)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册