sparse_api_custom_impl.cc 6.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/api/lib/sparse_api_custom_impl.h"
16 17 18

#include <memory>
#include "glog/logging.h"
19 20 21 22
#include "paddle/phi/api/lib/api_registry.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/core/kernel_registry.h"
23 24 25 26 27

namespace paddle {
namespace experimental {
namespace sparse {

28
Tensor to_sparse_coo_impl(const Tensor& x, const int64_t sparse_dim) {
29
  if (x.layout() == phi::DataLayout::SPARSE_COO) {
30 31
    return x;
  }
32

33 34
  // 1. Get kernel signature and kernel
  std::string kernel_name = "dense_to_sparse_coo";
35
  if (x.layout() == phi::DataLayout::SPARSE_CSR) {
36 37 38
    kernel_name = "sparse_csr_to_coo";
  }

39 40 41
  auto kernel_key_set = ParseKernelKeyByInputArgs(x);
  auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();

42
  auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
43 44
      kernel_name, kernel_key);

45
  VLOG(6) << "add API kernel key: " << kernel_key;
46 47 48 49
  VLOG(6) << "to API kernel: " << kernel;

  // 2. Get Device Context
  auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
50
  auto kernel_context = phi::KernelContext(dev_ctx);
51 52

  // 3. Auto data transform
53 54
  if (x.layout() == phi::DataLayout::SPARSE_CSR) {
    auto input = std::dynamic_pointer_cast<phi::SparseCsrTensor>(x.impl());
55 56
    kernel_context.EmplaceBackInput(input.get());
  } else {
57
    auto input = std::dynamic_pointer_cast<phi::DenseTensor>(x.impl());
58 59 60 61 62
    kernel_context.EmplaceBackInput(input.get());
    kernel_context.EmplaceBackAttr(sparse_dim);
  }

  // 4. InferMeta
63
  auto indices_meta =
64 65
      phi::DenseTensorMeta(phi::DataType::INT64, {1}, phi::DataLayout::NCHW);
  auto elements_meta = phi::DenseTensorMeta(x.dtype(), {1}, x.layout());
66 67 68

  // 5. Prepare outputs
  // create empty SparseCooTensor
69 70
  phi::DenseTensor non_zero_indices(
      phi::make_intrusive<paddle::experimental::SharedStorage>(
71
          phi::TransToPhiPlace(kernel_key.backend())),
72
      std::move(indices_meta));
73 74
  phi::DenseTensor non_zero_elements(
      phi::make_intrusive<paddle::experimental::SharedStorage>(
75
          phi::TransToPhiPlace(kernel_key.backend())),
76
      std::move(elements_meta));
77
  auto coo = std::make_shared<phi::SparseCooTensor>(
78 79 80 81 82 83 84 85 86 87 88 89
      non_zero_indices, non_zero_elements, x.dims());

  kernel_context.EmplaceBackOutput(coo.get());
  Tensor out;
  out.set_impl(coo);

  // 6. Call kernel
  kernel(&kernel_context);

  return out;
}

90
Tensor to_sparse_csr_impl(const Tensor& x) {
91
  if (x.layout() == phi::DataLayout::SPARSE_CSR) {
92 93 94 95
    return x;
  }
  // 1. Get kernel signature and kernel
  std::string kernel_name = "dense_to_sparse_csr";
96
  if (x.layout() == phi::DataLayout::SPARSE_COO) {
97 98 99
    kernel_name = "sparse_coo_to_csr";
  }

100 101 102
  auto kernel_key_set = ParseKernelKeyByInputArgs(x);
  auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();

103
  auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
104 105
      kernel_name, kernel_key);

106
  VLOG(6) << "add API kernel key: " << kernel_key;
107 108 109 110
  VLOG(6) << "to API kernel: " << kernel;

  // 2. Get Device Context
  auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
111
  auto kernel_context = phi::KernelContext(dev_ctx);
112 113

  // 3. Auto data transform
114 115
  if (x.layout() == phi::DataLayout::SPARSE_COO) {
    auto input = std::dynamic_pointer_cast<phi::SparseCooTensor>(x.impl());
116 117
    kernel_context.EmplaceBackInput(input.get());
  } else {
118
    auto input = std::dynamic_pointer_cast<phi::DenseTensor>(x.impl());
119 120 121 122
    kernel_context.EmplaceBackInput(input.get());
  }

  // 4. InferMeta
123
  auto crows_meta =
124
      phi::DenseTensorMeta(phi::DataType::INT64, {1}, phi::DataLayout::NCHW);
125
  auto cols_meta =
126 127
      phi::DenseTensorMeta(phi::DataType::INT64, {1}, phi::DataLayout::NCHW);
  auto elements_meta = phi::DenseTensorMeta(x.dtype(), {1}, x.layout());
128 129 130

  // 5. Prepare outputs
  // create empty SparseCooTensor
131 132
  phi::DenseTensor non_zero_crows(
      phi::make_intrusive<paddle::experimental::SharedStorage>(
133
          phi::TransToPhiPlace(kernel_key.backend())),
134
      std::move(crows_meta));
135 136
  phi::DenseTensor non_zero_cols(
      phi::make_intrusive<paddle::experimental::SharedStorage>(
137
          phi::TransToPhiPlace(kernel_key.backend())),
138
      std::move(cols_meta));
139 140
  phi::DenseTensor non_zero_elements(
      phi::make_intrusive<paddle::experimental::SharedStorage>(
141
          phi::TransToPhiPlace(kernel_key.backend())),
142
      std::move(elements_meta));
143
  auto csr = std::make_shared<phi::SparseCsrTensor>(
144 145 146 147 148 149 150 151 152 153 154
      non_zero_crows, non_zero_cols, non_zero_elements, x.dims());

  kernel_context.EmplaceBackOutput(csr.get());
  Tensor out;
  out.set_impl(csr);

  // 6. Call kernel
  kernel(&kernel_context);

  return out;
}
Z
zhangkaihuo 已提交
155

156
Tensor to_dense_impl(const Tensor& x) {
157 158
  if (x.layout() != phi::DataLayout::SPARSE_CSR &&
      x.layout() != phi::DataLayout::SPARSE_COO) {
Z
zhangkaihuo 已提交
159 160
    return x;
  }
161

Z
zhangkaihuo 已提交
162 163
  // 1. Get kernel signature and kernel
  std::string kernel_name = "sparse_coo_to_dense";
164
  if (x.layout() == phi::DataLayout::SPARSE_CSR) {
Z
zhangkaihuo 已提交
165 166 167
    kernel_name = "sparse_csr_to_dense";
  }

168 169 170
  auto kernel_key_set = ParseKernelKeyByInputArgs(x);
  auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();

171
  auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
Z
zhangkaihuo 已提交
172 173
      kernel_name, kernel_key);

174
  VLOG(6) << "add API kernel key: " << kernel_key;
Z
zhangkaihuo 已提交
175 176 177 178
  VLOG(6) << "to API kernel: " << kernel;

  // 2. Get Device Context
  auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
179
  auto kernel_context = phi::KernelContext(dev_ctx);
Z
zhangkaihuo 已提交
180 181

  // 3. Auto data transform
182 183
  if (x.layout() == phi::DataLayout::SPARSE_COO) {
    auto input = std::dynamic_pointer_cast<phi::SparseCooTensor>(x.impl());
Z
zhangkaihuo 已提交
184 185
    kernel_context.EmplaceBackInput(input.get());
  } else {
186
    auto input = std::dynamic_pointer_cast<phi::SparseCsrTensor>(x.impl());
Z
zhangkaihuo 已提交
187 188 189 190
    kernel_context.EmplaceBackInput(input.get());
  }

  // 4. InferMeta
191
  auto dense_meta = phi::DenseTensorMeta(x.dtype(), x.dims(), x.layout());
Z
zhangkaihuo 已提交
192 193 194

  // 5. Prepare outputs
  // create empty SparseCooTensor
195 196
  auto dense_out = std::make_shared<phi::DenseTensor>(
      phi::make_intrusive<paddle::experimental::SharedStorage>(
197
          phi::TransToPhiPlace(kernel_key.backend())),
Z
zhangkaihuo 已提交
198 199 200 201 202 203 204 205 206 207 208 209
      std::move(dense_meta));

  kernel_context.EmplaceBackOutput(dense_out.get());
  Tensor out;
  out.set_impl(dense_out);

  // 6. Call kernel
  kernel(&kernel_context);

  return out;
}

210 211 212 213
}  // namespace sparse
}  // namespace experimental
}  // namespace paddle

214
PD_REGISTER_API(SparseApi);