未验证 提交 5a6a2d27 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Removed friend class EigenTensor/EigenMatrix/EigenVector from Tensor (#38607)

上级 02c17c0b
...@@ -57,7 +57,7 @@ struct EigenTensor { ...@@ -57,7 +57,7 @@ struct EigenTensor {
} }
static Type From(Tensor& tensor) { // NOLINT static Type From(Tensor& tensor) { // NOLINT
return From(tensor, tensor.dims_); return From(tensor, tensor.dims());
} // NOLINT } // NOLINT
static ConstType From(const Tensor& tensor, DDim dims) { static ConstType From(const Tensor& tensor, DDim dims) {
...@@ -65,7 +65,7 @@ struct EigenTensor { ...@@ -65,7 +65,7 @@ struct EigenTensor {
} }
static ConstType From(const Tensor& tensor) { static ConstType From(const Tensor& tensor) {
return From(tensor, tensor.dims_); return From(tensor, tensor.dims());
} }
}; };
...@@ -74,7 +74,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -74,7 +74,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> { struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {
static typename EigenMatrix::Type Reshape(Tensor& tensor, // NOLINT static typename EigenMatrix::Type Reshape(Tensor& tensor, // NOLINT
int num_col_dims) { int num_col_dims) {
int rank = tensor.dims_.size(); int rank = tensor.dims().size();
PADDLE_ENFORCE_EQ((num_col_dims > 0 && num_col_dims < rank), true, PADDLE_ENFORCE_EQ((num_col_dims > 0 && num_col_dims < rank), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input dimension number(num_col_dims) must be " "Input dimension number(num_col_dims) must be "
...@@ -86,7 +86,7 @@ struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> { ...@@ -86,7 +86,7 @@ struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {
static typename EigenMatrix::ConstType Reshape(const Tensor& tensor, static typename EigenMatrix::ConstType Reshape(const Tensor& tensor,
int num_col_dims) { int num_col_dims) {
int rank = tensor.dims_.size(); int rank = tensor.dims().size();
PADDLE_ENFORCE_EQ((num_col_dims > 0 && num_col_dims < rank), true, PADDLE_ENFORCE_EQ((num_col_dims > 0 && num_col_dims < rank), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input dimension number(num_col_dims) must be " "Input dimension number(num_col_dims) must be "
...@@ -102,12 +102,12 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -102,12 +102,12 @@ template <typename T, int MajorType = Eigen::RowMajor,
struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> { struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> {
// Flatten reshapes a Tensor into an EigenVector. // Flatten reshapes a Tensor into an EigenVector.
static typename EigenVector::Type Flatten(Tensor& tensor) { // NOLINT static typename EigenVector::Type Flatten(Tensor& tensor) { // NOLINT
return EigenVector::From(tensor, {product(tensor.dims_)}); return EigenVector::From(tensor, {product(tensor.dims())});
} }
static typename EigenVector::ConstType Flatten( static typename EigenVector::ConstType Flatten(
const Tensor& tensor) { // NOLINT const Tensor& tensor) { // NOLINT
return EigenVector::From(tensor, {product(tensor.dims_)}); return EigenVector::From(tensor, {product(tensor.dims())});
} }
}; };
......
...@@ -111,16 +111,6 @@ class Tensor { ...@@ -111,16 +111,6 @@ class Tensor {
dnnl::memory::format_tag format_ = dnnl::memory::format_tag::undef; dnnl::memory::format_tag format_ = dnnl::memory::format_tag::undef;
#endif #endif
public:
template <typename T, size_t D, int MajorType, typename IndexType>
friend struct EigenTensor;
template <typename T, int MajorType, typename IndexType>
friend struct EigenMatrix;
template <typename T, int MajorType, typename IndexType>
friend struct EigenVector;
public: public:
Tensor() Tensor()
: type_(proto::VarType::FP32), : type_(proto::VarType::FP32),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册