sparse_api_custom_impl.cc 7.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/api/lib/sparse_api_custom_impl.h"
16 17 18

#include <memory>
#include "glog/logging.h"
19 20 21 22
#include "paddle/phi/api/lib/api_registry.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/core/kernel_registry.h"
23 24 25 26 27

namespace paddle {
namespace experimental {
namespace sparse {

28 29 30
Tensor to_sparse_coo_impl(const Tensor& x,
                          Backend backend,
                          const int64_t sparse_dim) {
31
  if (x.layout() == phi::DataLayout::SPARSE_COO) {
32 33 34 35 36
    return x;
  }
  // 1. Get kernel signature and kernel
  auto kernel_key_set = ParseKernelKeyByInputArgs(x);
  kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
37
  auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
38
  std::string kernel_name = "dense_to_sparse_coo";
39
  if (x.layout() == phi::DataLayout::SPARSE_CSR) {
40 41 42
    kernel_name = "sparse_csr_to_coo";
  }

43
  auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
44 45 46 47 48 49 50
      kernel_name, kernel_key);

  VLOG(6) << "to API kernel key: " << kernel_key;
  VLOG(6) << "to API kernel: " << kernel;

  // 2. Get Device Context
  auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
51
  auto kernel_context = phi::KernelContext(dev_ctx);
52 53

  // 3. Auto data transform
54 55
  if (x.layout() == phi::DataLayout::SPARSE_CSR) {
    auto input = std::dynamic_pointer_cast<phi::SparseCsrTensor>(x.impl());
56 57
    kernel_context.EmplaceBackInput(input.get());
  } else {
58
    auto input = std::dynamic_pointer_cast<phi::DenseTensor>(x.impl());
59 60 61 62 63
    kernel_context.EmplaceBackInput(input.get());
    kernel_context.EmplaceBackAttr(sparse_dim);
  }

  // 4. InferMeta
64 65 66
  auto indices_meta =
      phi::DenseTensorMeta(phi::DataType::INT64, {-1}, phi::DataLayout::NCHW);
  auto elements_meta = phi::DenseTensorMeta(x.dtype(), {-1}, x.layout());
67 68 69

  // 5. Prepare outputs
  // create empty SparseCooTensor
70 71
  phi::DenseTensor non_zero_indices(
      phi::make_intrusive<paddle::experimental::SharedStorage>(
72
          phi::TransToPhiPlace(backend)),
73
      std::move(indices_meta));
74 75
  phi::DenseTensor non_zero_elements(
      phi::make_intrusive<paddle::experimental::SharedStorage>(
76
          phi::TransToPhiPlace(backend)),
77
      std::move(elements_meta));
78
  auto coo = std::make_shared<phi::SparseCooTensor>(
79 80 81 82 83 84 85 86 87 88 89 90
      non_zero_indices, non_zero_elements, x.dims());

  kernel_context.EmplaceBackOutput(coo.get());
  Tensor out;
  out.set_impl(coo);

  // 6. Call kernel
  kernel(&kernel_context);

  return out;
}

91
Tensor to_sparse_csr_impl(const Tensor& x, Backend backend) {
92
  if (x.layout() == phi::DataLayout::SPARSE_CSR) {
93 94 95 96 97
    return x;
  }
  // 1. Get kernel signature and kernel
  auto kernel_key_set = ParseKernelKeyByInputArgs(x);
  kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
98
  auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
99
  std::string kernel_name = "dense_to_sparse_csr";
100
  if (x.layout() == phi::DataLayout::SPARSE_COO) {
101 102 103
    kernel_name = "sparse_coo_to_csr";
  }

104
  auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
105 106 107 108 109 110 111
      kernel_name, kernel_key);

  VLOG(6) << "to API kernel key: " << kernel_key;
  VLOG(6) << "to API kernel: " << kernel;

  // 2. Get Device Context
  auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
112
  auto kernel_context = phi::KernelContext(dev_ctx);
113 114

  // 3. Auto data transform
115 116
  if (x.layout() == phi::DataLayout::SPARSE_COO) {
    auto input = std::dynamic_pointer_cast<phi::SparseCooTensor>(x.impl());
117 118
    kernel_context.EmplaceBackInput(input.get());
  } else {
119
    auto input = std::dynamic_pointer_cast<phi::DenseTensor>(x.impl());
120 121 122 123
    kernel_context.EmplaceBackInput(input.get());
  }

  // 4. InferMeta
124 125 126 127 128
  auto crows_meta =
      phi::DenseTensorMeta(phi::DataType::INT64, {-1}, phi::DataLayout::NCHW);
  auto cols_meta =
      phi::DenseTensorMeta(phi::DataType::INT64, {-1}, phi::DataLayout::NCHW);
  auto elements_meta = phi::DenseTensorMeta(x.dtype(), {-1}, x.layout());
129 130 131

  // 5. Prepare outputs
  // create empty SparseCooTensor
132 133
  phi::DenseTensor non_zero_crows(
      phi::make_intrusive<paddle::experimental::SharedStorage>(
134
          phi::TransToPhiPlace(backend)),
135
      std::move(crows_meta));
136 137
  phi::DenseTensor non_zero_cols(
      phi::make_intrusive<paddle::experimental::SharedStorage>(
138
          phi::TransToPhiPlace(backend)),
139
      std::move(cols_meta));
140 141
  phi::DenseTensor non_zero_elements(
      phi::make_intrusive<paddle::experimental::SharedStorage>(
142
          phi::TransToPhiPlace(backend)),
143
      std::move(elements_meta));
144
  auto csr = std::make_shared<phi::SparseCsrTensor>(
145 146 147 148 149 150 151 152 153 154 155
      non_zero_crows, non_zero_cols, non_zero_elements, x.dims());

  kernel_context.EmplaceBackOutput(csr.get());
  Tensor out;
  out.set_impl(csr);

  // 6. Call kernel
  kernel(&kernel_context);

  return out;
}
Z
zhangkaihuo 已提交
156

157
Tensor to_dense_impl(const Tensor& x, Backend backend) {
158 159
  if (x.layout() != phi::DataLayout::SPARSE_CSR &&
      x.layout() != phi::DataLayout::SPARSE_COO) {
Z
zhangkaihuo 已提交
160 161 162 163 164
    return x;
  }
  // 1. Get kernel signature and kernel
  auto kernel_key_set = ParseKernelKeyByInputArgs(x);
  kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
165
  auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
Z
zhangkaihuo 已提交
166
  std::string kernel_name = "sparse_coo_to_dense";
167
  if (x.layout() == phi::DataLayout::SPARSE_CSR) {
Z
zhangkaihuo 已提交
168 169 170
    kernel_name = "sparse_csr_to_dense";
  }

171
  auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
Z
zhangkaihuo 已提交
172 173 174 175 176 177 178
      kernel_name, kernel_key);

  VLOG(6) << "to API kernel key: " << kernel_key;
  VLOG(6) << "to API kernel: " << kernel;

  // 2. Get Device Context
  auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
179
  auto kernel_context = phi::KernelContext(dev_ctx);
Z
zhangkaihuo 已提交
180 181

  // 3. Auto data transform
182 183
  if (x.layout() == phi::DataLayout::SPARSE_COO) {
    auto input = std::dynamic_pointer_cast<phi::SparseCooTensor>(x.impl());
Z
zhangkaihuo 已提交
184 185
    kernel_context.EmplaceBackInput(input.get());
  } else {
186
    auto input = std::dynamic_pointer_cast<phi::SparseCsrTensor>(x.impl());
Z
zhangkaihuo 已提交
187 188 189 190
    kernel_context.EmplaceBackInput(input.get());
  }

  // 4. InferMeta
191
  auto dense_meta = phi::DenseTensorMeta(x.dtype(), x.dims(), x.layout());
Z
zhangkaihuo 已提交
192 193 194

  // 5. Prepare outputs
  // create empty SparseCooTensor
195 196
  auto dense_out = std::make_shared<phi::DenseTensor>(
      phi::make_intrusive<paddle::experimental::SharedStorage>(
197
          phi::TransToPhiPlace(backend)),
Z
zhangkaihuo 已提交
198 199 200 201 202 203 204 205 206 207 208 209
      std::move(dense_meta));

  kernel_context.EmplaceBackOutput(dense_out.get());
  Tensor out;
  out.set_impl(dense_out);

  // 6. Call kernel
  kernel(&kernel_context);

  return out;
}

210 211 212 213
}  // namespace sparse
}  // namespace experimental
}  // namespace paddle

214
PD_REGISTER_API(SparseApi);