sparse_utils_kernel.cu 23.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"

17 18 19
#include <thrust/execution_policy.h>
#include <thrust/remove.h>

20
#include "paddle/phi/backends/gpu/gpu_context.h"
21
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
22
#include "paddle/phi/core/enforce.h"
23 24
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
25
#include "paddle/phi/core/visit_type.h"
Z
zhangkaihuo 已提交
26
#include "paddle/phi/kernels/funcs/math_function.h"
27
#include "paddle/phi/kernels/funcs/sparse/common_shape.h"
28

29
namespace phi {
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
namespace sparse {

template <typename T>
inline __device__ bool DevIsZero(const T* data, const int64_t cols) {
  const T zero = static_cast<T>(0);
  // TODO(zhangkaihuo): check the data is zero or not in parallen when cols > 1
  for (int64_t i = 0; i < cols; i++) {
    if (data[i] != zero) {
      return false;
    }
  }
  return true;
}

template <typename T>
__global__ void GetNonZeroNums(const T* dense_data,
                               const int rows,
                               const int cols,
                               int* non_zero_num,
                               int* temp_indexs) {
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  __shared__ int counter;
  if (threadIdx.x == 0) counter = 0;
  __syncthreads();

  for (int i = tid; i < rows; i += gridDim.x * blockDim.x) {
    int index = -1;
    // TODO(zhangkaihuo): when cols=1, vectorization can be used
    if (!DevIsZero(dense_data + i * cols, cols)) {
      // use reductions?
      atomicAdd(&counter, 1);
      index = i;
    }
    temp_indexs[i] = index;
  }
  __syncthreads();
  if (threadIdx.x == 0) {
    atomicAdd(non_zero_num, counter);
  }
}

template <typename T>
__global__ void GetNonZeroElementsAndIndices(const T* dense_data,
                                             const int64_t sparse_dim,
                                             const int64_t cols,
                                             const int64_t* x_dims,
                                             const int non_zero_num,
                                             const int* indexs,
                                             int64_t* indices,
                                             T* sparse_data) {
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) {
    int64_t sparse_index = indexs[i];
    int64_t x_index = sparse_index;
    for (int64_t j = sparse_dim - 1; j >= 0; j--) {
      indices[j * non_zero_num + i] = sparse_index % x_dims[j];
      sparse_index /= x_dims[j];
    }

    for (int j = 0; j < cols; j++) {
      sparse_data[i * cols + j] = dense_data[x_index * cols + j];
    }
  }
}

template <typename T, typename Context>
void DenseToSparseCooKernel(const Context& dev_ctx,
                            const DenseTensor& x,
                            const int64_t sparse_dim,
                            SparseCooTensor* out) {
  const T* x_data = x.data<T>();
  const auto& x_dims = x.dims();
102 103 104 105 106 107
  PADDLE_ENFORCE_LE(sparse_dim,
                    x_dims.size(),
                    phi::errors::InvalidArgument(
                        "sparse_dim must be less than the size of x.dims()"));
  PADDLE_ENFORCE_GT(
      sparse_dim, 0, phi::errors::InvalidArgument("sparse_dim must be >0"));
108 109 110
  auto dims_2d = flatten_to_2d(x_dims, sparse_dim);
  const int rows = dims_2d[0];
  const int cols = dims_2d[1];
111 112
  DenseTensor nums = phi::Empty<int32_t>(dev_ctx, {1});
  DenseTensor d_x_dims = phi::Empty<int64_t>(dev_ctx, {x_dims.size()});
113 114

  // 1. get numbers of non zero elements, and get the index of non zero elements
115 116 117
  int* nums_ptr = nums.data<int>();
  phi::backends::gpu::GpuMemsetAsync(
      nums_ptr, 0, sizeof(int), dev_ctx.stream());
118
  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rows, 1);
119

120 121 122
  DenseTensor temp_indexs = phi::Empty<int32_t>(dev_ctx, {rows});
  int* temp_indexs_ptr = temp_indexs.data<int>();

123 124 125 126
  GetNonZeroNums<<<config.block_per_grid.x,
                   config.thread_per_block.x,
                   0,
                   dev_ctx.stream()>>>(
127
      x_data, rows, cols, nums_ptr, temp_indexs_ptr);
128

129 130 131 132 133 134 135 136 137 138 139
#ifdef PADDLE_WITH_HIP
  thrust::remove(thrust::hip::par.on(dev_ctx.stream()),
#else
  thrust::remove(thrust::cuda::par.on(dev_ctx.stream()),
#endif
                 temp_indexs_ptr,
                 temp_indexs_ptr + rows,
                 -1);

  // 2. copy non_zero_num to host, copy x_dims to device
  int non_zero_num = 0;
140 141 142 143 144 145 146 147 148 149
  phi::backends::gpu::GpuMemcpyAsync(&non_zero_num,
                                     nums_ptr,
                                     sizeof(int),
                                     gpuMemcpyDeviceToHost,
                                     dev_ctx.stream());
  phi::backends::gpu::GpuMemcpyAsync(d_x_dims.data<int64_t>(),
                                     x_dims.Get(),
                                     x_dims.size() * sizeof(x_dims[0]),
                                     gpuMemcpyHostToDevice,
                                     dev_ctx.stream());
150 151 152

  dev_ctx.Wait();  // wait the copy

153 154
  const auto values_dims =
      phi::funcs::sparse::InferDenseDims(x_dims, sparse_dim, non_zero_num);
Z
zyfncg 已提交
155 156 157 158 159 160
  phi::DenseTensor indices = phi::Empty<int64_t>(
      dev_ctx, {sparse_dim, static_cast<int64_t>(non_zero_num)});
  int64_t* indices_data = indices.data<int64_t>();
  phi::DenseTensor values;
  values.Resize(values_dims);
  T* sparse_data = dev_ctx.template Alloc<T>(&values);
161 162

  // 3. calc indices by indexs and get values by indexs
163 164 165 166 167 168 169 170 171 172 173 174
  config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1);
  GetNonZeroElementsAndIndices<<<config.block_per_grid.x,
                                 config.thread_per_block.x,
                                 0,
                                 dev_ctx.stream()>>>(x_data,
                                                     sparse_dim,
                                                     cols,
                                                     d_x_dims.data<int64_t>(),
                                                     non_zero_num,
                                                     temp_indexs_ptr,
                                                     indices_data,
                                                     sparse_data);
175 176 177
  out->SetMember(indices, values, x_dims, true);
}

178 179
template <typename IntT>
__global__ void GetBatchSizes(const IntT* crows,
180 181
                              const int rows,
                              const int batchs,
182
                              IntT* batch_sizes) {
183 184 185 186 187 188
  const int tid = threadIdx.x + blockIdx.x * blockDim.x;
  if (tid < batchs) {
    batch_sizes[tid] = crows[tid * (rows + 1) + rows];
  }
}

189 190 191 192 193
template <typename IntT>
__global__ void ConvertCsrCrowsToCooRows(const IntT* crows_ptr,
                                         const IntT* crows_offsets,
                                         IntT* rows_ptr,
                                         IntT* batch_ptr,
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
                                         const int rows) {
  const int b = blockIdx.y;
  const int64_t offset = crows_offsets ? crows_offsets[b] : 0;
  const int tid = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = tid; i < rows; i += gridDim.x * blockDim.x) {
    for (int j = crows_ptr[b * (rows + 1) + i];
         j < crows_ptr[b * (rows + 1) + i + 1];
         j++) {
      rows_ptr[offset + j] = i;
      if (batch_ptr) {
        batch_ptr[offset + j] = b;
      }
    }
  }
}

210 211 212 213
template <typename T, typename IntT>
void SparseCsrToCooGPUKernel(const GPUContext& dev_ctx,
                             const SparseCsrTensor& x,
                             SparseCooTensor* out) {
214 215 216 217
  const DDim& x_dims = x.dims();
  const int64_t non_zero_num = x.non_zero_cols().numel();
  const auto& csr_crows = x.non_zero_crows();
  const auto& csr_cols = x.non_zero_cols();
218
  const auto& csr_values = x.values();
219 220
  const IntT* csr_crows_data = csr_crows.data<IntT>();
  const IntT* csr_cols_data = csr_cols.data<IntT>();
221 222 223 224 225 226 227 228 229
  const T* csr_values_data = csr_values.data<T>();

  int64_t sparse_dim = 2;
  if (x_dims.size() == 3) {
    sparse_dim = 3;
  }
  int batchs = x_dims.size() == 2 ? 1 : x_dims[0];
  int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1];

230 231 232 233 234 235
  DenseTensor indices = phi::Empty<IntT>(dev_ctx, {sparse_dim, non_zero_num});
  DenseTensor values = phi::EmptyLike<T, GPUContext>(dev_ctx, csr_values);
  DenseTensor offsets = phi::Empty<IntT>(dev_ctx, {batchs});
  IntT* coo_indices = indices.data<IntT>();
  IntT* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices;
  IntT* coo_rows_data =
236
      x_dims.size() == 2 ? coo_indices : batch_ptr + non_zero_num;
237 238 239
  IntT* coo_cols_data = coo_rows_data + non_zero_num;
  IntT* offsets_ptr = batchs == 1 ? nullptr : offsets.data<IntT>();
  T* coo_values_data = values.data<T>();
240 241

  if (batchs > 1) {
242
    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, batchs, 1);
243
    GetBatchSizes<IntT><<<config.block_per_grid.x, config.thread_per_block.x>>>(
244 245 246 247 248 249 250 251 252 253 254 255
        csr_crows_data, rows, batchs, offsets_ptr);

#ifdef PADDLE_WITH_HIP
    thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()),
#else
    thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()),
#endif
                           offsets_ptr,
                           offsets_ptr + batchs,
                           offsets_ptr);
  }

256 257
  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rows, 1);
  config.block_per_grid.y = batchs;
258 259 260 261 262 263 264 265 266 267 268 269 270 271
  ConvertCsrCrowsToCooRows<IntT>
      <<<config.block_per_grid, config.thread_per_block.x>>>(
          csr_crows_data, offsets_ptr, coo_rows_data, batch_ptr, rows);

  phi::backends::gpu::GpuMemcpyAsync(coo_cols_data,
                                     csr_cols_data,
                                     sizeof(IntT) * non_zero_num,
                                     gpuMemcpyDeviceToDevice,
                                     dev_ctx.stream());
  phi::backends::gpu::GpuMemcpyAsync(coo_values_data,
                                     csr_values_data,
                                     sizeof(T) * non_zero_num,
                                     gpuMemcpyDeviceToDevice,
                                     dev_ctx.stream());
272 273 274 275

  out->SetMember(indices, values, x_dims, true);
}

276 277 278 279
template <typename T, typename Context>
void SparseCsrToCooKernel(const Context& dev_ctx,
                          const SparseCsrTensor& x,
                          SparseCooTensor* out) {
Z
zhangkaihuo 已提交
280
  PD_VISIT_BASE_INTEGRAL_TYPES(
281 282 283 284 285 286 287
      x.non_zero_crows().dtype(), "SparseCsrToCooGPUKernel", ([&] {
        SparseCsrToCooGPUKernel<T, data_t>(dev_ctx, x, out);
      }));
}

template <typename IntT>
__global__ void GetBatchsOffset(const IntT* batchs_ptr,
Z
zhangkaihuo 已提交
288
                                const int batchs,
289
                                const int non_zero_num,
Z
zhangkaihuo 已提交
290
                                int* batchs_offset) {
291 292 293
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) {
    if (i == non_zero_num - 1 || batchs_ptr[i] != batchs_ptr[i + 1]) {
Z
zhangkaihuo 已提交
294 295 296 297 298
      const int start = batchs_ptr[i];
      const int end = i == non_zero_num - 1 ? batchs : batchs_ptr[i + 1];
      for (int j = start; j < end; j++) {
        batchs_offset[j] = i + 1;
      }
299 300 301 302
    }
  }
}

303
template <typename IntT>
304
__global__ void ConvertCooRowsToCsrCrows(
Z
zhangkaihuo 已提交
305
    const int* batchs_offset,  // can be null if batchs = 1
306 307
    const IntT* coo_rows_data,
    IntT* csr_crows_data,
308 309 310 311 312
    const int rows,
    const int64_t non_zero_num) {
  const int b = blockIdx.y;
  int batch_non_zero_num =
      batchs_offset == nullptr ? non_zero_num : batchs_offset[b];
313
  IntT batch_start = 0;
314 315 316 317
  if (b > 0) {
    batch_start = batchs_offset[b - 1];
    batch_non_zero_num -= batch_start;
  }
Z
zhangkaihuo 已提交
318

319
  const IntT* coo_rows_ptr = coo_rows_data + batch_start;
320 321 322
  const int tid = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = tid; i < batch_non_zero_num; i += gridDim.x * blockDim.x) {
    if (i == 0) {
323
      for (IntT j = 0; j <= coo_rows_ptr[0]; j++) {
324 325 326
        csr_crows_data[b * (rows + 1) + j] = 0;
      }
    } else {
327
      for (IntT j = coo_rows_ptr[i - 1]; j < coo_rows_ptr[i]; j++) {
328 329 330 331
        csr_crows_data[b * (rows + 1) + j + 1] = i;
      }
    }
    if (i == batch_non_zero_num - 1) {
332
      for (IntT i = coo_rows_ptr[batch_non_zero_num - 1] + 1; i < rows + 1;
333 334 335 336 337
           i++) {
        csr_crows_data[b * (rows + 1) + i] = batch_non_zero_num;
      }
    }
  }
Z
zhangkaihuo 已提交
338 339 340 341 342
  if (batch_non_zero_num == 0) {
    for (int i = tid; i < rows + 1; i += gridDim.x * blockDim.x) {
      csr_crows_data[b * (rows + 1) + i] = 0;
    }
  }
343 344
}

345 346 347 348
template <typename T, typename IntT>
void SparseCooToCsrGPUKernel(const GPUContext& dev_ctx,
                             const SparseCooTensor& x,
                             SparseCsrTensor* out) {
349 350 351 352
  const auto& x_dims = x.dims();
  bool valid = x_dims.size() == 2 || x_dims.size() == 3;
  PADDLE_ENFORCE_EQ(valid,
                    true,
353
                    phi::errors::InvalidArgument(
354 355 356 357 358 359 360
                        "SparseCsrTensor only support 2-D or 3-D matrix"));
  const int64_t non_zero_num = x.nnz();
  if (non_zero_num <= 0) return;

  int batchs = x_dims.size() == 2 ? 1 : x_dims[0];
  int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1];

Z
zyfncg 已提交
361
  phi::DenseTensor non_zero_crows =
362 363
      phi::Empty<IntT>(dev_ctx, {batchs * (rows + 1)});
  phi::DenseTensor non_zero_cols = phi::Empty<IntT>(dev_ctx, {non_zero_num});
364
  phi::DenseTensor values = phi::EmptyLike<T, GPUContext>(dev_ctx, x.values());
365 366
  IntT* csr_crows_data = non_zero_crows.data<IntT>();
  IntT* csr_cols_data = non_zero_cols.data<IntT>();
367
  T* csr_values_data = values.data<T>();
368 369

  const auto& coo_indices = x.non_zero_indices();
370
  const auto& coo_values = x.values();
371 372
  const IntT* batchs_ptr = coo_indices.data<IntT>();
  const IntT* coo_rows_data =
Z
zhangkaihuo 已提交
373
      x_dims.size() == 2 ? batchs_ptr : batchs_ptr + non_zero_num;
374
  const IntT* coo_cols_data = coo_rows_data + non_zero_num;
375 376
  const T* coo_values_data = coo_values.data<T>();

377
  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, batchs, 1);
378
  if (batchs > 1) {
Z
zhangkaihuo 已提交
379 380 381 382 383 384 385 386 387 388 389 390 391
    auto config =
        phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1);
    phi::DenseTensor batchs_offset = phi::Empty<int>(dev_ctx, {batchs});
    int* batchs_offset_ptr = batchs_offset.data<int>();
    phi::funcs::SetConstant<GPUContext, int> set_zero;
    // set zero if the nnz=0 of batchs[0]
    set_zero(dev_ctx, &batchs_offset, static_cast<IntT>(0));
    GetBatchsOffset<IntT><<<config.block_per_grid.x,
                            config.thread_per_block.x,
                            0,
                            dev_ctx.stream()>>>(
        batchs_ptr, batchs, non_zero_num, batchs_offset_ptr);

392
    config.block_per_grid.y = batchs;
393 394 395 396
    ConvertCooRowsToCsrCrows<IntT><<<config.block_per_grid,
                                     config.thread_per_block.x,
                                     0,
                                     dev_ctx.stream()>>>(
397 398
        batchs_offset_ptr, coo_rows_data, csr_crows_data, rows, non_zero_num);
  } else {
399 400 401 402
    ConvertCooRowsToCsrCrows<IntT><<<config.block_per_grid.x,
                                     config.thread_per_block.x,
                                     0,
                                     dev_ctx.stream()>>>(
403 404 405
        nullptr, coo_rows_data, csr_crows_data, rows, non_zero_num);
  }

406 407 408 409 410 411 412 413 414 415
  phi::backends::gpu::GpuMemcpyAsync(csr_cols_data,
                                     coo_cols_data,
                                     sizeof(IntT) * non_zero_num,
                                     gpuMemcpyDeviceToDevice,
                                     dev_ctx.stream());
  phi::backends::gpu::GpuMemcpyAsync(csr_values_data,
                                     coo_values_data,
                                     sizeof(T) * non_zero_num,
                                     gpuMemcpyDeviceToDevice,
                                     dev_ctx.stream());
416
  out->SetMember(non_zero_crows, non_zero_cols, values, x_dims);
417 418
}

419 420 421 422
template <typename T, typename Context>
void SparseCooToCsrKernel(const Context& dev_ctx,
                          const SparseCooTensor& x,
                          SparseCsrTensor* out) {
Z
zhangkaihuo 已提交
423
  PD_VISIT_BASE_INTEGRAL_TYPES(
424 425 426 427 428
      x.non_zero_indices().dtype(), "SparseCooToCsrGPUKernel", ([&] {
        SparseCooToCsrGPUKernel<T, data_t>(dev_ctx, x, out);
      }));
}

Z
zhangkaihuo 已提交
429 430
template <typename ValueT, typename IndicesT>
__global__ void KernelSparseCooToDense(const IndicesT* indices,
431
                                       const int64_t* sparse_offsets,
Z
zhangkaihuo 已提交
432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449
                                       const ValueT* data,
                                       ValueT* dense_data,
                                       const IndicesT non_zero_num,
                                       const int64_t base_offset,
                                       const int64_t sparse_dim) {
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) {
    int64_t index = 0;
    for (int j = 0; j < sparse_dim; j++) {
      index += indices[j * non_zero_num + i] * sparse_offsets[j];
    }

    for (int j = 0; j < base_offset; j++) {
      dense_data[index * base_offset + j] = data[i * base_offset + j];
    }
  }
}

450 451 452 453
template <typename T, typename IntT>
void SparseCooToDenseGPUKernel(const GPUContext& dev_ctx,
                               const SparseCooTensor& x,
                               DenseTensor* out) {
Z
zhangkaihuo 已提交
454 455 456
  const auto non_zero_num = x.nnz();
  const auto dense_dims = x.dims();
  const auto indices = x.non_zero_indices();
457
  const auto values = x.values();
Z
zhangkaihuo 已提交
458 459 460 461 462 463 464 465 466
  const auto indices_dims = indices.dims();
  int64_t sparse_dim = indices_dims[0];
  if (indices_dims.size() == 1) {
    sparse_dim = 1;
  }
  const int64_t dense_dim = values.dims().size() - 1;

  const auto place = dev_ctx.GetPlace();
  const T* x_data = values.data<T>();
467 468
  *out = phi::Empty(
      dev_ctx, phi::DenseTensorMeta(x.dtype(), x.dims(), x.values().layout()));
Z
zhangkaihuo 已提交
469
  T* out_data = out->data<T>();
Z
zhangkaihuo 已提交
470 471 472 473 474 475 476 477 478 479 480
  int64_t base_offset = 1;
  for (int64_t i = 0; i < dense_dim; i++) {
    base_offset *= dense_dims[sparse_dim + i];
  }
  std::vector<int64_t> sparse_offsets(sparse_dim);
  int64_t offset = 1;
  for (int i = sparse_dim - 1; i >= 0; i--) {
    sparse_offsets[i] = offset;
    offset *= dense_dims[i];
  }

481 482 483 484 485 486 487 488 489
  DenseTensor d_sparse_offsets = Empty<int64_t>(dev_ctx, {sparse_dim});

  phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data<int64_t>(),
                                     sparse_offsets.data(),
                                     sparse_dim * sizeof(int64_t),
                                     gpuMemcpyHostToDevice,
                                     dev_ctx.stream());
  phi::backends::gpu::GpuMemsetAsync(
      out_data, 0, sizeof(T) * out->numel(), dev_ctx.stream());
Z
zhangkaihuo 已提交
490

491 492
  auto config =
      phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1);
Z
zhangkaihuo 已提交
493

494
  KernelSparseCooToDense<T, IntT>
495 496 497
      <<<config.block_per_grid.x,
         config.thread_per_block.x,
         0,
498
         dev_ctx.stream()>>>(indices.data<IntT>(),
499 500 501 502 503 504
                             d_sparse_offsets.data<int64_t>(),
                             x_data,
                             out_data,
                             non_zero_num,
                             base_offset,
                             sparse_dim);
Z
zhangkaihuo 已提交
505 506
}

507 508 509 510
template <typename T, typename Context>
void SparseCooToDenseKernel(const Context& dev_ctx,
                            const SparseCooTensor& x,
                            DenseTensor* out) {
Z
zhangkaihuo 已提交
511
  PD_VISIT_BASE_INTEGRAL_TYPES(
512 513 514 515 516
      x.non_zero_indices().dtype(), "SparseCooToDenseGPUKernel", ([&] {
        SparseCooToDenseGPUKernel<T, data_t>(dev_ctx, x, out);
      }));
}

517
}  // namespace sparse
518
}  // namespace phi
519

520
PD_REGISTER_KERNEL(dense_to_sparse_coo,
521 522
                   GPU,
                   ALL_LAYOUT,
523
                   phi::sparse::DenseToSparseCooKernel,
524 525
                   float,
                   double,
526
                   phi::dtype::float16,
527 528 529 530 531
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {}
532

533
PD_REGISTER_KERNEL(sparse_csr_to_coo,
534 535
                   GPU,
                   ALL_LAYOUT,
536
                   phi::sparse::SparseCsrToCooKernel,
537 538
                   float,
                   double,
539
                   phi::dtype::float16,
540 541 542 543 544
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {}
545

546
PD_REGISTER_KERNEL(sparse_coo_to_csr,
547 548
                   GPU,
                   ALL_LAYOUT,
549
                   phi::sparse::SparseCooToCsrKernel,
550 551
                   float,
                   double,
552
                   phi::dtype::float16,
553 554 555 556 557 558
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {}

559
PD_REGISTER_KERNEL(dense_to_sparse_csr,
560 561
                   GPU,
                   ALL_LAYOUT,
562
                   phi::sparse::DenseToSparseCsrKernel,
563 564
                   float,
                   double,
565
                   phi::dtype::float16,
566 567 568 569 570
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {}
Z
zhangkaihuo 已提交
571

572
PD_REGISTER_KERNEL(sparse_coo_to_dense,
Z
zhangkaihuo 已提交
573 574
                   GPU,
                   ALL_LAYOUT,
575
                   phi::sparse::SparseCooToDenseKernel,
Z
zhangkaihuo 已提交
576 577
                   float,
                   double,
578
                   phi::dtype::float16,
Z
zhangkaihuo 已提交
579 580 581 582 583 584
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {}

585
PD_REGISTER_KERNEL(sparse_csr_to_dense,
Z
zhangkaihuo 已提交
586 587
                   GPU,
                   ALL_LAYOUT,
588
                   phi::sparse::SparseCsrToDenseKernel,
Z
zhangkaihuo 已提交
589 590
                   float,
                   double,
591
                   phi::dtype::float16,
Z
zhangkaihuo 已提交
592 593 594 595 596
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {}
597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626

PD_REGISTER_KERNEL(coo_values,
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::CooValuesKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

PD_REGISTER_KERNEL(csr_values,
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::CsrValuesKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
627 628 629 630 631 632 633 634 635 636 637 638

PD_REGISTER_KERNEL(sparse_coo_tensor,
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::SparseCooTensorKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int16_t,
                   int,
                   int64_t) {}