From 1981eaf922f3636a9f49209757d52c527d2dbe96 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Tue, 18 Jul 2017 18:37:29 -0700 Subject: [PATCH] Fix Tensor::data interface --- paddle/framework/eigen.h | 21 ++++++++------------- paddle/framework/eigen_test.cc | 22 ++++++++++++++-------- paddle/framework/tensor.h | 8 ++++---- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index 28641a389f0..cd87b042df8 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -28,7 +28,7 @@ struct EigenDim { static Type From(const DDim& dims) { PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)"); Type ret; - for (int d = 0; d < rank; d++) { + for (int d = 0; d < arity(dims); d++) { ret[d] = dims[d]; } return ret; @@ -43,8 +43,7 @@ struct EigenTensor { using ConstType = Eigen::TensorMap, - Eigen::Aligned> - ConstTensor; + Eigen::Aligned>; static Type From(Tensor& tensor, DDim dims) { return Type(tensor.data(), EigenDim::From(dims)); @@ -64,11 +63,10 @@ struct EigenTensor { // Interpret paddle::platform::Tensor as EigenVecotr and EigenConstVector. template struct EigenVector { - using EigenVector = - Eigen::TensorMap, - Eigen::Aligned>; + using Type = Eigen::TensorMap, + Eigen::Aligned>; - using EigenConstVector = + using ConstType = Eigen::TensorMap, Eigen::Aligned>; @@ -82,13 +80,10 @@ struct EigenVector { // Interpret paddle::platform::Tensor as EigenMatrix and EigenConstMatrix. template struct EigenMatrix { - template - using EigenMatrix = - Eigen::TensorMap, - Eigen::Aligned>; + using Type = Eigen::TensorMap, + Eigen::Aligned>; - template - using EigenConstMatrix = + using ConstType = Eigen::TensorMap, Eigen::Aligned>; diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc index c5f27a32984..23eec7533f0 100644 --- a/paddle/framework/eigen_test.cc +++ b/paddle/framework/eigen_test.cc @@ -12,26 +12,32 @@ */ #include "paddle/framework/eigen.h" - #include -#include "paddle/framework/tensor.h" +namespace paddle { +namespace framework { -TEST(Eigen, Tensor) { - using paddle::platform::Tensor; - using paddle::platform::EigenTensor; - using paddle::platform::make_ddim; +TEST(EigenDim, From) { + EigenDim<3>::Type ed = EigenDim<3>::From(make_ddim({1, 2, 3})); + EXPECT_EQ(1, ed[0]); + EXPECT_EQ(2, ed[1]); + EXPECT_EQ(3, ed[2]); +} +TEST(Eigen, Tensor) { Tensor t; - float* p = t.mutable_data(make_ddim({1, 2, 3}), CPUPlace()); + float* p = t.mutable_data(make_ddim({1, 2, 3}), platform::CPUPlace()); for (int i = 0; i < 1 * 2 * 3; i++) { p[i] = static_cast(i); } - EigenTensor::Type et = EigenTensor::From(t); + EigenTensor::Type et = EigenTensor::From(t); // TODO: check the content of et. } TEST(Eigen, Vector) {} TEST(Eigen, Matrix) {} + +} // namespace platform +} // namespace paddle diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 405393fb113..8fbf42e7f65 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -37,13 +37,13 @@ class Tensor { template friend struct paddle::pybind::details::CastToPyBufferImpl; - template + template friend struct EigenTensor; - template + template friend struct EigenVector; - template + template friend struct EigenMatrix; public: @@ -57,7 +57,7 @@ class Tensor { } template - T* raw_data() const { + T* data() { CheckDims(); return reinterpret_cast(reinterpret_cast(holder_->ptr()) + offset_); -- GitLab