From 47690d6a7bb410ef5f268d0bea8a7bbc3ed89474 Mon Sep 17 00:00:00 2001 From: liaogang Date: Tue, 1 Aug 2017 15:50:21 +0800 Subject: [PATCH] ENH: Add EigenScalar --- paddle/framework/eigen.h | 16 ++++++++++++++++ paddle/framework/eigen_test.cc | 11 +++++++++++ paddle/operators/type_alias.h | 4 ++++ 3 files changed, 31 insertions(+) diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index 5f3358c69..a4667cc51 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -80,5 +80,21 @@ struct EigenVector : public EigenTensor { } }; +template +struct EigenScalar { + // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. + using Type = Eigen::TensorMap< + Eigen::TensorFixedSize, MajorType, IndexType>>; + using ConstType = Eigen::TensorMap< + Eigen::TensorFixedSize, MajorType, IndexType>>; + + static Type From(Tensor& tensor) { return Type(tensor.data()); } + + static ConstType From(const Tensor& tensor) { + return ConstType(tensor.data()); + } +}; + } // namespace framework } // namespace paddle diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc index a9fa728e4..dc1957691 100644 --- a/paddle/framework/eigen_test.cc +++ b/paddle/framework/eigen_test.cc @@ -46,6 +46,17 @@ TEST(Eigen, Tensor) { } } +TEST(Eigen, ScalarFrom) { + Tensor t; + int* p = t.mutable_data(make_ddim({1}), platform::CPUPlace()); + *p = static_cast(100); + + EigenScalar::Type es = EigenScalar::From(t); + + ASSERT_EQ(0, es.dimension(0)); + ASSERT_EQ(100, es(0)); +} + TEST(Eigen, VectorFrom) { Tensor t; float* p = t.mutable_data(make_ddim({6}), platform::CPUPlace()); diff --git a/paddle/operators/type_alias.h b/paddle/operators/type_alias.h index b712e457f..275c0c1ac 100644 --- a/paddle/operators/type_alias.h +++ b/paddle/operators/type_alias.h @@ -23,6 +23,10 @@ namespace operators { using OpKernel = framework::OpKernel; using KernelContext = framework::KernelContext; +template +using EigenScalar = framework::EigenScalar; template -- GitLab