提交 47690d6a 编写于 作者: L liaogang

ENH: Add EigenScalar

上级 0b680772
...@@ -80,5 +80,21 @@ struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> { ...@@ -80,5 +80,21 @@ struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> {
} }
}; };
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
struct EigenScalar {
// Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
using Type = Eigen::TensorMap<
Eigen::TensorFixedSize<T, Eigen::Sizes<>, MajorType, IndexType>>;
using ConstType = Eigen::TensorMap<
Eigen::TensorFixedSize<const T, Eigen::Sizes<>, MajorType, IndexType>>;
static Type From(Tensor& tensor) { return Type(tensor.data<T>()); }
static ConstType From(const Tensor& tensor) {
return ConstType(tensor.data<T>());
}
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -46,6 +46,17 @@ TEST(Eigen, Tensor) { ...@@ -46,6 +46,17 @@ TEST(Eigen, Tensor) {
} }
} }
TEST(Eigen, ScalarFrom) {
Tensor t;
int* p = t.mutable_data<int>(make_ddim({1}), platform::CPUPlace());
*p = static_cast<int>(100);
EigenScalar<int>::Type es = EigenScalar<int>::From(t);
ASSERT_EQ(0, es.dimension(0));
ASSERT_EQ(100, es(0));
}
TEST(Eigen, VectorFrom) { TEST(Eigen, VectorFrom) {
Tensor t; Tensor t;
float* p = t.mutable_data<float>(make_ddim({6}), platform::CPUPlace()); float* p = t.mutable_data<float>(make_ddim({6}), platform::CPUPlace());
......
...@@ -23,6 +23,10 @@ namespace operators { ...@@ -23,6 +23,10 @@ namespace operators {
using OpKernel = framework::OpKernel; using OpKernel = framework::OpKernel;
using KernelContext = framework::KernelContext; using KernelContext = framework::KernelContext;
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename T, template <typename T,
int MajorType = Eigen::RowMajor, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册