未验证 提交 55b974e7 编写于 作者: H haosicheng 提交者: GitHub

[XPU] fix layer_norm_grad bug when bias_grad and scale_grad are nullptr (#54669)

上级 26980b7b
...@@ -63,11 +63,15 @@ void LayerNormGradKernel(const Context& ctx, ...@@ -63,11 +63,15 @@ void LayerNormGradKernel(const Context& ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
scale_data_fp32 = scale_data_temp; scale_data_fp32 = scale_data_temp;
need_cast_scale = true; need_cast_scale = true;
scale_grad_data_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel()); scale_grad_data_fp32 =
scale_grad == nullptr
? nullptr
: RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel());
} else { } else {
// no need to cast // no need to cast
scale_data_fp32 = scale_ptr->data<float>(); scale_data_fp32 = scale_ptr->data<float>();
scale_grad_data_fp32 = ctx.template Alloc<float>(scale_grad); scale_grad_data_fp32 =
scale_grad == nullptr ? nullptr : ctx.template Alloc<float>(scale_grad);
} }
// bias // bias
...@@ -79,10 +83,14 @@ void LayerNormGradKernel(const Context& ctx, ...@@ -79,10 +83,14 @@ void LayerNormGradKernel(const Context& ctx,
} else if (bias_ptr->dtype() == } else if (bias_ptr->dtype() ==
phi::CppTypeToDataType<phi::dtype::float16>::Type()) { phi::CppTypeToDataType<phi::dtype::float16>::Type()) {
need_cast_bias = true; need_cast_bias = true;
bias_grad_data_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(bias_ptr->numel()); bias_grad_data_fp32 =
bias_grad == nullptr
? nullptr
: RAII_GUARD.alloc_l3_or_gm<float>(bias_ptr->numel());
} else { } else {
// no need to cast // no need to cast
bias_grad_data_fp32 = ctx.template Alloc<float>(bias_grad); bias_grad_data_fp32 =
bias_grad == nullptr ? nullptr : ctx.template Alloc<float>(bias_grad);
} }
auto* x_grad_data = auto* x_grad_data =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册