coalesced_kernel.cu 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
#include "paddle/phi/kernels/funcs/sparse/utils.cu.h"
#include "paddle/phi/kernels/sparse/coalesced_kernel.h"

namespace phi {
namespace sparse {

template <typename T, typename IntT>
void CoalescedGPUKernel(const GPUContext& dev_ctx,
                        const SparseCooTensor& x,
                        SparseCooTensor* out) {
  const DenseTensor& x_indices = x.non_zero_indices();
  const DenseTensor& x_values = x.non_zero_elements();
  DenseTensor out_indices = phi::EmptyLike<IntT>(dev_ctx, x_indices);
  DenseTensor out_values = phi::EmptyLike<T>(dev_ctx, x_values);

  const int64_t nnz = x.nnz();
  const int64_t sparse_dim = x.non_zero_indices().dims()[0];
  std::vector<IntT> sparse_offsets(sparse_dim);

  phi::funcs::sparse::CalcOffsetsPerDim<IntT>(
      x.dims(), sparse_dim, sparse_offsets.data());

  DenseTensorMeta sparse_offset_meta(
      paddle::experimental::CppTypeToDataType<IntT>::Type(),
      {sparse_dim},
      DataLayout::NCHW);
  DenseTensor d_sparse_offsets =
      phi::Empty<GPUContext>(dev_ctx, std::move(sparse_offset_meta));
  DenseTensor indexs = phi::Empty(
      dev_ctx, DenseTensorMeta(x_indices.dtype(), {nnz}, x_indices.layout()));
  IntT* indexs_ptr = indexs.data<IntT>();

  phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data<IntT>(),
                                     sparse_offsets.data(),
                                     sizeof(IntT) * sparse_dim,
#ifdef PADDLE_WITH_HIP
                                     hipMemcpyHostToDevice,
#else
                                     cudaMemcpyHostToDevice,
#endif
                                     dev_ctx.stream());

  // 1. flatten indices
  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, nnz, 1);
  phi::funcs::sparse::FlattenIndicesKernel<<<config.block_per_grid,
                                             config.thread_per_block,
                                             0,
                                             dev_ctx.stream()>>>(
      x.non_zero_indices().data<IntT>(),
      d_sparse_offsets.data<IntT>(),
      indexs.numel(),
      sparse_dim,
      indexs_ptr);

  // 2. get the address of each non-zero values
  const T* x_values_ptr = x_values.data<T>();
  const int64_t stride =
      x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim;
  DenseTensor values_indexs = phi::Empty(
      dev_ctx, DenseTensorMeta(DataType::INT32, {nnz}, DataLayout::NCHW));
  int* values_indexs_ptr = values_indexs.data<int>();
  DenseTensor public_indexs = phi::EmptyLike<int>(dev_ctx, values_indexs);

  // values_indexs = [0,1,2,,,nnz-1]
  phi::IndexKernel<int, kps::IdentityFunctor<int>>(
      dev_ctx, &values_indexs, kps::IdentityFunctor<int>());
  phi::IndexKernel<int, kps::IdentityFunctor<int>>(
      dev_ctx, &public_indexs, kps::IdentityFunctor<int>());

// 3. sort (indices, values index)
#ifdef PADDLE_WITH_HIP
  thrust::sort_by_key(thrust::hip::par.on(dev_ctx.stream()),
#else
  thrust::sort_by_key(thrust::cuda::par.on(dev_ctx.stream()),
#endif
                      indexs_ptr,
                      indexs_ptr + nnz,
                      values_indexs_ptr);

  // 4. unique index
  thrust::pair<IntT*, int*> new_end =
#ifdef PADDLE_WITH_HIP
      thrust::unique_by_key(thrust::hip::par.on(dev_ctx.stream()),
#else
      thrust::unique_by_key(thrust::cuda::par.on(dev_ctx.stream()),
#endif
                            indexs_ptr,
                            indexs_ptr + nnz,
                            public_indexs.data<int>());

  phi::funcs::sparse::DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>(
      indexs_ptr, new_end.first, out_indices.data<IntT>());

  IntT out_nnz = 0;
  phi::backends::gpu::GpuMemcpyAsync(&out_nnz,
                                     out_indices.data<IntT>(),
                                     sizeof(IntT),
#ifdef PADDLE_WITH_HIP
                                     hipMemcpyDeviceToHost,
#else
                                     cudaMemcpyDeviceToHost,
#endif
                                     dev_ctx.stream());
  dev_ctx.Wait();

  out_indices.Resize({x_indices.dims()[0], out_nnz});
  if (out_values.dims().size() == 1) {
    out_values.Resize(phi::make_ddim({out_nnz}));
  } else {
    out_values.Resize(phi::make_ddim({out_nnz, x_values.dims()[1]}));
  }

  // 5. scatter the values
  config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, nnz * stride, 1);
  phi::funcs::sparse::ScatterKernel<T><<<config.block_per_grid,
                                         config.thread_per_block,
                                         0,
                                         dev_ctx.stream()>>>(
      x_values_ptr,
      public_indexs.data<int>(),
      values_indexs_ptr,
      out_nnz,
      nnz,
      stride,
      out_values.data<T>());

  // 6. convert index to coordinate
  Dim<DDim::kMaxRank> const_dims;
  for (int i = 0; i < x.dims().size(); i++) {
    const_dims[i] = x.dims()[i];
  }

  config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_nnz, 1);
  phi::funcs::sparse::IndexToCoordinateKernel<<<config.block_per_grid,
                                                config.thread_per_block,
                                                0,
                                                dev_ctx.stream()>>>(
      indexs_ptr, const_dims, out_nnz, sparse_dim, out_indices.data<IntT>());

  out->SetMember(out_indices, out_values, x.dims(), true);
}

template <typename T, typename Context>
void CoalescedKernel(const Context& dev_ctx,
                     const SparseCooTensor& x,
                     SparseCooTensor* out) {
  PD_VISIT_INTEGRAL_TYPES(
      x.non_zero_indices().dtype(), "CoalescedGPUKernel", ([&] {
        CoalescedGPUKernel<T, data_t>(dev_ctx, x, out);
      }));
}

}  // namespace sparse
}  // namespace phi

PD_REGISTER_KERNEL(sort,
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::CoalescedKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int16_t,
                   int,
                   int64_t) {
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}