From 76d527e17f8b9f1d22c5381964cb9d50206a9907 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 9 Feb 2022 09:36:31 +0800 Subject: [PATCH] Add a Sparse Op: to_sparse_csr (#39333) * implement AllocateFrom * dense_to_sparse_coo * optimize unit testing; support rocm * 1. delete fluid related header file 2. update the copyright * fix hipMemcpy * update dense_to_sparsecoo * add namespace sparse * sparse_csr_to_dense * test to_sparse_coo: csr_to_coo * fix writing error * to_sparse_csr: dense_to_sparse_csr and sparse_coo_to_csr * fix check shape * fix unit test * replace CUDADeviceContext by GPUContext --- paddle/pten/api/include/sparse_api.h | 2 + paddle/pten/api/lib/sparse_api.cc | 71 +++++ .../kernels/sparse/cpu/sparse_utils_kernel.cc | 111 +++++++ .../kernels/sparse/gpu/sparse_utils_kernel.cu | 163 ++++++++++ .../pten/kernels/sparse/sparse_utils_kernel.h | 46 +++ .../pten/tests/api/test_sparse_utils_api.cc | 76 ++++- .../kernels/test_sparse_utils_dev_api.cc | 293 +++++++++++++++++- 7 files changed, 752 insertions(+), 10 deletions(-) diff --git a/paddle/pten/api/include/sparse_api.h b/paddle/pten/api/include/sparse_api.h index 22e511e62ab..8ec36084ff8 100644 --- a/paddle/pten/api/include/sparse_api.h +++ b/paddle/pten/api/include/sparse_api.h @@ -25,6 +25,8 @@ PADDLE_API Tensor to_sparse_coo(const Tensor& x, Backend backend, const int64_t sparse_dim); +PADDLE_API Tensor to_sparse_csr(const Tensor& x, Backend backend); + } // namespace sparse } // namespace experimental } // namespace paddle diff --git a/paddle/pten/api/lib/sparse_api.cc b/paddle/pten/api/lib/sparse_api.cc index d763bb7e8d6..e2bccd47238 100644 --- a/paddle/pten/api/lib/sparse_api.cc +++ b/paddle/pten/api/lib/sparse_api.cc @@ -23,9 +23,15 @@ limitations under the License. */ #include "paddle/pten/infermeta/unary.h" PT_DECLARE_KERNEL(dense_to_sparse_coo, CPU, ALL_LAYOUT); +PT_DECLARE_KERNEL(sparse_csr_to_coo, CPU, ALL_LAYOUT); +PT_DECLARE_KERNEL(dense_to_sparse_csr, CPU, ALL_LAYOUT); +PT_DECLARE_KERNEL(sparse_coo_to_csr, CPU, ALL_LAYOUT); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PT_DECLARE_KERNEL(dense_to_sparse_coo, GPU, ALL_LAYOUT); +PT_DECLARE_KERNEL(sparse_csr_to_coo, GPU, ALL_LAYOUT); +PT_DECLARE_KERNEL(dense_to_sparse_csr, GPU, ALL_LAYOUT); +PT_DECLARE_KERNEL(sparse_coo_to_csr, GPU, ALL_LAYOUT); #endif namespace paddle { @@ -95,6 +101,71 @@ PADDLE_API Tensor to_sparse_coo(const Tensor& x, return out; } +PADDLE_API Tensor to_sparse_csr(const Tensor& x, Backend backend) { + if (x.layout() == pten::DataLayout::SPARSE_CSR) { + return x; + } + // 1. Get kernel signature and kernel + auto kernel_key_set = ParseKernelKeyByInputArgs(x); + kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend); + auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); + std::string kernel_name = "dense_to_sparse_csr"; + if (x.layout() == pten::DataLayout::SPARSE_COO) { + kernel_name = "sparse_coo_to_csr"; + } + + auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( + kernel_name, kernel_key); + + VLOG(6) << "to API kernel key: " << kernel_key; + VLOG(6) << "to API kernel: " << kernel; + + // 2. Get Device Context + auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); + auto kernel_context = pten::KernelContext(dev_ctx); + + // 3. Auto data transform + if (x.layout() == pten::DataLayout::SPARSE_COO) { + auto input = std::dynamic_pointer_cast(x.impl()); + kernel_context.EmplaceBackInput(input.get()); + } else { + auto input = std::dynamic_pointer_cast(x.impl()); + kernel_context.EmplaceBackInput(input.get()); + } + + // 4. InferMeta + auto crows_meta = pten::DenseTensorMeta( + pten::DataType::INT64, {-1}, pten::DataLayout::NCHW); + auto cols_meta = pten::DenseTensorMeta( + pten::DataType::INT64, {-1}, pten::DataLayout::NCHW); + auto elements_meta = pten::DenseTensorMeta(x.dtype(), {-1}, x.layout()); + + // 5. Prepare outputs + // create empty SparseCooTensor + pten::DenseTensor non_zero_crows( + pten::make_intrusive( + pten::TransToFluidPlace(backend)), + std::move(crows_meta)); + pten::DenseTensor non_zero_cols( + pten::make_intrusive( + pten::TransToFluidPlace(backend)), + std::move(cols_meta)); + pten::DenseTensor non_zero_elements( + pten::make_intrusive( + pten::TransToFluidPlace(backend)), + std::move(elements_meta)); + auto csr = std::make_shared( + non_zero_crows, non_zero_cols, non_zero_elements, x.dims()); + + kernel_context.EmplaceBackOutput(csr.get()); + Tensor out; + out.set_impl(csr); + + // 6. Call kernel + kernel(&kernel_context); + + return out; +} } // namespace sparse } // namespace experimental } // namespace paddle diff --git a/paddle/pten/kernels/sparse/cpu/sparse_utils_kernel.cc b/paddle/pten/kernels/sparse/cpu/sparse_utils_kernel.cc index d3aac6ee7d2..d8062104ed7 100644 --- a/paddle/pten/kernels/sparse/cpu/sparse_utils_kernel.cc +++ b/paddle/pten/kernels/sparse/cpu/sparse_utils_kernel.cc @@ -157,6 +157,91 @@ void SparseCsrToCooKernel(const Context& dev_ctx, out->SetMember(indices, values, x_dims, true); } +template +void SparseCooToCsrKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCsrTensor* out) { + const auto& x_dims = x.dims(); + bool valid = x_dims.size() == 2 || x_dims.size() == 3; + PADDLE_ENFORCE_EQ(valid, + true, + paddle::platform::errors::InvalidArgument( + "SparseCsrTensor only support 2-D or 3-D matrix")); + const int64_t non_zero_num = x.nnz(); + if (non_zero_num <= 0) return; + + int batchs = x_dims.size() == 2 ? 1 : x_dims[0]; + int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1]; + + const auto place = dev_ctx.GetPlace(); + DenseTensorMeta crows_meta( + DataType::INT64, {batchs * (rows + 1)}, DataLayout::NCHW); + DenseTensorMeta cols_meta(DataType::INT64, {non_zero_num}, DataLayout::NCHW); + DenseTensorMeta values_meta(x.dtype(), {non_zero_num}, x.layout()); + pten::DenseTensor non_zero_crows( + pten::make_intrusive(place), + std::move(crows_meta)); + pten::DenseTensor non_zero_cols( + pten::make_intrusive(place), + std::move(cols_meta)); + pten::DenseTensor non_zero_elements( + pten::make_intrusive(place), + std::move(values_meta)); + int64_t* csr_crows_data = non_zero_crows.mutable_data(place); + int64_t* csr_cols_data = non_zero_cols.mutable_data(place); + T* csr_values_data = non_zero_elements.mutable_data(place); + + const auto& coo_indices = x.non_zero_indices(); + const auto& coo_values = x.non_zero_elements(); + const int64_t* batchs_ptr = coo_indices.data(); + const int64_t* coo_rows_data = + batchs == 1 ? batchs_ptr : batchs_ptr + non_zero_num; + const int64_t* coo_cols_data = coo_rows_data + non_zero_num; + const T* coo_values_data = coo_values.data(); + + if (!x.coalesced()) { + // TODO(zhangkahuo): call coalesced() to distinct and sort the indices + } + + std::vector offsets(batchs, 0); + if (batchs > 1) { + for (int i = 0; i < non_zero_num; i++) { + if (i == non_zero_num - 1 || batchs_ptr[i] != batchs_ptr[i + 1]) { + offsets[batchs_ptr[i]] = i + 1; + } + } + } else { + offsets[0] = non_zero_num; + } + + for (int b = 0; b < batchs; b++) { + if (offsets[b] == 0) continue; + int batch_start = 0; + int batch_non_zero_num = offsets[b]; + if (b > 0) { + batch_start = offsets[b - 1]; + batch_non_zero_num -= batch_start; + } + auto* coo_rows_ptr = coo_rows_data + batch_start; + for (int i = 0; i <= coo_rows_ptr[0]; i++) { + csr_crows_data[b * (rows + 1) + i] = 0; + } + for (int64_t i = 1; i < batch_non_zero_num; i++) { + for (int j = coo_rows_ptr[i - 1]; j < coo_rows_ptr[i]; j++) { + csr_crows_data[b * (rows + 1) + j + 1] = i; + } + } + for (int64_t i = coo_rows_ptr[batch_non_zero_num - 1] + 1; i < rows + 1; + i++) { + csr_crows_data[b * (rows + 1) + i] = batch_non_zero_num; + } + } + + memcpy(csr_cols_data, coo_cols_data, sizeof(int64_t) * non_zero_num); + memcpy(csr_values_data, coo_values_data, sizeof(T) * non_zero_num); + out->SetMember(non_zero_crows, non_zero_cols, non_zero_elements, x_dims); +} + } // namespace sparse } // namespace pten @@ -185,3 +270,29 @@ PT_REGISTER_KERNEL(sparse_csr_to_coo, int16_t, int, int64_t) {} + +PT_REGISTER_KERNEL(sparse_coo_to_csr, + CPU, + ALL_LAYOUT, + pten::sparse::SparseCooToCsrKernel, + float, + double, + pten::dtype::float16, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} + +PT_REGISTER_KERNEL(dense_to_sparse_csr, + CPU, + ALL_LAYOUT, + pten::sparse::DenseToSparseCsrKernel, + float, + double, + pten::dtype::float16, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} diff --git a/paddle/pten/kernels/sparse/gpu/sparse_utils_kernel.cu b/paddle/pten/kernels/sparse/gpu/sparse_utils_kernel.cu index eb9fa7a1696..108e6a39c4d 100644 --- a/paddle/pten/kernels/sparse/gpu/sparse_utils_kernel.cu +++ b/paddle/pten/kernels/sparse/gpu/sparse_utils_kernel.cu @@ -330,6 +330,143 @@ void SparseCsrToCooKernel(const Context& dev_ctx, out->SetMember(indices, values, x_dims, true); } +__global__ void GetBatchsOffset(const int64_t* batchs_ptr, + const int non_zero_num, + int64_t* batchs_offset) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) { + if (i == non_zero_num - 1 || batchs_ptr[i] != batchs_ptr[i + 1]) { + batchs_offset[batchs_ptr[i]] = i + 1; + } + } +} + +__global__ void ConvertCooRowsToCsrCrows( + const int64_t* batchs_offset, // can be null if batchs = 1 + const int64_t* coo_rows_data, + int64_t* csr_crows_data, + const int rows, + const int64_t non_zero_num) { + const int b = blockIdx.y; + int batch_non_zero_num = + batchs_offset == nullptr ? non_zero_num : batchs_offset[b]; + if (batch_non_zero_num == 0) return; + int batch_start = 0; + if (b > 0) { + batch_start = batchs_offset[b - 1]; + batch_non_zero_num -= batch_start; + } + auto* coo_rows_ptr = coo_rows_data + batch_start; + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (int i = tid; i < batch_non_zero_num; i += gridDim.x * blockDim.x) { + if (i == 0) { + for (int j = 0; j <= coo_rows_ptr[0]; j++) { + csr_crows_data[b * (rows + 1) + j] = 0; + } + } else { + for (int j = coo_rows_ptr[i - 1]; j < coo_rows_ptr[i]; j++) { + csr_crows_data[b * (rows + 1) + j + 1] = i; + } + } + if (i == batch_non_zero_num - 1) { + for (int64_t i = coo_rows_ptr[batch_non_zero_num - 1] + 1; i < rows + 1; + i++) { + csr_crows_data[b * (rows + 1) + i] = batch_non_zero_num; + } + } + } +} + +template +void SparseCooToCsrKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCsrTensor* out) { + const auto& x_dims = x.dims(); + bool valid = x_dims.size() == 2 || x_dims.size() == 3; + PADDLE_ENFORCE_EQ(valid, + true, + paddle::platform::errors::InvalidArgument( + "SparseCsrTensor only support 2-D or 3-D matrix")); + const int64_t non_zero_num = x.nnz(); + if (non_zero_num <= 0) return; + + int batchs = x_dims.size() == 2 ? 1 : x_dims[0]; + int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1]; + + const auto place = dev_ctx.GetPlace(); + DenseTensorMeta crows_meta( + DataType::INT64, {batchs * (rows + 1)}, DataLayout::NCHW); + DenseTensorMeta cols_meta(DataType::INT64, {non_zero_num}, DataLayout::NCHW); + DenseTensorMeta values_meta(x.dtype(), {non_zero_num}, x.layout()); + pten::DenseTensor non_zero_crows( + pten::make_intrusive(place), + std::move(crows_meta)); + pten::DenseTensor non_zero_cols( + pten::make_intrusive(place), + std::move(cols_meta)); + pten::DenseTensor non_zero_elements( + pten::make_intrusive(place), + std::move(values_meta)); + int64_t* csr_crows_data = non_zero_crows.mutable_data(place); + int64_t* csr_cols_data = non_zero_cols.mutable_data(place); + T* csr_values_data = non_zero_elements.mutable_data(place); + + const auto& coo_indices = x.non_zero_indices(); + const auto& coo_values = x.non_zero_elements(); + const int64_t* batchs_ptr = coo_indices.data(); + const int64_t* coo_rows_data = + batchs == 1 ? batchs_ptr : batchs_ptr + non_zero_num; + const int64_t* coo_cols_data = coo_rows_data + non_zero_num; + const T* coo_values_data = coo_values.data(); + + if (!x.coalesced()) { + // TODO(zhangkahuo): call coalesced() to distinct and sort the indices + } + + int grid_size = 1, block_size = 1; + GetGpuLaunchConfig1D(dev_ctx, batchs, &grid_size, &block_size); + if (batchs > 1) { + DenseTensorMeta batchs_meta(DataType::INT64, {batchs}, DataLayout::NCHW); + pten::DenseTensor batchs_offset( + pten::make_intrusive(place), + std::move(batchs_meta)); + int64_t* batchs_offset_ptr = batchs_offset.mutable_data(place); + GetBatchsOffset<<>>( + batchs_ptr, non_zero_num, batchs_offset_ptr); + dim3 grids(grid_size, batchs, 1); + ConvertCooRowsToCsrCrows<<>>( + batchs_offset_ptr, coo_rows_data, csr_crows_data, rows, non_zero_num); + } else { + ConvertCooRowsToCsrCrows<<>>( + nullptr, coo_rows_data, csr_crows_data, rows, non_zero_num); + } + +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(csr_cols_data, + coo_cols_data, + sizeof(int64_t) * non_zero_num, + hipMemcpyDeviceToDevice, + dev_ctx.stream())); + PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(csr_values_data, + coo_values_data, + sizeof(T) * non_zero_num, + hipMemcpyDeviceToDevice, + dev_ctx.stream())); +#else + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(csr_cols_data, + coo_cols_data, + sizeof(int64_t) * non_zero_num, + cudaMemcpyDeviceToDevice, + dev_ctx.stream())); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(csr_values_data, + coo_values_data, + sizeof(T) * non_zero_num, + cudaMemcpyDeviceToDevice, + dev_ctx.stream())); +#endif + out->SetMember(non_zero_crows, non_zero_cols, non_zero_elements, x_dims); +} + } // namespace sparse } // namespace pten @@ -358,3 +495,29 @@ PT_REGISTER_KERNEL(sparse_csr_to_coo, int16_t, int, int64_t) {} + +PT_REGISTER_KERNEL(sparse_coo_to_csr, + GPU, + ALL_LAYOUT, + pten::sparse::SparseCooToCsrKernel, + float, + double, + pten::dtype::float16, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} + +PT_REGISTER_KERNEL(dense_to_sparse_csr, + GPU, + ALL_LAYOUT, + pten::sparse::DenseToSparseCsrKernel, + float, + double, + pten::dtype::float16, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} diff --git a/paddle/pten/kernels/sparse/sparse_utils_kernel.h b/paddle/pten/kernels/sparse/sparse_utils_kernel.h index c353caedf31..b2cb90878aa 100644 --- a/paddle/pten/kernels/sparse/sparse_utils_kernel.h +++ b/paddle/pten/kernels/sparse/sparse_utils_kernel.h @@ -72,5 +72,51 @@ SparseCooTensor SparseCsrToCoo(const Context& dev_ctx, return coo; } +template +void SparseCooToCsrKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCsrTensor* out); + +template +SparseCsrTensor SparseCooToCsr(const Context& dev_ctx, + const SparseCooTensor& x) { + DenseTensor non_zero_crows = pten::Empty(dev_ctx); + DenseTensor non_zero_cols = pten::Empty(dev_ctx); + DenseTensor non_zero_elements = pten::Empty(dev_ctx); + SparseCsrTensor csr( + non_zero_crows, non_zero_cols, non_zero_elements, x.dims()); + SparseCooToCsrKernel(dev_ctx, x, &csr); + return csr; +} + +template +void DenseToSparseCsrKernel(const Context& dev_ctx, + const DenseTensor& x, + SparseCsrTensor* out) { + const auto& x_dims = x.dims(); + bool valid = x_dims.size() == 2 || x_dims.size() == 3; + PADDLE_ENFORCE_EQ(valid, + true, + paddle::platform::errors::InvalidArgument( + "SparseCsrTensor only support 2-D or 3-D Tensor.")); + const int64_t sparse_dim = x_dims.size() == 2 ? 2 : 3; + DenseTensor indices = pten::Empty(dev_ctx); + DenseTensor values = pten::Empty(dev_ctx); + SparseCooTensor coo(indices, values, x.dims()); + DenseToSparseCooKernel(dev_ctx, x, sparse_dim, &coo); + SparseCooToCsrKernel(dev_ctx, coo, out); +} + +template +SparseCsrTensor DenseToSparseCsr(const Context& dev_ctx, const DenseTensor& x) { + DenseTensor non_zero_crows = pten::Empty(dev_ctx); + DenseTensor non_zero_cols = pten::Empty(dev_ctx); + DenseTensor non_zero_elements = pten::Empty(dev_ctx); + SparseCsrTensor csr( + non_zero_crows, non_zero_cols, non_zero_elements, x.dims()); + DenseToSparseCsrKernel(dev_ctx, x, &csr); + return csr; +} + } // namespace sparse } // namespace pten diff --git a/paddle/pten/tests/api/test_sparse_utils_api.cc b/paddle/pten/tests/api/test_sparse_utils_api.cc index 1ec025faedc..cb9a4191674 100644 --- a/paddle/pten/tests/api/test_sparse_utils_api.cc +++ b/paddle/pten/tests/api/test_sparse_utils_api.cc @@ -8,8 +8,8 @@ You may obtain a copy of the License at Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See +the License for the specific language governing permissions and limitations under the License. */ #include @@ -103,3 +103,75 @@ TEST(API, to_sparse_coo) { non_zero_data.size() * sizeof(float)); ASSERT_EQ(cmp_elements2, 0); } + +TEST(API, to_sparse_csr) { + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + + auto dense_x = std::make_shared( + alloc.get(), + pten::DenseTensorMeta(pten::DataType::FLOAT32, + pten::framework::make_ddim({3, 3}), + pten::DataLayout::NCHW)); + + pten::CPUPlace cpu; + const int64_t sparse_dim = 2; + auto* dense_x_data = dense_x->mutable_data(cpu); + float dense_data[3][3] = {{0.0, 1.0, 0.0}, {2.0, 0.0, 3.0}, {3.2, 0.0, 0.0}}; + std::vector non_zero_data = {1.0, 2.0, 3.0, 3.2}; + std::vector indices_data = {0, 1, 1, 2, 1, 0, 2, 0}; + std::vector cols_data = {1, 0, 2, 0}; + std::vector crows_data = {0, 1, 3, 4}; + const int64_t non_zero_num = 4; + + std::copy(&dense_data[0][0], &dense_data[0][0] + 9, dense_x_data); + + pten::CPUContext dev_ctx_cpu; + + // 1. test dense_to_sparse_csr + paddle::experimental::Tensor x(dense_x); + auto out = paddle::experimental::sparse::to_sparse_csr(x, pten::Backend::CPU); + auto csr = std::dynamic_pointer_cast(out.impl()); + auto check = [&](const pten::SparseCsrTensor& csr) { + ASSERT_EQ(csr.non_zero_cols().numel(), non_zero_num); + int cmp_crows = memcmp(csr.non_zero_crows().data(), + crows_data.data(), + crows_data.size() * sizeof(int64_t)); + ASSERT_EQ(cmp_crows, 0); + int cmp_cols = memcmp(csr.non_zero_cols().data(), + cols_data.data(), + cols_data.size() * sizeof(int64_t)); + ASSERT_EQ(cmp_cols, 0); + int cmp_elements = memcmp(csr.non_zero_elements().data(), + non_zero_data.data(), + non_zero_data.size() * sizeof(float)); + ASSERT_EQ(cmp_elements, 0); + }; + check(*csr); + + // 1. test sparse_coo_to_csr + auto dense_dims = pten::framework::make_ddim({3, 3}); + pten::DenseTensorMeta indices_meta(pten::DataType::INT64, + {sparse_dim, non_zero_num}, + pten::DataLayout::NCHW); + pten::DenseTensorMeta values_meta( + pten::DataType::FLOAT32, {non_zero_num}, pten::DataLayout::NCHW); + + pten::CPUPlace place; + pten::DenseTensor indices(alloc.get(), indices_meta); + pten::DenseTensor values(alloc.get(), values_meta); + memcpy(indices.mutable_data(place), + indices_data.data(), + indices_data.size() * sizeof(int64_t)); + memcpy(values.mutable_data(place), + non_zero_data.data(), + non_zero_data.size() * sizeof(float)); + auto coo = + std::make_shared(indices, values, dense_dims); + paddle::experimental::Tensor coo_x(coo); + auto out2 = + paddle::experimental::sparse::to_sparse_csr(coo_x, pten::Backend::CPU); + + auto csr2 = std::dynamic_pointer_cast(out.impl()); + check(*csr2); +} diff --git a/paddle/pten/tests/kernels/test_sparse_utils_dev_api.cc b/paddle/pten/tests/kernels/test_sparse_utils_dev_api.cc index 967609e9a8c..c082b08d526 100644 --- a/paddle/pten/tests/kernels/test_sparse_utils_dev_api.cc +++ b/paddle/pten/tests/kernels/test_sparse_utils_dev_api.cc @@ -43,7 +43,7 @@ inline void CheckResult( #if defined(PADDLE_WITH_CUDA) if (coo.place() == pten::GPUPlace()) { - const auto* dev_ctx_cuda = static_cast(dev_ctx); + const auto* dev_ctx_gpu = static_cast(dev_ctx); DenseTensor indices( alloc.get(), DenseTensorMeta( @@ -53,8 +53,8 @@ inline void CheckResult( DenseTensorMeta(real_elements.dtype(), real_elements.dims(), real_elements.layout())); - pten::Copy(*dev_ctx_cuda, real_indices, true, &indices); - pten::Copy(*dev_ctx_cuda, real_elements, true, &elements); + pten::Copy(*dev_ctx_gpu, real_indices, true, &indices); + pten::Copy(*dev_ctx_gpu, real_elements, true, &elements); int cmp_indices = memcmp(indices.data(), non_zero_indices.data(), @@ -103,9 +103,6 @@ void TestDenseToSparseCoo(const DenseTensor& dense_x, // 2. test cuda #if defined(PADDLE_WITH_CUDA) - // paddle::platform::DeviceContextPool& pool = - // paddle::platform::DeviceContextPool::Instance(); - // auto* dev_ctx_cuda = pool.GetByPlace(paddle::platform::CUDAPlace()); pten::GPUContext dev_ctx_gpu; dev_ctx_gpu.PartialInitWithoutAllocator(); dev_ctx_gpu.SetAllocator( @@ -327,8 +324,6 @@ void TestSparseCsrToCoo(const DDim& dense_dims, const auto cuda_alloc = std::make_shared( paddle::platform::CUDAPlace()); - // auto& pool = paddle::platform::DeviceContextPool::Instance(); - // auto* dev_ctx_cuda = pool.GetByPlace(paddle::platform::CUDAPlace()); pten::DenseTensor d_crows(cuda_alloc.get(), crows_meta); pten::DenseTensor d_cols(cuda_alloc.get(), cols_meta); pten::DenseTensor d_values(cuda_alloc.get(), values_meta); @@ -382,5 +377,287 @@ TEST(DEV_API, sparse_csr_to_coo_batch_and_fp16) { non_zero_num); } +template +inline void CheckCsrResult( + const DeviceContext* dev_ctx, + const SparseCsrTensor& csr, + const std::vector non_zero_elements, + const std::vector& non_zero_crows, + const std::vector& non_zero_cols, + const int64_t non_zero_num, + const std::shared_ptr& alloc) { + const DenseTensor real_crows = csr.non_zero_crows(); + const DenseTensor real_cols = csr.non_zero_cols(); + const DenseTensor real_elements = csr.non_zero_elements(); + ASSERT_EQ(csr.non_zero_cols().numel(), non_zero_num); + +#if defined(PADDLE_WITH_CUDA) + if (csr.place() == paddle::platform::CUDAPlace()) { + const auto* dev_ctx_gpu = static_cast(dev_ctx); + DenseTensor crows( + alloc.get(), + DenseTensorMeta( + DataType::INT64, real_crows.dims(), real_crows.layout())); + DenseTensor cols( + alloc.get(), + DenseTensorMeta(DataType::INT64, real_cols.dims(), real_cols.layout())); + + DenseTensor elements(alloc.get(), + DenseTensorMeta(real_elements.dtype(), + real_elements.dims(), + real_elements.layout())); + pten::Copy(*dev_ctx_gpu, real_crows, true, &crows); + pten::Copy(*dev_ctx_gpu, real_cols, true, &cols); + pten::Copy(*dev_ctx_gpu, real_elements, true, &elements); + + int cmp_crows = memcmp(crows.data(), + non_zero_crows.data(), + non_zero_crows.size() * sizeof(IndicesT)); + ASSERT_EQ(cmp_crows, 0); + int cmp_cols = memcmp(cols.data(), + non_zero_cols.data(), + non_zero_cols.size() * sizeof(IndicesT)); + ASSERT_EQ(cmp_cols, 0); + int cmp_elements = memcmp(elements.data(), + non_zero_elements.data(), + non_zero_elements.size() * sizeof(ValueT)); + ASSERT_EQ(cmp_elements, 0); + } else { +#endif + int cmp_crows = memcmp(real_crows.data(), + non_zero_crows.data(), + non_zero_crows.size() * sizeof(IndicesT)); + ASSERT_EQ(cmp_crows, 0); + int cmp_cols = memcmp(real_cols.data(), + non_zero_cols.data(), + non_zero_cols.size() * sizeof(IndicesT)); + ASSERT_EQ(cmp_cols, 0); + int cmp_elements = memcmp(real_elements.data(), + non_zero_elements.data(), + non_zero_elements.size() * sizeof(ValueT)); + ASSERT_EQ(cmp_elements, 0); +#if defined(PADDLE_WITH_CUDA) + } +#endif +} + +template +void TestCooToCsr(const DDim& dense_dims, + const int64_t& non_zero_num, + const std::vector& non_zero_data, + const std::vector& non_zero_indices, + const std::vector& cols_data, + const std::vector& crows_data) { + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + + pten::CPUPlace cpu; + DenseTensorMeta indices_meta( + DataType::INT64, + {static_cast(dense_dims.size()), non_zero_num}, + DataLayout::NCHW); + DenseTensor indices(alloc.get(), indices_meta); + DenseTensorMeta values_meta( + paddle::experimental::CppTypeToDataType::Type(), + {non_zero_num}, + DataLayout::NCHW); + DenseTensor values(alloc.get(), values_meta); + + memcpy(indices.mutable_data(cpu), + non_zero_indices.data(), + non_zero_indices.size() * sizeof(int64_t)); + memcpy(values.mutable_data(cpu), + non_zero_data.data(), + non_zero_data.size() * sizeof(T)); + pten::SparseCooTensor coo(indices, values, dense_dims); + + // 1. test cpu + pten::CPUContext dev_ctx_cpu; + auto cpu_sparse_out = sparse::SparseCooToCsr(dev_ctx_cpu, coo); + CheckCsrResult(&dev_ctx_cpu, + cpu_sparse_out, + non_zero_data, + crows_data, + cols_data, + non_zero_num, + alloc); + +// 2. test cuda +#if defined(PADDLE_WITH_CUDA) + const auto cuda_alloc = + std::make_shared( + paddle::platform::CUDAPlace()); + pten::GPUContext dev_ctx_gpu; + dev_ctx_gpu.PartialInitWithoutAllocator(); + dev_ctx_gpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(dev_ctx_gpu.GetPlace(), dev_ctx_gpu.stream()) + .get()); + dev_ctx_gpu.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(pten::CPUPlace()) + .get()); + dev_ctx_gpu.PartialInitWithAllocator(); + pten::DenseTensor d_indices(cuda_alloc.get(), indices_meta); + pten::DenseTensor d_values(cuda_alloc.get(), values_meta); + pten::Copy(dev_ctx_gpu, indices, true, &d_indices); + pten::Copy(dev_ctx_gpu, values, true, &d_values); + pten::SparseCooTensor d_coo(d_indices, d_values, dense_dims); + auto cuda_sparse_out = sparse::SparseCooToCsr(dev_ctx_gpu, d_coo); + CheckCsrResult(&dev_ctx_gpu, + cuda_sparse_out, + non_zero_data, + crows_data, + cols_data, + non_zero_num, + alloc); +#endif +} + +TEST(DEV_API, coo_to_csr) { + // float dense_data[3][3] = {{0.0, 1.0, 0.0}, {2.0, 0.0, 3.0}, {3.2, 0.0, + // 0.0}}; + std::vector non_zero_data = {1.0, 2.0, 3.0, 3.2}; + std::vector non_zero_indices = {0, 1, 1, 2, 1, 0, 2, 0}; + std::vector cols_data = {1, 0, 2, 0}; + std::vector crows_data = {0, 1, 3, 4}; + const int64_t non_zero_num = 4; + auto dense_dims = pten::framework::make_ddim({3, 3}); + TestCooToCsr(dense_dims, + non_zero_num, + non_zero_data, + non_zero_indices, + cols_data, + crows_data); +} + +TEST(DEV_API, batch_coo_to_csr) { + // float dense_data[2][3][3] = + // {{{0.0, 1.0, 0.0}, {2.0, 0.0, 3.0}, {3.2, 0.0, 0.0}}, + // {{0.0, 1.0, 0.0}, {2.0, 0.0, 3.0}, {0.0, 0.0, 0.0}}}; + const int64_t non_zero_num = 7; + std::vector data = {1.0, 2.0, 3.0, 3.2, 1.0, 2.0, 3.0}; + std::vector non_zero_data(non_zero_num); + for (int64_t i = 0; i < non_zero_num; i++) { + non_zero_data[i] = static_cast(data[i]); + } + std::vector non_zero_indices = {0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 2, + 0, 1, 1, 1, 0, 2, 0, 1, 0, 2}; + std::vector cols_data = {1, 0, 2, 0, 1, 0, 2}; + std::vector crows_data = {0, 1, 3, 4, 0, 1, 3, 3}; + auto dense_dims = pten::framework::make_ddim({2, 3, 3}); + TestCooToCsr(dense_dims, + non_zero_num, + non_zero_data, + non_zero_indices, + cols_data, + crows_data); +} + +template +void TestDenseToSparseCsr(const DenseTensor& dense_x, + const int64_t non_zero_num, + const std::vector& non_zero_data, + const std::vector& crows_data, + const std::vector& cols_data) { + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + pten::CPUContext dev_ctx_cpu; + + // 1. test cpu + auto cpu_sparse_out = sparse::DenseToSparseCsr(dev_ctx_cpu, dense_x); + CheckCsrResult(&dev_ctx_cpu, + cpu_sparse_out, + non_zero_data, + crows_data, + cols_data, + non_zero_num, + alloc); +// 2. test cuda +#if defined(PADDLE_WITH_CUDA) + const auto cuda_alloc = + std::make_shared( + paddle::platform::CUDAPlace()); + DenseTensor d_dense_x( + cuda_alloc.get(), + DenseTensorMeta(dense_x.dtype(), dense_x.dims(), dense_x.layout())); + + pten::GPUContext dev_ctx_gpu; + dev_ctx_gpu.PartialInitWithoutAllocator(); + dev_ctx_gpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(dev_ctx_gpu.GetPlace(), dev_ctx_gpu.stream()) + .get()); + dev_ctx_gpu.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(pten::CPUPlace()) + .get()); + dev_ctx_gpu.PartialInitWithAllocator(); + pten::Copy(dev_ctx_gpu, dense_x, true, &d_dense_x); + auto sparse_out = sparse::DenseToSparseCsr(dev_ctx_gpu, d_dense_x); + + CheckCsrResult(&dev_ctx_gpu, + sparse_out, + non_zero_data, + crows_data, + cols_data, + non_zero_num, + alloc); +#endif +} + +TEST(DEV_API, dense_to_sparse_csr) { + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + + DenseTensor dense_x( + alloc.get(), + DenseTensorMeta( + DataType::FLOAT32, framework::make_ddim({3, 3}), DataLayout::NCHW)); + + pten::CPUPlace cpu; + auto* dense_x_data = dense_x.mutable_data(cpu); + float dense_data[3][3] = {{0.0, 1.0, 0.0}, {2.0, 0.0, 3.0}, {3.2, 0.0, 0.0}}; + std::vector non_zero_data = {1.0, 2.0, 3.0, 3.2}; + std::vector cols_data = {1, 0, 2, 0}; + std::vector crows_data = {0, 1, 3, 4}; + const int64_t non_zero_num = 4; + + std::copy(&dense_data[0][0], &dense_data[0][0] + 9, dense_x_data); + TestDenseToSparseCsr( + dense_x, non_zero_num, non_zero_data, crows_data, cols_data); +} + +TEST(DEV_API, dense_to_sparse_csr_batch) { + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + + DenseTensor dense_x(alloc.get(), + DenseTensorMeta(DataType::FLOAT16, + framework::make_ddim({2, 3, 3}), + DataLayout::NCHW)); + + pten::CPUPlace cpu; + auto* dense_x_data = dense_x.mutable_data(cpu); + const int64_t non_zero_num = 7; + float dense_data[2][3][3] = { + {{0.0, 1.0, 0.0}, {2.0, 0.0, 3.0}, {3.2, 0.0, 0.0}}, + {{0.0, 1.0, 0.0}, {2.0, 0.0, 0.0}, {3.2, 0.0, 0.0}}}; + std::vector data = {1.0, 2.0, 3.0, 3.2, 1.0, 2.0, 3.2}; + std::vector non_zero_data(non_zero_num); + for (int64_t i = 0; i < non_zero_num; i++) { + non_zero_data[i] = static_cast(data[i]); + } + std::vector cols_data = {1, 0, 2, 0, 1, 0, 0}; + std::vector crows_data = {0, 1, 3, 4, 0, 1, 2, 3}; + + float* dense_ptr = &dense_data[0][0][0]; + for (int i = 0; i < 18; i++) { + dense_x_data[i] = static_cast(dense_ptr[i]); + } + TestDenseToSparseCsr( + dense_x, non_zero_num, non_zero_data, crows_data, cols_data); +} + } // namespace tests } // namespace pten -- GitLab