diff --git a/paddle/pten/api/CMakeLists.txt b/paddle/pten/api/CMakeLists.txt index 89f115524f79ae615d6bb2af65126ae754e899cc..a993cb3ff8041dcaa9734687c0409aaa3e6cebc8 100644 --- a/paddle/pten/api/CMakeLists.txt +++ b/paddle/pten/api/CMakeLists.txt @@ -1,3 +1,2 @@ add_subdirectory(lib) - -cc_library(pten_api SRCS all.cc DEPS pten_function_api pten_bw_function_api manual_api) +cc_library(pten_api SRCS all.cc DEPS pten_function_api pten_bw_function_api manual_api sparse_api) diff --git a/paddle/pten/api/all.h b/paddle/pten/api/all.h index bf39ee27295a47df3ea38b5cb77f3fa3cbc32faa..a327bd998cb76e29f49ed108d2a97c5cc6ca9d69 100644 --- a/paddle/pten/api/all.h +++ b/paddle/pten/api/all.h @@ -27,6 +27,7 @@ limitations under the License. */ // new pten apis #include "paddle/pten/api/include/api.h" #include "paddle/pten/api/include/manual_api.h" +#include "paddle/pten/api/include/sparse_api.h" #include "paddle/pten/api/include/tensor.h" // pten common headers diff --git a/paddle/pten/api/include/sparse_api.h b/paddle/pten/api/include/sparse_api.h new file mode 100644 index 0000000000000000000000000000000000000000..22e511e62ab63bfc3c5a7991f568afcd41a7686b --- /dev/null +++ b/paddle/pten/api/include/sparse_api.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/pten/api/include/tensor.h" +#include "paddle/pten/common/backend.h" + +namespace paddle { +namespace experimental { +namespace sparse { + +PADDLE_API Tensor to_sparse_coo(const Tensor& x, + Backend backend, + const int64_t sparse_dim); + +} // namespace sparse +} // namespace experimental +} // namespace paddle diff --git a/paddle/pten/api/lib/CMakeLists.txt b/paddle/pten/api/lib/CMakeLists.txt index b4a9c65d55f24774280750b54f1d17b1b9643c2d..0b899f1abda9a49ed4a6ee8ec03970cba6777361 100644 --- a/paddle/pten/api/lib/CMakeLists.txt +++ b/paddle/pten/api/lib/CMakeLists.txt @@ -63,5 +63,6 @@ add_custom_command( VERBATIM) cc_library(manual_api SRCS manual_api.cc DEPS pten_tensor pten kernel_dispatch) +cc_library(sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch) cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch) cc_library(pten_bw_function_api SRCS ${bw_api_source_file} DEPS pten_tensor pten kernel_dispatch backward_infermeta pten_function_api) diff --git a/paddle/pten/api/lib/api_declare.h b/paddle/pten/api/lib/api_declare.h index 0023170714fa6bfeed4793313833278dc2bbc373..998e01e41eae2cf8f150d9def70e4717951ad4fd 100644 --- a/paddle/pten/api/lib/api_declare.h +++ b/paddle/pten/api/lib/api_declare.h @@ -19,3 +19,4 @@ limitations under the License. */ PT_DECLARE_API(Math); PT_DECLARE_API(Utils); +PT_DECLARE_API(SparseApi); diff --git a/paddle/pten/api/lib/sparse_api.cc b/paddle/pten/api/lib/sparse_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..d763bb7e8d620d3a980b32454cc4fe0399404405 --- /dev/null +++ b/paddle/pten/api/lib/sparse_api.cc @@ -0,0 +1,102 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/api/include/sparse_api.h" + +#include +#include "glog/logging.h" +#include "paddle/pten/api/lib/api_registry.h" +#include "paddle/pten/api/lib/kernel_dispatch.h" +#include "paddle/pten/api/lib/utils/storage.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/infermeta/unary.h" + +PT_DECLARE_KERNEL(dense_to_sparse_coo, CPU, ALL_LAYOUT); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_DECLARE_KERNEL(dense_to_sparse_coo, GPU, ALL_LAYOUT); +#endif + +namespace paddle { +namespace experimental { +namespace sparse { + +PADDLE_API Tensor to_sparse_coo(const Tensor& x, + Backend backend, + const int64_t sparse_dim) { + if (x.layout() == pten::DataLayout::SPARSE_COO) { + return x; + } + // 1. Get kernel signature and kernel + auto kernel_key_set = ParseKernelKeyByInputArgs(x); + kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend); + auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); + std::string kernel_name = "dense_to_sparse_coo"; + if (x.layout() == pten::DataLayout::SPARSE_CSR) { + kernel_name = "sparse_csr_to_coo"; + } + + auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( + kernel_name, kernel_key); + + VLOG(6) << "to API kernel key: " << kernel_key; + VLOG(6) << "to API kernel: " << kernel; + + // 2. Get Device Context + auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); + auto kernel_context = pten::KernelContext(dev_ctx); + + // 3. Auto data transform + if (x.layout() == pten::DataLayout::SPARSE_CSR) { + auto input = std::dynamic_pointer_cast(x.impl()); + kernel_context.EmplaceBackInput(input.get()); + } else { + auto input = std::dynamic_pointer_cast(x.impl()); + kernel_context.EmplaceBackInput(input.get()); + kernel_context.EmplaceBackAttr(sparse_dim); + } + + // 4. InferMeta + auto indices_meta = pten::DenseTensorMeta( + pten::DataType::INT64, {-1}, pten::DataLayout::NCHW); + auto elements_meta = pten::DenseTensorMeta(x.dtype(), {-1}, x.layout()); + + // 5. Prepare outputs + // create empty SparseCooTensor + pten::DenseTensor non_zero_indices( + pten::make_intrusive( + pten::TransToFluidPlace(backend)), + std::move(indices_meta)); + pten::DenseTensor non_zero_elements( + pten::make_intrusive( + pten::TransToFluidPlace(backend)), + std::move(elements_meta)); + auto coo = std::make_shared( + non_zero_indices, non_zero_elements, x.dims()); + + kernel_context.EmplaceBackOutput(coo.get()); + Tensor out; + out.set_impl(coo); + + // 6. Call kernel + kernel(&kernel_context); + + return out; +} + +} // namespace sparse +} // namespace experimental +} // namespace paddle + +PT_REGISTER_API(SparseApi); diff --git a/paddle/pten/kernels/CMakeLists.txt b/paddle/pten/kernels/CMakeLists.txt index e14c2f6b6c47c74d5ee4f79741a5cbde2b49a7b4..a9b81ad4eb2b3914b005caf348d65bf87e788dca 100644 --- a/paddle/pten/kernels/CMakeLists.txt +++ b/paddle/pten/kernels/CMakeLists.txt @@ -9,7 +9,7 @@ add_subdirectory(funcs) # pten depends all pten kernel targets set_property(GLOBAL PROPERTY PTEN_KERNELS "") -set(COMMON_KERNEL_DEPS dense_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils) +set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function) # remove this dep after removing fluid deps on tensor creation set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} pten_api_utils) @@ -23,5 +23,6 @@ endif() # auto build kernel targets by cmake register_kernels(EXCLUDES math_kernel DEPS ${COMMON_KERNEL_DEPS}) kernel_library(math_kernel DEPS ${MATH_KERNEL_DEPS}) +add_subdirectory(sparse) copy_if_different(${kernel_declare_file} ${kernel_declare_file_final}) diff --git a/paddle/pten/kernels/sparse/CMakeLists.txt b/paddle/pten/kernels/sparse/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3e4a968b7a8a569fc518366c321a57e2738bf12a --- /dev/null +++ b/paddle/pten/kernels/sparse/CMakeLists.txt @@ -0,0 +1,3 @@ + +set(SPARSE_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils) +register_kernels(DEPS ${SPARSE_KERNEL_DEPS} SUB_DIR "sparse_kernel") diff --git a/paddle/pten/kernels/sparse/cpu/sparse_utils_kernel.cc b/paddle/pten/kernels/sparse/cpu/sparse_utils_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..e4cd2d42be382e9ba33e46e3ec87270a92dec97a --- /dev/null +++ b/paddle/pten/kernels/sparse/cpu/sparse_utils_kernel.cc @@ -0,0 +1,119 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/kernels/sparse/sparse_utils_kernel.h" +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/core/tensor_meta.h" + +namespace pten { +namespace sparse { + +template +inline bool IsZero(const T* data, const size_t n) { + const T zero = static_cast(0); + for (size_t i = 0; i < n; i++) { + if (data[i] != zero) { + return false; + } + } + return true; +} + +// TODO(zhangkaihuo): implement a kernel to count the number of non-zero +// elements in tensor +template +inline int64_t GetNonZeroNum(const DenseTensor& dense, + const int64_t sparse_dim) { + const auto& dims = dense.dims(); + PADDLE_ENFORCE_GE( + dims.size(), + sparse_dim, + paddle::platform::errors::InvalidArgument( + "sparse_dim(%d) should be less than or equal to dense.dim(%d)", + sparse_dim, + dims.size())); + + auto dims_2d = flatten_to_2d(dims, sparse_dim); + const int rows = dims_2d[0]; + const int cols = dims_2d[1]; + + const T* data = dense.data(); + int64_t non_zero_num = 0; + for (int64_t i = 0; i < rows; i++) { + if (!IsZero(data + i * cols, cols)) { + non_zero_num = non_zero_num + 1; + } + } + return non_zero_num; +} + +template +void DenseToSparseCooKernel(const Context& dev_ctx, + const DenseTensor& x, + const int64_t sparse_dim, + SparseCooTensor* out) { + const T* x_data = x.data(); + const auto& x_dims = x.dims(); + + int64_t non_zero_num = GetNonZeroNum(x, sparse_dim); + + const auto place = dev_ctx.GetPlace(); + const auto values_dims = InferDenseDims(x_dims, sparse_dim, non_zero_num); + DenseTensorMeta indices_meta(DataType::INT64, + {sparse_dim, static_cast(non_zero_num)}, + DataLayout::NCHW); + DenseTensorMeta values_meta(x.meta().dtype, values_dims, x.meta().layout); + pten::DenseTensor indices = + pten::Empty(dev_ctx, std::move(indices_meta)); + pten::DenseTensor values = + pten::Empty(dev_ctx, std::move(values_meta)); + int64_t* indices_data = indices.mutable_data(place); + T* values_data = values.mutable_data(place); + + auto dims_2d = flatten_to_2d(x_dims, sparse_dim); + const int rows = dims_2d[0]; + const int cols = dims_2d[1]; + + int index = 0; + for (int i = 0; i < rows; i++) { + if (!IsZero(x_data + i * cols, cols)) { + int64_t sparse_index = i; + for (int64_t j = sparse_dim - 1; j >= 0; j--) { + indices_data[j * non_zero_num + index] = sparse_index % x_dims[j]; + sparse_index /= x_dims[j]; + } + memcpy(values_data + index * cols, x_data + i * cols, cols * sizeof(T)); + ++index; + } + } + out->SetMember(indices, values, x_dims, true); +} + +} // namespace sparse +} // namespace pten + +PT_REGISTER_KERNEL(dense_to_sparse_coo, + CPU, + ALL_LAYOUT, + pten::sparse::DenseToSparseCooKernel, + float, + double, + paddle::float16, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} diff --git a/paddle/pten/kernels/sparse/gpu/sparse_utils_kernel.cu b/paddle/pten/kernels/sparse/gpu/sparse_utils_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..fa37220660ff38909f8b3b7388ab2d62e382ccf5 --- /dev/null +++ b/paddle/pten/kernels/sparse/gpu/sparse_utils_kernel.cu @@ -0,0 +1,231 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/core/tensor_meta.h" +#include "paddle/pten/kernels/sparse/sparse_utils_kernel.h" + +namespace pten { +namespace sparse { + +template +inline __device__ bool DevIsZero(const T* data, const int64_t cols) { + const T zero = static_cast(0); + // TODO(zhangkaihuo): check the data is zero or not in parallen when cols > 1 + for (int64_t i = 0; i < cols; i++) { + if (data[i] != zero) { + return false; + } + } + return true; +} + +template +__global__ void GetNonZeroNums(const T* dense_data, + const int rows, + const int cols, + int* non_zero_num, + int* temp_indexs) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + __shared__ int counter; + if (threadIdx.x == 0) counter = 0; + __syncthreads(); + + for (int i = tid; i < rows; i += gridDim.x * blockDim.x) { + int index = -1; + // TODO(zhangkaihuo): when cols=1, vectorization can be used + if (!DevIsZero(dense_data + i * cols, cols)) { + // use reductions? + atomicAdd(&counter, 1); + index = i; + } + temp_indexs[i] = index; + } + __syncthreads(); + if (threadIdx.x == 0) { + atomicAdd(non_zero_num, counter); + } +} + +template +__global__ void GetNonZeroElementsAndIndices(const T* dense_data, + const int64_t sparse_dim, + const int64_t cols, + const int64_t* x_dims, + const int non_zero_num, + const int* indexs, + int64_t* indices, + T* sparse_data) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) { + int64_t sparse_index = indexs[i]; + int64_t x_index = sparse_index; + for (int64_t j = sparse_dim - 1; j >= 0; j--) { + indices[j * non_zero_num + i] = sparse_index % x_dims[j]; + sparse_index /= x_dims[j]; + } + + for (int j = 0; j < cols; j++) { + sparse_data[i * cols + j] = dense_data[x_index * cols + j]; + } + } +} + +template +void GetGpuLaunchConfig1D(const Context& dev_ctx, + const int64_t n, + int* grid_size, + int* block_size) { + const int MAX_BLOCK_DIM = dev_ctx.GetMaxThreadsPerBlock(); + const int MAX_GRID_DIM = dev_ctx.GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM; + *block_size = (n >= MAX_BLOCK_DIM) ? MAX_BLOCK_DIM + : (1 << static_cast(std::log2(n))); + *grid_size = n / *block_size; + *grid_size = (*grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : *grid_size; +} + +template +void DenseToSparseCooKernel(const Context& dev_ctx, + const DenseTensor& x, + const int64_t sparse_dim, + SparseCooTensor* out) { + const T* x_data = x.data(); + const auto& x_dims = x.dims(); + auto dims_2d = flatten_to_2d(x_dims, sparse_dim); + const int rows = dims_2d[0]; + const int cols = dims_2d[1]; + auto nums_meta = + pten::DenseTensorMeta(DataType::INT32, {1}, pten::DataLayout::NCHW); + DenseTensor nums = + pten::Empty(dev_ctx, std::move(nums_meta)); + auto x_dims_meta = + pten::DenseTensorMeta(DataType::INT64, + {static_cast(x_dims.size())}, + pten::DataLayout::NCHW); + DenseTensor d_x_dims = + pten::Empty(dev_ctx, std::move(x_dims_meta)); + + const auto place = dev_ctx.GetPlace(); + + // 1. get numbers of non zero elements, and get the index of non zero elements + int* nums_ptr = nums.mutable_data(place); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS( + hipMemsetAsync(nums_ptr, 0, sizeof(int), dev_ctx.stream())); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(nums_ptr, 0, sizeof(int), dev_ctx.stream())); +#endif + int grid_size = 1, block_size = 1; + GetGpuLaunchConfig1D(dev_ctx, rows, &grid_size, &block_size); + + auto temp_indexs_meta = + pten::DenseTensorMeta(DataType::INT32, {rows}, pten::DataLayout::NCHW); + DenseTensor temp_indexs = + pten::Empty(dev_ctx, std::move(temp_indexs_meta)); + int* temp_indexs_ptr = temp_indexs.mutable_data(place); + GetNonZeroNums<<>>( + x_data, rows, cols, nums_ptr, temp_indexs_ptr); +#ifdef PADDLE_WITH_HIP + thrust::remove(thrust::hip::par.on(dev_ctx.stream()), +#else + thrust::remove(thrust::cuda::par.on(dev_ctx.stream()), +#endif + temp_indexs_ptr, + temp_indexs_ptr + rows, + -1); + + // 2. copy non_zero_num to host, copy x_dims to device + int non_zero_num = 0; +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(&non_zero_num, + nums_ptr, + sizeof(int), + hipMemcpyDeviceToHost, + dev_ctx.stream())); +#else + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&non_zero_num, + nums_ptr, + sizeof(int), + cudaMemcpyDeviceToHost, + dev_ctx.stream())); +#endif + +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS( + hipMemcpyAsync(d_x_dims.mutable_data(place), + x_dims.Get(), + x_dims.size() * sizeof(x_dims[0]), + hipMemcpyHostToDevice, + dev_ctx.stream())); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemcpyAsync(d_x_dims.mutable_data(place), + x_dims.Get(), + x_dims.size() * sizeof(x_dims[0]), + cudaMemcpyHostToDevice, + dev_ctx.stream())); +#endif + + dev_ctx.Wait(); // wait the copy + + const auto values_dims = InferDenseDims(x_dims, sparse_dim, non_zero_num); + DenseTensorMeta indices_meta(DataType::INT64, + {sparse_dim, static_cast(non_zero_num)}, + DataLayout::NCHW); + DenseTensorMeta values_meta(x.meta().dtype, values_dims, x.meta().layout); + pten::DenseTensor indices( + pten::make_intrusive( + dev_ctx.GetPlace()), + std::move(indices_meta)); + pten::DenseTensor values( + pten::make_intrusive( + dev_ctx.GetPlace()), + std::move(values_meta)); + int64_t* indices_data = indices.mutable_data(place); + T* sparse_data = values.mutable_data(place); + + // 3. calc indices by indexs and get values by indexs + GetGpuLaunchConfig1D(dev_ctx, non_zero_num, &grid_size, &block_size); + GetNonZeroElementsAndIndices<<>>( + x_data, + sparse_dim, + cols, + d_x_dims.data(), + non_zero_num, + temp_indexs_ptr, + indices_data, + sparse_data); + out->SetMember(indices, values, x_dims, true); +} + +} // namespace sparse +} // namespace pten + +PT_REGISTER_KERNEL(dense_to_sparse_coo, + GPU, + ALL_LAYOUT, + pten::sparse::DenseToSparseCooKernel, + float, + double, + pten::dtype::float16, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} diff --git a/paddle/pten/kernels/sparse/sparse_utils_kernel.h b/paddle/pten/kernels/sparse/sparse_utils_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..a705044c5d2111f6144544d403baff30c1f0f3de --- /dev/null +++ b/paddle/pten/kernels/sparse/sparse_utils_kernel.h @@ -0,0 +1,61 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/sparse_coo_tensor.h" +#include "paddle/pten/core/sparse_csr_tensor.h" +#include "paddle/pten/kernels/empty_kernel.h" + +namespace pten { +namespace sparse { + +inline const DDim InferDenseDims(const DDim& x_dims, + const int64_t sparse_dim, + const int64_t non_zero_num) { + auto dense_dim = x_dims.size() - sparse_dim; + DDim values_dims; + if (dense_dim) { + std::vector dense_dim_vec(dense_dim + 1); + dense_dim_vec[0] = non_zero_num; + memcpy(&dense_dim_vec[1], + x_dims.Get() + sparse_dim, + dense_dim * sizeof(x_dims[0])); + values_dims = pten::framework::make_ddim(dense_dim_vec); + } else { + values_dims = pten::framework::make_ddim({non_zero_num}); + } + return values_dims; +} + +template +void DenseToSparseCooKernel(const Context& dev_ctx, + const DenseTensor& x, + const int64_t sparse_dim, + SparseCooTensor* out); + +template +SparseCooTensor DenseToSparseCoo(const Context& dev_ctx, + const DenseTensor& x, + const int64_t sparse_dim) { + DenseTensor indices = pten::Empty(dev_ctx); + DenseTensor values = pten::Empty(dev_ctx); + SparseCooTensor coo(indices, values, x.dims()); + DenseToSparseCooKernel(dev_ctx, x, sparse_dim, &coo); + return coo; +} + +} // namespace sparse +} // namespace pten diff --git a/paddle/pten/tests/api/CMakeLists.txt b/paddle/pten/tests/api/CMakeLists.txt index 1da7fb9613024803dc1b5c6a7c3d3d2f3d59d093..33a1e25f3c534a03b3c1e10dc4cf807012c302cb 100644 --- a/paddle/pten/tests/api/CMakeLists.txt +++ b/paddle/pten/tests/api/CMakeLists.txt @@ -22,3 +22,4 @@ cc_test(test_scale_api SRCS test_scale_api.cc DEPS pten_tensor pten_api pten_api cc_test(test_scale_benchmark SRCS test_scale_benchmark.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_conj_api SRCS test_conj_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_concat_api SRCS test_concat_api.cc DEPS pten_tensor pten_api pten_api_utils) +cc_test(test_sparse_utils_api SRCS test_sparse_utils_api.cc DEPS pten_tensor pten_api pten_api_utils) diff --git a/paddle/pten/tests/api/test_sparse_utils_api.cc b/paddle/pten/tests/api/test_sparse_utils_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..e102bffea8e1c27be071247fe437695aba201c02 --- /dev/null +++ b/paddle/pten/tests/api/test_sparse_utils_api.cc @@ -0,0 +1,65 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/api/include/api.h" + +#include "paddle/pten/api/include/sparse_api.h" + +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/core/sparse_coo_tensor.h" + +TEST(API, to_sparse_coo) { + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + + auto dense_x = std::make_shared( + alloc.get(), + pten::DenseTensorMeta(pten::DataType::FLOAT32, + pten::framework::make_ddim({3, 3}), + pten::DataLayout::NCHW)); + + pten::CPUPlace cpu; + const int64_t sparse_dim = 2; + auto* dense_x_data = dense_x->mutable_data(cpu); + float dense_data[3][3] = {{0.0, 1.0, 0.0}, {2.0, 0.0, 3.0}, {3.2, 0.0, 0.0}}; + std::vector non_zero_data = {1.0, 2.0, 3.0, 3.2}; + std::vector indices_data = {0, 1, 1, 2, 1, 0, 2, 0}; + std::vector cols_data = {1, 0, 2, 0}; + std::vector crows_data = {0, 1, 3, 4}; + const int64_t non_zero_num = 4; + + std::copy(&dense_data[0][0], &dense_data[0][0] + 9, dense_x_data); + + pten::CPUContext dev_ctx_cpu; + + // 1. test dense_to_sparse_coo + paddle::experimental::Tensor x(dense_x); + auto out = paddle::experimental::sparse::to_sparse_coo( + x, pten::Backend::CPU, sparse_dim); + auto coo = std::dynamic_pointer_cast(out.impl()); + ASSERT_EQ(coo->nnz(), non_zero_num); + int cmp_indices = memcmp(coo->non_zero_indices().data(), + indices_data.data(), + indices_data.size() * sizeof(int64_t)); + ASSERT_EQ(cmp_indices, 0); + int cmp_elements = memcmp(coo->non_zero_elements().data(), + non_zero_data.data(), + non_zero_data.size() * sizeof(float)); + ASSERT_EQ(cmp_elements, 0); +} diff --git a/paddle/pten/tests/kernels/CMakeLists.txt b/paddle/pten/tests/kernels/CMakeLists.txt index 407e5c097aec44d9f70d0d774b04c49f283bdd0e..e2063241689f929e6d173bcb29dde849ca5a3f48 100644 --- a/paddle/pten/tests/kernels/CMakeLists.txt +++ b/paddle/pten/tests/kernels/CMakeLists.txt @@ -11,3 +11,4 @@ cc_test(test_reshape_dev_api SRCS test_reshape_dev_api.cc DEPS pten pten_api_uti cc_test(test_sum_dev_api SRCS test_sum_dev_api.cc DEPS pten pten_api_utils) cc_test(test_conj_dev_api SRCS test_conj_dev_api.cc DEPS pten pten_api_utils) cc_test(test_concat_dev_api SRCS test_concat_dev_api.cc DEPS pten pten_api_utils) +cc_test(test_sparse_utils_dev_api SRCS test_sparse_utils_dev_api.cc DEPS pten pten_api_utils) diff --git a/paddle/pten/tests/kernels/test_sparse_utils_dev_api.cc b/paddle/pten/tests/kernels/test_sparse_utils_dev_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..04c288aa06a755d11f7c783a62eb0733c913c0b8 --- /dev/null +++ b/paddle/pten/tests/kernels/test_sparse_utils_dev_api.cc @@ -0,0 +1,250 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF NCHW KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/kernels/copy_kernel.h" +#include "paddle/pten/kernels/sparse/sparse_utils_kernel.h" + +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace pten { +namespace tests { + +template +inline void CheckResult( + const DeviceContext* dev_ctx, + const SparseCooTensor& coo, + const std::vector non_zero_elements, + const std::vector& non_zero_indices, + const int64_t non_zero_num, + const std::shared_ptr& alloc) { + const DenseTensor real_indices = coo.non_zero_indices(); + const DenseTensor real_elements = coo.non_zero_elements(); + ASSERT_EQ(coo.nnz(), non_zero_num); + +#if defined(PADDLE_WITH_CUDA) + if (coo.place() == paddle::platform::CUDAPlace()) { + const auto* dev_ctx_cuda = + static_cast(dev_ctx); + DenseTensor indices( + alloc.get(), + DenseTensorMeta( + DataType::INT64, real_indices.dims(), real_indices.layout())); + + DenseTensor elements(alloc.get(), + DenseTensorMeta(real_elements.dtype(), + real_elements.dims(), + real_elements.layout())); + pten::Copy(*dev_ctx_cuda, real_indices, true, &indices); + pten::Copy(*dev_ctx_cuda, real_elements, true, &elements); + + int cmp_indices = memcmp(indices.data(), + non_zero_indices.data(), + non_zero_indices.size() * sizeof(IndicesT)); + ASSERT_EQ(cmp_indices, 0); + int cmp_elements = memcmp(elements.data(), + non_zero_elements.data(), + non_zero_elements.size() * sizeof(ValueT)); + ASSERT_EQ(cmp_elements, 0); + } else { +#endif + int cmp_indices = memcmp(real_indices.data(), + non_zero_indices.data(), + non_zero_indices.size() * sizeof(IndicesT)); + ASSERT_EQ(cmp_indices, 0); + int cmp_elements = memcmp(real_elements.data(), + non_zero_elements.data(), + non_zero_elements.size() * sizeof(ValueT)); + ASSERT_EQ(cmp_elements, 0); +#if defined(PADDLE_WITH_CUDA) + } +#endif +} + +template +void TestDenseToSparseCoo(const DenseTensor& dense_x, + const int64_t sparse_dim, + const std::vector& non_zero_data, + const std::vector& indices_data, + const int64_t non_zero_num) { + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + + pten::CPUContext dev_ctx_cpu; + // 1. test cpu + auto cpu_sparse_out = + sparse::DenseToSparseCoo(dev_ctx_cpu, dense_x, sparse_dim); + CheckResult(&dev_ctx_cpu, + cpu_sparse_out, + non_zero_data, + indices_data, + non_zero_num, + alloc); + +// 2. test cuda +#if defined(PADDLE_WITH_CUDA) + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx_cuda = pool.GetByPlace(paddle::platform::CUDAPlace()); + const auto cuda_alloc = + std::make_shared( + paddle::platform::CUDAPlace()); + DenseTensor d_dense_x( + cuda_alloc.get(), + DenseTensorMeta(dense_x.dtype(), dense_x.dims(), dense_x.layout())); + + pten::Copy(*dev_ctx_cuda, dense_x, true, &d_dense_x); + auto sparse_out = + sparse::DenseToSparseCoo(*dev_ctx_cuda, d_dense_x, sparse_dim); + CheckResult(dev_ctx_cuda, + sparse_out, + non_zero_data, + indices_data, + non_zero_num, + alloc); +#endif +} + +TEST(DEV_API, to_sparse_coo) { + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + + std::default_random_engine random(time(NULL)); + std::uniform_real_distribution dis(0.0, 1.0); + std::uniform_int_distribution dis_int(4, 64); + const int rows = dis_int(random), cols = dis_int(random); + DenseTensor dense_x( + alloc.get(), + DenseTensorMeta(DataType::FLOAT32, {rows, cols}, DataLayout::NCHW)); + + pten::CPUPlace cpu; + auto* dense_x_data = dense_x.mutable_data(cpu); + std::vector dense_data(rows * cols); + std::vector non_zero_data; + std::vector rows_data, cols_data; + const int64_t sparse_dim = 2; + + const float zero_rate = dis(random); + + int64_t non_zero_num = 0; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + bool iszero = dis(random) < zero_rate; + if (iszero) { + dense_data[i * cols + j] = 0.0; + } else { + float data = dis(random); + dense_data[i * cols + j] = data; + non_zero_data.push_back(data); + rows_data.push_back(i); + cols_data.push_back(j); + non_zero_num += 1; + } + } + } + + std::copy( + dense_data.data(), dense_data.data() + dense_data.size(), dense_x_data); + + std::vector indices_data(non_zero_num * 2); + memcpy(&indices_data[0], &rows_data[0], non_zero_num * sizeof(int64_t)); + memcpy(&indices_data[non_zero_num], + &cols_data[0], + non_zero_num * sizeof(int64_t)); + + TestDenseToSparseCoo( + dense_x, sparse_dim, non_zero_data, indices_data, non_zero_num); +} + +TEST(DEV_API, to_sparse_coo_hybird) { + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + + DenseTensor dense_x( + alloc.get(), + DenseTensorMeta(DataType::FLOAT32, {3, 3}, DataLayout::NCHW)); + + pten::CPUPlace cpu; + const int64_t sparse_dim = 1; // the non zero element is a vector + auto* dense_x_data = dense_x.mutable_data(cpu); + float dense_data[3][3] = {{0.0, 1.0, 0.0}, {0.0, 0.0, 0.0}, {3.2, 0.0, 0.0}}; + std::vector non_zero_data = { + /*element0(*/ 0.0, 1.0, 0.0 /*)*/, /*element1(*/ 3.2, 0.0, 0.0 /*)*/}; + std::vector indices_data = {0, 2}; + const int64_t non_zero_num = 2; + + std::copy(&dense_data[0][0], &dense_data[0][0] + 9, dense_x_data); + TestDenseToSparseCoo( + dense_x, sparse_dim, non_zero_data, indices_data, non_zero_num); +} + +TEST(DEV_API, to_sparse_coo_fp16) { + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + + DenseTensor dense_x( + alloc.get(), + DenseTensorMeta(DataType::FLOAT16, {3, 3}, DataLayout::NCHW)); + + pten::CPUPlace cpu; + const int64_t sparse_dim = 2; + const int64_t non_zero_num = 2; + auto* dense_x_data = dense_x.mutable_data(cpu); + float dense_data[3][3] = {{0.0, 1.0, 0.0}, {0.0, 0.0, 0.0}, {3.2, 0.0, 0.0}}; + std::vector data = {1.0, 3.2}; + std::vector non_zero_data(non_zero_num); + for (int i = 0; i < non_zero_num; i++) { + non_zero_data[i] = static_cast(data[i]); + } + std::vector indices_data = {0, 2, 1, 0}; + + std::copy(&dense_data[0][0], &dense_data[0][0] + 9, dense_x_data); + TestDenseToSparseCoo( + dense_x, sparse_dim, non_zero_data, indices_data, non_zero_num); +} + +TEST(DEV_API, to_sparse_coo_batch) { + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + + DenseTensor dense_x( + alloc.get(), + DenseTensorMeta(DataType::FLOAT32, {2, 3, 3}, DataLayout::NCHW)); + + pten::CPUPlace cpu; + const int64_t sparse_dim = 3; + const int64_t non_zero_num = 4; + auto* dense_x_data = dense_x.mutable_data(cpu); + float dense_data[2][3][3] = { + {{0.0, 1.0, 0.0}, {0.0, 0.0, 0.0}, {2.0, 0.0, 0.0}}, + {{0.0, 0.0, 0.0}, {0.0, 3.0, 0.0}, {4.0, 0.0, 0.0}}}; + std::vector non_zero_data = {1.0, 2.0, 3.0, 4.0}; + std::vector indices_data = {0, 0, 1, 1, 0, 2, 1, 2, 1, 0, 1, 0}; + /* + 0, 0, 1, 1, + 0, 2, 1, 2, + 1, 0, 1, 0 + */ + + std::copy(&dense_data[0][0][0], &dense_data[0][0][0] + 18, dense_x_data); + TestDenseToSparseCoo( + dense_x, sparse_dim, non_zero_data, indices_data, non_zero_num); +} + +} // namespace tests +} // namespace pten