sparse_utils_kernel.h 5.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include "paddle/phi/common/int_array.h"
18 19 20 21
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h"
22

23
namespace phi {
24 25 26
namespace sparse {

template <typename T, typename Context>
27 28 29 30
void DenseToCooKernel(const Context& dev_ctx,
                      const DenseTensor& x,
                      const int64_t sparse_dim,
                      SparseCooTensor* out);
31 32

template <typename T, typename Context>
33 34 35
SparseCooTensor DenseToCoo(const Context& dev_ctx,
                           const DenseTensor& x,
                           const int64_t sparse_dim) {
36 37
  DenseTensor indices;
  DenseTensor values;
38
  SparseCooTensor coo(indices, values, x.dims());
39
  DenseToCooKernel<T, Context>(dev_ctx, x, sparse_dim, &coo);
40 41 42
  return coo;
}

43
template <typename T, typename Context>
44 45 46
void CsrToCooKernel(const Context& dev_ctx,
                    const SparseCsrTensor& x,
                    SparseCooTensor* out);
47 48

template <typename T, typename Context>
49
SparseCooTensor CsrToCoo(const Context& dev_ctx, const SparseCsrTensor& x) {
50 51
  DenseTensor indices;
  DenseTensor values;
52
  SparseCooTensor coo(indices, values, x.dims());
53
  CsrToCooKernel<T, Context>(dev_ctx, x, &coo);
54 55 56
  return coo;
}

57
template <typename T, typename Context>
58 59 60
void CooToCsrKernel(const Context& dev_ctx,
                    const SparseCooTensor& x,
                    SparseCsrTensor* out);
61 62

template <typename T, typename Context>
63
SparseCsrTensor CooToCsr(const Context& dev_ctx, const SparseCooTensor& x) {
64 65
  DenseTensor crows;
  DenseTensor cols;
66
  DenseTensor non_zero_elements;
67
  SparseCsrTensor csr(crows, cols, non_zero_elements, x.dims());
68
  CooToCsrKernel<T, Context>(dev_ctx, x, &csr);
69 70 71 72
  return csr;
}

template <typename T, typename Context>
73 74 75
void DenseToCsrKernel(const Context& dev_ctx,
                      const DenseTensor& x,
                      SparseCsrTensor* out) {
76 77 78 79
  const auto& x_dims = x.dims();
  bool valid = x_dims.size() == 2 || x_dims.size() == 3;
  PADDLE_ENFORCE_EQ(valid,
                    true,
80
                    phi::errors::InvalidArgument(
81 82
                        "SparseCsrTensor only support 2-D or 3-D Tensor."));
  const int64_t sparse_dim = x_dims.size() == 2 ? 2 : 3;
83 84
  DenseTensor indices;
  DenseTensor values;
85
  SparseCooTensor coo(indices, values, x.dims());
86 87
  DenseToCooKernel<T, Context>(dev_ctx, x, sparse_dim, &coo);
  CooToCsrKernel<T, Context>(dev_ctx, coo, out);
88 89 90
}

template <typename T, typename Context>
91
SparseCsrTensor DenseToCsr(const Context& dev_ctx, const DenseTensor& x) {
92 93
  DenseTensor crows;
  DenseTensor cols;
94
  DenseTensor non_zero_elements;
95
  SparseCsrTensor csr(crows, cols, non_zero_elements, x.dims());
96
  DenseToCsrKernel<T, Context>(dev_ctx, x, &csr);
97 98 99
  return csr;
}

Z
zhangkaihuo 已提交
100
template <typename T, typename Context>
101 102 103
void CooToDenseKernel(const Context& dev_ctx,
                      const SparseCooTensor& x,
                      DenseTensor* out);
Z
zhangkaihuo 已提交
104 105

template <typename T, typename Context>
106
DenseTensor CooToDense(const Context& dev_ctx, const SparseCooTensor& x) {
Z
zhangkaihuo 已提交
107
  DenseTensorMeta meta(x.dtype(), x.dims(), x.non_zero_elements().layout());
108
  DenseTensor dense = phi::Empty(dev_ctx, std::move(meta));
109
  CooToDenseKernel<T, Context>(dev_ctx, x, &dense);
Z
zhangkaihuo 已提交
110 111 112 113
  return dense;
}

template <typename T, typename Context>
114 115 116
void CsrToDenseKernel(const Context& dev_ctx,
                      const SparseCsrTensor& x,
                      DenseTensor* out) {
117 118
  DenseTensor indices;
  DenseTensor values;
Z
zhangkaihuo 已提交
119
  SparseCooTensor coo(indices, values, x.dims());
120 121
  CsrToCooKernel<T, Context>(dev_ctx, x, &coo);
  CooToDenseKernel<T, Context>(dev_ctx, coo, out);
Z
zhangkaihuo 已提交
122 123 124
}

template <typename T, typename Context>
125
DenseTensor CsrToDense(const Context& dev_ctx, const SparseCsrTensor& x) {
Z
zhangkaihuo 已提交
126
  DenseTensorMeta meta(x.dtype(), x.dims(), x.non_zero_elements().layout());
127
  DenseTensor dense = phi::Empty(dev_ctx, std::move(meta));
128
  CsrToDenseKernel<T, Context>(dev_ctx, x, &dense);
Z
zhangkaihuo 已提交
129 130 131
  return dense;
}

132
template <typename T, typename Context>
133
void ValuesCooKernel(const Context& dev_ctx,
134 135 136 137 138 139
                     const SparseCooTensor& x,
                     DenseTensor* out) {
  *out = x.non_zero_elements();
}

template <typename T, typename Context>
140
void ValuesCsrKernel(const Context& dev_ctx,
141 142 143 144 145
                     const SparseCsrTensor& x,
                     DenseTensor* out) {
  *out = x.non_zero_elements();
}

146 147 148 149 150 151
template <typename T, typename Context>
void SparseCooTensorKernel(const Context& dev_ctx,
                           const DenseTensor& values,
                           const DenseTensor& indices,
                           const IntArray& dense_shape,
                           SparseCooTensor* out) {
Z
zhangkaihuo 已提交
152 153
  *out =
      SparseCooTensor(indices, values, phi::make_ddim(dense_shape.GetData()));
154 155
}

156
}  // namespace sparse
157
}  // namespace phi