diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index a4ee407cae1ec4974ae5addd8626686d908eaf32..54d2231425293f6cfb3adc9cb34d903a75fcdcd0 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -28,9 +28,13 @@ public: output->mutable_data(context.GetPlace()); - EigenVector::Flatten(*output).device(context.GetEigenDevice()) = - framework::EigenVector::Flatten(*input0) + - framework::EigenVector::Flatten(*input1); + auto X = EigenVector::Flatten(*input0); + auto Y = EigenVector::Flatten(*input1); + auto Z = EigenVector::Flatten(*output); + + auto place = context.GetEigenDevice(); + + Z.device(place) = X + Y; } }; diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h index 20f21105298a20b2083c35e3a5d8b8049b0752ca..5c339bffbf8e39f36ee9b4f857ab380cbac82879 100644 --- a/paddle/operators/mean_op.h +++ b/paddle/operators/mean_op.h @@ -27,8 +27,11 @@ public: output->mutable_data(context.GetPlace()); - EigenScalar::From(*output).device(context.GetEigenDevice()) = - EigenVector::Flatten(*input).mean(); + auto X = EigenVector::Flatten(*input); + auto y = EigenScalar::From(*output); + auto place = context.GetEigenDevice(); + + y.device(place) = X.mean(); } }; diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index 1d0617ab8bd44297ee112c58542cf6487118e801..c7b78ad39045d25d73bfc2c930063c255a514864 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -26,13 +26,18 @@ public: Eigen::array, 1> dim_pair = { {Eigen::IndexPair(1, 0)}}; + auto input0 = context.Input("X"); + auto input1 = context.Input("Y"); auto output = context.Output(0); + output->mutable_data(context.GetPlace()); - EigenMatrix::From(*output).device(context.GetEigenDevice()) = - EigenMatrix::From(*context.Input("X")) - .contract(EigenMatrix::From(*context.Input("Y")), - dim_pair); + auto X = EigenMatrix::From(*input0); + auto Y = EigenMatrix::From(*input1); + auto Z = EigenMatrix::From(*output); + auto place = context.GetEigenDevice(); + + Z.device(place) = X.contract(Y, dim_pair); } }; } // namespace operators diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index d8ddbac573ffae8bd5d4b53041ed4dd35bb5ae0c..0c3a240f9a4a5fc7bc4898e82786810cee2f7010 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -29,8 +29,12 @@ public: param_out->mutable_data(ctx.GetPlace()); - EigenVector::Flatten(*param_out).device(ctx.GetEigenDevice()) = - EigenVector::Flatten(*param) - lr * EigenVector::Flatten(*grad); + auto p = EigenVector::Flatten(*param); + auto g = EigenVector::Flatten(*grad); + auto o = EigenVector::Flatten(*param_out); + auto place = ctx.GetEigenDevice(); + + o.device(place) = p - lr * g; } }; diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h index f518ddcf3b2381f0b8ceec3eee260f033a847263..1412e4398440c8e946d3ab434a50e978079637ab 100644 --- a/paddle/operators/sigmoid_op.h +++ b/paddle/operators/sigmoid_op.h @@ -27,8 +27,11 @@ public: auto output = context.Output(0); output->mutable_data(context.GetPlace()); - EigenVector::Flatten(*output).device(context.GetEigenDevice()) = - 1.0 / (1.0 + (-1.0 * EigenVector::Flatten(*input)).exp()); + auto X = EigenVector::Flatten(*input); + auto Y = EigenVector::Flatten(*output); + auto place = context.GetEigenDevice(); + + Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp()); } }; } // namespace operators