From a9cc54820b9bae481ddc365b6fa22b933f739b83 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 12 Oct 2022 11:00:29 +0800 Subject: [PATCH] [Sparse] Rename and fix doc (#46853) --- paddle/phi/api/yaml/sparse_backward.yaml | 4 ++-- paddle/phi/api/yaml/sparse_ops.yaml | 12 +++++----- paddle/phi/kernels/sparse/coalesce_kernel.h | 10 ++++---- .../phi/kernels/sparse/cpu/coalesce_kernel.cc | 23 ++++++++++--------- paddle/phi/kernels/sparse/cpu/full_kernel.cc | 12 +++++----- paddle/phi/kernels/sparse/cpu/unary_kernel.cc | 12 +++++----- paddle/phi/kernels/sparse/full_kernel.h | 4 ++-- .../phi/kernels/sparse/gpu/coalesce_kernel.cu | 23 ++++++++++--------- paddle/phi/kernels/sparse/gpu/full_kernel.cu | 12 +++++----- paddle/phi/kernels/sparse/gpu/unary_kernel.cu | 12 +++++----- paddle/phi/kernels/sparse/unary_kernel.h | 4 ++-- .../kernels/test_sparse_conv3d_dev_api.cc | 2 +- .../tests/kernels/test_sparse_pool_dev_api.cc | 2 +- .../paddle/incubate/sparse/nn/layer/conv.py | 4 ++-- 14 files changed, 69 insertions(+), 67 deletions(-) diff --git a/paddle/phi/api/yaml/sparse_backward.yaml b/paddle/phi/api/yaml/sparse_backward.yaml index 5850acb3c37..5cca834a6dc 100644 --- a/paddle/phi/api/yaml/sparse_backward.yaml +++ b/paddle/phi/api/yaml/sparse_backward.yaml @@ -124,8 +124,8 @@ cast_csr_grad {sparse_csr, sparse_csr -> sparse_csr} data_type : out_grad -- backward_op : conv3d_coo_grad - forward : conv3d_coo (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) -> Tensor(out), Tensor(rulebook), Tensor(counter) +- backward_op : conv3d_grad + forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) -> Tensor(out), Tensor(rulebook), Tensor(counter) args : (Tensor x, Tensor kernel, Tensor out, Tensor rulebook, Tensor counter, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) output : Tensor(x_grad), Tensor(kernel_grad) infer_meta : diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index 43f8688fb81..e1083ae3a65 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -119,7 +119,7 @@ func : conv3d_coo{sparse_coo, dense -> sparse_coo, dense, dense} layout : x intermediate: rulebook, counter - backward : conv3d_coo_grad + backward : conv3d_grad - op : divide args : (Tensor x, Tensor y) @@ -139,8 +139,8 @@ func : UnchangedInferMeta param : [x] kernel : - func : divide_coo_scalar{sparse_coo -> sparse_coo}, - divide_csr_scalar{sparse_csr -> sparse_csr} + func : divide_scalar_coo{sparse_coo -> sparse_coo}, + divide_scalar_csr{sparse_csr -> sparse_csr} backward : divide_scalar_grad - op : expm1 @@ -393,7 +393,7 @@ infer_meta : func : UnchangedInferMeta kernel : - func: coalesce{sparse_coo -> sparse_coo} + func: coalesce_coo{sparse_coo -> sparse_coo} layout : x - op: full_like @@ -403,8 +403,8 @@ func : CreateLikeInferMeta param : [x, dtype] kernel : - func : coo_full_like{sparse_coo -> sparse_coo}, - csr_full_like{sparse_csr -> sparse_csr} + func : full_like_coo{sparse_coo -> sparse_coo}, + full_like_csr{sparse_csr -> sparse_csr} layout : x data_type : dtype diff --git a/paddle/phi/kernels/sparse/coalesce_kernel.h b/paddle/phi/kernels/sparse/coalesce_kernel.h index cb8b98fd874..9f3c7faba16 100644 --- a/paddle/phi/kernels/sparse/coalesce_kernel.h +++ b/paddle/phi/kernels/sparse/coalesce_kernel.h @@ -22,14 +22,14 @@ namespace phi { namespace sparse { template -void CoalesceKernel(const Context& dev_ctx, - const SparseCooTensor& x, - SparseCooTensor* out); +void CoalesceCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out); template -SparseCooTensor Coalesce(const Context& dev_ctx, const SparseCooTensor& x) { +SparseCooTensor CoalesceCoo(const Context& dev_ctx, const SparseCooTensor& x) { SparseCooTensor coo; - CoalesceKernel(dev_ctx, x, &coo); + CoalesceCooKernel(dev_ctx, x, &coo); return coo; } diff --git a/paddle/phi/kernels/sparse/cpu/coalesce_kernel.cc b/paddle/phi/kernels/sparse/cpu/coalesce_kernel.cc index a2d622daaa1..b8d25741e42 100644 --- a/paddle/phi/kernels/sparse/cpu/coalesce_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/coalesce_kernel.cc @@ -22,9 +22,9 @@ namespace phi { namespace sparse { template -void CoalesceCPUKernel(const CPUContext& dev_ctx, - const SparseCooTensor& x, - SparseCooTensor* out) { +void CoalesceCooCPUKernel(const CPUContext& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out) { const DenseTensor& x_indices = x.indices(); const DenseTensor& x_values = x.values(); DenseTensor out_indices = phi::EmptyLike(dev_ctx, x_indices); @@ -95,21 +95,22 @@ void CoalesceCPUKernel(const CPUContext& dev_ctx, } template -void CoalesceKernel(const Context& dev_ctx, - const SparseCooTensor& x, - SparseCooTensor* out) { - PD_VISIT_BASE_INTEGRAL_TYPES(x.indices().dtype(), "CoalesceCPUKernel", ([&] { - CoalesceCPUKernel(dev_ctx, x, out); - })); +void CoalesceCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out) { + PD_VISIT_BASE_INTEGRAL_TYPES( + x.indices().dtype(), "CoalesceCooCPUKernel", ([&] { + CoalesceCooCPUKernel(dev_ctx, x, out); + })); } } // namespace sparse } // namespace phi -PD_REGISTER_KERNEL(coalesce, +PD_REGISTER_KERNEL(coalesce_coo, CPU, ALL_LAYOUT, - phi::sparse::CoalesceKernel, + phi::sparse::CoalesceCooKernel, float, double, phi::dtype::float16, diff --git a/paddle/phi/kernels/sparse/cpu/full_kernel.cc b/paddle/phi/kernels/sparse/cpu/full_kernel.cc index b848751deb9..ac13327caee 100644 --- a/paddle/phi/kernels/sparse/cpu/full_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/full_kernel.cc @@ -31,7 +31,7 @@ void FullValue(const Context& dev_ctx, DenseTensor* tensor, T val) { } template -void CooFullLikeKernel(const Context& dev_ctx, +void FullLikeCooKernel(const Context& dev_ctx, const SparseCooTensor& x, const Scalar& val, DataType dtype, @@ -51,7 +51,7 @@ void CooFullLikeKernel(const Context& dev_ctx, } template -void CsrFullLikeKernel(const Context& dev_ctx, +void FullLikeCsrKernel(const Context& dev_ctx, const SparseCsrTensor& x, const Scalar& val, DataType dtype, @@ -78,10 +78,10 @@ void CsrFullLikeKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(coo_full_like, +PD_REGISTER_KERNEL(full_like_coo, CPU, ALL_LAYOUT, - phi::CooFullLikeKernel, + phi::FullLikeCooKernel, float, double, uint8_t, @@ -96,10 +96,10 @@ PD_REGISTER_KERNEL(coo_full_like, kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } -PD_REGISTER_KERNEL(csr_full_like, +PD_REGISTER_KERNEL(full_like_csr, CPU, ALL_LAYOUT, - phi::CsrFullLikeKernel, + phi::FullLikeCsrKernel, float, double, uint8_t, diff --git a/paddle/phi/kernels/sparse/cpu/unary_kernel.cc b/paddle/phi/kernels/sparse/cpu/unary_kernel.cc index d0df0095947..a8fc928108c 100644 --- a/paddle/phi/kernels/sparse/cpu/unary_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/unary_kernel.cc @@ -25,7 +25,7 @@ namespace phi { namespace sparse { template -void DivCooScalarKernel(const Context& dev_ctx, +void DivScalarCooKernel(const Context& dev_ctx, const SparseCooTensor& x, float scalar, SparseCooTensor* out) { @@ -41,7 +41,7 @@ void DivCooScalarKernel(const Context& dev_ctx, } template -void DivCsrScalarKernel(const Context& dev_ctx, +void DivScalarCsrKernel(const Context& dev_ctx, const SparseCsrTensor& x, float scalar, SparseCsrTensor* out) { @@ -97,19 +97,19 @@ PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(expm1, Expm1) PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu6, Relu6) PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(leaky_relu, LeakyRelu) -PD_REGISTER_KERNEL(divide_coo_scalar, +PD_REGISTER_KERNEL(divide_scalar_coo, CPU, ALL_LAYOUT, - phi::sparse::DivCooScalarKernel, + phi::sparse::DivScalarCooKernel, float, double) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } -PD_REGISTER_KERNEL(divide_csr_scalar, +PD_REGISTER_KERNEL(divide_scalar_csr, CPU, ALL_LAYOUT, - phi::sparse::DivCsrScalarKernel, + phi::sparse::DivScalarCsrKernel, float, double) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); diff --git a/paddle/phi/kernels/sparse/full_kernel.h b/paddle/phi/kernels/sparse/full_kernel.h index 8c84d43ff02..ea3461615b7 100644 --- a/paddle/phi/kernels/sparse/full_kernel.h +++ b/paddle/phi/kernels/sparse/full_kernel.h @@ -22,14 +22,14 @@ namespace phi { template -void CooFullLikeKernel(const Context& dev_ctx, +void FullLikeCooKernel(const Context& dev_ctx, const SparseCooTensor& x, const Scalar& val, DataType dtype, SparseCooTensor* out); template -void CsrFullLikeKernel(const Context& dev_ctx, +void FullLikeCsrKernel(const Context& dev_ctx, const SparseCsrTensor& x, const Scalar& val, DataType dtype, diff --git a/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu b/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu index d369c0ecd99..e10d762c886 100644 --- a/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu @@ -27,9 +27,9 @@ namespace phi { namespace sparse { template -void CoalesceGPUKernel(const GPUContext& dev_ctx, - const SparseCooTensor& x, - SparseCooTensor* out) { +void CoalesceCooGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out) { const DenseTensor& x_indices = x.indices(); const DenseTensor& x_values = x.values(); DenseTensor out_indices = phi::EmptyLike(dev_ctx, x_indices); @@ -172,20 +172,21 @@ void CoalesceGPUKernel(const GPUContext& dev_ctx, } template -void CoalesceKernel(const Context& dev_ctx, - const SparseCooTensor& x, - SparseCooTensor* out) { - PD_VISIT_BASE_INTEGRAL_TYPES(x.indices().dtype(), "CoalesceGPUKernel", ([&] { - CoalesceGPUKernel(dev_ctx, x, out); - })); +void CoalesceCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out) { + PD_VISIT_BASE_INTEGRAL_TYPES( + x.indices().dtype(), "CoalesceCooGPUKernel", ([&] { + CoalesceCooGPUKernel(dev_ctx, x, out); + })); } } // namespace sparse } // namespace phi -PD_REGISTER_KERNEL(coalesce, +PD_REGISTER_KERNEL(coalesce_coo, GPU, ALL_LAYOUT, - phi::sparse::CoalesceKernel, + phi::sparse::CoalesceCooKernel, float, double, phi::dtype::float16, diff --git a/paddle/phi/kernels/sparse/gpu/full_kernel.cu b/paddle/phi/kernels/sparse/gpu/full_kernel.cu index 223561bc179..f0b3537528f 100644 --- a/paddle/phi/kernels/sparse/gpu/full_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/full_kernel.cu @@ -37,7 +37,7 @@ struct FullFunctor { }; template -void CooFullLikeKernel(const Context& dev_ctx, +void FullLikeCooKernel(const Context& dev_ctx, const SparseCooTensor& x, const Scalar& val, DataType dtype, @@ -60,7 +60,7 @@ void CooFullLikeKernel(const Context& dev_ctx, } template -void CsrFullLikeKernel(const Context& dev_ctx, +void FullLikeCsrKernel(const Context& dev_ctx, const SparseCsrTensor& x, const Scalar& val, DataType dtype, @@ -87,10 +87,10 @@ void CsrFullLikeKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(coo_full_like, +PD_REGISTER_KERNEL(full_like_coo, GPU, ALL_LAYOUT, - phi::CooFullLikeKernel, + phi::FullLikeCooKernel, float, double, uint8_t, @@ -105,10 +105,10 @@ PD_REGISTER_KERNEL(coo_full_like, kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } -PD_REGISTER_KERNEL(csr_full_like, +PD_REGISTER_KERNEL(full_like_csr, GPU, ALL_LAYOUT, - phi::CsrFullLikeKernel, + phi::FullLikeCsrKernel, float, double, uint8_t, diff --git a/paddle/phi/kernels/sparse/gpu/unary_kernel.cu b/paddle/phi/kernels/sparse/gpu/unary_kernel.cu index c2d3dec047a..5ff222720ac 100644 --- a/paddle/phi/kernels/sparse/gpu/unary_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/unary_kernel.cu @@ -34,7 +34,7 @@ struct DivScalarFunctor { }; template -void DivCooScalarKernel(const Context& dev_ctx, +void DivScalarCooKernel(const Context& dev_ctx, const SparseCooTensor& x, float scalar, SparseCooTensor* out) { @@ -47,7 +47,7 @@ void DivCooScalarKernel(const Context& dev_ctx, } template -void DivCsrScalarKernel(const Context& dev_ctx, +void DivScalarCsrKernel(const Context& dev_ctx, const SparseCsrTensor& x, float scalar, SparseCsrTensor* out) { @@ -102,19 +102,19 @@ PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(expm1, Expm1) PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu6, Relu6) PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(leaky_relu, LeakyRelu) -PD_REGISTER_KERNEL(divide_coo_scalar, +PD_REGISTER_KERNEL(divide_scalar_coo, GPU, ALL_LAYOUT, - phi::sparse::DivCooScalarKernel, + phi::sparse::DivScalarCooKernel, float, double) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } -PD_REGISTER_KERNEL(divide_csr_scalar, +PD_REGISTER_KERNEL(divide_scalar_csr, GPU, ALL_LAYOUT, - phi::sparse::DivCsrScalarKernel, + phi::sparse::DivScalarCsrKernel, float, double) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); diff --git a/paddle/phi/kernels/sparse/unary_kernel.h b/paddle/phi/kernels/sparse/unary_kernel.h index fb5cd21ed39..d7f3e5c9dd7 100644 --- a/paddle/phi/kernels/sparse/unary_kernel.h +++ b/paddle/phi/kernels/sparse/unary_kernel.h @@ -74,13 +74,13 @@ void ScaleCsrKernel(const Context& dev_ctx, SparseCsrTensor* out); template -void DivCooScalarKernel(const Context& dev_ctx, +void DivScalarCooKernel(const Context& dev_ctx, const SparseCooTensor& x, float scalar, SparseCooTensor* out); template -void DivCsrScalarKernel(const Context& dev_ctx, +void DivScalarCsrKernel(const Context& dev_ctx, const SparseCsrTensor& x, float scalar, SparseCsrTensor* out); diff --git a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc index f9c41d1826a..4e3b00e28cc 100644 --- a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc @@ -212,7 +212,7 @@ void TestConv3dBase(const std::vector& indices, "Conv3d", &d_rulebook, &d_counter); - SparseCooTensor tmp_d_out = sparse::Coalesce(dev_ctx_gpu, d_out); + SparseCooTensor tmp_d_out = sparse::CoalesceCoo(dev_ctx_gpu, d_out); ASSERT_EQ(correct_out_dims.size(), d_out.dims().size()); ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, d_out.nnz()); diff --git a/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc index b4e1def372f..b5d9665d244 100644 --- a/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc @@ -161,7 +161,7 @@ void TestMaxPoolBase(const std::vector& indices, &d_rulebook, &d_counter); - SparseCooTensor tmp_d_out = sparse::Coalesce(dev_ctx_gpu, d_out); + SparseCooTensor tmp_d_out = sparse::CoalesceCoo(dev_ctx_gpu, d_out); ASSERT_EQ(correct_out_dims.size(), d_out.dims().size()); ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, d_out.nnz()); diff --git a/python/paddle/incubate/sparse/nn/layer/conv.py b/python/paddle/incubate/sparse/nn/layer/conv.py index 6684a7561f4..708359be160 100644 --- a/python/paddle/incubate/sparse/nn/layer/conv.py +++ b/python/paddle/incubate/sparse/nn/layer/conv.py @@ -144,7 +144,7 @@ class Conv3D(_Conv3D): Parameters: in_channels(int): The number of input channels in the input image. out_channels(int): The number of output channels produced by the convolution. - kernel_size(int|list|tuple, optional): The size of the convolving kernel. + kernel_size(int|list|tuple): The size of the convolving kernel. stride(int|list|tuple, optional): The stride size. If stride is a list/tuple, it must contain three integers, (stride_D, stride_H, stride_W). Otherwise, the stride_D = stride_H = stride_W = stride. The default value is 1. @@ -277,7 +277,7 @@ class SubmConv3D(_Conv3D): Parameters: in_channels(int): The number of input channels in the input image. out_channels(int): The number of output channels produced by the convolution. - kernel_size(int|list|tuple, optional): The size of the convolving kernel. + kernel_size(int|list|tuple): The size of the convolving kernel. stride(int|list|tuple, optional): The stride size. If stride is a list/tuple, it must contain three integers, (stride_D, stride_H, stride_W). Otherwise, the stride_D = stride_H = stride_W = stride. The default value is 1. -- GitLab