sparse_api_custom_impl.cc 6.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/api/lib/sparse_api_custom_impl.h"
16 17

#include <memory>
18

19
#include "glog/logging.h"
20 21
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/core/kernel_registry.h"
22 23 24 25 26

namespace paddle {
namespace experimental {
namespace sparse {

27
Tensor to_sparse_coo_impl(const Tensor& x, const int64_t sparse_dim) {
28
  if (x.layout() == phi::DataLayout::SPARSE_COO) {
29 30
    return x;
  }
31

32
  // 1. Get kernel signature and kernel
33
  std::string kernel_name = "dense_to_coo";
34
  if (x.layout() == phi::DataLayout::SPARSE_CSR) {
35
    kernel_name = "csr_to_coo";
36 37
  }

38 39 40
  auto kernel_key_set = ParseKernelKeyByInputArgs(x);
  auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();

41
  auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
42
      kernel_name, kernel_key);
43
  const auto& kernel = kernel_result.kernel;
44

45
  VLOG(6) << "add API kernel key: " << kernel_key;
46 47 48 49
  VLOG(6) << "to API kernel: " << kernel;

  // 2. Get Device Context
  auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
50
  auto kernel_context = phi::KernelContext(dev_ctx);
51 52

  // 3. Auto data transform
53 54
  if (x.layout() == phi::DataLayout::SPARSE_CSR) {
    auto input = std::dynamic_pointer_cast<phi::SparseCsrTensor>(x.impl());
55 56
    kernel_context.EmplaceBackInput(input.get());
  } else {
57
    auto input = std::dynamic_pointer_cast<phi::DenseTensor>(x.impl());
58 59 60 61 62
    kernel_context.EmplaceBackInput(input.get());
    kernel_context.EmplaceBackAttr(sparse_dim);
  }

  // 4. InferMeta
63
  auto indices_meta =
64 65
      phi::DenseTensorMeta(phi::DataType::INT64, {1}, phi::DataLayout::NCHW);
  auto elements_meta = phi::DenseTensorMeta(x.dtype(), {1}, x.layout());
66 67 68

  // 5. Prepare outputs
  // create empty SparseCooTensor
Z
zyfncg 已提交
69 70 71 72
  phi::DenseTensor non_zero_indices(std::make_shared<phi::Allocation>(),
                                    std::move(indices_meta));
  phi::DenseTensor non_zero_elements(std::make_shared<phi::Allocation>(),
                                     std::move(elements_meta));
73
  auto coo = std::make_shared<phi::SparseCooTensor>(
74 75 76 77 78 79 80 81 82 83 84 85
      non_zero_indices, non_zero_elements, x.dims());

  kernel_context.EmplaceBackOutput(coo.get());
  Tensor out;
  out.set_impl(coo);

  // 6. Call kernel
  kernel(&kernel_context);

  return out;
}

86
Tensor to_sparse_csr_impl(const Tensor& x) {
87
  if (x.layout() == phi::DataLayout::SPARSE_CSR) {
88 89 90
    return x;
  }
  // 1. Get kernel signature and kernel
91
  std::string kernel_name = "dense_to_csr";
92
  if (x.layout() == phi::DataLayout::SPARSE_COO) {
93
    kernel_name = "coo_to_csr";
94 95
  }

96 97 98
  auto kernel_key_set = ParseKernelKeyByInputArgs(x);
  auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();

99
  auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
100
      kernel_name, kernel_key);
101
  const auto& kernel = kernel_result.kernel;
102

103
  VLOG(6) << "add API kernel key: " << kernel_key;
104 105 106 107
  VLOG(6) << "to API kernel: " << kernel;

  // 2. Get Device Context
  auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
108
  auto kernel_context = phi::KernelContext(dev_ctx);
109 110

  // 3. Auto data transform
111 112
  if (x.layout() == phi::DataLayout::SPARSE_COO) {
    auto input = std::dynamic_pointer_cast<phi::SparseCooTensor>(x.impl());
113 114
    kernel_context.EmplaceBackInput(input.get());
  } else {
115
    auto input = std::dynamic_pointer_cast<phi::DenseTensor>(x.impl());
116 117 118 119
    kernel_context.EmplaceBackInput(input.get());
  }

  // 4. InferMeta
120
  auto crows_meta =
121
      phi::DenseTensorMeta(phi::DataType::INT64, {1}, phi::DataLayout::NCHW);
122
  auto cols_meta =
123 124
      phi::DenseTensorMeta(phi::DataType::INT64, {1}, phi::DataLayout::NCHW);
  auto elements_meta = phi::DenseTensorMeta(x.dtype(), {1}, x.layout());
125 126 127

  // 5. Prepare outputs
  // create empty SparseCooTensor
Z
zyfncg 已提交
128 129 130 131 132 133
  phi::DenseTensor non_zero_crows(std::make_shared<phi::Allocation>(),
                                  std::move(crows_meta));
  phi::DenseTensor non_zero_cols(std::make_shared<phi::Allocation>(),
                                 std::move(cols_meta));
  phi::DenseTensor non_zero_elements(std::make_shared<phi::Allocation>(),
                                     std::move(elements_meta));
134
  auto csr = std::make_shared<phi::SparseCsrTensor>(
135 136 137 138 139 140 141 142 143 144 145
      non_zero_crows, non_zero_cols, non_zero_elements, x.dims());

  kernel_context.EmplaceBackOutput(csr.get());
  Tensor out;
  out.set_impl(csr);

  // 6. Call kernel
  kernel(&kernel_context);

  return out;
}
Z
zhangkaihuo 已提交
146

147
Tensor to_dense_impl(const Tensor& x) {
148 149
  if (x.layout() != phi::DataLayout::SPARSE_CSR &&
      x.layout() != phi::DataLayout::SPARSE_COO) {
Z
zhangkaihuo 已提交
150 151
    return x;
  }
152

Z
zhangkaihuo 已提交
153
  // 1. Get kernel signature and kernel
154
  std::string kernel_name = "coo_to_dense";
155
  if (x.layout() == phi::DataLayout::SPARSE_CSR) {
156
    kernel_name = "csr_to_dense";
Z
zhangkaihuo 已提交
157 158
  }

159 160 161
  auto kernel_key_set = ParseKernelKeyByInputArgs(x);
  auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();

162
  auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
Z
zhangkaihuo 已提交
163
      kernel_name, kernel_key);
164
  const auto& kernel = kernel_result.kernel;
Z
zhangkaihuo 已提交
165

166
  VLOG(6) << "add API kernel key: " << kernel_key;
Z
zhangkaihuo 已提交
167 168 169 170
  VLOG(6) << "to API kernel: " << kernel;

  // 2. Get Device Context
  auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
171
  auto kernel_context = phi::KernelContext(dev_ctx);
Z
zhangkaihuo 已提交
172 173

  // 3. Auto data transform
174 175
  if (x.layout() == phi::DataLayout::SPARSE_COO) {
    auto input = std::dynamic_pointer_cast<phi::SparseCooTensor>(x.impl());
Z
zhangkaihuo 已提交
176 177
    kernel_context.EmplaceBackInput(input.get());
  } else {
178
    auto input = std::dynamic_pointer_cast<phi::SparseCsrTensor>(x.impl());
Z
zhangkaihuo 已提交
179 180 181 182
    kernel_context.EmplaceBackInput(input.get());
  }

  // 4. InferMeta
183
  auto dense_meta = phi::DenseTensorMeta(x.dtype(), x.dims(), x.layout());
Z
zhangkaihuo 已提交
184 185 186

  // 5. Prepare outputs
  // create empty SparseCooTensor
187
  auto dense_out = std::make_shared<phi::DenseTensor>(
Z
zyfncg 已提交
188
      std::make_shared<phi::Allocation>(), std::move(dense_meta));
Z
zhangkaihuo 已提交
189 190 191 192 193 194 195 196 197 198 199

  kernel_context.EmplaceBackOutput(dense_out.get());
  Tensor out;
  out.set_impl(dense_out);

  // 6. Call kernel
  kernel(&kernel_context);

  return out;
}

200 201 202
}  // namespace sparse
}  // namespace experimental
}  // namespace paddle