未验证 提交 b9bfcf14 编写于 作者: T tiancaishaonvjituizi 提交者: GitHub

fix sparse csr (#42271)

上级 d1e01232
...@@ -154,7 +154,7 @@ phi::TensorBase* SetSparseKernelOutput(Tensor* out, TensorType type) { ...@@ -154,7 +154,7 @@ phi::TensorBase* SetSparseKernelOutput(Tensor* out, TensorType type) {
std::make_shared<phi::SparseCsrTensor>(phi::DenseTensor(), std::make_shared<phi::SparseCsrTensor>(phi::DenseTensor(),
phi::DenseTensor(), phi::DenseTensor(),
phi::DenseTensor(), phi::DenseTensor(),
phi::DDim{-1}); phi::DDim{-1, -1});
out->set_impl(sparse_tensor); out->set_impl(sparse_tensor);
return sparse_tensor.get(); return sparse_tensor.get();
} else { } else {
......
...@@ -27,9 +27,11 @@ SparseCsrTensor::SparseCsrTensor() { ...@@ -27,9 +27,11 @@ SparseCsrTensor::SparseCsrTensor() {
inline void check_shape(const DDim& dims) { inline void check_shape(const DDim& dims) {
bool valid = dims.size() == 2 || dims.size() == 3; bool valid = dims.size() == 2 || dims.size() == 3;
PADDLE_ENFORCE(valid, PADDLE_ENFORCE(
phi::errors::InvalidArgument( valid,
"the SparseCsrTensor only support 2-D Tensor.")); phi::errors::InvalidArgument("the SparseCsrTensor only support 2-D or "
"3-D Tensor, but get %d-D Tensor",
dims.size()));
} }
#define Check(non_zero_crows, non_zero_cols, non_zero_elements, dims) \ #define Check(non_zero_crows, non_zero_cols, non_zero_elements, dims) \
{ \ { \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册