未验证 提交 55b974e7 编写于 作者: H haosicheng 提交者: GitHub

[XPU] fix layer_norm_grad bug when bias_grad and scale_grad are nullptr (#54669)

上级 26980b7b
......@@ -63,11 +63,15 @@ void LayerNormGradKernel(const Context& ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
scale_data_fp32 = scale_data_temp;
need_cast_scale = true;
scale_grad_data_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel());
scale_grad_data_fp32 =
scale_grad == nullptr
? nullptr
: RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel());
} else {
// no need to cast
scale_data_fp32 = scale_ptr->data<float>();
scale_grad_data_fp32 = ctx.template Alloc<float>(scale_grad);
scale_grad_data_fp32 =
scale_grad == nullptr ? nullptr : ctx.template Alloc<float>(scale_grad);
}
// bias
......@@ -79,10 +83,14 @@ void LayerNormGradKernel(const Context& ctx,
} else if (bias_ptr->dtype() ==
phi::CppTypeToDataType<phi::dtype::float16>::Type()) {
need_cast_bias = true;
bias_grad_data_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(bias_ptr->numel());
bias_grad_data_fp32 =
bias_grad == nullptr
? nullptr
: RAII_GUARD.alloc_l3_or_gm<float>(bias_ptr->numel());
} else {
// no need to cast
bias_grad_data_fp32 = ctx.template Alloc<float>(bias_grad);
bias_grad_data_fp32 =
bias_grad == nullptr ? nullptr : ctx.template Alloc<float>(bias_grad);
}
auto* x_grad_data =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册