From b36205e296bb0f2ec67bdd5df48b9c33d1903167 Mon Sep 17 00:00:00 2001 From: liaogang Date: Wed, 2 Aug 2017 19:26:07 +0800 Subject: [PATCH] Refine compute code in operators --- paddle/framework/eigen.h | 1 + paddle/operators/add_op.h | 11 +++++++---- paddle/operators/mean_op.h | 7 +++++-- paddle/operators/mul_op.h | 13 +++++++++---- paddle/operators/sgd_op.h | 8 ++++++-- paddle/operators/sigmoid_op.h | 8 +++++--- 6 files changed, 33 insertions(+), 15 deletions(-) diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index a4667cc51f..85006bb16e 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -28,6 +28,7 @@ struct EigenDim { static Type From(const DDim& dims) { PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)"); Type ret; +#pragma unroll for (int d = 0; d < arity(dims); d++) { ret[d] = dims[d]; } diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index d2b649fcbd..2e17334860 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -28,10 +28,13 @@ public: output->mutable_data(context.GetPlace()); - EigenVector::Flatten(*output).device( - *(context.GetEigenDevice())) = - framework::EigenVector::Flatten(*input0) + - framework::EigenVector::Flatten(*input1); + auto X = EigenVector::Flatten(*input0); + auto Y = EigenVector::Flatten(*input1); + auto Z = EigenVector::Flatten(*output); + + auto place = *context.GetEigenDevice(); + + Z.device(place) = X + Y; } }; diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h index 5f7d443751..658686c108 100644 --- a/paddle/operators/mean_op.h +++ b/paddle/operators/mean_op.h @@ -27,8 +27,11 @@ public: output->mutable_data(context.GetPlace()); - EigenScalar::From(*output).device(*(context.GetEigenDevice())) = - EigenVector::Flatten(*input).mean(); + auto X = EigenVector::Flatten(*input); + auto y = EigenScalar::From(*output); + auto place = *context.GetEigenDevice(); + + y.device(place) = X.mean(); } }; diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index eef72ab293..60fa6bdc4a 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -26,13 +26,18 @@ public: Eigen::array, 1> dim_pair = { {Eigen::IndexPair(1, 0)}}; + auto input0 = context.Input("X"); + auto input1 = context.Input("Y"); auto output = context.Output(0); + output->mutable_data(context.GetPlace()); - EigenMatrix::From(*output).device(*(context.GetEigenDevice())) = - EigenMatrix::From(*context.Input("X")) - .contract(EigenMatrix::From(*context.Input("Y")), - dim_pair); + auto X = EigenMatrix::From(*input0); + auto Y = EigenMatrix::From(*input1); + auto Z = EigenMatrix::From(*output); + auto place = *context.GetEigenDevice(); + + Z.device(place) = X.contract(Y, dim_pair); } }; } // namespace operators diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index af1dfdd756..43681ab82a 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -29,8 +29,12 @@ public: param_out->mutable_data(ctx.GetPlace()); - EigenVector::Flatten(*param_out).device(*(ctx.GetEigenDevice())) = - EigenVector::Flatten(*param) - lr * EigenVector::Flatten(*grad); + auto p = EigenVector::Flatten(*param); + auto g = EigenVector::Flatten(*grad); + auto o = EigenVector::Flatten(*param_out); + auto place = *ctx.GetEigenDevice(); + + o.device(place) = p - lr * g; } }; diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h index 3dd23a9ebc..16272da789 100644 --- a/paddle/operators/sigmoid_op.h +++ b/paddle/operators/sigmoid_op.h @@ -27,9 +27,11 @@ public: auto output = context.Output(0); output->mutable_data(context.GetPlace()); - EigenVector::Flatten(*output).device( - *(context.GetEigenDevice())) = - 1.0 / (1.0 + (-1.0 * EigenVector::Flatten(*input)).exp()); + auto X = EigenVector::Flatten(*input); + auto Y = EigenVector::Flatten(*output); + auto place = *context.GetEigenDevice(); + + Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp()); } }; } // namespace operators -- GitLab