未验证 提交 5161a047 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

【Sparse】add SparseTensor mv kernel(csr*dense_vec->dence_vec, coo*dense_vec->dense_vec) (#43668)

* [Sparse]add SparseTensor mv kernel(csr*dense_vec->dence_vec, coo*dense_vec->dense_vec)

* fix CI
上级 1dc2117f
......@@ -29,7 +29,6 @@ namespace dynload {
extern DynLoad__##__name __name
#if defined(PADDLE_WITH_CUDA)
// APIs available after CUDA 11.0
#if CUDA_VERSION >= 11000
#define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \
......@@ -43,10 +42,14 @@ namespace dynload {
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateDnMat); \
__macro(cusparseCreateDnVec); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat);
__macro(cusparseDestroyDnMat); \
__macro(cusparseDestroyDnVec); \
__macro(cusparseSpMV_bufferSize); \
__macro(cusparseSpMV);
CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
#endif
......
......@@ -41,7 +41,6 @@ extern void *cusparse_dso_handle;
extern DynLoad__##__name __name
#if defined(PADDLE_WITH_CUDA)
// APIs available after CUDA 11.0
#if CUDA_VERSION >= 11000
#define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \
......@@ -55,10 +54,14 @@ extern void *cusparse_dso_handle;
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateDnMat); \
__macro(cusparseCreateDnVec); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat);
__macro(cusparseDestroyDnMat); \
__macro(cusparseDestroyDnVec); \
__macro(cusparseSpMV_bufferSize); \
__macro(cusparseSpMV);
CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
#endif
......
......@@ -37,6 +37,14 @@ class SparseBlas {
T beta,
phi::DenseTensor* mat_out) const;
template <typename T, typename TensorType>
void SPMV(bool transa,
T alpha,
const TensorType& mat_a,
const phi::DenseTensor& vec_x,
T beta,
phi::DenseTensor* vec_out) const;
template <typename T, typename TensorType>
void SDDMM(bool transa,
bool transb,
......@@ -60,6 +68,11 @@ class SparseBlasT : private SparseBlas<DeviceContext> {
Base()->template SPMM<T>(args...);
}
template <typename... ARGS>
void SPMV(ARGS... args) const {
Base()->template SPMV<T>(args...);
}
template <typename... ARGS>
void SDDMM(ARGS... args) const {
Base()->template SDDMM<T>(args...);
......
......@@ -20,6 +20,7 @@
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/visit_type.h"
......@@ -47,6 +48,8 @@ inline cusparseOperation_t GetTransposeOperation(const bool trans) {
}
}
/************* SPARSE MATRIX DESCRIPTOR (COO/CSR) ************/
template <typename T, typename IntT>
inline void CreateCsrDescriptor(const phi::SparseCsrTensor& x,
const phi::GPUContext& dev_ctx,
......@@ -102,6 +105,60 @@ inline void CreateCsrDescriptor(const phi::SparseCsrTensor& x,
}
}
template <typename T, typename IntT>
inline void CreateCooDescriptor(const phi::SparseCooTensor& x,
const phi::GPUContext& dev_ctx,
cusparseSpMatDescr_t* descriptor) {
std::vector<int64_t> xdim_vec = phi::vectorize(x.dims());
auto x_ndims = xdim_vec.size();
PADDLE_ENFORCE_GE(
x_ndims,
2,
phi::errors::InvalidArgument("the dim size of SparseCsrTensor must be "
"greater than or eaqual to 2."));
int64_t M = xdim_vec[x_ndims - 2];
int64_t N = xdim_vec[x_ndims - 1];
int batch_size = 1;
for (int i = 0; i < x_ndims - 2; i++) {
batch_size *= xdim_vec[i];
}
int64_t nnz = x.nnz();
const IntT* indices_data = x.non_zero_indices().data<IntT>();
const T* values_data = x.non_zero_elements().data<T>();
auto rows_data = indices_data + (x_ndims - 2) * nnz;
auto cols_data = indices_data + (x_ndims - 1) * nnz;
int64_t batch_nnz = nnz / batch_size;
cudaDataType_t gpu_type = GetGpuDataType<T>();
dev_ctx.CusparseCall([&](cusparseHandle_t handle) {
phi::dynload::cusparseCreateCoo(descriptor,
M,
N,
batch_nnz,
const_cast<IntT*>(rows_data),
const_cast<IntT*>(cols_data),
const_cast<T*>(values_data),
CUSPARSE_INDEX_64I,
CUSPARSE_INDEX_BASE_ZERO,
gpu_type);
});
if (batch_size > 1) {
#if CUDA_VERSION >= 11070
dev_ctx.CusparseCall([&](cusparseHandle_t handle) {
phi::dynload::cusparseCooSetStridedBatch(
*descriptor, batch_size, batch_nnz);
});
#else
PADDLE_THROW(phi::errors::Unimplemented(
"Batch Sparse matmul use 'cusparseCooSetStridedBatch', which is "
"supported from CUDA 11.7"));
#endif
}
}
template <typename T>
class CuSparseSpMatDescriptor {
public:
......@@ -109,12 +166,22 @@ class CuSparseSpMatDescriptor {
const phi::GPUContext& dev_ctx)
: dev_ctx_(dev_ctx) {
PD_VISIT_INTEGRAL_TYPES(
x.non_zero_crows().dtype(), "CuSparseSpMatDescriptor", ([&] {
x.non_zero_crows().dtype(), "Csr CuSparseSpMatDescriptor", ([&] {
CreateCsrDescriptor<T, data_t>(x, dev_ctx_, &descriptor_);
}));
VLOG(6) << "Create csr cusparseSpMatDescr_t " << &descriptor_;
}
explicit CuSparseSpMatDescriptor(const phi::SparseCooTensor& x,
const phi::GPUContext& dev_ctx)
: dev_ctx_(dev_ctx) {
PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "Coo CuSparseSpMatDescriptor", ([&] {
CreateCooDescriptor<T, data_t>(x, dev_ctx_, &descriptor_);
}));
VLOG(6) << "Create coo cusparseSpMatDescr_t " << &descriptor_;
}
~CuSparseSpMatDescriptor() {
dev_ctx_.CusparseCall([&](cusparseHandle_t handle) {
phi::dynload::cusparseDestroySpMat(descriptor_);
......@@ -129,6 +196,7 @@ class CuSparseSpMatDescriptor {
cusparseSpMatDescr_t descriptor_;
};
/************* DENSE MATRIX DESCRIPTOR ************/
template <typename T>
class CuSparseDnMatDescriptor {
public:
......@@ -192,6 +260,44 @@ class CuSparseDnMatDescriptor {
cusparseDnMatDescr_t descriptor_;
};
/************* DENSE VECTOR DESCRIPTOR ************/
template <typename T>
class CuSparseDnVecDescriptor {
public:
explicit CuSparseDnVecDescriptor(const phi::DenseTensor& x,
const phi::GPUContext& dev_ctx)
: dev_ctx_(dev_ctx) {
std::vector<int64_t> xdim_vec = phi::vectorize(x.dims());
auto x_ndims = xdim_vec.size();
PADDLE_ENFORCE_GE(x_ndims,
1,
phi::errors::InvalidArgument(
"the dim size of Vec must be eaqual to 1."));
const T* x_data = x.data<T>();
cudaDataType_t gpu_type = GetGpuDataType<T>();
dev_ctx_.CusparseCall([&](cusparseHandle_t handle) {
phi::dynload::cusparseCreateDnVec(
&descriptor_, x.numel(), const_cast<T*>(x_data), gpu_type);
});
VLOG(6) << "Create cusparseDnVecDescr_t " << &descriptor_;
}
~CuSparseDnVecDescriptor() {
dev_ctx_.CusparseCall([&](cusparseHandle_t handle) {
phi::dynload::cusparseDestroyDnVec(descriptor_);
});
VLOG(6) << "Destroy cusparseDnVecDescr_t " << &descriptor_;
}
const cusparseDnVecDescr_t& descriptor() const { return descriptor_; }
private:
const phi::GPUContext& dev_ctx_;
cusparseDnVecDescr_t descriptor_;
};
template <>
template <typename T, typename TensorType>
void SparseBlas<phi::GPUContext>::SPMM(bool transa,
......@@ -239,6 +345,50 @@ void SparseBlas<phi::GPUContext>::SPMM(bool transa,
});
}
template <>
template <typename T, typename TensorType>
void SparseBlas<phi::GPUContext>::SPMV(bool transa,
T alpha,
const TensorType& mat_a,
const phi::DenseTensor& vec_x,
T beta,
phi::DenseTensor* vec_out) const {
auto a_descriptor = CuSparseSpMatDescriptor<T>(mat_a, dev_ctx_);
auto x_descriptor = CuSparseDnVecDescriptor<T>(vec_x, dev_ctx_);
auto out_descriptor = CuSparseDnVecDescriptor<T>(*vec_out, dev_ctx_);
cudaDataType_t gpu_type = GetGpuDataType<T>();
size_t buffer_size = 0;
dev_ctx_.CusparseCall([&](cusparseHandle_t handle) {
phi::dynload::cusparseSpMV_bufferSize(handle,
GetTransposeOperation(transa),
&alpha,
a_descriptor.descriptor(),
x_descriptor.descriptor(),
&beta,
out_descriptor.descriptor(),
gpu_type,
CUSPARSE_MV_ALG_DEFAULT,
&buffer_size);
});
paddle::memory::allocation::AllocationPtr tmp_buffer =
paddle::memory::Alloc(dev_ctx_, buffer_size);
void* tmp_buffer_ptr = tmp_buffer->ptr();
dev_ctx_.CusparseCall([&](cusparseHandle_t handle) {
phi::dynload::cusparseSpMV(handle,
GetTransposeOperation(transa),
&alpha,
a_descriptor.descriptor(),
x_descriptor.descriptor(),
&beta,
out_descriptor.descriptor(),
gpu_type,
CUSPARSE_MV_ALG_DEFAULT,
tmp_buffer_ptr);
});
}
#if CUDA_VERSION >= 11030
template <>
template <typename T, typename TensorType>
......@@ -249,12 +399,11 @@ void SparseBlas<phi::GPUContext>::SDDMM(bool transa,
const phi::DenseTensor& mat_b,
T beta,
TensorType* mat_out) const {
cudaDataType_t gpu_type = GetGpuDataType<T>();
auto a_descriptor = CuSparseDnMatDescriptor<T>(mat_a, dev_ctx_);
auto b_descriptor = CuSparseDnMatDescriptor<T>(mat_b, dev_ctx_);
auto out_descriptor = CuSparseSpMatDescriptor<T>(*mat_out, dev_ctx_);
cudaDataType_t gpu_type = GetGpuDataType<T>();
size_t buffer_size = 0;
dev_ctx_.CusparseCall([&](cusparseHandle_t handle) {
phi::dynload::cusparseSDDMM_bufferSize(handle,
......
......@@ -27,7 +27,7 @@ void CsrDenseMatmulKernel(const Context& dev_ctx,
const DenseTensor& y,
DenseTensor* out) {
PADDLE_THROW(phi::errors::Unimplemented(
"Not support CPU kernel of Sparse Matmul now."));
"Not support CPU kernel of 'sparse.matmul' now."));
}
// TODO(zhouwei25): implement CPU kernel of " DENSE @ DENSE * CSR_MASK -> CSR"
......@@ -38,7 +38,7 @@ void CsrMaskedMatmulKernel(const Context& dev_ctx,
const SparseCsrTensor& mask,
SparseCsrTensor* out) {
PADDLE_THROW(phi::errors::Unimplemented(
"Not support CPU kernel of Matmul Mask As Sparse now."));
"Not support CPU kernel of 'sparse.masked_matmul' now."));
}
} // namespace sparse
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/mv_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void MvCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& vec,
const DenseTensor& dout,
SparseCooTensor* dx,
DenseTensor* dvec) {
PADDLE_THROW(phi::errors::Unimplemented(
"Not support CPU backward kernel of 'sparse.mv' now."));
}
template <typename T, typename Context>
void MvCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const DenseTensor& vec,
const DenseTensor& dout,
SparseCsrTensor* dx,
DenseTensor* dvec) {
PADDLE_THROW(phi::errors::Unimplemented(
"Not support CPU backward kernel of 'sparse.mv' now."));
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(
mv_coo_grad, CPU, ALL_LAYOUT, phi::sparse::MvCooGradKernel, float, double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
PD_REGISTER_KERNEL(
mv_csr_grad, CPU, ALL_LAYOUT, phi::sparse::MvCsrGradKernel, float, double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/mv_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void MvCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const DenseTensor& vec,
DenseTensor* out) {
PADDLE_THROW(
phi::errors::Unimplemented("Not support CPU kernel of 'sparse.mv' now."));
}
template <typename T, typename Context>
void MvCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& vec,
DenseTensor* out) {
PADDLE_THROW(
phi::errors::Unimplemented("Not support CPU kernel of 'sparse.mv' now."));
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(
mv_csr, CPU, ALL_LAYOUT, phi::sparse::MvCsrKernel, float, double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}
PD_REGISTER_KERNEL(
mv_coo, CPU, ALL_LAYOUT, phi::sparse::MvCooKernel, float, double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
......@@ -28,12 +28,10 @@ void EmptyLikeCooKernel(const Context& dev_ctx,
SparseCooTensor* out) {
const DenseTensor& x_indices = x.non_zero_indices();
const DenseTensor& x_values = x.non_zero_elements();
DenseTensor* out_indices = out->mutable_non_zero_indices();
DenseTensor* out_values = out->mutable_non_zero_elements();
phi::Copy(dev_ctx, x_indices, dev_ctx.GetPlace(), false, out_indices);
phi::Copy(dev_ctx, x_values, dev_ctx.GetPlace(), false, out_values);
out_values->Resize(x_values.dims());
dev_ctx.template Alloc<T>(out_values);
......@@ -48,7 +46,6 @@ void EmptyLikeCsrKernel(const Context& dev_ctx,
const DenseTensor& x_crows = x.non_zero_crows();
const DenseTensor& x_cols = x.non_zero_cols();
const DenseTensor& x_values = x.non_zero_elements();
DenseTensor* out_crows = out->mutable_non_zero_crows();
DenseTensor* out_cols = out->mutable_non_zero_cols();
DenseTensor* out_values = out->mutable_non_zero_elements();
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/mv_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
namespace phi {
namespace sparse {
template <typename T, typename IntT>
__global__ void MvCooGradGpuKernel(const T *dout,
const T *vec,
const IntT *dx_indices,
T *dx_values,
int nnz) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < nnz; idx += blockDim.x * gridDim.x) {
int i = dx_indices[idx];
int j = dx_indices[idx + nnz];
dx_values[idx] = dout[i] * vec[j];
}
}
template <typename T, typename IntT>
__global__ void MvCsrGradGpuKernel(const T *dout,
const T *vec,
const IntT *dx_crows,
const IntT *dx_cols,
T *dx_values,
int row_number) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
for (; i < row_number; i += gridDim.x * blockDim.x) {
int row_first = static_cast<int>(dx_crows[i]);
int row_nnz = static_cast<int>(dx_crows[i + 1] - dx_crows[i]);
int non_zero_idx = blockIdx.y * blockDim.y + threadIdx.y;
for (; non_zero_idx < row_nnz; non_zero_idx += gridDim.y * blockDim.y) {
int j = dx_cols[row_first + non_zero_idx];
dx_values[row_first + non_zero_idx] = dout[i] * vec[j];
}
}
}
template <typename T, typename Context>
void MvCooGradKernel(const Context &dev_ctx,
const SparseCooTensor &x,
const DenseTensor &vec,
const DenseTensor &dout,
SparseCooTensor *dx,
DenseTensor *dvec) {
// dx{SparseCoo} = dout{Dense} * vec'{Dense}
if (dx) {
// InferMeta of SparseCooTensor 'dx', CreateLikeInferMeta
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx);
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, dx->nnz());
PD_VISIT_INTEGRAL_TYPES(
dx->non_zero_indices().dtype(), "MvCooGradKernel", ([&] {
MvCooGradGpuKernel<T>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(dout.data<T>(),
vec.data<T>(),
dx->non_zero_indices().data<data_t>(),
dx->mutable_non_zero_elements()->data<T>(),
dx->nnz());
}));
}
// dvec{Dense} = x'{SparseCoo} * dout{Dense}
if (dvec) {
#if CUDA_VERSION >= 11000
// InferMeta of DenseTensor 'dvec'
dvec->Resize(vec.dims());
dev_ctx.template Alloc<T>(dvec);
auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
sparse_blas.SPMV(true, static_cast<T>(1), x, dout, static_cast<T>(0), dvec);
#else
PADDLE_THROW(
phi::errors::Unimplemented(" vec.grad of 'sparse.mv' use cusparseSpMV, "
"which is supported from CUDA 11.0"));
#endif
}
}
template <typename T, typename Context>
void MvCsrGradKernel(const Context &dev_ctx,
const SparseCsrTensor &x,
const DenseTensor &vec,
const DenseTensor &dout,
SparseCsrTensor *dx,
DenseTensor *dvec) {
// dx{SparseCsr} = dout{Dense} * vec'{Dense}
if (dx) {
// InferMeta of SparseCsrTensor 'dx', CreateLikeInferMeta
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx);
int row_number = dx->dims()[0];
int col_number = dx->dims()[1];
auto config = phi::backends::gpu::GetGpuLaunchConfig2D(
dev_ctx, col_number, row_number);
PD_VISIT_INTEGRAL_TYPES(
dx->non_zero_crows().dtype(), "MvCsrGradKernel", ([&] {
MvCsrGradGpuKernel<T>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(dout.data<T>(),
vec.data<T>(),
dx->non_zero_crows().data<data_t>(),
dx->non_zero_cols().data<data_t>(),
dx->mutable_non_zero_elements()->data<T>(),
row_number);
}));
}
// dvec{Dense} = x'{SparseCsr} * dout{Dense}
if (dvec) {
#if CUDA_VERSION >= 11000
// InferMeta of DenseTensor 'dvec'
dvec->Resize(vec.dims());
dev_ctx.template Alloc<T>(dvec);
auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
sparse_blas.SPMV(true, static_cast<T>(1), x, dout, static_cast<T>(0), dvec);
#else
PADDLE_THROW(
phi::errors::Unimplemented(" vec.grad of 'sparse.mv' use cusparseSpMV, "
"which is supported from CUDA 11.0"));
#endif
}
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(
mv_coo_grad, GPU, ALL_LAYOUT, phi::sparse::MvCooGradKernel, float, double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
PD_REGISTER_KERNEL(
mv_csr_grad, GPU, ALL_LAYOUT, phi::sparse::MvCsrGradKernel, float, double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/mv_kernel.h"
#include <vector>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h"
namespace phi {
namespace sparse {
template <typename T, typename Context, typename TensorType>
void MvKernelImpl(const Context& dev_ctx,
const TensorType& x,
const DenseTensor& vec,
DenseTensor* out) {
#if CUDA_VERSION >= 11000
std::vector<int64_t> x_dim = phi::vectorize(x.dims());
std::vector<int64_t> vec_dim = phi::vectorize(vec.dims());
auto x_ndims = x_dim.size();
auto vec_ndims = vec_dim.size();
PADDLE_ENFORCE_EQ(x_ndims,
2,
phi::errors::InvalidArgument(
"the dims size of Input(x) must be eaqual to 2."));
PADDLE_ENFORCE_EQ(vec_ndims,
1,
phi::errors::InvalidArgument(
"the dims size of Input(vec) must be eaqual to 1."));
PADDLE_ENFORCE_EQ(x_dim[x_ndims - 1],
vec_dim[vec_ndims - 1],
phi::errors::PreconditionNotMet(
"The shape of Input(x) and Input(vec) is not "
"suitable for mv opetation, "
"x_dim[-1] must be eaqual to vec_dim[-1]."));
std::vector<int64_t> out_dim = {x_dim[x_ndims - 2]};
out->Resize(phi::make_ddim(out_dim));
dev_ctx.template Alloc<T>(out);
auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
sparse_blas.SPMV(false, static_cast<T>(1), x, vec, static_cast<T>(0), out);
#else
PADDLE_THROW(phi::errors::Unimplemented(
" 'sparse.mv' use cusparseSpMV, which is supported from CUDA 11.0"));
#endif
}
template <typename T, typename Context>
void MvCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& vec,
DenseTensor* out) {
MvKernelImpl<T>(dev_ctx, x, vec, out);
}
template <typename T, typename Context>
void MvCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const DenseTensor& vec,
DenseTensor* out) {
MvKernelImpl<T>(dev_ctx, x, vec, out);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(
mv_csr, GPU, ALL_LAYOUT, phi::sparse::MvCsrKernel, float, double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}
PD_REGISTER_KERNEL(
mv_coo, GPU, ALL_LAYOUT, phi::sparse::MvCooKernel, float, double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
namespace phi {
namespace sparse {
/* backward of COO @ DENSE VEC -> DENSE VEC */
template <typename T, typename Context>
void MvCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& vec,
const DenseTensor& dout,
SparseCooTensor* dx,
DenseTensor* dvec);
/* backward of CSR @ DENSE VEC -> DENSE VEC */
template <typename T, typename Context>
void MvCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const DenseTensor& vec,
const DenseTensor& dout,
SparseCsrTensor* dx,
DenseTensor* dvec);
} // namespace sparse
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
namespace phi {
namespace sparse {
/* COO @ DENSE VEC -> DENSE VEC */
template <typename T, typename Context>
void MvCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& vec,
DenseTensor* out);
/* CSR @ DENSE VEC -> DENSE VEC */
template <typename T, typename Context>
void MvCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const DenseTensor& vec,
DenseTensor* out);
} // namespace sparse
} // namespace phi
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid.framework import _test_eager_guard
import numpy as np
import scipy
import scipy.sparse as sp
import unittest
import os
import re
paddle.seed(100)
def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1
@unittest.skipIf(
not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000,
"paddle is not compiled with CUDA and cuda version need to >= 11.0")
class TestCsrMv(unittest.TestCase):
# x: csr-matrix, y: dense-vec, out: dense-vec
def test_mv(self):
with _test_eager_guard():
paddle.set_default_dtype('float64')
origin_x = paddle.rand([64, 32])
mask = paddle.randint(0, 2, [64, 32])
origin_x = origin_x * mask
origin_vec = paddle.rand([32])
dense_x = origin_x.detach()
dense_x.stop_gradient = False
dense_vec = origin_vec.detach()
dense_vec.stop_gradient = False
dense_out = paddle.mv(dense_x, dense_vec)
dense_out.backward()
sp_x = origin_x.detach().to_sparse_csr()
sp_x.stop_gradient = False
sp_vec = origin_vec.detach()
sp_vec.stop_gradient = False
sp_out = paddle.incubate.sparse.mv(sp_x, sp_vec)
sp_out.backward()
self.assertTrue(np.allclose(sp_out.numpy(), dense_out.numpy()))
self.assertTrue(
np.allclose(sp_x.grad.to_dense().numpy(),
(dense_x.grad * mask).numpy()))
self.assertTrue(
np.allclose(sp_vec.grad.numpy(), dense_vec.grad.numpy()))
@unittest.skipIf(
not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000,
"paddle is not compiled with CUDA and cuda version need to >= 11.0")
class TestCooMv(unittest.TestCase):
# x: csr-matrix, y: dense-vec, out: dense-vec
def test_mv(self):
with _test_eager_guard():
paddle.set_default_dtype('float64')
origin_x = paddle.rand([64, 32])
mask = paddle.randint(0, 2, [64, 32])
origin_x = origin_x * mask
origin_vec = paddle.rand([32])
dense_x = origin_x.detach()
dense_x.stop_gradient = False
dense_vec = origin_vec.detach()
dense_vec.stop_gradient = False
dense_out = paddle.mv(dense_x, dense_vec)
dense_out.backward()
sp_x = origin_x.detach().to_sparse_coo(sparse_dim=2)
sp_x.stop_gradient = False
sp_vec = origin_vec.detach()
sp_vec.stop_gradient = False
sp_out = paddle.incubate.sparse.mv(sp_x, sp_vec)
sp_out.backward()
self.assertTrue(np.allclose(sp_out.numpy(), dense_out.numpy()))
self.assertTrue(
np.allclose(sp_x.grad.to_dense().numpy(),
(dense_x.grad * mask).numpy()))
self.assertTrue(
np.allclose(sp_vec.grad.numpy(), dense_vec.grad.numpy()))
if __name__ == "__main__":
unittest.main()
......@@ -19,6 +19,7 @@ from .unary import sqrt
from .unary import sin
from .unary import tanh
from .binary import mv
from .binary import matmul
from .binary import masked_matmul
......@@ -35,6 +36,7 @@ __all__ = [
'sqrt',
'sin',
'tanh',
'mv',
'matmul',
'masked_matmul',
'add',
......
......@@ -21,8 +21,8 @@ __all__ = []
@dygraph_only
def matmul(x, y, name=None):
"""
Warning:
This API is only used from ``CUDA 11.0`` .
Note:
This API is only supported from ``CUDA 11.0`` .
Applies matrix multiplication of two Tensors.
......@@ -83,8 +83,8 @@ def matmul(x, y, name=None):
@dygraph_only
def masked_matmul(x, y, mask, name=None):
"""
Warning:
This API is only used from ``CUDA 11.3`` .
Note:
This API is only supported from ``CUDA 11.3`` .
Applies matrix multiplication of two Dense Tensors.
......@@ -141,3 +141,59 @@ def masked_matmul(x, y, mask, name=None):
"""
return _C_ops.final_state_sparse_masked_matmul(x, y, mask)
@dygraph_only
def mv(x, vec, name=None):
"""
Note:
This API is only supported from ``CUDA 11.0`` .
Applies matrix-vector product of Sparse Matrix 'x' and Dense vector 'vec' .
The supported input/output Tensor layout are as follows:
Note:
x[SparseCsrTensor] @ y[DenseTensor] -> out[SparseCsrTensor]
x[SparseCooTensor] @ y[DenseTensor] -> out[SparseCooTensor]
It supports backward propagation.
The shape of `x` should be `[M, N]` , and the shape of `y` should be `[N]` ,
and the shape of `out` will be `[M]` .
Args:
x (Tensor): The input 2D tensor. It must be SparseCooTensor/SparseCsrTensor. The data type can be float32 or float64.
y (Tensor): The input 1D tensor. It must be DenseTensor vector. The data type can be float32 or float64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: 1D Tensor.
Examples:
.. code-block:: python
import paddle
from paddle.fluid.framework import _test_eager_guard
paddle.seed(100)
# csr @ dense -> dense
with _test_eager_guard():
crows = [0, 2, 3, 5]
cols = [1, 3, 2, 0, 1]
values = [1., 2., 3., 4., 5.]
dense_shape = [3, 4]
csr = paddle.incubate.sparse.sparse_csr_tensor(crows, cols, values, dense_shape)
# Tensor(shape=[3, 4], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True,
# crows=[0, 2, 3, 5],
# cols=[1, 3, 2, 0, 1],
# values=[1., 2., 3., 4., 5.])
vec = paddle.randn([4])
out = paddle.incubate.sparse.mv(csr, vec)
# Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [-3.85499096, -2.42975140, -1.75087738])
"""
return _C_ops.final_state_sparse_mv(x, vec)
......@@ -168,3 +168,12 @@
layout : x
intermediate : rulebook
backward : sparse_maxpool_grad
- api: mv
args : (Tensor x, Tensor vec)
output : Tensor(out)
kernel :
func : mv_coo{sparse_coo, dense -> dense},
mv_csr{sparse_csr, dense -> dense}
layout : x
backward: mv_grad
......@@ -63,6 +63,14 @@
func : multiply_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
multiply_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}
- backward_api : mv_grad
forward : mv(Tensor x, Tensor vec) -> Tensor(out)
args : (Tensor x, Tensor vec, Tensor out_grad)
output : Tensor(x_grad), Tensor(vec_grad)
kernel :
func : mv_coo_grad{sparse_coo, dense, dense -> sparse_coo, dense},
mv_csr_grad{sparse_csr, dense, dense -> sparse_csr, dense}
- backward_api : relu_grad
forward : relu(Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册