mask_kernel.cu 11.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Z
zhangkaihuo 已提交
15
#include "paddle/phi/kernels/sparse/mask_kernel.h"
16

17 18 19 20 21
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
22
#include "paddle/phi/core/tensor_utils.h"
23
#include "paddle/phi/core/visit_type.h"
24
#include "paddle/phi/kernels/empty_kernel.h"
Z
zhangkaihuo 已提交
25
#include "paddle/phi/kernels/funcs/aligned_vector.h"
26
#include "paddle/phi/kernels/funcs/math_function.h"
27
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h"
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44

namespace phi {
namespace sparse {

template <typename T, typename IntT>
__global__ void MaskKernel(const T* x_ptr,
                           const IntT* indices_ptr,
                           const int64_t* sparse_offsets,
                           const int64_t non_zero_num,
                           const int cols,
                           const int sparse_dim,
                           T* out_values_ptr) {
  CUDA_KERNEL_LOOP_TYPE(i, non_zero_num * cols, int64_t) {
    int64_t out_i = i / cols;
    int64_t col_i = i - out_i * cols;
    int64_t index = 0;
    for (int j = 0; j < sparse_dim; j++) {
Z
zhangkaihuo 已提交
45
      index += indices_ptr[j * non_zero_num + out_i] * sparse_offsets[j];
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
    }
    out_values_ptr[out_i * cols + col_i] = x_ptr[index * cols + col_i];
  }
}

template <typename T, typename IntT>
void SparseMaskGPUKernel(const GPUContext& dev_ctx,
                         const DenseTensor& x,
                         const SparseCooTensor& mask,
                         SparseCooTensor* out) {
  const DDim& dims = x.dims();
  PADDLE_ENFORCE_EQ(
      x.dims(),
      mask.dims(),
      phi::errors::InvalidArgument("the input x and mask must have the shape"));
  const DenseTensor& indices = mask.non_zero_indices();
  const DenseTensor& values = mask.non_zero_elements();
Z
zhangkaihuo 已提交
63
  const int sparse_dim = mask.sparse_dim();
64
  DenseTensor sparse_offsets = phi::Empty<GPUContext>(
65 66 67
      dev_ctx,
      DenseTensorMeta(DataType::INT64, {sparse_dim}, DataLayout::NCHW));
  std::vector<int64_t> h_sparse_offsets(sparse_dim);
Z
zhangkaihuo 已提交
68 69
  phi::funcs::sparse::CalcOffsetsPerDim(
      dims, sparse_dim, h_sparse_offsets.data());
70 71 72 73

  phi::backends::gpu::GpuMemcpyAsync(sparse_offsets.data<int64_t>(),
                                     &h_sparse_offsets[0],
                                     sizeof(int64_t) * sparse_dim,
Z
zhangkaihuo 已提交
74
                                     gpuMemcpyHostToDevice,
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
                                     dev_ctx.stream());

  DenseTensor out_indices = phi::EmptyLike<T>(dev_ctx, indices);
  DenseTensor out_values = phi::EmptyLike<T>(dev_ctx, values);

  phi::Copy(dev_ctx, indices, dev_ctx.GetPlace(), false, &out_indices);

  const IntT* indices_ptr = indices.data<IntT>();
  T* out_values_ptr = out_values.data<T>();
  const T* x_ptr = x.data<T>();
  const int64_t non_zero_num = mask.nnz();
  auto dims_2d = flatten_to_2d(dims, sparse_dim);
  const int cols = dims_2d[1];

  auto config =
      phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num * cols, 1);
Z
zhangkaihuo 已提交
91 92 93 94 95 96 97 98 99
  MaskKernel<T, IntT>
      <<<config.block_per_grid, config.thread_per_block, 0, dev_ctx.stream()>>>(
          x_ptr,
          indices_ptr,
          sparse_offsets.data<int64_t>(),
          non_zero_num,
          cols,
          sparse_dim,
          out_values_ptr);
100 101 102 103 104 105 106 107

  out->SetMember(out_indices, out_values, dims, true);
}

/**
 * @brief Filter the DenseTensor x by the
 * mask.non_zero_indices() and output a SparseCooTensor
 * x and mask must have the same shape.
108
 **/
109 110 111 112 113
template <typename T, typename Context>
void SparseMaskKernel(const Context& dev_ctx,
                      const DenseTensor& x,
                      const SparseCooTensor& mask,
                      SparseCooTensor* out) {
Z
zhangkaihuo 已提交
114
  PD_VISIT_BASE_INTEGRAL_TYPES(
115 116 117 118 119
      mask.non_zero_indices().dtype(), "SparseMaskGPUKernel", ([&] {
        SparseMaskGPUKernel<T, data_t>(dev_ctx, x, mask, out);
      }));
}

Z
zhangkaihuo 已提交
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
template <typename IntT>
__global__ void MaskTable(const IntT* x_indexs, const int n, int* table) {
  CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) {
    int index = x_indexs[i];
    table[index] = i == 0 ? -1 : i;
  }
}

template <typename T, typename IntT, int VecSize>
__global__ void MaskCopy(const IntT* mask_indexs,
                         const int* table,
                         const int n,
                         const int stride,
                         const T* x_values,
                         T* out_values) {
  using LoadT = phi::AlignedVector<T, VecSize>;
  using StoreT = phi::AlignedVector<T, VecSize>;
137
  CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) {
Z
zhangkaihuo 已提交
138 139 140 141 142 143 144
    int j = table[mask_indexs[i]];
    if (j != 0) {
      if (j == -1) j = 0;
      for (int k = 0; k < stride; k += VecSize) {
        LoadT vec_x;
        phi::Load<T, VecSize>(x_values + j * stride + k, &vec_x);
        phi::Store<T, VecSize>(vec_x, out_values + i * stride + k);
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
      }
    }
  }
}

template <typename T, typename IntT>
void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
                               const SparseCooTensor& x,
                               const DenseTensor& mask_indices,
                               DenseTensor* out) {
  PADDLE_ENFORCE_EQ(
      mask_indices.dims().size(),
      2,
      phi::errors::InvalidArgument("the mask_indices must be 2-D tensor"));

Z
zhangkaihuo 已提交
160
  const int32_t sparse_dim = x.sparse_dim();
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
  auto indices_dtype = paddle::experimental::CppTypeToDataType<IntT>::Type();

  std::vector<IntT> sparse_offsets(sparse_dim);

  DenseTensorMeta x_indexs_meta(indices_dtype, {x.nnz()}, DataLayout::NCHW);
  DenseTensorMeta mask_indexs_meta(
      indices_dtype, {mask_indices.dims()[1]}, DataLayout::NCHW);
  DenseTensorMeta sparse_offset_meta(
      indices_dtype, {sparse_dim}, DataLayout::NCHW);

  DenseTensor x_indexs =
      phi::Empty<GPUContext>(dev_ctx, std::move(x_indexs_meta));
  DenseTensor mask_indexs =
      phi::Empty<GPUContext>(dev_ctx, std::move(mask_indexs_meta));
  DenseTensor bound_out =
      phi::Empty<GPUContext>(dev_ctx, std::move(mask_indexs_meta));
  DenseTensor d_sparse_offsets =
      phi::Empty<GPUContext>(dev_ctx, std::move(sparse_offset_meta));
  IntT* x_indexs_ptr = x_indexs.data<IntT>();
  IntT* mask_indexs_ptr = mask_indexs.data<IntT>();
  IntT* bound_out_ptr = bound_out.data<IntT>();

  // 1. calc the offsets of per dim
184 185
  phi::funcs::sparse::CalcOffsetsPerDim(
      x.dims(), sparse_dim, sparse_offsets.data());
186 187 188 189
  // 2. copy sparse_offsets to device
  phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data<IntT>(),
                                     sparse_offsets.data(),
                                     sizeof(IntT) * sparse_dim,
Z
zhangkaihuo 已提交
190
                                     gpuMemcpyHostToDevice,
191 192 193 194 195
                                     dev_ctx.stream());

  // 3. flatten x indices and mask indices
  auto config =
      phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_indexs.numel(), 1);
196 197 198 199 200 201 202 203 204
  phi::funcs::sparse::FlattenIndicesKernel<<<config.block_per_grid,
                                             config.thread_per_block,
                                             0,
                                             dev_ctx.stream()>>>(
      x.non_zero_indices().data<IntT>(),
      d_sparse_offsets.data<IntT>(),
      x_indexs.numel(),
      sparse_dim,
      x_indexs_ptr);
205 206 207

  config =
      phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1);
208 209 210 211 212 213 214 215 216
  phi::funcs::sparse::FlattenIndicesKernel<<<config.block_per_grid,
                                             config.thread_per_block,
                                             0,
                                             dev_ctx.stream()>>>(
      mask_indices.data<IntT>(),
      d_sparse_offsets.data<IntT>(),
      mask_indexs.numel(),
      sparse_dim,
      mask_indexs_ptr);
217

Z
zhangkaihuo 已提交
218 219 220 221 222 223 224 225 226 227
  int table_size = 1;
  auto x_dims = x.dims();
  for (int i = 0; i < x_dims.size() - 1; i++) {
    table_size *= x_dims[i];
  }
  DenseTensor table = phi::Empty<int>(dev_ctx, {table_size});
  phi::backends::gpu::GpuMemsetAsync(
      table.data<int>(), 0, table_size * sizeof(int), dev_ctx.stream());
  const int64_t stride =
      x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1];
228 229 230 231
  *out = phi::EmptyLike<T>(dev_ctx, x.non_zero_elements());
  phi::funcs::SetConstant<GPUContext, T> set_zero;
  set_zero(dev_ctx, out, static_cast<T>(0));
  T* out_ptr = out->data<T>();
Z
zhangkaihuo 已提交
232 233 234 235 236 237 238 239 240
  config =
      phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_indexs.numel(), 1);
  MaskTable<<<config.block_per_grid,
              config.thread_per_block,
              0,
              dev_ctx.stream()>>>(
      x_indexs_ptr, x_indexs.numel(), table.data<int>());
  config =
      phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1);
241

Z
zhangkaihuo 已提交
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
  const int VecBytes = 16;
  const int VecSize = VecBytes / sizeof(T);
  if (stride % VecSize == 0) {
    MaskCopy<T, IntT, VecSize>
        <<<config.block_per_grid,
           config.thread_per_block,
           0,
           dev_ctx.stream()>>>(mask_indexs_ptr,
                               table.data<int>(),
                               mask_indexs.numel(),
                               stride,
                               x.non_zero_elements().data<T>(),
                               out_ptr);
  } else {
    MaskCopy<T, IntT, 1><<<config.block_per_grid,
                           config.thread_per_block,
                           0,
                           dev_ctx.stream()>>>(mask_indexs_ptr,
                                               table.data<int>(),
                                               mask_indexs.numel(),
                                               stride,
                                               x.non_zero_elements().data<T>(),
                                               out_ptr);
  }
266 267 268 269 270 271 272
}

template <typename T, typename Context>
void SparseMaskHelperKernel(const Context& dev_ctx,
                            const SparseCooTensor& x,
                            const DenseTensor& mask_indices,
                            DenseTensor* out) {
Z
zhangkaihuo 已提交
273
  PD_VISIT_BASE_INTEGRAL_TYPES(
274 275 276 277 278
      x.non_zero_indices().dtype(), "SparseMaskHelperGPUKernel", ([&] {
        SparseMaskHelperGPUKernel<T, data_t>(dev_ctx, x, mask_indices, out);
      }));
}

279 280 281
}  // namespace sparse
}  // namespace phi

Z
zhangkaihuo 已提交
282
PD_REGISTER_KERNEL(mask,
283 284 285 286 287 288 289 290 291 292 293 294 295
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::SparseMaskKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {
  kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
296

Z
zhangkaihuo 已提交
297
PD_REGISTER_KERNEL(mask_helper,
298 299 300 301 302 303 304 305 306 307 308 309
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::SparseMaskHelperKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int16_t,
                   int,
                   int64_t) {
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}