diff --git a/paddle/fluid/operators/dist_op.h b/paddle/fluid/operators/dist_op.h index ca03400cfd1ef9a27ba8e725381515d5e4ebc0ba..a2279e40623b4ba2f0421e73a6148b89eb970e71 100644 --- a/paddle/fluid/operators/dist_op.h +++ b/paddle/fluid/operators/dist_op.h @@ -176,14 +176,26 @@ static void DistGradFunction(const framework::ExecutionContext& context) { } else if (p == INFINITY || p == -INFINITY) { // p=inf or -inf, Lp-norm = |z_i|, the j-th element of dz tends to 0 if // j!=i, or equals to sign(z_i) * dout if j=i. - grad_t.device(place) = - (x_minux_y_abs == out_t.broadcast(out_bcast_dims)).template cast() * - sign * out_grad_t.broadcast(out_bcast_dims); + if (platform::is_cpu_place(context.GetPlace())) { + grad_t.device(place) = (x_minux_y_abs == out_t.broadcast(out_bcast_dims)) + .template cast() * + sign.eval() * out_grad_t.broadcast(out_bcast_dims); + } else { + grad_t.device(place) = (x_minux_y_abs == out_t.broadcast(out_bcast_dims)) + .template cast() * + sign * out_grad_t.broadcast(out_bcast_dims); + } } else { // dz = pow(abs(x-y)/out, p-1) * sign(x-y) * dout - grad_t.device(place) = - (x_minux_y_abs / out_t.broadcast(out_bcast_dims)).pow(p - 1) * sign * - out_grad_t.broadcast(out_bcast_dims); + if (platform::is_cpu_place(context.GetPlace())) { + grad_t.device(place) = + (x_minux_y_abs / out_t.broadcast(out_bcast_dims)).pow(p - 1) * + sign.eval() * out_grad_t.broadcast(out_bcast_dims); + } else { + grad_t.device(place) = + (x_minux_y_abs / out_t.broadcast(out_bcast_dims)).pow(p - 1) * sign * + out_grad_t.broadcast(out_bcast_dims); + } } Eigen::DSizes x_reshape_dims;