From d9fa6159b7b9109e76c8841388c7811eeac2eb6b Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 19 Jul 2017 14:06:58 +0800 Subject: [PATCH] add Flatten method to EigenVector --- paddle/framework/eigen.h | 15 +++++++++++++-- paddle/framework/eigen_test.cc | 6 +++++- paddle/operators/add_op.h | 5 +++-- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index cd87b042df8..f5865635bea 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -23,7 +23,7 @@ namespace framework { // EigenDim converts paddle::platform::DDim into Eigen::DSizes. template struct EigenDim { - typedef Eigen::DSizes Type; + using Type = Eigen::DSizes; static Type From(const DDim& dims) { PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)"); @@ -69,12 +69,23 @@ struct EigenVector { using ConstType = Eigen::TensorMap, Eigen::Aligned>; - + // From is to transfer a one dimension Tensor into a one dimension EigenVector static Type From(Tensor& tensor) { return EigenTensor::From(tensor); } + // Flatten is to reshape a Tensor into a one dimension EigenVector + static Type Flatten(Tensor& tensor) { + return EigenTensor::From( + tensor, make_ddim({static_cast(product(tensor.dims_))})); + } + static ConstType From(const Tensor& tensor) { return EigenTensor::From(tensor); } + + static ConstType Flatten(const Tensor& tensor) { + return EigenTensor::From( + tensor, make_ddim({static_cast(product(tensor.dims_))})); + } }; // Interpret paddle::platform::Tensor as EigenMatrix and EigenConstMatrix. diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc index 23eec7533f0..eca2dce60e7 100644 --- a/paddle/framework/eigen_test.cc +++ b/paddle/framework/eigen_test.cc @@ -32,6 +32,10 @@ TEST(Eigen, Tensor) { } EigenTensor::Type et = EigenTensor::From(t); + + for (int i = 0; i < 1 * 2 * 3; i++) { + EXPECT_EQ(et(i), i); + } // TODO: check the content of et. } @@ -39,5 +43,5 @@ TEST(Eigen, Vector) {} TEST(Eigen, Matrix) {} -} // namespace platform +} // namespace framework } // namespace paddle diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index e7c106a23ff..39d54a63bd1 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -30,9 +30,10 @@ public: output->mutable_data(context.GetPlace()); - framework::EigenVector::From(*output).device( + framework::EigenVector::Flatten(*output).device( *(context.GetEigenDevice())) = - framework::EigenVector(*input0) + framework::EigenVector(*input1); + framework::EigenVector::Flatten(input0) + + framework::EigenVector::Flatten(input1); } }; -- GitLab