From 3673974833c32ae7f5b857a96665bf71f2512bb2 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 7 Sep 2022 09:42:18 +0800 Subject: [PATCH] [Sparse]Rename sparse kernel (#45730) --- paddle/phi/api/lib/sparse_api_custom_impl.cc | 12 +- paddle/phi/api/yaml/sparse_api.yaml | 50 ++++---- paddle/phi/api/yaml/sparse_bw_api.yaml | 45 +++---- .../kernels/sparse/cpu/elementwise_kernel.cc | 6 +- .../kernels/sparse/cpu/sparse_utils_kernel.cc | 94 ++++++++------- .../kernels/sparse/gpu/matmul_grad_kernel.cu | 4 +- .../kernels/sparse/gpu/sparse_utils_kernel.cu | 110 +++++++++--------- .../sparse/sparse_utils_grad_kernel.cc | 26 ++--- .../kernels/sparse/sparse_utils_grad_kernel.h | 10 +- .../phi/kernels/sparse/sparse_utils_kernel.h | 80 +++++++------ paddle/phi/tests/api/test_sparse_utils_api.cc | 2 +- .../kernels/test_sparse_activation_dev_api.cc | 2 +- .../test_sparse_elementwise_dev_api.cc | 87 +++++++------- .../kernels/test_sparse_utils_dev_api.cc | 26 ++--- .../fluid/dygraph/varbase_patch_methods.py | 25 +--- python/paddle/incubate/sparse/creation.py | 2 +- 16 files changed, 271 insertions(+), 310 deletions(-) diff --git a/paddle/phi/api/lib/sparse_api_custom_impl.cc b/paddle/phi/api/lib/sparse_api_custom_impl.cc index 73f5b28f459..6aaf21a5e7f 100644 --- a/paddle/phi/api/lib/sparse_api_custom_impl.cc +++ b/paddle/phi/api/lib/sparse_api_custom_impl.cc @@ -30,9 +30,9 @@ Tensor to_sparse_coo_impl(const Tensor& x, const int64_t sparse_dim) { } // 1. Get kernel signature and kernel - std::string kernel_name = "dense_to_sparse_coo"; + std::string kernel_name = "dense_to_coo"; if (x.layout() == phi::DataLayout::SPARSE_CSR) { - kernel_name = "sparse_csr_to_coo"; + kernel_name = "csr_to_coo"; } auto kernel_key_set = ParseKernelKeyByInputArgs(x); @@ -88,9 +88,9 @@ Tensor to_sparse_csr_impl(const Tensor& x) { return x; } // 1. Get kernel signature and kernel - std::string kernel_name = "dense_to_sparse_csr"; + std::string kernel_name = "dense_to_csr"; if (x.layout() == phi::DataLayout::SPARSE_COO) { - kernel_name = "sparse_coo_to_csr"; + kernel_name = "coo_to_csr"; } auto kernel_key_set = ParseKernelKeyByInputArgs(x); @@ -151,9 +151,9 @@ Tensor to_dense_impl(const Tensor& x) { } // 1. Get kernel signature and kernel - std::string kernel_name = "sparse_coo_to_dense"; + std::string kernel_name = "coo_to_dense"; if (x.layout() == phi::DataLayout::SPARSE_CSR) { - kernel_name = "sparse_csr_to_dense"; + kernel_name = "csr_to_dense"; } auto kernel_key_set = ParseKernelKeyByInputArgs(x); diff --git a/paddle/phi/api/yaml/sparse_api.yaml b/paddle/phi/api/yaml/sparse_api.yaml index e11306f21f2..ca40d10b496 100644 --- a/paddle/phi/api/yaml/sparse_api.yaml +++ b/paddle/phi/api/yaml/sparse_api.yaml @@ -82,34 +82,13 @@ - api : conv3d args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) - output : Tensor(out), Tensor(rulebook), Tensor(counter) + output : Tensor(out), Tensor(rulebook), Tensor(counter) kernel : func : conv3d_coo{sparse_coo, dense -> sparse_coo, dense, dense} layout : x intermediate: rulebook, counter backward : conv3d_coo_grad -- api : coo_to_dense - args : (Tensor x) - output : Tensor(out) - invoke : to_dense_impl(x) - backward : coo_to_dense_grad - -- api : create_sparse_coo_tensor - args : (Tensor values, Tensor indices, IntArray dense_shape) - output : Tensor(out) - kernel : - func : sparse_coo_tensor{dense, dense -> sparse_coo} - layout : values - data_type : values - backward : create_sparse_coo_tensor_grad - -- api : dense_to_coo - args : (Tensor x, int64_t sparse_dim) - output : Tensor(out) - invoke : to_sparse_coo_impl(x, sparse_dim) - backward : dense_to_coo_grad - - api : divide args : (Tensor x, Tensor y) output : Tensor(out) @@ -224,6 +203,15 @@ layout : x backward : softmax_grad +- api : sparse_coo_tensor + args : (Tensor values, Tensor indices, IntArray dense_shape) + output : Tensor(out) + kernel : + func : sparse_coo_tensor{dense, dense -> sparse_coo} + layout : values + data_type : values + backward : sparse_coo_tensor_grad + - api : sqrt args : (Tensor x) output : Tensor(out) @@ -272,24 +260,32 @@ - api : to_dense args : (Tensor x) output : Tensor(out) - invoke : to_dense_impl(x) + kernel : + func : coo_to_dense {sparse_coo -> dense}, + csr_to_dense {sparse_csr -> dense} + backward : to_dense_grad - api : to_sparse_coo args : (Tensor x, int64_t sparse_dim) output : Tensor(out) - invoke : to_sparse_coo_impl(x, sparse_dim) + kernel : + func : dense_to_coo { dense -> sparse_coo }, + csr_to_coo { sparse_csr -> sparse_coo} + backward : to_sparse_coo_grad - api : to_sparse_csr args : (Tensor x) output : Tensor(out) - invoke : to_sparse_csr_impl(x) + kernel : + func : dense_to_csr {dense -> sparse_csr}, + coo_to_csr {sparse_coo -> sparse_csr} - api : values args : (Tensor x) output : Tensor(out) kernel : - func : coo_values{sparse_coo -> dense}, - csr_values{sparse_csr -> dense} + func : values_coo{sparse_coo -> dense}, + values_csr{sparse_csr -> dense} layout : x backward : values_grad diff --git a/paddle/phi/api/yaml/sparse_bw_api.yaml b/paddle/phi/api/yaml/sparse_bw_api.yaml index b30687f3af2..e6242f178e5 100644 --- a/paddle/phi/api/yaml/sparse_bw_api.yaml +++ b/paddle/phi/api/yaml/sparse_bw_api.yaml @@ -88,26 +88,6 @@ kernel : func : conv3d_coo_grad{sparse_coo, dense, sparse_coo, dense, dense, sparse_coo -> sparse_coo, dense} -- backward_api : coo_to_dense_grad - forward : coo_to_dense(Tensor x) -> Tensor(out) - args : (Tensor x, Tensor out_grad) - output : Tensor(x_grad) - kernel : - func : sparse_coo_to_dense_grad{sparse_coo, dense-> sparse_coo} - -- backward_api : create_sparse_coo_tensor_grad - forward : create_sparse_coo_tensor(Tensor values, Tensor indices, IntArray dense_shape) -> Tensor(out) - args : (Tensor indices, Tensor out_grad) - output : Tensor(values_grad) - kernel : - func : sparse_coo_tensor_grad{dense, sparse_coo -> dense} - -- backward_api : dense_to_coo_grad - forward : dense_to_coo(Tensor x, int64_t sparse_dim) -> Tensor(out) - args : (Tensor out_grad) - output : Tensor(x_grad) - invoke : coo_to_dense(out_grad) - - backward_api : divide_grad forward : divide(Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out, Tensor out_grad) @@ -239,6 +219,13 @@ kernel : func : softmax_csr_grad{sparse_csr, sparse_csr -> sparse_csr} +- backward_api : sparse_coo_tensor_grad + forward : sparse_coo_tensor(Tensor values, Tensor indices, IntArray dense_shape) -> Tensor(out) + args : (Tensor indices, Tensor out_grad) + output : Tensor(values_grad) + kernel : + func : sparse_coo_tensor_grad{dense, sparse_coo -> dense} + - backward_api : sqrt_grad forward : sqrt(Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) @@ -279,12 +266,26 @@ func : tanh_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, tanh_csr_grad {sparse_csr, sparse_csr -> sparse_csr} +- backward_api : to_dense_grad + forward : to_dense(Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + kernel : + func : coo_to_dense_grad{sparse_coo, dense -> sparse_coo} + +- backward_api : to_sparse_coo_grad + forward : to_sparse_coo(Tensor x, int64_t sparse_dim) -> Tensor(out) + args : (Tensor out_grad) + output : Tensor(x_grad) + kernel : + func : coo_to_dense { sparse_coo -> dense } + - backward_api : values_grad - forward : coo_values(Tensor x) -> Tensor(out) + forward : values_coo(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) kernel : - func : coo_values_grad{sparse_coo, dense-> sparse_coo} + func : values_coo_grad{sparse_coo, dense-> sparse_coo} - backward_api: fused_attention_grad forward : fused_attention_csr(Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) -> Tensor(out), Tensor(softmax) diff --git a/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc b/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc index d41a67656d0..4156e46dc81 100644 --- a/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc @@ -270,15 +270,15 @@ void ElementWiseCooKernelImpl(const Context& dev_ctx, const SparseCsrTensor& y, \ SparseCsrTensor* out) { \ funcs::name##Functor functor; \ - auto coo_x = SparseCsrToCoo(dev_ctx, x); \ - auto coo_y = SparseCsrToCoo(dev_ctx, y); \ + auto coo_x = CsrToCoo(dev_ctx, x); \ + auto coo_y = CsrToCoo(dev_ctx, y); \ DenseTensor indeces; \ DenseTensor values; \ SparseCooTensor coo_out; \ coo_out.SetMember(indeces, values, x.dims()); \ ElementWiseCooKernelImpl>( \ dev_ctx, coo_x, coo_y, &coo_out, functor); \ - *out = SparseCooToCsr(dev_ctx, coo_out); \ + *out = CooToCsr(dev_ctx, coo_out); \ } #define DEFINE_CSR_ELEMENTWISE_KERNEL(name) \ diff --git a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc index bf35eaef25a..5199f42ed99 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc @@ -63,10 +63,10 @@ inline int64_t GetNonZeroNum(const DenseTensor& dense, } template -void DenseToSparseCooKernel(const Context& dev_ctx, - const DenseTensor& x, - const int64_t sparse_dim, - SparseCooTensor* out) { +void DenseToCooKernel(const Context& dev_ctx, + const DenseTensor& x, + const int64_t sparse_dim, + SparseCooTensor* out) { const T* x_data = x.data(); const auto& x_dims = x.dims(); PADDLE_ENFORCE_LE(sparse_dim, @@ -107,9 +107,9 @@ void DenseToSparseCooKernel(const Context& dev_ctx, } template -void SparseCsrToCooCPUKernel(const CPUContext& dev_ctx, - const SparseCsrTensor& x, - SparseCooTensor* out) { +void CsrToCooCPUKernel(const CPUContext& dev_ctx, + const SparseCsrTensor& x, + SparseCooTensor* out) { const DDim& x_dims = x.dims(); const int64_t non_zero_num = x.cols().numel(); const auto& csr_crows = x.crows(); @@ -157,19 +157,18 @@ void SparseCsrToCooCPUKernel(const CPUContext& dev_ctx, } template -void SparseCsrToCooKernel(const Context& dev_ctx, - const SparseCsrTensor& x, - SparseCooTensor* out) { - PD_VISIT_BASE_INTEGRAL_TYPES( - x.crows().dtype(), "SparseCsrToCooCPUKernel", ([&] { - SparseCsrToCooCPUKernel(dev_ctx, x, out); - })); +void CsrToCooKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + SparseCooTensor* out) { + PD_VISIT_BASE_INTEGRAL_TYPES(x.crows().dtype(), "CsrToCooCPUKernel", ([&] { + CsrToCooCPUKernel(dev_ctx, x, out); + })); } template -void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx, - const SparseCooTensor& x, - SparseCsrTensor* out) { +void CooToCsrCPUKernel(const CPUContext& dev_ctx, + const SparseCooTensor& x, + SparseCsrTensor* out) { const auto& x_dims = x.dims(); bool valid = x_dims.size() == 2 || x_dims.size() == 3; PADDLE_ENFORCE_EQ(valid, @@ -247,19 +246,18 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx, } template -void SparseCooToCsrKernel(const Context& dev_ctx, - const SparseCooTensor& x, - SparseCsrTensor* out) { - PD_VISIT_BASE_INTEGRAL_TYPES( - x.indices().dtype(), "SparseCooToCsrCPUKernel", ([&] { - SparseCooToCsrCPUKernel(dev_ctx, x, out); - })); +void CooToCsrKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCsrTensor* out) { + PD_VISIT_BASE_INTEGRAL_TYPES(x.indices().dtype(), "CooToCsrCPUKernel", ([&] { + CooToCsrCPUKernel(dev_ctx, x, out); + })); } template -void SparseCooToDenseCPUKernel(const CPUContext& dev_ctx, - const SparseCooTensor& x, - DenseTensor* out) { +void CooToDenseCPUKernel(const CPUContext& dev_ctx, + const SparseCooTensor& x, + DenseTensor* out) { const auto non_zero_num = x.nnz(); const auto dense_dims = x.dims(); const auto indices = x.indices(); @@ -300,22 +298,22 @@ void SparseCooToDenseCPUKernel(const CPUContext& dev_ctx, } template -void SparseCooToDenseKernel(const Context& dev_ctx, - const SparseCooTensor& x, - DenseTensor* out) { +void CooToDenseKernel(const Context& dev_ctx, + const SparseCooTensor& x, + DenseTensor* out) { PD_VISIT_BASE_INTEGRAL_TYPES( - x.indices().dtype(), "SparseCooToDenseCPUKernel", ([&] { - SparseCooToDenseCPUKernel(dev_ctx, x, out); + x.indices().dtype(), "CooToDenseCPUKernel", ([&] { + CooToDenseCPUKernel(dev_ctx, x, out); })); } } // namespace sparse } // namespace phi -PD_REGISTER_KERNEL(dense_to_sparse_coo, +PD_REGISTER_KERNEL(dense_to_coo, CPU, ALL_LAYOUT, - phi::sparse::DenseToSparseCooKernel, + phi::sparse::DenseToCooKernel, float, double, paddle::float16, @@ -325,10 +323,10 @@ PD_REGISTER_KERNEL(dense_to_sparse_coo, int, int64_t) {} -PD_REGISTER_KERNEL(sparse_csr_to_coo, +PD_REGISTER_KERNEL(csr_to_coo, CPU, ALL_LAYOUT, - phi::sparse::SparseCsrToCooKernel, + phi::sparse::CsrToCooKernel, float, double, paddle::float16, @@ -338,10 +336,10 @@ PD_REGISTER_KERNEL(sparse_csr_to_coo, int, int64_t) {} -PD_REGISTER_KERNEL(sparse_coo_to_csr, +PD_REGISTER_KERNEL(coo_to_csr, CPU, ALL_LAYOUT, - phi::sparse::SparseCooToCsrKernel, + phi::sparse::CooToCsrKernel, float, double, phi::dtype::float16, @@ -351,10 +349,10 @@ PD_REGISTER_KERNEL(sparse_coo_to_csr, int, int64_t) {} -PD_REGISTER_KERNEL(dense_to_sparse_csr, +PD_REGISTER_KERNEL(dense_to_csr, CPU, ALL_LAYOUT, - phi::sparse::DenseToSparseCsrKernel, + phi::sparse::DenseToCsrKernel, float, double, phi::dtype::float16, @@ -364,10 +362,10 @@ PD_REGISTER_KERNEL(dense_to_sparse_csr, int, int64_t) {} -PD_REGISTER_KERNEL(sparse_coo_to_dense, +PD_REGISTER_KERNEL(coo_to_dense, CPU, ALL_LAYOUT, - phi::sparse::SparseCooToDenseKernel, + phi::sparse::CooToDenseKernel, float, double, phi::dtype::float16, @@ -377,10 +375,10 @@ PD_REGISTER_KERNEL(sparse_coo_to_dense, int, int64_t) {} -PD_REGISTER_KERNEL(sparse_csr_to_dense, +PD_REGISTER_KERNEL(csr_to_dense, CPU, ALL_LAYOUT, - phi::sparse::SparseCsrToDenseKernel, + phi::sparse::CsrToDenseKernel, float, double, phi::dtype::float16, @@ -390,10 +388,10 @@ PD_REGISTER_KERNEL(sparse_csr_to_dense, int, int64_t) {} -PD_REGISTER_KERNEL(coo_values, +PD_REGISTER_KERNEL(values_coo, CPU, ALL_LAYOUT, - phi::sparse::CooValuesKernel, + phi::sparse::ValuesCooKernel, float, double, phi::dtype::float16, @@ -405,10 +403,10 @@ PD_REGISTER_KERNEL(coo_values, kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } -PD_REGISTER_KERNEL(csr_values, +PD_REGISTER_KERNEL(values_csr, CPU, ALL_LAYOUT, - phi::sparse::CsrValuesKernel, + phi::sparse::ValuesCsrKernel, float, double, phi::dtype::float16, diff --git a/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu index c4bb66827e3..05eb6a90cb4 100644 --- a/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu @@ -43,10 +43,10 @@ void MatmulCooDenseGradKernel(const Context& dev_ctx, // 'cusparseSDDMM' only support CSR now, so use COO->CSR->COO, // which will increase some expenses. EmptyLikeCooKernel(dev_ctx, x, dx); - SparseCsrTensor dx_csr = SparseCooToCsr(dev_ctx, *dx); + SparseCsrTensor dx_csr = CooToCsr(dev_ctx, *dx); sparse_blas.SDDMM( false, true, static_cast(1), dout, y, static_cast(0), &dx_csr); - SparseCsrToCooKernel(dev_ctx, dx_csr, dx); + CsrToCooKernel(dev_ctx, dx_csr, dx); } // dy{Dense} = x'{SparseCoo} * dout{Dense} diff --git a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu index dbd2f305936..2ceda7da750 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu @@ -93,10 +93,10 @@ __global__ void GetNonZeroElementsAndIndices(const T* dense_data, } template -void DenseToSparseCooKernel(const Context& dev_ctx, - const DenseTensor& x, - const int64_t sparse_dim, - SparseCooTensor* out) { +void DenseToCooKernel(const Context& dev_ctx, + const DenseTensor& x, + const int64_t sparse_dim, + SparseCooTensor* out) { const T* x_data = x.data(); const auto& x_dims = x.dims(); PADDLE_ENFORCE_LE(sparse_dim, @@ -208,9 +208,9 @@ __global__ void ConvertCsrCrowsToCooRows(const IntT* crows_ptr, } template -void SparseCsrToCooGPUKernel(const GPUContext& dev_ctx, - const SparseCsrTensor& x, - SparseCooTensor* out) { +void CsrToCooGPUKernel(const GPUContext& dev_ctx, + const SparseCsrTensor& x, + SparseCooTensor* out) { const DDim& x_dims = x.dims(); const int64_t non_zero_num = x.cols().numel(); const auto& csr_crows = x.crows(); @@ -274,13 +274,12 @@ void SparseCsrToCooGPUKernel(const GPUContext& dev_ctx, } template -void SparseCsrToCooKernel(const Context& dev_ctx, - const SparseCsrTensor& x, - SparseCooTensor* out) { - PD_VISIT_BASE_INTEGRAL_TYPES( - x.crows().dtype(), "SparseCsrToCooGPUKernel", ([&] { - SparseCsrToCooGPUKernel(dev_ctx, x, out); - })); +void CsrToCooKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + SparseCooTensor* out) { + PD_VISIT_BASE_INTEGRAL_TYPES(x.crows().dtype(), "CsrToCooGPUKernel", ([&] { + CsrToCooGPUKernel(dev_ctx, x, out); + })); } template @@ -343,9 +342,9 @@ __global__ void ConvertCooRowsToCsrCrows( } template -void SparseCooToCsrGPUKernel(const GPUContext& dev_ctx, - const SparseCooTensor& x, - SparseCsrTensor* out) { +void CooToCsrGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + SparseCsrTensor* out) { const auto& x_dims = x.dims(); bool valid = x_dims.size() == 2 || x_dims.size() == 3; PADDLE_ENFORCE_EQ(valid, @@ -416,23 +415,22 @@ void SparseCooToCsrGPUKernel(const GPUContext& dev_ctx, } template -void SparseCooToCsrKernel(const Context& dev_ctx, - const SparseCooTensor& x, - SparseCsrTensor* out) { - PD_VISIT_BASE_INTEGRAL_TYPES( - x.indices().dtype(), "SparseCooToCsrGPUKernel", ([&] { - SparseCooToCsrGPUKernel(dev_ctx, x, out); - })); +void CooToCsrKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCsrTensor* out) { + PD_VISIT_BASE_INTEGRAL_TYPES(x.indices().dtype(), "CooToCsrGPUKernel", ([&] { + CooToCsrGPUKernel(dev_ctx, x, out); + })); } template -__global__ void KernelSparseCooToDense(const IndicesT* indices, - const int64_t* sparse_offsets, - const ValueT* data, - ValueT* dense_data, - const IndicesT non_zero_num, - const int64_t base_offset, - const int64_t sparse_dim) { +__global__ void KernelCooToDense(const IndicesT* indices, + const int64_t* sparse_offsets, + const ValueT* data, + ValueT* dense_data, + const IndicesT non_zero_num, + const int64_t base_offset, + const int64_t sparse_dim) { int tid = threadIdx.x + blockIdx.x * blockDim.x; for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) { int64_t index = 0; @@ -447,9 +445,9 @@ __global__ void KernelSparseCooToDense(const IndicesT* indices, } template -void SparseCooToDenseGPUKernel(const GPUContext& dev_ctx, - const SparseCooTensor& x, - DenseTensor* out) { +void CooToDenseGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + DenseTensor* out) { const auto non_zero_num = x.nnz(); const auto dense_dims = x.dims(); const auto indices = x.indices(); @@ -490,7 +488,7 @@ void SparseCooToDenseGPUKernel(const GPUContext& dev_ctx, auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); - KernelSparseCooToDense + KernelCooToDense << -void SparseCooToDenseKernel(const Context& dev_ctx, - const SparseCooTensor& x, - DenseTensor* out) { +void CooToDenseKernel(const Context& dev_ctx, + const SparseCooTensor& x, + DenseTensor* out) { PD_VISIT_BASE_INTEGRAL_TYPES( - x.indices().dtype(), "SparseCooToDenseGPUKernel", ([&] { - SparseCooToDenseGPUKernel(dev_ctx, x, out); + x.indices().dtype(), "CooToDenseGPUKernel", ([&] { + CooToDenseGPUKernel(dev_ctx, x, out); })); } } // namespace sparse } // namespace phi -PD_REGISTER_KERNEL(dense_to_sparse_coo, +PD_REGISTER_KERNEL(dense_to_coo, GPU, ALL_LAYOUT, - phi::sparse::DenseToSparseCooKernel, + phi::sparse::DenseToCooKernel, float, double, phi::dtype::float16, @@ -529,10 +527,10 @@ PD_REGISTER_KERNEL(dense_to_sparse_coo, int, int64_t) {} -PD_REGISTER_KERNEL(sparse_csr_to_coo, +PD_REGISTER_KERNEL(csr_to_coo, GPU, ALL_LAYOUT, - phi::sparse::SparseCsrToCooKernel, + phi::sparse::CsrToCooKernel, float, double, phi::dtype::float16, @@ -542,10 +540,10 @@ PD_REGISTER_KERNEL(sparse_csr_to_coo, int, int64_t) {} -PD_REGISTER_KERNEL(sparse_coo_to_csr, +PD_REGISTER_KERNEL(coo_to_csr, GPU, ALL_LAYOUT, - phi::sparse::SparseCooToCsrKernel, + phi::sparse::CooToCsrKernel, float, double, phi::dtype::float16, @@ -555,10 +553,10 @@ PD_REGISTER_KERNEL(sparse_coo_to_csr, int, int64_t) {} -PD_REGISTER_KERNEL(dense_to_sparse_csr, +PD_REGISTER_KERNEL(dense_to_csr, GPU, ALL_LAYOUT, - phi::sparse::DenseToSparseCsrKernel, + phi::sparse::DenseToCsrKernel, float, double, phi::dtype::float16, @@ -568,10 +566,10 @@ PD_REGISTER_KERNEL(dense_to_sparse_csr, int, int64_t) {} -PD_REGISTER_KERNEL(sparse_coo_to_dense, +PD_REGISTER_KERNEL(coo_to_dense, GPU, ALL_LAYOUT, - phi::sparse::SparseCooToDenseKernel, + phi::sparse::CooToDenseKernel, float, double, phi::dtype::float16, @@ -581,10 +579,10 @@ PD_REGISTER_KERNEL(sparse_coo_to_dense, int, int64_t) {} -PD_REGISTER_KERNEL(sparse_csr_to_dense, +PD_REGISTER_KERNEL(csr_to_dense, GPU, ALL_LAYOUT, - phi::sparse::SparseCsrToDenseKernel, + phi::sparse::CsrToDenseKernel, float, double, phi::dtype::float16, @@ -594,10 +592,10 @@ PD_REGISTER_KERNEL(sparse_csr_to_dense, int, int64_t) {} -PD_REGISTER_KERNEL(coo_values, +PD_REGISTER_KERNEL(values_coo, GPU, ALL_LAYOUT, - phi::sparse::CooValuesKernel, + phi::sparse::ValuesCooKernel, float, double, phi::dtype::float16, @@ -609,10 +607,10 @@ PD_REGISTER_KERNEL(coo_values, kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } -PD_REGISTER_KERNEL(csr_values, +PD_REGISTER_KERNEL(values_csr, GPU, ALL_LAYOUT, - phi::sparse::CsrValuesKernel, + phi::sparse::ValuesCsrKernel, float, double, phi::dtype::float16, diff --git a/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc b/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc index b41497c22c3..4c1c1f85cce 100644 --- a/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc @@ -20,7 +20,7 @@ namespace phi { namespace sparse { template -void CooValuesGradKernel(const Context& dev_ctx, +void ValuesCooGradKernel(const Context& dev_ctx, const SparseCooTensor& x, const DenseTensor& out_grad, SparseCooTensor* x_grad) { @@ -28,20 +28,20 @@ void CooValuesGradKernel(const Context& dev_ctx, } template -void SparseCooToDenseGradKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& out_grad, - SparseCooTensor* x_grad) { +void CooToDenseGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& out_grad, + SparseCooTensor* x_grad) { SparseMaskKernel(dev_ctx, out_grad, x, x_grad); } } // namespace sparse } // namespace phi -PD_REGISTER_KERNEL(coo_values_grad, +PD_REGISTER_KERNEL(values_coo_grad, CPU, ALL_LAYOUT, - phi::sparse::CooValuesGradKernel, + phi::sparse::ValuesCooGradKernel, float, double, uint8_t, @@ -52,10 +52,10 @@ PD_REGISTER_KERNEL(coo_values_grad, kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } -PD_REGISTER_KERNEL(sparse_coo_to_dense_grad, +PD_REGISTER_KERNEL(coo_to_dense_grad, CPU, ALL_LAYOUT, - phi::sparse::SparseCooToDenseGradKernel, + phi::sparse::CooToDenseGradKernel, float, double, uint8_t, @@ -80,10 +80,10 @@ PD_REGISTER_KERNEL(sparse_coo_tensor_grad, } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL(coo_values_grad, +PD_REGISTER_KERNEL(values_coo_grad, GPU, ALL_LAYOUT, - phi::sparse::CooValuesGradKernel, + phi::sparse::ValuesCooGradKernel, float, double, phi::dtype::float16, @@ -94,10 +94,10 @@ PD_REGISTER_KERNEL(coo_values_grad, int64_t) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } -PD_REGISTER_KERNEL(sparse_coo_to_dense_grad, +PD_REGISTER_KERNEL(coo_to_dense_grad, GPU, ALL_LAYOUT, - phi::sparse::SparseCooToDenseGradKernel, + phi::sparse::CooToDenseGradKernel, float, double, phi::dtype::float16, diff --git a/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h b/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h index 7cf97c3f48e..08e68658d84 100644 --- a/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h +++ b/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h @@ -22,16 +22,16 @@ namespace phi { namespace sparse { template -void CooValuesGradKernel(const Context& dev_ctx, +void ValuesCooGradKernel(const Context& dev_ctx, const SparseCooTensor& x, const DenseTensor& out_grad, SparseCooTensor* x_grad); template -void SparseCooToDenseGradKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& out_grad, - SparseCooTensor* x_grad); +void CooToDenseGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& out_grad, + SparseCooTensor* x_grad); template void SparseCooTensorGradKernel(const Context& dev_ctx, diff --git a/paddle/phi/kernels/sparse/sparse_utils_kernel.h b/paddle/phi/kernels/sparse/sparse_utils_kernel.h index 70f719de04a..932427d42cd 100644 --- a/paddle/phi/kernels/sparse/sparse_utils_kernel.h +++ b/paddle/phi/kernels/sparse/sparse_utils_kernel.h @@ -24,57 +24,55 @@ namespace phi { namespace sparse { template -void DenseToSparseCooKernel(const Context& dev_ctx, - const DenseTensor& x, - const int64_t sparse_dim, - SparseCooTensor* out); +void DenseToCooKernel(const Context& dev_ctx, + const DenseTensor& x, + const int64_t sparse_dim, + SparseCooTensor* out); template -SparseCooTensor DenseToSparseCoo(const Context& dev_ctx, - const DenseTensor& x, - const int64_t sparse_dim) { +SparseCooTensor DenseToCoo(const Context& dev_ctx, + const DenseTensor& x, + const int64_t sparse_dim) { DenseTensor indices; DenseTensor values; SparseCooTensor coo(indices, values, x.dims()); - DenseToSparseCooKernel(dev_ctx, x, sparse_dim, &coo); + DenseToCooKernel(dev_ctx, x, sparse_dim, &coo); return coo; } template -void SparseCsrToCooKernel(const Context& dev_ctx, - const SparseCsrTensor& x, - SparseCooTensor* out); +void CsrToCooKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + SparseCooTensor* out); template -SparseCooTensor SparseCsrToCoo(const Context& dev_ctx, - const SparseCsrTensor& x) { +SparseCooTensor CsrToCoo(const Context& dev_ctx, const SparseCsrTensor& x) { DenseTensor indices; DenseTensor values; SparseCooTensor coo(indices, values, x.dims()); - SparseCsrToCooKernel(dev_ctx, x, &coo); + CsrToCooKernel(dev_ctx, x, &coo); return coo; } template -void SparseCooToCsrKernel(const Context& dev_ctx, - const SparseCooTensor& x, - SparseCsrTensor* out); +void CooToCsrKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCsrTensor* out); template -SparseCsrTensor SparseCooToCsr(const Context& dev_ctx, - const SparseCooTensor& x) { +SparseCsrTensor CooToCsr(const Context& dev_ctx, const SparseCooTensor& x) { DenseTensor crows; DenseTensor cols; DenseTensor non_zero_elements; SparseCsrTensor csr(crows, cols, non_zero_elements, x.dims()); - SparseCooToCsrKernel(dev_ctx, x, &csr); + CooToCsrKernel(dev_ctx, x, &csr); return csr; } template -void DenseToSparseCsrKernel(const Context& dev_ctx, - const DenseTensor& x, - SparseCsrTensor* out) { +void DenseToCsrKernel(const Context& dev_ctx, + const DenseTensor& x, + SparseCsrTensor* out) { const auto& x_dims = x.dims(); bool valid = x_dims.size() == 2 || x_dims.size() == 3; PADDLE_ENFORCE_EQ(valid, @@ -85,61 +83,61 @@ void DenseToSparseCsrKernel(const Context& dev_ctx, DenseTensor indices; DenseTensor values; SparseCooTensor coo(indices, values, x.dims()); - DenseToSparseCooKernel(dev_ctx, x, sparse_dim, &coo); - SparseCooToCsrKernel(dev_ctx, coo, out); + DenseToCooKernel(dev_ctx, x, sparse_dim, &coo); + CooToCsrKernel(dev_ctx, coo, out); } template -SparseCsrTensor DenseToSparseCsr(const Context& dev_ctx, const DenseTensor& x) { +SparseCsrTensor DenseToCsr(const Context& dev_ctx, const DenseTensor& x) { DenseTensor crows; DenseTensor cols; DenseTensor non_zero_elements; SparseCsrTensor csr(crows, cols, non_zero_elements, x.dims()); - DenseToSparseCsrKernel(dev_ctx, x, &csr); + DenseToCsrKernel(dev_ctx, x, &csr); return csr; } template -void SparseCooToDenseKernel(const Context& dev_ctx, - const SparseCooTensor& x, - DenseTensor* out); +void CooToDenseKernel(const Context& dev_ctx, + const SparseCooTensor& x, + DenseTensor* out); template -DenseTensor SparseCooToDense(const Context& dev_ctx, const SparseCooTensor& x) { +DenseTensor CooToDense(const Context& dev_ctx, const SparseCooTensor& x) { DenseTensorMeta meta(x.dtype(), x.dims(), x.non_zero_elements().layout()); DenseTensor dense = phi::Empty(dev_ctx, std::move(meta)); - SparseCooToDenseKernel(dev_ctx, x, &dense); + CooToDenseKernel(dev_ctx, x, &dense); return dense; } template -void SparseCsrToDenseKernel(const Context& dev_ctx, - const SparseCsrTensor& x, - DenseTensor* out) { +void CsrToDenseKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + DenseTensor* out) { DenseTensor indices; DenseTensor values; SparseCooTensor coo(indices, values, x.dims()); - SparseCsrToCooKernel(dev_ctx, x, &coo); - SparseCooToDenseKernel(dev_ctx, coo, out); + CsrToCooKernel(dev_ctx, x, &coo); + CooToDenseKernel(dev_ctx, coo, out); } template -DenseTensor SparseCsrToDense(const Context& dev_ctx, const SparseCsrTensor& x) { +DenseTensor CsrToDense(const Context& dev_ctx, const SparseCsrTensor& x) { DenseTensorMeta meta(x.dtype(), x.dims(), x.non_zero_elements().layout()); DenseTensor dense = phi::Empty(dev_ctx, std::move(meta)); - SparseCsrToDenseKernel(dev_ctx, x, &dense); + CsrToDenseKernel(dev_ctx, x, &dense); return dense; } template -void CooValuesKernel(const Context& dev_ctx, +void ValuesCooKernel(const Context& dev_ctx, const SparseCooTensor& x, DenseTensor* out) { *out = x.non_zero_elements(); } template -void CsrValuesKernel(const Context& dev_ctx, +void ValuesCsrKernel(const Context& dev_ctx, const SparseCsrTensor& x, DenseTensor* out) { *out = x.non_zero_elements(); diff --git a/paddle/phi/tests/api/test_sparse_utils_api.cc b/paddle/phi/tests/api/test_sparse_utils_api.cc index d5891baaf10..bf55e9256cc 100644 --- a/paddle/phi/tests/api/test_sparse_utils_api.cc +++ b/paddle/phi/tests/api/test_sparse_utils_api.cc @@ -23,7 +23,7 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/sparse_coo_tensor.h" -PD_DECLARE_KERNEL(dense_to_sparse_coo, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(dense_to_coo, CPU, ALL_LAYOUT); TEST(API, to_sparse_coo) { const auto alloc = std::make_shared( diff --git a/paddle/phi/tests/kernels/test_sparse_activation_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_activation_dev_api.cc index 9c6776fb2ac..b58133f935d 100644 --- a/paddle/phi/tests/kernels/test_sparse_activation_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_activation_dev_api.cc @@ -47,7 +47,7 @@ TEST(DEV_API, sparse_relu) { phi::Empty(dev_ctx_cpu, DenseTensorMeta(DataType::FLOAT32, {3, 4}, DataLayout::NCHW)); memcpy(dense_x.data(), data.data(), data.size() * sizeof(float)); - auto sparse_coo = sparse::DenseToSparseCoo(dev_ctx_cpu, dense_x, 2); + auto sparse_coo = sparse::DenseToCoo(dev_ctx_cpu, dense_x, 2); auto sparse_out = sparse::ReluCoo(dev_ctx_cpu, sparse_coo); DenseTensor dense_out = diff --git a/paddle/phi/tests/kernels/test_sparse_elementwise_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_elementwise_dev_api.cc index cbac854d48e..f4add7faecb 100644 --- a/paddle/phi/tests/kernels/test_sparse_elementwise_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_elementwise_dev_api.cc @@ -49,12 +49,9 @@ namespace tests { const Sparse##type##Tensor& y, \ const DDim& dense_dims) { \ auto out = sparse::ElementWise##name##type(dev_ctx_cpu, x, y); \ - const DenseTensor denseX = \ - sparse::Sparse##type##ToDense(dev_ctx_cpu, x); \ - const DenseTensor denseY = \ - sparse::Sparse##type##ToDense(dev_ctx_cpu, y); \ - const DenseTensor denseOut = \ - sparse::Sparse##type##ToDense(dev_ctx_cpu, out); \ + const DenseTensor denseX = sparse::type##ToDense(dev_ctx_cpu, x); \ + const DenseTensor denseY = sparse::type##ToDense(dev_ctx_cpu, y); \ + const DenseTensor denseOut = sparse::type##ToDense(dev_ctx_cpu, out); \ auto expectResult = name(dev_ctx_cpu, denseX, denseY); \ for (int j = 0; j < denseOut.numel(); ++j) { \ auto actualResultRow = denseOut.template data()[j]; \ @@ -114,8 +111,8 @@ TEST(DEV_API, sparse_elementwise_coo_kernel_double) { .GetAllocator(paddle::platform::CPUPlace()) .get()); - auto coo_x = sparse::DenseToSparseCoo(dev_ctx_cpu, dense_x, sparse_dim); - auto coo_y = sparse::DenseToSparseCoo(dev_ctx_cpu, dense_y, sparse_dim); + auto coo_x = sparse::DenseToCoo(dev_ctx_cpu, dense_x, sparse_dim); + auto coo_y = sparse::DenseToCoo(dev_ctx_cpu, dense_y, sparse_dim); TestElementWiseAddCoo(dev_ctx_cpu, coo_x, coo_y, dense_dims); TestElementWiseSubtractCoo(dev_ctx_cpu, coo_x, coo_y, dense_dims); @@ -159,8 +156,8 @@ TEST(DEV_API, sparse_elementwise_csr_kernel_float) { .GetAllocator(paddle::platform::CPUPlace()) .get()); - auto csr_x = sparse::DenseToSparseCsr(dev_ctx_cpu, dense_x); - auto csr_y = sparse::DenseToSparseCsr(dev_ctx_cpu, dense_y); + auto csr_x = sparse::DenseToCsr(dev_ctx_cpu, dense_x); + auto csr_y = sparse::DenseToCsr(dev_ctx_cpu, dense_y); TestElementWiseAddCsr(dev_ctx_cpu, csr_x, csr_y, dense_dims); TestElementWiseSubtractCsr(dev_ctx_cpu, csr_x, csr_y, dense_dims); @@ -190,20 +187,18 @@ TEST(DEV_API, sparse_elementwise_csr_kernel_float) { dev_ctx_cpu, \ DenseTensorMeta(DataType::FLOAT32, dense_dims, DataLayout::NCHW)); \ \ - phi::name##GradKernel( \ - dev_ctx_cpu, \ - sparse::Sparse##type##ToDense(dev_ctx_cpu, x), \ - sparse::Sparse##type##ToDense(dev_ctx_cpu, y), \ - sparse::Sparse##type##ToDense(dev_ctx_cpu, out), \ - -1, \ - &expectdx, \ - &expectdy); \ + phi::name##GradKernel(dev_ctx_cpu, \ + sparse::type##ToDense(dev_ctx_cpu, x), \ + sparse::type##ToDense(dev_ctx_cpu, y), \ + sparse::type##ToDense(dev_ctx_cpu, out), \ + -1, \ + &expectdx, \ + &expectdy); \ const DenseTensor densedX = \ - sparse::Sparse##type##ToDense(dev_ctx_cpu, dresult[0]); \ + sparse::type##ToDense(dev_ctx_cpu, dresult[0]); \ const DenseTensor densedY = \ - sparse::Sparse##type##ToDense(dev_ctx_cpu, dresult[1]); \ - const DenseTensor denseOut = \ - sparse::Sparse##type##ToDense(dev_ctx_cpu, out); \ + sparse::type##ToDense(dev_ctx_cpu, dresult[1]); \ + const DenseTensor denseOut = sparse::type##ToDense(dev_ctx_cpu, out); \ \ for (int j = 0; j < densedX.numel(); ++j) { \ auto actualResultRow = densedX.template data()[j]; \ @@ -248,18 +243,16 @@ void TestElementWiseDivideCsrGrad(const Context& dev_ctx_cpu, dev_ctx_cpu, DenseTensorMeta(DataType::FLOAT32, dense_dims, DataLayout::NCHW)); phi::DivideGradKernel(dev_ctx_cpu, - sparse::SparseCsrToDense(dev_ctx_cpu, x), - sparse::SparseCsrToDense(dev_ctx_cpu, y), - sparse::SparseCsrToDense(dev_ctx_cpu, out), - sparse::SparseCsrToDense(dev_ctx_cpu, out), + sparse::CsrToDense(dev_ctx_cpu, x), + sparse::CsrToDense(dev_ctx_cpu, y), + sparse::CsrToDense(dev_ctx_cpu, out), + sparse::CsrToDense(dev_ctx_cpu, out), -1, &expectdx, &expectdy); - const DenseTensor densedX = - sparse::SparseCsrToDense(dev_ctx_cpu, dresult[0]); - const DenseTensor densedY = - sparse::SparseCsrToDense(dev_ctx_cpu, dresult[1]); - const DenseTensor denseOut = sparse::SparseCsrToDense(dev_ctx_cpu, out); + const DenseTensor densedX = sparse::CsrToDense(dev_ctx_cpu, dresult[0]); + const DenseTensor densedY = sparse::CsrToDense(dev_ctx_cpu, dresult[1]); + const DenseTensor denseOut = sparse::CsrToDense(dev_ctx_cpu, out); for (int j = 0; j < densedX.numel(); ++j) { auto actualResultRow = densedX.template data()[j]; auto expectResultRow = expectdx.template data()[j]; @@ -291,18 +284,16 @@ void TestElementWiseDivideCooGrad(const Context& dev_ctx_cpu, dev_ctx_cpu, DenseTensorMeta(DataType::FLOAT32, dense_dims, DataLayout::NCHW)); phi::DivideGradKernel(dev_ctx_cpu, - sparse::SparseCooToDense(dev_ctx_cpu, x), - sparse::SparseCooToDense(dev_ctx_cpu, y), - sparse::SparseCooToDense(dev_ctx_cpu, out), - sparse::SparseCooToDense(dev_ctx_cpu, out), + sparse::CooToDense(dev_ctx_cpu, x), + sparse::CooToDense(dev_ctx_cpu, y), + sparse::CooToDense(dev_ctx_cpu, out), + sparse::CooToDense(dev_ctx_cpu, out), -1, &expectdx, &expectdy); - const DenseTensor densedX = - sparse::SparseCooToDense(dev_ctx_cpu, dresult[0]); - const DenseTensor densedY = - sparse::SparseCooToDense(dev_ctx_cpu, dresult[1]); - const DenseTensor denseOut = sparse::SparseCooToDense(dev_ctx_cpu, out); + const DenseTensor densedX = sparse::CooToDense(dev_ctx_cpu, dresult[0]); + const DenseTensor densedY = sparse::CooToDense(dev_ctx_cpu, dresult[1]); + const DenseTensor denseOut = sparse::CooToDense(dev_ctx_cpu, out); for (int j = 0; j < densedX.numel(); ++j) { auto actualResultRow = densedX.template data()[j]; auto expectResultRow = expectdx.template data()[j]; @@ -356,11 +347,11 @@ TEST(DEV_API, sparse_elementwise_csr_grad_kernel_float) { .GetAllocator(paddle::platform::CPUPlace()) .get()); - auto csr_x = sparse::DenseToSparseCsr(dev_ctx_cpu, dense_x); - auto csr_y = sparse::DenseToSparseCsr(dev_ctx_cpu, dense_y); + auto csr_x = sparse::DenseToCsr(dev_ctx_cpu, dense_x); + auto csr_y = sparse::DenseToCsr(dev_ctx_cpu, dense_y); - auto dx = sparse::DenseToSparseCsr(dev_ctx_cpu, dense_y); - auto dy = sparse::DenseToSparseCsr(dev_ctx_cpu, dense_x); + auto dx = sparse::DenseToCsr(dev_ctx_cpu, dense_y); + auto dy = sparse::DenseToCsr(dev_ctx_cpu, dense_x); TestElementWiseAddCsrGrad(dev_ctx_cpu, csr_x, csr_y, dense_dims); TestElementWiseSubtractCsrGrad(dev_ctx_cpu, csr_x, csr_y, dense_dims); @@ -402,11 +393,11 @@ TEST(DEV_API, sparse_elementwise_coo_grad_kernel_double) { .GetAllocator(paddle::platform::CPUPlace()) .get()); - auto csr_x = sparse::DenseToSparseCoo(dev_ctx_cpu, dense_x, sparse_dim); - auto csr_y = sparse::DenseToSparseCoo(dev_ctx_cpu, dense_y, sparse_dim); + auto csr_x = sparse::DenseToCoo(dev_ctx_cpu, dense_x, sparse_dim); + auto csr_y = sparse::DenseToCoo(dev_ctx_cpu, dense_y, sparse_dim); - auto dx = sparse::DenseToSparseCoo(dev_ctx_cpu, dense_y, sparse_dim); - auto dy = sparse::DenseToSparseCoo(dev_ctx_cpu, dense_x, sparse_dim); + auto dx = sparse::DenseToCoo(dev_ctx_cpu, dense_y, sparse_dim); + auto dy = sparse::DenseToCoo(dev_ctx_cpu, dense_x, sparse_dim); TestElementWiseAddCooGrad(dev_ctx_cpu, csr_x, csr_y, dense_dims); TestElementWiseSubtractCooGrad(dev_ctx_cpu, csr_x, csr_y, dense_dims); diff --git a/paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc index 29300d8f58a..73f072a3f80 100644 --- a/paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc @@ -94,8 +94,7 @@ void TestDenseToSparseCoo(const DenseTensor& dense_x, .get()); // 1. test cpu - auto cpu_sparse_out = - sparse::DenseToSparseCoo(dev_ctx_cpu, dense_x, sparse_dim); + auto cpu_sparse_out = sparse::DenseToCoo(dev_ctx_cpu, dense_x, sparse_dim); CheckResult(&dev_ctx_cpu, cpu_sparse_out, non_zero_data, @@ -129,8 +128,7 @@ void TestDenseToSparseCoo(const DenseTensor& dense_x, DenseTensorMeta(dense_x.dtype(), dense_x.dims(), dense_x.layout())); phi::Copy(dev_ctx_gpu, dense_x, phi::GPUPlace(), true, &d_dense_x); - auto sparse_out = - sparse::DenseToSparseCoo(dev_ctx_gpu, d_dense_x, sparse_dim); + auto sparse_out = sparse::DenseToCoo(dev_ctx_gpu, d_dense_x, sparse_dim); CheckResult(&dev_ctx_gpu, sparse_out, non_zero_data, @@ -310,7 +308,7 @@ void TestSparseCsrToCoo(const DDim& dense_dims, paddle::memory::allocation::AllocatorFacade::Instance() .GetAllocator(phi::CPUPlace()) .get()); - auto cpu_sparse_out = sparse::SparseCsrToCoo(dev_ctx_cpu, csr); + auto cpu_sparse_out = sparse::CsrToCoo(dev_ctx_cpu, csr); CheckResult(&dev_ctx_cpu, cpu_sparse_out, non_zero_data, @@ -345,7 +343,7 @@ void TestSparseCsrToCoo(const DDim& dense_dims, phi::Copy(dev_ctx_gpu, cols, d_cols.place(), true, &d_cols); phi::Copy(dev_ctx_gpu, values, d_values.place(), true, &d_values); phi::SparseCsrTensor d_csr(d_crows, d_cols, d_values, dense_dims); - auto cuda_sparse_out = sparse::SparseCsrToCoo(dev_ctx_gpu, d_csr); + auto cuda_sparse_out = sparse::CsrToCoo(dev_ctx_gpu, d_csr); CheckResult(&dev_ctx_gpu, cuda_sparse_out, non_zero_data, @@ -491,7 +489,7 @@ void TestCooToCsr(const DDim& dense_dims, paddle::memory::allocation::AllocatorFacade::Instance() .GetAllocator(phi::CPUPlace()) .get()); - auto cpu_sparse_out = sparse::SparseCooToCsr(dev_ctx_cpu, coo); + auto cpu_sparse_out = sparse::CooToCsr(dev_ctx_cpu, coo); CheckCsrResult(&dev_ctx_cpu, cpu_sparse_out, non_zero_data, @@ -525,7 +523,7 @@ void TestCooToCsr(const DDim& dense_dims, phi::Copy(dev_ctx_gpu, indices, phi::GPUPlace(), true, &d_indices); phi::Copy(dev_ctx_gpu, values, phi::GPUPlace(), true, &d_values); phi::SparseCooTensor d_coo(d_indices, d_values, dense_dims); - auto cuda_sparse_out = sparse::SparseCooToCsr(dev_ctx_gpu, d_coo); + auto cuda_sparse_out = sparse::CooToCsr(dev_ctx_gpu, d_coo); CheckCsrResult(&dev_ctx_gpu, cuda_sparse_out, non_zero_data, @@ -591,7 +589,7 @@ void TestDenseToSparseCsr(const DenseTensor& dense_x, .get()); // 1. test cpu - auto cpu_sparse_out = sparse::DenseToSparseCsr(dev_ctx_cpu, dense_x); + auto cpu_sparse_out = sparse::DenseToCsr(dev_ctx_cpu, dense_x); CheckCsrResult(&dev_ctx_cpu, cpu_sparse_out, non_zero_data, @@ -624,7 +622,7 @@ void TestDenseToSparseCsr(const DenseTensor& dense_x, .get()); dev_ctx_gpu.PartialInitWithAllocator(); phi::Copy(dev_ctx_gpu, dense_x, phi::GPUPlace(), true, &d_dense_x); - auto sparse_out = sparse::DenseToSparseCsr(dev_ctx_gpu, d_dense_x); + auto sparse_out = sparse::DenseToCsr(dev_ctx_gpu, d_dense_x); CheckCsrResult(&dev_ctx_gpu, sparse_out, @@ -731,7 +729,7 @@ void TestSparseCooToDense(const DDim& dense_dims, SparseCooTensor coo(dense_indices, dense_elements, dense_dims); - DenseTensor dense_out = sparse::SparseCooToDense(dev_ctx_cpu, coo); + DenseTensor dense_out = sparse::CooToDense(dev_ctx_cpu, coo); int cmp = memcmp( &dense_data[0], dense_out.data(), sizeof(T) * dense_data.size()); @@ -763,7 +761,7 @@ void TestSparseCooToDense(const DDim& dense_dims, phi::Copy( dev_ctx_gpu, dense_elements, phi::GPUPlace(), true, &d_dense_elements); SparseCooTensor coo_cuda(d_dense_indices, d_dense_elements, dense_dims); - auto dense_out_cuda = sparse::SparseCooToDense(dev_ctx_gpu, coo_cuda); + auto dense_out_cuda = sparse::CooToDense(dev_ctx_gpu, coo_cuda); DenseTensor h_dense_out(alloc.get(), DenseTensorMeta(dense_out_cuda.dtype(), @@ -878,7 +876,7 @@ void TestSparseCsrToDense(const DDim& dense_dims, paddle::memory::allocation::AllocatorFacade::Instance() .GetAllocator(phi::CPUPlace()) .get()); - DenseTensor cpu_sparse_out = sparse::SparseCsrToDense(dev_ctx_cpu, csr); + DenseTensor cpu_sparse_out = sparse::CsrToDense(dev_ctx_cpu, csr); int cmp_cpu = memcmp(cpu_sparse_out.data(), dense_data.data(), sizeof(T) * dense_data.size()); @@ -911,7 +909,7 @@ void TestSparseCsrToDense(const DDim& dense_dims, phi::Copy(dev_ctx_gpu, cols, phi::GPUPlace(), true, &d_cols); phi::Copy(dev_ctx_gpu, values, phi::GPUPlace(), true, &d_values); phi::SparseCsrTensor d_csr(d_crows, d_cols, d_values, dense_dims); - auto cuda_sparse_out = sparse::SparseCsrToDense(dev_ctx_gpu, d_csr); + auto cuda_sparse_out = sparse::CsrToDense(dev_ctx_gpu, d_csr); phi::DenseTensor h_out(alloc.get(), cpu_sparse_out.meta()); phi::Copy(dev_ctx_gpu, cuda_sparse_out, phi::CPUPlace(), true, &h_out); int cmp_cuda = diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 8ad85895258..cb6907d842c 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -923,12 +923,7 @@ def monkey_patch_varbase(): print(sparse_x.values()) #[1, 2, 3, 4, 5] """ - - if self.is_sparse_coo() or self.is_sparse_csr(): - return _C_ops.sparse_values(self) - else: - raise ValueError( - "only SparseCooTensor and SparseCsrTensor have method values") + return _C_ops.sparse_values(self) @framework.dygraph_only def to_dense(self): @@ -956,12 +951,7 @@ def monkey_patch_varbase(): # [4., 5., 0., 0.]] """ - if self.is_sparse_coo(): - return _C_ops.sparse_coo_to_dense(self) - elif self.is_sparse_csr(): - return _C_ops.sparse_to_dense(self) - else: - return self + return _C_ops.sparse_to_dense(self) @framework.dygraph_only def to_sparse_coo(self, sparse_dim): @@ -987,16 +977,7 @@ def monkey_patch_varbase(): #values=[1., 2., 3., 4.] """ - if self.is_sparse_csr(): - return _C_ops.sparse_to_sparse_coo(self, sparse_dim) - elif self.is_sparse_coo(): - return self - elif self.is_selected_rows(): - raise ValueError( - "SelectedRows does not support to_sparse_coo method") - else: - #is dense tensor - return _C_ops.sparse_dense_to_coo(self, sparse_dim) + return _C_ops.sparse_to_sparse_coo(self, sparse_dim) if framework._in_eager_mode_ and not hasattr(core, "eager"): return diff --git a/python/paddle/incubate/sparse/creation.py b/python/paddle/incubate/sparse/creation.py index 143dbd77081..18794788831 100644 --- a/python/paddle/incubate/sparse/creation.py +++ b/python/paddle/incubate/sparse/creation.py @@ -166,7 +166,7 @@ def sparse_coo_tensor(indices, "the number of dimensions(len(shape) must be sparse_dim({}) + dense_dim({}), but get {}" .format(sparse_dim, dense_dim, len(shape))) - return _C_ops.sparse_create_sparse_coo_tensor(values, indices, shape) + return _C_ops.sparse_sparse_coo_tensor(values, indices, shape) #TODO: need to support shape is None -- GitLab