From 642d3c4687eb91c3a7fd026e3d8ae15957c8836d Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Tue, 18 Jul 2017 15:05:33 -0700 Subject: [PATCH] Refactorize Tensor to Eigen convesion --- paddle/framework/ddim.h | 11 ---- paddle/framework/eigen.h | 103 ++++++++++++++++++++++++++++++++ paddle/framework/tensor.h | 60 ------------------- paddle/framework/tensor_types.h | 67 --------------------- 4 files changed, 103 insertions(+), 138 deletions(-) create mode 100644 paddle/framework/eigen.h delete mode 100644 paddle/framework/tensor_types.h diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 070850375d1..06c4c583b3a 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -119,17 +119,6 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); -template -Eigen::DSizes ToEigenDSizes(const DDim& dims) { - int rank = arity(dims); - PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same"); - Eigen::DSizes dsizes; - for (int d = 0; d < rank; d++) { - dsizes[d] = dims[d]; - } - return dsizes; -} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h new file mode 100644 index 00000000000..edbbc2694a3 --- /dev/null +++ b/paddle/framework/eigen.h @@ -0,0 +1,103 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/platform/tensor.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace framework { + +// EigenDim converts paddle::platform::DDim into Eigen::DSizes. +template +struct EigenDim { + typedef Eigen::DSizes Type; + + static Type From(const DDim& dims) { + PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)"); + Type ret; + for (int d = 0; d < rank; d++) { + ret[d] = dims[d]; + } + return ret; + } +}; + +// Interpret paddle::platform::Tensor as EigenTensor and EigenConstTensor. +template +struct EigenTensor { + using Type = Eigen::TensorMap, + Eigen::Aligned>; + + using ConstType = + Eigen::TensorMap, + Eigen::Aligned> + ConstTensor; + + static Type From(Tensor& tensor, DDim dims) { + return Type(tensor.data(), EigenDim::From(dims)); + } + + static Type From(Tensor& tensor) { return From(tensor, tensor.dims_); } + + static ConstType From(const Tensor& tensor, DDim dims) { + return ConstType(tensor.data(), EigenDim::From(dims)); + } + + static ConstType From(const Tensor& tensor) { + return From(tensor, tensor.dims_); + } +}; + +// Interpret paddle::platform::Tensor as EigenVecotr and EigenConstVector. +template +struct EigenVector { + using EigenVector = + Eigen::TensorMap, + Eigen::Aligned>; + + using EigenConstVector = + Eigen::TensorMap, + Eigen::Aligned>; + + static Type From(Tensor& tensor) { return EigenTensor::From(tensor); } + + static ConstType From(const Tensor& tensor) { + return EigenTensor::From(tensor); + } +}; + +// Interpret paddle::platform::Tensor as EigenMatrix and EigenConstMatrix. +template +struct EigenMatrix { + template + using EigenMatrix = + Eigen::TensorMap, + Eigen::Aligned>; + + template + using EigenConstMatrix = + Eigen::TensorMap, + Eigen::Aligned>; + + static Type From(Tensor& tensor) { return EigenTensor::From(tensor); } + + static ConstType From(const Tensor& tensor) { + return EigenTensor::From(tensor); + } +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 4f07350e59d..1235b532273 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -86,66 +86,6 @@ class Tensor { offset_); } - template - typename TTypes::Tensor shaped(DDim new_dims) { - Eigen::array dims = - paddle::framework::ToEigenDSizes(new_dims); - return typename TTypes::Tensor(raw_data(), dims); - } - - template - typename TTypes::Tensor tensor() { - return typename TTypes::Tensor( - raw_data(), paddle::framework::ToEigenDSizes(dims_)); - } - - // flat to rank = 1 - template - typename TTypes::Flat flat() { - return shaped(make_ddim({static_cast(product(dims_))})); - } - - // to TensorType Vec - template - typename TTypes::Vec vec() { - return tensor(); - } - - // to TensorType Matrix - template - typename TTypes::Matrix matrix() { - return tensor(); - } - - // const versions of all the methods above. - template - typename TTypes::Tensor shaped(DDim new_dims) const { - Eigen::array dims = - paddle::framework::ToEigenDSizes(new_dims); - return typename TTypes::Tensor(data(), dims); - } - - template - typename TTypes::ConstantTensor tensor() const { - return typename TTypes::Tensor( - data(), paddle::framework::ToEigenDSizes(dims_)); - } - - template - typename TTypes::ConstFlat flat() const { - return shaped(make_ddim({static_cast(product(dims_))})); - } - - template - typename TTypes::ConstVec vec() const { - return tensor(); - } - - template - typename TTypes::ConstMatrix matrix() const { - return tensor(); - } - template void ShareDataFrom(const Tensor& src) { src.CheckDims(); diff --git a/paddle/framework/tensor_types.h b/paddle/framework/tensor_types.h deleted file mode 100644 index 4bf27a377e8..00000000000 --- a/paddle/framework/tensor_types.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "unsupported/Eigen/CXX11/Tensor" - -namespace paddle { -namespace framework { - -// Helper to define Tensor types given that the scalar is of type T. -template -struct TTypes { - // Rank- tensor of scalar type T. - typedef Eigen::TensorMap, - Eigen::Aligned> - Tensor; - typedef Eigen::TensorMap< - Eigen::Tensor, Eigen::Aligned> - ConstTensor; - - // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. - typedef Eigen::TensorMap< - Eigen::TensorFixedSize, Eigen::RowMajor, IndexType>, - Eigen::Aligned> - Scalar; - typedef Eigen::TensorMap, - Eigen::RowMajor, IndexType>, - Eigen::Aligned> - ConstScalar; - - // Rank-1 tensor (vector) of scalar type T. - typedef Eigen::TensorMap, - Eigen::Aligned> - Flat; - typedef Eigen::TensorMap< - Eigen::Tensor, Eigen::Aligned> - ConstFlat; - typedef Eigen::TensorMap, - Eigen::Aligned> - Vec; - typedef Eigen::TensorMap< - Eigen::Tensor, Eigen::Aligned> - ConstVec; - - // Rank-2 tensor (matrix) of scalar type T. - typedef Eigen::TensorMap, - Eigen::Aligned> - Matrix; - typedef Eigen::TensorMap< - Eigen::Tensor, Eigen::Aligned> - ConstMatrix; -}; - -} // namespace framework -} // namespace paddle -- GitLab