未验证 提交 1896c777 编写于 作者: Z Zhang Ting 提交者: GitHub

fix gradient(nan) when two inputs are equal (#32448)

上级 727b28d7
......@@ -167,6 +167,7 @@ static void DistGradFunction(const framework::ExecutionContext& context) {
auto sign =
(x_minux_y > static_cast<T>(0)).template cast<T>() * static_cast<T>(1.0) +
(x_minux_y < static_cast<T>(0)).template cast<T>() * static_cast<T>(-1.0);
T epsilon = static_cast<T>(1.0e-10f);
// 1: Lp-norm(z), z = x-y, compute dz
if (p == 0) {
......@@ -189,12 +190,14 @@ static void DistGradFunction(const framework::ExecutionContext& context) {
// dz = pow(abs(x-y)/out, p-1) * sign(x-y) * dout
if (platform::is_cpu_place(context.GetPlace())) {
grad_t.device(place) =
(x_minux_y_abs / out_t.broadcast(out_bcast_dims)).pow(p - 1) *
(x_minux_y_abs / (out_t + epsilon).broadcast(out_bcast_dims))
.pow(p - 1) *
sign.eval() * out_grad_t.broadcast(out_bcast_dims);
} else {
grad_t.device(place) =
(x_minux_y_abs / out_t.broadcast(out_bcast_dims)).pow(p - 1) * sign *
out_grad_t.broadcast(out_bcast_dims);
(x_minux_y_abs / (out_t + epsilon).broadcast(out_bcast_dims))
.pow(p - 1) *
sign * out_grad_t.broadcast(out_bcast_dims);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册