sparse_utils_kernel.cc 14.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
16

17 18 19
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
20
#include "paddle/phi/core/visit_type.h"
21
#include "paddle/phi/kernels/funcs/sparse/common_shape.h"
22

23
namespace phi {
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
namespace sparse {

template <typename T>
inline bool IsZero(const T* data, const size_t n) {
  const T zero = static_cast<T>(0);
  for (size_t i = 0; i < n; i++) {
    if (data[i] != zero) {
      return false;
    }
  }
  return true;
}

// TODO(zhangkaihuo): implement a kernel to count the number of non-zero
// elements in tensor
template <typename T>
inline int64_t GetNonZeroNum(const DenseTensor& dense,
                             const int64_t sparse_dim) {
  const auto& dims = dense.dims();
  PADDLE_ENFORCE_GE(
      dims.size(),
      sparse_dim,
46
      phi::errors::InvalidArgument(
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
          "sparse_dim(%d) should be less than or equal to dense.dim(%d)",
          sparse_dim,
          dims.size()));

  auto dims_2d = flatten_to_2d(dims, sparse_dim);
  const int rows = dims_2d[0];
  const int cols = dims_2d[1];

  const T* data = dense.data<T>();
  int64_t non_zero_num = 0;
  for (int64_t i = 0; i < rows; i++) {
    if (!IsZero(data + i * cols, cols)) {
      non_zero_num = non_zero_num + 1;
    }
  }
  return non_zero_num;
}

template <typename T, typename Context>
66 67 68 69
void DenseToCooKernel(const Context& dev_ctx,
                      const DenseTensor& x,
                      const int64_t sparse_dim,
                      SparseCooTensor* out) {
70 71
  const T* x_data = x.data<T>();
  const auto& x_dims = x.dims();
72 73 74 75 76 77
  PADDLE_ENFORCE_LE(sparse_dim,
                    x_dims.size(),
                    phi::errors::InvalidArgument(
                        "sparse_dim must be less than the size of x.dims()"));
  PADDLE_ENFORCE_GT(
      sparse_dim, 0, phi::errors::InvalidArgument("sparse_dim must be >0"));
78 79 80

  int64_t non_zero_num = GetNonZeroNum<T>(x, sparse_dim);

81 82
  const auto values_dims =
      phi::funcs::sparse::InferDenseDims(x_dims, sparse_dim, non_zero_num);
83
  DenseTensorMeta values_meta(x.meta().dtype, values_dims, x.meta().layout);
84 85
  phi::DenseTensor indices =
      phi::Empty<int64_t>(dev_ctx, {sparse_dim, non_zero_num});
86
  phi::DenseTensor values = phi::Empty(dev_ctx, std::move(values_meta));
87 88
  int64_t* indices_data = indices.data<int64_t>();
  T* values_data = values.data<T>();
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105

  auto dims_2d = flatten_to_2d(x_dims, sparse_dim);
  const int rows = dims_2d[0];
  const int cols = dims_2d[1];

  int index = 0;
  for (int i = 0; i < rows; i++) {
    if (!IsZero(x_data + i * cols, cols)) {
      int64_t sparse_index = i;
      for (int64_t j = sparse_dim - 1; j >= 0; j--) {
        indices_data[j * non_zero_num + index] = sparse_index % x_dims[j];
        sparse_index /= x_dims[j];
      }
      memcpy(values_data + index * cols, x_data + i * cols, cols * sizeof(T));
      ++index;
    }
  }
106

107 108 109
  out->SetMember(indices, values, x_dims, true);
}

110
template <typename T, typename IntT>
111 112 113
void CsrToCooCPUKernel(const CPUContext& dev_ctx,
                       const SparseCsrTensor& x,
                       SparseCooTensor* out) {
114
  const DDim& x_dims = x.dims();
115 116 117 118
  const int64_t non_zero_num = x.cols().numel();
  const auto& csr_crows = x.crows();
  const auto& csr_cols = x.cols();
  const auto& csr_values = x.values();
119 120
  const IntT* csr_crows_data = csr_crows.data<IntT>();
  const IntT* csr_cols_data = csr_cols.data<IntT>();
121 122 123 124 125 126
  const T* csr_values_data = csr_values.data<T>();

  int64_t sparse_dim = 2;
  if (x_dims.size() == 3) {
    sparse_dim = 3;
  }
127 128 129 130 131 132
  phi::DenseTensor indices =
      phi::Empty<IntT>(dev_ctx, {sparse_dim, non_zero_num});
  phi::DenseTensor values = phi::Empty<T>(dev_ctx, {non_zero_num});
  IntT* coo_indices = indices.data<IntT>();
  IntT* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices;
  IntT* coo_rows_data =
133
      x_dims.size() == 2 ? coo_indices : batch_ptr + non_zero_num;
134 135
  IntT* coo_cols_data = coo_rows_data + non_zero_num;
  T* coo_values_data = values.data<T>();
136 137 138 139 140 141 142

  int batch = x_dims.size() == 2 ? 1 : x_dims[0];
  int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1];

  int index = 0;
  for (int b = 0; b < batch; b++) {
    for (int i = 0; i < rows; i++) {
143
      for (IntT j = csr_crows_data[b * (rows + 1) + i];
144 145 146 147 148 149 150 151 152 153 154
           j < csr_crows_data[b * (rows + 1) + i + 1];
           j++) {
        coo_rows_data[index] = i;
        if (batch_ptr) {
          batch_ptr[index] = b;
        }
        ++index;
      }
    }
  }

155
  memcpy(coo_cols_data, csr_cols_data, sizeof(IntT) * non_zero_num);
156 157 158 159
  memcpy(coo_values_data, csr_values_data, sizeof(T) * non_zero_num);
  out->SetMember(indices, values, x_dims, true);
}

160
template <typename T, typename Context>
161 162 163 164 165 166
void CsrToCooKernel(const Context& dev_ctx,
                    const SparseCsrTensor& x,
                    SparseCooTensor* out) {
  PD_VISIT_BASE_INTEGRAL_TYPES(x.crows().dtype(), "CsrToCooCPUKernel", ([&] {
                                 CsrToCooCPUKernel<T, data_t>(dev_ctx, x, out);
                               }));
167 168 169
}

template <typename T, typename IntT>
170 171 172
void CooToCsrCPUKernel(const CPUContext& dev_ctx,
                       const SparseCooTensor& x,
                       SparseCsrTensor* out) {
173 174 175 176
  const auto& x_dims = x.dims();
  bool valid = x_dims.size() == 2 || x_dims.size() == 3;
  PADDLE_ENFORCE_EQ(valid,
                    true,
177
                    phi::errors::InvalidArgument(
178 179 180 181 182 183 184
                        "SparseCsrTensor only support 2-D or 3-D matrix"));
  const int64_t non_zero_num = x.nnz();
  if (non_zero_num <= 0) return;

  int batchs = x_dims.size() == 2 ? 1 : x_dims[0];
  int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1];

185 186 187 188 189 190
  phi::DenseTensor crows = phi::Empty<IntT>(dev_ctx, {batchs * (rows + 1)});
  phi::DenseTensor cols = phi::Empty<IntT>(dev_ctx, {non_zero_num});
  phi::DenseTensor values = phi::EmptyLike<T, CPUContext>(dev_ctx, x.values());
  IntT* csr_crows_data = crows.data<IntT>();
  IntT* csr_cols_data = cols.data<IntT>();
  T* csr_values_data = values.data<T>();
191

192 193
  const auto& coo_indices = x.indices();
  const auto& coo_values = x.values();
194 195
  const IntT* batchs_ptr = coo_indices.data<IntT>();
  const IntT* coo_rows_data =
Z
zhangkaihuo 已提交
196
      x_dims.size() == 2 ? batchs_ptr : batchs_ptr + non_zero_num;
197
  const IntT* coo_cols_data = coo_rows_data + non_zero_num;
198 199 200 201 202 203
  const T* coo_values_data = coo_values.data<T>();

  std::vector<int64_t> offsets(batchs, 0);
  if (batchs > 1) {
    for (int i = 0; i < non_zero_num; i++) {
      if (i == non_zero_num - 1 || batchs_ptr[i] != batchs_ptr[i + 1]) {
Z
zhangkaihuo 已提交
204 205 206 207 208
        const int start = batchs_ptr[i];
        const int end = i == non_zero_num - 1 ? batchs : batchs_ptr[i + 1];
        for (int j = start; j < end; j++) {
          offsets[j] = i + 1;
        }
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
      }
    }
  } else {
    offsets[0] = non_zero_num;
  }

  for (int b = 0; b < batchs; b++) {
    int batch_start = 0;
    int batch_non_zero_num = offsets[b];
    if (b > 0) {
      batch_start = offsets[b - 1];
      batch_non_zero_num -= batch_start;
    }
    auto* coo_rows_ptr = coo_rows_data + batch_start;
    for (int i = 0; i <= coo_rows_ptr[0]; i++) {
      csr_crows_data[b * (rows + 1) + i] = 0;
    }
    for (int64_t i = 1; i < batch_non_zero_num; i++) {
227
      for (IntT j = coo_rows_ptr[i - 1]; j < coo_rows_ptr[i]; j++) {
228 229 230
        csr_crows_data[b * (rows + 1) + j + 1] = i;
      }
    }
231
    for (IntT i = coo_rows_ptr[batch_non_zero_num - 1] + 1; i < rows + 1; i++) {
232 233
      csr_crows_data[b * (rows + 1) + i] = batch_non_zero_num;
    }
Z
zhangkaihuo 已提交
234 235 236
    if (batch_non_zero_num == 0) {
      memset(csr_crows_data + b * (rows + 1), 0, sizeof(IntT) * (rows + 1));
    }
237 238
  }

239
  memcpy(csr_cols_data, coo_cols_data, sizeof(IntT) * non_zero_num);
240
  memcpy(csr_values_data, coo_values_data, sizeof(T) * non_zero_num);
241
  out->SetMember(crows, cols, values, x_dims);
242 243
}

Z
zhangkaihuo 已提交
244
template <typename T, typename Context>
245 246 247 248 249 250
void CooToCsrKernel(const Context& dev_ctx,
                    const SparseCooTensor& x,
                    SparseCsrTensor* out) {
  PD_VISIT_BASE_INTEGRAL_TYPES(x.indices().dtype(), "CooToCsrCPUKernel", ([&] {
                                 CooToCsrCPUKernel<T, data_t>(dev_ctx, x, out);
                               }));
251 252 253
}

template <typename T, typename IntT>
254 255 256
void CooToDenseCPUKernel(const CPUContext& dev_ctx,
                         const SparseCooTensor& x,
                         DenseTensor* out) {
Z
zhangkaihuo 已提交
257 258
  const auto non_zero_num = x.nnz();
  const auto dense_dims = x.dims();
259 260
  const auto indices = x.indices();
  const auto values = x.values();
Z
zhangkaihuo 已提交
261 262 263 264 265
  const auto indices_dims = indices.dims();
  int64_t sparse_dim = indices_dims[0];
  if (indices_dims.size() == 1) {
    sparse_dim = 1;
  }
Z
zhangkaihuo 已提交
266
  const int64_t dense_dim = x.dense_dim();
Z
zhangkaihuo 已提交
267 268

  const T* x_data = values.data<T>();
269
  dev_ctx.template Alloc<T>(out);
Z
zhangkaihuo 已提交
270
  T* out_data = out->data<T>();
Z
zhangkaihuo 已提交
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
  int64_t base_offset = 1;
  for (int64_t i = 0; i < dense_dim; i++) {
    base_offset *= dense_dims[sparse_dim + i];
  }
  std::vector<int64_t> sparse_offsets(sparse_dim);
  int64_t offset = 1;
  for (int i = sparse_dim - 1; i >= 0; i--) {
    sparse_offsets[i] = offset;
    offset *= dense_dims[i];
  }

  memset(out_data, 0, sizeof(T) * out->numel());
  for (auto i = 0; i < non_zero_num; i++) {
    int64_t index = 0;
    for (int j = 0; j < sparse_dim; j++) {
286
      index += indices.data<IntT>()[j * non_zero_num + i] * sparse_offsets[j];
Z
zhangkaihuo 已提交
287 288 289 290 291 292 293 294
    }

    for (int j = 0; j < base_offset; j++) {
      out_data[index * base_offset + j] = x_data[i * base_offset + j];
    }
  }
}

295
template <typename T, typename Context>
296 297 298
void CooToDenseKernel(const Context& dev_ctx,
                      const SparseCooTensor& x,
                      DenseTensor* out) {
299
  PD_VISIT_BASE_INTEGRAL_TYPES(
300 301
      x.indices().dtype(), "CooToDenseCPUKernel", ([&] {
        CooToDenseCPUKernel<T, data_t>(dev_ctx, x, out);
302 303 304
      }));
}

305
}  // namespace sparse
306
}  // namespace phi
307

308
PD_REGISTER_KERNEL(dense_to_coo,
309 310
                   CPU,
                   ALL_LAYOUT,
311
                   phi::sparse::DenseToCooKernel,
312 313 314 315 316 317 318 319
                   float,
                   double,
                   paddle::float16,
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {}
320

321
PD_REGISTER_KERNEL(csr_to_coo,
322 323
                   CPU,
                   ALL_LAYOUT,
324
                   phi::sparse::CsrToCooKernel,
325 326 327 328 329 330 331
                   float,
                   double,
                   paddle::float16,
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
332 333
                   int64_t,
                   bool) {}
334

335
PD_REGISTER_KERNEL(coo_to_csr,
336 337
                   CPU,
                   ALL_LAYOUT,
338
                   phi::sparse::CooToCsrKernel,
339 340
                   float,
                   double,
341
                   phi::dtype::float16,
342 343 344 345
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
346 347
                   int64_t,
                   bool) {}
348

349
PD_REGISTER_KERNEL(dense_to_csr,
350 351
                   CPU,
                   ALL_LAYOUT,
352
                   phi::sparse::DenseToCsrKernel,
353 354
                   float,
                   double,
355
                   phi::dtype::float16,
356 357 358 359 360
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {}
Z
zhangkaihuo 已提交
361

362
PD_REGISTER_KERNEL(coo_to_dense,
Z
zhangkaihuo 已提交
363 364
                   CPU,
                   ALL_LAYOUT,
365
                   phi::sparse::CooToDenseKernel,
Z
zhangkaihuo 已提交
366 367
                   float,
                   double,
368
                   phi::dtype::float16,
Z
zhangkaihuo 已提交
369 370 371 372 373 374
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {}

375
PD_REGISTER_KERNEL(csr_to_dense,
Z
zhangkaihuo 已提交
376 377
                   CPU,
                   ALL_LAYOUT,
378
                   phi::sparse::CsrToDenseKernel,
Z
zhangkaihuo 已提交
379 380
                   float,
                   double,
381
                   phi::dtype::float16,
Z
zhangkaihuo 已提交
382 383 384 385 386
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {}
387

388
PD_REGISTER_KERNEL(values_coo,
389 390
                   CPU,
                   ALL_LAYOUT,
391
                   phi::sparse::ValuesCooKernel,
392 393 394 395 396 397 398 399 400 401 402
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

403 404 405 406 407 408 409 410 411 412 413 414 415 416 417
PD_REGISTER_KERNEL(indices_coo,
                   CPU,
                   ALL_LAYOUT,
                   phi::sparse::IndicesCooKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

418
PD_REGISTER_KERNEL(values_csr,
419 420
                   CPU,
                   ALL_LAYOUT,
421
                   phi::sparse::ValuesCsrKernel,
422 423 424 425 426 427 428 429
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {
430
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
431
}
432 433 434 435 436 437 438 439 440 441 442 443

PD_REGISTER_KERNEL(sparse_coo_tensor,
                   CPU,
                   ALL_LAYOUT,
                   phi::sparse::SparseCooTensorKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   uint8_t,
                   int16_t,
                   int,
                   int64_t) {}