未验证 提交 5a6a2d27 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Removed friend class EigenTensor/EigenMatrix/EigenVector from Tensor (#38607)

上级 02c17c0b
......@@ -57,7 +57,7 @@ struct EigenTensor {
}
static Type From(Tensor& tensor) { // NOLINT
return From(tensor, tensor.dims_);
return From(tensor, tensor.dims());
} // NOLINT
static ConstType From(const Tensor& tensor, DDim dims) {
......@@ -65,7 +65,7 @@ struct EigenTensor {
}
static ConstType From(const Tensor& tensor) {
return From(tensor, tensor.dims_);
return From(tensor, tensor.dims());
}
};
......@@ -74,7 +74,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {
static typename EigenMatrix::Type Reshape(Tensor& tensor, // NOLINT
int num_col_dims) {
int rank = tensor.dims_.size();
int rank = tensor.dims().size();
PADDLE_ENFORCE_EQ((num_col_dims > 0 && num_col_dims < rank), true,
platform::errors::InvalidArgument(
"Input dimension number(num_col_dims) must be "
......@@ -86,7 +86,7 @@ struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {
static typename EigenMatrix::ConstType Reshape(const Tensor& tensor,
int num_col_dims) {
int rank = tensor.dims_.size();
int rank = tensor.dims().size();
PADDLE_ENFORCE_EQ((num_col_dims > 0 && num_col_dims < rank), true,
platform::errors::InvalidArgument(
"Input dimension number(num_col_dims) must be "
......@@ -102,12 +102,12 @@ template <typename T, int MajorType = Eigen::RowMajor,
struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> {
// Flatten reshapes a Tensor into an EigenVector.
static typename EigenVector::Type Flatten(Tensor& tensor) { // NOLINT
return EigenVector::From(tensor, {product(tensor.dims_)});
return EigenVector::From(tensor, {product(tensor.dims())});
}
static typename EigenVector::ConstType Flatten(
const Tensor& tensor) { // NOLINT
return EigenVector::From(tensor, {product(tensor.dims_)});
return EigenVector::From(tensor, {product(tensor.dims())});
}
};
......
......@@ -111,16 +111,6 @@ class Tensor {
dnnl::memory::format_tag format_ = dnnl::memory::format_tag::undef;
#endif
public:
template <typename T, size_t D, int MajorType, typename IndexType>
friend struct EigenTensor;
template <typename T, int MajorType, typename IndexType>
friend struct EigenMatrix;
template <typename T, int MajorType, typename IndexType>
friend struct EigenVector;
public:
Tensor()
: type_(proto::VarType::FP32),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册