sparse_api_custom_impl.cc 6.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/api/lib/sparse_api_custom_impl.h"
16 17 18

#include <memory>
#include "glog/logging.h"
19 20 21
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/core/kernel_registry.h"
22 23 24 25 26

namespace paddle {
namespace experimental {
namespace sparse {

27
Tensor to_sparse_coo_impl(const Tensor& x, const int64_t sparse_dim) {
28
  if (x.layout() == phi::DataLayout::SPARSE_COO) {
29 30
    return x;
  }
31

32 33
  // 1. Get kernel signature and kernel
  std::string kernel_name = "dense_to_sparse_coo";
34
  if (x.layout() == phi::DataLayout::SPARSE_CSR) {
35 36 37
    kernel_name = "sparse_csr_to_coo";
  }

38 39 40
  auto kernel_key_set = ParseKernelKeyByInputArgs(x);
  auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();

41
  auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
42 43
      kernel_name, kernel_key);

44
  VLOG(6) << "add API kernel key: " << kernel_key;
45 46 47 48
  VLOG(6) << "to API kernel: " << kernel;

  // 2. Get Device Context
  auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
49
  auto kernel_context = phi::KernelContext(dev_ctx);
50 51

  // 3. Auto data transform
52 53
  if (x.layout() == phi::DataLayout::SPARSE_CSR) {
    auto input = std::dynamic_pointer_cast<phi::SparseCsrTensor>(x.impl());
54 55
    kernel_context.EmplaceBackInput(input.get());
  } else {
56
    auto input = std::dynamic_pointer_cast<phi::DenseTensor>(x.impl());
57 58 59 60 61
    kernel_context.EmplaceBackInput(input.get());
    kernel_context.EmplaceBackAttr(sparse_dim);
  }

  // 4. InferMeta
62
  auto indices_meta =
63 64
      phi::DenseTensorMeta(phi::DataType::INT64, {1}, phi::DataLayout::NCHW);
  auto elements_meta = phi::DenseTensorMeta(x.dtype(), {1}, x.layout());
65 66 67

  // 5. Prepare outputs
  // create empty SparseCooTensor
Z
zyfncg 已提交
68 69 70 71
  phi::DenseTensor non_zero_indices(std::make_shared<phi::Allocation>(),
                                    std::move(indices_meta));
  phi::DenseTensor non_zero_elements(std::make_shared<phi::Allocation>(),
                                     std::move(elements_meta));
72
  auto coo = std::make_shared<phi::SparseCooTensor>(
73 74 75 76 77 78 79 80 81 82 83 84
      non_zero_indices, non_zero_elements, x.dims());

  kernel_context.EmplaceBackOutput(coo.get());
  Tensor out;
  out.set_impl(coo);

  // 6. Call kernel
  kernel(&kernel_context);

  return out;
}

85
Tensor to_sparse_csr_impl(const Tensor& x) {
86
  if (x.layout() == phi::DataLayout::SPARSE_CSR) {
87 88 89 90
    return x;
  }
  // 1. Get kernel signature and kernel
  std::string kernel_name = "dense_to_sparse_csr";
91
  if (x.layout() == phi::DataLayout::SPARSE_COO) {
92 93 94
    kernel_name = "sparse_coo_to_csr";
  }

95 96 97
  auto kernel_key_set = ParseKernelKeyByInputArgs(x);
  auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();

98
  auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
99 100
      kernel_name, kernel_key);

101
  VLOG(6) << "add API kernel key: " << kernel_key;
102 103 104 105
  VLOG(6) << "to API kernel: " << kernel;

  // 2. Get Device Context
  auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
106
  auto kernel_context = phi::KernelContext(dev_ctx);
107 108

  // 3. Auto data transform
109 110
  if (x.layout() == phi::DataLayout::SPARSE_COO) {
    auto input = std::dynamic_pointer_cast<phi::SparseCooTensor>(x.impl());
111 112
    kernel_context.EmplaceBackInput(input.get());
  } else {
113
    auto input = std::dynamic_pointer_cast<phi::DenseTensor>(x.impl());
114 115 116 117
    kernel_context.EmplaceBackInput(input.get());
  }

  // 4. InferMeta
118
  auto crows_meta =
119
      phi::DenseTensorMeta(phi::DataType::INT64, {1}, phi::DataLayout::NCHW);
120
  auto cols_meta =
121 122
      phi::DenseTensorMeta(phi::DataType::INT64, {1}, phi::DataLayout::NCHW);
  auto elements_meta = phi::DenseTensorMeta(x.dtype(), {1}, x.layout());
123 124 125

  // 5. Prepare outputs
  // create empty SparseCooTensor
Z
zyfncg 已提交
126 127 128 129 130 131
  phi::DenseTensor non_zero_crows(std::make_shared<phi::Allocation>(),
                                  std::move(crows_meta));
  phi::DenseTensor non_zero_cols(std::make_shared<phi::Allocation>(),
                                 std::move(cols_meta));
  phi::DenseTensor non_zero_elements(std::make_shared<phi::Allocation>(),
                                     std::move(elements_meta));
132
  auto csr = std::make_shared<phi::SparseCsrTensor>(
133 134 135 136 137 138 139 140 141 142 143
      non_zero_crows, non_zero_cols, non_zero_elements, x.dims());

  kernel_context.EmplaceBackOutput(csr.get());
  Tensor out;
  out.set_impl(csr);

  // 6. Call kernel
  kernel(&kernel_context);

  return out;
}
Z
zhangkaihuo 已提交
144

145
Tensor to_dense_impl(const Tensor& x) {
146 147
  if (x.layout() != phi::DataLayout::SPARSE_CSR &&
      x.layout() != phi::DataLayout::SPARSE_COO) {
Z
zhangkaihuo 已提交
148 149
    return x;
  }
150

Z
zhangkaihuo 已提交
151 152
  // 1. Get kernel signature and kernel
  std::string kernel_name = "sparse_coo_to_dense";
153
  if (x.layout() == phi::DataLayout::SPARSE_CSR) {
Z
zhangkaihuo 已提交
154 155 156
    kernel_name = "sparse_csr_to_dense";
  }

157 158 159
  auto kernel_key_set = ParseKernelKeyByInputArgs(x);
  auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();

160
  auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
Z
zhangkaihuo 已提交
161 162
      kernel_name, kernel_key);

163
  VLOG(6) << "add API kernel key: " << kernel_key;
Z
zhangkaihuo 已提交
164 165 166 167
  VLOG(6) << "to API kernel: " << kernel;

  // 2. Get Device Context
  auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
168
  auto kernel_context = phi::KernelContext(dev_ctx);
Z
zhangkaihuo 已提交
169 170

  // 3. Auto data transform
171 172
  if (x.layout() == phi::DataLayout::SPARSE_COO) {
    auto input = std::dynamic_pointer_cast<phi::SparseCooTensor>(x.impl());
Z
zhangkaihuo 已提交
173 174
    kernel_context.EmplaceBackInput(input.get());
  } else {
175
    auto input = std::dynamic_pointer_cast<phi::SparseCsrTensor>(x.impl());
Z
zhangkaihuo 已提交
176 177 178 179
    kernel_context.EmplaceBackInput(input.get());
  }

  // 4. InferMeta
180
  auto dense_meta = phi::DenseTensorMeta(x.dtype(), x.dims(), x.layout());
Z
zhangkaihuo 已提交
181 182 183

  // 5. Prepare outputs
  // create empty SparseCooTensor
184
  auto dense_out = std::make_shared<phi::DenseTensor>(
Z
zyfncg 已提交
185
      std::make_shared<phi::Allocation>(), std::move(dense_meta));
Z
zhangkaihuo 已提交
186 187 188 189 190 191 192 193 194 195 196

  kernel_context.EmplaceBackOutput(dense_out.get());
  Tensor out;
  out.set_impl(dense_out);

  // 6. Call kernel
  kernel(&kernel_context);

  return out;
}

197 198 199
}  // namespace sparse
}  // namespace experimental
}  // namespace paddle