diff --git a/paddle/phi/core/meta_tensor.cc b/paddle/phi/core/meta_tensor.cc index 70e42cc849bb0746c0e3c71e3864e00f51bca2f4..534efcdc79efb08d3b1f5a958612cc72d3a5f90b 100644 --- a/paddle/phi/core/meta_tensor.cc +++ b/paddle/phi/core/meta_tensor.cc @@ -39,7 +39,11 @@ int64_t MetaTensor::numel() const { DDim MetaTensor::dims() const { ValidCheck(*this); - return tensor_->dims(); + if (phi::SelectedRows::classof(tensor_)) { + return static_cast(tensor_)->GetCompleteDims(); + } else { + return tensor_->dims(); + } } DataType MetaTensor::dtype() const { @@ -61,9 +65,7 @@ void MetaTensor::set_dims(const DDim& dims) { StringTensorUtils::GetMutableMeta(static_cast(tensor_)) ->dims = dims; } else if (phi::SelectedRows::classof(tensor_)) { - DenseTensorUtils::GetMutableMeta( - static_cast(tensor_)->mutable_value()) - ->dims = dims; + static_cast(tensor_)->set_height(dims[0]); } else if (phi::SparseCooTensor::classof(tensor_)) { DenseTensorUtils::GetMutableMeta(static_cast(tensor_)) ->dims = dims; @@ -179,7 +181,6 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) { bool is_sparse_coo = phi::SparseCooTensor::classof(tensor_); bool is_sparse_csr = phi::SparseCsrTensor::classof(tensor_); if (is_dense_tensor || is_selected_rows || is_sparse_coo || is_sparse_csr) { - set_dims(meta_tensor.dims()); if (is_selected_rows) { const auto in_tensor_base = meta_tensor.tensor(); PADDLE_ENFORCE_EQ( @@ -191,6 +192,11 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) { auto* selected_rows_in = static_cast(in_tensor_base); selected_rows_out->set_rows(selected_rows_in->rows()); selected_rows_out->set_height(selected_rows_in->height()); + DenseTensorUtils::GetMutableMeta( + static_cast(tensor_)->mutable_value()) + ->dims = selected_rows_in->mutable_value()->dims(); + } else { + set_dims(meta_tensor.dims()); } } else { PADDLE_THROW(phi::errors::Unimplemented( diff --git a/paddle/phi/core/selected_rows.h b/paddle/phi/core/selected_rows.h index a71c0471cc431c8e988f00062352aaf8dfcaec3e..c011605809e4417e585f3ba3a4e5d577a08a1837 100644 --- a/paddle/phi/core/selected_rows.h +++ b/paddle/phi/core/selected_rows.h @@ -132,10 +132,7 @@ class SelectedRows : public TensorBase, /// \brief Returns the dims of the tensor. /// \return The dims of the tensor. - const DDim& dims() const noexcept override { - return impl_->dims(); - // return phi::make_ddim(dims); - } + const DDim& dims() const noexcept override { return impl_->dims(); } /// \brief Returns the data type of the tensor. /// \return The data type of the tensor.