提交 1981eaf9 编写于 作者: Y Yi Wang

Fix Tensor::data interface

上级 2538e207
......@@ -28,7 +28,7 @@ struct EigenDim {
static Type From(const DDim& dims) {
PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)");
Type ret;
for (int d = 0; d < rank; d++) {
for (int d = 0; d < arity(dims); d++) {
ret[d] = dims[d];
}
return ret;
......@@ -43,8 +43,7 @@ struct EigenTensor {
using ConstType =
Eigen::TensorMap<Eigen::Tensor<const T, D, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
ConstTensor;
Eigen::Aligned>;
static Type From(Tensor& tensor, DDim dims) {
return Type(tensor.data<T>(), EigenDim<D>::From(dims));
......@@ -64,11 +63,10 @@ struct EigenTensor {
// Interpret paddle::platform::Tensor as EigenVecotr and EigenConstVector.
template <typename T, typename IndexType = Eigen::DenseIndex>
struct EigenVector {
using EigenVector =
Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
using Type = Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
Eigen::Aligned>;
using EigenConstVector =
using ConstType =
Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>,
Eigen::Aligned>;
......@@ -82,13 +80,10 @@ struct EigenVector {
// Interpret paddle::platform::Tensor as EigenMatrix and EigenConstMatrix.
template <typename T, typename IndexType = Eigen::DenseIndex>
struct EigenMatrix {
template <typename T, typename IndexType = Eigen::DenseIndex>
using EigenMatrix =
Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
using Type = Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
Eigen::Aligned>;
template <typename T, typename IndexType = Eigen::DenseIndex>
using EigenConstMatrix =
using ConstType =
Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>,
Eigen::Aligned>;
......
......@@ -12,26 +12,32 @@
*/
#include "paddle/framework/eigen.h"
#include <gtest/gtest.h>
#include "paddle/framework/tensor.h"
namespace paddle {
namespace framework {
TEST(Eigen, Tensor) {
using paddle::platform::Tensor;
using paddle::platform::EigenTensor;
using paddle::platform::make_ddim;
TEST(EigenDim, From) {
EigenDim<3>::Type ed = EigenDim<3>::From(make_ddim({1, 2, 3}));
EXPECT_EQ(1, ed[0]);
EXPECT_EQ(2, ed[1]);
EXPECT_EQ(3, ed[2]);
}
TEST(Eigen, Tensor) {
Tensor t;
float* p = t.mutable_data<float>(make_ddim({1, 2, 3}), CPUPlace());
float* p = t.mutable_data<float>(make_ddim({1, 2, 3}), platform::CPUPlace());
for (int i = 0; i < 1 * 2 * 3; i++) {
p[i] = static_cast<float>(i);
}
EigenTensor::Type et = EigenTensor::From(t);
EigenTensor<float, 3>::Type et = EigenTensor<float, 3>::From(t);
// TODO: check the content of et.
}
TEST(Eigen, Vector) {}
TEST(Eigen, Matrix) {}
} // namespace platform
} // namespace paddle
......@@ -37,13 +37,13 @@ class Tensor {
template <bool less, size_t i, typename... args>
friend struct paddle::pybind::details::CastToPyBufferImpl;
template <typename T, size_t D, typename IndexType = Eigen::DenseIndex>
template <typename T, size_t D, typename IndexType>
friend struct EigenTensor;
template <typename T, typename IndexType = Eigen::DenseIndex>
template <typename T, typename IndexType>
friend struct EigenVector;
template <typename T, typename IndexType = Eigen::DenseIndex>
template <typename T, typename IndexType>
friend struct EigenMatrix;
public:
......@@ -57,7 +57,7 @@ class Tensor {
}
template <typename T>
T* raw_data() const {
T* data() {
CheckDims<T>();
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册