From dfa63126b2d276ed463c3a726876b1c5dc265bf2 Mon Sep 17 00:00:00 2001 From: zhiboniu <31800336+zhiboniu@users.noreply.github.com> Date: Thu, 7 Apr 2022 10:58:16 +0800 Subject: [PATCH] fix p_norm gpu nan bug while divide zero (#41359) --- paddle/phi/kernels/gpu/p_norm_grad_kernel.cu | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu index 9b0e43d25a7..fdfed25b3dd 100644 --- a/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu @@ -42,8 +42,9 @@ struct AbsMaxAndMinGradFunctor { template struct PNormGradFunctor { - HOSTDEVICE explicit inline PNormGradFunctor(float porder) { + HOSTDEVICE explicit inline PNormGradFunctor(float porder, float eps) { this->porder = static_cast(porder - 1.); + this->eps = static_cast(eps); } template device(place) = (*x).abs().pow(this->porder) * (*x).sign() * - dy->broadcast(dim) * - (*y).pow(-this->porder).broadcast(dim); + dx->device(place) = + (*x).abs().pow(this->porder) * (*x).sign() * dy->broadcast(dim) * + (*y + y->constant(eps)).pow(-this->porder).broadcast(dim); } T porder; + T eps; }; template @@ -96,7 +98,7 @@ void PNormGradKernel(const Context& dev_ctx, dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all); } else { - auto functor = PNormGradFunctor(porder); + auto functor = PNormGradFunctor(porder, epsilon); funcs::LaunchReduceGradKernel>( dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all); } -- GitLab