sparse_utils_kernel.cu 25.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"

17 18 19
#include <thrust/execution_policy.h>
#include <thrust/remove.h>

20 21 22
#ifdef PADDLE_WITH_HIP
#include "paddle/phi/backends/dynload/rocsparse.h"
#endif
23
#include "paddle/phi/backends/gpu/gpu_context.h"
24
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
25
#include "paddle/phi/core/enforce.h"
26 27
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
28
#include "paddle/phi/core/visit_type.h"
29
#include "paddle/phi/kernels/cast_kernel.h"
Z
zhangkaihuo 已提交
30
#include "paddle/phi/kernels/funcs/math_function.h"
31
#include "paddle/phi/kernels/funcs/sparse/common_shape.h"
32

33
namespace phi {
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
namespace sparse {

template <typename T>
inline __device__ bool DevIsZero(const T* data, const int64_t cols) {
  const T zero = static_cast<T>(0);
  // TODO(zhangkaihuo): check the data is zero or not in parallen when cols > 1
  for (int64_t i = 0; i < cols; i++) {
    if (data[i] != zero) {
      return false;
    }
  }
  return true;
}

template <typename T>
__global__ void GetNonZeroNums(const T* dense_data,
                               const int rows,
                               const int cols,
                               int* non_zero_num,
                               int* temp_indexs) {
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  __shared__ int counter;
  if (threadIdx.x == 0) counter = 0;
  __syncthreads();

  for (int i = tid; i < rows; i += gridDim.x * blockDim.x) {
    int index = -1;
    // TODO(zhangkaihuo): when cols=1, vectorization can be used
    if (!DevIsZero(dense_data + i * cols, cols)) {
      // use reductions?
      atomicAdd(&counter, 1);
      index = i;
    }
    temp_indexs[i] = index;
  }
  __syncthreads();
  if (threadIdx.x == 0) {
    atomicAdd(non_zero_num, counter);
  }
}

template <typename T>
__global__ void GetNonZeroElementsAndIndices(const T* dense_data,
                                             const int64_t sparse_dim,
                                             const int64_t cols,
                                             const int64_t* x_dims,
                                             const int non_zero_num,
                                             const int* indexs,
                                             int64_t* indices,
                                             T* sparse_data) {
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) {
    int64_t sparse_index = indexs[i];
    int64_t x_index = sparse_index;
    for (int64_t j = sparse_dim - 1; j >= 0; j--) {
      indices[j * non_zero_num + i] = sparse_index % x_dims[j];
      sparse_index /= x_dims[j];
    }

    for (int j = 0; j < cols; j++) {
      sparse_data[i * cols + j] = dense_data[x_index * cols + j];
    }
  }
}

template <typename T, typename Context>
100 101 102 103
void DenseToCooKernel(const Context& dev_ctx,
                      const DenseTensor& x,
                      const int64_t sparse_dim,
                      SparseCooTensor* out) {
104 105
  const T* x_data = x.data<T>();
  const auto& x_dims = x.dims();
106 107 108 109 110 111
  PADDLE_ENFORCE_LE(sparse_dim,
                    x_dims.size(),
                    phi::errors::InvalidArgument(
                        "sparse_dim must be less than the size of x.dims()"));
  PADDLE_ENFORCE_GT(
      sparse_dim, 0, phi::errors::InvalidArgument("sparse_dim must be >0"));
112 113 114
  auto dims_2d = flatten_to_2d(x_dims, sparse_dim);
  const int rows = dims_2d[0];
  const int cols = dims_2d[1];
115 116
  DenseTensor nums = phi::Empty<int32_t>(dev_ctx, {1});
  DenseTensor d_x_dims = phi::Empty<int64_t>(dev_ctx, {x_dims.size()});
117 118

  // 1. get numbers of non zero elements, and get the index of non zero elements
119 120 121
  int* nums_ptr = nums.data<int>();
  phi::backends::gpu::GpuMemsetAsync(
      nums_ptr, 0, sizeof(int), dev_ctx.stream());
122
  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rows, 1);
123

124 125 126
  DenseTensor temp_indexs = phi::Empty<int32_t>(dev_ctx, {rows});
  int* temp_indexs_ptr = temp_indexs.data<int>();

127 128 129 130
  GetNonZeroNums<<<config.block_per_grid.x,
                   config.thread_per_block.x,
                   0,
                   dev_ctx.stream()>>>(
131
      x_data, rows, cols, nums_ptr, temp_indexs_ptr);
132

133 134 135 136 137 138 139 140 141 142 143
#ifdef PADDLE_WITH_HIP
  thrust::remove(thrust::hip::par.on(dev_ctx.stream()),
#else
  thrust::remove(thrust::cuda::par.on(dev_ctx.stream()),
#endif
                 temp_indexs_ptr,
                 temp_indexs_ptr + rows,
                 -1);

  // 2. copy non_zero_num to host, copy x_dims to device
  int non_zero_num = 0;
144 145 146 147 148 149 150 151 152 153
  phi::backends::gpu::GpuMemcpyAsync(&non_zero_num,
                                     nums_ptr,
                                     sizeof(int),
                                     gpuMemcpyDeviceToHost,
                                     dev_ctx.stream());
  phi::backends::gpu::GpuMemcpyAsync(d_x_dims.data<int64_t>(),
                                     x_dims.Get(),
                                     x_dims.size() * sizeof(x_dims[0]),
                                     gpuMemcpyHostToDevice,
                                     dev_ctx.stream());
154 155 156

  dev_ctx.Wait();  // wait the copy

157 158
  const auto values_dims =
      phi::funcs::sparse::InferDenseDims(x_dims, sparse_dim, non_zero_num);
Z
zyfncg 已提交
159 160 161 162 163 164
  phi::DenseTensor indices = phi::Empty<int64_t>(
      dev_ctx, {sparse_dim, static_cast<int64_t>(non_zero_num)});
  int64_t* indices_data = indices.data<int64_t>();
  phi::DenseTensor values;
  values.Resize(values_dims);
  T* sparse_data = dev_ctx.template Alloc<T>(&values);
165 166

  // 3. calc indices by indexs and get values by indexs
Z
zhangkaihuo 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180
  if (non_zero_num > 0) {
    config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1);
    GetNonZeroElementsAndIndices<<<config.block_per_grid.x,
                                   config.thread_per_block.x,
                                   0,
                                   dev_ctx.stream()>>>(x_data,
                                                       sparse_dim,
                                                       cols,
                                                       d_x_dims.data<int64_t>(),
                                                       non_zero_num,
                                                       temp_indexs_ptr,
                                                       indices_data,
                                                       sparse_data);
  }
Z
zhangkaihuo 已提交
181

182 183 184
  out->SetMember(indices, values, x_dims, true);
}

185 186
template <typename IntT>
__global__ void GetBatchSizes(const IntT* crows,
187 188
                              const int rows,
                              const int batchs,
189
                              IntT* batch_sizes) {
190 191 192 193 194 195
  const int tid = threadIdx.x + blockIdx.x * blockDim.x;
  if (tid < batchs) {
    batch_sizes[tid] = crows[tid * (rows + 1) + rows];
  }
}

196 197 198 199 200
template <typename IntT>
__global__ void ConvertCsrCrowsToCooRows(const IntT* crows_ptr,
                                         const IntT* crows_offsets,
                                         IntT* rows_ptr,
                                         IntT* batch_ptr,
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
                                         const int rows) {
  const int b = blockIdx.y;
  const int64_t offset = crows_offsets ? crows_offsets[b] : 0;
  const int tid = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = tid; i < rows; i += gridDim.x * blockDim.x) {
    for (int j = crows_ptr[b * (rows + 1) + i];
         j < crows_ptr[b * (rows + 1) + i + 1];
         j++) {
      rows_ptr[offset + j] = i;
      if (batch_ptr) {
        batch_ptr[offset + j] = b;
      }
    }
  }
}

217
template <typename T, typename IntT>
218 219 220
void CsrToCooGPUKernel(const GPUContext& dev_ctx,
                       const SparseCsrTensor& x,
                       SparseCooTensor* out) {
221
  const DDim& x_dims = x.dims();
222
  const int64_t non_zero_num = x.cols().numel();
Z
zhangkaihuo 已提交
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
  int64_t sparse_dim = 2;
  if (x_dims.size() == 3) {
    sparse_dim = 3;
  }

  if (x.nnz() <= 0) {
#ifdef PADDLE_WITH_HIP
    DenseTensor indices = phi::Empty<int>(dev_ctx, {sparse_dim, non_zero_num});
#else
    DenseTensor indices = phi::Empty<IntT>(dev_ctx, {sparse_dim, non_zero_num});
#endif
    DenseTensor values = phi::EmptyLike<T, GPUContext>(dev_ctx, x.values());
    out->SetMember(indices, values, x_dims, true);
    return;
  }
238 239 240 241 242 243 244 245 246

// rocsparse_csr2coo only support index with type 'rocsparse_int' (aka 'int')
// now
#ifdef PADDLE_WITH_HIP
  const auto& csr_crows = Cast<IntT>(dev_ctx, x.crows(), DataType::INT32);
  const auto& csr_cols = Cast<IntT>(dev_ctx, x.cols(), DataType::INT32);
  const int* csr_crows_data = csr_crows.template data<int>();
  const int* csr_cols_data = csr_cols.template data<int>();
#else
247 248
  const auto& csr_crows = x.crows();
  const auto& csr_cols = x.cols();
249 250
  const IntT* csr_crows_data = csr_crows.data<IntT>();
  const IntT* csr_cols_data = csr_cols.data<IntT>();
251 252
#endif
  const auto& csr_values = x.values();
253 254
  const T* csr_values_data = csr_values.data<T>();

255
  int batches = x_dims.size() == 2 ? 1 : x_dims[0];
256 257
  int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1];

258 259 260 261 262 263
#ifdef PADDLE_WITH_HIP
  DenseTensor indices = phi::Empty<int>(dev_ctx, {sparse_dim, non_zero_num});
  int* coo_indices = indices.data<int>();
  int* coo_rows_data = coo_indices;
  int* coo_cols_data = coo_rows_data + non_zero_num;
#else
264
  DenseTensor indices = phi::Empty<IntT>(dev_ctx, {sparse_dim, non_zero_num});
265
  DenseTensor offsets = phi::Empty<IntT>(dev_ctx, {batches});
266 267 268
  IntT* coo_indices = indices.data<IntT>();
  IntT* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices;
  IntT* coo_rows_data =
269
      x_dims.size() == 2 ? coo_indices : batch_ptr + non_zero_num;
270
  IntT* coo_cols_data = coo_rows_data + non_zero_num;
271 272 273
  IntT* offsets_ptr = batches == 1 ? nullptr : offsets.data<IntT>();
#endif
  DenseTensor values = phi::EmptyLike<T, GPUContext>(dev_ctx, csr_values);
274
  T* coo_values_data = values.data<T>();
275

276
  if (batches > 1) {
277
#ifdef PADDLE_WITH_HIP
278 279 280
    PADDLE_THROW(
        phi::errors::Unimplemented("'rocsparse_csr2coo' only supports batches "
                                   "with a value of 1 currently."));
281
#else
282 283 284 285
    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, batches, 1);
    GetBatchSizes<IntT><<<config.block_per_grid.x, config.thread_per_block.x>>>(
        csr_crows_data, rows, batches, offsets_ptr);

286 287
    thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()),
                           offsets_ptr,
288
                           offsets_ptr + batches,
289
                           offsets_ptr);
290
#endif
291 292
  }

293 294 295 296 297 298 299 300 301 302
#ifdef PADDLE_WITH_HIP
  dev_ctx.CusparseCall([&](rocsparse_handle handle) {
    phi::dynload::rocsparse_csr2coo(handle,
                                    csr_crows_data,
                                    non_zero_num,
                                    rows,
                                    coo_rows_data,
                                    rocsparse_index_base_zero);
  });
#else
303
  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rows, 1);
304
  config.block_per_grid.y = batches;
305 306 307
  ConvertCsrCrowsToCooRows<IntT>
      <<<config.block_per_grid, config.thread_per_block.x>>>(
          csr_crows_data, offsets_ptr, coo_rows_data, batch_ptr, rows);
308
#endif
309 310
  phi::backends::gpu::GpuMemcpyAsync(coo_cols_data,
                                     csr_cols_data,
311 312 313
#ifdef PADDLE_WITH_HIP
                                     sizeof(int) * non_zero_num,
#else
314
                                     sizeof(IntT) * non_zero_num,
315
#endif
316 317 318 319 320 321 322
                                     gpuMemcpyDeviceToDevice,
                                     dev_ctx.stream());
  phi::backends::gpu::GpuMemcpyAsync(coo_values_data,
                                     csr_values_data,
                                     sizeof(T) * non_zero_num,
                                     gpuMemcpyDeviceToDevice,
                                     dev_ctx.stream());
323

324 325 326 327 328
#ifdef PADDLE_WITH_HIP
  if (std::is_same<IntT, int64_t>::value)
    indices = Cast<int>(dev_ctx, indices, DataType::INT64);
#endif

329 330 331
  out->SetMember(indices, values, x_dims, true);
}

332
template <typename T, typename Context>
333 334 335 336 337 338
void CsrToCooKernel(const Context& dev_ctx,
                    const SparseCsrTensor& x,
                    SparseCooTensor* out) {
  PD_VISIT_BASE_INTEGRAL_TYPES(x.crows().dtype(), "CsrToCooGPUKernel", ([&] {
                                 CsrToCooGPUKernel<T, data_t>(dev_ctx, x, out);
                               }));
339 340 341 342
}

template <typename IntT>
__global__ void GetBatchsOffset(const IntT* batchs_ptr,
Z
zhangkaihuo 已提交
343
                                const int batchs,
344
                                const int non_zero_num,
Z
zhangkaihuo 已提交
345
                                int* batchs_offset) {
346 347 348
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) {
    if (i == non_zero_num - 1 || batchs_ptr[i] != batchs_ptr[i + 1]) {
Z
zhangkaihuo 已提交
349 350 351 352 353
      const int start = batchs_ptr[i];
      const int end = i == non_zero_num - 1 ? batchs : batchs_ptr[i + 1];
      for (int j = start; j < end; j++) {
        batchs_offset[j] = i + 1;
      }
354 355 356 357
    }
  }
}

358
template <typename IntT>
359
__global__ void ConvertCooRowsToCsrCrows(
Z
zhangkaihuo 已提交
360
    const int* batchs_offset,  // can be null if batchs = 1
361 362
    const IntT* coo_rows_data,
    IntT* csr_crows_data,
363 364 365 366 367
    const int rows,
    const int64_t non_zero_num) {
  const int b = blockIdx.y;
  int batch_non_zero_num =
      batchs_offset == nullptr ? non_zero_num : batchs_offset[b];
368
  IntT batch_start = 0;
369 370 371 372
  if (b > 0) {
    batch_start = batchs_offset[b - 1];
    batch_non_zero_num -= batch_start;
  }
Z
zhangkaihuo 已提交
373

374
  const IntT* coo_rows_ptr = coo_rows_data + batch_start;
375 376 377
  const int tid = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = tid; i < batch_non_zero_num; i += gridDim.x * blockDim.x) {
    if (i == 0) {
378
      for (IntT j = 0; j <= coo_rows_ptr[0]; j++) {
379 380 381
        csr_crows_data[b * (rows + 1) + j] = 0;
      }
    } else {
382
      for (IntT j = coo_rows_ptr[i - 1]; j < coo_rows_ptr[i]; j++) {
383 384 385 386
        csr_crows_data[b * (rows + 1) + j + 1] = i;
      }
    }
    if (i == batch_non_zero_num - 1) {
387
      for (IntT i = coo_rows_ptr[batch_non_zero_num - 1] + 1; i < rows + 1;
388 389 390 391 392
           i++) {
        csr_crows_data[b * (rows + 1) + i] = batch_non_zero_num;
      }
    }
  }
Z
zhangkaihuo 已提交
393 394 395 396 397
  if (batch_non_zero_num == 0) {
    for (int i = tid; i < rows + 1; i += gridDim.x * blockDim.x) {
      csr_crows_data[b * (rows + 1) + i] = 0;
    }
  }
398 399
}

400
template <typename T, typename IntT>
401 402 403
void CooToCsrGPUKernel(const GPUContext& dev_ctx,
                       const SparseCooTensor& x,
                       SparseCsrTensor* out) {
404 405 406 407
  const auto& x_dims = x.dims();
  bool valid = x_dims.size() == 2 || x_dims.size() == 3;
  PADDLE_ENFORCE_EQ(valid,
                    true,
408
                    phi::errors::InvalidArgument(
409 410 411 412 413 414
                        "SparseCsrTensor only support 2-D or 3-D matrix"));
  const int64_t non_zero_num = x.nnz();

  int batchs = x_dims.size() == 2 ? 1 : x_dims[0];
  int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1];

415 416
  phi::DenseTensor crows = phi::Empty<IntT>(dev_ctx, {batchs * (rows + 1)});
  phi::DenseTensor cols = phi::Empty<IntT>(dev_ctx, {non_zero_num});
417
  phi::DenseTensor values = phi::EmptyLike<T, GPUContext>(dev_ctx, x.values());
Z
zhangkaihuo 已提交
418 419 420 421
  if (non_zero_num <= 0) {
    out->SetMember(crows, cols, values, x_dims);
    return;
  }
422 423
  IntT* csr_crows_data = crows.data<IntT>();
  IntT* csr_cols_data = cols.data<IntT>();
424
  T* csr_values_data = values.data<T>();
425

426
  const auto& coo_indices = x.indices();
427
  const auto& coo_values = x.values();
428 429
  const IntT* batchs_ptr = coo_indices.data<IntT>();
  const IntT* coo_rows_data =
Z
zhangkaihuo 已提交
430
      x_dims.size() == 2 ? batchs_ptr : batchs_ptr + non_zero_num;
431
  const IntT* coo_cols_data = coo_rows_data + non_zero_num;
432 433
  const T* coo_values_data = coo_values.data<T>();

434
  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, batchs, 1);
435
  if (batchs > 1) {
Z
zhangkaihuo 已提交
436 437 438 439 440 441 442 443 444 445 446 447 448
    auto config =
        phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1);
    phi::DenseTensor batchs_offset = phi::Empty<int>(dev_ctx, {batchs});
    int* batchs_offset_ptr = batchs_offset.data<int>();
    phi::funcs::SetConstant<GPUContext, int> set_zero;
    // set zero if the nnz=0 of batchs[0]
    set_zero(dev_ctx, &batchs_offset, static_cast<IntT>(0));
    GetBatchsOffset<IntT><<<config.block_per_grid.x,
                            config.thread_per_block.x,
                            0,
                            dev_ctx.stream()>>>(
        batchs_ptr, batchs, non_zero_num, batchs_offset_ptr);

449
    config.block_per_grid.y = batchs;
450 451 452 453
    ConvertCooRowsToCsrCrows<IntT><<<config.block_per_grid,
                                     config.thread_per_block.x,
                                     0,
                                     dev_ctx.stream()>>>(
454 455
        batchs_offset_ptr, coo_rows_data, csr_crows_data, rows, non_zero_num);
  } else {
456 457 458 459
    ConvertCooRowsToCsrCrows<IntT><<<config.block_per_grid.x,
                                     config.thread_per_block.x,
                                     0,
                                     dev_ctx.stream()>>>(
460 461 462
        nullptr, coo_rows_data, csr_crows_data, rows, non_zero_num);
  }

463 464 465 466 467 468 469 470 471 472
  phi::backends::gpu::GpuMemcpyAsync(csr_cols_data,
                                     coo_cols_data,
                                     sizeof(IntT) * non_zero_num,
                                     gpuMemcpyDeviceToDevice,
                                     dev_ctx.stream());
  phi::backends::gpu::GpuMemcpyAsync(csr_values_data,
                                     coo_values_data,
                                     sizeof(T) * non_zero_num,
                                     gpuMemcpyDeviceToDevice,
                                     dev_ctx.stream());
473
  out->SetMember(crows, cols, values, x_dims);
474 475
}

476
template <typename T, typename Context>
477 478 479 480 481 482
void CooToCsrKernel(const Context& dev_ctx,
                    const SparseCooTensor& x,
                    SparseCsrTensor* out) {
  PD_VISIT_BASE_INTEGRAL_TYPES(x.indices().dtype(), "CooToCsrGPUKernel", ([&] {
                                 CooToCsrGPUKernel<T, data_t>(dev_ctx, x, out);
                               }));
483 484
}

Z
zhangkaihuo 已提交
485
template <typename ValueT, typename IndicesT>
486 487 488 489 490 491 492
__global__ void KernelCooToDense(const IndicesT* indices,
                                 const int64_t* sparse_offsets,
                                 const ValueT* data,
                                 ValueT* dense_data,
                                 const IndicesT non_zero_num,
                                 const int64_t base_offset,
                                 const int64_t sparse_dim) {
Z
zhangkaihuo 已提交
493 494 495 496 497 498 499 500 501 502 503 504 505
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) {
    int64_t index = 0;
    for (int j = 0; j < sparse_dim; j++) {
      index += indices[j * non_zero_num + i] * sparse_offsets[j];
    }

    for (int j = 0; j < base_offset; j++) {
      dense_data[index * base_offset + j] = data[i * base_offset + j];
    }
  }
}

506
template <typename T, typename IntT>
507 508 509
void CooToDenseGPUKernel(const GPUContext& dev_ctx,
                         const SparseCooTensor& x,
                         DenseTensor* out) {
Z
zhangkaihuo 已提交
510 511
  const auto non_zero_num = x.nnz();
  const auto dense_dims = x.dims();
512
  const auto indices = x.indices();
513
  const auto values = x.values();
Z
zhangkaihuo 已提交
514 515 516 517 518 519 520 521
  const auto indices_dims = indices.dims();
  int64_t sparse_dim = indices_dims[0];
  if (indices_dims.size() == 1) {
    sparse_dim = 1;
  }
  const int64_t dense_dim = values.dims().size() - 1;

  const auto place = dev_ctx.GetPlace();
Z
zhangkaihuo 已提交
522 523
  dev_ctx.template Alloc<T>(out);

Z
zhangkaihuo 已提交
524
  T* out_data = out->data<T>();
Z
zhangkaihuo 已提交
525 526 527 528 529 530 531 532
  phi::backends::gpu::GpuMemsetAsync(
      out_data, 0, sizeof(T) * out->numel(), dev_ctx.stream());

  if (x.nnz() <= 0) {
    return;
  }

  const T* x_data = values.data<T>();
Z
zhangkaihuo 已提交
533 534 535 536 537 538 539 540 541 542 543
  int64_t base_offset = 1;
  for (int64_t i = 0; i < dense_dim; i++) {
    base_offset *= dense_dims[sparse_dim + i];
  }
  std::vector<int64_t> sparse_offsets(sparse_dim);
  int64_t offset = 1;
  for (int i = sparse_dim - 1; i >= 0; i--) {
    sparse_offsets[i] = offset;
    offset *= dense_dims[i];
  }

544 545 546 547 548 549 550
  DenseTensor d_sparse_offsets = Empty<int64_t>(dev_ctx, {sparse_dim});

  phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data<int64_t>(),
                                     sparse_offsets.data(),
                                     sparse_dim * sizeof(int64_t),
                                     gpuMemcpyHostToDevice,
                                     dev_ctx.stream());
Z
zhangkaihuo 已提交
551

552 553
  auto config =
      phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1);
Z
zhangkaihuo 已提交
554

555
  KernelCooToDense<T, IntT>
556 557 558
      <<<config.block_per_grid.x,
         config.thread_per_block.x,
         0,
559
         dev_ctx.stream()>>>(indices.data<IntT>(),
560 561 562 563 564 565
                             d_sparse_offsets.data<int64_t>(),
                             x_data,
                             out_data,
                             non_zero_num,
                             base_offset,
                             sparse_dim);
Z
zhangkaihuo 已提交
566 567
}

568
template <typename T, typename Context>
569 570 571
void CooToDenseKernel(const Context& dev_ctx,
                      const SparseCooTensor& x,
                      DenseTensor* out) {
Z
zhangkaihuo 已提交
572
  PD_VISIT_BASE_INTEGRAL_TYPES(
573 574
      x.indices().dtype(), "CooToDenseGPUKernel", ([&] {
        CooToDenseGPUKernel<T, data_t>(dev_ctx, x, out);
575 576 577
      }));
}

578
}  // namespace sparse
579
}  // namespace phi
580

581
PD_REGISTER_KERNEL(dense_to_coo,
582 583
                   GPU,
                   ALL_LAYOUT,
584
                   phi::sparse::DenseToCooKernel,
585 586
                   float,
                   double,
587
                   phi::dtype::float16,
588 589 590 591 592
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {}
593

594
PD_REGISTER_KERNEL(csr_to_coo,
595 596
                   GPU,
                   ALL_LAYOUT,
597
                   phi::sparse::CsrToCooKernel,
598 599
                   float,
                   double,
600
                   phi::dtype::float16,
601 602 603 604
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
605 606
                   int64_t,
                   bool) {}
607

608
PD_REGISTER_KERNEL(coo_to_csr,
609 610
                   GPU,
                   ALL_LAYOUT,
611
                   phi::sparse::CooToCsrKernel,
612 613
                   float,
                   double,
614
                   phi::dtype::float16,
615 616 617 618
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
619 620
                   int64_t,
                   bool) {}
621

622
PD_REGISTER_KERNEL(dense_to_csr,
623 624
                   GPU,
                   ALL_LAYOUT,
625
                   phi::sparse::DenseToCsrKernel,
626 627
                   float,
                   double,
628
                   phi::dtype::float16,
629 630 631 632 633
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {}
Z
zhangkaihuo 已提交
634

635
PD_REGISTER_KERNEL(coo_to_dense,
Z
zhangkaihuo 已提交
636 637
                   GPU,
                   ALL_LAYOUT,
638
                   phi::sparse::CooToDenseKernel,
Z
zhangkaihuo 已提交
639 640
                   float,
                   double,
641
                   phi::dtype::float16,
Z
zhangkaihuo 已提交
642 643 644 645
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
646 647
                   int64_t,
                   bool) {}
Z
zhangkaihuo 已提交
648

649
PD_REGISTER_KERNEL(csr_to_dense,
Z
zhangkaihuo 已提交
650 651
                   GPU,
                   ALL_LAYOUT,
652
                   phi::sparse::CsrToDenseKernel,
Z
zhangkaihuo 已提交
653 654
                   float,
                   double,
655
                   phi::dtype::float16,
Z
zhangkaihuo 已提交
656 657 658 659
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
660 661
                   int64_t,
                   bool) {}
662

663
PD_REGISTER_KERNEL(values_coo,
664 665
                   GPU,
                   ALL_LAYOUT,
666
                   phi::sparse::ValuesCooKernel,
667 668 669 670 671 672 673
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
674 675
                   int64_t,
                   bool) {
676 677 678
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

679
PD_REGISTER_KERNEL(values_csr,
680 681
                   GPU,
                   ALL_LAYOUT,
682
                   phi::sparse::ValuesCsrKernel,
683 684 685 686 687 688 689
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
690 691
                   int64_t,
                   bool) {
Z
zhangkaihuo 已提交
692 693 694 695 696 697 698 699 700 701 702 703 704 705 706
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}

PD_REGISTER_KERNEL(indices_coo,
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::IndicesCooKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {
707 708
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
709 710 711 712 713 714 715 716 717 718 719 720

PD_REGISTER_KERNEL(sparse_coo_tensor,
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::SparseCooTensorKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int16_t,
                   int,
                   int64_t) {}