From 1896c777f1dc151d0ee8241ce57354f1a0db3e71 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Sun, 25 Apr 2021 15:04:36 +0800 Subject: [PATCH] fix gradient(nan) when two inputs are equal (#32448) --- paddle/fluid/operators/dist_op.h | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/dist_op.h b/paddle/fluid/operators/dist_op.h index a2279e4062..6a34ef48a1 100644 --- a/paddle/fluid/operators/dist_op.h +++ b/paddle/fluid/operators/dist_op.h @@ -167,6 +167,7 @@ static void DistGradFunction(const framework::ExecutionContext& context) { auto sign = (x_minux_y > static_cast(0)).template cast() * static_cast(1.0) + (x_minux_y < static_cast(0)).template cast() * static_cast(-1.0); + T epsilon = static_cast(1.0e-10f); // 1: Lp-norm(z), z = x-y, compute dz if (p == 0) { @@ -189,12 +190,14 @@ static void DistGradFunction(const framework::ExecutionContext& context) { // dz = pow(abs(x-y)/out, p-1) * sign(x-y) * dout if (platform::is_cpu_place(context.GetPlace())) { grad_t.device(place) = - (x_minux_y_abs / out_t.broadcast(out_bcast_dims)).pow(p - 1) * + (x_minux_y_abs / (out_t + epsilon).broadcast(out_bcast_dims)) + .pow(p - 1) * sign.eval() * out_grad_t.broadcast(out_bcast_dims); } else { grad_t.device(place) = - (x_minux_y_abs / out_t.broadcast(out_bcast_dims)).pow(p - 1) * sign * - out_grad_t.broadcast(out_bcast_dims); + (x_minux_y_abs / (out_t + epsilon).broadcast(out_bcast_dims)) + .pow(p - 1) * + sign * out_grad_t.broadcast(out_bcast_dims); } } -- GitLab