未验证 提交 bafea65c 编写于 作者: Z zhangkaihuo 提交者: GitHub

Add a Sparse OP:sparse_csr_to_coo (#39266)

* dense_to_sparse_coo

* optimize unit testing; support rocm

* 1. delete fluid related header file
2. update the copyright

* fix hipMemcpy

* update dense_to_sparsecoo

* add namespace sparse

* sparse_csr_to_dense

* test to_sparse_coo: csr_to_coo

* fix writing error
上级 7e29cea9
......@@ -26,10 +26,6 @@ inline void check_shape(const DDim& dims) {
#define Check(non_zero_crows, non_zero_cols, non_zero_elements, dims) \
{ \
check_shape(dims); \
PADDLE_ENFORCE_EQ(dims.size(), \
2, \
paddle::platform::errors::InvalidArgument( \
"the SparseCsrTensor only support 2-D Tensor.")); \
PADDLE_ENFORCE_EQ( \
non_zero_cols.place(), \
non_zero_crows.place(), \
......@@ -50,7 +46,12 @@ SparseCsrTensor::SparseCsrTensor(const DenseTensor& non_zero_crows,
non_zero_cols_(non_zero_cols),
non_zero_elements_(non_zero_elements),
dims_(dims) {
Check(non_zero_crows_, non_zero_cols_, non_zero_elements_, dims_);
if (non_zero_crows.initialized()) {
Check(non_zero_crows_, non_zero_cols_, non_zero_elements_, dims_);
} else {
// create a empty tensor
check_shape(dims);
}
}
SparseCsrTensor::SparseCsrTensor(const SparseCsrTensor& other)
......
......@@ -102,6 +102,61 @@ void DenseToSparseCooKernel(const Context& dev_ctx,
out->SetMember(indices, values, x_dims, true);
}
template <typename T, typename Context>
void SparseCsrToCooKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
SparseCooTensor* out) {
const DDim& x_dims = x.dims();
const int64_t non_zero_num = x.non_zero_cols().numel();
const auto& csr_crows = x.non_zero_crows();
const auto& csr_cols = x.non_zero_cols();
const auto& csr_values = x.non_zero_elements();
const int64_t* csr_crows_data = csr_crows.data<int64_t>();
const int64_t* csr_cols_data = csr_cols.data<int64_t>();
const T* csr_values_data = csr_values.data<T>();
int64_t sparse_dim = 2;
if (x_dims.size() == 3) {
sparse_dim = 3;
}
const auto place = dev_ctx.GetPlace();
DenseTensorMeta indices_meta(
DataType::INT64, {sparse_dim, non_zero_num}, DataLayout::NCHW);
DenseTensorMeta values_meta(x.dtype(), {non_zero_num}, x.layout());
pten::DenseTensor indices =
pten::Empty<int64_t, Context>(dev_ctx, std::move(indices_meta));
pten::DenseTensor values =
pten::Empty<T, Context>(dev_ctx, std::move(values_meta));
int64_t* coo_indices = indices.mutable_data<int64_t>(place);
int64_t* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices;
int64_t* coo_rows_data =
x_dims.size() == 2 ? coo_indices : batch_ptr + non_zero_num;
int64_t* coo_cols_data = coo_rows_data + non_zero_num;
T* coo_values_data = values.mutable_data<T>(place);
int batch = x_dims.size() == 2 ? 1 : x_dims[0];
int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1];
int index = 0;
for (int b = 0; b < batch; b++) {
for (int i = 0; i < rows; i++) {
for (int j = csr_crows_data[b * (rows + 1) + i];
j < csr_crows_data[b * (rows + 1) + i + 1];
j++) {
coo_rows_data[index] = i;
if (batch_ptr) {
batch_ptr[index] = b;
}
++index;
}
}
}
memcpy(coo_cols_data, csr_cols_data, sizeof(int64_t) * non_zero_num);
memcpy(coo_values_data, csr_values_data, sizeof(T) * non_zero_num);
out->SetMember(indices, values, x_dims, true);
}
} // namespace sparse
} // namespace pten
......@@ -117,3 +172,16 @@ PT_REGISTER_KERNEL(dense_to_sparse_coo,
int16_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(sparse_csr_to_coo,
CPU,
ALL_LAYOUT,
pten::sparse::SparseCsrToCooKernel,
float,
double,
paddle::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
......@@ -214,6 +214,122 @@ void DenseToSparseCooKernel(const Context& dev_ctx,
out->SetMember(indices, values, x_dims, true);
}
__global__ void GetBatchSizes(const int64_t* crows,
const int rows,
const int batchs,
int* batch_sizes) {
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < batchs) {
batch_sizes[tid] = crows[tid * (rows + 1) + rows];
}
}
__global__ void ConvertCsrCrowsToCooRows(const int64_t* crows_ptr,
const int* crows_offsets,
int64_t* rows_ptr,
int64_t* batch_ptr,
const int rows) {
const int b = blockIdx.y;
const int64_t offset = crows_offsets ? crows_offsets[b] : 0;
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < rows; i += gridDim.x * blockDim.x) {
for (int j = crows_ptr[b * (rows + 1) + i];
j < crows_ptr[b * (rows + 1) + i + 1];
j++) {
rows_ptr[offset + j] = i;
if (batch_ptr) {
batch_ptr[offset + j] = b;
}
}
}
}
template <typename T, typename Context>
void SparseCsrToCooKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
SparseCooTensor* out) {
const DDim& x_dims = x.dims();
const int64_t non_zero_num = x.non_zero_cols().numel();
const auto& csr_crows = x.non_zero_crows();
const auto& csr_cols = x.non_zero_cols();
const auto& csr_values = x.non_zero_elements();
const int64_t* csr_crows_data = csr_crows.data<int64_t>();
const int64_t* csr_cols_data = csr_cols.data<int64_t>();
const T* csr_values_data = csr_values.data<T>();
int64_t sparse_dim = 2;
if (x_dims.size() == 3) {
sparse_dim = 3;
}
int batchs = x_dims.size() == 2 ? 1 : x_dims[0];
int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1];
const auto place = dev_ctx.GetPlace();
DenseTensorMeta indices_meta(
DataType::INT64, {sparse_dim, non_zero_num}, DataLayout::NCHW);
DenseTensorMeta values_meta(x.dtype(), {non_zero_num}, x.layout());
DenseTensorMeta offsets_meta(DataType::INT32, {batchs}, DataLayout::NCHW);
DenseTensor indices =
pten::Empty<int64_t, Context>(dev_ctx, std::move(indices_meta));
DenseTensor values = pten::Empty<T, Context>(dev_ctx, std::move(values_meta));
DenseTensor offsets =
pten::Empty<T, Context>(dev_ctx, std::move(offsets_meta));
int64_t* coo_indices = indices.mutable_data<int64_t>(place);
int64_t* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices;
int64_t* coo_rows_data =
x_dims.size() == 2 ? coo_indices : batch_ptr + non_zero_num;
int64_t* coo_cols_data = coo_rows_data + non_zero_num;
int* offsets_ptr = batchs == 1 ? nullptr : offsets.mutable_data<int>(place);
T* coo_values_data = values.mutable_data<T>(place);
int grid_size = 1, block_size = 1;
if (batchs > 1) {
GetGpuLaunchConfig1D(dev_ctx, batchs, &grid_size, &block_size);
GetBatchSizes<<<grid_size, block_size>>>(
csr_crows_data, rows, batchs, offsets_ptr);
#ifdef PADDLE_WITH_HIP
thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()),
#else
thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()),
#endif
offsets_ptr,
offsets_ptr + batchs,
offsets_ptr);
}
GetGpuLaunchConfig1D(dev_ctx, rows, &grid_size, &block_size);
dim3 grids(grid_size, batchs, 1);
ConvertCsrCrowsToCooRows<<<grids, block_size>>>(
csr_crows_data, offsets_ptr, coo_rows_data, batch_ptr, rows);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(coo_cols_data,
csr_cols_data,
sizeof(int64_t) * non_zero_num,
hipMemcpyDeviceToDevice,
dev_ctx.stream()));
PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(coo_values_data,
csr_values_data,
sizeof(T) * non_zero_num,
hipMemcpyDeviceToDevice,
dev_ctx.stream()));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(coo_cols_data,
csr_cols_data,
sizeof(int64_t) * non_zero_num,
cudaMemcpyDeviceToDevice,
dev_ctx.stream()));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(coo_values_data,
csr_values_data,
sizeof(T) * non_zero_num,
cudaMemcpyDeviceToDevice,
dev_ctx.stream()));
#endif
out->SetMember(indices, values, x_dims, true);
}
} // namespace sparse
} // namespace pten
......@@ -229,3 +345,16 @@ PT_REGISTER_KERNEL(dense_to_sparse_coo,
int16_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(sparse_csr_to_coo,
GPU,
ALL_LAYOUT,
pten::sparse::SparseCsrToCooKernel,
float,
double,
pten::dtype::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
......@@ -57,5 +57,20 @@ SparseCooTensor DenseToSparseCoo(const Context& dev_ctx,
return coo;
}
template <typename T, typename Context>
void SparseCsrToCooKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
SparseCooTensor* out);
template <typename T, typename Context>
SparseCooTensor SparseCsrToCoo(const Context& dev_ctx,
const SparseCsrTensor& x) {
DenseTensor indices = pten::Empty<T, Context>(dev_ctx);
DenseTensor values = pten::Empty<T, Context>(dev_ctx);
SparseCooTensor coo(indices, values, x.dims());
SparseCsrToCooKernel<T, Context>(dev_ctx, x, &coo);
return coo;
}
} // namespace sparse
} // namespace pten
......@@ -62,4 +62,43 @@ TEST(API, to_sparse_coo) {
non_zero_data.data(),
non_zero_data.size() * sizeof(float));
ASSERT_EQ(cmp_elements, 0);
// 1. test sparse_csr_to_coo
auto dense_dims = pten::framework::make_ddim({3, 3});
pten::DenseTensorMeta crows_meta(
pten::DataType::INT64, {dense_dims[0] + 1}, pten::DataLayout::NCHW);
pten::DenseTensorMeta cols_meta(
pten::DataType::INT64, {non_zero_num}, pten::DataLayout::NCHW);
pten::DenseTensorMeta values_meta(
pten::DataType::FLOAT32, {non_zero_num}, pten::DataLayout::NCHW);
pten::CPUPlace place;
pten::DenseTensor crows(alloc.get(), crows_meta);
pten::DenseTensor cols(alloc.get(), cols_meta);
pten::DenseTensor values(alloc.get(), values_meta);
memcpy(crows.mutable_data<int64_t>(place),
crows_data.data(),
crows_data.size() * sizeof(int64_t));
memcpy(cols.mutable_data<int64_t>(place),
cols_data.data(),
cols_data.size() * sizeof(int64_t));
memcpy(values.mutable_data<float>(place),
non_zero_data.data(),
non_zero_data.size() * sizeof(float));
auto csr =
std::make_shared<pten::SparseCsrTensor>(crows, cols, values, dense_dims);
paddle::experimental::Tensor csr_x(csr);
auto out2 = paddle::experimental::sparse::to_sparse_coo(
csr_x, pten::Backend::CPU, sparse_dim);
auto coo2 = std::dynamic_pointer_cast<pten::SparseCooTensor>(out.impl());
ASSERT_EQ(coo2->nnz(), non_zero_num);
int cmp_indices2 = memcmp(coo2->non_zero_indices().data<int64_t>(),
indices_data.data(),
indices_data.size() * sizeof(int64_t));
ASSERT_EQ(cmp_indices2, 0);
int cmp_elements2 = memcmp(coo2->non_zero_elements().data<float>(),
non_zero_data.data(),
non_zero_data.size() * sizeof(float));
ASSERT_EQ(cmp_elements2, 0);
}
......@@ -246,5 +246,112 @@ TEST(DEV_API, to_sparse_coo_batch) {
dense_x, sparse_dim, non_zero_data, indices_data, non_zero_num);
}
template <typename T>
void TestSparseCsrToCoo(const DDim& dense_dims,
const std::vector<T>& non_zero_data,
const std::vector<int64_t>& crows_data,
const std::vector<int64_t>& cols_data,
const std::vector<int64_t>& indices_data,
const int64_t non_zero_num) {
int batchs = 1;
int rows = dense_dims[0];
if (dense_dims.size() == 3) {
batchs = dense_dims[0];
rows = dense_dims[1];
}
pten::DenseTensorMeta crows_meta(
DataType::INT64, {batchs * (rows + 1)}, DataLayout::NCHW);
pten::DenseTensorMeta cols_meta(
DataType::INT64, {non_zero_num}, DataLayout::NCHW);
pten::DenseTensorMeta values_meta(
paddle::experimental::CppTypeToDataType<T>::Type(),
{non_zero_num},
DataLayout::NCHW);
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
pten::CPUPlace place;
pten::DenseTensor crows(alloc.get(), crows_meta);
pten::DenseTensor cols(alloc.get(), cols_meta);
pten::DenseTensor values(alloc.get(), values_meta);
memcpy(crows.mutable_data<int64_t>(place),
crows_data.data(),
crows_data.size() * sizeof(int64_t));
memcpy(cols.mutable_data<int64_t>(place),
cols_data.data(),
cols_data.size() * sizeof(int64_t));
memcpy(values.mutable_data<T>(place),
non_zero_data.data(),
non_zero_data.size() * sizeof(T));
pten::SparseCsrTensor csr(crows, cols, values, dense_dims);
// 1. test cpu
pten::CPUContext dev_ctx_cpu;
auto cpu_sparse_out = sparse::SparseCsrToCoo<T>(dev_ctx_cpu, csr);
CheckResult<T, int64_t>(&dev_ctx_cpu,
cpu_sparse_out,
non_zero_data,
indices_data,
non_zero_num,
alloc);
// 2. test cuda
#if defined(PADDLE_WITH_CUDA)
const auto cuda_alloc =
std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CUDAPlace());
auto& pool = paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx_cuda = pool.GetByPlace(paddle::platform::CUDAPlace());
pten::DenseTensor d_crows(cuda_alloc.get(), crows_meta);
pten::DenseTensor d_cols(cuda_alloc.get(), cols_meta);
pten::DenseTensor d_values(cuda_alloc.get(), values_meta);
pten::Copy(*dev_ctx_cuda, crows, true, &d_crows);
pten::Copy(*dev_ctx_cuda, cols, true, &d_cols);
pten::Copy(*dev_ctx_cuda, values, true, &d_values);
pten::SparseCsrTensor d_csr(d_crows, d_cols, d_values, dense_dims);
auto cuda_sparse_out = sparse::SparseCsrToCoo<T>(*dev_ctx_cuda, d_csr);
CheckResult<T, int64_t>(dev_ctx_cuda,
cuda_sparse_out,
non_zero_data,
indices_data,
non_zero_num,
alloc);
#endif
}
TEST(DEV_API, sparse_csr_to_coo) {
DDim dense_dims = framework::make_ddim({3, 3});
std::vector<float> non_zero_data = {1.0, 2.0, 3.0, 3.2};
std::vector<int64_t> indices_data = {0, 1, 1, 2, 1, 0, 2, 0};
std::vector<int64_t> cols_data = {1, 0, 2, 0};
std::vector<int64_t> crows_data = {0, 1, 3, 4};
const int64_t non_zero_num = 4;
TestSparseCsrToCoo(dense_dims,
non_zero_data,
crows_data,
cols_data,
indices_data,
non_zero_num);
}
TEST(DEV_API, sparse_csr_to_coo_batch_and_fp16) {
DDim dense_dims = framework::make_ddim({2, 3, 3});
std::vector<float> non_zero_data = {1.0, 2.0, 3.0, 3.2, 1.0, 2.0, 3.0, 3.2};
std::vector<int64_t> cols_data = {1, 0, 2, 0, 1, 0, 2, 0};
std::vector<int64_t> crows_data = {0, 1, 3, 4, 0, 1, 3, 4};
std::vector<int64_t> indices_data = {0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 2,
0, 1, 1, 2, 1, 0, 2, 0, 1, 0, 2, 0};
const int64_t non_zero_num = 8;
using float16 = pten::dtype::float16;
std::vector<float16> non_zero_data_fp16(non_zero_num);
for (int64_t i = 0; i < non_zero_num; i++) {
non_zero_data_fp16[i] = static_cast<float16>(non_zero_data[i]);
}
TestSparseCsrToCoo(dense_dims,
non_zero_data_fp16,
crows_data,
cols_data,
indices_data,
non_zero_num);
}
} // namespace tests
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册