sparse_utils_kernel.cu 24.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"

17 18 19
#include <thrust/execution_policy.h>
#include <thrust/remove.h>

20 21 22
#ifdef PADDLE_WITH_HIP
#include "paddle/phi/backends/dynload/rocsparse.h"
#endif
23
#include "paddle/phi/backends/gpu/gpu_context.h"
24
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
25
#include "paddle/phi/core/enforce.h"
26 27
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
28
#include "paddle/phi/core/visit_type.h"
29
#include "paddle/phi/kernels/cast_kernel.h"
Z
zhangkaihuo 已提交
30
#include "paddle/phi/kernels/funcs/math_function.h"
31
#include "paddle/phi/kernels/funcs/sparse/common_shape.h"
32

33
namespace phi {
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
namespace sparse {

template <typename T>
inline __device__ bool DevIsZero(const T* data, const int64_t cols) {
  const T zero = static_cast<T>(0);
  // TODO(zhangkaihuo): check the data is zero or not in parallen when cols > 1
  for (int64_t i = 0; i < cols; i++) {
    if (data[i] != zero) {
      return false;
    }
  }
  return true;
}

template <typename T>
__global__ void GetNonZeroNums(const T* dense_data,
                               const int rows,
                               const int cols,
                               int* non_zero_num,
                               int* temp_indexs) {
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  __shared__ int counter;
  if (threadIdx.x == 0) counter = 0;
  __syncthreads();

  for (int i = tid; i < rows; i += gridDim.x * blockDim.x) {
    int index = -1;
    // TODO(zhangkaihuo): when cols=1, vectorization can be used
    if (!DevIsZero(dense_data + i * cols, cols)) {
      // use reductions?
      atomicAdd(&counter, 1);
      index = i;
    }
    temp_indexs[i] = index;
  }
  __syncthreads();
  if (threadIdx.x == 0) {
    atomicAdd(non_zero_num, counter);
  }
}

template <typename T>
__global__ void GetNonZeroElementsAndIndices(const T* dense_data,
                                             const int64_t sparse_dim,
                                             const int64_t cols,
                                             const int64_t* x_dims,
                                             const int non_zero_num,
                                             const int* indexs,
                                             int64_t* indices,
                                             T* sparse_data) {
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) {
    int64_t sparse_index = indexs[i];
    int64_t x_index = sparse_index;
    for (int64_t j = sparse_dim - 1; j >= 0; j--) {
      indices[j * non_zero_num + i] = sparse_index % x_dims[j];
      sparse_index /= x_dims[j];
    }

    for (int j = 0; j < cols; j++) {
      sparse_data[i * cols + j] = dense_data[x_index * cols + j];
    }
  }
}

template <typename T, typename Context>
100 101 102 103
void DenseToCooKernel(const Context& dev_ctx,
                      const DenseTensor& x,
                      const int64_t sparse_dim,
                      SparseCooTensor* out) {
104 105
  const T* x_data = x.data<T>();
  const auto& x_dims = x.dims();
106 107 108 109 110 111
  PADDLE_ENFORCE_LE(sparse_dim,
                    x_dims.size(),
                    phi::errors::InvalidArgument(
                        "sparse_dim must be less than the size of x.dims()"));
  PADDLE_ENFORCE_GT(
      sparse_dim, 0, phi::errors::InvalidArgument("sparse_dim must be >0"));
112 113 114
  auto dims_2d = flatten_to_2d(x_dims, sparse_dim);
  const int rows = dims_2d[0];
  const int cols = dims_2d[1];
115 116
  DenseTensor nums = phi::Empty<int32_t>(dev_ctx, {1});
  DenseTensor d_x_dims = phi::Empty<int64_t>(dev_ctx, {x_dims.size()});
117 118

  // 1. get numbers of non zero elements, and get the index of non zero elements
119 120 121
  int* nums_ptr = nums.data<int>();
  phi::backends::gpu::GpuMemsetAsync(
      nums_ptr, 0, sizeof(int), dev_ctx.stream());
122
  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rows, 1);
123

124 125 126
  DenseTensor temp_indexs = phi::Empty<int32_t>(dev_ctx, {rows});
  int* temp_indexs_ptr = temp_indexs.data<int>();

127 128 129 130
  GetNonZeroNums<<<config.block_per_grid.x,
                   config.thread_per_block.x,
                   0,
                   dev_ctx.stream()>>>(
131
      x_data, rows, cols, nums_ptr, temp_indexs_ptr);
132

133 134 135 136 137 138 139 140 141 142 143
#ifdef PADDLE_WITH_HIP
  thrust::remove(thrust::hip::par.on(dev_ctx.stream()),
#else
  thrust::remove(thrust::cuda::par.on(dev_ctx.stream()),
#endif
                 temp_indexs_ptr,
                 temp_indexs_ptr + rows,
                 -1);

  // 2. copy non_zero_num to host, copy x_dims to device
  int non_zero_num = 0;
144 145 146 147 148 149 150 151 152 153
  phi::backends::gpu::GpuMemcpyAsync(&non_zero_num,
                                     nums_ptr,
                                     sizeof(int),
                                     gpuMemcpyDeviceToHost,
                                     dev_ctx.stream());
  phi::backends::gpu::GpuMemcpyAsync(d_x_dims.data<int64_t>(),
                                     x_dims.Get(),
                                     x_dims.size() * sizeof(x_dims[0]),
                                     gpuMemcpyHostToDevice,
                                     dev_ctx.stream());
154 155 156

  dev_ctx.Wait();  // wait the copy

157 158
  const auto values_dims =
      phi::funcs::sparse::InferDenseDims(x_dims, sparse_dim, non_zero_num);
Z
zyfncg 已提交
159 160 161 162 163 164
  phi::DenseTensor indices = phi::Empty<int64_t>(
      dev_ctx, {sparse_dim, static_cast<int64_t>(non_zero_num)});
  int64_t* indices_data = indices.data<int64_t>();
  phi::DenseTensor values;
  values.Resize(values_dims);
  T* sparse_data = dev_ctx.template Alloc<T>(&values);
165 166

  // 3. calc indices by indexs and get values by indexs
167 168 169 170 171 172 173 174 175 176 177 178
  config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1);
  GetNonZeroElementsAndIndices<<<config.block_per_grid.x,
                                 config.thread_per_block.x,
                                 0,
                                 dev_ctx.stream()>>>(x_data,
                                                     sparse_dim,
                                                     cols,
                                                     d_x_dims.data<int64_t>(),
                                                     non_zero_num,
                                                     temp_indexs_ptr,
                                                     indices_data,
                                                     sparse_data);
Z
zhangkaihuo 已提交
179

180 181 182
  out->SetMember(indices, values, x_dims, true);
}

183 184
template <typename IntT>
__global__ void GetBatchSizes(const IntT* crows,
185 186
                              const int rows,
                              const int batchs,
187
                              IntT* batch_sizes) {
188 189 190 191 192 193
  const int tid = threadIdx.x + blockIdx.x * blockDim.x;
  if (tid < batchs) {
    batch_sizes[tid] = crows[tid * (rows + 1) + rows];
  }
}

194 195 196 197 198
template <typename IntT>
__global__ void ConvertCsrCrowsToCooRows(const IntT* crows_ptr,
                                         const IntT* crows_offsets,
                                         IntT* rows_ptr,
                                         IntT* batch_ptr,
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
                                         const int rows) {
  const int b = blockIdx.y;
  const int64_t offset = crows_offsets ? crows_offsets[b] : 0;
  const int tid = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = tid; i < rows; i += gridDim.x * blockDim.x) {
    for (int j = crows_ptr[b * (rows + 1) + i];
         j < crows_ptr[b * (rows + 1) + i + 1];
         j++) {
      rows_ptr[offset + j] = i;
      if (batch_ptr) {
        batch_ptr[offset + j] = b;
      }
    }
  }
}

215
template <typename T, typename IntT>
216 217 218
void CsrToCooGPUKernel(const GPUContext& dev_ctx,
                       const SparseCsrTensor& x,
                       SparseCooTensor* out) {
219
  const DDim& x_dims = x.dims();
220
  const int64_t non_zero_num = x.cols().numel();
221 222 223 224 225 226 227 228 229

// rocsparse_csr2coo only support index with type 'rocsparse_int' (aka 'int')
// now
#ifdef PADDLE_WITH_HIP
  const auto& csr_crows = Cast<IntT>(dev_ctx, x.crows(), DataType::INT32);
  const auto& csr_cols = Cast<IntT>(dev_ctx, x.cols(), DataType::INT32);
  const int* csr_crows_data = csr_crows.template data<int>();
  const int* csr_cols_data = csr_cols.template data<int>();
#else
230 231
  const auto& csr_crows = x.crows();
  const auto& csr_cols = x.cols();
232 233
  const IntT* csr_crows_data = csr_crows.data<IntT>();
  const IntT* csr_cols_data = csr_cols.data<IntT>();
234 235
#endif
  const auto& csr_values = x.values();
236 237 238 239 240 241
  const T* csr_values_data = csr_values.data<T>();

  int64_t sparse_dim = 2;
  if (x_dims.size() == 3) {
    sparse_dim = 3;
  }
242
  int batches = x_dims.size() == 2 ? 1 : x_dims[0];
243 244
  int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1];

245 246 247 248 249 250
#ifdef PADDLE_WITH_HIP
  DenseTensor indices = phi::Empty<int>(dev_ctx, {sparse_dim, non_zero_num});
  int* coo_indices = indices.data<int>();
  int* coo_rows_data = coo_indices;
  int* coo_cols_data = coo_rows_data + non_zero_num;
#else
251
  DenseTensor indices = phi::Empty<IntT>(dev_ctx, {sparse_dim, non_zero_num});
252
  DenseTensor offsets = phi::Empty<IntT>(dev_ctx, {batches});
253 254 255
  IntT* coo_indices = indices.data<IntT>();
  IntT* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices;
  IntT* coo_rows_data =
256
      x_dims.size() == 2 ? coo_indices : batch_ptr + non_zero_num;
257
  IntT* coo_cols_data = coo_rows_data + non_zero_num;
258 259 260
  IntT* offsets_ptr = batches == 1 ? nullptr : offsets.data<IntT>();
#endif
  DenseTensor values = phi::EmptyLike<T, GPUContext>(dev_ctx, csr_values);
261
  T* coo_values_data = values.data<T>();
262

263
  if (batches > 1) {
264
#ifdef PADDLE_WITH_HIP
265 266 267
    PADDLE_THROW(
        phi::errors::Unimplemented("'rocsparse_csr2coo' only supports batches "
                                   "with a value of 1 currently."));
268
#else
269 270 271 272
    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, batches, 1);
    GetBatchSizes<IntT><<<config.block_per_grid.x, config.thread_per_block.x>>>(
        csr_crows_data, rows, batches, offsets_ptr);

273 274
    thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()),
                           offsets_ptr,
275
                           offsets_ptr + batches,
276
                           offsets_ptr);
277
#endif
278 279
  }

280 281 282 283 284 285 286 287 288 289
#ifdef PADDLE_WITH_HIP
  dev_ctx.CusparseCall([&](rocsparse_handle handle) {
    phi::dynload::rocsparse_csr2coo(handle,
                                    csr_crows_data,
                                    non_zero_num,
                                    rows,
                                    coo_rows_data,
                                    rocsparse_index_base_zero);
  });
#else
290
  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rows, 1);
291
  config.block_per_grid.y = batches;
292 293 294
  ConvertCsrCrowsToCooRows<IntT>
      <<<config.block_per_grid, config.thread_per_block.x>>>(
          csr_crows_data, offsets_ptr, coo_rows_data, batch_ptr, rows);
295
#endif
296 297
  phi::backends::gpu::GpuMemcpyAsync(coo_cols_data,
                                     csr_cols_data,
298 299 300
#ifdef PADDLE_WITH_HIP
                                     sizeof(int) * non_zero_num,
#else
301
                                     sizeof(IntT) * non_zero_num,
302
#endif
303 304 305 306 307 308 309
                                     gpuMemcpyDeviceToDevice,
                                     dev_ctx.stream());
  phi::backends::gpu::GpuMemcpyAsync(coo_values_data,
                                     csr_values_data,
                                     sizeof(T) * non_zero_num,
                                     gpuMemcpyDeviceToDevice,
                                     dev_ctx.stream());
310

311 312 313 314 315
#ifdef PADDLE_WITH_HIP
  if (std::is_same<IntT, int64_t>::value)
    indices = Cast<int>(dev_ctx, indices, DataType::INT64);
#endif

316 317 318
  out->SetMember(indices, values, x_dims, true);
}

319
template <typename T, typename Context>
320 321 322 323 324 325
void CsrToCooKernel(const Context& dev_ctx,
                    const SparseCsrTensor& x,
                    SparseCooTensor* out) {
  PD_VISIT_BASE_INTEGRAL_TYPES(x.crows().dtype(), "CsrToCooGPUKernel", ([&] {
                                 CsrToCooGPUKernel<T, data_t>(dev_ctx, x, out);
                               }));
326 327 328 329
}

template <typename IntT>
__global__ void GetBatchsOffset(const IntT* batchs_ptr,
Z
zhangkaihuo 已提交
330
                                const int batchs,
331
                                const int non_zero_num,
Z
zhangkaihuo 已提交
332
                                int* batchs_offset) {
333 334 335
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) {
    if (i == non_zero_num - 1 || batchs_ptr[i] != batchs_ptr[i + 1]) {
Z
zhangkaihuo 已提交
336 337 338 339 340
      const int start = batchs_ptr[i];
      const int end = i == non_zero_num - 1 ? batchs : batchs_ptr[i + 1];
      for (int j = start; j < end; j++) {
        batchs_offset[j] = i + 1;
      }
341 342 343 344
    }
  }
}

345
template <typename IntT>
346
__global__ void ConvertCooRowsToCsrCrows(
Z
zhangkaihuo 已提交
347
    const int* batchs_offset,  // can be null if batchs = 1
348 349
    const IntT* coo_rows_data,
    IntT* csr_crows_data,
350 351 352 353 354
    const int rows,
    const int64_t non_zero_num) {
  const int b = blockIdx.y;
  int batch_non_zero_num =
      batchs_offset == nullptr ? non_zero_num : batchs_offset[b];
355
  IntT batch_start = 0;
356 357 358 359
  if (b > 0) {
    batch_start = batchs_offset[b - 1];
    batch_non_zero_num -= batch_start;
  }
Z
zhangkaihuo 已提交
360

361
  const IntT* coo_rows_ptr = coo_rows_data + batch_start;
362 363 364
  const int tid = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = tid; i < batch_non_zero_num; i += gridDim.x * blockDim.x) {
    if (i == 0) {
365
      for (IntT j = 0; j <= coo_rows_ptr[0]; j++) {
366 367 368
        csr_crows_data[b * (rows + 1) + j] = 0;
      }
    } else {
369
      for (IntT j = coo_rows_ptr[i - 1]; j < coo_rows_ptr[i]; j++) {
370 371 372 373
        csr_crows_data[b * (rows + 1) + j + 1] = i;
      }
    }
    if (i == batch_non_zero_num - 1) {
374
      for (IntT i = coo_rows_ptr[batch_non_zero_num - 1] + 1; i < rows + 1;
375 376 377 378 379
           i++) {
        csr_crows_data[b * (rows + 1) + i] = batch_non_zero_num;
      }
    }
  }
Z
zhangkaihuo 已提交
380 381 382 383 384
  if (batch_non_zero_num == 0) {
    for (int i = tid; i < rows + 1; i += gridDim.x * blockDim.x) {
      csr_crows_data[b * (rows + 1) + i] = 0;
    }
  }
385 386
}

387
template <typename T, typename IntT>
388 389 390
void CooToCsrGPUKernel(const GPUContext& dev_ctx,
                       const SparseCooTensor& x,
                       SparseCsrTensor* out) {
391 392 393 394
  const auto& x_dims = x.dims();
  bool valid = x_dims.size() == 2 || x_dims.size() == 3;
  PADDLE_ENFORCE_EQ(valid,
                    true,
395
                    phi::errors::InvalidArgument(
396 397 398 399 400 401 402
                        "SparseCsrTensor only support 2-D or 3-D matrix"));
  const int64_t non_zero_num = x.nnz();
  if (non_zero_num <= 0) return;

  int batchs = x_dims.size() == 2 ? 1 : x_dims[0];
  int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1];

403 404
  phi::DenseTensor crows = phi::Empty<IntT>(dev_ctx, {batchs * (rows + 1)});
  phi::DenseTensor cols = phi::Empty<IntT>(dev_ctx, {non_zero_num});
405
  phi::DenseTensor values = phi::EmptyLike<T, GPUContext>(dev_ctx, x.values());
406 407
  IntT* csr_crows_data = crows.data<IntT>();
  IntT* csr_cols_data = cols.data<IntT>();
408
  T* csr_values_data = values.data<T>();
409

410
  const auto& coo_indices = x.indices();
411
  const auto& coo_values = x.values();
412 413
  const IntT* batchs_ptr = coo_indices.data<IntT>();
  const IntT* coo_rows_data =
Z
zhangkaihuo 已提交
414
      x_dims.size() == 2 ? batchs_ptr : batchs_ptr + non_zero_num;
415
  const IntT* coo_cols_data = coo_rows_data + non_zero_num;
416 417
  const T* coo_values_data = coo_values.data<T>();

418
  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, batchs, 1);
419
  if (batchs > 1) {
Z
zhangkaihuo 已提交
420 421 422 423 424 425 426 427 428 429 430 431 432
    auto config =
        phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1);
    phi::DenseTensor batchs_offset = phi::Empty<int>(dev_ctx, {batchs});
    int* batchs_offset_ptr = batchs_offset.data<int>();
    phi::funcs::SetConstant<GPUContext, int> set_zero;
    // set zero if the nnz=0 of batchs[0]
    set_zero(dev_ctx, &batchs_offset, static_cast<IntT>(0));
    GetBatchsOffset<IntT><<<config.block_per_grid.x,
                            config.thread_per_block.x,
                            0,
                            dev_ctx.stream()>>>(
        batchs_ptr, batchs, non_zero_num, batchs_offset_ptr);

433
    config.block_per_grid.y = batchs;
434 435 436 437
    ConvertCooRowsToCsrCrows<IntT><<<config.block_per_grid,
                                     config.thread_per_block.x,
                                     0,
                                     dev_ctx.stream()>>>(
438 439
        batchs_offset_ptr, coo_rows_data, csr_crows_data, rows, non_zero_num);
  } else {
440 441 442 443
    ConvertCooRowsToCsrCrows<IntT><<<config.block_per_grid.x,
                                     config.thread_per_block.x,
                                     0,
                                     dev_ctx.stream()>>>(
444 445 446
        nullptr, coo_rows_data, csr_crows_data, rows, non_zero_num);
  }

447 448 449 450 451 452 453 454 455 456
  phi::backends::gpu::GpuMemcpyAsync(csr_cols_data,
                                     coo_cols_data,
                                     sizeof(IntT) * non_zero_num,
                                     gpuMemcpyDeviceToDevice,
                                     dev_ctx.stream());
  phi::backends::gpu::GpuMemcpyAsync(csr_values_data,
                                     coo_values_data,
                                     sizeof(T) * non_zero_num,
                                     gpuMemcpyDeviceToDevice,
                                     dev_ctx.stream());
457
  out->SetMember(crows, cols, values, x_dims);
458 459
}

460
template <typename T, typename Context>
461 462 463 464 465 466
void CooToCsrKernel(const Context& dev_ctx,
                    const SparseCooTensor& x,
                    SparseCsrTensor* out) {
  PD_VISIT_BASE_INTEGRAL_TYPES(x.indices().dtype(), "CooToCsrGPUKernel", ([&] {
                                 CooToCsrGPUKernel<T, data_t>(dev_ctx, x, out);
                               }));
467 468
}

Z
zhangkaihuo 已提交
469
template <typename ValueT, typename IndicesT>
470 471 472 473 474 475 476
__global__ void KernelCooToDense(const IndicesT* indices,
                                 const int64_t* sparse_offsets,
                                 const ValueT* data,
                                 ValueT* dense_data,
                                 const IndicesT non_zero_num,
                                 const int64_t base_offset,
                                 const int64_t sparse_dim) {
Z
zhangkaihuo 已提交
477 478 479 480 481 482 483 484 485 486 487 488 489
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) {
    int64_t index = 0;
    for (int j = 0; j < sparse_dim; j++) {
      index += indices[j * non_zero_num + i] * sparse_offsets[j];
    }

    for (int j = 0; j < base_offset; j++) {
      dense_data[index * base_offset + j] = data[i * base_offset + j];
    }
  }
}

490
template <typename T, typename IntT>
491 492 493
void CooToDenseGPUKernel(const GPUContext& dev_ctx,
                         const SparseCooTensor& x,
                         DenseTensor* out) {
Z
zhangkaihuo 已提交
494 495
  const auto non_zero_num = x.nnz();
  const auto dense_dims = x.dims();
496
  const auto indices = x.indices();
497
  const auto values = x.values();
Z
zhangkaihuo 已提交
498 499 500 501 502 503 504 505 506
  const auto indices_dims = indices.dims();
  int64_t sparse_dim = indices_dims[0];
  if (indices_dims.size() == 1) {
    sparse_dim = 1;
  }
  const int64_t dense_dim = values.dims().size() - 1;

  const auto place = dev_ctx.GetPlace();
  const T* x_data = values.data<T>();
Z
zhangkaihuo 已提交
507 508
  dev_ctx.template Alloc<T>(out);

Z
zhangkaihuo 已提交
509
  T* out_data = out->data<T>();
Z
zhangkaihuo 已提交
510 511 512 513 514 515 516 517 518 519 520
  int64_t base_offset = 1;
  for (int64_t i = 0; i < dense_dim; i++) {
    base_offset *= dense_dims[sparse_dim + i];
  }
  std::vector<int64_t> sparse_offsets(sparse_dim);
  int64_t offset = 1;
  for (int i = sparse_dim - 1; i >= 0; i--) {
    sparse_offsets[i] = offset;
    offset *= dense_dims[i];
  }

521 522 523 524 525 526 527 528 529
  DenseTensor d_sparse_offsets = Empty<int64_t>(dev_ctx, {sparse_dim});

  phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data<int64_t>(),
                                     sparse_offsets.data(),
                                     sparse_dim * sizeof(int64_t),
                                     gpuMemcpyHostToDevice,
                                     dev_ctx.stream());
  phi::backends::gpu::GpuMemsetAsync(
      out_data, 0, sizeof(T) * out->numel(), dev_ctx.stream());
Z
zhangkaihuo 已提交
530

531 532
  auto config =
      phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1);
Z
zhangkaihuo 已提交
533

534
  KernelCooToDense<T, IntT>
535 536 537
      <<<config.block_per_grid.x,
         config.thread_per_block.x,
         0,
538
         dev_ctx.stream()>>>(indices.data<IntT>(),
539 540 541 542 543 544
                             d_sparse_offsets.data<int64_t>(),
                             x_data,
                             out_data,
                             non_zero_num,
                             base_offset,
                             sparse_dim);
Z
zhangkaihuo 已提交
545 546
}

547
template <typename T, typename Context>
548 549 550
void CooToDenseKernel(const Context& dev_ctx,
                      const SparseCooTensor& x,
                      DenseTensor* out) {
Z
zhangkaihuo 已提交
551
  PD_VISIT_BASE_INTEGRAL_TYPES(
552 553
      x.indices().dtype(), "CooToDenseGPUKernel", ([&] {
        CooToDenseGPUKernel<T, data_t>(dev_ctx, x, out);
554 555 556
      }));
}

557
}  // namespace sparse
558
}  // namespace phi
559

560
PD_REGISTER_KERNEL(dense_to_coo,
561 562
                   GPU,
                   ALL_LAYOUT,
563
                   phi::sparse::DenseToCooKernel,
564 565
                   float,
                   double,
566
                   phi::dtype::float16,
567 568 569 570 571
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {}
572

573
PD_REGISTER_KERNEL(csr_to_coo,
574 575
                   GPU,
                   ALL_LAYOUT,
576
                   phi::sparse::CsrToCooKernel,
577 578
                   float,
                   double,
579
                   phi::dtype::float16,
580 581 582 583
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
584 585
                   int64_t,
                   bool) {}
586

587
PD_REGISTER_KERNEL(coo_to_csr,
588 589
                   GPU,
                   ALL_LAYOUT,
590
                   phi::sparse::CooToCsrKernel,
591 592
                   float,
                   double,
593
                   phi::dtype::float16,
594 595 596 597
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
598 599
                   int64_t,
                   bool) {}
600

601
PD_REGISTER_KERNEL(dense_to_csr,
602 603
                   GPU,
                   ALL_LAYOUT,
604
                   phi::sparse::DenseToCsrKernel,
605 606
                   float,
                   double,
607
                   phi::dtype::float16,
608 609 610 611 612
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {}
Z
zhangkaihuo 已提交
613

614
PD_REGISTER_KERNEL(coo_to_dense,
Z
zhangkaihuo 已提交
615 616
                   GPU,
                   ALL_LAYOUT,
617
                   phi::sparse::CooToDenseKernel,
Z
zhangkaihuo 已提交
618 619
                   float,
                   double,
620
                   phi::dtype::float16,
Z
zhangkaihuo 已提交
621 622 623 624
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
625 626
                   int64_t,
                   bool) {}
Z
zhangkaihuo 已提交
627

628
PD_REGISTER_KERNEL(csr_to_dense,
Z
zhangkaihuo 已提交
629 630
                   GPU,
                   ALL_LAYOUT,
631
                   phi::sparse::CsrToDenseKernel,
Z
zhangkaihuo 已提交
632 633
                   float,
                   double,
634
                   phi::dtype::float16,
Z
zhangkaihuo 已提交
635 636 637 638
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
639 640
                   int64_t,
                   bool) {}
641

642
PD_REGISTER_KERNEL(values_coo,
643 644
                   GPU,
                   ALL_LAYOUT,
645
                   phi::sparse::ValuesCooKernel,
646 647 648 649 650 651 652
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
653 654
                   int64_t,
                   bool) {
655 656 657
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

658
PD_REGISTER_KERNEL(values_csr,
659 660
                   GPU,
                   ALL_LAYOUT,
661
                   phi::sparse::ValuesCsrKernel,
662 663 664 665 666 667 668
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
669 670
                   int64_t,
                   bool) {
Z
zhangkaihuo 已提交
671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}

PD_REGISTER_KERNEL(indices_coo,
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::IndicesCooKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {
686 687
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
688 689 690 691 692 693 694 695 696 697 698 699

PD_REGISTER_KERNEL(sparse_coo_tensor,
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::SparseCooTensorKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int16_t,
                   int,
                   int64_t) {}