From 5a6a2d27d73969d3a4cae4fa510a2ce547bcc572 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Fri, 31 Dec 2021 11:28:32 +0800 Subject: [PATCH] Removed friend class EigenTensor/EigenMatrix/EigenVector from Tensor (#38607) --- paddle/fluid/framework/eigen.h | 12 ++++++------ paddle/fluid/framework/tensor.h | 10 ---------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/framework/eigen.h b/paddle/fluid/framework/eigen.h index a6abda8a83b..2970c52db1a 100644 --- a/paddle/fluid/framework/eigen.h +++ b/paddle/fluid/framework/eigen.h @@ -57,7 +57,7 @@ struct EigenTensor { } static Type From(Tensor& tensor) { // NOLINT - return From(tensor, tensor.dims_); + return From(tensor, tensor.dims()); } // NOLINT static ConstType From(const Tensor& tensor, DDim dims) { @@ -65,7 +65,7 @@ struct EigenTensor { } static ConstType From(const Tensor& tensor) { - return From(tensor, tensor.dims_); + return From(tensor, tensor.dims()); } }; @@ -74,7 +74,7 @@ template { static typename EigenMatrix::Type Reshape(Tensor& tensor, // NOLINT int num_col_dims) { - int rank = tensor.dims_.size(); + int rank = tensor.dims().size(); PADDLE_ENFORCE_EQ((num_col_dims > 0 && num_col_dims < rank), true, platform::errors::InvalidArgument( "Input dimension number(num_col_dims) must be " @@ -86,7 +86,7 @@ struct EigenMatrix : public EigenTensor { static typename EigenMatrix::ConstType Reshape(const Tensor& tensor, int num_col_dims) { - int rank = tensor.dims_.size(); + int rank = tensor.dims().size(); PADDLE_ENFORCE_EQ((num_col_dims > 0 && num_col_dims < rank), true, platform::errors::InvalidArgument( "Input dimension number(num_col_dims) must be " @@ -102,12 +102,12 @@ template { // Flatten reshapes a Tensor into an EigenVector. static typename EigenVector::Type Flatten(Tensor& tensor) { // NOLINT - return EigenVector::From(tensor, {product(tensor.dims_)}); + return EigenVector::From(tensor, {product(tensor.dims())}); } static typename EigenVector::ConstType Flatten( const Tensor& tensor) { // NOLINT - return EigenVector::From(tensor, {product(tensor.dims_)}); + return EigenVector::From(tensor, {product(tensor.dims())}); } }; diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 4f54ce33c14..b7cc57d5e04 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -111,16 +111,6 @@ class Tensor { dnnl::memory::format_tag format_ = dnnl::memory::format_tag::undef; #endif - public: - template - friend struct EigenTensor; - - template - friend struct EigenMatrix; - - template - friend struct EigenVector; - public: Tensor() : type_(proto::VarType::FP32), -- GitLab