diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index cd87b042df8557b763e3b40873412c1c038b06ee..f5865635bea4d28e32108bd389699b74b8fd4b17 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -23,7 +23,7 @@ namespace framework { // EigenDim converts paddle::platform::DDim into Eigen::DSizes. template struct EigenDim { - typedef Eigen::DSizes Type; + using Type = Eigen::DSizes; static Type From(const DDim& dims) { PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)"); @@ -69,12 +69,23 @@ struct EigenVector { using ConstType = Eigen::TensorMap, Eigen::Aligned>; - + // From is to transfer a one dimension Tensor into a one dimension EigenVector static Type From(Tensor& tensor) { return EigenTensor::From(tensor); } + // Flatten is to reshape a Tensor into a one dimension EigenVector + static Type Flatten(Tensor& tensor) { + return EigenTensor::From( + tensor, make_ddim({static_cast(product(tensor.dims_))})); + } + static ConstType From(const Tensor& tensor) { return EigenTensor::From(tensor); } + + static ConstType Flatten(const Tensor& tensor) { + return EigenTensor::From( + tensor, make_ddim({static_cast(product(tensor.dims_))})); + } }; // Interpret paddle::platform::Tensor as EigenMatrix and EigenConstMatrix. diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc index 23eec7533f0ed5055648e13f74a8fe55745f2f30..eca2dce60e73407d15ff046b32416a0c06011ee1 100644 --- a/paddle/framework/eigen_test.cc +++ b/paddle/framework/eigen_test.cc @@ -32,6 +32,10 @@ TEST(Eigen, Tensor) { } EigenTensor::Type et = EigenTensor::From(t); + + for (int i = 0; i < 1 * 2 * 3; i++) { + EXPECT_EQ(et(i), i); + } // TODO: check the content of et. } @@ -39,5 +43,5 @@ TEST(Eigen, Vector) {} TEST(Eigen, Matrix) {} -} // namespace platform +} // namespace framework } // namespace paddle diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index e7c106a23ffe34baec179f58efa67943a6164a23..39d54a63bd16cdafeec1cfcd86ef5d142382e880 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -30,9 +30,10 @@ public: output->mutable_data(context.GetPlace()); - framework::EigenVector::From(*output).device( + framework::EigenVector::Flatten(*output).device( *(context.GetEigenDevice())) = - framework::EigenVector(*input0) + framework::EigenVector(*input1); + framework::EigenVector::Flatten(input0) + + framework::EigenVector::Flatten(input1); } };