From 5161a047a9624e7676eb5fcbfd1c6e2431738113 Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Tue, 28 Jun 2022 14:23:40 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Sparse=E3=80=91add=20SparseTensor=20mv?= =?UTF-8?q?=20kernel(csr*dense=5Fvec->dence=5Fvec,=20coo*dense=5Fvec->dens?= =?UTF-8?q?e=5Fvec)=20(#43668)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Sparse]add SparseTensor mv kernel(csr*dense_vec->dence_vec, coo*dense_vec->dense_vec) * fix CI --- paddle/fluid/platform/dynload/cusparse.h | 7 +- paddle/phi/backends/dynload/cusparse.h | 7 +- paddle/phi/kernels/funcs/sparse/sparse_blas.h | 13 ++ .../funcs/sparse/sparse_blas_impl.cu.h | 155 ++++++++++++++++- .../phi/kernels/sparse/cpu/matmul_kernel.cc | 4 +- .../phi/kernels/sparse/cpu/mv_grad_kernel.cc | 56 ++++++ paddle/phi/kernels/sparse/cpu/mv_kernel.cc | 52 ++++++ paddle/phi/kernels/sparse/empty_kernel.cc | 3 - .../phi/kernels/sparse/gpu/mv_grad_kernel.cu | 163 ++++++++++++++++++ paddle/phi/kernels/sparse/gpu/mv_kernel.cu | 89 ++++++++++ paddle/phi/kernels/sparse/mv_grad_kernel.h | 43 +++++ paddle/phi/kernels/sparse/mv_kernel.h | 39 +++++ .../tests/unittests/test_sparse_mv_op.py | 111 ++++++++++++ python/paddle/incubate/sparse/__init__.py | 2 + python/paddle/incubate/sparse/binary.py | 64 ++++++- python/paddle/utils/code_gen/sparse_api.yaml | 9 + .../paddle/utils/code_gen/sparse_bw_api.yaml | 8 + 17 files changed, 809 insertions(+), 16 deletions(-) create mode 100644 paddle/phi/kernels/sparse/cpu/mv_grad_kernel.cc create mode 100644 paddle/phi/kernels/sparse/cpu/mv_kernel.cc create mode 100644 paddle/phi/kernels/sparse/gpu/mv_grad_kernel.cu create mode 100644 paddle/phi/kernels/sparse/gpu/mv_kernel.cu create mode 100644 paddle/phi/kernels/sparse/mv_grad_kernel.h create mode 100644 paddle/phi/kernels/sparse/mv_kernel.h create mode 100644 python/paddle/fluid/tests/unittests/test_sparse_mv_op.py diff --git a/paddle/fluid/platform/dynload/cusparse.h b/paddle/fluid/platform/dynload/cusparse.h index 5a67e34fbb..480245fec2 100644 --- a/paddle/fluid/platform/dynload/cusparse.h +++ b/paddle/fluid/platform/dynload/cusparse.h @@ -29,7 +29,6 @@ namespace dynload { extern DynLoad__##__name __name #if defined(PADDLE_WITH_CUDA) -// APIs available after CUDA 11.0 #if CUDA_VERSION >= 11000 #define CUSPARSE_ROUTINE_EACH(__macro) \ __macro(cusparseCreate); \ @@ -43,10 +42,14 @@ namespace dynload { __macro(cusparseCreateCsr); \ __macro(cusparseCreateCoo); \ __macro(cusparseCreateDnMat); \ + __macro(cusparseCreateDnVec); \ __macro(cusparseSpMM_bufferSize); \ __macro(cusparseSpMM); \ __macro(cusparseDestroySpMat); \ - __macro(cusparseDestroyDnMat); + __macro(cusparseDestroyDnMat); \ + __macro(cusparseDestroyDnVec); \ + __macro(cusparseSpMV_bufferSize); \ + __macro(cusparseSpMV); CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) #endif diff --git a/paddle/phi/backends/dynload/cusparse.h b/paddle/phi/backends/dynload/cusparse.h index 8f78ad37af..45a466b380 100644 --- a/paddle/phi/backends/dynload/cusparse.h +++ b/paddle/phi/backends/dynload/cusparse.h @@ -41,7 +41,6 @@ extern void *cusparse_dso_handle; extern DynLoad__##__name __name #if defined(PADDLE_WITH_CUDA) -// APIs available after CUDA 11.0 #if CUDA_VERSION >= 11000 #define CUSPARSE_ROUTINE_EACH(__macro) \ __macro(cusparseCreate); \ @@ -55,10 +54,14 @@ extern void *cusparse_dso_handle; __macro(cusparseCreateCsr); \ __macro(cusparseCreateCoo); \ __macro(cusparseCreateDnMat); \ + __macro(cusparseCreateDnVec); \ __macro(cusparseSpMM_bufferSize); \ __macro(cusparseSpMM); \ __macro(cusparseDestroySpMat); \ - __macro(cusparseDestroyDnMat); + __macro(cusparseDestroyDnMat); \ + __macro(cusparseDestroyDnVec); \ + __macro(cusparseSpMV_bufferSize); \ + __macro(cusparseSpMV); CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) #endif diff --git a/paddle/phi/kernels/funcs/sparse/sparse_blas.h b/paddle/phi/kernels/funcs/sparse/sparse_blas.h index c2e270f71a..858620d0d4 100644 --- a/paddle/phi/kernels/funcs/sparse/sparse_blas.h +++ b/paddle/phi/kernels/funcs/sparse/sparse_blas.h @@ -37,6 +37,14 @@ class SparseBlas { T beta, phi::DenseTensor* mat_out) const; + template + void SPMV(bool transa, + T alpha, + const TensorType& mat_a, + const phi::DenseTensor& vec_x, + T beta, + phi::DenseTensor* vec_out) const; + template void SDDMM(bool transa, bool transb, @@ -60,6 +68,11 @@ class SparseBlasT : private SparseBlas { Base()->template SPMM(args...); } + template + void SPMV(ARGS... args) const { + Base()->template SPMV(args...); + } + template void SDDMM(ARGS... args) const { Base()->template SDDMM(args...); diff --git a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h index c65d506cce..3d92674c92 100644 --- a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h @@ -20,6 +20,7 @@ #include "paddle/phi/common/float16.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/visit_type.h" @@ -47,6 +48,8 @@ inline cusparseOperation_t GetTransposeOperation(const bool trans) { } } +/************* SPARSE MATRIX DESCRIPTOR (COO/CSR) ************/ + template inline void CreateCsrDescriptor(const phi::SparseCsrTensor& x, const phi::GPUContext& dev_ctx, @@ -102,6 +105,60 @@ inline void CreateCsrDescriptor(const phi::SparseCsrTensor& x, } } +template +inline void CreateCooDescriptor(const phi::SparseCooTensor& x, + const phi::GPUContext& dev_ctx, + cusparseSpMatDescr_t* descriptor) { + std::vector xdim_vec = phi::vectorize(x.dims()); + auto x_ndims = xdim_vec.size(); + PADDLE_ENFORCE_GE( + x_ndims, + 2, + phi::errors::InvalidArgument("the dim size of SparseCsrTensor must be " + "greater than or eaqual to 2.")); + + int64_t M = xdim_vec[x_ndims - 2]; + int64_t N = xdim_vec[x_ndims - 1]; + int batch_size = 1; + for (int i = 0; i < x_ndims - 2; i++) { + batch_size *= xdim_vec[i]; + } + int64_t nnz = x.nnz(); + + const IntT* indices_data = x.non_zero_indices().data(); + const T* values_data = x.non_zero_elements().data(); + auto rows_data = indices_data + (x_ndims - 2) * nnz; + auto cols_data = indices_data + (x_ndims - 1) * nnz; + + int64_t batch_nnz = nnz / batch_size; + cudaDataType_t gpu_type = GetGpuDataType(); + dev_ctx.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseCreateCoo(descriptor, + M, + N, + batch_nnz, + const_cast(rows_data), + const_cast(cols_data), + const_cast(values_data), + CUSPARSE_INDEX_64I, + CUSPARSE_INDEX_BASE_ZERO, + gpu_type); + }); + + if (batch_size > 1) { +#if CUDA_VERSION >= 11070 + dev_ctx.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseCooSetStridedBatch( + *descriptor, batch_size, batch_nnz); + }); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "Batch Sparse matmul use 'cusparseCooSetStridedBatch', which is " + "supported from CUDA 11.7")); +#endif + } +} + template class CuSparseSpMatDescriptor { public: @@ -109,12 +166,22 @@ class CuSparseSpMatDescriptor { const phi::GPUContext& dev_ctx) : dev_ctx_(dev_ctx) { PD_VISIT_INTEGRAL_TYPES( - x.non_zero_crows().dtype(), "CuSparseSpMatDescriptor", ([&] { + x.non_zero_crows().dtype(), "Csr CuSparseSpMatDescriptor", ([&] { CreateCsrDescriptor(x, dev_ctx_, &descriptor_); })); VLOG(6) << "Create csr cusparseSpMatDescr_t " << &descriptor_; } + explicit CuSparseSpMatDescriptor(const phi::SparseCooTensor& x, + const phi::GPUContext& dev_ctx) + : dev_ctx_(dev_ctx) { + PD_VISIT_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "Coo CuSparseSpMatDescriptor", ([&] { + CreateCooDescriptor(x, dev_ctx_, &descriptor_); + })); + VLOG(6) << "Create coo cusparseSpMatDescr_t " << &descriptor_; + } + ~CuSparseSpMatDescriptor() { dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { phi::dynload::cusparseDestroySpMat(descriptor_); @@ -129,6 +196,7 @@ class CuSparseSpMatDescriptor { cusparseSpMatDescr_t descriptor_; }; +/************* DENSE MATRIX DESCRIPTOR ************/ template class CuSparseDnMatDescriptor { public: @@ -192,6 +260,44 @@ class CuSparseDnMatDescriptor { cusparseDnMatDescr_t descriptor_; }; +/************* DENSE VECTOR DESCRIPTOR ************/ +template +class CuSparseDnVecDescriptor { + public: + explicit CuSparseDnVecDescriptor(const phi::DenseTensor& x, + const phi::GPUContext& dev_ctx) + : dev_ctx_(dev_ctx) { + std::vector xdim_vec = phi::vectorize(x.dims()); + auto x_ndims = xdim_vec.size(); + PADDLE_ENFORCE_GE(x_ndims, + 1, + phi::errors::InvalidArgument( + "the dim size of Vec must be eaqual to 1.")); + + const T* x_data = x.data(); + cudaDataType_t gpu_type = GetGpuDataType(); + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseCreateDnVec( + &descriptor_, x.numel(), const_cast(x_data), gpu_type); + }); + + VLOG(6) << "Create cusparseDnVecDescr_t " << &descriptor_; + } + + ~CuSparseDnVecDescriptor() { + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseDestroyDnVec(descriptor_); + }); + VLOG(6) << "Destroy cusparseDnVecDescr_t " << &descriptor_; + } + + const cusparseDnVecDescr_t& descriptor() const { return descriptor_; } + + private: + const phi::GPUContext& dev_ctx_; + cusparseDnVecDescr_t descriptor_; +}; + template <> template void SparseBlas::SPMM(bool transa, @@ -239,6 +345,50 @@ void SparseBlas::SPMM(bool transa, }); } +template <> +template +void SparseBlas::SPMV(bool transa, + T alpha, + const TensorType& mat_a, + const phi::DenseTensor& vec_x, + T beta, + phi::DenseTensor* vec_out) const { + auto a_descriptor = CuSparseSpMatDescriptor(mat_a, dev_ctx_); + auto x_descriptor = CuSparseDnVecDescriptor(vec_x, dev_ctx_); + auto out_descriptor = CuSparseDnVecDescriptor(*vec_out, dev_ctx_); + + cudaDataType_t gpu_type = GetGpuDataType(); + size_t buffer_size = 0; + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseSpMV_bufferSize(handle, + GetTransposeOperation(transa), + &alpha, + a_descriptor.descriptor(), + x_descriptor.descriptor(), + &beta, + out_descriptor.descriptor(), + gpu_type, + CUSPARSE_MV_ALG_DEFAULT, + &buffer_size); + }); + + paddle::memory::allocation::AllocationPtr tmp_buffer = + paddle::memory::Alloc(dev_ctx_, buffer_size); + void* tmp_buffer_ptr = tmp_buffer->ptr(); + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseSpMV(handle, + GetTransposeOperation(transa), + &alpha, + a_descriptor.descriptor(), + x_descriptor.descriptor(), + &beta, + out_descriptor.descriptor(), + gpu_type, + CUSPARSE_MV_ALG_DEFAULT, + tmp_buffer_ptr); + }); +} + #if CUDA_VERSION >= 11030 template <> template @@ -249,12 +399,11 @@ void SparseBlas::SDDMM(bool transa, const phi::DenseTensor& mat_b, T beta, TensorType* mat_out) const { - cudaDataType_t gpu_type = GetGpuDataType(); - auto a_descriptor = CuSparseDnMatDescriptor(mat_a, dev_ctx_); auto b_descriptor = CuSparseDnMatDescriptor(mat_b, dev_ctx_); auto out_descriptor = CuSparseSpMatDescriptor(*mat_out, dev_ctx_); + cudaDataType_t gpu_type = GetGpuDataType(); size_t buffer_size = 0; dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { phi::dynload::cusparseSDDMM_bufferSize(handle, diff --git a/paddle/phi/kernels/sparse/cpu/matmul_kernel.cc b/paddle/phi/kernels/sparse/cpu/matmul_kernel.cc index 10ad848442..0818b8e900 100644 --- a/paddle/phi/kernels/sparse/cpu/matmul_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/matmul_kernel.cc @@ -27,7 +27,7 @@ void CsrDenseMatmulKernel(const Context& dev_ctx, const DenseTensor& y, DenseTensor* out) { PADDLE_THROW(phi::errors::Unimplemented( - "Not support CPU kernel of Sparse Matmul now.")); + "Not support CPU kernel of 'sparse.matmul' now.")); } // TODO(zhouwei25): implement CPU kernel of " DENSE @ DENSE * CSR_MASK -> CSR" @@ -38,7 +38,7 @@ void CsrMaskedMatmulKernel(const Context& dev_ctx, const SparseCsrTensor& mask, SparseCsrTensor* out) { PADDLE_THROW(phi::errors::Unimplemented( - "Not support CPU kernel of Matmul Mask As Sparse now.")); + "Not support CPU kernel of 'sparse.masked_matmul' now.")); } } // namespace sparse diff --git a/paddle/phi/kernels/sparse/cpu/mv_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/mv_grad_kernel.cc new file mode 100644 index 0000000000..c8936e62e4 --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/mv_grad_kernel.cc @@ -0,0 +1,56 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/mv_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace sparse { + +template +void MvCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& vec, + const DenseTensor& dout, + SparseCooTensor* dx, + DenseTensor* dvec) { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support CPU backward kernel of 'sparse.mv' now.")); +} + +template +void MvCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const DenseTensor& vec, + const DenseTensor& dout, + SparseCsrTensor* dx, + DenseTensor* dvec) { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support CPU backward kernel of 'sparse.mv' now.")); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL( + mv_coo_grad, CPU, ALL_LAYOUT, phi::sparse::MvCooGradKernel, float, double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +PD_REGISTER_KERNEL( + mv_csr_grad, CPU, ALL_LAYOUT, phi::sparse::MvCsrGradKernel, float, double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} diff --git a/paddle/phi/kernels/sparse/cpu/mv_kernel.cc b/paddle/phi/kernels/sparse/cpu/mv_kernel.cc new file mode 100644 index 0000000000..6f16694c6e --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/mv_kernel.cc @@ -0,0 +1,52 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/mv_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace sparse { + +template +void MvCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const DenseTensor& vec, + DenseTensor* out) { + PADDLE_THROW( + phi::errors::Unimplemented("Not support CPU kernel of 'sparse.mv' now.")); +} + +template +void MvCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& vec, + DenseTensor* out) { + PADDLE_THROW( + phi::errors::Unimplemented("Not support CPU kernel of 'sparse.mv' now.")); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL( + mv_csr, CPU, ALL_LAYOUT, phi::sparse::MvCsrKernel, float, double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} + +PD_REGISTER_KERNEL( + mv_coo, CPU, ALL_LAYOUT, phi::sparse::MvCooKernel, float, double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/empty_kernel.cc b/paddle/phi/kernels/sparse/empty_kernel.cc index 2d04f93521..fe7fb72b4c 100644 --- a/paddle/phi/kernels/sparse/empty_kernel.cc +++ b/paddle/phi/kernels/sparse/empty_kernel.cc @@ -28,12 +28,10 @@ void EmptyLikeCooKernel(const Context& dev_ctx, SparseCooTensor* out) { const DenseTensor& x_indices = x.non_zero_indices(); const DenseTensor& x_values = x.non_zero_elements(); - DenseTensor* out_indices = out->mutable_non_zero_indices(); DenseTensor* out_values = out->mutable_non_zero_elements(); phi::Copy(dev_ctx, x_indices, dev_ctx.GetPlace(), false, out_indices); - phi::Copy(dev_ctx, x_values, dev_ctx.GetPlace(), false, out_values); out_values->Resize(x_values.dims()); dev_ctx.template Alloc(out_values); @@ -48,7 +46,6 @@ void EmptyLikeCsrKernel(const Context& dev_ctx, const DenseTensor& x_crows = x.non_zero_crows(); const DenseTensor& x_cols = x.non_zero_cols(); const DenseTensor& x_values = x.non_zero_elements(); - DenseTensor* out_crows = out->mutable_non_zero_crows(); DenseTensor* out_cols = out->mutable_non_zero_cols(); DenseTensor* out_values = out->mutable_non_zero_elements(); diff --git a/paddle/phi/kernels/sparse/gpu/mv_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/mv_grad_kernel.cu new file mode 100644 index 0000000000..26e37556d3 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/mv_grad_kernel.cu @@ -0,0 +1,163 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/mv_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" + +namespace phi { +namespace sparse { + +template +__global__ void MvCooGradGpuKernel(const T *dout, + const T *vec, + const IntT *dx_indices, + T *dx_values, + int nnz) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < nnz; idx += blockDim.x * gridDim.x) { + int i = dx_indices[idx]; + int j = dx_indices[idx + nnz]; + dx_values[idx] = dout[i] * vec[j]; + } +} + +template +__global__ void MvCsrGradGpuKernel(const T *dout, + const T *vec, + const IntT *dx_crows, + const IntT *dx_cols, + T *dx_values, + int row_number) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + for (; i < row_number; i += gridDim.x * blockDim.x) { + int row_first = static_cast(dx_crows[i]); + int row_nnz = static_cast(dx_crows[i + 1] - dx_crows[i]); + + int non_zero_idx = blockIdx.y * blockDim.y + threadIdx.y; + for (; non_zero_idx < row_nnz; non_zero_idx += gridDim.y * blockDim.y) { + int j = dx_cols[row_first + non_zero_idx]; + dx_values[row_first + non_zero_idx] = dout[i] * vec[j]; + } + } +} + +template +void MvCooGradKernel(const Context &dev_ctx, + const SparseCooTensor &x, + const DenseTensor &vec, + const DenseTensor &dout, + SparseCooTensor *dx, + DenseTensor *dvec) { + // dx{SparseCoo} = dout{Dense} * vec'{Dense} + if (dx) { + // InferMeta of SparseCooTensor 'dx', CreateLikeInferMeta + EmptyLikeCooKernel(dev_ctx, x, dx); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, dx->nnz()); + PD_VISIT_INTEGRAL_TYPES( + dx->non_zero_indices().dtype(), "MvCooGradKernel", ([&] { + MvCooGradGpuKernel + <<>>(dout.data(), + vec.data(), + dx->non_zero_indices().data(), + dx->mutable_non_zero_elements()->data(), + dx->nnz()); + })); + } + + // dvec{Dense} = x'{SparseCoo} * dout{Dense} + if (dvec) { +#if CUDA_VERSION >= 11000 + // InferMeta of DenseTensor 'dvec' + dvec->Resize(vec.dims()); + dev_ctx.template Alloc(dvec); + + auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); + sparse_blas.SPMV(true, static_cast(1), x, dout, static_cast(0), dvec); +#else + PADDLE_THROW( + phi::errors::Unimplemented(" vec.grad of 'sparse.mv' use cusparseSpMV, " + "which is supported from CUDA 11.0")); +#endif + } +} + +template +void MvCsrGradKernel(const Context &dev_ctx, + const SparseCsrTensor &x, + const DenseTensor &vec, + const DenseTensor &dout, + SparseCsrTensor *dx, + DenseTensor *dvec) { + // dx{SparseCsr} = dout{Dense} * vec'{Dense} + if (dx) { + // InferMeta of SparseCsrTensor 'dx', CreateLikeInferMeta + EmptyLikeCsrKernel(dev_ctx, x, dx); + + int row_number = dx->dims()[0]; + int col_number = dx->dims()[1]; + auto config = phi::backends::gpu::GetGpuLaunchConfig2D( + dev_ctx, col_number, row_number); + PD_VISIT_INTEGRAL_TYPES( + dx->non_zero_crows().dtype(), "MvCsrGradKernel", ([&] { + MvCsrGradGpuKernel + <<>>(dout.data(), + vec.data(), + dx->non_zero_crows().data(), + dx->non_zero_cols().data(), + dx->mutable_non_zero_elements()->data(), + row_number); + })); + } + + // dvec{Dense} = x'{SparseCsr} * dout{Dense} + if (dvec) { +#if CUDA_VERSION >= 11000 + // InferMeta of DenseTensor 'dvec' + dvec->Resize(vec.dims()); + dev_ctx.template Alloc(dvec); + + auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); + sparse_blas.SPMV(true, static_cast(1), x, dout, static_cast(0), dvec); +#else + PADDLE_THROW( + phi::errors::Unimplemented(" vec.grad of 'sparse.mv' use cusparseSpMV, " + "which is supported from CUDA 11.0")); +#endif + } +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL( + mv_coo_grad, GPU, ALL_LAYOUT, phi::sparse::MvCooGradKernel, float, double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +PD_REGISTER_KERNEL( + mv_csr_grad, GPU, ALL_LAYOUT, phi::sparse::MvCsrGradKernel, float, double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} diff --git a/paddle/phi/kernels/sparse/gpu/mv_kernel.cu b/paddle/phi/kernels/sparse/gpu/mv_kernel.cu new file mode 100644 index 0000000000..27f094fb0f --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/mv_kernel.cu @@ -0,0 +1,89 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/mv_kernel.h" + +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h" + +namespace phi { +namespace sparse { + +template +void MvKernelImpl(const Context& dev_ctx, + const TensorType& x, + const DenseTensor& vec, + DenseTensor* out) { +#if CUDA_VERSION >= 11000 + std::vector x_dim = phi::vectorize(x.dims()); + std::vector vec_dim = phi::vectorize(vec.dims()); + auto x_ndims = x_dim.size(); + auto vec_ndims = vec_dim.size(); + PADDLE_ENFORCE_EQ(x_ndims, + 2, + phi::errors::InvalidArgument( + "the dims size of Input(x) must be eaqual to 2.")); + PADDLE_ENFORCE_EQ(vec_ndims, + 1, + phi::errors::InvalidArgument( + "the dims size of Input(vec) must be eaqual to 1.")); + PADDLE_ENFORCE_EQ(x_dim[x_ndims - 1], + vec_dim[vec_ndims - 1], + phi::errors::PreconditionNotMet( + "The shape of Input(x) and Input(vec) is not " + "suitable for mv opetation, " + "x_dim[-1] must be eaqual to vec_dim[-1].")); + std::vector out_dim = {x_dim[x_ndims - 2]}; + out->Resize(phi::make_ddim(out_dim)); + dev_ctx.template Alloc(out); + auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); + sparse_blas.SPMV(false, static_cast(1), x, vec, static_cast(0), out); +#else + PADDLE_THROW(phi::errors::Unimplemented( + " 'sparse.mv' use cusparseSpMV, which is supported from CUDA 11.0")); +#endif +} + +template +void MvCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& vec, + DenseTensor* out) { + MvKernelImpl(dev_ctx, x, vec, out); +} + +template +void MvCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const DenseTensor& vec, + DenseTensor* out) { + MvKernelImpl(dev_ctx, x, vec, out); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL( + mv_csr, GPU, ALL_LAYOUT, phi::sparse::MvCsrKernel, float, double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} + +PD_REGISTER_KERNEL( + mv_coo, GPU, ALL_LAYOUT, phi::sparse::MvCooKernel, float, double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/mv_grad_kernel.h b/paddle/phi/kernels/sparse/mv_grad_kernel.h new file mode 100644 index 0000000000..778429992d --- /dev/null +++ b/paddle/phi/kernels/sparse/mv_grad_kernel.h @@ -0,0 +1,43 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" + +namespace phi { +namespace sparse { + +/* backward of COO @ DENSE VEC -> DENSE VEC */ +template +void MvCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& vec, + const DenseTensor& dout, + SparseCooTensor* dx, + DenseTensor* dvec); + +/* backward of CSR @ DENSE VEC -> DENSE VEC */ +template +void MvCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const DenseTensor& vec, + const DenseTensor& dout, + SparseCsrTensor* dx, + DenseTensor* dvec); + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/kernels/sparse/mv_kernel.h b/paddle/phi/kernels/sparse/mv_kernel.h new file mode 100644 index 0000000000..57c598698d --- /dev/null +++ b/paddle/phi/kernels/sparse/mv_kernel.h @@ -0,0 +1,39 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" + +namespace phi { +namespace sparse { + +/* COO @ DENSE VEC -> DENSE VEC */ +template +void MvCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& vec, + DenseTensor* out); + +/* CSR @ DENSE VEC -> DENSE VEC */ +template +void MvCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const DenseTensor& vec, + DenseTensor* out); + +} // namespace sparse +} // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_sparse_mv_op.py b/python/paddle/fluid/tests/unittests/test_sparse_mv_op.py new file mode 100644 index 0000000000..9ac4fff850 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sparse_mv_op.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid.framework import _test_eager_guard + +import numpy as np +import scipy +import scipy.sparse as sp +import unittest +import os +import re + +paddle.seed(100) + + +def get_cuda_version(): + result = os.popen("nvcc --version").read() + regex = r'release (\S+),' + match = re.search(regex, result) + if match: + num = str(match.group(1)) + integer, decimal = num.split('.') + return int(integer) * 1000 + int(float(decimal) * 10) + else: + return -1 + + +@unittest.skipIf( + not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000, + "paddle is not compiled with CUDA and cuda version need to >= 11.0") +class TestCsrMv(unittest.TestCase): + # x: csr-matrix, y: dense-vec, out: dense-vec + def test_mv(self): + with _test_eager_guard(): + paddle.set_default_dtype('float64') + origin_x = paddle.rand([64, 32]) + mask = paddle.randint(0, 2, [64, 32]) + origin_x = origin_x * mask + origin_vec = paddle.rand([32]) + + dense_x = origin_x.detach() + dense_x.stop_gradient = False + dense_vec = origin_vec.detach() + dense_vec.stop_gradient = False + dense_out = paddle.mv(dense_x, dense_vec) + dense_out.backward() + + sp_x = origin_x.detach().to_sparse_csr() + sp_x.stop_gradient = False + sp_vec = origin_vec.detach() + sp_vec.stop_gradient = False + sp_out = paddle.incubate.sparse.mv(sp_x, sp_vec) + sp_out.backward() + + self.assertTrue(np.allclose(sp_out.numpy(), dense_out.numpy())) + self.assertTrue( + np.allclose(sp_x.grad.to_dense().numpy(), + (dense_x.grad * mask).numpy())) + self.assertTrue( + np.allclose(sp_vec.grad.numpy(), dense_vec.grad.numpy())) + + +@unittest.skipIf( + not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000, + "paddle is not compiled with CUDA and cuda version need to >= 11.0") +class TestCooMv(unittest.TestCase): + # x: csr-matrix, y: dense-vec, out: dense-vec + def test_mv(self): + with _test_eager_guard(): + paddle.set_default_dtype('float64') + origin_x = paddle.rand([64, 32]) + mask = paddle.randint(0, 2, [64, 32]) + origin_x = origin_x * mask + origin_vec = paddle.rand([32]) + + dense_x = origin_x.detach() + dense_x.stop_gradient = False + dense_vec = origin_vec.detach() + dense_vec.stop_gradient = False + dense_out = paddle.mv(dense_x, dense_vec) + dense_out.backward() + + sp_x = origin_x.detach().to_sparse_coo(sparse_dim=2) + sp_x.stop_gradient = False + sp_vec = origin_vec.detach() + sp_vec.stop_gradient = False + sp_out = paddle.incubate.sparse.mv(sp_x, sp_vec) + sp_out.backward() + + self.assertTrue(np.allclose(sp_out.numpy(), dense_out.numpy())) + self.assertTrue( + np.allclose(sp_x.grad.to_dense().numpy(), + (dense_x.grad * mask).numpy())) + self.assertTrue( + np.allclose(sp_vec.grad.numpy(), dense_vec.grad.numpy())) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/sparse/__init__.py b/python/paddle/incubate/sparse/__init__.py index 05dd8b6d56..f696434118 100644 --- a/python/paddle/incubate/sparse/__init__.py +++ b/python/paddle/incubate/sparse/__init__.py @@ -19,6 +19,7 @@ from .unary import sqrt from .unary import sin from .unary import tanh +from .binary import mv from .binary import matmul from .binary import masked_matmul @@ -35,6 +36,7 @@ __all__ = [ 'sqrt', 'sin', 'tanh', + 'mv', 'matmul', 'masked_matmul', 'add', diff --git a/python/paddle/incubate/sparse/binary.py b/python/paddle/incubate/sparse/binary.py index a158e7d684..f34378924e 100644 --- a/python/paddle/incubate/sparse/binary.py +++ b/python/paddle/incubate/sparse/binary.py @@ -21,8 +21,8 @@ __all__ = [] @dygraph_only def matmul(x, y, name=None): """ - Warning: - This API is only used from ``CUDA 11.0`` . + Note: + This API is only supported from ``CUDA 11.0`` . Applies matrix multiplication of two Tensors. @@ -83,8 +83,8 @@ def matmul(x, y, name=None): @dygraph_only def masked_matmul(x, y, mask, name=None): """ - Warning: - This API is only used from ``CUDA 11.3`` . + Note: + This API is only supported from ``CUDA 11.3`` . Applies matrix multiplication of two Dense Tensors. @@ -141,3 +141,59 @@ def masked_matmul(x, y, mask, name=None): """ return _C_ops.final_state_sparse_masked_matmul(x, y, mask) + + +@dygraph_only +def mv(x, vec, name=None): + """ + Note: + This API is only supported from ``CUDA 11.0`` . + + Applies matrix-vector product of Sparse Matrix 'x' and Dense vector 'vec' . + + The supported input/output Tensor layout are as follows: + + Note: + x[SparseCsrTensor] @ y[DenseTensor] -> out[SparseCsrTensor] + x[SparseCooTensor] @ y[DenseTensor] -> out[SparseCooTensor] + + It supports backward propagation. + + The shape of `x` should be `[M, N]` , and the shape of `y` should be `[N]` , + and the shape of `out` will be `[M]` . + + Args: + x (Tensor): The input 2D tensor. It must be SparseCooTensor/SparseCsrTensor. The data type can be float32 or float64. + y (Tensor): The input 1D tensor. It must be DenseTensor vector. The data type can be float32 or float64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: 1D Tensor. + + Examples: + + .. code-block:: python + + import paddle + from paddle.fluid.framework import _test_eager_guard + paddle.seed(100) + + # csr @ dense -> dense + with _test_eager_guard(): + crows = [0, 2, 3, 5] + cols = [1, 3, 2, 0, 1] + values = [1., 2., 3., 4., 5.] + dense_shape = [3, 4] + csr = paddle.incubate.sparse.sparse_csr_tensor(crows, cols, values, dense_shape) + # Tensor(shape=[3, 4], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True, + # crows=[0, 2, 3, 5], + # cols=[1, 3, 2, 0, 1], + # values=[1., 2., 3., 4., 5.]) + vec = paddle.randn([4]) + + out = paddle.incubate.sparse.mv(csr, vec) + # Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [-3.85499096, -2.42975140, -1.75087738]) + + """ + return _C_ops.final_state_sparse_mv(x, vec) diff --git a/python/paddle/utils/code_gen/sparse_api.yaml b/python/paddle/utils/code_gen/sparse_api.yaml index e3b61bae15..e99009a70f 100644 --- a/python/paddle/utils/code_gen/sparse_api.yaml +++ b/python/paddle/utils/code_gen/sparse_api.yaml @@ -168,3 +168,12 @@ layout : x intermediate : rulebook backward : sparse_maxpool_grad + +- api: mv + args : (Tensor x, Tensor vec) + output : Tensor(out) + kernel : + func : mv_coo{sparse_coo, dense -> dense}, + mv_csr{sparse_csr, dense -> dense} + layout : x + backward: mv_grad diff --git a/python/paddle/utils/code_gen/sparse_bw_api.yaml b/python/paddle/utils/code_gen/sparse_bw_api.yaml index 7ddaba0f0a..6ceedb0978 100644 --- a/python/paddle/utils/code_gen/sparse_bw_api.yaml +++ b/python/paddle/utils/code_gen/sparse_bw_api.yaml @@ -63,6 +63,14 @@ func : multiply_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}, multiply_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr} +- backward_api : mv_grad + forward : mv(Tensor x, Tensor vec) -> Tensor(out) + args : (Tensor x, Tensor vec, Tensor out_grad) + output : Tensor(x_grad), Tensor(vec_grad) + kernel : + func : mv_coo_grad{sparse_coo, dense, dense -> sparse_coo, dense}, + mv_csr_grad{sparse_csr, dense, dense -> sparse_csr, dense} + - backward_api : relu_grad forward : relu(Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) -- GitLab