未验证 提交 03517d8a 编写于 作者: Z zhangkaihuo 提交者: GitHub

fix batch csr (#43553)

* fix to_sparse_csr
上级 5a5649c2
...@@ -206,7 +206,11 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx, ...@@ -206,7 +206,11 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx,
if (batchs > 1) { if (batchs > 1) {
for (int i = 0; i < non_zero_num; i++) { for (int i = 0; i < non_zero_num; i++) {
if (i == non_zero_num - 1 || batchs_ptr[i] != batchs_ptr[i + 1]) { if (i == non_zero_num - 1 || batchs_ptr[i] != batchs_ptr[i + 1]) {
offsets[batchs_ptr[i]] = i + 1; const int start = batchs_ptr[i];
const int end = i == non_zero_num - 1 ? batchs : batchs_ptr[i + 1];
for (int j = start; j < end; j++) {
offsets[j] = i + 1;
}
} }
} }
} else { } else {
...@@ -214,7 +218,6 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx, ...@@ -214,7 +218,6 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx,
} }
for (int b = 0; b < batchs; b++) { for (int b = 0; b < batchs; b++) {
if (offsets[b] == 0) continue;
int batch_start = 0; int batch_start = 0;
int batch_non_zero_num = offsets[b]; int batch_non_zero_num = offsets[b];
if (b > 0) { if (b > 0) {
...@@ -233,6 +236,9 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx, ...@@ -233,6 +236,9 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx,
for (IntT i = coo_rows_ptr[batch_non_zero_num - 1] + 1; i < rows + 1; i++) { for (IntT i = coo_rows_ptr[batch_non_zero_num - 1] + 1; i < rows + 1; i++) {
csr_crows_data[b * (rows + 1) + i] = batch_non_zero_num; csr_crows_data[b * (rows + 1) + i] = batch_non_zero_num;
} }
if (batch_non_zero_num == 0) {
memset(csr_crows_data + b * (rows + 1), 0, sizeof(IntT) * (rows + 1));
}
} }
memcpy(csr_cols_data, coo_cols_data, sizeof(IntT) * non_zero_num); memcpy(csr_cols_data, coo_cols_data, sizeof(IntT) * non_zero_num);
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/visit_type.h" #include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/common_shape.h" #include "paddle/phi/kernels/funcs/sparse/common_shape.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" #include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
...@@ -283,19 +284,24 @@ void SparseCsrToCooKernel(const Context& dev_ctx, ...@@ -283,19 +284,24 @@ void SparseCsrToCooKernel(const Context& dev_ctx,
template <typename IntT> template <typename IntT>
__global__ void GetBatchsOffset(const IntT* batchs_ptr, __global__ void GetBatchsOffset(const IntT* batchs_ptr,
const int batchs,
const int non_zero_num, const int non_zero_num,
IntT* batchs_offset) { int* batchs_offset) {
int tid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) { for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) {
if (i == non_zero_num - 1 || batchs_ptr[i] != batchs_ptr[i + 1]) { if (i == non_zero_num - 1 || batchs_ptr[i] != batchs_ptr[i + 1]) {
batchs_offset[batchs_ptr[i]] = i + 1; const int start = batchs_ptr[i];
const int end = i == non_zero_num - 1 ? batchs : batchs_ptr[i + 1];
for (int j = start; j < end; j++) {
batchs_offset[j] = i + 1;
}
} }
} }
} }
template <typename IntT> template <typename IntT>
__global__ void ConvertCooRowsToCsrCrows( __global__ void ConvertCooRowsToCsrCrows(
const IntT* batchs_offset, // can be null if batchs = 1 const int* batchs_offset, // can be null if batchs = 1
const IntT* coo_rows_data, const IntT* coo_rows_data,
IntT* csr_crows_data, IntT* csr_crows_data,
const int rows, const int rows,
...@@ -303,12 +309,12 @@ __global__ void ConvertCooRowsToCsrCrows( ...@@ -303,12 +309,12 @@ __global__ void ConvertCooRowsToCsrCrows(
const int b = blockIdx.y; const int b = blockIdx.y;
int batch_non_zero_num = int batch_non_zero_num =
batchs_offset == nullptr ? non_zero_num : batchs_offset[b]; batchs_offset == nullptr ? non_zero_num : batchs_offset[b];
if (batch_non_zero_num == 0) return;
IntT batch_start = 0; IntT batch_start = 0;
if (b > 0) { if (b > 0) {
batch_start = batchs_offset[b - 1]; batch_start = batchs_offset[b - 1];
batch_non_zero_num -= batch_start; batch_non_zero_num -= batch_start;
} }
const IntT* coo_rows_ptr = coo_rows_data + batch_start; const IntT* coo_rows_ptr = coo_rows_data + batch_start;
const int tid = threadIdx.x + blockIdx.x * blockDim.x; const int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < batch_non_zero_num; i += gridDim.x * blockDim.x) { for (int i = tid; i < batch_non_zero_num; i += gridDim.x * blockDim.x) {
...@@ -328,6 +334,11 @@ __global__ void ConvertCooRowsToCsrCrows( ...@@ -328,6 +334,11 @@ __global__ void ConvertCooRowsToCsrCrows(
} }
} }
} }
if (batch_non_zero_num == 0) {
for (int i = tid; i < rows + 1; i += gridDim.x * blockDim.x) {
csr_crows_data[b * (rows + 1) + i] = 0;
}
}
} }
template <typename T, typename IntT> template <typename T, typename IntT>
...@@ -365,13 +376,19 @@ void SparseCooToCsrGPUKernel(const GPUContext& dev_ctx, ...@@ -365,13 +376,19 @@ void SparseCooToCsrGPUKernel(const GPUContext& dev_ctx,
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, batchs, 1); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, batchs, 1);
if (batchs > 1) { if (batchs > 1) {
phi::DenseTensor batchs_offset = phi::Empty<IntT>(dev_ctx, {batchs}); auto config =
IntT* batchs_offset_ptr = batchs_offset.data<IntT>(); phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1);
GetBatchsOffset<IntT> phi::DenseTensor batchs_offset = phi::Empty<int>(dev_ctx, {batchs});
<<<config.block_per_grid.x, int* batchs_offset_ptr = batchs_offset.data<int>();
phi::funcs::SetConstant<GPUContext, int> set_zero;
// set zero if the nnz=0 of batchs[0]
set_zero(dev_ctx, &batchs_offset, static_cast<IntT>(0));
GetBatchsOffset<IntT><<<config.block_per_grid.x,
config.thread_per_block.x, config.thread_per_block.x,
0, 0,
dev_ctx.stream()>>>(batchs_ptr, non_zero_num, batchs_offset_ptr); dev_ctx.stream()>>>(
batchs_ptr, batchs, non_zero_num, batchs_offset_ptr);
config.block_per_grid.y = batchs; config.block_per_grid.y = batchs;
ConvertCooRowsToCsrCrows<IntT><<<config.block_per_grid, ConvertCooRowsToCsrCrows<IntT><<<config.block_per_grid,
config.thread_per_block.x, config.thread_per_block.x,
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle import paddle
from paddle.incubate import sparse
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard from paddle.fluid.framework import _test_eager_guard
...@@ -315,6 +316,53 @@ class TestSparseConvert(unittest.TestCase): ...@@ -315,6 +316,53 @@ class TestSparseConvert(unittest.TestCase):
assert np.array_equal(values_sorted, assert np.array_equal(values_sorted,
sparse_x.values().numpy()) sparse_x.values().numpy())
def test_batch_csr(self):
with _test_eager_guard():
shape = [3, 3, 3]
def verify(x, crows, cols, values):
x = paddle.to_tensor(x)
csr = x.to_sparse_csr()
assert np.allclose(crows, csr.crows().numpy())
assert np.allclose(cols, csr.cols().numpy())
assert np.allclose(values, csr.values().numpy())
dense = csr.to_dense()
assert np.allclose(x.numpy(), dense.numpy())
x = [
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
]
crows = [[0, 1, 2, 3, 0, 0, 0, 0, 0, 1, 2, 3]]
cols = [0, 1, 2, 0, 1, 2]
values = [1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
verify(x, crows, cols, values)
x = [
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
]
crows = [[0, 0, 0, 0, 0, 1, 2, 3, 0, 1, 2, 3]]
cols = [0, 1, 2, 0, 1, 2]
values = [1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
verify(x, crows, cols, values)
x = [
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
]
crows = [[0, 1, 2, 3, 0, 1, 2, 3, 0, 0, 0, 0]]
cols = [0, 1, 2, 0, 1, 2]
values = [1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
verify(x, crows, cols, values)
class TestCooError(unittest.TestCase): class TestCooError(unittest.TestCase):
......
...@@ -249,6 +249,7 @@ def sparse_csr_tensor(crows, ...@@ -249,6 +249,7 @@ def sparse_csr_tensor(crows,
raise ValueError( raise ValueError(
"SparseCsrTensor only support 2-D or 3-D matrix. but get shape {}". "SparseCsrTensor only support 2-D or 3-D matrix. but get shape {}".
format(shape)) format(shape))
rows = shape[len(shape) - 2]
if not crows.place._equals(place): if not crows.place._equals(place):
crows = crows._copy_to(place, False) crows = crows._copy_to(place, False)
...@@ -268,10 +269,10 @@ def sparse_csr_tensor(crows, ...@@ -268,10 +269,10 @@ def sparse_csr_tensor(crows,
raise ValueError("the length of cols must be same as length of values") raise ValueError("the length of cols must be same as length of values")
if len(shape) == 2: if len(shape) == 2:
if crows.shape[0] != shape[0] + 1: if crows.shape[0] != rows + 1:
raise ValueError( raise ValueError(
"The length({}) of crows must be equal to the rows({})+1 of matrix." "The length({}) of crows must be equal to the rows({})+1 of matrix."
.format(crows.shape[0], shape[0])) .format(crows.shape[0], rows))
if crows[0] != 0: if crows[0] != 0:
raise ValueError("the 0th value of crows must be 0") raise ValueError("the 0th value of crows must be 0")
...@@ -279,10 +280,10 @@ def sparse_csr_tensor(crows, ...@@ -279,10 +280,10 @@ def sparse_csr_tensor(crows,
raise ValueError( raise ValueError(
"the last value of crows must be equal the number of non-zero") "the last value of crows must be equal the number of non-zero")
else: else:
if crows.shape[0] % (shape[0] + 1) != 0: if crows.shape[0] % (rows + 1) != 0:
raise ValueError( raise ValueError(
"The length({}) of crows must be divisible the rows({})+1 of matrix." "The length({}) of crows must be divisible the rows({})+1 of matrix."
.format(crows.shape[0], shape[0])) .format(crows.shape[0], rows))
# TODO(zkh2016): check whether the value in crows and cols is legal # TODO(zkh2016): check whether the value in crows and cols is legal
return core.eager.sparse_csr_tensor(crows, cols, values, shape, return core.eager.sparse_csr_tensor(crows, cols, values, shape,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册