diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index 5f3358c69b3fbbbfcd97a96ab50fde3d8b9efad0..a4667cc51fadfc020d3211b7a82356db386fced1 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -80,5 +80,21 @@ struct EigenVector : public EigenTensor { } }; +template +struct EigenScalar { + // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. + using Type = Eigen::TensorMap< + Eigen::TensorFixedSize, MajorType, IndexType>>; + using ConstType = Eigen::TensorMap< + Eigen::TensorFixedSize, MajorType, IndexType>>; + + static Type From(Tensor& tensor) { return Type(tensor.data()); } + + static ConstType From(const Tensor& tensor) { + return ConstType(tensor.data()); + } +}; + } // namespace framework } // namespace paddle diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc index a9fa728e49a0dcc781e520a22c1ee5f921c4c733..dc1957691b1a202826e10e84c21ac8874df9e378 100644 --- a/paddle/framework/eigen_test.cc +++ b/paddle/framework/eigen_test.cc @@ -46,6 +46,17 @@ TEST(Eigen, Tensor) { } } +TEST(Eigen, ScalarFrom) { + Tensor t; + int* p = t.mutable_data(make_ddim({1}), platform::CPUPlace()); + *p = static_cast(100); + + EigenScalar::Type es = EigenScalar::From(t); + + ASSERT_EQ(0, es.dimension(0)); + ASSERT_EQ(100, es(0)); +} + TEST(Eigen, VectorFrom) { Tensor t; float* p = t.mutable_data(make_ddim({6}), platform::CPUPlace()); diff --git a/paddle/operators/type_alias.h b/paddle/operators/type_alias.h index 9d1f5fba2ad3ada4742ada30b41d68d15a69ca45..93b62cddc819e0d1fd48323e474a294ff0d327e1 100644 --- a/paddle/operators/type_alias.h +++ b/paddle/operators/type_alias.h @@ -25,6 +25,10 @@ using OpKernel = framework::OpKernel; using InferShapeContext = framework::InferShapeContext; using ExecutionContext = framework::ExecutionContext; using Variable = framework::Variable; +template +using EigenScalar = framework::EigenScalar; template