提交 1a968b4f 编写于 作者: Q qijun

init

上级 1038bc46
......@@ -6,6 +6,7 @@
#include <vector>
#include "paddle/framework/dim.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace framework {
......@@ -91,6 +92,15 @@ int arity(const DDim& ddim);
std::ostream& operator<<(std::ostream&, const DDim&);
template <int NDIMS>
Eigen::DSizes<Eigen::DenseIndex, NDIMS> ToEigenDSizes(DDim dims) const {
Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes;
for (int d = 0; d < paddle::framework::arity(dims); d++) {
dsizes[d] = dims[d];
}
return dsizes;
}
} // namespace framework
} // namespace paddle
......
......@@ -18,8 +18,10 @@ limitations under the License. */
#include <type_traits>
#include "paddle/framework/ddim.h"
#include "paddle/framework/enforce.h"
#include "paddle/framework/tensor_types.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace framework {
......@@ -33,6 +35,13 @@ class Tensor {
return static_cast<const T*>(holder_->Ptr());
}
template <typename T>
T* data() const {
PADDLE_ENFORCE(holder_ != nullptr,
"Tensor::data must be called after Tensor::mutable_data.");
return static_cast<T*>(holder_->Ptr());
}
template <typename T, // must be POD types
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
T* mutable_data(DDim dims, paddle::platform::Place place) {
......@@ -41,14 +50,23 @@ class Tensor {
place) /* some versions of boost::variant don't have operator!= */
|| holder_->Size() < product(dims) * sizeof(T)) {
holder_.reset(new PlaceholderImpl<T>(place, product(dims) * sizeof(T)));
dims_ = dims;
}
return static_cast<T*>(holder_->Ptr());
}
template <typename T, // must be POD types
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
T* mutable_data(DDim dims) {
return mutable_data<T>(dims, paddle::platform::get_place());
DDim dim() const { return dims_; }
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstantTensor Tensor::tensor() {
return typename TTypes<T, NDIMS>::Tensor(
data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor Tensor::tensor() {
return typename TTypes<T, NDIMS>::Tensor(
data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
}
private:
......@@ -92,6 +110,7 @@ class Tensor {
};
std::shared_ptr<Placeholder> holder_; // holds the memory block if allocated.
DDim dims_;
};
} // namespace framework
......
#pragma once
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace framework {
// Helper to define Tensor types given that the scalar is of type T.
template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
struct TTypes {
// Rank-<NDIMS> tensor of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Tensor;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstTensor;
// Unaligned Rank-<NDIMS> tensor of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>>
UnalignedTensor;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>>
UnalignedConstTensor;
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, int>,
Eigen::Aligned>
Tensor32Bit;
// Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
typedef Eigen::TensorMap<
Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Scalar;
typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
Eigen::RowMajor, IndexType>,
Eigen::Aligned>
ConstScalar;
// Unaligned Scalar tensor of scalar type T.
typedef Eigen::TensorMap<
Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>>
UnalignedScalar;
typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
Eigen::RowMajor, IndexType>>
UnalignedConstScalar;
// Rank-1 tensor (vector) of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Flat;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstFlat;
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Vec;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstVec;
// Unaligned Rank-1 tensor (vector) of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>>
UnalignedFlat;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>>
UnalignedConstFlat;
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>>
UnalignedVec;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>>
UnalignedConstVec;
// Rank-2 tensor (matrix) of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Matrix;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstMatrix;
// Unaligned Rank-2 tensor (matrix) of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>>
UnalignedMatrix;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>>
UnalignedConstMatrix;
};
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册