diff --git a/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu index 9b0e43d25a7ceb08a0974c15769de19473aa670e..fdfed25b3dda8fd03dcc1ab6a0fdccfc3dbb25a1 100644 --- a/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu @@ -42,8 +42,9 @@ struct AbsMaxAndMinGradFunctor { template struct PNormGradFunctor { - HOSTDEVICE explicit inline PNormGradFunctor(float porder) { + HOSTDEVICE explicit inline PNormGradFunctor(float porder, float eps) { this->porder = static_cast(porder - 1.); + this->eps = static_cast(eps); } template device(place) = (*x).abs().pow(this->porder) * (*x).sign() * - dy->broadcast(dim) * - (*y).pow(-this->porder).broadcast(dim); + dx->device(place) = + (*x).abs().pow(this->porder) * (*x).sign() * dy->broadcast(dim) * + (*y + y->constant(eps)).pow(-this->porder).broadcast(dim); } T porder; + T eps; }; template @@ -96,7 +98,7 @@ void PNormGradKernel(const Context& dev_ctx, dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all); } else { - auto functor = PNormGradFunctor(porder); + auto functor = PNormGradFunctor(porder, epsilon); funcs::LaunchReduceGradKernel>( dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all); }