From 57f9723d36f1740bc306a8e5022ac3cf01595c2f Mon Sep 17 00:00:00 2001 From: yangyaming Date: Wed, 6 Sep 2017 18:43:33 +0800 Subject: [PATCH] Using EigenVector to replace EigenMatrix for some variables. --- paddle/operators/squared_l2_distance_op.h | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/paddle/operators/squared_l2_distance_op.h b/paddle/operators/squared_l2_distance_op.h index 77c5a0a5c9..ad3347a0b3 100644 --- a/paddle/operators/squared_l2_distance_op.h +++ b/paddle/operators/squared_l2_distance_op.h @@ -20,6 +20,9 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; template using EigenMatrix = framework::EigenMatrix; @@ -46,7 +49,7 @@ class SquaredL2DistanceKernel : public framework::OpKernel { out0->mutable_data(context.GetPlace()); out1->mutable_data(context.GetPlace()); auto sub_result = EigenMatrix::From(*out0); - auto z = EigenMatrix::From(*out1); + auto z = EigenVector::Flatten(*out1); auto place = context.GetEigenDevice(); auto x_dims = x.dimensions(); @@ -55,13 +58,12 @@ class SquaredL2DistanceKernel : public framework::OpKernel { if (y_dims[0] == 1 && x_dims[0] > y_dims[0]) { sub_result.device(place) = x - - y.broadcast(Eigen::array({static_cast(x_dims[0]), 1})); + y.broadcast(Eigen::array({{static_cast(x_dims[0]), 1}})); } else { sub_result.device(place) = x - y; } auto sub_res_pow2 = sub_result * sub_result; - // z is TensorMap, no need reshape - z.device(place) = sub_res_pow2.sum(Eigen::array({1})); + z.device(place) = sub_res_pow2.sum(Eigen::array({{1}})); } }; @@ -82,8 +84,9 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel { int cols = framework::product(x_dims) / x_dims[0]; // calculate gradient - auto grad_mat = - 2 * (out_grad.broadcast(Eigen::array({1, cols}))) * sub_result; + auto grad_mat = 2 * + (out_grad.broadcast(Eigen::array({{1, cols}}))) * + sub_result; // propagate back to input auto eigen_place = context.GetEigenDevice(); @@ -98,18 +101,18 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel { if (y_g) { y_g->mutable_data(context.GetPlace()); - auto y_grad = - EigenMatrix::From(*y_g, framework::make_ddim({y_dims[0], cols})); PADDLE_ENFORCE_GE(sub_result.dimensions()[0], y_dims[0], "First dimension of gradient must be greater or " "equal than first dimension of target."); if (sub_result.dimensions()[0] == y_dims[0]) { + auto y_grad = + EigenMatrix::From(*y_g, framework::make_ddim({y_dims[0], cols})); y_grad.device(eigen_place) = -1 * grad_mat; } else { - auto col_sum_res = -1 * (grad_mat.sum(Eigen::array({0}))); - // y_grad is TensorMap, no need reshape + auto col_sum_res = -1 * (grad_mat.sum(Eigen::array({{0}}))); + auto y_grad = EigenVector::Flatten(*y_g); y_grad.device(eigen_place) = col_sum_res; } } -- GitLab