sparse_mask_kernel.cu 10.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <thrust/binary_search.h>

17 18 19 20 21
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
22
#include "paddle/phi/core/visit_type.h"
23 24 25
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
26
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h"
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h"

namespace phi {
namespace sparse {

template <typename T, typename IntT>
__global__ void MaskKernel(const T* x_ptr,
                           const IntT* indices_ptr,
                           const int64_t* sparse_offsets,
                           const int64_t non_zero_num,
                           const int cols,
                           const int sparse_dim,
                           T* out_values_ptr) {
  CUDA_KERNEL_LOOP_TYPE(i, non_zero_num * cols, int64_t) {
    int64_t out_i = i / cols;
    int64_t col_i = i - out_i * cols;
    int64_t index = 0;
    for (int j = 0; j < sparse_dim; j++) {
      index += indices_ptr[j * non_zero_num + i] * sparse_offsets[j];
    }
    out_values_ptr[out_i * cols + col_i] = x_ptr[index * cols + col_i];
  }
}

template <typename T, typename IntT>
void SparseMaskGPUKernel(const GPUContext& dev_ctx,
                         const DenseTensor& x,
                         const SparseCooTensor& mask,
                         SparseCooTensor* out) {
  const DDim& dims = x.dims();
  PADDLE_ENFORCE_EQ(
      x.dims(),
      mask.dims(),
      phi::errors::InvalidArgument("the input x and mask must have the shape"));
  const DenseTensor& indices = mask.non_zero_indices();
  const DenseTensor& values = mask.non_zero_elements();
  int sparse_dim = indices.dims().size();
64
  DenseTensor sparse_offsets = phi::Empty<GPUContext>(
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
      dev_ctx,
      DenseTensorMeta(DataType::INT64, {sparse_dim}, DataLayout::NCHW));
  std::vector<int64_t> h_sparse_offsets(sparse_dim);
  int64_t offset = 1;
  for (int i = sparse_dim - 1; i >= 0; i--) {
    h_sparse_offsets[i] = offset;
    offset *= dims[i];
  }

  phi::backends::gpu::GpuMemcpyAsync(sparse_offsets.data<int64_t>(),
                                     &h_sparse_offsets[0],
                                     sizeof(int64_t) * sparse_dim,
#ifdef PADDLE_WITH_HIP
                                     hipMemcpyHostToDevice,
#else
                                     cudaMemcpyHostToDevice,
#endif
                                     dev_ctx.stream());

  DenseTensor out_indices = phi::EmptyLike<T>(dev_ctx, indices);
  DenseTensor out_values = phi::EmptyLike<T>(dev_ctx, values);

  phi::Copy(dev_ctx, indices, dev_ctx.GetPlace(), false, &out_indices);

  const IntT* indices_ptr = indices.data<IntT>();
  T* out_values_ptr = out_values.data<T>();
  const T* x_ptr = x.data<T>();
  const int64_t non_zero_num = mask.nnz();
  auto dims_2d = flatten_to_2d(dims, sparse_dim);
  const int cols = dims_2d[1];

  auto config =
      phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num * cols, 1);
  MaskKernel<T, IntT><<<config.block_per_grid, config.thread_per_block>>>(
      x_ptr,
      indices_ptr,
      sparse_offsets.data<int64_t>(),
      non_zero_num,
      cols,
      sparse_dim,
      out_values_ptr);

  out->SetMember(out_indices, out_values, dims, true);
}

/**
 * @brief Filter the DenseTensor x by the
 * mask.non_zero_indices() and output a SparseCooTensor
 * x and mask must have the same shape.
**/
template <typename T, typename Context>
void SparseMaskKernel(const Context& dev_ctx,
                      const DenseTensor& x,
                      const SparseCooTensor& mask,
                      SparseCooTensor* out) {
120
  PD_VISIT_INTEGRAL_TYPES(
121 122 123 124 125
      mask.non_zero_indices().dtype(), "SparseMaskGPUKernel", ([&] {
        SparseMaskGPUKernel<T, data_t>(dev_ctx, x, mask, out);
      }));
}

126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
template <typename T, typename IntT>
__global__ void SparseMaskCopyKernel(const IntT* x_indexs,
                                     const IntT* mask_indexs,
                                     const IntT* bound_out,
                                     const T* x_values,
                                     const int64_t n,
                                     const int64_t stride,
                                     T* out_values) {
  CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) {
    const IntT j = bound_out[i];
    if (j >= 0 && j < n && mask_indexs[i] == x_indexs[j]) {
      for (int k = 0; k < stride; k++) {
        out_values[i * stride + k] = x_values[j * stride + k];
      }
    }
  }
}

template <typename T, typename IntT>
void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
                               const SparseCooTensor& x,
                               const DenseTensor& mask_indices,
                               DenseTensor* out) {
  PADDLE_ENFORCE_EQ(
      mask_indices.dims().size(),
      2,
      phi::errors::InvalidArgument("the mask_indices must be 2-D tensor"));

  const int64_t sparse_dim = x.non_zero_indices().dims()[0];
  auto indices_dtype = paddle::experimental::CppTypeToDataType<IntT>::Type();

  std::vector<IntT> sparse_offsets(sparse_dim);

  DenseTensorMeta x_indexs_meta(indices_dtype, {x.nnz()}, DataLayout::NCHW);
  DenseTensorMeta mask_indexs_meta(
      indices_dtype, {mask_indices.dims()[1]}, DataLayout::NCHW);
  DenseTensorMeta sparse_offset_meta(
      indices_dtype, {sparse_dim}, DataLayout::NCHW);

  DenseTensor x_indexs =
      phi::Empty<GPUContext>(dev_ctx, std::move(x_indexs_meta));
  DenseTensor mask_indexs =
      phi::Empty<GPUContext>(dev_ctx, std::move(mask_indexs_meta));
  DenseTensor bound_out =
      phi::Empty<GPUContext>(dev_ctx, std::move(mask_indexs_meta));
  DenseTensor d_sparse_offsets =
      phi::Empty<GPUContext>(dev_ctx, std::move(sparse_offset_meta));
  IntT* x_indexs_ptr = x_indexs.data<IntT>();
  IntT* mask_indexs_ptr = mask_indexs.data<IntT>();
  IntT* bound_out_ptr = bound_out.data<IntT>();

  // 1. calc the offsets of per dim
178 179
  phi::funcs::sparse::CalcOffsetsPerDim(
      x.dims(), sparse_dim, sparse_offsets.data());
180 181 182 183 184 185 186 187 188 189 190 191 192 193
  // 2. copy sparse_offsets to device
  phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data<IntT>(),
                                     sparse_offsets.data(),
                                     sizeof(IntT) * sparse_dim,
#ifdef PADDLE_WITH_HIP
                                     hipMemcpyHostToDevice,
#else
                                     cudaMemcpyHostToDevice,
#endif
                                     dev_ctx.stream());

  // 3. flatten x indices and mask indices
  auto config =
      phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_indexs.numel(), 1);
194 195 196 197 198 199 200 201 202
  phi::funcs::sparse::FlattenIndicesKernel<<<config.block_per_grid,
                                             config.thread_per_block,
                                             0,
                                             dev_ctx.stream()>>>(
      x.non_zero_indices().data<IntT>(),
      d_sparse_offsets.data<IntT>(),
      x_indexs.numel(),
      sparse_dim,
      x_indexs_ptr);
203 204 205

  config =
      phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1);
206 207 208 209 210 211 212 213 214
  phi::funcs::sparse::FlattenIndicesKernel<<<config.block_per_grid,
                                             config.thread_per_block,
                                             0,
                                             dev_ctx.stream()>>>(
      mask_indices.data<IntT>(),
      d_sparse_offsets.data<IntT>(),
      mask_indexs.numel(),
      sparse_dim,
      mask_indexs_ptr);
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
// 4. call thrust::lower_bound
#ifdef PADDLE_WITH_HIP
  thrust::lower_bound(thrust::hip::par.on(dev_ctx.stream()),
#else
  thrust::lower_bound(thrust::cuda::par.on(dev_ctx.stream()),
#endif
                      x_indexs_ptr,
                      x_indexs_ptr + x_indexs.numel(),
                      mask_indexs_ptr,
                      mask_indexs_ptr + mask_indexs.numel(),
                      bound_out_ptr);

  // 5. copy value to out
  *out = phi::EmptyLike<T>(dev_ctx, x.non_zero_elements());
  phi::funcs::SetConstant<GPUContext, T> set_zero;
  set_zero(dev_ctx, out, static_cast<T>(0));
  T* out_ptr = out->data<T>();

  const int64_t stride =
      x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim;

  SparseMaskCopyKernel<<<config.block_per_grid,
                         config.thread_per_block,
                         0,
                         dev_ctx.stream()>>>(x_indexs_ptr,
                                             mask_indexs_ptr,
                                             bound_out_ptr,
                                             x.non_zero_elements().data<T>(),
                                             mask_indexs.numel(),
                                             stride,
                                             out_ptr);
}

template <typename T, typename Context>
void SparseMaskHelperKernel(const Context& dev_ctx,
                            const SparseCooTensor& x,
                            const DenseTensor& mask_indices,
                            DenseTensor* out) {
253
  PD_VISIT_INTEGRAL_TYPES(
254 255 256 257 258
      x.non_zero_indices().dtype(), "SparseMaskHelperGPUKernel", ([&] {
        SparseMaskHelperGPUKernel<T, data_t>(dev_ctx, x, mask_indices, out);
      }));
}

259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
}  // namespace sparse
}  // namespace phi

PD_REGISTER_KERNEL(sparse_mask,
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::SparseMaskKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {
  kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
276 277 278 279 280 281 282 283 284 285 286 287 288 289

PD_REGISTER_KERNEL(sparse_mask_helper,
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::SparseMaskHelperKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int16_t,
                   int,
                   int64_t) {
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}