diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index 54bbeafcabdeeb1e2c1017c156b3512c83dada3a..c3c29dcdbcf17b11e54ac063ba36758cfc4314db 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -14,27 +14,43 @@ limitations under the License. */ #pragma once +#include "paddle/framework/ddim.h" #include "paddle/framework/tensor.h" +#include "paddle/platform/variant.h" #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace framework { -// EigenDim converts paddle::platform::DDim into Eigen::DSizes. -template -struct EigenDim { - using Type = Eigen::DSizes; +template +using EigenDim = Eigen::DSizes; - static Type From(const DDim& dims) { - PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)"); - Type ret; - for (int64_t d = 0; d < arity(dims); d++) { +using EigenDDim = boost::variant, EigenDim<2>, EigenDim<3>, + EigenDim<4>, EigenDim<5>, EigenDim<6>, + EigenDim<7>, EigenDim<8>, EigenDim<9>>; + +struct EigenDDimConvertVisitor : public boost::static_visitor { + template + EigenDDim operator()(const DimType& dims) const { + constexpr int arity = DimType::dimensions; + Eigen::DSizes ret; + for (int64_t d = 0; d < arity; ++d) { ret[d] = dims[d]; } return ret; } }; +inline EigenDDim DDimToEigenDDim(const DDim& dims) { + return boost::apply_visitor(EigenDDimConvertVisitor(), dims); +} + +template +inline auto VisitEigenDDim(Visitor visitor, const EigenDDim& ddim) -> + typename Visitor::result_type { + return boost::apply_visitor(visitor, ddim); +} + // Interpret paddle::platform::Tensor as EigenTensor and EigenConstTensor. template @@ -47,13 +63,15 @@ struct EigenTensor { Eigen::TensorMap>; static Type From(Tensor& tensor, DDim dims) { - return Type(tensor.data(), EigenDim::From(dims)); + return Type(tensor.data(), + boost::get>(DDimToEigenDDim(dims))); } static Type From(Tensor& tensor) { return From(tensor, tensor.dims_); } static ConstType From(const Tensor& tensor, DDim dims) { - return ConstType(tensor.data(), EigenDim::From(dims)); + return ConstType(tensor.data(), + boost::get>(DDimToEigenDDim(dims))); } static ConstType From(const Tensor& tensor) { diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc index bc4a2db32cfba66bef2c444e1f822e0d2a57b91e..d4190e2fe9cecf84131abd0b27c3c6429c0e3610 100644 --- a/paddle/framework/eigen_test.cc +++ b/paddle/framework/eigen_test.cc @@ -17,11 +17,36 @@ namespace paddle { namespace framework { -TEST(EigenDim, From) { - EigenDim<3>::Type ed = EigenDim<3>::From(make_ddim({1, 2, 3})); - ASSERT_EQ(1, ed[0]); - ASSERT_EQ(2, ed[1]); - ASSERT_EQ(3, ed[2]); +TEST(EigenDim, FromDDimToEigenDDim) { + auto eigen_ddim = DDimToEigenDDim({1, 2, 3}); + auto& eigen_dim = boost::get>(eigen_ddim); + ASSERT_EQ(1, eigen_dim[0]); + ASSERT_EQ(2, eigen_dim[1]); + ASSERT_EQ(3, eigen_dim[2]); +} + +struct ProductVisit : public boost::static_visitor { + template + int64_t operator()(const EigenDim& dim) const { + int64_t prod = 1; + for (auto& item : dim) { + prod *= item; + } + return prod; + } +}; + +TEST(EigenDim, Visit) { + std::vector tmp(5); + int64_t expect = 1; + for (int i = 0; i < 5; ++i) { + tmp[i] = i + 1; + expect *= tmp[i]; + } + auto eigen_ddim = DDimToEigenDDim(make_ddim(tmp)); + + int64_t actual = VisitEigenDDim(ProductVisit(), eigen_ddim); + ASSERT_EQ(expect, actual); } TEST(Eigen, Tensor) {