sparse_utils_kernel.h 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include "paddle/phi/common/int_array.h"
18 19 20
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
Z
zhangkaihuo 已提交
21
#include "paddle/phi/infermeta/unary.h"
22
#include "paddle/phi/kernels/empty_kernel.h"
23

24
namespace phi {
25 26 27
namespace sparse {

template <typename T, typename Context>
28 29 30 31
void DenseToCooKernel(const Context& dev_ctx,
                      const DenseTensor& x,
                      const int64_t sparse_dim,
                      SparseCooTensor* out);
32 33

template <typename T, typename Context>
34 35 36
SparseCooTensor DenseToCoo(const Context& dev_ctx,
                           const DenseTensor& x,
                           const int64_t sparse_dim) {
37 38
  DenseTensor indices;
  DenseTensor values;
39
  SparseCooTensor coo(indices, values, x.dims());
Z
zhangkaihuo 已提交
40 41
  MetaTensor meta_out(&coo);
  phi::UnchangedInferMeta(x, &meta_out);
42
  DenseToCooKernel<T, Context>(dev_ctx, x, sparse_dim, &coo);
43 44 45
  return coo;
}

46
template <typename T, typename Context>
47 48 49
void CsrToCooKernel(const Context& dev_ctx,
                    const SparseCsrTensor& x,
                    SparseCooTensor* out);
50 51

template <typename T, typename Context>
52
SparseCooTensor CsrToCoo(const Context& dev_ctx, const SparseCsrTensor& x) {
53 54
  DenseTensor indices;
  DenseTensor values;
55
  SparseCooTensor coo(indices, values, x.dims());
Z
zhangkaihuo 已提交
56 57
  MetaTensor meta_out(&coo);
  phi::UnchangedInferMeta(x, &meta_out);
58
  CsrToCooKernel<T, Context>(dev_ctx, x, &coo);
59 60 61
  return coo;
}

62
template <typename T, typename Context>
63 64 65
void CooToCsrKernel(const Context& dev_ctx,
                    const SparseCooTensor& x,
                    SparseCsrTensor* out);
66 67

template <typename T, typename Context>
68
SparseCsrTensor CooToCsr(const Context& dev_ctx, const SparseCooTensor& x) {
69 70
  DenseTensor crows;
  DenseTensor cols;
71
  DenseTensor non_zero_elements;
72
  SparseCsrTensor csr(crows, cols, non_zero_elements, x.dims());
Z
zhangkaihuo 已提交
73 74
  MetaTensor meta_out(&csr);
  phi::UnchangedInferMeta(x, &meta_out);
75
  CooToCsrKernel<T, Context>(dev_ctx, x, &csr);
76 77 78 79
  return csr;
}

template <typename T, typename Context>
80 81 82
void DenseToCsrKernel(const Context& dev_ctx,
                      const DenseTensor& x,
                      SparseCsrTensor* out) {
83 84 85 86
  const auto& x_dims = x.dims();
  bool valid = x_dims.size() == 2 || x_dims.size() == 3;
  PADDLE_ENFORCE_EQ(valid,
                    true,
87
                    phi::errors::InvalidArgument(
88
                        "SparseCsrTensor only support 2-D or 3-D Tensor."));
Z
zhangkaihuo 已提交
89

90
  const int64_t sparse_dim = x_dims.size() == 2 ? 2 : 3;
91 92
  DenseTensor indices;
  DenseTensor values;
93
  SparseCooTensor coo(indices, values, x.dims());
Z
zhangkaihuo 已提交
94 95
  MetaTensor meta_out(&coo);
  phi::UnchangedInferMeta(x, &meta_out);
96 97
  DenseToCooKernel<T, Context>(dev_ctx, x, sparse_dim, &coo);
  CooToCsrKernel<T, Context>(dev_ctx, coo, out);
98 99 100
}

template <typename T, typename Context>
101
SparseCsrTensor DenseToCsr(const Context& dev_ctx, const DenseTensor& x) {
102 103
  DenseTensor crows;
  DenseTensor cols;
104
  DenseTensor non_zero_elements;
105
  SparseCsrTensor csr(crows, cols, non_zero_elements, x.dims());
Z
zhangkaihuo 已提交
106 107
  MetaTensor meta_out(&csr);
  phi::UnchangedInferMeta(x, &meta_out);
108
  DenseToCsrKernel<T, Context>(dev_ctx, x, &csr);
109 110 111
  return csr;
}

Z
zhangkaihuo 已提交
112
template <typename T, typename Context>
113 114 115
void CooToDenseKernel(const Context& dev_ctx,
                      const SparseCooTensor& x,
                      DenseTensor* out);
Z
zhangkaihuo 已提交
116 117

template <typename T, typename Context>
118
DenseTensor CooToDense(const Context& dev_ctx, const SparseCooTensor& x) {
Z
zhangkaihuo 已提交
119
  DenseTensorMeta meta(x.dtype(), x.dims(), x.non_zero_elements().layout());
120
  DenseTensor dense = phi::Empty(dev_ctx, std::move(meta));
121
  CooToDenseKernel<T, Context>(dev_ctx, x, &dense);
Z
zhangkaihuo 已提交
122 123 124 125
  return dense;
}

template <typename T, typename Context>
126 127 128
void CsrToDenseKernel(const Context& dev_ctx,
                      const SparseCsrTensor& x,
                      DenseTensor* out) {
129 130
  DenseTensor indices;
  DenseTensor values;
Z
zhangkaihuo 已提交
131
  SparseCooTensor coo(indices, values, x.dims());
Z
zhangkaihuo 已提交
132 133
  MetaTensor meta_out(&coo);
  phi::UnchangedInferMeta(x, &meta_out);
134 135
  CsrToCooKernel<T, Context>(dev_ctx, x, &coo);
  CooToDenseKernel<T, Context>(dev_ctx, coo, out);
Z
zhangkaihuo 已提交
136 137 138
}

template <typename T, typename Context>
139
DenseTensor CsrToDense(const Context& dev_ctx, const SparseCsrTensor& x) {
Z
zhangkaihuo 已提交
140
  DenseTensorMeta meta(x.dtype(), x.dims(), x.non_zero_elements().layout());
141
  DenseTensor dense = phi::Empty(dev_ctx, std::move(meta));
142
  CsrToDenseKernel<T, Context>(dev_ctx, x, &dense);
Z
zhangkaihuo 已提交
143 144 145
  return dense;
}

146
template <typename T, typename Context>
147
void ValuesCooKernel(const Context& dev_ctx,
148 149 150 151 152 153
                     const SparseCooTensor& x,
                     DenseTensor* out) {
  *out = x.non_zero_elements();
}

template <typename T, typename Context>
154
void ValuesCsrKernel(const Context& dev_ctx,
155 156 157 158 159
                     const SparseCsrTensor& x,
                     DenseTensor* out) {
  *out = x.non_zero_elements();
}

Z
zhangkaihuo 已提交
160 161 162 163 164 165 166
template <typename T, typename Context>
void IndicesCooKernel(const Context& dev_ctx,
                      const SparseCooTensor& x,
                      DenseTensor* out) {
  *out = x.indices();
}

167 168 169 170 171 172
template <typename T, typename Context>
void SparseCooTensorKernel(const Context& dev_ctx,
                           const DenseTensor& values,
                           const DenseTensor& indices,
                           const IntArray& dense_shape,
                           SparseCooTensor* out) {
Z
zhangkaihuo 已提交
173 174
  *out =
      SparseCooTensor(indices, values, phi::make_ddim(dense_shape.GetData()));
175 176
}

177
}  // namespace sparse
178
}  // namespace phi