未验证 提交 5c1bafbb 编写于 作者: Z Zhang Ting 提交者: GitHub

use eval to improve performance, test=develop (#25459)

上级 5c4eed66
...@@ -176,15 +176,27 @@ static void DistGradFunction(const framework::ExecutionContext& context) { ...@@ -176,15 +176,27 @@ static void DistGradFunction(const framework::ExecutionContext& context) {
} else if (p == INFINITY || p == -INFINITY) { } else if (p == INFINITY || p == -INFINITY) {
// p=inf or -inf, Lp-norm = |z_i|, the j-th element of dz tends to 0 if // p=inf or -inf, Lp-norm = |z_i|, the j-th element of dz tends to 0 if
// j!=i, or equals to sign(z_i) * dout if j=i. // j!=i, or equals to sign(z_i) * dout if j=i.
grad_t.device(place) = if (platform::is_cpu_place(context.GetPlace())) {
(x_minux_y_abs == out_t.broadcast(out_bcast_dims)).template cast<T>() * grad_t.device(place) = (x_minux_y_abs == out_t.broadcast(out_bcast_dims))
.template cast<T>() *
sign.eval() * out_grad_t.broadcast(out_bcast_dims);
} else {
grad_t.device(place) = (x_minux_y_abs == out_t.broadcast(out_bcast_dims))
.template cast<T>() *
sign * out_grad_t.broadcast(out_bcast_dims); sign * out_grad_t.broadcast(out_bcast_dims);
}
} else { } else {
// dz = pow(abs(x-y)/out, p-1) * sign(x-y) * dout // dz = pow(abs(x-y)/out, p-1) * sign(x-y) * dout
if (platform::is_cpu_place(context.GetPlace())) {
grad_t.device(place) =
(x_minux_y_abs / out_t.broadcast(out_bcast_dims)).pow(p - 1) *
sign.eval() * out_grad_t.broadcast(out_bcast_dims);
} else {
grad_t.device(place) = grad_t.device(place) =
(x_minux_y_abs / out_t.broadcast(out_bcast_dims)).pow(p - 1) * sign * (x_minux_y_abs / out_t.broadcast(out_bcast_dims)).pow(p - 1) * sign *
out_grad_t.broadcast(out_bcast_dims); out_grad_t.broadcast(out_bcast_dims);
} }
}
Eigen::DSizes<int, Rank * 2> x_reshape_dims; Eigen::DSizes<int, Rank * 2> x_reshape_dims;
Eigen::DSizes<int, Rank * 2> y_reshape_dims; Eigen::DSizes<int, Rank * 2> y_reshape_dims;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册