From 281ea2f462a5e146cb7b570cbe00b39bab90da36 Mon Sep 17 00:00:00 2001 From: umiswing Date: Fri, 14 Apr 2023 15:19:35 +0800 Subject: [PATCH] [Dcu]: Add rocsparse_spmm for dcu. (#52200) --- paddle/fluid/platform/dynload/rocsparse.cc | 37 ++ paddle/fluid/platform/dynload/rocsparse.h | 75 ++++ paddle/phi/backends/dynload/CMakeLists.txt | 9 +- paddle/phi/backends/dynload/dynamic_loader.cc | 2 + paddle/phi/backends/dynload/rocsparse.cc | 37 ++ paddle/phi/backends/dynload/rocsparse.h | 86 ++++ paddle/phi/backends/gpu/gpu_resources.cc | 12 + paddle/phi/kernels/funcs/sparse/sparse_blas.h | 3 + .../funcs/sparse/sparse_blas_impl.hip.h | 405 ++++++++++++++++++ .../kernels/sparse/gpu/matmul_grad_kernel.cu | 39 +- .../phi/kernels/sparse/gpu/matmul_kernel.cu | 14 +- .../kernels/sparse/gpu/sparse_utils_kernel.cu | 72 +++- 12 files changed, 772 insertions(+), 19 deletions(-) create mode 100644 paddle/fluid/platform/dynload/rocsparse.cc create mode 100644 paddle/fluid/platform/dynload/rocsparse.h create mode 100644 paddle/phi/backends/dynload/rocsparse.cc create mode 100644 paddle/phi/backends/dynload/rocsparse.h create mode 100644 paddle/phi/kernels/funcs/sparse/sparse_blas_impl.hip.h diff --git a/paddle/fluid/platform/dynload/rocsparse.cc b/paddle/fluid/platform/dynload/rocsparse.cc new file mode 100644 index 00000000000..8951acaaffa --- /dev/null +++ b/paddle/fluid/platform/dynload/rocsparse.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/dynload/rocsparse.h" + +namespace paddle { +namespace platform { +namespace dynload { + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +#ifdef ROCSPARSE_ROUTIN_EACH +ROCSPARSE_ROUTINE_EACH(DEFINE_WRAP); +#endif + +#ifdef ROCSPARSE_ROUTINE_EACH_R2 +ROCSPARSE_ROUTINE_EACH_R2(DEFINE_WRAP); +#endif + +#ifdef ROCSPARSE_ROUTINE_EACH_R3 +ROCSPARSE_ROUTINE_EACH_R3(DEFINE_WRAP); +#endif + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/rocsparse.h b/paddle/fluid/platform/dynload/rocsparse.h new file mode 100644 index 00000000000..1f3e041349c --- /dev/null +++ b/paddle/fluid/platform/dynload/rocsparse.h @@ -0,0 +1,75 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include // NOLINT +#include + +#include "paddle/phi/backends/dynload/rocsparse.h" + +namespace paddle { +namespace platform { +namespace dynload { + +/** + * The following macro definition can generate structs + * (for each function) to dynamic load rocsparse routine + * via operator overloading. + * + * note: default dynamic linked libs + */ +#define PLATFORM_DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP(__name) \ + using DynLoad__##__name = phi::dynload::DynLoad__##__name; \ + extern DynLoad__##__name __name + +#if defined(PADDLE_WITH_HIP) +#define ROCSPARSE_ROUTINE_EACH(__macro) \ + __macro(rocsparse_create_handle); \ + __macro(rocsparse_destroy_handle); \ + __macro(rocsparse_set_stream); \ + __macro(rocsparse_csr2coo); + +ROCSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP) + +#if HIP_VERSION >= 402 +#define ROCSPARSE_ROUTINE_EACH_R2(__macro) \ + __macro(rocsparse_create_coo_descr); \ + __macro(rocsparse_create_csr_descr); \ + __macro(rocsparse_destroy_spmat_descr); \ + __macro(rocsparse_create_dnmat_descr); \ + __macro(rocsparse_destroy_dnmat_descr); \ + __macro(rocsparse_spmm); + +ROCSPARSE_ROUTINE_EACH_R2(PLATFORM_DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP) +#endif + +#if HIP_VERSION >= 403 +#define ROCSPARSE_ROUTINE_EACH_R3(__macro) \ + __macro(rocsparse_sddmm_buffer_size); \ + __macro(rocsparse_sddmm_preprocess); \ + __macro(rocsparse_sddmm); + +ROCSPARSE_ROUTINE_EACH_R3(PLATFORM_DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP) +#endif + +#endif // PADDLE_WITH_HIP + +#undef PLATFORM_DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/phi/backends/dynload/CMakeLists.txt b/paddle/phi/backends/dynload/CMakeLists.txt index 5225d746f29..a96af96adac 100644 --- a/paddle/phi/backends/dynload/CMakeLists.txt +++ b/paddle/phi/backends/dynload/CMakeLists.txt @@ -20,7 +20,14 @@ if(NOT WITH_NV_JETSON) endif() if(WITH_ROCM) - list(APPEND HIP_SRCS rocblas.cc miopen.cc hiprand.cc hipfft.cc) + list( + APPEND + HIP_SRCS + rocblas.cc + miopen.cc + hiprand.cc + hipfft.cc + rocsparse.cc) endif() # There is no macOS version of NCCL. diff --git a/paddle/phi/backends/dynload/dynamic_loader.cc b/paddle/phi/backends/dynload/dynamic_loader.cc index fc32e6fe35c..4e95c2adde5 100644 --- a/paddle/phi/backends/dynload/dynamic_loader.cc +++ b/paddle/phi/backends/dynload/dynamic_loader.cc @@ -427,6 +427,8 @@ void* GetCusparseDsoHandle() { #elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) return GetDsoHandleFromSearchPath( FLAGS_cuda_dir, win_cusparse_lib, true, {cuda_lib_path}); +#elif defined(PADDLE_WITH_HIP) + return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "librocsparse.so"); #else return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcusparse.so"); #endif diff --git a/paddle/phi/backends/dynload/rocsparse.cc b/paddle/phi/backends/dynload/rocsparse.cc new file mode 100644 index 00000000000..dab0d149dd7 --- /dev/null +++ b/paddle/phi/backends/dynload/rocsparse.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/dynload/rocsparse.h" + +namespace phi { +namespace dynload { +std::once_flag rocsparse_dso_flag; +void *rocsparse_dso_handle = nullptr; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +#ifdef ROCSPARSE_ROUTINE_EACH +ROCSPARSE_ROUTINE_EACH(DEFINE_WRAP) +#endif + +#ifdef ROCSPARSE_ROUTINE_EACH_R2 +ROCSPARSE_ROUTINE_EACH_R2(DEFINE_WRAP); +#endif + +#ifdef ROCSPARSE_ROUTINE_EACH_R3 +ROCSPARSE_ROUTINE_EACH_R3(DEFINE_WRAP); +#endif + +} // namespace dynload +} // namespace phi diff --git a/paddle/phi/backends/dynload/rocsparse.h b/paddle/phi/backends/dynload/rocsparse.h new file mode 100644 index 00000000000..423bb8e1c5a --- /dev/null +++ b/paddle/phi/backends/dynload/rocsparse.h @@ -0,0 +1,86 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include // NOLINT +#include + +#include "paddle/phi/backends/dynload/dynamic_loader.h" +#include "paddle/phi/backends/dynload/port.h" + +namespace phi { +namespace dynload { +extern std::once_flag rocsparse_dso_flag; +extern void *rocsparse_dso_handle; + +/** + * The following macro definition can generate structs + * (for each function) to dynamic load rocsparse routine + * via operator overloading. + * + * note: default dynamic linked libs + */ +#define DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + rocsparse_status operator()(Args... args) { \ + using rocsparse_func = decltype(&::__name); \ + std::call_once(rocsparse_dso_flag, []() { \ + rocsparse_dso_handle = phi::dynload::GetCusparseDsoHandle(); \ + }); \ + static void *p_##__name = dlsym(rocsparse_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern DynLoad__##__name __name + +#if defined(PADDLE_WITH_HIP) +#define ROCSPARSE_ROUTINE_EACH(__macro) \ + __macro(rocsparse_create_handle); \ + __macro(rocsparse_destroy_handle); \ + __macro(rocsparse_set_stream); \ + __macro(rocsparse_csr2coo); + +ROCSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP) + +#if HIP_VERSION >= 402 +#define ROCSPARSE_ROUTINE_EACH_R2(__macro) \ + __macro(rocsparse_create_coo_descr); \ + __macro(rocsparse_create_csr_descr); \ + __macro(rocsparse_destroy_spmat_descr); \ + __macro(rocsparse_create_dnmat_descr); \ + __macro(rocsparse_destroy_dnmat_descr); \ + __macro(rocsparse_spmm); + +ROCSPARSE_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP) +#endif + +#if HIP_VERSION >= 403 +#define ROCSPARSE_ROUTINE_EACH_R3(__macro) \ + __macro(rocsparse_sddmm_buffer_size); \ + __macro(rocsparse_sddmm_preprocess); \ + __macro(rocsparse_sddmm); + +ROCSPARSE_ROUTINE_EACH_R3(DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP) +#endif + +#endif // PADDLE_WITH_HIP + +#undef DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP +} // namespace dynload +} // namespace phi diff --git a/paddle/phi/backends/gpu/gpu_resources.cc b/paddle/phi/backends/gpu/gpu_resources.cc index 622891c93bb..06dc74f2d27 100644 --- a/paddle/phi/backends/gpu/gpu_resources.cc +++ b/paddle/phi/backends/gpu/gpu_resources.cc @@ -33,6 +33,10 @@ #endif // !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) #endif // PADDLE_WITH_CUDA +#ifdef PADDLE_WITH_HIP +#include "paddle/phi/backends/dynload/rocsparse.h" +#endif + #include "glog/logging.h" #include "unsupported/Eigen/CXX11/Tensor" @@ -295,6 +299,9 @@ void InitSparseHandle(sparseHandle_t* handle, gpuStream_t stream) { PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseCreate(handle)); PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseSetStream(*handle, stream)); #endif +#elif defined(PADDLE_WITH_HIP) + phi::dynload::rocsparse_create_handle(handle); + phi::dynload::rocsparse_set_stream(*handle, stream); #endif } @@ -306,6 +313,11 @@ void DestroySparseHandle(sparseHandle_t handle) { handle = nullptr; } #endif +#elif defined(PADDLE_WITH_HIP) + if (handle != nullptr) { + phi::dynload::rocsparse_destroy_handle(handle); + handle = nullptr; + } #endif } diff --git a/paddle/phi/kernels/funcs/sparse/sparse_blas.h b/paddle/phi/kernels/funcs/sparse/sparse_blas.h index 858620d0d41..f6d67488d1f 100644 --- a/paddle/phi/kernels/funcs/sparse/sparse_blas.h +++ b/paddle/phi/kernels/funcs/sparse/sparse_blas.h @@ -97,3 +97,6 @@ inline SparseBlasT GetSparseBlas( #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11000 #include "paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h" #endif +#if defined(PADDLE_WITH_HIP) && HIP_VERSION >= 402 +#include "paddle/phi/kernels/funcs/sparse/sparse_blas_impl.hip.h" +#endif diff --git a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.hip.h b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.hip.h new file mode 100644 index 00000000000..cbd42be3cb6 --- /dev/null +++ b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.hip.h @@ -0,0 +1,405 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/dynload/rocsparse.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" +#include "paddle/phi/core/visit_type.h" +namespace phi { +namespace funcs { +namespace sparse { + +template +rocsparse_indextype GetGpuIndexType() { + if (std::is_same::value) { + return rocsparse_indextype_i32; + } else if (std::is_same::value) { + return rocsparse_indextype_i64; + } +} + +template +rocsparse_datatype GetGpuDataType() { + if (std::is_same::value) { + return rocsparse_datatype_f32_r; + } else if (std::is_same::value) { + return rocsparse_datatype_f64_r; + } +} + +inline rocsparse_operation GetTransposeOperation(const bool trans) { + if (trans) { + return rocsparse_operation_transpose; + } else { + return rocsparse_operation_none; + } +} + +template +inline rocsparse_spmm_alg GetSpMMAlgorithm(const TensorType& x) { + return rocsparse_spmm_alg_default; +} + +/************* SPARSE MATRIX DESCRIPTOR (COO/CSR) ************/ +template +inline void CreateCsrDescriptor(const phi::SparseCsrTensor& x, + const phi::GPUContext& dev_ctx, + rocsparse_spmat_descr* descriptor) { + std::vector xdim_vec = phi::vectorize(x.dims()); + auto x_ndims = xdim_vec.size(); + PADDLE_ENFORCE_GE( + x_ndims, + 2, + phi::errors::InvalidArgument("the dim size of SparseCsrTensor must be " + "greater than or eaqual to 2.")); + int64_t M = xdim_vec[x_ndims - 2]; + int64_t N = xdim_vec[x_ndims - 1]; + int batch_size = 1; + for (int i = 0; i < x_ndims - 2; i++) { + batch_size *= xdim_vec[i]; + } + PADDLE_ENFORCE_EQ(x.non_zero_crows().numel(), + batch_size * (M + 1), + phi::errors::PreconditionNotMet( + "the length of SparseCsrTensor crows is not right.")); + + const IntT* crows_data = x.non_zero_crows().data(); + const IntT* cols_data = x.non_zero_cols().data(); + const T* values_data = x.non_zero_elements().data(); + + int64_t batch_nnz = x.nnz() / batch_size; + rocsparse_indextype itype = GetGpuIndexType(); + rocsparse_indextype jtype = GetGpuIndexType(); + rocsparse_datatype ttype = GetGpuDataType(); + dev_ctx.CusparseCall([&](rocsparse_handle handle) { + phi::dynload::rocsparse_create_csr_descr(descriptor, + M, + N, + batch_nnz, + const_cast(crows_data), + const_cast(cols_data), + const_cast(values_data), + itype, + jtype, + rocsparse_index_base_zero, + ttype); + }); + if (batch_size > 1) { + // TODO(umiswing): Add batch sparse matmul support for ROCM after 5.2.0 + PADDLE_THROW(phi::errors::Unimplemented( + "Batch Sparse matmul use 'rocsparse_coo_set_strided_batch', which is " + "supported from ROCM 5.2.0")); + } +} + +template +inline void CreateCooDescriptor(const phi::SparseCooTensor& x, + const phi::GPUContext& dev_ctx, + rocsparse_spmat_descr* descriptor) { + std::vector xdim_vec = phi::vectorize(x.dims()); + auto x_ndims = xdim_vec.size(); + PADDLE_ENFORCE_GE( + x_ndims, + 2, + phi::errors::InvalidArgument("the dim size of SparseCooTensor must be " + "greater than or eaqual to 2.")); + + int64_t M = xdim_vec[x_ndims - 2]; + int64_t N = xdim_vec[x_ndims - 1]; + int batch_size = 1; + for (int i = 0; i < x_ndims - 2; i++) { + batch_size *= xdim_vec[i]; + } + int64_t nnz = x.nnz(); + + const IntT* indices_data = x.non_zero_indices().data(); + const T* values_data = x.non_zero_elements().data(); + auto rows_data = indices_data + (x_ndims - 2) * nnz; + auto cols_data = indices_data + (x_ndims - 1) * nnz; + + int64_t batch_nnz = nnz / batch_size; + rocsparse_indextype itype = GetGpuIndexType(); + rocsparse_datatype ttype = GetGpuDataType(); + dev_ctx.CusparseCall([&](rocsparse_handle handle) { + phi::dynload::rocsparse_create_coo_descr(descriptor, + M, + N, + batch_nnz, + const_cast(rows_data), + const_cast(cols_data), + const_cast(values_data), + itype, + rocsparse_index_base_zero, + ttype); + }); + + if (batch_size > 1) { + // TODO(umiswing): Add batch sparse matmul support for ROCM after 5.2.0 + PADDLE_THROW(phi::errors::Unimplemented( + "Batch Sparse matmul use 'rocsparse_coo_set_strided_batch', which is " + "supported from ROCM 5.2.0")); + } +} + +template +class RocSparseSpMatDescriptor { + public: + explicit RocSparseSpMatDescriptor(const phi::SparseCsrTensor& x, + const phi::GPUContext& dev_ctx) + : dev_ctx_(dev_ctx) { + PD_VISIT_BASE_INTEGRAL_TYPES( + x.non_zero_crows().dtype(), "Csr RocSparseSpMatDescriptor", ([&] { + CreateCsrDescriptor(x, dev_ctx_, &descriptor_); + })); + VLOG(6) << "Create csr rocsparse_spmat_descr " << &descriptor_; + } + explicit RocSparseSpMatDescriptor(const phi::SparseCooTensor& x, + const phi::GPUContext& dev_ctx) + : dev_ctx_(dev_ctx) { + PD_VISIT_BASE_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "Coo RocSparseSpMatDescriptor", ([&] { + CreateCooDescriptor(x, dev_ctx_, &descriptor_); + })); + VLOG(6) << "Create coo rocsparse_spmat_descr " << &descriptor_; + } + + ~RocSparseSpMatDescriptor() { + dev_ctx_.CusparseCall([&](rocsparse_handle handle) { + phi::dynload::rocsparse_destroy_spmat_descr(descriptor_); + }); + VLOG(6) << "Destroy roscparse_spmat_descr " << &descriptor_; + } + + const rocsparse_spmat_descr& descriptor() const { return descriptor_; } + + private: + const phi::GPUContext& dev_ctx_; + rocsparse_spmat_descr descriptor_; +}; + +/************* DENSE MATRIX DESCRIPTOR ************/ +template +class RocSparseDnMatDescriptor { + public: + explicit RocSparseDnMatDescriptor(const phi::DenseTensor& x, + const phi::GPUContext& dev_ctx) + : dev_ctx_(dev_ctx) { + std::vector xdim_vec = phi::vectorize(x.dims()); + auto x_ndims = xdim_vec.size(); + PADDLE_ENFORCE_GE( + x_ndims, + 2, + phi::errors::InvalidArgument("the dim size of DenseTensor must be " + "greater than or eaqual to 2.")); + + int64_t M = xdim_vec[x_ndims - 2]; + int64_t N = xdim_vec[x_ndims - 1]; + int batch_size = 1; + for (int i = 0; i < x_ndims - 2; i++) { + batch_size *= xdim_vec[i]; + } + + const T* x_data = x.data(); + rocsparse_datatype ttype = GetGpuDataType(); + dev_ctx.CusparseCall([&](rocsparse_handle handle) { + phi::dynload::rocsparse_create_dnmat_descr(&descriptor_, + M, + N, + N, + const_cast(x_data), + ttype, + rocsparse_order_row); + }); + + PADDLE_ENFORCE_EQ( + x.numel(), + batch_size * M * N, + phi::errors::InvalidArgument("The number of elements in DenseTensor " + "must equals to batch_size * M * N.")); + if (batch_size > 1) { + // TODO(umiswing): Add batch sparse matmul support for ROCM after 5.2.0 + PADDLE_THROW(phi::errors::Unimplemented( + "Batch Sparse matmul use 'rocsparse_dnmat_set_strided_batch', which " + "is supported from ROCM 5.2.0")); + } + VLOG(6) << "Create cusparseDnMatDescr_t " << &descriptor_; + } + + ~RocSparseDnMatDescriptor() { + dev_ctx_.CusparseCall([&](rocsparse_handle handle) { + phi::dynload::rocsparse_destroy_dnmat_descr(descriptor_); + }); + VLOG(6) << "Destroy rocsparse_dnmat_descr " << &descriptor_; + } + + const rocsparse_dnmat_descr& descriptor() const { return descriptor_; } + + private: + const phi::GPUContext& dev_ctx_; + rocsparse_dnmat_descr descriptor_; +}; + +/************* SPARSE*DENSE->DENSE MATMUL ************/ +template <> +template +void SparseBlas::SPMM(bool transa, + bool transb, + T alpha, + const TensorType& mat_a, + const phi::DenseTensor& mat_b, + T beta, + phi::DenseTensor* mat_out) const { + auto a_descriptor = RocSparseSpMatDescriptor(mat_a, dev_ctx_); + auto b_descriptor = RocSparseDnMatDescriptor(mat_b, dev_ctx_); + auto out_descriptor = RocSparseDnMatDescriptor(*mat_out, dev_ctx_); + + rocsparse_datatype ttype = GetGpuDataType(); + size_t buffer_size = 0; + + // Query SpMM buffer + dev_ctx_.CusparseCall([&](rocsparse_handle handle) { + phi::dynload::rocsparse_spmm(handle, + GetTransposeOperation(transa), + GetTransposeOperation(transb), + &alpha, + a_descriptor.descriptor(), + b_descriptor.descriptor(), + &beta, + out_descriptor.descriptor(), + ttype, + GetSpMMAlgorithm(mat_a), + rocsparse_spmm_stage_buffer_size, + &buffer_size, + nullptr); + }); + + // Allocate buffer + phi::Allocator::AllocationPtr tmp_buffer = phi::memory_utils::Alloc( + dev_ctx_.GetPlace(), + buffer_size, + phi::Stream(reinterpret_cast(dev_ctx_.stream()))); + void* tmp_buffer_ptr = tmp_buffer->ptr(); + + // Preprocess data + dev_ctx_.CusparseCall([&](rocsparse_handle handle) { + phi::dynload::rocsparse_spmm(handle, + GetTransposeOperation(transa), + GetTransposeOperation(transb), + &alpha, + a_descriptor.descriptor(), + b_descriptor.descriptor(), + &beta, + out_descriptor.descriptor(), + ttype, + GetSpMMAlgorithm(mat_a), + rocsparse_spmm_stage_preprocess, + &buffer_size, + tmp_buffer_ptr); + }); + + // Performs the actual SpMM computation + dev_ctx_.CusparseCall([&](rocsparse_handle handle) { + phi::dynload::rocsparse_spmm(handle, + GetTransposeOperation(transa), + GetTransposeOperation(transb), + &alpha, + a_descriptor.descriptor(), + b_descriptor.descriptor(), + &beta, + out_descriptor.descriptor(), + ttype, + GetSpMMAlgorithm(mat_a), + rocsparse_spmm_stage_compute, + &buffer_size, + tmp_buffer_ptr); + }); +} + +/************* DENSE*DENSE->SPARSE MATMUL ************/ +#if HIP_VERSION >= 403 +template <> +template +void SparseBlas::SDDMM(bool transa, + bool transb, + T alpha, + const phi::DenseTensor& mat_a, + const phi::DenseTensor& mat_b, + T beta, + TensorType* mat_out) const { + auto a_descriptor = RocSparseDnMatDescriptor(mat_a, dev_ctx_); + auto b_descriptor = RocSparseDnMatDescriptor(mat_b, dev_ctx_); + auto out_descriptor = RocSparseSpMatDescriptor(*mat_out, dev_ctx_); + + rocsparse_datatype gpu_type = GetGpuDataType(); + size_t buffer_size = 0; + dev_ctx_.CusparseCall([&](rocsparse_handle handle) { + phi::dynload::rocsparse_sddmm_buffer_size(handle, + GetTransposeOperation(transa), + GetTransposeOperation(transb), + &alpha, + a_descriptor.descriptor(), + b_descriptor.descriptor(), + &beta, + out_descriptor.descriptor(), + gpu_type, + rocsparse_sddmm_alg_default, + &buffer_size); + }); + + phi::Allocator::AllocationPtr tmp_buffer = phi::memory_utils::Alloc( + dev_ctx_.GetPlace(), + buffer_size, + phi::Stream(reinterpret_cast(dev_ctx_.stream()))); + void* tmp_buffer_ptr = tmp_buffer->ptr(); + + dev_ctx_.CusparseCall([&](rocsparse_handle handle) { + phi::dynload::rocsparse_sddmm_preprocess(handle, + GetTransposeOperation(transa), + GetTransposeOperation(transb), + &alpha, + a_descriptor.descriptor(), + b_descriptor.descriptor(), + &beta, + out_descriptor.descriptor(), + gpu_type, + rocsparse_sddmm_alg_default, + tmp_buffer_ptr); + }); + + dev_ctx_.CusparseCall([&](rocsparse_handle handle) { + phi::dynload::rocsparse_sddmm(handle, + GetTransposeOperation(transa), + GetTransposeOperation(transb), + &alpha, + a_descriptor.descriptor(), + b_descriptor.descriptor(), + &beta, + out_descriptor.descriptor(), + gpu_type, + rocsparse_sddmm_alg_default, + tmp_buffer_ptr); + }); +} +#endif +} // namespace sparse +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu index 05eb6a90cb4..7dbdbe2acc9 100644 --- a/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/math_function_impl.h" #include "paddle/phi/kernels/funcs/sparse/sparse_blas.h" #include "paddle/phi/kernels/sparse/empty_kernel.h" #include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" @@ -35,7 +36,7 @@ void MatmulCooDenseGradKernel(const Context& dev_ctx, const DenseTensor& dout, SparseCooTensor* dx, DenseTensor* dy) { -#if CUDA_VERSION >= 11030 +#if CUDA_VERSION >= 11030 || HIP_VERSION >= 403 auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); // dx{SparseCoo} = dout{Dense} * y'{Dense} @@ -44,8 +45,13 @@ void MatmulCooDenseGradKernel(const Context& dev_ctx, // which will increase some expenses. EmptyLikeCooKernel(dev_ctx, x, dx); SparseCsrTensor dx_csr = CooToCsr(dev_ctx, *dx); +#ifdef PADDLE_WITH_HIP + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, dx_csr.mutable_non_zero_elements(), static_cast(0.0f)); +#endif sparse_blas.SDDMM( false, true, static_cast(1), dout, y, static_cast(0), &dx_csr); + CsrToCooKernel(dev_ctx, dx_csr, dx); } @@ -56,13 +62,29 @@ void MatmulCooDenseGradKernel(const Context& dev_ctx, meta_dy.set_dtype(y.dtype()); dev_ctx.template Alloc(dy); +#ifdef PADDLE_WITH_HIP + SparseCsrTensor x_csr = CooToCsr(dev_ctx, x); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, dy, static_cast(0.0f)); + sparse_blas.SPMM( + true, false, static_cast(1), x_csr, dout, static_cast(0), dy); +#elif defined(PADDLE_WITH_CUDA) sparse_blas.SPMM( true, false, static_cast(1), x, dout, static_cast(0), dy); +#endif } #else +#ifdef PADDLE_WITH_CUDA PADDLE_THROW(phi::errors::Unimplemented( "backward of 'sparse.matmul' use cusparseSDDMM, which is supported from " "CUDA 11.3")); +#elif defined(PADDLE_WITH_HIP) + PADDLE_THROW( + phi::errors::Unimplemented("backward of 'sparse.matmul' use " + "rocsparse_sddmm with transpose, which is " + "supported from " + "ROCM 4.3.0")); +#endif #endif } @@ -73,7 +95,7 @@ void MatmulCsrDenseGradKernel(const Context& dev_ctx, const DenseTensor& dout, SparseCsrTensor* dx, DenseTensor* dy) { -#if CUDA_VERSION >= 11030 +#if CUDA_VERSION >= 11030 || HIP_VERSION >= 403 auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); // dx{SparseCsr} = dout{Dense} * y'{Dense} @@ -94,13 +116,26 @@ void MatmulCsrDenseGradKernel(const Context& dev_ctx, dev_ctx.template Alloc(dy); +#ifdef PADDLE_WITH_HIP + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, dy, static_cast(0.0f)); +#endif + sparse_blas.SPMM( true, false, static_cast(1), x, dout, static_cast(0), dy); } #else +#ifdef PADDLE_WITH_CUDA PADDLE_THROW(phi::errors::Unimplemented( "backward of 'sparse.matmul' use cusparseSDDMM, which is supported from " "CUDA 11.3")); +#elif defined(PADDLE_WITH_HIP) + PADDLE_THROW( + phi::errors::Unimplemented("backward of 'sparse.matmul' use " + "rocsparse_sddmm with transpose, which is " + "supported from " + "ROCM 4.3.0")); +#endif #endif } diff --git a/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu b/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu index 3adbce0dd17..f39209e9b86 100644 --- a/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/math_function_impl.h" #include "paddle/phi/kernels/funcs/sparse/sparse_blas.h" #include "paddle/phi/kernels/sparse/empty_kernel.h" @@ -36,7 +37,7 @@ void MatmulKernelImpl(const Context& dev_ctx, const TensorType& x, const DenseTensor& y, DenseTensor* out) { -#if CUDA_VERSION >= 11000 +#if CUDA_VERSION >= 11000 || HIP_VERSION >= 402 std::vector xdim_vec = phi::vectorize(x.dims()); std::vector ydim_vec = phi::vectorize(y.dims()); auto x_ndims = xdim_vec.size(); @@ -80,13 +81,24 @@ void MatmulKernelImpl(const Context& dev_ctx, dev_ctx.template Alloc(out); +#ifdef PADDLE_WITH_HIP + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, out, static_cast(0.0f)); +#endif + auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); sparse_blas.SPMM( false, false, static_cast(1), x, y, static_cast(0), out); #else +#ifdef PADDLE_WITH_CUDA PADDLE_THROW( phi::errors::Unimplemented("forward of 'sparse.matmul' use cusparseSpMM, " "which is supported from CUDA 11.0")); +#elif defined(PADDLE_WITH_HIP) + PADDLE_THROW(phi::errors::Unimplemented( + "forward of 'sparse.matmul' use rocsparse_spmm, " + "which is supported from ROCM 4.2.0")); +#endif #endif } diff --git a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu index 2f86a643aa5..94fe0570563 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu @@ -17,12 +17,16 @@ limitations under the License. */ #include #include +#ifdef PADDLE_WITH_HIP +#include "paddle/phi/backends/dynload/rocsparse.h" +#endif #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/sparse/common_shape.h" @@ -214,55 +218,88 @@ void CsrToCooGPUKernel(const GPUContext& dev_ctx, SparseCooTensor* out) { const DDim& x_dims = x.dims(); const int64_t non_zero_num = x.cols().numel(); + +// rocsparse_csr2coo only support index with type 'rocsparse_int' (aka 'int') +// now +#ifdef PADDLE_WITH_HIP + const auto& csr_crows = Cast(dev_ctx, x.crows(), DataType::INT32); + const auto& csr_cols = Cast(dev_ctx, x.cols(), DataType::INT32); + const int* csr_crows_data = csr_crows.template data(); + const int* csr_cols_data = csr_cols.template data(); +#else const auto& csr_crows = x.crows(); const auto& csr_cols = x.cols(); - const auto& csr_values = x.values(); const IntT* csr_crows_data = csr_crows.data(); const IntT* csr_cols_data = csr_cols.data(); +#endif + const auto& csr_values = x.values(); const T* csr_values_data = csr_values.data(); int64_t sparse_dim = 2; if (x_dims.size() == 3) { sparse_dim = 3; } - int batchs = x_dims.size() == 2 ? 1 : x_dims[0]; + int batches = x_dims.size() == 2 ? 1 : x_dims[0]; int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1]; +#ifdef PADDLE_WITH_HIP + DenseTensor indices = phi::Empty(dev_ctx, {sparse_dim, non_zero_num}); + int* coo_indices = indices.data(); + int* coo_rows_data = coo_indices; + int* coo_cols_data = coo_rows_data + non_zero_num; +#else DenseTensor indices = phi::Empty(dev_ctx, {sparse_dim, non_zero_num}); - DenseTensor values = phi::EmptyLike(dev_ctx, csr_values); - DenseTensor offsets = phi::Empty(dev_ctx, {batchs}); + DenseTensor offsets = phi::Empty(dev_ctx, {batches}); IntT* coo_indices = indices.data(); IntT* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices; IntT* coo_rows_data = x_dims.size() == 2 ? coo_indices : batch_ptr + non_zero_num; IntT* coo_cols_data = coo_rows_data + non_zero_num; - IntT* offsets_ptr = batchs == 1 ? nullptr : offsets.data(); + IntT* offsets_ptr = batches == 1 ? nullptr : offsets.data(); +#endif + DenseTensor values = phi::EmptyLike(dev_ctx, csr_values); T* coo_values_data = values.data(); - if (batchs > 1) { - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, batchs, 1); - GetBatchSizes<<>>( - csr_crows_data, rows, batchs, offsets_ptr); - + if (batches > 1) { #ifdef PADDLE_WITH_HIP - thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()), + PADDLE_THROW( + phi::errors::Unimplemented("'rocsparse_csr2coo' only supports batches " + "with a value of 1 currently.")); #else + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, batches, 1); + GetBatchSizes<<>>( + csr_crows_data, rows, batches, offsets_ptr); + thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()), -#endif offsets_ptr, - offsets_ptr + batchs, + offsets_ptr + batches, offsets_ptr); +#endif } +#ifdef PADDLE_WITH_HIP + dev_ctx.CusparseCall([&](rocsparse_handle handle) { + phi::dynload::rocsparse_csr2coo(handle, + csr_crows_data, + non_zero_num, + rows, + coo_rows_data, + rocsparse_index_base_zero); + }); +#else auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rows, 1); - config.block_per_grid.y = batchs; + config.block_per_grid.y = batches; ConvertCsrCrowsToCooRows <<>>( csr_crows_data, offsets_ptr, coo_rows_data, batch_ptr, rows); - +#endif phi::backends::gpu::GpuMemcpyAsync(coo_cols_data, csr_cols_data, +#ifdef PADDLE_WITH_HIP + sizeof(int) * non_zero_num, +#else sizeof(IntT) * non_zero_num, +#endif gpuMemcpyDeviceToDevice, dev_ctx.stream()); phi::backends::gpu::GpuMemcpyAsync(coo_values_data, @@ -271,6 +308,11 @@ void CsrToCooGPUKernel(const GPUContext& dev_ctx, gpuMemcpyDeviceToDevice, dev_ctx.stream()); +#ifdef PADDLE_WITH_HIP + if (std::is_same::value) + indices = Cast(dev_ctx, indices, DataType::INT64); +#endif + out->SetMember(indices, values, x_dims, true); } -- GitLab