未验证 提交 b9bfcf14 编写于 作者: T tiancaishaonvjituizi 提交者: GitHub

fix sparse csr (#42271)

上级 d1e01232
......@@ -154,7 +154,7 @@ phi::TensorBase* SetSparseKernelOutput(Tensor* out, TensorType type) {
std::make_shared<phi::SparseCsrTensor>(phi::DenseTensor(),
phi::DenseTensor(),
phi::DenseTensor(),
phi::DDim{-1});
phi::DDim{-1, -1});
out->set_impl(sparse_tensor);
return sparse_tensor.get();
} else {
......
......@@ -27,9 +27,11 @@ SparseCsrTensor::SparseCsrTensor() {
inline void check_shape(const DDim& dims) {
bool valid = dims.size() == 2 || dims.size() == 3;
PADDLE_ENFORCE(valid,
phi::errors::InvalidArgument(
"the SparseCsrTensor only support 2-D Tensor."));
PADDLE_ENFORCE(
valid,
phi::errors::InvalidArgument("the SparseCsrTensor only support 2-D or "
"3-D Tensor, but get %d-D Tensor",
dims.size()));
}
#define Check(non_zero_crows, non_zero_cols, non_zero_elements, dims) \
{ \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册