mv_grad_kernel.cu 6.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/sparse/mv_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"

namespace phi {
namespace sparse {

template <typename T, typename IntT>
__global__ void MvCooGradGpuKernel(const T *dout,
                                   const T *vec,
                                   const IntT *dx_indices,
                                   T *dx_values,
                                   int nnz) {
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
  for (; idx < nnz; idx += blockDim.x * gridDim.x) {
    int i = dx_indices[idx];
    int j = dx_indices[idx + nnz];
    dx_values[idx] = dout[i] * vec[j];
  }
}

template <typename T, typename IntT>
__global__ void MvCsrGradGpuKernel(const T *dout,
                                   const T *vec,
                                   const IntT *dx_crows,
                                   const IntT *dx_cols,
                                   T *dx_values,
                                   int row_number) {
  int i = blockIdx.x * blockDim.x + threadIdx.x;
  for (; i < row_number; i += gridDim.x * blockDim.x) {
    int row_first = static_cast<int>(dx_crows[i]);
    int row_nnz = static_cast<int>(dx_crows[i + 1] - dx_crows[i]);

    int non_zero_idx = blockIdx.y * blockDim.y + threadIdx.y;
    for (; non_zero_idx < row_nnz; non_zero_idx += gridDim.y * blockDim.y) {
      int j = dx_cols[row_first + non_zero_idx];
      dx_values[row_first + non_zero_idx] = dout[i] * vec[j];
    }
  }
}

template <typename T, typename Context>
void MvCooGradKernel(const Context &dev_ctx,
                     const SparseCooTensor &x,
                     const DenseTensor &vec,
                     const DenseTensor &dout,
                     SparseCooTensor *dx,
                     DenseTensor *dvec) {
  // dx{SparseCoo} = dout{Dense} * vec'{Dense}
  if (dx) {
    // InferMeta of SparseCooTensor 'dx', CreateLikeInferMeta
    EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx);
    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, dx->nnz());
    PD_VISIT_INTEGRAL_TYPES(
        dx->non_zero_indices().dtype(), "MvCooGradKernel", ([&] {
          MvCooGradGpuKernel<T>
              <<<config.block_per_grid.x,
                 config.thread_per_block.x,
                 0,
                 dev_ctx.stream()>>>(dout.data<T>(),
                                     vec.data<T>(),
                                     dx->non_zero_indices().data<data_t>(),
                                     dx->mutable_non_zero_elements()->data<T>(),
                                     dx->nnz());
        }));
  }

  // dvec{Dense} = x'{SparseCoo} * dout{Dense}
  if (dvec) {
#if CUDA_VERSION >= 11000
    // InferMeta of DenseTensor 'dvec'
    dvec->Resize(vec.dims());
    dev_ctx.template Alloc<T>(dvec);

    auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
    sparse_blas.SPMV(true, static_cast<T>(1), x, dout, static_cast<T>(0), dvec);
#else
    PADDLE_THROW(
        phi::errors::Unimplemented(" vec.grad of 'sparse.mv' use cusparseSpMV, "
                                   "which is supported from CUDA 11.0"));
#endif
  }
}

template <typename T, typename Context>
void MvCsrGradKernel(const Context &dev_ctx,
                     const SparseCsrTensor &x,
                     const DenseTensor &vec,
                     const DenseTensor &dout,
                     SparseCsrTensor *dx,
                     DenseTensor *dvec) {
  // dx{SparseCsr} = dout{Dense} * vec'{Dense}
  if (dx) {
    // InferMeta of SparseCsrTensor 'dx', CreateLikeInferMeta
    EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx);

    int row_number = dx->dims()[0];
    int col_number = dx->dims()[1];
    auto config = phi::backends::gpu::GetGpuLaunchConfig2D(
        dev_ctx, col_number, row_number);
    PD_VISIT_INTEGRAL_TYPES(
        dx->non_zero_crows().dtype(), "MvCsrGradKernel", ([&] {
          MvCsrGradGpuKernel<T>
              <<<config.block_per_grid.x,
                 config.thread_per_block.x,
                 0,
                 dev_ctx.stream()>>>(dout.data<T>(),
                                     vec.data<T>(),
                                     dx->non_zero_crows().data<data_t>(),
                                     dx->non_zero_cols().data<data_t>(),
                                     dx->mutable_non_zero_elements()->data<T>(),
                                     row_number);
        }));
  }

  // dvec{Dense} = x'{SparseCsr} * dout{Dense}
  if (dvec) {
#if CUDA_VERSION >= 11000
    // InferMeta of DenseTensor 'dvec'
    dvec->Resize(vec.dims());
    dev_ctx.template Alloc<T>(dvec);

    auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
    sparse_blas.SPMV(true, static_cast<T>(1), x, dout, static_cast<T>(0), dvec);
#else
    PADDLE_THROW(
        phi::errors::Unimplemented(" vec.grad of 'sparse.mv' use cusparseSpMV, "
                                   "which is supported from CUDA 11.0"));
#endif
  }
}

}  // namespace sparse
}  // namespace phi

PD_REGISTER_KERNEL(
    mv_coo_grad, GPU, ALL_LAYOUT, phi::sparse::MvCooGradKernel, float, double) {
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

PD_REGISTER_KERNEL(
    mv_csr_grad, GPU, ALL_LAYOUT, phi::sparse::MvCsrGradKernel, float, double) {
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}