提交 1981eaf9 编写于 作者: Y Yi Wang

Fix Tensor::data interface

上级 2538e207
...@@ -28,7 +28,7 @@ struct EigenDim { ...@@ -28,7 +28,7 @@ struct EigenDim {
static Type From(const DDim& dims) { static Type From(const DDim& dims) {
PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)"); PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)");
Type ret; Type ret;
for (int d = 0; d < rank; d++) { for (int d = 0; d < arity(dims); d++) {
ret[d] = dims[d]; ret[d] = dims[d];
} }
return ret; return ret;
...@@ -43,8 +43,7 @@ struct EigenTensor { ...@@ -43,8 +43,7 @@ struct EigenTensor {
using ConstType = using ConstType =
Eigen::TensorMap<Eigen::Tensor<const T, D, Eigen::RowMajor, IndexType>, Eigen::TensorMap<Eigen::Tensor<const T, D, Eigen::RowMajor, IndexType>,
Eigen::Aligned> Eigen::Aligned>;
ConstTensor;
static Type From(Tensor& tensor, DDim dims) { static Type From(Tensor& tensor, DDim dims) {
return Type(tensor.data<T>(), EigenDim<D>::From(dims)); return Type(tensor.data<T>(), EigenDim<D>::From(dims));
...@@ -64,11 +63,10 @@ struct EigenTensor { ...@@ -64,11 +63,10 @@ struct EigenTensor {
// Interpret paddle::platform::Tensor as EigenVecotr and EigenConstVector. // Interpret paddle::platform::Tensor as EigenVecotr and EigenConstVector.
template <typename T, typename IndexType = Eigen::DenseIndex> template <typename T, typename IndexType = Eigen::DenseIndex>
struct EigenVector { struct EigenVector {
using EigenVector = using Type = Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
Eigen::Aligned>; Eigen::Aligned>;
using EigenConstVector = using ConstType =
Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>,
Eigen::Aligned>; Eigen::Aligned>;
...@@ -82,13 +80,10 @@ struct EigenVector { ...@@ -82,13 +80,10 @@ struct EigenVector {
// Interpret paddle::platform::Tensor as EigenMatrix and EigenConstMatrix. // Interpret paddle::platform::Tensor as EigenMatrix and EigenConstMatrix.
template <typename T, typename IndexType = Eigen::DenseIndex> template <typename T, typename IndexType = Eigen::DenseIndex>
struct EigenMatrix { struct EigenMatrix {
template <typename T, typename IndexType = Eigen::DenseIndex> using Type = Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
using EigenMatrix =
Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
Eigen::Aligned>; Eigen::Aligned>;
template <typename T, typename IndexType = Eigen::DenseIndex> using ConstType =
using EigenConstMatrix =
Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>, Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>,
Eigen::Aligned>; Eigen::Aligned>;
......
...@@ -12,26 +12,32 @@ ...@@ -12,26 +12,32 @@
*/ */
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/framework/tensor.h" namespace paddle {
namespace framework {
TEST(Eigen, Tensor) { TEST(EigenDim, From) {
using paddle::platform::Tensor; EigenDim<3>::Type ed = EigenDim<3>::From(make_ddim({1, 2, 3}));
using paddle::platform::EigenTensor; EXPECT_EQ(1, ed[0]);
using paddle::platform::make_ddim; EXPECT_EQ(2, ed[1]);
EXPECT_EQ(3, ed[2]);
}
TEST(Eigen, Tensor) {
Tensor t; Tensor t;
float* p = t.mutable_data<float>(make_ddim({1, 2, 3}), CPUPlace()); float* p = t.mutable_data<float>(make_ddim({1, 2, 3}), platform::CPUPlace());
for (int i = 0; i < 1 * 2 * 3; i++) { for (int i = 0; i < 1 * 2 * 3; i++) {
p[i] = static_cast<float>(i); p[i] = static_cast<float>(i);
} }
EigenTensor::Type et = EigenTensor::From(t); EigenTensor<float, 3>::Type et = EigenTensor<float, 3>::From(t);
// TODO: check the content of et. // TODO: check the content of et.
} }
TEST(Eigen, Vector) {} TEST(Eigen, Vector) {}
TEST(Eigen, Matrix) {} TEST(Eigen, Matrix) {}
} // namespace platform
} // namespace paddle
...@@ -37,13 +37,13 @@ class Tensor { ...@@ -37,13 +37,13 @@ class Tensor {
template <bool less, size_t i, typename... args> template <bool less, size_t i, typename... args>
friend struct paddle::pybind::details::CastToPyBufferImpl; friend struct paddle::pybind::details::CastToPyBufferImpl;
template <typename T, size_t D, typename IndexType = Eigen::DenseIndex> template <typename T, size_t D, typename IndexType>
friend struct EigenTensor; friend struct EigenTensor;
template <typename T, typename IndexType = Eigen::DenseIndex> template <typename T, typename IndexType>
friend struct EigenVector; friend struct EigenVector;
template <typename T, typename IndexType = Eigen::DenseIndex> template <typename T, typename IndexType>
friend struct EigenMatrix; friend struct EigenMatrix;
public: public:
...@@ -57,7 +57,7 @@ class Tensor { ...@@ -57,7 +57,7 @@ class Tensor {
} }
template <typename T> template <typename T>
T* raw_data() const { T* data() {
CheckDims<T>(); CheckDims<T>();
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) + return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_); offset_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册