api_gen_utils.cc 7.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/api/lib/api_gen_utils.h"
16 17 18 19 20 21

namespace paddle {
namespace experimental {

/* ------------------ for input ----------------------- */

22
std::shared_ptr<phi::DenseTensor> TensorToDenseTensor(const Tensor& tensor) {
Z
zyfncg 已提交
23
  return std::static_pointer_cast<phi::DenseTensor>(tensor.impl());
24 25
}

26 27
paddle::optional<phi::DenseTensor> TensorToDenseTensor(
    const paddle::optional<Tensor>& tensor) {
28
  if (tensor) {
29
    return {*std::static_pointer_cast<phi::DenseTensor>(tensor->impl())};
30 31 32 33
  }
  return nullptr;
}

34
std::unique_ptr<std::vector<phi::DenseTensor*>> TensorToDenseTensor(
35
    const std::vector<Tensor>& tensors) {
36
  auto pt_tensors = std::make_unique<std::vector<phi::DenseTensor*>>();
37 38 39 40
  pt_tensors->reserve(tensors.size());

  for (const auto& t : tensors) {
    pt_tensors->push_back(
41
        std::dynamic_pointer_cast<phi::DenseTensor>(t.impl()).get());
42 43
  }

44
  return pt_tensors;
45 46
}

47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
std::vector<const phi::DenseTensor*> TensorToConstDenseTensorPtr(
    const std::vector<Tensor>& tensors) {
  std::vector<const phi::DenseTensor*> pt_tensors(tensors.size());

  for (size_t i = 0; i < tensors.size(); ++i) {
    pt_tensors[i] = static_cast<phi::DenseTensor*>(tensors[i].impl().get());
  }

  return pt_tensors;
}

paddle::optional<std::vector<const phi::DenseTensor*>>
TensorToConstDenseTensorPtr(
    const paddle::optional<std::vector<Tensor>>& tensors) {
  paddle::optional<std::vector<const phi::DenseTensor*>> pt_tensors;

  if (tensors) {
    pt_tensors =
        paddle::optional<std::vector<const phi::DenseTensor*>>(tensors->size());
    for (size_t i = 0; i < tensors->size(); ++i) {
      pt_tensors->at(i) =
          static_cast<phi::DenseTensor*>(tensors->at(i).impl().get());
    }
  }

  return pt_tensors;
}

75
std::shared_ptr<phi::SelectedRows> TensorToSelectedRows(const Tensor& tensor) {
Z
zyfncg 已提交
76
  return std::static_pointer_cast<phi::SelectedRows>(tensor.impl());
77 78
}

79 80
paddle::optional<phi::SelectedRows> TensorToSelectedRows(
    const paddle::optional<Tensor>& tensor) {
81
  if (tensor) {
82
    return {*std::static_pointer_cast<phi::SelectedRows>(tensor->impl())};
83 84 85 86
  }
  return nullptr;
}

J
Jack Zhou 已提交
87 88 89 90
std::shared_ptr<phi::StringTensor> TensorToStringTensor(const Tensor& tensor) {
  return std::dynamic_pointer_cast<phi::StringTensor>(tensor.impl());
}

91 92
/* ----------------- for infer_meta --------------------- */

93
phi::MetaTensor MakeMetaTensor(const phi::TensorBase& tensor) {
94 95 96
  return phi::MetaTensor(tensor);
}

97 98
phi::MetaTensor MakeMetaTensor(
    const paddle::optional<phi::DenseTensor>& tensor) {
Z
zyfncg 已提交
99 100 101
  if (tensor) {
    return {phi::MetaTensor(*tensor)};
  }
102
  return phi::MetaTensor();
Z
zyfncg 已提交
103 104
}

105
std::vector<phi::MetaTensor> MakeMetaTensor(
106
    const std::vector<const phi::DenseTensor*>& tensors) {
107 108
  std::vector<phi::MetaTensor> meta_tensors;
  meta_tensors.reserve(tensors.size());
109 110
  for (const auto* t : tensors) {
    meta_tensors.emplace_back(*t);
111 112 113 114
  }
  return meta_tensors;
}

115 116 117 118 119 120 121 122 123 124
std::vector<phi::MetaTensor> MakeMetaTensor(
    const std::vector<phi::DenseTensor*>& tensors) {
  std::vector<phi::MetaTensor> meta_tensors;
  meta_tensors.reserve(tensors.size());
  for (auto* t : tensors) {
    meta_tensors.emplace_back(*t);
  }
  return meta_tensors;
}

125 126
phi::MetaTensor MakeMetaTensor(
    const paddle::optional<phi::SelectedRows>& tensor) {
Z
zyfncg 已提交
127 128 129
  if (tensor) {
    return {phi::MetaTensor(*tensor)};
  }
130
  return phi::MetaTensor();
Z
zyfncg 已提交
131 132
}

133 134 135 136 137 138 139 140 141 142 143 144
std::vector<phi::MetaTensor> MakeMetaTensor(
    const paddle::optional<std::vector<const phi::DenseTensor*>>& tensors) {
  std::vector<phi::MetaTensor> meta_tensors;
  if (tensors) {
    meta_tensors.reserve(tensors->size());
    for (auto* t : tensors.get()) {
      meta_tensors.emplace_back(*t);
    }
  }
  return meta_tensors;
}

145 146
/* ------------------ for output ----------------------- */

147
phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) {
148 149 150 151 152
  if (out) {
    if (out->impl() == nullptr) {
      out->set_impl(std::make_shared<phi::DenseTensor>());
    }
    return static_cast<phi::DenseTensor*>(out->impl().get());
153
  }
154
  return nullptr;
155 156
}

157 158 159
std::vector<phi::DenseTensor*> SetKernelOutput(size_t out_size,
                                               Backend backend,
                                               std::vector<Tensor>* out) {
160 161 162
  out->reserve(out_size);
  std::vector<phi::DenseTensor*> results(out_size);
  for (size_t i = 0; i < out_size; ++i) {
163
    auto tensor_ptr = std::make_shared<phi::DenseTensor>();
164 165 166 167 168 169 170
    results[i] = tensor_ptr.get();
    out->emplace_back();
    out->back().set_impl(tensor_ptr);
  }
  return results;
}

171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
std::vector<phi::DenseTensor*> SetInplaceVectorKernelOutput(
    size_t out_size, Backend backend, std::vector<Tensor>* out) {
  std::vector<phi::DenseTensor*> results(out->size(), nullptr);
  for (size_t i = 0; i < out->size(); ++i) {
    results[i] = static_cast<phi::DenseTensor*>(out->at(i).impl().get());
  }
  return results;
}

std::vector<phi::DenseTensor*> SetInplaceOptionalVectorKernelOutput(
    size_t out_size,
    Backend backend,
    const paddle::optional<std::vector<Tensor>>& out) {
  std::vector<phi::DenseTensor*> results;
  if (out) {
    results = std::vector<phi::DenseTensor*>(out->size(), nullptr);
    for (size_t i = 0; i < out->size(); ++i) {
      results[i] = static_cast<phi::DenseTensor*>(out->at(i).impl().get());
    }
  }
  return results;
}

194 195 196 197 198 199 200 201 202 203 204 205
std::vector<phi::DenseTensor*> SetKernelOutput(std::vector<Tensor*>* out) {
  std::vector<phi::DenseTensor*> results(out->size(), nullptr);
  for (size_t i = 0; i < out->size(); ++i) {
    if (out->at(i)) {
      auto tensor_ptr = std::make_shared<phi::DenseTensor>();
      results[i] = tensor_ptr.get();
      (*out)[i]->set_impl(tensor_ptr);
    }
  }
  return results;
}

206
phi::SelectedRows* SetSelectedRowsKernelOutput(Backend backend, Tensor* out) {
207 208 209 210 211 212 213 214
  if (!out->initialized()) {
    auto select_rows = std::make_shared<phi::SelectedRows>();
    out->set_impl(select_rows);
    return select_rows.get();
  }
  return static_cast<phi::SelectedRows*>(out->impl().get());
}

215 216 217 218 219 220 221 222 223 224 225 226
phi::TensorBase* SetSparseKernelOutput(Tensor* out, TensorType type) {
  if (!out->initialized()) {
    if (type == TensorType::SPARSE_COO) {
      auto sparse_tensor = std::make_shared<phi::SparseCooTensor>(
          phi::DenseTensor(), phi::DenseTensor(), phi::DDim{-1});
      out->set_impl(sparse_tensor);
      return sparse_tensor.get();
    } else if (type == TensorType::SPARSE_CSR) {
      auto sparse_tensor =
          std::make_shared<phi::SparseCsrTensor>(phi::DenseTensor(),
                                                 phi::DenseTensor(),
                                                 phi::DenseTensor(),
T
tiancaishaonvjituizi 已提交
227
                                                 phi::DDim{-1, -1});
228 229 230 231 232 233 234 235 236 237 238
      out->set_impl(sparse_tensor);
      return sparse_tensor.get();
    } else {
      auto dense_tensor = std::make_shared<phi::DenseTensor>();
      out->set_impl(dense_tensor);
      return dense_tensor.get();
    }
  }
  return out->impl().get();
}

J
Jack Zhou 已提交
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
phi::TensorBase* SetStringsKernelOutput(Backend backend,
                                        Tensor* out,
                                        TensorType type) {
  if (!out->initialized()) {
    if (type == TensorType::STRING_TENSOR) {
      if (out->impl() == nullptr) {
        auto strings_tensor = std::make_shared<phi::StringTensor>();
        out->set_impl(strings_tensor);
      }
      return out->impl().get();
    }
  }
  return out->impl().get();
}

254 255
}  // namespace experimental
}  // namespace paddle