From eec4e0347855fc6c56e8dfdf6e51b669b5dabcc1 Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Fri, 24 Jun 2022 16:37:03 +0800 Subject: [PATCH] [Sparse] support batch compute of SparseTensor matmul/masked_matmul/softmax (#43703) --- paddle/fluid/platform/dynload/cusparse.h | 43 +++--- paddle/phi/backends/dynload/cusparse.h | 43 +++--- paddle/phi/kernels/funcs/sparse/sparse_blas.h | 34 ++-- .../funcs/sparse/sparse_blas_impl.cu.h | 146 +++++++++++------- .../kernels/sparse/cpu/softmax_grad_kernel.cc | 47 +++--- .../phi/kernels/sparse/cpu/softmax_kernel.cc | 50 +++--- paddle/phi/kernels/sparse/empty_kernel.cc | 52 ++++++- paddle/phi/kernels/sparse/empty_kernel.h | 6 + .../kernels/sparse/gpu/matmul_grad_kernel.cu | 30 +--- .../phi/kernels/sparse/gpu/matmul_kernel.cu | 33 ++-- .../kernels/sparse/gpu/softmax_grad_kernel.cu | 33 ++-- .../phi/kernels/sparse/gpu/softmax_kernel.cu | 38 +++-- .../tests/unittests/test_sparse_matmul_op.py | 72 ++++++++- ...e_softmax.py => test_sparse_softmax_op.py} | 58 ++++++- .../sparse/nn/functional/activation.py | 2 +- .../incubate/sparse/nn/layer/activation.py | 2 +- 16 files changed, 457 insertions(+), 232 deletions(-) rename python/paddle/fluid/tests/unittests/{test_sparse_softmax.py => test_sparse_softmax_op.py} (53%) diff --git a/paddle/fluid/platform/dynload/cusparse.h b/paddle/fluid/platform/dynload/cusparse.h index c0620a110c..5a67e34fbb 100644 --- a/paddle/fluid/platform/dynload/cusparse.h +++ b/paddle/fluid/platform/dynload/cusparse.h @@ -31,24 +31,22 @@ namespace dynload { #if defined(PADDLE_WITH_CUDA) // APIs available after CUDA 11.0 #if CUDA_VERSION >= 11000 -#define CUSPARSE_ROUTINE_EACH(__macro) \ - __macro(cusparseCreate); \ - __macro(cusparseSetStream); \ - __macro(cusparseCreateMatDescr); \ - __macro(cusparseDestroy); \ - __macro(cusparseSnnz); \ - __macro(cusparseDnnz); \ - __macro(cusparseSetMatType); \ - __macro(cusparseSetMatIndexBase); \ - __macro(cusparseCreateCsr); \ - __macro(cusparseCreateCoo); \ - __macro(cusparseCreateDnMat); \ - __macro(cusparseSpMM_bufferSize); \ - __macro(cusparseSpMM); \ - __macro(cusparseDestroySpMat); \ - __macro(cusparseDestroyDnMat); \ - __macro(cusparseDnMatSetStridedBatch); \ - __macro(cusparseCsrSetStridedBatch); +#define CUSPARSE_ROUTINE_EACH(__macro) \ + __macro(cusparseCreate); \ + __macro(cusparseSetStream); \ + __macro(cusparseCreateMatDescr); \ + __macro(cusparseDestroy); \ + __macro(cusparseSnnz); \ + __macro(cusparseDnnz); \ + __macro(cusparseSetMatType); \ + __macro(cusparseSetMatIndexBase); \ + __macro(cusparseCreateCsr); \ + __macro(cusparseCreateCoo); \ + __macro(cusparseCreateDnMat); \ + __macro(cusparseSpMM_bufferSize); \ + __macro(cusparseSpMM); \ + __macro(cusparseDestroySpMat); \ + __macro(cusparseDestroyDnMat); CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) #endif @@ -62,8 +60,17 @@ CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) CUSPARSE_ROUTINE_EACH_R2(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) #endif +#if CUDA_VERSION >= 11070 +#define CUSPARSE_ROUTINE_EACH_R3(__macro) \ + __macro(cusparseDnMatSetStridedBatch); \ + __macro(cusparseCooSetStridedBatch); \ + __macro(cusparseCsrSetStridedBatch); + +CUSPARSE_ROUTINE_EACH_R3(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) #endif +#endif // PADDLE_WITH_CUDA + #undef PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP } // namespace dynload } // namespace platform diff --git a/paddle/phi/backends/dynload/cusparse.h b/paddle/phi/backends/dynload/cusparse.h index 2e96e5788f..8f78ad37af 100644 --- a/paddle/phi/backends/dynload/cusparse.h +++ b/paddle/phi/backends/dynload/cusparse.h @@ -43,24 +43,22 @@ extern void *cusparse_dso_handle; #if defined(PADDLE_WITH_CUDA) // APIs available after CUDA 11.0 #if CUDA_VERSION >= 11000 -#define CUSPARSE_ROUTINE_EACH(__macro) \ - __macro(cusparseCreate); \ - __macro(cusparseSetStream); \ - __macro(cusparseCreateMatDescr); \ - __macro(cusparseDestroy); \ - __macro(cusparseSnnz); \ - __macro(cusparseDnnz); \ - __macro(cusparseSetMatType); \ - __macro(cusparseSetMatIndexBase); \ - __macro(cusparseCreateCsr); \ - __macro(cusparseCreateCoo); \ - __macro(cusparseCreateDnMat); \ - __macro(cusparseSpMM_bufferSize); \ - __macro(cusparseSpMM); \ - __macro(cusparseDestroySpMat); \ - __macro(cusparseDestroyDnMat); \ - __macro(cusparseDnMatSetStridedBatch); \ - __macro(cusparseCsrSetStridedBatch); +#define CUSPARSE_ROUTINE_EACH(__macro) \ + __macro(cusparseCreate); \ + __macro(cusparseSetStream); \ + __macro(cusparseCreateMatDescr); \ + __macro(cusparseDestroy); \ + __macro(cusparseSnnz); \ + __macro(cusparseDnnz); \ + __macro(cusparseSetMatType); \ + __macro(cusparseSetMatIndexBase); \ + __macro(cusparseCreateCsr); \ + __macro(cusparseCreateCoo); \ + __macro(cusparseCreateDnMat); \ + __macro(cusparseSpMM_bufferSize); \ + __macro(cusparseSpMM); \ + __macro(cusparseDestroySpMat); \ + __macro(cusparseDestroyDnMat); CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) #endif @@ -74,8 +72,17 @@ CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) CUSPARSE_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) #endif +#if CUDA_VERSION >= 11070 +#define CUSPARSE_ROUTINE_EACH_R3(__macro) \ + __macro(cusparseDnMatSetStridedBatch); \ + __macro(cusparseCooSetStridedBatch); \ + __macro(cusparseCsrSetStridedBatch); + +CUSPARSE_ROUTINE_EACH_R3(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) #endif +#endif // PADDLE_WITH_CUDA + #undef DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP } // namespace dynload } // namespace phi diff --git a/paddle/phi/kernels/funcs/sparse/sparse_blas.h b/paddle/phi/kernels/funcs/sparse/sparse_blas.h index edad70edd7..c2e270f71a 100644 --- a/paddle/phi/kernels/funcs/sparse/sparse_blas.h +++ b/paddle/phi/kernels/funcs/sparse/sparse_blas.h @@ -28,33 +28,23 @@ class SparseBlas { public: explicit SparseBlas(const DeviceContext& dev_ctx) : dev_ctx_(dev_ctx) {} - // TODO(zhouwei25): implement "COO @ DENSE -> DENSE" of DSDMM - template - void DSDMM(bool transa, - bool transb, - T alpha, - const phi::SparseCooTensor& mat_a, - const phi::DenseTensor& mat_b, - T beta, - phi::DenseTensor* mat_c) const; - - template - void DSDMM(bool transa, - bool transb, - T alpha, - const phi::SparseCsrTensor& mat_a, - const phi::DenseTensor& mat_b, - T beta, - phi::DenseTensor* mat_c) const; + template + void SPMM(bool transa, + bool transb, + T alpha, + const TensorType& mat_a, + const phi::DenseTensor& mat_b, + T beta, + phi::DenseTensor* mat_out) const; - template + template void SDDMM(bool transa, bool transb, T alpha, const phi::DenseTensor& mat_a, const phi::DenseTensor& mat_b, T beta, - phi::SparseCsrTensor* mat_c) const; + TensorType* mat_out) const; private: const DeviceContext& dev_ctx_; @@ -66,8 +56,8 @@ class SparseBlasT : private SparseBlas { using SparseBlas::SparseBlas; template - void DSDMM(ARGS... args) const { - Base()->template DSDMM(args...); + void SPMM(ARGS... args) const { + Base()->template SPMM(args...); } template diff --git a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h index 0c54f99bef..c65d506cce 100644 --- a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h @@ -47,6 +47,61 @@ inline cusparseOperation_t GetTransposeOperation(const bool trans) { } } +template +inline void CreateCsrDescriptor(const phi::SparseCsrTensor& x, + const phi::GPUContext& dev_ctx, + cusparseSpMatDescr_t* descriptor) { + std::vector xdim_vec = phi::vectorize(x.dims()); + auto x_ndims = xdim_vec.size(); + PADDLE_ENFORCE_GE( + x_ndims, + 2, + phi::errors::InvalidArgument("the dim size of SparseCsrTensor must be " + "greater than or eaqual to 2.")); + int64_t M = xdim_vec[x_ndims - 2]; + int64_t N = xdim_vec[x_ndims - 1]; + int batch_size = 1; + for (int i = 0; i < x_ndims - 2; i++) { + batch_size *= xdim_vec[i]; + } + PADDLE_ENFORCE_EQ(x.non_zero_crows().numel(), + batch_size * (M + 1), + phi::errors::PreconditionNotMet( + "the length of SparseCsrTensor crows is not right.")); + + const IntT* crows_data = x.non_zero_crows().data(); + const IntT* cols_data = x.non_zero_cols().data(); + const T* values_data = x.non_zero_elements().data(); + + int64_t batch_nnz = x.nnz() / batch_size; + cudaDataType_t gpu_type = GetGpuDataType(); + dev_ctx.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseCreateCsr(descriptor, + M, + N, + batch_nnz, + const_cast(crows_data), + const_cast(cols_data), + const_cast(values_data), + CUSPARSE_INDEX_64I, + CUSPARSE_INDEX_64I, + CUSPARSE_INDEX_BASE_ZERO, + gpu_type); + }); + if (batch_size > 1) { +#if CUDA_VERSION >= 11070 + dev_ctx.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseCsrSetStridedBatch( + *descriptor, batch_size, M + 1, batch_nnz); + }); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "Batch Sparse matmul use 'cusparseCsrSetStridedBatch', which is " + "supported from CUDA 11.7")); +#endif + } +} + template class CuSparseSpMatDescriptor { public: @@ -55,45 +110,9 @@ class CuSparseSpMatDescriptor { : dev_ctx_(dev_ctx) { PD_VISIT_INTEGRAL_TYPES( x.non_zero_crows().dtype(), "CuSparseSpMatDescriptor", ([&] { - const data_t* crows_data = x.non_zero_crows().data(); - const data_t* cols_data = x.non_zero_cols().data(); - const T* values_data = x.non_zero_elements().data(); - int64_t nnz = x.nnz(); - - std::vector xdim_vec = phi::vectorize(x.dims()); - auto x_ndims = xdim_vec.size(); - int64_t M = xdim_vec[x_ndims - 2]; - int64_t N = xdim_vec[x_ndims - 1]; - int batch_size = 1; - for (int i = 0; i < x_ndims - 2; i++) { - batch_size *= xdim_vec[i]; - } - - cudaDataType_t gpu_type = GetGpuDataType(); - dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { - phi::dynload::cusparseCreateCsr(&descriptor_, - M, - N, - nnz, - const_cast(crows_data), - const_cast(cols_data), - const_cast(values_data), - CUSPARSE_INDEX_64I, - CUSPARSE_INDEX_64I, - CUSPARSE_INDEX_BASE_ZERO, - gpu_type); - }); - PADDLE_ENFORCE_EQ(x.non_zero_crows().numel(), batch_size * (M + 1)); - PADDLE_ENFORCE_EQ(x.non_zero_cols().numel(), x.nnz()); - if (batch_size > 1) { - dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { - phi::dynload::cusparseCsrSetStridedBatch( - descriptor_, batch_size, M + 1, nnz); - }); - } + CreateCsrDescriptor(x, dev_ctx_, &descriptor_); })); - - VLOG(6) << "Create cusparseSpMatDescr_t " << &descriptor_; + VLOG(6) << "Create csr cusparseSpMatDescr_t " << &descriptor_; } ~CuSparseSpMatDescriptor() { @@ -116,9 +135,14 @@ class CuSparseDnMatDescriptor { explicit CuSparseDnMatDescriptor(const phi::DenseTensor& x, const phi::GPUContext& dev_ctx) : dev_ctx_(dev_ctx) { - const T* x_data = x.data(); std::vector xdim_vec = phi::vectorize(x.dims()); auto x_ndims = xdim_vec.size(); + PADDLE_ENFORCE_GE( + x_ndims, + 2, + phi::errors::InvalidArgument("the dim size of DenseTensor must be " + "greater than or eaqual to 2.")); + int64_t M = xdim_vec[x_ndims - 2]; int64_t N = xdim_vec[x_ndims - 1]; int batch_size = 1; @@ -126,6 +150,7 @@ class CuSparseDnMatDescriptor { batch_size *= xdim_vec[i]; } + const T* x_data = x.data(); cudaDataType_t gpu_type = GetGpuDataType(); dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { phi::dynload::cusparseCreateDnMat(&descriptor_, @@ -139,10 +164,16 @@ class CuSparseDnMatDescriptor { PADDLE_ENFORCE_EQ(x.numel(), batch_size * M * N); if (batch_size > 1) { +#if CUDA_VERSION >= 11070 dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { phi::dynload::cusparseDnMatSetStridedBatch( descriptor_, batch_size, M * N); }); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "Batch Sparse matmul use 'cusparseDnMatSetStridedBatch', which is " + "supported from CUDA 11.7")); +#endif } VLOG(6) << "Create cusparseDnMatDescr_t " << &descriptor_; } @@ -162,20 +193,19 @@ class CuSparseDnMatDescriptor { }; template <> -template -void SparseBlas::DSDMM(bool transa, - bool transb, - T alpha, - const phi::SparseCsrTensor& mat_a, - const phi::DenseTensor& mat_b, - T beta, - phi::DenseTensor* mat_c) const { - cudaDataType_t gpu_type = GetGpuDataType(); - +template +void SparseBlas::SPMM(bool transa, + bool transb, + T alpha, + const TensorType& mat_a, + const phi::DenseTensor& mat_b, + T beta, + phi::DenseTensor* mat_out) const { auto a_descriptor = CuSparseSpMatDescriptor(mat_a, dev_ctx_); auto b_descriptor = CuSparseDnMatDescriptor(mat_b, dev_ctx_); - auto c_descriptor = CuSparseDnMatDescriptor(*mat_c, dev_ctx_); + auto out_descriptor = CuSparseDnMatDescriptor(*mat_out, dev_ctx_); + cudaDataType_t gpu_type = GetGpuDataType(); size_t buffer_size = 0; dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { phi::dynload::cusparseSpMM_bufferSize(handle, @@ -185,7 +215,7 @@ void SparseBlas::DSDMM(bool transa, a_descriptor.descriptor(), b_descriptor.descriptor(), &beta, - c_descriptor.descriptor(), + out_descriptor.descriptor(), gpu_type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size); @@ -202,7 +232,7 @@ void SparseBlas::DSDMM(bool transa, a_descriptor.descriptor(), b_descriptor.descriptor(), &beta, - c_descriptor.descriptor(), + out_descriptor.descriptor(), gpu_type, CUSPARSE_SPMM_ALG_DEFAULT, tmp_buffer_ptr); @@ -211,19 +241,19 @@ void SparseBlas::DSDMM(bool transa, #if CUDA_VERSION >= 11030 template <> -template +template void SparseBlas::SDDMM(bool transa, bool transb, T alpha, const phi::DenseTensor& mat_a, const phi::DenseTensor& mat_b, T beta, - phi::SparseCsrTensor* mat_c) const { + TensorType* mat_out) const { cudaDataType_t gpu_type = GetGpuDataType(); auto a_descriptor = CuSparseDnMatDescriptor(mat_a, dev_ctx_); auto b_descriptor = CuSparseDnMatDescriptor(mat_b, dev_ctx_); - auto c_descriptor = CuSparseSpMatDescriptor(*mat_c, dev_ctx_); + auto out_descriptor = CuSparseSpMatDescriptor(*mat_out, dev_ctx_); size_t buffer_size = 0; dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { @@ -234,7 +264,7 @@ void SparseBlas::SDDMM(bool transa, a_descriptor.descriptor(), b_descriptor.descriptor(), &beta, - c_descriptor.descriptor(), + out_descriptor.descriptor(), gpu_type, CUSPARSE_SDDMM_ALG_DEFAULT, &buffer_size); @@ -252,7 +282,7 @@ void SparseBlas::SDDMM(bool transa, a_descriptor.descriptor(), b_descriptor.descriptor(), &beta, - c_descriptor.descriptor(), + out_descriptor.descriptor(), gpu_type, CUSPARSE_SDDMM_ALG_DEFAULT, tmp_buffer_ptr); @@ -266,7 +296,7 @@ void SparseBlas::SDDMM(bool transa, a_descriptor.descriptor(), b_descriptor.descriptor(), &beta, - c_descriptor.descriptor(), + out_descriptor.descriptor(), gpu_type, CUSPARSE_SDDMM_ALG_DEFAULT, tmp_buffer_ptr); diff --git a/paddle/phi/kernels/sparse/cpu/softmax_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/softmax_grad_kernel.cc index a97718436a..18d4f4a9c2 100644 --- a/paddle/phi/kernels/sparse/cpu/softmax_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/softmax_grad_kernel.cc @@ -38,11 +38,17 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx, "SparseCsrTensor only support axis=-1 for softmax, " "which is faster when reading data by row (axis=-1)")); EmptyLikeCsrKernel(dev_ctx, dout, dx); - auto out_dim = out.dims(); - int rows = 1; - for (int i = 0; i < out_dim.size() - 1; ++i) { - rows *= out_dim[i]; + auto out_rank = out_dim.size(); + + int batch_size = 1; + int row_number = 1; + for (int i = 0; i < out_rank - 1; ++i) { + if (i < out_rank - 2) { + batch_size *= out_dim[i]; + } else if (i == out_rank - 2) { + row_number = out_dim[i]; + } } const DenseTensor& out_crows = out.non_zero_crows(); @@ -50,7 +56,6 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx, const DenseTensor& dout_values = dout.non_zero_elements(); DenseTensor* dx_values = dx->mutable_non_zero_elements(); - int row_first = 0; int row_nnz = 0; const T* out_data = out_values.data(); const T* dout_data = dout_values.data(); @@ -60,20 +65,24 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx, PD_VISIT_INTEGRAL_TYPES( out.non_zero_crows().dtype(), "SoftmaxCsrGradKernel", ([&] { const data_t* out_crows_data = out_crows.data(); - for (int i = 0; i < rows; ++i) { - row_first = static_cast(out_crows_data[i]); - row_nnz = static_cast(out_crows_data[i + 1] - out_crows_data[i]); - - out_data = out_data + row_first; - dout_data = dout_data + row_first; - dx_data = dx_data + row_first; - - T sum = 0; - phi::funcs::vec_mul_reduce( - row_nnz, dout_data, out_data, &sum); - phi::funcs::vec_add_bias( - row_nnz, static_cast(-1) * sum, dout_data, dx_data); - phi::funcs::vec_mul(row_nnz, dx_data, out_data, dx_data); + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < row_number; ++j) { + int crow_idx = i * (row_number + 1) + j; + row_nnz = static_cast(out_crows_data[crow_idx + 1] - + out_crows_data[crow_idx]); + + T sum = 0; + phi::funcs::vec_mul_reduce( + row_nnz, dout_data, out_data, &sum); + phi::funcs::vec_add_bias( + row_nnz, static_cast(-1) * sum, dout_data, dx_data); + phi::funcs::vec_mul( + row_nnz, dx_data, out_data, dx_data); + + out_data = out_data + row_nnz; + dout_data = dout_data + row_nnz; + dx_data = dx_data + row_nnz; + } } })); } diff --git a/paddle/phi/kernels/sparse/cpu/softmax_kernel.cc b/paddle/phi/kernels/sparse/cpu/softmax_kernel.cc index 46ff09004c..5f7342b52a 100644 --- a/paddle/phi/kernels/sparse/cpu/softmax_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/softmax_kernel.cc @@ -37,18 +37,23 @@ void SoftmaxCsrKernel(const Context& dev_ctx, "SparseCsrTensor only support axis=-1 for softmax, " "which is faster when reading data by row (axis=-1)")); EmptyLikeCsrKernel(dev_ctx, x, out); - auto x_dim = x.dims(); + auto x_rank = x_dim.size(); + + int batch_size = 1; int row_number = 1; - for (int i = 0; i < x_dim.size() - 1; ++i) { - row_number *= x_dim[i]; + for (int i = 0; i < x_rank - 1; ++i) { + if (i < x_rank - 2) { + batch_size *= x_dim[i]; + } else if (i == x_rank - 2) { + row_number = x_dim[i]; + } } const DenseTensor& x_crows = x.non_zero_crows(); const DenseTensor& x_values = x.non_zero_elements(); DenseTensor* out_values = out->mutable_non_zero_elements(); - int row_first = 0; int row_nnz = 0; T row_max_val = 0; const T* x_data = x_values.data(); @@ -58,23 +63,26 @@ void SoftmaxCsrKernel(const Context& dev_ctx, PD_VISIT_INTEGRAL_TYPES( x.non_zero_crows().dtype(), "CsrSoftmaxKernel", ([&] { const data_t* x_crows_data = x_crows.data(); - for (int i = 0; i < row_number; ++i) { - row_first = static_cast(x_crows_data[i]); - row_nnz = static_cast(x_crows_data[i + 1] - x_crows_data[i]); - - x_data = x_data + row_first; - out_data = out_data + row_first; - - row_max_val = *std::max_element(x_data, x_data + row_nnz); - phi::funcs::vec_add_bias( - row_nnz, static_cast(-1) * row_max_val, x_data, out_data); - - phi::funcs::vec_exp(row_nnz, out_data, out_data); - - T sum = 0; - phi::funcs::vec_sum(row_nnz, out_data, &sum); - phi::funcs::vec_scal( - row_nnz, static_cast(1) / sum, out_data, out_data); + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < row_number; ++j) { + int crow_idx = i * (row_number + 1) + j; + row_nnz = static_cast(x_crows_data[crow_idx + 1] - + x_crows_data[crow_idx]); + + row_max_val = *std::max_element(x_data, x_data + row_nnz); + phi::funcs::vec_add_bias( + row_nnz, static_cast(-1) * row_max_val, x_data, out_data); + + phi::funcs::vec_exp(row_nnz, out_data, out_data); + + T sum = 0; + phi::funcs::vec_sum(row_nnz, out_data, &sum); + phi::funcs::vec_scal( + row_nnz, static_cast(1) / sum, out_data, out_data); + + x_data = x_data + row_nnz; + out_data = out_data + row_nnz; + } } })); } diff --git a/paddle/phi/kernels/sparse/empty_kernel.cc b/paddle/phi/kernels/sparse/empty_kernel.cc index b8e164de53..4b7a5fe615 100644 --- a/paddle/phi/kernels/sparse/empty_kernel.cc +++ b/paddle/phi/kernels/sparse/empty_kernel.cc @@ -22,6 +22,25 @@ limitations under the License. */ namespace phi { namespace sparse { +template +void EmptyLikeCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out) { + const DenseTensor& x_indices = x.non_zero_indices(); + const DenseTensor& x_values = x.non_zero_elements(); + + DenseTensor* out_indices = out->mutable_non_zero_indices(); + DenseTensor* out_values = out->mutable_non_zero_elements(); + + phi::Copy(dev_ctx, x_indices, dev_ctx.GetPlace(), false, out_indices); + phi::Copy(dev_ctx, x_values, dev_ctx.GetPlace(), false, out_values); + + out_values->Resize(x_values.dims()); + dev_ctx.template Alloc(out_values); + + out->set_dims(x.dims()); +} + template void EmptyLikeCsrKernel(const Context& dev_ctx, const SparseCsrTensor& x, @@ -34,17 +53,33 @@ void EmptyLikeCsrKernel(const Context& dev_ctx, DenseTensor* out_cols = out->mutable_non_zero_cols(); DenseTensor* out_values = out->mutable_non_zero_elements(); - out->set_dims(x.dims()); phi::Copy(dev_ctx, x_crows, dev_ctx.GetPlace(), false, out_crows); phi::Copy(dev_ctx, x_cols, dev_ctx.GetPlace(), false, out_cols); out_values->Resize(x_values.dims()); dev_ctx.template Alloc(out_values); + + out->set_dims(x.dims()); } } // namespace sparse } // namespace phi +PD_REGISTER_KERNEL(empty_like_coo, + CPU, + ALL_LAYOUT, + phi::sparse::EmptyLikeCooKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + PD_REGISTER_KERNEL(empty_like_csr, CPU, ALL_LAYOUT, @@ -61,6 +96,21 @@ PD_REGISTER_KERNEL(empty_like_csr, } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL(empty_like_coo, + GPU, + ALL_LAYOUT, + phi::sparse::EmptyLikeCooKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + PD_REGISTER_KERNEL(empty_like_csr, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/sparse/empty_kernel.h b/paddle/phi/kernels/sparse/empty_kernel.h index 9656f3e7b3..29eb20af58 100644 --- a/paddle/phi/kernels/sparse/empty_kernel.h +++ b/paddle/phi/kernels/sparse/empty_kernel.h @@ -14,11 +14,17 @@ limitations under the License. */ #pragma once +#include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" namespace phi { namespace sparse { +template +void EmptyLikeCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out); + template void EmptyLikeCsrKernel(const Context& dev_ctx, const SparseCsrTensor& x, diff --git a/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu index 6db1bfac6b..8bc162eaae 100644 --- a/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/sparse/sparse_blas.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h" namespace phi { @@ -38,23 +39,8 @@ void CsrDenseMatmulGradKernel(const Context& dev_ctx, // dx{SparseCsr} = dout{Dense} * y'{Dense} if (dx) { - // InferMeta of SparseCsrTensor 'dx' - dx->set_dims(x.dims()); - - phi::Copy(dev_ctx, - x.non_zero_crows(), - dev_ctx.GetPlace(), - false, - dx->mutable_non_zero_crows()); - phi::Copy(dev_ctx, - x.non_zero_cols(), - dev_ctx.GetPlace(), - false, - dx->mutable_non_zero_cols()); - - DenseTensor* values = dx->mutable_non_zero_elements(); - values->Resize(x.non_zero_elements().dims()); - dev_ctx.template Alloc(values); + // InferMeta of SparseCsrTensor 'dx', CreateLikeInferMeta + EmptyLikeCsrKernel(dev_ctx, x, dx); sparse_blas.SDDMM( false, true, static_cast(1), dout, y, static_cast(0), dx); @@ -69,13 +55,13 @@ void CsrDenseMatmulGradKernel(const Context& dev_ctx, dev_ctx.template Alloc(dy); - sparse_blas.DSDMM( + sparse_blas.SPMM( true, false, static_cast(1), x, dout, static_cast(0), dy); } #else PADDLE_THROW(phi::errors::Unimplemented( - " backward of 'sparse.mm' use cusparseSDDMM, Only " - "support it from CUDA 11.3")); + "backward of 'sparse.matmul' use cusparseSDDMM, which is supported from " + "CUDA 11.3")); #endif } @@ -97,7 +83,7 @@ void CsrMaskedMatmulGradKernel(const Context& dev_ctx, meta_dx.set_dtype(x.dtype()); dev_ctx.template Alloc(dx); - sparse_blas.DSDMM( + sparse_blas.SPMM( false, true, static_cast(1), dout, y, static_cast(0), dx); } @@ -109,7 +95,7 @@ void CsrMaskedMatmulGradKernel(const Context& dev_ctx, std::swap(trans_dim_vec[rank - 1], trans_dim_vec[rank - 2]); DenseTensor trans_dy = phi::Empty(dev_ctx, trans_dim_vec); - sparse_blas.DSDMM( + sparse_blas.SPMM( true, false, static_cast(1), dout, x, static_cast(0), &trans_dy); // InferMeta of DenseTensor 'dy' diff --git a/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu b/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu index e1192687ac..df5a4b5752 100644 --- a/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/sparse/sparse_blas.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" namespace phi { namespace sparse { @@ -59,7 +60,7 @@ void CsrDenseMatmulKernel(const Context& dev_ctx, PADDLE_ENFORCE_EQ(xdim_vec[i], ydim_vec[i], phi::errors::InvalidArgument( - "x.dim[%d] and x.dim[%d] must match.", i, i)); + "x.dim[%d] and x.dim[%d] must be eaqul.", i, i)); } PADDLE_ENFORCE_GE( @@ -80,11 +81,11 @@ void CsrDenseMatmulKernel(const Context& dev_ctx, dev_ctx.template Alloc(out); auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); - sparse_blas.DSDMM( + sparse_blas.SPMM( false, false, static_cast(1), x, y, static_cast(0), out); #else PADDLE_THROW( - phi::errors::Unimplemented(" forward of 'sparse.mm' use cusparseSpMM, " + phi::errors::Unimplemented("forward of 'sparse.matmul' use cusparseSpMM, " "which is supported from CUDA 11.0")); #endif } @@ -159,32 +160,16 @@ void CsrMaskedMatmulKernel(const Context& dev_ctx, "The shape of Input(x) and Input(y) is not suitable for matmul " "opetation, mask_dim[-1] must be eaqual to y_dim[-1].")); - // InferMeta of SparseCsrTensor 'out' - out->set_dims(mask.dims()); - - phi::Copy(dev_ctx, - mask.non_zero_crows(), - dev_ctx.GetPlace(), - false, - out->mutable_non_zero_crows()); - phi::Copy(dev_ctx, - mask.non_zero_cols(), - dev_ctx.GetPlace(), - false, - out->mutable_non_zero_cols()); - - DenseTensor* values = out->mutable_non_zero_elements(); - values->Resize(mask.non_zero_elements().dims()); - dev_ctx.template Alloc(values); + // InferMeta of SparseCsrTensor 'out', CreateLikeInferMeta + EmptyLikeCsrKernel(dev_ctx, mask, out); auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); sparse_blas.SDDMM( false, false, static_cast(1), x, y, static_cast(0), out); #else - PADDLE_THROW( - phi::errors::Unimplemented(" forward of 'sparse.masked_mm' use " - "cusparseSDDMM, which is supported from " - "CUDA 11.3")); + PADDLE_THROW(phi::errors::Unimplemented( + "forward of 'sparse.masked_matmul' use cusparseSDDMM, which is supported " + "from CUDA 11.3")); #endif } diff --git a/paddle/phi/kernels/sparse/gpu/softmax_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/softmax_grad_kernel.cu index 0eb0acc3fc..14b9ec9a37 100644 --- a/paddle/phi/kernels/sparse/gpu/softmax_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/softmax_grad_kernel.cu @@ -12,12 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/phi/kernels/sparse/softmax_grad_kernel.h" + #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/math_cuda_utils.h" #include "paddle/phi/kernels/sparse/empty_kernel.h" -#include "paddle/phi/kernels/sparse/softmax_grad_kernel.h" namespace phi { namespace sparse { @@ -27,13 +28,20 @@ __global__ void SoftmaxGradGpuKernel(const IntT* out_crows, const T* out_values, const T* dout_values, T* dx_values, - int row_number) { + int row_number, + int total_row_number) { // dx = (dout - sum(dout * out)) * out int row = blockIdx.x * blockDim.y + threadIdx.y; int non_zero_idx = threadIdx.x; - if (row >= row_number) return; - int row_first = static_cast(out_crows[row]); - int row_nnz = static_cast(out_crows[row + 1] - out_crows[row]); + if (row >= total_row_number) return; + int cur_batch = row / row_number; + int crow_idx = cur_batch * (row_number + 1) + (row % row_number); + int cur_batch_offset = 0; + for (int i = 1; i < cur_batch + 1; ++i) { + cur_batch_offset += out_crows[i * (row_number + 1) - 1]; + } + int row_first = cur_batch_offset + static_cast(out_crows[crow_idx]); + int row_nnz = static_cast(out_crows[crow_idx + 1] - out_crows[crow_idx]); if (row_nnz == 0) return; int kIteration = (row_nnz + warpSize - 1) / warpSize; @@ -70,12 +78,18 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx, EmptyLikeCsrKernel(dev_ctx, dout, dx); auto out_dim = out.dims(); + auto out_rank = out_dim.size(); + + int total_row_number = 1; int row_number = 1; - for (int i = 0; i < out_dim.size() - 1; ++i) { - row_number *= out_dim[i]; + for (int i = 0; i < out_rank - 1; ++i) { + total_row_number *= out_dim[i]; + if (i == out_rank - 2) { + row_number = out_dim[i]; + } } - dim3 grid((row_number + 3) / 4); + dim3 grid((total_row_number + 3) / 4); dim3 block(32, 4); PD_VISIT_INTEGRAL_TYPES( @@ -85,7 +99,8 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx, out.non_zero_elements().data(), dout.non_zero_elements().data(), dx->mutable_non_zero_elements()->data(), - row_number); + row_number, + total_row_number); })); } diff --git a/paddle/phi/kernels/sparse/gpu/softmax_kernel.cu b/paddle/phi/kernels/sparse/gpu/softmax_kernel.cu index 6353a46eea..9c9f5cfbca 100644 --- a/paddle/phi/kernels/sparse/gpu/softmax_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/softmax_kernel.cu @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/phi/kernels/sparse/softmax_kernel.h" + #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/visit_type.h" @@ -19,7 +21,6 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/activation_functor.h" #include "paddle/phi/kernels/funcs/math_cuda_utils.h" #include "paddle/phi/kernels/sparse/empty_kernel.h" -#include "paddle/phi/kernels/sparse/softmax_kernel.h" namespace phi { namespace sparse { @@ -28,13 +29,20 @@ template __global__ void SoftmaxGpuKernel(const IntT* x_crows, const T* x_values, T* out_values, - int row_number) { + int row_number, + int total_row_number) { // out = exp(x-x_max) / sum(exp(x-x_max)) int row = blockIdx.x * blockDim.y + threadIdx.y; int non_zero_idx = threadIdx.x; - if (row >= row_number) return; - int row_first = static_cast(x_crows[row]); - int row_nnz = static_cast(x_crows[row + 1] - x_crows[row]); + if (row >= total_row_number) return; + int cur_batch = row / row_number; + int crow_idx = cur_batch * (row_number + 1) + (row % row_number); + int cur_batch_offset = 0; + for (int i = 1; i < cur_batch + 1; ++i) { + cur_batch_offset += x_crows[i * (row_number + 1) - 1]; + } + int row_first = cur_batch_offset + static_cast(x_crows[crow_idx]); + int row_nnz = static_cast(x_crows[crow_idx + 1] - x_crows[crow_idx]); if (row_nnz == 0) return; int kIteration = (row_nnz + warpSize - 1) / warpSize; @@ -81,17 +89,20 @@ void SoftmaxCsrKernel(const Context& dev_ctx, "SparseCsrTensor only support axis=-1 for softmax, " "which is faster when reading data by row (axis=-1)")); EmptyLikeCsrKernel(dev_ctx, x, out); - auto x_dim = x.dims(); + auto x_rank = x_dim.size(); + + int total_row_number = 1; int row_number = 1; - for (int i = 0; i < x_dim.size() - 1; ++i) { - row_number *= x_dim[i]; + for (int i = 0; i < x_rank - 1; ++i) { + total_row_number *= x_dim[i]; + if (i == x_rank - 2) { + row_number = x_dim[i]; + } } - dim3 grid((row_number + 3) / 4); - dim3 block(32, 4); - DenseTensor tmp_tensor = - phi::EmptyLike(dev_ctx, x.non_zero_elements()); + dim3 grid((total_row_number + 3) / 4); + dim3 block(32, 4); PD_VISIT_INTEGRAL_TYPES(x.non_zero_crows().dtype(), "CsrSoftmaxKernel", ([&] { SoftmaxGpuKernel @@ -99,7 +110,8 @@ void SoftmaxCsrKernel(const Context& dev_ctx, x.non_zero_crows().data(), x.non_zero_elements().data(), out->mutable_non_zero_elements()->data(), - row_number); + row_number, + total_row_number); })); } diff --git a/python/paddle/fluid/tests/unittests/test_sparse_matmul_op.py b/python/paddle/fluid/tests/unittests/test_sparse_matmul_op.py index 64087ed950..96adf959b2 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_matmul_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_matmul_op.py @@ -114,7 +114,77 @@ class TestCsrMaskedMatmul2D(unittest.TestCase): self.assertTrue(np.allclose(np_y_grad, y.grad.numpy())) -#TODO(zhouwei25): support unit test of batch 'paddle.sparse.mm/masked_mm' +@unittest.skipIf( + not paddle.is_compiled_with_cuda() or get_cuda_version() < 11070, + "paddle is not compiled with CUDA and cuda version need to >= 11.7") +class TestCsrDenseMatmul3D(unittest.TestCase): + # x: csr, y: dense, out: dense + def test_matmul(self): + with _test_eager_guard(): + paddle.set_default_dtype('float32') + origin_x = paddle.rand([16, 16, 12]) + mask = paddle.randint(0, 2, [16, 12]) + origin_x = origin_x * mask + origin_y = paddle.rand([16, 12, 10]) + + dense_x = origin_x.detach() + dense_x.stop_gradient = False + dense_y = origin_y.detach() + dense_y.stop_gradient = False + dense_out = paddle.matmul(dense_x, dense_y) + dense_out.backward() + + sp_x = origin_x.detach().to_sparse_csr() + sp_x.stop_gradient = False + sp_y = origin_y.detach() + sp_y.stop_gradient = False + sp_out = paddle.incubate.sparse.matmul(sp_x, sp_y) + sp_out.backward() + + self.assertTrue(np.allclose(sp_out.numpy(), dense_out.numpy())) + self.assertTrue( + np.allclose(sp_x.grad.to_dense().numpy(), + (dense_x.grad * mask).numpy())) + self.assertTrue(np.allclose(sp_y.grad.numpy(), + dense_y.grad.numpy())) + + +@unittest.skipIf( + not paddle.is_compiled_with_cuda() or get_cuda_version() < 11070, + "paddle is not compiled with CUDA and cuda version need to >= 11.7") +class TestCsrMaskedMatmul3D(unittest.TestCase): + # x: dense, y: dense, out: csr + def test_matmul(self): + with _test_eager_guard(): + paddle.set_default_dtype('float64') + origin_x = paddle.rand([16, 16, 12]) + origin_y = paddle.rand([16, 12, 10]) + + mask = paddle.randint(0, 2, [16, 10]) + + dense_x = origin_x.detach() + dense_x.stop_gradient = False + dense_y = origin_y.detach() + dense_y.stop_gradient = False + dense_out = paddle.matmul(dense_x, dense_y) + dense_out = dense_out * mask + dense_out.backward() + + sp_x = origin_x.detach() + sp_x.stop_gradient = False + sp_y = origin_y.detach() + sp_y.stop_gradient = False + sp_out = paddle.incubate.sparse.masked_matmul( + sp_x, sp_y, dense_out.to_sparse_csr()) + sp_out.backward() + + self.assertTrue( + np.allclose(sp_out.to_dense().numpy(), dense_out.numpy())) + self.assertTrue(np.allclose(sp_x.grad.numpy(), + dense_x.grad.numpy())) + self.assertTrue(np.allclose(sp_y.grad.numpy(), + dense_y.grad.numpy())) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sparse_softmax.py b/python/paddle/fluid/tests/unittests/test_sparse_softmax_op.py similarity index 53% rename from python/paddle/fluid/tests/unittests/test_sparse_softmax.py rename to python/paddle/fluid/tests/unittests/test_sparse_softmax_op.py index 5f33a9e474..b1026b080c 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_softmax.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_softmax_op.py @@ -28,10 +28,10 @@ np.random.seed(2022) class TestCsrSoftmax(unittest.TestCase): - def test_softmax(self): + def test_softmax2d(self): with _test_eager_guard(): - mask = np.random.rand(1, 5) < 0.5 - np_x = np.random.rand(1, 5) * mask + mask = np.random.rand(16, 128) < 0.5 + np_x = np.random.rand(16, 128) * mask np_csr = sp.csr_matrix(np_x) row_number = np_csr.shape[0] @@ -56,6 +56,7 @@ class TestCsrSoftmax(unittest.TestCase): # dx = (dout - sum(dout * out)) * out, dout=rand_x out.backward(csr.detach()) + dx = np.array([]) for i in range(row_number): start = np_csr.indptr[i] end = np_csr.indptr[i + 1] @@ -64,7 +65,7 @@ class TestCsrSoftmax(unittest.TestCase): out = np_out[start:end] dout = np_csr.data[start:end] sum = np.sum(dout * out, keepdims=True) - dx = (dout - sum) * out + dx = np.concatenate([dx, (dout - sum) * out]) self.assertTrue(np.allclose(csr.grad.crows().numpy(), np_csr.indptr)) @@ -72,6 +73,55 @@ class TestCsrSoftmax(unittest.TestCase): np_csr.indices)) self.assertTrue(np.allclose(csr.grad.values().numpy(), dx)) + def test_softmax3d(self): + with _test_eager_guard(): + batchNum = 16 + mask = np.random.rand(batchNum, 16, 128) < 0.5 + np_x = np.random.rand(batchNum, 16, 128) * mask + + np_out_list = [] + np_out = np.array([]) + for i in range(batchNum): + np_csr = sp.csr_matrix(np_x[i, :, :]) + row_number = np_csr.shape[0] + for j in range(row_number, ): + start = np_csr.indptr[j] + end = np_csr.indptr[j + 1] + if start == end: + continue + x = np_csr.data[start:end] + x_max = np.max(x, keepdims=True) + x_exp = np.exp(x - x_max) + x_exp_sum = np.sum(x_exp, keepdims=True) + np_out_list.append(x_exp / x_exp_sum) + np_out = np.concatenate([np_out, x_exp / x_exp_sum]) + + csr = paddle.to_tensor(np_x, stop_gradient=False).to_sparse_csr() + m = paddle.incubate.sparse.nn.Softmax() + out = m(csr) + self.assertTrue(np.allclose(out.values().numpy(), np_out)) + + # dx = (dout - sum(dout * out)) * out, dout=rand_x + out.backward(csr.detach()) + dx = np.array([]) + batch_offset = 0 + for i in range(batchNum): + np_csr = sp.csr_matrix(np_x[i, :, :]) + row_number = np_csr.shape[0] + for j in range(row_number): + start = np_csr.indptr[j] + end = np_csr.indptr[j + 1] + if start == end: + continue + dout = np_csr.data[start:end] + out = np_out[batch_offset + start:batch_offset + end] + sum = np.sum(dout * out, keepdims=True) + dx = np.concatenate([dx, (dout - sum) * out]) + + batch_offset += np_csr.nnz + + self.assertTrue(np.allclose(csr.grad.values().numpy(), dx)) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/incubate/sparse/nn/functional/activation.py b/python/paddle/incubate/sparse/nn/functional/activation.py index 3d4e77010b..12d44063e0 100644 --- a/python/paddle/incubate/sparse/nn/functional/activation.py +++ b/python/paddle/incubate/sparse/nn/functional/activation.py @@ -55,7 +55,7 @@ def softmax(x, axis=-1, name=None): sparse softmax activation, x must be SparseCsrTensor or SparseCooTensor. Note: - Only supported axis=-1 for SparseCsrTensor, which is faster when read data + Only support axis=-1 for SparseCsrTensor, which is faster when read data by row (axis=-1). From the point of view of dense matrix, for each row :math:`i` and each column :math:`j` diff --git a/python/paddle/incubate/sparse/nn/layer/activation.py b/python/paddle/incubate/sparse/nn/layer/activation.py index 011fef90c0..9aec20603a 100644 --- a/python/paddle/incubate/sparse/nn/layer/activation.py +++ b/python/paddle/incubate/sparse/nn/layer/activation.py @@ -66,7 +66,7 @@ class Softmax(Layer): sparse softmax activation, x must be SparseCsrTensor or SparseCooTensor. Note: - Only supported axis=-1 for SparseCsrTensor, which is faster when read data + Only support axis=-1 for SparseCsrTensor, which is faster when read data by row (axis=-1). From the point of view of dense matrix, for each row :math:`i` and each column :math:`j` -- GitLab