未验证 提交 03517d8a 编写于 作者: Z zhangkaihuo 提交者: GitHub

fix batch csr (#43553)

* fix to_sparse_csr
上级 5a5649c2
......@@ -206,7 +206,11 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx,
if (batchs > 1) {
for (int i = 0; i < non_zero_num; i++) {
if (i == non_zero_num - 1 || batchs_ptr[i] != batchs_ptr[i + 1]) {
offsets[batchs_ptr[i]] = i + 1;
const int start = batchs_ptr[i];
const int end = i == non_zero_num - 1 ? batchs : batchs_ptr[i + 1];
for (int j = start; j < end; j++) {
offsets[j] = i + 1;
}
}
}
} else {
......@@ -214,7 +218,6 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx,
}
for (int b = 0; b < batchs; b++) {
if (offsets[b] == 0) continue;
int batch_start = 0;
int batch_non_zero_num = offsets[b];
if (b > 0) {
......@@ -233,6 +236,9 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx,
for (IntT i = coo_rows_ptr[batch_non_zero_num - 1] + 1; i < rows + 1; i++) {
csr_crows_data[b * (rows + 1) + i] = batch_non_zero_num;
}
if (batch_non_zero_num == 0) {
memset(csr_crows_data + b * (rows + 1), 0, sizeof(IntT) * (rows + 1));
}
}
memcpy(csr_cols_data, coo_cols_data, sizeof(IntT) * non_zero_num);
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/common_shape.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
......@@ -283,19 +284,24 @@ void SparseCsrToCooKernel(const Context& dev_ctx,
template <typename IntT>
__global__ void GetBatchsOffset(const IntT* batchs_ptr,
const int batchs,
const int non_zero_num,
IntT* batchs_offset) {
int* batchs_offset) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) {
if (i == non_zero_num - 1 || batchs_ptr[i] != batchs_ptr[i + 1]) {
batchs_offset[batchs_ptr[i]] = i + 1;
const int start = batchs_ptr[i];
const int end = i == non_zero_num - 1 ? batchs : batchs_ptr[i + 1];
for (int j = start; j < end; j++) {
batchs_offset[j] = i + 1;
}
}
}
}
template <typename IntT>
__global__ void ConvertCooRowsToCsrCrows(
const IntT* batchs_offset, // can be null if batchs = 1
const int* batchs_offset, // can be null if batchs = 1
const IntT* coo_rows_data,
IntT* csr_crows_data,
const int rows,
......@@ -303,12 +309,12 @@ __global__ void ConvertCooRowsToCsrCrows(
const int b = blockIdx.y;
int batch_non_zero_num =
batchs_offset == nullptr ? non_zero_num : batchs_offset[b];
if (batch_non_zero_num == 0) return;
IntT batch_start = 0;
if (b > 0) {
batch_start = batchs_offset[b - 1];
batch_non_zero_num -= batch_start;
}
const IntT* coo_rows_ptr = coo_rows_data + batch_start;
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < batch_non_zero_num; i += gridDim.x * blockDim.x) {
......@@ -328,6 +334,11 @@ __global__ void ConvertCooRowsToCsrCrows(
}
}
}
if (batch_non_zero_num == 0) {
for (int i = tid; i < rows + 1; i += gridDim.x * blockDim.x) {
csr_crows_data[b * (rows + 1) + i] = 0;
}
}
}
template <typename T, typename IntT>
......@@ -365,13 +376,19 @@ void SparseCooToCsrGPUKernel(const GPUContext& dev_ctx,
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, batchs, 1);
if (batchs > 1) {
phi::DenseTensor batchs_offset = phi::Empty<IntT>(dev_ctx, {batchs});
IntT* batchs_offset_ptr = batchs_offset.data<IntT>();
GetBatchsOffset<IntT>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(batchs_ptr, non_zero_num, batchs_offset_ptr);
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1);
phi::DenseTensor batchs_offset = phi::Empty<int>(dev_ctx, {batchs});
int* batchs_offset_ptr = batchs_offset.data<int>();
phi::funcs::SetConstant<GPUContext, int> set_zero;
// set zero if the nnz=0 of batchs[0]
set_zero(dev_ctx, &batchs_offset, static_cast<IntT>(0));
GetBatchsOffset<IntT><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(
batchs_ptr, batchs, non_zero_num, batchs_offset_ptr);
config.block_per_grid.y = batchs;
ConvertCooRowsToCsrCrows<IntT><<<config.block_per_grid,
config.thread_per_block.x,
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle
from paddle.incubate import sparse
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
......@@ -315,6 +316,53 @@ class TestSparseConvert(unittest.TestCase):
assert np.array_equal(values_sorted,
sparse_x.values().numpy())
def test_batch_csr(self):
with _test_eager_guard():
shape = [3, 3, 3]
def verify(x, crows, cols, values):
x = paddle.to_tensor(x)
csr = x.to_sparse_csr()
assert np.allclose(crows, csr.crows().numpy())
assert np.allclose(cols, csr.cols().numpy())
assert np.allclose(values, csr.values().numpy())
dense = csr.to_dense()
assert np.allclose(x.numpy(), dense.numpy())
x = [
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
]
crows = [[0, 1, 2, 3, 0, 0, 0, 0, 0, 1, 2, 3]]
cols = [0, 1, 2, 0, 1, 2]
values = [1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
verify(x, crows, cols, values)
x = [
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
]
crows = [[0, 0, 0, 0, 0, 1, 2, 3, 0, 1, 2, 3]]
cols = [0, 1, 2, 0, 1, 2]
values = [1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
verify(x, crows, cols, values)
x = [
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
]
crows = [[0, 1, 2, 3, 0, 1, 2, 3, 0, 0, 0, 0]]
cols = [0, 1, 2, 0, 1, 2]
values = [1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
verify(x, crows, cols, values)
class TestCooError(unittest.TestCase):
......
......@@ -249,6 +249,7 @@ def sparse_csr_tensor(crows,
raise ValueError(
"SparseCsrTensor only support 2-D or 3-D matrix. but get shape {}".
format(shape))
rows = shape[len(shape) - 2]
if not crows.place._equals(place):
crows = crows._copy_to(place, False)
......@@ -268,10 +269,10 @@ def sparse_csr_tensor(crows,
raise ValueError("the length of cols must be same as length of values")
if len(shape) == 2:
if crows.shape[0] != shape[0] + 1:
if crows.shape[0] != rows + 1:
raise ValueError(
"The length({}) of crows must be equal to the rows({})+1 of matrix."
.format(crows.shape[0], shape[0]))
.format(crows.shape[0], rows))
if crows[0] != 0:
raise ValueError("the 0th value of crows must be 0")
......@@ -279,10 +280,10 @@ def sparse_csr_tensor(crows,
raise ValueError(
"the last value of crows must be equal the number of non-zero")
else:
if crows.shape[0] % (shape[0] + 1) != 0:
if crows.shape[0] % (rows + 1) != 0:
raise ValueError(
"The length({}) of crows must be divisible the rows({})+1 of matrix."
.format(crows.shape[0], shape[0]))
.format(crows.shape[0], rows))
# TODO(zkh2016): check whether the value in crows and cols is legal
return core.eager.sparse_csr_tensor(crows, cols, values, shape,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册