From aca864702de07a5ce339059a959ab4fdb100c597 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 9 Feb 2022 13:59:11 +0800 Subject: [PATCH] Add a Sparse Op to_dense (#39335) * implement AllocateFrom * dense_to_sparse_coo * optimize unit testing; support rocm * 1. delete fluid related header file 2. update the copyright * fix hipMemcpy * update dense_to_sparsecoo * add namespace sparse * sparse_csr_to_dense * test to_sparse_coo: csr_to_coo * fix writing error * to_sparse_csr: dense_to_sparse_csr and sparse_coo_to_csr * fix check shape * fix unit test * to_dense: sparse_coo_to_dense, sparse_csr_to_dense * replace CUDADeviceContext by GPUContext --- paddle/pten/api/include/sparse_api.h | 2 + paddle/pten/api/lib/sparse_api.cc | 58 ++++ .../kernels/sparse/cpu/sparse_utils_kernel.cc | 69 +++++ .../kernels/sparse/gpu/sparse_utils_kernel.cu | 116 ++++++++ .../pten/kernels/sparse/sparse_utils_kernel.h | 32 +++ .../pten/tests/api/test_sparse_utils_api.cc | 69 +++++ .../kernels/test_sparse_utils_dev_api.cc | 271 ++++++++++++++++++ 7 files changed, 617 insertions(+) diff --git a/paddle/pten/api/include/sparse_api.h b/paddle/pten/api/include/sparse_api.h index 8ec36084ff..bc0a8152d1 100644 --- a/paddle/pten/api/include/sparse_api.h +++ b/paddle/pten/api/include/sparse_api.h @@ -27,6 +27,8 @@ PADDLE_API Tensor to_sparse_coo(const Tensor& x, PADDLE_API Tensor to_sparse_csr(const Tensor& x, Backend backend); +PADDLE_API Tensor to_dense(const Tensor& x, Backend backend); + } // namespace sparse } // namespace experimental } // namespace paddle diff --git a/paddle/pten/api/lib/sparse_api.cc b/paddle/pten/api/lib/sparse_api.cc index e2bccd4723..d6df7de71e 100644 --- a/paddle/pten/api/lib/sparse_api.cc +++ b/paddle/pten/api/lib/sparse_api.cc @@ -26,12 +26,16 @@ PT_DECLARE_KERNEL(dense_to_sparse_coo, CPU, ALL_LAYOUT); PT_DECLARE_KERNEL(sparse_csr_to_coo, CPU, ALL_LAYOUT); PT_DECLARE_KERNEL(dense_to_sparse_csr, CPU, ALL_LAYOUT); PT_DECLARE_KERNEL(sparse_coo_to_csr, CPU, ALL_LAYOUT); +PT_DECLARE_KERNEL(sparse_coo_to_dense, CPU, ALL_LAYOUT); +PT_DECLARE_KERNEL(sparse_csr_to_dense, CPU, ALL_LAYOUT); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PT_DECLARE_KERNEL(dense_to_sparse_coo, GPU, ALL_LAYOUT); PT_DECLARE_KERNEL(sparse_csr_to_coo, GPU, ALL_LAYOUT); PT_DECLARE_KERNEL(dense_to_sparse_csr, GPU, ALL_LAYOUT); PT_DECLARE_KERNEL(sparse_coo_to_csr, GPU, ALL_LAYOUT); +PT_DECLARE_KERNEL(sparse_coo_to_dense, GPU, ALL_LAYOUT); +PT_DECLARE_KERNEL(sparse_csr_to_dense, GPU, ALL_LAYOUT); #endif namespace paddle { @@ -166,6 +170,60 @@ PADDLE_API Tensor to_sparse_csr(const Tensor& x, Backend backend) { return out; } + +PADDLE_API Tensor to_dense(const Tensor& x, Backend backend) { + if (x.layout() != pten::DataLayout::SPARSE_CSR && + x.layout() != pten::DataLayout::SPARSE_COO) { + return x; + } + // 1. Get kernel signature and kernel + auto kernel_key_set = ParseKernelKeyByInputArgs(x); + kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend); + auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); + std::string kernel_name = "sparse_coo_to_dense"; + if (x.layout() == pten::DataLayout::SPARSE_CSR) { + kernel_name = "sparse_csr_to_dense"; + } + + auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( + kernel_name, kernel_key); + + VLOG(6) << "to API kernel key: " << kernel_key; + VLOG(6) << "to API kernel: " << kernel; + + // 2. Get Device Context + auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); + auto kernel_context = pten::KernelContext(dev_ctx); + + // 3. Auto data transform + if (x.layout() == pten::DataLayout::SPARSE_COO) { + auto input = std::dynamic_pointer_cast(x.impl()); + kernel_context.EmplaceBackInput(input.get()); + } else { + auto input = std::dynamic_pointer_cast(x.impl()); + kernel_context.EmplaceBackInput(input.get()); + } + + // 4. InferMeta + auto dense_meta = pten::DenseTensorMeta(x.dtype(), x.dims(), x.layout()); + + // 5. Prepare outputs + // create empty SparseCooTensor + auto dense_out = std::make_shared( + pten::make_intrusive( + pten::TransToFluidPlace(backend)), + std::move(dense_meta)); + + kernel_context.EmplaceBackOutput(dense_out.get()); + Tensor out; + out.set_impl(dense_out); + + // 6. Call kernel + kernel(&kernel_context); + + return out; +} + } // namespace sparse } // namespace experimental } // namespace paddle diff --git a/paddle/pten/kernels/sparse/cpu/sparse_utils_kernel.cc b/paddle/pten/kernels/sparse/cpu/sparse_utils_kernel.cc index d8062104ed..de648f76c1 100644 --- a/paddle/pten/kernels/sparse/cpu/sparse_utils_kernel.cc +++ b/paddle/pten/kernels/sparse/cpu/sparse_utils_kernel.cc @@ -242,6 +242,49 @@ void SparseCooToCsrKernel(const Context& dev_ctx, out->SetMember(non_zero_crows, non_zero_cols, non_zero_elements, x_dims); } +template +void SparseCooToDenseKernel(const Context& dev_ctx, + const SparseCooTensor& x, + DenseTensor* out) { + const auto non_zero_num = x.nnz(); + const auto dense_dims = x.dims(); + const auto indices = x.non_zero_indices(); + const auto values = x.non_zero_elements(); + const auto indices_dims = indices.dims(); + int64_t sparse_dim = indices_dims[0]; + if (indices_dims.size() == 1) { + sparse_dim = 1; + } + const int64_t dense_dim = values.dims().size() - 1; + + const auto place = dev_ctx.GetPlace(); + const T* x_data = values.data(); + T* out_data = out->mutable_data(place); + int64_t base_offset = 1; + for (int64_t i = 0; i < dense_dim; i++) { + base_offset *= dense_dims[sparse_dim + i]; + } + std::vector sparse_offsets(sparse_dim); + int64_t offset = 1; + for (int i = sparse_dim - 1; i >= 0; i--) { + sparse_offsets[i] = offset; + offset *= dense_dims[i]; + } + + memset(out_data, 0, sizeof(T) * out->numel()); + for (auto i = 0; i < non_zero_num; i++) { + int64_t index = 0; + for (int j = 0; j < sparse_dim; j++) { + index += + indices.data()[j * non_zero_num + i] * sparse_offsets[j]; + } + + for (int j = 0; j < base_offset; j++) { + out_data[index * base_offset + j] = x_data[i * base_offset + j]; + } + } +} + } // namespace sparse } // namespace pten @@ -296,3 +339,29 @@ PT_REGISTER_KERNEL(dense_to_sparse_csr, int16_t, int, int64_t) {} + +PT_REGISTER_KERNEL(sparse_coo_to_dense, + CPU, + ALL_LAYOUT, + pten::sparse::SparseCooToDenseKernel, + float, + double, + pten::dtype::float16, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} + +PT_REGISTER_KERNEL(sparse_csr_to_dense, + CPU, + ALL_LAYOUT, + pten::sparse::SparseCsrToDenseKernel, + float, + double, + pten::dtype::float16, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} diff --git a/paddle/pten/kernels/sparse/gpu/sparse_utils_kernel.cu b/paddle/pten/kernels/sparse/gpu/sparse_utils_kernel.cu index 108e6a39c4..cb8e307531 100644 --- a/paddle/pten/kernels/sparse/gpu/sparse_utils_kernel.cu +++ b/paddle/pten/kernels/sparse/gpu/sparse_utils_kernel.cu @@ -467,6 +467,96 @@ void SparseCooToCsrKernel(const Context& dev_ctx, out->SetMember(non_zero_crows, non_zero_cols, non_zero_elements, x_dims); } +template +__global__ void KernelSparseCooToDense(const IndicesT* indices, + const IndicesT* sparse_offsets, + const ValueT* data, + ValueT* dense_data, + const IndicesT non_zero_num, + const int64_t base_offset, + const int64_t sparse_dim) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) { + int64_t index = 0; + for (int j = 0; j < sparse_dim; j++) { + index += indices[j * non_zero_num + i] * sparse_offsets[j]; + } + + for (int j = 0; j < base_offset; j++) { + dense_data[index * base_offset + j] = data[i * base_offset + j]; + } + } +} + +template +void SparseCooToDenseKernel(const Context& dev_ctx, + const SparseCooTensor& x, + DenseTensor* out) { + const auto non_zero_num = x.nnz(); + const auto dense_dims = x.dims(); + const auto indices = x.non_zero_indices(); + const auto values = x.non_zero_elements(); + const auto indices_dims = indices.dims(); + int64_t sparse_dim = indices_dims[0]; + if (indices_dims.size() == 1) { + sparse_dim = 1; + } + const int64_t dense_dim = values.dims().size() - 1; + + const auto place = dev_ctx.GetPlace(); + const T* x_data = values.data(); + T* out_data = out->mutable_data(place); + int64_t base_offset = 1; + for (int64_t i = 0; i < dense_dim; i++) { + base_offset *= dense_dims[sparse_dim + i]; + } + std::vector sparse_offsets(sparse_dim); + int64_t offset = 1; + for (int i = sparse_dim - 1; i >= 0; i--) { + sparse_offsets[i] = offset; + offset *= dense_dims[i]; + } + + auto sparse_offset_meta = pten::DenseTensorMeta( + DataType::INT64, {sparse_dim}, pten::DataLayout::NCHW); + DenseTensor d_sparse_offsets = + pten::Empty(dev_ctx, std::move(sparse_offset_meta)); + +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS( + hipMemcpyAsync(d_sparse_offsets.mutable_data(place), + sparse_offsets.data(), + sparse_dim * sizeof(int64_t), + hipMemcpyHostToDevice, + dev_ctx.stream())); + + PADDLE_ENFORCE_GPU_SUCCESS( + hipMemsetAsync(out_data, 0, sizeof(T) * out->numel(), dev_ctx.stream())); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemcpyAsync(d_sparse_offsets.mutable_data(place), + sparse_offsets.data(), + sparse_dim * sizeof(int64_t), + cudaMemcpyHostToDevice, + dev_ctx.stream())); + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(out_data, 0, sizeof(T) * out->numel(), dev_ctx.stream())); +#endif + int grid_size = 1, block_size = 1; + GetGpuLaunchConfig1D(dev_ctx, non_zero_num, &grid_size, &block_size); + + KernelSparseCooToDense< + T, + int64_t><<>>( + indices.data(), + d_sparse_offsets.data(), + x_data, + out_data, + non_zero_num, + base_offset, + sparse_dim); +} + } // namespace sparse } // namespace pten @@ -521,3 +611,29 @@ PT_REGISTER_KERNEL(dense_to_sparse_csr, int16_t, int, int64_t) {} + +PT_REGISTER_KERNEL(sparse_coo_to_dense, + GPU, + ALL_LAYOUT, + pten::sparse::SparseCooToDenseKernel, + float, + double, + pten::dtype::float16, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} + +PT_REGISTER_KERNEL(sparse_csr_to_dense, + GPU, + ALL_LAYOUT, + pten::sparse::SparseCsrToDenseKernel, + float, + double, + pten::dtype::float16, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} diff --git a/paddle/pten/kernels/sparse/sparse_utils_kernel.h b/paddle/pten/kernels/sparse/sparse_utils_kernel.h index b2cb90878a..0e880bc5af 100644 --- a/paddle/pten/kernels/sparse/sparse_utils_kernel.h +++ b/paddle/pten/kernels/sparse/sparse_utils_kernel.h @@ -118,5 +118,37 @@ SparseCsrTensor DenseToSparseCsr(const Context& dev_ctx, const DenseTensor& x) { return csr; } +template +void SparseCooToDenseKernel(const Context& dev_ctx, + const SparseCooTensor& x, + DenseTensor* out); + +template +DenseTensor SparseCooToDense(const Context& dev_ctx, const SparseCooTensor& x) { + DenseTensorMeta meta(x.dtype(), x.dims(), x.layout()); + DenseTensor dense = pten::Empty(dev_ctx, std::move(meta)); + SparseCooToDenseKernel(dev_ctx, x, &dense); + return dense; +} + +template +void SparseCsrToDenseKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + DenseTensor* out) { + DenseTensor indices = pten::Empty(dev_ctx); + DenseTensor values = pten::Empty(dev_ctx); + SparseCooTensor coo(indices, values, x.dims()); + SparseCsrToCooKernel(dev_ctx, x, &coo); + SparseCooToDenseKernel(dev_ctx, coo, out); +} + +template +DenseTensor SparseCsrToDense(const Context& dev_ctx, const SparseCsrTensor& x) { + DenseTensorMeta meta(x.dtype(), x.dims(), x.layout()); + DenseTensor dense = pten::Empty(dev_ctx, std::move(meta)); + SparseCsrToDenseKernel(dev_ctx, x, &dense); + return dense; +} + } // namespace sparse } // namespace pten diff --git a/paddle/pten/tests/api/test_sparse_utils_api.cc b/paddle/pten/tests/api/test_sparse_utils_api.cc index cb9a419167..40cb20fdd0 100644 --- a/paddle/pten/tests/api/test_sparse_utils_api.cc +++ b/paddle/pten/tests/api/test_sparse_utils_api.cc @@ -175,3 +175,72 @@ TEST(API, to_sparse_csr) { auto csr2 = std::dynamic_pointer_cast(out.impl()); check(*csr2); } + +TEST(API, to_dense) { + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + + pten::CPUPlace cpu; + const int64_t sparse_dim = 2; + float dense_data[3][3] = {{0.0, 1.0, 0.0}, {2.0, 0.0, 3.0}, {3.2, 0.0, 0.0}}; + std::vector non_zero_data = {1.0, 2.0, 3.0, 3.2}; + std::vector indices_data = {0, 1, 1, 2, 1, 0, 2, 0}; + std::vector cols_data = {1, 0, 2, 0}; + std::vector crows_data = {0, 1, 3, 4}; + const int64_t non_zero_num = 4; + auto dense_dims = pten::framework::make_ddim({3, 3}); + + pten::CPUContext dev_ctx_cpu; + + // 1. test sparse_coo_to_dense + pten::DenseTensorMeta indices_meta(pten::DataType::INT64, + {sparse_dim, non_zero_num}, + pten::DataLayout::NCHW); + pten::DenseTensorMeta values_meta( + pten::DataType::FLOAT32, {non_zero_num}, pten::DataLayout::NCHW); + + pten::CPUPlace place; + pten::DenseTensor indices(alloc.get(), indices_meta); + pten::DenseTensor values(alloc.get(), values_meta); + memcpy(indices.mutable_data(place), + indices_data.data(), + indices_data.size() * sizeof(int64_t)); + memcpy(values.mutable_data(place), + non_zero_data.data(), + non_zero_data.size() * sizeof(float)); + auto coo = + std::make_shared(indices, values, dense_dims); + + paddle::experimental::Tensor coo_x(coo); + auto out = paddle::experimental::sparse::to_dense(coo_x, pten::Backend::CPU); + auto dense_out = std::dynamic_pointer_cast(out.impl()); + int cmp1 = + memcmp(dense_out->data(), &dense_data[0][0], 9 * sizeof(float)); + ASSERT_EQ(cmp1, 0); + + // 1. test sparse_csr_to_dense + pten::DenseTensorMeta crows_meta( + pten::DataType::INT64, {dense_dims[0] + 1}, pten::DataLayout::NCHW); + pten::DenseTensorMeta cols_meta( + pten::DataType::INT64, {non_zero_num}, pten::DataLayout::NCHW); + pten::DenseTensor crows(alloc.get(), crows_meta); + pten::DenseTensor cols(alloc.get(), cols_meta); + memcpy(crows.mutable_data(place), + crows_data.data(), + crows_data.size() * sizeof(int64_t)); + memcpy(cols.mutable_data(place), + cols_data.data(), + cols_data.size() * sizeof(int64_t)); + memcpy(values.mutable_data(place), + non_zero_data.data(), + non_zero_data.size() * sizeof(float)); + auto csr = + std::make_shared(crows, cols, values, dense_dims); + paddle::experimental::Tensor csr_x(csr); + auto out2 = paddle::experimental::sparse::to_dense(csr_x, pten::Backend::CPU); + + auto dense_out2 = std::dynamic_pointer_cast(out.impl()); + int cmp2 = + memcmp(dense_out2->data(), &dense_data[0][0], 9 * sizeof(float)); + ASSERT_EQ(cmp2, 0); +} diff --git a/paddle/pten/tests/kernels/test_sparse_utils_dev_api.cc b/paddle/pten/tests/kernels/test_sparse_utils_dev_api.cc index c082b08d52..4b4e372cfe 100644 --- a/paddle/pten/tests/kernels/test_sparse_utils_dev_api.cc +++ b/paddle/pten/tests/kernels/test_sparse_utils_dev_api.cc @@ -659,5 +659,276 @@ TEST(DEV_API, dense_to_sparse_csr_batch) { dense_x, non_zero_num, non_zero_data, crows_data, cols_data); } +template +void TestSparseCooToDense(const DDim& dense_dims, + const std::vector& dense_data, + const std::vector& non_zero_data, + const std::vector& indices_data, + const int64_t non_zero_num, + const int64_t sparse_dim) { + pten::CPUContext dev_ctx_cpu; + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + + DenseTensor dense_indices( + alloc.get(), + DenseTensorMeta(DataType::INT64, + framework::make_ddim({sparse_dim, non_zero_num}), + DataLayout::NCHW)); + std::vector dense_elements_vec; + dense_elements_vec.push_back(non_zero_num); + for (int64_t i = sparse_dim; i < dense_dims.size(); i++) { + dense_elements_vec.push_back(dense_dims[i]); + } + DDim dense_elements_dims = framework::make_ddim(dense_elements_vec); + DenseTensor dense_elements( + alloc.get(), + DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), + dense_elements_dims, + DataLayout::NCHW)); + + pten::CPUPlace cpu_place; + memcpy(dense_indices.mutable_data(cpu_place), + indices_data.data(), + indices_data.size() * sizeof(int64_t)); + memcpy(dense_elements.mutable_data(cpu_place), + non_zero_data.data(), + non_zero_num * sizeof(T)); + + SparseCooTensor coo(dense_indices, dense_elements, dense_dims); + + DenseTensor dense_out = sparse::SparseCooToDense(dev_ctx_cpu, coo); + + int cmp = memcmp( + &dense_data[0], dense_out.data(), sizeof(T) * dense_data.size()); + ASSERT_EQ(cmp, 0); + +#if defined(PADDLE_WITH_CUDA) + const auto cuda_alloc = + std::make_shared( + paddle::platform::CUDAPlace()); + pten::GPUContext dev_ctx_gpu; + dev_ctx_gpu.PartialInitWithoutAllocator(); + dev_ctx_gpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(dev_ctx_gpu.GetPlace(), dev_ctx_gpu.stream()) + .get()); + dev_ctx_gpu.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(pten::CPUPlace()) + .get()); + dev_ctx_gpu.PartialInitWithAllocator(); + DenseTensor d_dense_indices(cuda_alloc.get(), dense_indices.meta()); + DenseTensor d_dense_elements(cuda_alloc.get(), dense_elements.meta()); + pten::Copy(dev_ctx_gpu, dense_indices, true, &d_dense_indices); + pten::Copy(dev_ctx_gpu, dense_elements, true, &d_dense_elements); + SparseCooTensor coo_cuda(d_dense_indices, d_dense_elements, dense_dims); + auto dense_out_cuda = sparse::SparseCooToDense(dev_ctx_gpu, coo_cuda); + + DenseTensor h_dense_out(alloc.get(), + DenseTensorMeta(dense_out_cuda.dtype(), + dense_out_cuda.dims(), + dense_out_cuda.layout())); + pten::Copy(dev_ctx_gpu, dense_out_cuda, true, &h_dense_out); + int cmp_cuda = memcmp( + &dense_data[0], h_dense_out.data(), sizeof(T) * dense_data.size()); + ASSERT_EQ(cmp_cuda, 0); +#endif +} + +TEST(DEV_API, sparse_coo_to_dense) { + const int non_zero_num = 4; + const int sparse_dim = 2; + std::vector dense_data = {0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 3.2, 0.0, 0.0}; + std::vector non_zero_data = {1.0, 2.0, 3.0, 3.2}; + std::vector indices_data = {0, 1, 1, 2, 1, 0, 2, 0}; + DDim dense_dims = framework::make_ddim({3, 3}); + TestSparseCooToDense(dense_dims, + dense_data, + non_zero_data, + indices_data, + non_zero_num, + sparse_dim); +} + +TEST(DEV_API, sparse_coo_to_dense_batch_and_fp16) { + std::vector dense_data = {0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 2.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 3.0, + 0.0, + 4.0, + 0.0, + 0.0}; + std::vector non_zero_data = {1.0, 2.0, 3.0, 4.0}; + std::vector indices_data = {0, 0, 1, 1, 0, 2, 1, 2, 1, 0, 1, 0}; + const int non_zero_num = 4; + const int sparse_dim = 3; + DDim dense_dims = framework::make_ddim({2, 3, 3}); + using float16 = pten::dtype::float16; + std::vector dense_data_fp16(dense_data.size()), + non_zero_data_fp16(non_zero_num); + for (uint64_t i = 0; i < dense_data.size(); i++) { + dense_data_fp16[i] = static_cast(dense_data[i]); + } + for (int64_t i = 0; i < non_zero_num; i++) { + non_zero_data_fp16[i] = static_cast(non_zero_data[i]); + } + TestSparseCooToDense(dense_dims, + dense_data_fp16, + non_zero_data_fp16, + indices_data, + non_zero_num, + sparse_dim); +} + +template +void TestSparseCsrToDense(const DDim& dense_dims, + const std::vector& dense_data, + const std::vector& non_zero_data, + const std::vector& crows_data, + const std::vector& cols_data, + const int64_t non_zero_num) { + int batchs = 1; + int rows = dense_dims[0]; + if (dense_dims.size() == 3) { + batchs = dense_dims[0]; + rows = dense_dims[1]; + } + pten::DenseTensorMeta crows_meta(DataType::INT64, + framework::make_ddim({batchs * (rows + 1)}), + DataLayout::NCHW); + pten::DenseTensorMeta cols_meta( + DataType::INT64, framework::make_ddim({non_zero_num}), DataLayout::NCHW); + pten::DenseTensorMeta values_meta( + paddle::experimental::CppTypeToDataType::Type(), + framework::make_ddim({non_zero_num}), + DataLayout::NCHW); + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + + pten::CPUPlace place; + pten::DenseTensor crows(alloc.get(), crows_meta); + pten::DenseTensor cols(alloc.get(), cols_meta); + pten::DenseTensor values(alloc.get(), values_meta); + memcpy(crows.mutable_data(place), + crows_data.data(), + crows_data.size() * sizeof(int64_t)); + memcpy(cols.mutable_data(place), + cols_data.data(), + cols_data.size() * sizeof(int64_t)); + memcpy(values.mutable_data(place), + non_zero_data.data(), + non_zero_data.size() * sizeof(T)); + pten::SparseCsrTensor csr(crows, cols, values, dense_dims); + + // 1. test cpu + pten::CPUContext dev_ctx_cpu; + DenseTensor cpu_sparse_out = sparse::SparseCsrToDense(dev_ctx_cpu, csr); + int cmp_cpu = memcmp(cpu_sparse_out.data(), + dense_data.data(), + sizeof(T) * dense_data.size()); + ASSERT_EQ(cmp_cpu, 0); + +// 2. test cuda +#if defined(PADDLE_WITH_CUDA) + const auto cuda_alloc = + std::make_shared( + paddle::platform::CUDAPlace()); + pten::GPUContext dev_ctx_gpu; + dev_ctx_gpu.PartialInitWithoutAllocator(); + dev_ctx_gpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(dev_ctx_gpu.GetPlace(), dev_ctx_gpu.stream()) + .get()); + dev_ctx_gpu.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(pten::CPUPlace()) + .get()); + dev_ctx_gpu.PartialInitWithAllocator(); + pten::DenseTensor d_crows(cuda_alloc.get(), crows_meta); + pten::DenseTensor d_cols(cuda_alloc.get(), cols_meta); + pten::DenseTensor d_values(cuda_alloc.get(), values_meta); + pten::Copy(dev_ctx_gpu, crows, true, &d_crows); + pten::Copy(dev_ctx_gpu, cols, true, &d_cols); + pten::Copy(dev_ctx_gpu, values, true, &d_values); + pten::SparseCsrTensor d_csr(d_crows, d_cols, d_values, dense_dims); + auto cuda_sparse_out = sparse::SparseCsrToDense(dev_ctx_gpu, d_csr); + pten::DenseTensor h_out(alloc.get(), cpu_sparse_out.meta()); + pten::Copy(dev_ctx_gpu, cuda_sparse_out, true, &h_out); + int cmp_cuda = + memcmp(h_out.data(), dense_data.data(), sizeof(T) * dense_data.size()); + ASSERT_EQ(cmp_cuda, 0); +#endif +} + +TEST(DEV_API, sparse_csr_to_dense) { + DDim dense_dims = framework::make_ddim({3, 3}); + std::vector dense_data = {0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 3.2, 0.0, 0.0}; + std::vector non_zero_data = {1.0, 2.0, 3.0, 3.2}; + std::vector cols_data = {1, 0, 2, 0}; + std::vector crows_data = {0, 1, 3, 4}; + const int64_t non_zero_num = 4; + + TestSparseCsrToDense(dense_dims, + dense_data, + non_zero_data, + crows_data, + cols_data, + non_zero_num); +} + +TEST(DEV_API, sparse_csr_to_dense_batch_and_fp16) { + DDim dense_dims = framework::make_ddim({2, 3, 3}); + std::vector dense_data = {0.0, + 1.0, + 0.0, + 2.0, + 0.0, + 3.0, + 3.2, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 2.0, + 0.0, + 3.0, + 3.2, + 0.0, + 0.0}; + std::vector non_zero_data = {1.0, 2.0, 3.0, 3.2, 1.0, 2.0, 3.0, 3.2}; + std::vector cols_data = {1, 0, 2, 0, 1, 0, 2, 0}; + std::vector crows_data = {0, 1, 3, 4, 0, 1, 3, 4}; + const int64_t non_zero_num = 8; + + using float16 = pten::dtype::float16; + std::vector dense_data_fp16(dense_data.size()), + non_zero_data_fp16(non_zero_num); + for (uint64_t i = 0; i < dense_data.size(); i++) { + dense_data_fp16[i] = static_cast(dense_data[i]); + } + for (int64_t i = 0; i < non_zero_num; i++) { + non_zero_data_fp16[i] = static_cast(non_zero_data[i]); + } + TestSparseCsrToDense(dense_dims, + dense_data_fp16, + non_zero_data_fp16, + crows_data, + cols_data, + non_zero_num); +} + } // namespace tests } // namespace pten -- GitLab