未验证 提交 261f97fb 编写于 作者: Z zhiboniu 提交者: GitHub

fix p_norm gpu nan bug while divide zero (#41804)

上级 9294ba25
...@@ -42,8 +42,9 @@ struct AbsMaxAndMinGradFunctor { ...@@ -42,8 +42,9 @@ struct AbsMaxAndMinGradFunctor {
template <typename T> template <typename T>
struct PNormGradFunctor { struct PNormGradFunctor {
HOSTDEVICE explicit inline PNormGradFunctor(float porder) { HOSTDEVICE explicit inline PNormGradFunctor(float porder, float eps) {
this->porder = static_cast<T>(porder - 1.); this->porder = static_cast<T>(porder - 1.);
this->eps = static_cast<T>(eps);
} }
template <typename Context, template <typename Context,
typename X, typename X,
...@@ -58,11 +59,12 @@ struct PNormGradFunctor { ...@@ -58,11 +59,12 @@ struct PNormGradFunctor {
DY* dy, DY* dy,
const Dim& dim, const Dim& dim,
int size) { int size) {
dx->device(place) = (*x).abs().pow(this->porder) * (*x).sign() * dx->device(place) =
dy->broadcast(dim) * (*x).abs().pow(this->porder) * (*x).sign() * dy->broadcast(dim) *
(*y).pow(-this->porder).broadcast(dim); (*y + y->constant(eps)).pow(-this->porder).broadcast(dim);
} }
T porder; T porder;
T eps;
}; };
template <typename T, typename Context> template <typename T, typename Context>
...@@ -96,7 +98,7 @@ void PNormGradKernel(const Context& dev_ctx, ...@@ -96,7 +98,7 @@ void PNormGradKernel(const Context& dev_ctx,
dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all); dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all);
} else { } else {
auto functor = PNormGradFunctor<T>(porder); auto functor = PNormGradFunctor<T>(porder, epsilon);
funcs::LaunchReduceGradKernel<Context, T, PNormGradFunctor<T>>( funcs::LaunchReduceGradKernel<Context, T, PNormGradFunctor<T>>(
dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all); dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册