未验证 提交 f30ead13 编写于 作者: Y YuanRisheng 提交者: GitHub

[BUG FIX]Fix MetaTensor's bug when run infermeta (#46265)

* fix sum bug

* fix ci bugs

* fix ci bugs

* update code according comment
上级 b1771368
......@@ -39,7 +39,11 @@ int64_t MetaTensor::numel() const {
DDim MetaTensor::dims() const {
ValidCheck(*this);
return tensor_->dims();
if (phi::SelectedRows::classof(tensor_)) {
return static_cast<SelectedRows*>(tensor_)->GetCompleteDims();
} else {
return tensor_->dims();
}
}
DataType MetaTensor::dtype() const {
......@@ -61,9 +65,7 @@ void MetaTensor::set_dims(const DDim& dims) {
StringTensorUtils::GetMutableMeta(static_cast<StringTensor*>(tensor_))
->dims = dims;
} else if (phi::SelectedRows::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(
static_cast<SelectedRows*>(tensor_)->mutable_value())
->dims = dims;
static_cast<SelectedRows*>(tensor_)->set_height(dims[0]);
} else if (phi::SparseCooTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<SparseCooTensor*>(tensor_))
->dims = dims;
......@@ -179,7 +181,6 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
bool is_sparse_coo = phi::SparseCooTensor::classof(tensor_);
bool is_sparse_csr = phi::SparseCsrTensor::classof(tensor_);
if (is_dense_tensor || is_selected_rows || is_sparse_coo || is_sparse_csr) {
set_dims(meta_tensor.dims());
if (is_selected_rows) {
const auto in_tensor_base = meta_tensor.tensor();
PADDLE_ENFORCE_EQ(
......@@ -191,6 +192,11 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
auto* selected_rows_in = static_cast<SelectedRows*>(in_tensor_base);
selected_rows_out->set_rows(selected_rows_in->rows());
selected_rows_out->set_height(selected_rows_in->height());
DenseTensorUtils::GetMutableMeta(
static_cast<SelectedRows*>(tensor_)->mutable_value())
->dims = selected_rows_in->mutable_value()->dims();
} else {
set_dims(meta_tensor.dims());
}
} else {
PADDLE_THROW(phi::errors::Unimplemented(
......
......@@ -132,10 +132,7 @@ class SelectedRows : public TensorBase,
/// \brief Returns the dims of the tensor.
/// \return The dims of the tensor.
const DDim& dims() const noexcept override {
return impl_->dims();
// return phi::make_ddim(dims);
}
const DDim& dims() const noexcept override { return impl_->dims(); }
/// \brief Returns the data type of the tensor.
/// \return The data type of the tensor.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册